博客
关于我
自编码器模型详解与实现(采用tensorflow2.x实现)
阅读量:783 次
发布时间:2019-03-25

本文共 7897 字,大约阅读时间需要 26 分钟。

自编码器模型详解与实现(采用TensorFlow 2.x 实现)

1. 自编码器与潜变量学习概述

自编码器是一种在无监督学习中广泛应用的深度学习模型,由Geoffrey Hinton等人于1980年代首次提出。它的核心目标是通过压缩高维输入空间到低维潜变量Representation(以下简称“潜”),并在解码阶段将这些潜还原为原始高维输入。这种能力使其在图像处理、材质分离等领域具有重要应用价值。

在图像处理领域,自编码器可以类比于数据压缩与解压过程。例如,就如JPEG将高分辨率图像压缩为小文件格式一样,自编码器则可以将原始图像压缩为低维潜变量,再通过解码器还原回高分辨率图像。这使得自编码器成为一种高效的图像压缩与恢复工具。

2. 自编码器架构详解

2.1 编码器设计

编码器负责将高维输入通过一系列的神经网络层压缩为低维潜变量。我们以MNIST数据集为例,构建一个适用于28x28x1输入尺寸的编码器。潜变量的维度设置为低于输入维度的超参数,这里采用10维。

以下是编码器的实现代码框架:

def Encoder(z_dim):    inputs = layers.Input(shape=[28, 28, 1])    x = Conv2D(filters=8, kernel_size=(3,3), strides=2, padding='same', activation='relu')(x)    x = Conv2D(filters=8, kernel_size=(3,3), strides=1, padding='same', activation='relu')(x)    x = Conv2D(filters=8, kernel_size=(3,3), strides=2, padding='same', activation='relu')(x)    x = Conv2D(filters=8, kernel_size=(3,3), strides=1, padding='same', activation='relu')(x)    x = Flatten()(x)    out = Dense(z_dim, activation='relu')(x)    return Model(inputs=inputs, outputs=out, name='encoder')

编码器主要包含卷积层和全连接层。卷积层用于提取高层次的特征,同时通过调整卷积核的步长(如2)实现特征图的下采样,逐步减少输入的高维信息。全连接层则负责将多个特征图融合到低维潜变量空间中。

2.2 解码器设计

解码器的任务是将低维潜变量还原为高维图像。其结构与编码器相似,但需在解码过程中通过卷积层和上采样操作逐步还原特征图。

以下是解码器的实现代码框架:

def Decoder(z_dim):    inputs = layers.Input(shape=[z_dim])    x = Dense(7*7*64, activation='relu')(x)    x = Reshape((7,7,64))(x)    x = Conv2D(filters=64, kernel_size=(3,3), strides=1, padding='same', activation='relu')(x)    x = UpSampling2D((2,2))(x)    x = Conv2D(filters=32, kernel_size=(3,3), strides=1, padding='same', activation='relu')(x)    x = UpSampling2D((2,2))(x)    out = Conv2D(filters=1, kernel_size=(3,3), strides=1, padding='same', activation='sigmoid')(x)    return Model(inputs=inputs, outputs=out, name='decoder')

解码器通过卷积层在低维空间中生成特征图,并结合上采样操作逐渐将特征图还原为原始图像尺寸。上采样方法包括卷积核转置(例如UpSampling2D)或仿射变换,但后者是不训练的参数,通常不适合深度学习模型。

3. 自编码器模型构建

将编码器和解码器组合,构建完整的自编码器模型:

z_dim = 10encoder = Encoder(z_dim)decoder = Decoder(z_dim)model_input = encoder.inputmodel_output = decoder(encoder.output)autoencoder = Model(model_input, model_output)
3.1 模型训练

为了训练模型,我们采用MSE(均方误差)损失函数,旨在最小化编码器输出与解码器预测值之间的差异。同时,使用一些训练回调(如ModelCheckpointEarlyStopping)来优化训练过程。

autoencoder.compile(loss='mse', optimizer='rmsprop', lr=3e-4)

训练过程中,我们需要分成训练集和验证集,定期保存最佳模型参数以防止过拟合。

4. 从潜变量生成图像

