博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
自编码器
阅读量:3898 次
发布时间:2019-05-23

本文共 5778 字,大约阅读时间需要 19 分钟。

在深度学习中,自编码器是非常有用的一种无监督学习模型。自编码器由encoder和decoder组成,前者将原始表示编码成隐层表示,后者将隐层表示解码成原始表示,训练目标为最小化重构误差,它将输入数据压缩到一个潜在表示空间里面,然后再根据这个表示空间将数据进行重构得到最后的输出数据。编码器和解码器都是用神经网络构建的,整个网络的构建方式和普通的神经网络类似,通过最小化输入和输出之间的差异来得到最好的网络。而且一般而言,隐层的特征维度低于原始特征维度。自编码器只是一种思想,在具体实现中,encoder和decoder可以由多种深度学习模型构成,例如全连接层、卷积层或LSTM等,以下使用Keras来实现用于图像去噪的卷积自编码器。

使用Keras来实现自编码器,encoder和decoder使用CNN来实现。且获取数据集MNIST,观察加载的数据结果

from keras.datasets import mnistimport numpy as npimport matplotlib.pyplot as plt(x_train, _), (x_test, _) = mnist.load_data()print(x_train[0])plt.imshow(x_train[0],cmap=plt.cm.binary) # 显示黑白图像plt.show()

结果如下:

print(x_train[0])
[[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 3. 18.
18. 18. 126. 136. 175. 26. 166. 255. 247. 127. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0. 30. 36. 94. 154. 170. 253.
253. 253. 253. 253. 225. 172. 253. 242. 195. 64. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 49. 238. 253. 253. 253. 253. 253.
253. 253. 253. 251. 93. 82. 82. 56. 39. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 18. 219. 253. 253. 253. 253. 253.
198. 182. 247. 241. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0. 80. 156. 107. 253. 253. 205.
11. 0. 43. 154. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 14. 1. 154. 253. 90.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 139. 253. 190.
2. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 11. 190. 253.
70. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 35. 241.
225. 160. 108. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 81.
240. 253. 253. 119. 25. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
45. 186. 253. 253. 150. 27. 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 16. 93. 252. 253. 187. 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 249. 253. 249. 64. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
46. 130. 183. 253. 253. 207. 2. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 39. 148.
229. 253. 253. 253. 250. 182. 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 24. 114. 221. 253.
253. 253. 253. 201. 78. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0. 23. 66. 213. 253. 253. 253.
253. 198. 81. 2. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 18. 171. 219. 253. 253. 253. 253. 195.
80. 9. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 55. 172. 226. 253. 253. 253. 253. 244. 133. 11.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 136. 253. 253. 253. 212. 135. 132. 16. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
在这里插入图片描述
将MNIST像素点值转化到0-1区间,并且重塑为N×1×28×28的四维tensor。

x_train = x_train.astype('float32') / 255.x_test = x_test.astype('float32') / 255.x_train = np.reshape(x_train, (len(x_train), 28, 28, 1))x_test = np.reshape(x_test, (len(x_test), 28, 28, 1))

添加噪声,即叠加一个随机的高斯白噪声,并限制加噪之后的值仍处于0-1区间。

noise_factor = 0.5x_train_noisy = x_train + noise_factor * np.random.normal(loc=0.0, scale=1.0, size=x_train.shape) x_test_noisy = x_test + noise_factor * np.random.normal(loc=0.0, scale=1.0, size=x_test.shape) x_train_noisy = np.clip(x_train_noisy, 0., 1.)x_test_noisy = np.clip(x_test_noisy, 0., 1.)

看一下测试集前十张图片加噪以后的结果

n = 10plt.figure(figsize=(20, 2))for i in range(n):    ax = plt.subplot(1, n, i + 1)    plt.imshow(x_test_noisy[i].reshape(28, 28))    plt.gray()    ax.get_xaxis().set_visible(False)    ax.get_yaxis().set_visible(False)plt.show()

在这里插入图片描述

定义模型的输入。接下来定义encoder部分,由两个32×3×3的卷积层和两个2×2的最大池化层组成

input_img = Input(shape=(28, 28, 1,)) # N * 28 * 28 * 1x = Conv2D(32, (3, 3), padding='same', activation='relu')(input_img) # 28 * 28 * 32x = MaxPooling2D((2, 2), padding='same')(x) # 14 * 14 * 32x = Conv2D(32, (3, 3), padding='same', activation='relu')(x) # 14 * 14 * 32encoded = MaxPooling2D((2, 2), padding='same')(x) # 7 * 7 * 32

定义decoder部分,由两个32×3×3的卷积层和两个2×2的上采样层组成。

# 7 * 7 * 32x = Conv2D(32, (3, 3), padding='same', activation='relu')(encoded) # 7 * 7 * 32x = UpSampling2D((2, 2))(x) # 14 * 14 * 32x = Conv2D(32, (3, 3), padding='same', activation='relu')(x) # 14 * 14 * 32x = UpSampling2D((2, 2))(x) # 28 * 28 * 32decoded = Conv2D(1, (3, 3), padding='same', activation='sigmoid')(x) # 28 * 28 * 1

将输入和输出连接起来,构成autoencoder并compile。

autoencoder = Model(input_img, decoded)autoencoder.compile(optimizer='adadelta', loss='binary_crossentropy')

使用x_train作为输入和输出来训练我们的autoencoder,并使用x_test进行validation。

autoencoder.fit(x_train_noisy, x_train,                epochs=100,                batch_size=128,                shuffle=True,                validation_data=(x_test_noisy, x_test))

使用autoencoder对x_test预测,并将预测结果绘制出来,和原始加噪图像进行对比

n = 10plt.figure(figsize=(20, 4))for i in range(n):    # display original    ax = plt.subplot(2, n, i + 1)    plt.imshow(x_test_noisy[i].reshape(28, 28))    plt.gray()    ax.get_xaxis().set_visible(False)    ax.get_yaxis().set_visible(False)     # display reconstruction    ax = plt.subplot(2, n, i + 1 + n)    plt.imshow(decoded_imgs[i].reshape(28, 28))    plt.gray()    ax.get_xaxis().set_visible(False)    ax.get_yaxis().set_visible(False)plt.show()

在这里插入图片描述

在这里插入图片描述

转载地址:http://bmyen.baihongyu.com/

你可能感兴趣的文章
How To Set Up vsftpd on Ubuntu 12.04
查看>>
实例演示如何使用WordPress自定义字段
查看>>
在 WordPress 指定页面加载指定 JavaScript 或 CSS 代码
查看>>
Apache配置多个监听端口和不同的网站目录的简单方法
查看>>
Linux 搭建 discuz 论坛
查看>>
如何在discuz帖子中插入视频
查看>>
怎么更改织梦网站logo和默认广告
查看>>
织梦系统如何插入优酷视频?
查看>>
Discuz设置特定用户组不启用验证码发帖权限
查看>>
百度云服务器 CentOS 图形界面支持
查看>>
为什么要使用R语言?历数R的优势与缺点
查看>>
[小技巧] Linux 下查询图片的大小
查看>>
Linus Torvalds说那些对人工智能奇点深信不疑的人显然磕了药
查看>>
[小技巧] svn: 不能解析 URL
查看>>
USB_ModeSwitch 介绍
查看>>
大公司和小公司的抢人战,孰胜孰负?
查看>>
通过make编译多文件的内核模块
查看>>
如何调试Javascript代码
查看>>
皮克斯宣布开源Universal Scene Description
查看>>
复盘:一个创业项目的失败之路
查看>>