博客
关于我
自编码器模型详解与实现(采用tensorflow2.x实现)
阅读量:783 次
发布时间:2019-03-25

本文共 7897 字,大约阅读时间需要 26 分钟。

自编码器模型详解与实现(采用TensorFlow 2.x 实现)

1. 自编码器与潜变量学习概述

自编码器是一种在无监督学习中广泛应用的深度学习模型,由Geoffrey Hinton等人于1980年代首次提出。它的核心目标是通过压缩高维输入空间到低维潜变量Representation(以下简称“潜”),并在解码阶段将这些潜还原为原始高维输入。这种能力使其在图像处理、材质分离等领域具有重要应用价值。

在图像处理领域,自编码器可以类比于数据压缩与解压过程。例如,就如JPEG将高分辨率图像压缩为小文件格式一样,自编码器则可以将原始图像压缩为低维潜变量,再通过解码器还原回高分辨率图像。这使得自编码器成为一种高效的图像压缩与恢复工具。

2. 自编码器架构详解

2.1 编码器设计

编码器负责将高维输入通过一系列的神经网络层压缩为低维潜变量。我们以MNIST数据集为例,构建一个适用于28x28x1输入尺寸的编码器。潜变量的维度设置为低于输入维度的超参数,这里采用10维。

以下是编码器的实现代码框架:

def Encoder(z_dim):    inputs = layers.Input(shape=[28, 28, 1])    x = Conv2D(filters=8, kernel_size=(3,3), strides=2, padding='same', activation='relu')(x)    x = Conv2D(filters=8, kernel_size=(3,3), strides=1, padding='same', activation='relu')(x)    x = Conv2D(filters=8, kernel_size=(3,3), strides=2, padding='same', activation='relu')(x)    x = Conv2D(filters=8, kernel_size=(3,3), strides=1, padding='same', activation='relu')(x)    x = Flatten()(x)    out = Dense(z_dim, activation='relu')(x)    return Model(inputs=inputs, outputs=out, name='encoder')

编码器主要包含卷积层和全连接层。卷积层用于提取高层次的特征,同时通过调整卷积核的步长(如2)实现特征图的下采样,逐步减少输入的高维信息。全连接层则负责将多个特征图融合到低维潜变量空间中。

2.2 解码器设计

解码器的任务是将低维潜变量还原为高维图像。其结构与编码器相似,但需在解码过程中通过卷积层和上采样操作逐步还原特征图。

以下是解码器的实现代码框架:

def Decoder(z_dim):    inputs = layers.Input(shape=[z_dim])    x = Dense(7*7*64, activation='relu')(x)    x = Reshape((7,7,64))(x)    x = Conv2D(filters=64, kernel_size=(3,3), strides=1, padding='same', activation='relu')(x)    x = UpSampling2D((2,2))(x)    x = Conv2D(filters=32, kernel_size=(3,3), strides=1, padding='same', activation='relu')(x)    x = UpSampling2D((2,2))(x)    out = Conv2D(filters=1, kernel_size=(3,3), strides=1, padding='same', activation='sigmoid')(x)    return Model(inputs=inputs, outputs=out, name='decoder')

解码器通过卷积层在低维空间中生成特征图,并结合上采样操作逐渐将特征图还原为原始图像尺寸。上采样方法包括卷积核转置(例如UpSampling2D)或仿射变换,但后者是不训练的参数,通常不适合深度学习模型。

3. 自编码器模型构建

将编码器和解码器组合,构建完整的自编码器模型:

z_dim = 10encoder = Encoder(z_dim)decoder = Decoder(z_dim)model_input = encoder.inputmodel_output = decoder(encoder.output)autoencoder = Model(model_input, model_output)
3.1 模型训练

为了训练模型,我们采用MSE(均方误差)损失函数,旨在最小化编码器输出与解码器预测值之间的差异。同时,使用一些训练回调(如ModelCheckpointEarlyStopping)来优化训练过程。

autoencoder.compile(loss='mse', optimizer='rmsprop', lr=3e-4)

训练过程中,我们需要分成训练集和验证集,定期保存最佳模型参数以防止过拟合。

4. 从潜变量生成图像

