TensorFlow2.0教程-卷积自编码器

TensorFlow2.0教程-卷积自编码器

TensorFlow2.0教程-卷积自编码器

最全Tensorflow 2.0 入门教程持续更新:

Doit:最全Tensorflow 2.0 入门教程持续更新zhuanlan.zhihu.com图标

完整tensorflow2.0教程代码请看https://github.com/czy36mengfei/tensorflow2_tutorials_chinese (欢迎star)

本教程主要由tensorflow2.0官方教程的个人学习复现笔记整理而来,中文讲解,方便喜欢阅读中文教程的朋友,官方教程:https://www.tensorflow.org


1.导入数据

import tensorflow as tf
import tensorflow.keras as keras
import tensorflow.keras.layers as layers
print(tf.__version__)
2.0.0-alpha0
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = tf.expand_dims(x_train.astype('float32'), -1) / 255.0
x_test = tf.expand_dims(x_test.astype('float32'),-1) / 255.0

print(x_train.shape, ' ', y_train.shape)
print(x_test.shape, ' ', y_test.shape)
(60000, 28, 28, 1)   (60000,)
(10000, 28, 28, 1)   (10000,)



2.构造网络

inputs = layers.Input(shape=(x_train.shape[1], x_train.shape[2], x_train.shape[3]), name='inputs')
print(inputs.shape)
code = layers.Conv2D(16, (3,3), activation='relu', padding='same')(inputs)
code = layers.MaxPool2D((2,2), padding='same')(code)
print(code.shape)
decoded = layers.Conv2D(16, (3,3), activation='relu', padding='same')(code)
decoded = layers.UpSampling2D((2,2))(decoded)
print(decoded.shape)
outputs = layers.Conv2D(1, (3,3), activation='sigmoid', padding='same')(decoded)
print(outputs.shape)
auto_encoder = keras.Model(inputs, outputs)
(None, 28, 28, 1)
(None, 14, 14, 16)
(None, 28, 28, 16)
(None, 28, 28, 1)
auto_encoder.compile(optimizer=keras.optimizers.Adam(),
                    loss=keras.losses.BinaryCrossentropy())
keras.utils.plot_model(auto_encoder, show_shapes=True)

3.训练与测试

early_stop = keras.callbacks.EarlyStopping(patience=2, monitor='loss')
auto_encoder.fit(x_train,x_train, batch_size=64, epochs=1, validation_split=0.1,validation_freq=10,
                callbacks=[early_stop])
Train on 54000 samples, validate on 6000 samples
54000/54000 [==============================] - 31s 572us/sample - loss: 0.1007





<tensorflow.python.keras.callbacks.History at 0x7f0089d31fd0>
import matplotlib.pyplot as plt
decoded = auto_encoder.predict(x_test)
n = 5
plt.figure(figsize=(10, 4))
for i in range(n):
    # display original
    ax = plt.subplot(2, n, i+1)
    plt.imshow(tf.reshape(x_test[i+1],(28, 28)))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

    # display reconstruction
    ax = plt.subplot(2, n, i + n+1)
    plt.imshow(tf.reshape(decoded[i+1],(28, 28)))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
plt.show()

参考材料:

buomsoo-kim/Easy-deep-learning-with-Keras

编辑于 2019-05-02

文章被以下专栏收录