keras自动编码器实现系列之卷积自动编码器操作


Posted in Python onJuly 03, 2020

图片的自动编码很容易就想到用卷积神经网络做为编码-解码器。在实际的操作中,

也经常使用卷积自动编码器去解决图像编码问题,而且非常有效。

下面通过**keras**完成简单的卷积自动编码。 编码器有堆叠的卷积层和池化层(max pooling用于空间降采样)组成。 对应的解码器由卷积层和上采样层组成。

@requires_authorization
# -*- coding:utf-8 -*-

from keras.layers import Input, Dense, Conv2D, MaxPooling2D, UpSampling2D
from keras.models import Model
from keras import backend as K
import os

## 网络结构 ##
input_img = Input(shape=(28,28,1)) # Tensorflow后端, 注意要用channel_last
# 编码器部分
x = Conv2D(16, (3,3), activation='relu', padding='same')(input_img)
x = MaxPooling2D((2,2), padding='same')(x)
x = Conv2D(8,(3,3), activation='relu', padding='same')(x)
x = MaxPooling2D((2,2), padding='same')(x)
x = Conv2D(8, (3,3), activation='relu', padding='same')(x)
encoded = MaxPooling2D((2,2), padding='same')(x)

# 解码器部分
x = Conv2D(8, (3,3), activation='relu', padding='same')(encoded)
x = UpSampling2D((2, 2))(x)
x = Conv2D(8, (3,3), activation='relu', padding='same')(x) 
x = UpSampling2D((2, 2))(x)
x = Conv2D(16, (3, 3), activation='relu', padding='same')(x)
x = UpSampling2D((2, 2))(x)
decoded = Conv2D(1, (3, 3), activation='sigmoid', padding='same')(x)

autoencoder = Model(input_img, decoded)
autoencoder.compile(optimizer='adam', loss='binary_crossentropy')

# 得到编码层的输出
encoder_model = Model(inputs=autoencoder.input, outputs=autoencoder.get_layer('encoder_out').output)

## 导入数据, 使用常用的手写识别数据集
def load_mnist(dataset_name):
'''
load the data
'''
  data_dir = os.path.join("./data", dataset_name)
  f = np.load(os.path.join(data_dir, 'mnist.npz'))
  train_data = f['train'].T
  trX = train_data.reshape((-1, 28, 28, 1)).astype(np.float32)
  trY = f['train_labels'][-1].astype(np.float32)
  test_data = f['test'].T
  teX = test_data.reshape((-1, 28, 28, 1)).astype(np.float32)
  teY = f['test_labels'][-1].astype(np.float32)

  # one-hot 
  # y_vec = np.zeros((len(y), 10), dtype=np.float32)
  # for i, label in enumerate(y):
  #   y_vec[i, y[i]] = 1
  # keras.utils里带的有one-hot的函数, 就直接用那个了
  return trX / 255., trY, teX/255., teY

# 开始导入数据
x_train, _ , x_test, _= load_mnist('mnist')

# 可视化训练结果, 我们打开终端, 使用tensorboard
# tensorboard --logdir=/tmp/autoencoder # 注意这里是打开一个终端, 在终端里运行

# 训练模型, 并且在callbacks中使用tensorBoard实例, 写入训练日志 http://0.0.0.0:6006
from keras.callbacks import TensorBoard
autoencoder.fit(x_train, x_train,
        epochs=50,
        batch_size=128,
        shuffle=True,
        validation_data=(x_test, x_test),
        callbacks=[TensorBoard(log_dir='/tmp/autoencoder')])

# 重建图片
import matplotlib.pyplot as plt 
decoded_imgs = autoencoder.predict(x_test)
encoded_imgs = encoder_model.predict(x_test)
n = 10
plt.figure(figsize=(20, 4))
for i in range(n):
  k = i + 1
  # 画原始图片
  ax = plt.subplot(2, n, k)
  plt.imshow(x_test[k].reshape(28, 28))
  plt.gray()
  ax.get_xaxis().set_visible(False)
  # 画重建图片
  ax = plt.subplot(2, n, k + n)
  plt.imshow(decoded_imgs[i].reshape(28, 28))
  plt.gray()
  ax.get_xaxis().set_visible(False)
  ax.get_yaxis().set_visible(False)
