Keras实现将两个模型连接到一起


Posted in Python onMay 23, 2020

神经网络玩得越久就越会尝试一些网络结构上的大改动。

先说意图

有两个模型:模型A和模型B。模型A的输出可以连接B的输入。将两个小模型连接成一个大模型,A-B,既可以同时训练又可以分离训练。

流行的算法里经常有这么关系的两个模型,对GAN来说,生成器和判别器就是这样子;对VAE来说,编码器和解码器就是这样子;对目标检测网络来说,backbone和整体也是可以拆分的。所以,应用范围还是挺广的。

实现方法

首先说明,我的实现方法不一定是最佳方法。也是实在没有借鉴到比较好的方法,所以才自己手动写了一个。

第一步,我们有现成的两个模型A和B;我们想把A的输出连到B的输入,组成一个整体C。

第二步, 重构新模型C;我的方法是:读出A和B各有哪些layer,然后一层一层重新搭成C。

可以看一个自编码器的代码(本人所编写):

class AE:
 def __init__(self, dim, img_dim, batch_size):
  self.dim = dim
  self.img_dim = img_dim
  self.batch_size = batch_size
  self.encoder = self.encoder_construct()
  self.decoder = self.decoder_construct()
 
 def encoder_construct(self):
  x_in = Input(shape=(self.img_dim, self.img_dim, 3))
  x = x_in
  x = Conv2D(self.dim // 16, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = Conv2D(self.dim // 8, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = Conv2D(self.dim // 4, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = Conv2D(self.dim // 2, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = Conv2D(self.dim, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = GlobalAveragePooling2D()(x)
  encoder = Model(x_in, x)
  return encoder
 
 def decoder_construct(self):
  map_size = K.int_shape(self.encoder.layers[-2].output)[1:-1]
  # print(type(map_size))
  z_in = Input(shape=K.int_shape(self.encoder.output)[1:])
  z = z_in
  z_dim = self.dim
  z = Dense(np.prod(map_size) * z_dim)(z)
  z = Reshape(map_size + (z_dim,))(z)
  z = Conv2DTranspose(z_dim // 2, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = BatchNormalization()(z)
  z = Activation('relu')(z)
  z = Conv2DTranspose(z_dim // 4, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = BatchNormalization()(z)
  z = Activation('relu')(z)
  z = Conv2DTranspose(z_dim // 8, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = BatchNormalization()(z)
  z = Activation('relu')(z)
  z = Conv2DTranspose(z_dim // 16, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = BatchNormalization()(z)
  z = Activation('relu')(z)
  z = Conv2DTranspose(3, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = Activation('tanh')(z)
  decoder = Model(z_in, z)
  return decoder
 
 def build_ae(self):
  input_x = Input(shape=(self.img_dim, self.img_dim, 3))
  x = input_x
  for i in range(1, len(self.encoder.layers)):
   x = self.encoder.layers[i](x)
  for j in range(1, len(self.decoder.layers)):
   x = self.decoder.layers[j](x)
  y = x
  auto_encoder = Model(input_x, y)
  return auto_encoder

模型A就是这里的encoder,模型B就是这里的decoder。所以,连接的精髓在build_ae()函数,直接用for循环读出各层,然后一层一层重新构造新的模型,从而实现连接效果。因为keras也是基于图的框架,这个操作并不会很费时,因为没有实际地计算。

补充知识:keras得到每层的系数

使用keras搭建好一个模型,训练好,怎么得到每层的系数呢:

weights = np.array(model.get_weights())
print(weights)
print(weights[0].shape)
print(weights[1].shape)

这样系数就被存放到一个np中了。

以上这篇Keras实现将两个模型连接到一起就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中使用OpenCV库来进行简单的气象学遥感影像计算
Feb 19 Python
深入解答关于Python的11道基本面试题
Apr 01 Python
13个最常用的Python深度学习库介绍
Oct 28 Python
3分钟学会一个Python小技巧
Nov 23 Python
Python子类继承父类构造函数详解
Feb 19 Python
Python3模拟登录操作实例分析
Mar 12 Python
Python3.5装饰器典型案例分析
Apr 30 Python
下载与当前Chrome对应的chromedriver.exe(用于python+selenium)
Jan 14 Python
使用Python爬虫库BeautifulSoup遍历文档树并对标签进行操作详解
Jan 25 Python
Anaconda和ipython环境适配的实现
Apr 22 Python
如何利用pycharm进行代码更新比较
Nov 04 Python
Python连续赋值需要注意的一些问题
Jun 03 Python
keras 获取某层输出 获取复用层的多次输出实例
May 23 #Python
给keras层命名,并提取中间层输出值,保存到文档的实例
May 23 #Python
keras小技巧——获取某一个网络层的输出方式
May 23 #Python
keras自定义回调函数查看训练的loss和accuracy方式
May 23 #Python
Keras设定GPU使用内存大小方式(Tensorflow backend)
May 22 #Python
tensorflow使用L2 regularization正则化修正overfitting过拟合方式
May 22 #Python
Softmax函数原理及Python实现过程解析
May 22 #Python
You might like
调频问题解答
2021/03/01 无线电
PHP新手上路(十四)
2006/10/09 PHP
php通过array_unshift函数添加多个变量到数组前端的方法
2015/03/18 PHP
php生成固定长度纯数字编码的方法
2015/07/09 PHP
php微信公众号开发之欢迎老朋友
2018/10/20 PHP
PHP常用函数之根据生日计算年龄功能示例
2019/10/21 PHP
jQuery提交表单ajax查询实例代码
2012/10/07 Javascript
javascript获取URL参数与参数值的示例代码
2013/12/20 Javascript
jQuery实现锚点scoll效果实例分析
2015/03/10 Javascript
全面解析Bootstrap中form、navbar的使用方法
2016/05/30 Javascript
JavaScript中style.left与offsetLeft的使用及区别详解
2016/06/08 Javascript
Ionic2系列之使用DeepLinker实现指定页面URL
2016/11/21 Javascript
在Vue组件中使用 TypeScript的方法
2018/02/28 Javascript
JS使用遮罩实现点击某区域以外时弹窗的弹出与关闭功能示例
2018/07/31 Javascript
Vue中引入svg图标的两种方式
2021/01/14 Vue.js
[01:17:55]VGJ.T vs Mineski 2018国际邀请赛小组赛BO2 第一场 8.18
2018/08/20 DOTA
[52:05]EG vs OG 2019国际邀请赛小组赛 BO2 第二场 8.16
2019/08/18 DOTA
python条件变量之生产者与消费者操作实例分析
2017/03/22 Python
详解opencv中画圆circle函数和椭圆ellipse函数
2019/12/27 Python
python ETL工具 pyetl
2020/06/07 Python
使用html5 canvas 画时钟代码实例分享
2015/11/11 HTML / CSS
粉红色的鲸鱼:Vineyard Vines
2018/02/17 全球购物
英国复古皮包品牌:Beara Beara
2018/07/18 全球购物
Aosom西班牙:家具在线商店
2020/06/11 全球购物
销售员自我评价怎么写
2013/09/19 职场文书
社区优秀志愿者材料
2014/02/02 职场文书
高级编程求职信模板
2014/02/16 职场文书
公司会计主管岗位责任制
2014/03/01 职场文书
工作批评与自我批评范文
2014/10/16 职场文书
捐助倡议书
2015/01/19 职场文书
入党自荐书范文
2015/03/05 职场文书
2015年派出所工作总结
2015/04/24 职场文书
写给女朋友的保证书
2015/05/09 职场文书
毕业设计致谢词
2015/05/14 职场文书
班级元旦晚会开幕词
2016/03/04 职场文书
分家协议书范本
2016/03/22 职场文书