Keras实现将两个模型连接到一起


Posted in Python onMay 23, 2020

神经网络玩得越久就越会尝试一些网络结构上的大改动。

先说意图

有两个模型:模型A和模型B。模型A的输出可以连接B的输入。将两个小模型连接成一个大模型,A-B,既可以同时训练又可以分离训练。

流行的算法里经常有这么关系的两个模型,对GAN来说,生成器和判别器就是这样子;对VAE来说,编码器和解码器就是这样子;对目标检测网络来说,backbone和整体也是可以拆分的。所以,应用范围还是挺广的。

实现方法

首先说明,我的实现方法不一定是最佳方法。也是实在没有借鉴到比较好的方法,所以才自己手动写了一个。

第一步,我们有现成的两个模型A和B;我们想把A的输出连到B的输入,组成一个整体C。

第二步, 重构新模型C;我的方法是:读出A和B各有哪些layer,然后一层一层重新搭成C。

可以看一个自编码器的代码(本人所编写):

class AE:
 def __init__(self, dim, img_dim, batch_size):
  self.dim = dim
  self.img_dim = img_dim
  self.batch_size = batch_size
  self.encoder = self.encoder_construct()
  self.decoder = self.decoder_construct()
 
 def encoder_construct(self):
  x_in = Input(shape=(self.img_dim, self.img_dim, 3))
  x = x_in
  x = Conv2D(self.dim // 16, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = Conv2D(self.dim // 8, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = Conv2D(self.dim // 4, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = Conv2D(self.dim // 2, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = Conv2D(self.dim, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = GlobalAveragePooling2D()(x)
  encoder = Model(x_in, x)
  return encoder
 
 def decoder_construct(self):
  map_size = K.int_shape(self.encoder.layers[-2].output)[1:-1]
  # print(type(map_size))
  z_in = Input(shape=K.int_shape(self.encoder.output)[1:])
  z = z_in
  z_dim = self.dim
  z = Dense(np.prod(map_size) * z_dim)(z)
  z = Reshape(map_size + (z_dim,))(z)
  z = Conv2DTranspose(z_dim // 2, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = BatchNormalization()(z)
  z = Activation('relu')(z)
  z = Conv2DTranspose(z_dim // 4, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = BatchNormalization()(z)
  z = Activation('relu')(z)
  z = Conv2DTranspose(z_dim // 8, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = BatchNormalization()(z)
  z = Activation('relu')(z)
  z = Conv2DTranspose(z_dim // 16, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = BatchNormalization()(z)
  z = Activation('relu')(z)
  z = Conv2DTranspose(3, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = Activation('tanh')(z)
  decoder = Model(z_in, z)
  return decoder
 
 def build_ae(self):
  input_x = Input(shape=(self.img_dim, self.img_dim, 3))
  x = input_x
  for i in range(1, len(self.encoder.layers)):
   x = self.encoder.layers[i](x)
  for j in range(1, len(self.decoder.layers)):
   x = self.decoder.layers[j](x)
  y = x
  auto_encoder = Model(input_x, y)
  return auto_encoder

模型A就是这里的encoder,模型B就是这里的decoder。所以,连接的精髓在build_ae()函数,直接用for循环读出各层,然后一层一层重新构造新的模型,从而实现连接效果。因为keras也是基于图的框架,这个操作并不会很费时,因为没有实际地计算。

补充知识:keras得到每层的系数

使用keras搭建好一个模型,训练好,怎么得到每层的系数呢:

weights = np.array(model.get_weights())
print(weights)
print(weights[0].shape)
print(weights[1].shape)

这样系数就被存放到一个np中了。

以上这篇Keras实现将两个模型连接到一起就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实例之wxpython中Frame使用方法
Jun 09 Python
Python单元测试框架unittest使用方法讲解
Apr 13 Python
Django中的CACHE_BACKEND参数和站点级Cache设置
Jul 23 Python
最大K个数问题的Python版解法总结
Jun 16 Python
Python利用BeautifulSoup解析Html的方法示例
Jul 30 Python
python使用pyqt写带界面工具的示例代码
Oct 23 Python
Django视图之ORM数据库查询操作API的实例
Oct 27 Python
Python实现PS图像抽象画风效果的方法
Jan 23 Python
python3.6环境安装+pip环境配置教程图文详解
Jun 20 Python
django xadmin中form_layout添加字段显示方式
Mar 30 Python
Django 解决阿里云部署同步数据库报错的问题
May 14 Python
python利用faker库批量生成测试数据
Oct 15 Python
keras 获取某层输出 获取复用层的多次输出实例
May 23 #Python
给keras层命名,并提取中间层输出值,保存到文档的实例
May 23 #Python
keras小技巧——获取某一个网络层的输出方式
May 23 #Python
keras自定义回调函数查看训练的loss和accuracy方式
May 23 #Python
Keras设定GPU使用内存大小方式(Tensorflow backend)
May 22 #Python
tensorflow使用L2 regularization正则化修正overfitting过拟合方式
May 22 #Python
Softmax函数原理及Python实现过程解析
May 22 #Python
You might like
FCKeditor添加自定义按钮
2008/03/27 PHP
PHP里的中文变量说明
2011/07/23 PHP
PHP中spl_autoload_register()和__autoload()区别分析
2014/05/10 PHP
基于PHP常用文件函数和目录函数整理
2017/08/17 PHP
基于PHP实现短信验证码发送次数限制
2020/07/11 PHP
Jquery chosen动态设置值实例介绍
2013/08/08 Javascript
jQuery 设置 CSS 属性示例介绍
2014/01/16 Javascript
javascript 对象数组根据对象object key的值排序
2015/03/09 Javascript
jquery选择器简述
2015/08/31 Javascript
前端js文件合并的三种方式推荐
2016/05/19 Javascript
AngularJS基础 ng-include 指令示例讲解
2016/08/01 Javascript
js实现添加可信站点、修改activex安全设置,禁用弹出窗口阻止程序
2016/08/17 Javascript
bootstrap侧边栏圆点导航
2017/01/11 Javascript
Bootstrap实现的经典栅格布局效果实例【附demo源码】
2017/03/30 Javascript
JS+canvas动态绘制饼图的方法示例
2017/09/12 Javascript
使用SVG基本操作API的实例讲解
2017/09/14 Javascript
WebSocket的简单介绍及应用
2019/05/23 Javascript
jQuery实现tab栏切换效果
2020/12/22 jQuery
编写Python脚本使得web页面上的代码高亮显示
2015/04/24 Python
在Mac OS系统上安装Python的Pillow库的教程
2015/11/20 Python
Python中的time模块与datetime模块用法总结
2016/06/30 Python
python爬取指定微信公众号文章
2018/12/20 Python
Python2和Python3之间的str处理方式导致乱码的讲解
2019/01/03 Python
pycharm 添加解释器的方法步骤
2020/08/31 Python
浅析Python模块之间的相互引用问题
2021/02/26 Python
工程管理专业个人求职信范文
2013/12/07 职场文书
优秀中学生事迹材料
2014/01/31 职场文书
毕业生如何写自我鉴定
2014/03/15 职场文书
五一活动标语
2014/06/30 职场文书
工厂见习报告范文
2014/10/31 职场文书
2014年艾滋病防治工作总结
2014/12/10 职场文书
出国留学自荐信模板
2015/03/06 职场文书
结婚典礼致辞
2015/07/28 职场文书
关于golang高并发的实现与注意事项说明
2021/05/08 Golang
如何在向量化NumPy数组上进行移动窗口
2021/05/18 Python
Python基于百度AI实现抓取表情包
2021/06/27 Python