自编码器的潜变量具有潜在的生成能力。比如,如果我们定义另一个解码器仅使用潜变量生成图像,可以利用这个能力进行高效的图像生成。

z_dim = 2  # 定义更低维的潜变量autoencoder_2 = Autoencoder(z_dim=2)

通过对潜变量进行采样,可以生成大量不同样本。如上图所示,我们采用2维潜变量空间,生成500个样本,散布在二维平面上。通过观察标签分布图,可以发现某些类别的潜变量代表性较强,而另一些类别则相对模糊。

更进一步地,我们可以通过滑动窗口或交互式工具(如下图所示),进行潜变量的可视化和探索。

from ipywidgets import interact, interact_manual@interactdef explore_latent_variable(z1=(-5,5,0.1), z2=(-5,5,0.1)):    z_samples = np.array([[z1, z2] for z2 in np.arange(-5,5,0.1)] for z1 in np.arange(-5,5,0.1))    images = autoencoder_2.decoder.predict(z_samples)    plt.figure(figsize=(2,2))    plt.imshow(images[0,:,:,0], cmap='gray')

完整代码示例

import tensorflow as tffrom tensorflow.keras import layers, Modelfrom tensorflow.keras.layers import Input, Conv2D, Dense, Flatten, Reshape, Conv2DTranspose, MaxPooling2D, UpSampling2D, LeakyReLUfrom tensorflow.keras.activations import relufrom tensorflow.keras.models import Sequential, load_modelfrom tensorflow.keras.callbacks import ModelCheckpoint, EarlyStoppingimport tensorflow_datasets as tfdsimport numpy as npimport matplotlib.pyplot as pltimport warningswarnings.filterwarnings('ignore')print(tf.__version__)# 加载MNIST数据集(ds_train, ds_test), ds_info = tfds.load(    'mnist',    split=['train', 'test'],    shuffle_files=True,    as_supervised=True,    with_info=True)# 预处理数据def preprocess(image, label):    image = tf.cast(image, tf.float32)    image = image / 255.    return image, imageds_train = ds_train.cache().shuffle(ds_info.splits['train'].num_examples).batch(batch_size, drop_remainder=True)ds_test = ds_test.cache().batch(batch_size, drop_remainder=True).prefetch(batch_size)def Encoder(z_dim):    inputs = layers.Input(shape=[28, 28, 1])    x = Conv2D(filters=8, kernel_size=(3,3), strides=2, padding='same', activation='relu')(x)    x = Conv2D(filters=8, kernel_size=(3,3), strides=1, padding='same', activation='relu')(x)    x = Conv2D(filters=8, kernel_size=(3,3), strides=2, padding='same', activation='relu')(x)    x = Conv2D(filters=8, kernel_size=(3,3), strides=1, padding='same', activation='relu')(x)    x = Flatten()(x)    out = Dense(z_dim, activation='relu')(x)    return Model(inputs=inputs, outputs=out, name='encoder')def Decoder(z_dim):    inputs = layers.Input(shape=[z_dim])    x = Dense(7*7*64, activation='relu')(x)    x = Reshape((7,7,64))(x)    x = Conv2D(filters=64, kernel_size=(3,3), strides=1, padding='same', activation='relu')(x)    x = UpSampling2D((2,2))(x)    x = Conv2D(filters=32, kernel_size=(3,3), strides=1, padding='same', activation='relu')(x)    x = UpSampling2D((2,2))(x)    out = Conv2D(filters=1, kernel_size=(3,3), strides=1, padding='same', activation='sigmoid')(x)    return Model(inputs=inputs, outputs=out, name='decoder')class Autoencoder:    def __init__(self, z_dim):        self.encoder = Encoder(z_dim)        self.decoder = Decoder(z_dim)        self.model_input = self.encoder.input        self.model_output = self.decoder(self.model_input)        self.model = Model(self.model_input, self.model_output)autoencoder = Autoencoder(z_dim=10)# 训练设置model_path = 'autoencoder.h5'checkpoint = ModelCheckpoint(model_path,                         monitor="val_loss",                        verbose=1,                        save_best_only=True,                         mode="auto",                        save_weights_only=False)early = EarlyStopping(monitor="val_loss",                      mode="auto",                      patience=5)callbacks_list = [checkpoint, early]autoencoder.model.compile(loss='mse',                        optimizer='rmsprop',                        lr=3e-4)autoencoder.model.fit(ds_train,                       validation_data=ds_test,                       epochs=100,                       callbacks=callbacks_list)# 加载预训练模型autoencoder.model = load_model(model_path)images, labels = next(iter(ds_test))outputs = autoencoder.model.predict(images)# 显示恢复后的图像plt.figure(figsize=(10, 2))for i in range(0, 64, 2):    plt.figure(figsize=(5, 2))    for j in range(2):        ax = plt.subplot(2, 5, j + i*2)  # 调整图像位置        ax.imshow(images[i, j], cmap='gray')        ax.axis('off')    plt.show()autoencoder_2 = Autoencoder(z_dim=2)model_path_2 = 'autoencoder_2.h5' checkpoint_2 = ModelCheckpoint(model_path_2,                             monitor="val_loss",                             verbose=1,                             save_best_only=True,                              mode="auto",                             save_weights_only=False) early_2 = EarlyStopping(monitor="val_loss",                        mode="auto",                        patience=5) callbacks_list_2 = [checkpoint_2, early_2]autoencoder_2.model.compile(loss="mse",                        optimizer='rmsprop',                        lr=1e-3)autoencoder_2.model.fit(ds_train,                       validation_data=ds_test,                       epochs=50,                       callbacks=callbacks_list_2)images_2, labels_2 = next(iter(ds_test))# 观察潜变量分布encoder_outputs = autoencoder_2.encoder.predict(images_2)plt.figure(figsize=(8, 8))plt.scatter(encoder_outputs[:, 0], encoder_outputs[:, 1], c=labels_2, cmap='RdYlBu', s=3)plt.colorbar()plt.show()z_samples = np.array([[z1, z2] for z1 in np.arange(-5,5,1.) for z2 in np.arange(-5,5,1.)])decoded_images = autoencoder_2.decoder.predict(z_samples)plt.figure(figsize=(10, 10))for i in range(100):    plt.figure(figsize=(5,5))    for j in range(10):        ax = plt.subplot(10, 10, i*10 + j + 1)        ax.imshow(decoded_images[i, j], cmap='gray')        ax.axis('off')plt.show()