自编码器的潜变量具有潜在的生成能力。比如,如果我们定义另一个解码器仅使用潜变量生成图像,可以利用这个能力进行高效的图像生成。

z_dim = 2  # 定义更低维的潜变量autoencoder_2 = Autoencoder(z_dim=2)

通过对潜变量进行采样,可以生成大量不同样本。如上图所示,我们采用2维潜变量空间,生成500个样本,散布在二维平面上。通过观察标签分布图,可以发现某些类别的潜变量代表性较强,而另一些类别则相对模糊。

更进一步地,我们可以通过滑动窗口或交互式工具(如下图所示),进行潜变量的可视化和探索。

from ipywidgets import interact, interact_manual@interactdef explore_latent_variable(z1=(-5,5,0.1), z2=(-5,5,0.1)):    z_samples = np.array([[z1, z2] for z2 in np.arange(-5,5,0.1)] for z1 in np.arange(-5,5,0.1))    images = autoencoder_2.decoder.predict(z_samples)    plt.figure(figsize=(2,2))    plt.imshow(images[0,:,:,0], cmap='gray')

完整代码示例

import tensorflow as tffrom tensorflow.keras import layers, Modelfrom tensorflow.keras.layers import Input, Conv2D, Dense, Flatten, Reshape, Conv2DTranspose, MaxPooling2D, UpSampling2D, LeakyReLUfrom tensorflow.keras.activations import relufrom tensorflow.keras.models import Sequential, load_modelfrom tensorflow.keras.callbacks import ModelCheckpoint, EarlyStoppingimport tensorflow_datasets as tfdsimport numpy as npimport matplotlib.pyplot as pltimport warningswarnings.filterwarnings('ignore')print(tf.__version__)# 加载MNIST数据集(ds_train, ds_test), ds_info = tfds.load(    'mnist',    split=['train', 'test'],    shuffle_files=True,    as_supervised=True,    with_info=True)# 预处理数据def preprocess(image, label):    image = tf.cast(image, tf.float32)    image = image / 255.    return image, imageds_train = ds_train.cache().shuffle(ds_info.splits['train'].num_examples).batch(batch_size, drop_remainder=True)ds_test = ds_test.cache().batch(batch_size, drop_remainder=True).prefetch(batch_size)def Encoder(z_dim):    inputs = layers.Input(shape=[28, 28, 1])    x = Conv2D(filters=8, kernel_size=(3,3), strides=2, padding='same', activation='relu')(x)    x = Conv2D(filters=8, kernel_size=(3,3), strides=1, padding='same', activation='relu')(x)    x = Conv2D(filters=8, kernel_size=(3,3), strides=2, padding='same', activation='relu')(x)    x = Conv2D(filters=8, kernel_size=(3,3), strides=1, padding='same', activation='relu')(x)    x = Flatten()(x)    out = Dense(z_dim, activation='relu')(x)    return Model(inputs=inputs, outputs=out, name='encoder')def Decoder(z_dim):    inputs = layers.Input(shape=[z_dim])    x = Dense(7*7*64, activation='relu')(x)    x = Reshape((7,7,64))(x)    x = Conv2D(filters=64, kernel_size=(3,3), strides=1, padding='same', activation='relu')(x)    x = UpSampling2D((2,2))(x)    x = Conv2D(filters=32, kernel_size=(3,3), strides=1, padding='same', activation='relu')(x)    x = UpSampling2D((2,2))(x)    out = Conv2D(filters=1, kernel_size=(3,3), strides=1, padding='same', activation='sigmoid')(x)    return Model(inputs=inputs, outputs=out, name='decoder')class Autoencoder:    def __init__(self, z_dim):        self.encoder = Encoder(z_dim)        self.decoder = Decoder(z_dim)        self.model_input = self.encoder.input        self.model_output = self.decoder(self.model_input)        self.model = Model(self.model_input, self.model_output)autoencoder = Autoencoder(z_dim=10)# 训练设置model_path = 'autoencoder.h5'checkpoint = ModelCheckpoint(model_path,                         monitor="val_loss",                        verbose=1,                        save_best_only=True,                         mode="auto",                        save_weights_only=False)early = EarlyStopping(monitor="val_loss",                      mode="auto",                      patience=5)callbacks_list = [checkpoint, early]autoencoder.model.compile(loss='mse',                        optimizer='rmsprop',                        lr=3e-4)autoencoder.model.fit(ds_train,                       validation_data=ds_test,                       epochs=100,                       callbacks=callbacks_list)# 加载预训练模型autoencoder.model = load_model(model_path)images, labels = next(iter(ds_test))outputs = autoencoder.model.predict(images)# 显示恢复后的图像plt.figure(figsize=(10, 2))for i in range(0, 64, 2):    plt.figure(figsize=(5, 2))    for j in range(2):        ax = plt.subplot(2, 5, j + i*2)  # 调整图像位置        ax.imshow(images[i, j], cmap='gray')        ax.axis('off')    plt.show()autoencoder_2 = Autoencoder(z_dim=2)model_path_2 = 'autoencoder_2.h5' checkpoint_2 = ModelCheckpoint(model_path_2,                             monitor="val_loss",                             verbose=1,                             save_best_only=True,                              mode="auto",                             save_weights_only=False) early_2 = EarlyStopping(monitor="val_loss",                        mode="auto",                        patience=5) callbacks_list_2 = [checkpoint_2, early_2]autoencoder_2.model.compile(loss="mse",                        optimizer='rmsprop',                        lr=1e-3)autoencoder_2.model.fit(ds_train,                       validation_data=ds_test,                       epochs=50,                       callbacks=callbacks_list_2)images_2, labels_2 = next(iter(ds_test))# 观察潜变量分布encoder_outputs = autoencoder_2.encoder.predict(images_2)plt.figure(figsize=(8, 8))plt.scatter(encoder_outputs[:, 0], encoder_outputs[:, 1], c=labels_2, cmap='RdYlBu', s=3)plt.colorbar()plt.show()z_samples = np.array([[z1, z2] for z1 in np.arange(-5,5,1.) for z2 in np.arange(-5,5,1.)])decoded_images = autoencoder_2.decoder.predict(z_samples)plt.figure(figsize=(10, 10))for i in range(100):    plt.figure(figsize=(5,5))    for j in range(10):        ax = plt.subplot(10, 10, i*10 + j + 1)        ax.imshow(decoded_images[i, j], cmap='gray')        ax.axis('off')plt.show()