plt.show()

# 编码得到的特征
n = 10
plt.figure(figsize=(20, 8))
for i in range(n):
  k = i + 1
  ax = plt.subplot(1, n, k)
  plt.imshow(encoded[k].reshape(4, 4 * 8).T)
  plt.gray()
  ax.get_xaxis().set_visible(False)
  ax.get_yaxis().set_visible(False)
plt.show()

补充知识:keras搬砖系列-单层卷积自编码器

考试成绩出来了,竟然有一门出奇的差,只是有点意外。

觉得应该不错的,竟然考差了,它估计写了个随机数吧。

头文件

from keras.layers import Input,Dense
from keras.models import Model 
from keras.datasets import mnist
import numpy as np 
import matplotlib.pyplot as plt

导入数据

(X_train,_),(X_test,_) = mnist.load_data()
 
X_train = X_train.astype('float32')/255.
X_test = X_test.astype('float32')/255.
X_train = X_train.reshape((len(X_train),-1))
X_test = X_test.reshape((len(X_test),-1))

这里的X_train和X_test的维度分别为(60000L,784L),(10000L,784L)

这里进行了归一化,将所有的数值除上255.

设定编码的维数与输入数据的维数

encoding_dim = 32

input_img = Input(shape=(784,))

构建模型

encoded = Dense(encoding_dim,activation='relu')(input_img)
decoded = Dense(784,activation='relu')(encoded)
 
autoencoder = Model(inputs = input_img,outputs=decoded)
encoder = Model(inputs=input_img,outputs=encoded)
 
encoded_input = Input(shape=(encoding_dim,))
decoder_layer = autoencoder.layers[-1]
deconder = Model(inputs=encoded_input,outputs = decoder_layer(encoded_input))

模型编译

autoencoder.compile(optimizer='adadelta',loss='binary_crossentropy')

模型训练

autoencoder.fit(X_train,X_train,epochs=50,batch_size=256,shuffle=True,validation_data=(X_test,X_test))

预测

encoded_imgs = encoder.predict(X_test)

decoded_imgs = deconder.predict(encoded_imgs)

数据可视化

n = 10
for i in range(n):
 ax = plt.subplot(2,n,i+1)
 plt.imshow(X_test[i].reshape(28,28))
 plt.gray()
 ax.get_xaxis().set_visible(False)
 ax.get_yaxis().set_visible(False)
 
 ax = plt.subplot(2,n,i+1+n)
 plt.imshow(decoded_imgs[i].reshape(28,28))
 plt.gray()
 ax.get_xaxis().set_visible(False)
 ax.get_yaxis().set_visible(False)
plt.show()

完成代码

from keras.layers import Input,Dense
from keras.models import Model 
from keras.datasets import mnist
import numpy as np 
import matplotlib.pyplot as plt 
 
(X_train,_),(X_test,_) = mnist.load_data()
 
X_train = X_train.astype('float32')/255.
X_test = X_test.astype('float32')/255.
X_train = X_train.reshape((len(X_train),-1))
X_test = X_test.reshape((len(X_test),-1))
 
encoding_dim = 32
input_img = Input(shape=(784,))
 
encoded = Dense(encoding_dim,activation='relu')(input_img)
decoded = Dense(784,activation='relu')(encoded)
 
autoencoder = Model(inputs = input_img,outputs=decoded)
encoder = Model(inputs=input_img,outputs=encoded)
 
encoded_input = Input(shape=(encoding_dim,))
decoder_layer = autoencoder.layers[-1]
deconder = Model(inputs=encoded_input,outputs = decoder_layer(encoded_input))
 
autoencoder.compile(optimizer='adadelta',loss='binary_crossentropy')
autoencoder.fit(X_train,X_train,epochs=50,batch_size=256,shuffle=True,validation_data=(X_test,X_test))
 
encoded_imgs = encoder.predict(X_test)
decoded_imgs = deconder.predict(encoded_imgs)
 
##via
n = 10
for i in range(n):
 ax = plt.subplot(2,n,i+1)
 plt.imshow(X_test[i].reshape(28,28))
 plt.gray()
 ax.get_xaxis().set_visible(False)
 ax.get_yaxis().set_visible(False)
 
 ax = plt.subplot(2,n,i+1+n)
 plt.imshow(decoded_imgs[i].reshape(28,28))
 plt.gray()
 ax.get_xaxis().set_visible(False)
 ax.get_yaxis().set_visible(False)