结论

通过以上实现,我们成功构建并训练了一个自编码器模型,能够将MNIST数据集中的影像压缩为低维潜变量并还原回高分辨率图像。这种模型不仅能够实现图像压缩,还可以用于图像去噪、风格迁移等多种任务。通过探索潜变量空间,我们还可以发现输入数据中的潜在特征分布,以进一步提升模型性能和应用效果。

转载地址:http://mynuk.baihongyu.com/

你可能感兴趣的文章
MTTR、MTBF、MTTF的大白话理解
查看>>
mt_rand
查看>>
mysql -存储过程
查看>>
mysql /*! 50100 ... */ 条件编译
查看>>
mudbox卸载/完美解决安装失败/如何彻底卸载清除干净mudbox各种残留注册表和文件的方法...
查看>>
mysql 1264_关于mysql 出现 1264 Out of range value for column 错误的解决办法
查看>>
mysql 1593_Linux高可用(HA)之MySQL主从复制中出现1593错误码的低级错误
查看>>
mysql 5.6 修改端口_mysql5.6.24怎么修改端口号
查看>>
MySQL 8.0 恢复孤立文件每表ibd文件
查看>>
MySQL 8.0开始Group by不再排序
查看>>
mysql ansi nulls_SET ANSI_NULLS ON SET QUOTED_IDENTIFIER ON 什么意思
查看>>
multi swiper bug solution
查看>>
MySQL Binlog 日志监听与 Spring 集成实战
查看>>
MySQL binlog三种模式
查看>>
multi-angle cosine and sines
查看>>
Mysql Can't connect to MySQL server
查看>>
mysql case when 乱码_Mysql CASE WHEN 用法
查看>>
Multicast1
查看>>
mysql client library_MySQL数据库之zabbix3.x安装出现“configure: error: Not found mysqlclient library”的解决办法...
查看>>
MySQL Cluster 7.0.36 发布
查看>>