结论

通过以上实现,我们成功构建并训练了一个自编码器模型,能够将MNIST数据集中的影像压缩为低维潜变量并还原回高分辨率图像。这种模型不仅能够实现图像压缩,还可以用于图像去噪、风格迁移等多种任务。通过探索潜变量空间,我们还可以发现输入数据中的潜在特征分布,以进一步提升模型性能和应用效果。

转载地址:http://mynuk.baihongyu.com/

你可能感兴趣的文章
mysql 多字段删除重复数据,保留最小id数据
查看>>
MySQL 多表联合查询:UNION 和 JOIN 分析
查看>>
MySQL 大数据量快速插入方法和语句优化
查看>>
mysql 如何给SQL添加索引
查看>>
mysql 字段区分大小写
查看>>
mysql 字段合并问题(group_concat)
查看>>
mysql 字段类型类型
查看>>
MySQL 字符串截取函数,字段截取,字符串截取
查看>>
MySQL 存储引擎
查看>>
mysql 存储过程 注入_mysql 视图 事务 存储过程 SQL注入
查看>>
MySQL 存储过程参数:in、out、inout
查看>>
mysql 存储过程每隔一段时间执行一次
查看>>
mysql 存在update不存在insert
查看>>
Mysql 学习总结(86)—— Mysql 的 JSON 数据类型正确使用姿势
查看>>
Mysql 学习总结(87)—— Mysql 执行计划(Explain)再总结
查看>>
Mysql 学习总结(88)—— Mysql 官方为什么不推荐用雪花 id 和 uuid 做 MySQL 主键
查看>>
Mysql 学习总结(89)—— Mysql 库表容量统计
查看>>
mysql 实现主从复制/主从同步
查看>>
mysql 审核_审核MySQL数据库上的登录
查看>>
mysql 导入 sql 文件时 ERROR 1046 (3D000) no database selected 错误的解决
查看>>