plt.show()

以上这篇keras自动编码器实现系列之卷积自动编码器操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python之yield表达式学习
Sep 02 Python
在Python中操作字典之update()方法的使用
May 22 Python
对pandas的层次索引与取值的新方法详解
Nov 06 Python
浅谈Pandas:Series和DataFrame间的算术元素
Dec 22 Python
Python设计模式之原型模式实例详解
Jan 18 Python
pandas把所有大于0的数设置为1的方法
Jan 26 Python
python中update的基本使用方法详解
Jul 17 Python
详解Python是如何实现issubclass的
Jul 24 Python
python selenium爬取斗鱼所有直播房间信息过程详解
Aug 09 Python
关于Python Tkinter Button控件command传参问题的解决方式
Mar 04 Python
tensorflow指定CPU与GPU运算的方法实现
Apr 21 Python
python使用多线程+socket实现端口扫描
May 28 Python
Python with语句用法原理详解
Jul 03 #Python
Keras搭建自编码器操作
Jul 03 #Python
python 识别登录验证码图片功能的实现代码(完整代码)
Jul 03 #Python
python图片验证码识别最新模块muggle_ocr的示例代码
Jul 03 #Python
keras topN显示,自编写代码案例
Jul 03 #Python
python如何使用代码运行助手
Jul 03 #Python
Python 3.10 的首个 PEP 诞生,内置类型 zip() 迎来新特性(推荐)
Jul 03 #Python
You might like
dede全站URL静态化改造[070414更正]
2007/04/17 PHP
PHP 函数语法介绍一
2009/06/14 PHP
php开发留言板的CRUD(增,删,改,查)操作
2012/04/19 PHP
CI(CodeIgniter)框架配置
2014/06/10 PHP
利用phpexcel对数据库数据的导入excel(excel筛选)、导出excel
2017/04/27 PHP
thinkphp框架实现路由重定义简化url访问地址的方法分析
2020/04/04 PHP
PHP标准库 (SPL)――Countable用法示例
2020/06/05 PHP
基于jquery的loading效果实现代码
2010/11/05 Javascript
js切换div css注意的细节
2012/12/10 Javascript
JS随机生成不重复数据的实例方法
2013/07/17 Javascript
jQuery filter函数使用方法
2014/05/19 Javascript
jquery+easeing实现仿flash的载入动画
2015/03/10 Javascript
jQuery图片旋转插件jQueryRotate.js用法实例(附demo下载)
2016/01/21 Javascript
JS实现中文汉字按拼音排序的方法
2017/10/09 Javascript
微信小程序页面跳转功能之从列表的item项跳转到下一个页面的方法
2017/11/27 Javascript
探索Vue高阶组件的使用
2018/01/08 Javascript
Vue 实现分页与输入框关键字筛选功能
2020/01/02 Javascript
详解JavaScript作用域、作用域链和闭包的用法
2020/09/03 Javascript
详解Python的单元测试
2015/04/28 Python
python3制作捧腹网段子页爬虫
2017/02/12 Python
Python实现二维数组输出为图片
2018/04/03 Python
如何用Python做一个微信机器人自动拉群
2019/07/03 Python
Django处理Ajax发送的Get请求代码详解
2019/07/29 Python
Python 实现加密过的PDF文件转WORD格式
2020/02/04 Python
python实现猜数游戏
2020/03/27 Python
使用 CSS3 中@media 实现网页自适应的示例代码
2020/03/24 HTML / CSS
Ajxa常见问题都有哪些
2014/03/26 面试题
实习生求职自荐信
2014/02/07 职场文书
交通事故协议书范文
2014/10/23 职场文书
优秀党务工作者先进事迹材料
2014/12/25 职场文书
涨价通知
2015/04/23 职场文书
2015年小学二年级班主任工作总结
2015/05/21 职场文书
汽车修理厂管理制度
2015/08/05 职场文书
药品销售员2015年终工作总结
2015/10/22 职场文书
2016年度师德标兵先进事迹材料
2016/02/26 职场文书
导游词之山西祁县乔家大院
2019/10/14 职场文书