Keras实现将两个模型连接到一起


Posted in Python onMay 23, 2020

神经网络玩得越久就越会尝试一些网络结构上的大改动。

先说意图

有两个模型:模型A和模型B。模型A的输出可以连接B的输入。将两个小模型连接成一个大模型,A-B,既可以同时训练又可以分离训练。

流行的算法里经常有这么关系的两个模型,对GAN来说,生成器和判别器就是这样子;对VAE来说,编码器和解码器就是这样子;对目标检测网络来说,backbone和整体也是可以拆分的。所以,应用范围还是挺广的。

实现方法

首先说明,我的实现方法不一定是最佳方法。也是实在没有借鉴到比较好的方法,所以才自己手动写了一个。

第一步,我们有现成的两个模型A和B;我们想把A的输出连到B的输入,组成一个整体C。

第二步, 重构新模型C;我的方法是:读出A和B各有哪些layer,然后一层一层重新搭成C。

可以看一个自编码器的代码(本人所编写):

class AE:
 def __init__(self, dim, img_dim, batch_size):
  self.dim = dim
  self.img_dim = img_dim
  self.batch_size = batch_size
  self.encoder = self.encoder_construct()
  self.decoder = self.decoder_construct()
 
 def encoder_construct(self):
  x_in = Input(shape=(self.img_dim, self.img_dim, 3))
  x = x_in
  x = Conv2D(self.dim // 16, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = Conv2D(self.dim // 8, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = Conv2D(self.dim // 4, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = Conv2D(self.dim // 2, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = Conv2D(self.dim, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = GlobalAveragePooling2D()(x)
  encoder = Model(x_in, x)
  return encoder
 
 def decoder_construct(self):
  map_size = K.int_shape(self.encoder.layers[-2].output)[1:-1]
  # print(type(map_size))
  z_in = Input(shape=K.int_shape(self.encoder.output)[1:])
  z = z_in
  z_dim = self.dim
  z = Dense(np.prod(map_size) * z_dim)(z)
  z = Reshape(map_size + (z_dim,))(z)
  z = Conv2DTranspose(z_dim // 2, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = BatchNormalization()(z)
  z = Activation('relu')(z)
  z = Conv2DTranspose(z_dim // 4, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = BatchNormalization()(z)
  z = Activation('relu')(z)
  z = Conv2DTranspose(z_dim // 8, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = BatchNormalization()(z)
  z = Activation('relu')(z)
  z = Conv2DTranspose(z_dim // 16, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = BatchNormalization()(z)
  z = Activation('relu')(z)
  z = Conv2DTranspose(3, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = Activation('tanh')(z)
  decoder = Model(z_in, z)
  return decoder
 
 def build_ae(self):
  input_x = Input(shape=(self.img_dim, self.img_dim, 3))
  x = input_x
  for i in range(1, len(self.encoder.layers)):
   x = self.encoder.layers[i](x)
  for j in range(1, len(self.decoder.layers)):
   x = self.decoder.layers[j](x)
  y = x
  auto_encoder = Model(input_x, y)
  return auto_encoder

模型A就是这里的encoder,模型B就是这里的decoder。所以,连接的精髓在build_ae()函数,直接用for循环读出各层,然后一层一层重新构造新的模型,从而实现连接效果。因为keras也是基于图的框架,这个操作并不会很费时,因为没有实际地计算。

补充知识:keras得到每层的系数

使用keras搭建好一个模型,训练好,怎么得到每层的系数呢:

weights = np.array(model.get_weights())
print(weights)
print(weights[0].shape)
print(weights[1].shape)

这样系数就被存放到一个np中了。

以上这篇Keras实现将两个模型连接到一起就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python处理二进制数据的方法
Jun 03 Python
Python处理文本文件中控制字符的方法
Feb 07 Python
Pycharm学习教程(5) Python快捷键相关设置
May 03 Python
基于Python3 逗号代码 和 字符图网格(详谈)
Jun 22 Python
对python numpy数组中冒号的使用方法详解
Apr 17 Python
用python实现将数组元素按从小到大的顺序排列方法
Jul 02 Python
基于numpy中数组元素的切片复制方法
Nov 15 Python
浅析python参数的知识点
Dec 10 Python
使用Django开发简单接口实现文章增删改查
May 09 Python
使用Python制作表情包实现换脸功能
Jul 19 Python
Python加载数据的5种不同方式(收藏)
Nov 13 Python
Django怎么在admin后台注册数据库表
Nov 14 Python
keras 获取某层输出 获取复用层的多次输出实例
May 23 #Python
给keras层命名,并提取中间层输出值,保存到文档的实例
May 23 #Python
keras小技巧——获取某一个网络层的输出方式
May 23 #Python
keras自定义回调函数查看训练的loss和accuracy方式
May 23 #Python
Keras设定GPU使用内存大小方式(Tensorflow backend)
May 22 #Python
tensorflow使用L2 regularization正则化修正overfitting过拟合方式
May 22 #Python
Softmax函数原理及Python实现过程解析
May 22 #Python
You might like
PHP概述.
2006/10/09 PHP
php修改指定文件后缀的方法
2014/09/11 PHP
PHP错误Warning:mysql_query()解决方法
2015/10/24 PHP
thinkPHP多表查询及分页功能实现方法示例
2017/07/03 PHP
php实现socket推送技术的示例
2017/12/20 PHP
PHP基于双向链表与排序操作实现的会员排名功能示例
2017/12/26 PHP
PHP读取文件或采集时解决中文乱码
2021/03/09 PHP
javaScript 页面自动加载事件详解
2014/02/10 Javascript
基于编写jQuery的无缝滚动插件
2014/08/02 Javascript
js上传图片及预览功能实例分析
2015/04/24 Javascript
bootstrap PrintThis打印插件使用详解
2017/02/20 Javascript
jQuery插件FusionCharts实现的3D帕累托图效果示例【附demo源码】
2017/03/25 jQuery
js数组的基本使用总结
2021/01/18 Javascript
[18:20]DOTA2 HEROS教学视频教你分分钟做大人-昆卡
2014/06/11 DOTA
[01:37]全新的一集《真视界》——TI7总决赛
2017/09/21 DOTA
CentOS安装pillow报错的解决方法
2016/01/27 Python
Python Django使用forms来实现评论功能
2016/08/17 Python
Python登录并获取CSDN博客所有文章列表代码实例
2017/12/28 Python
wx.CheckBox创建复选框控件并响应鼠标点击事件
2018/04/25 Python
对pandas replace函数的使用方法小结
2018/05/18 Python
Python爬取qq空间说说的实例代码
2018/08/17 Python
pycharm在调试python时执行其他语句的方法
2018/11/29 Python
Python 文件操作之读取文件(read),文件指针与写入文件(write),文件打开方式示例
2019/09/29 Python
浅谈pytorch、cuda、python的版本对齐问题
2020/01/15 Python
python 从list中随机取值的方法
2020/11/16 Python
英国第一豪华护肤品牌:Elemis
2017/10/12 全球购物
英国领先的亚洲旅游专家:Wendy Wu Tours
2018/01/21 全球购物
全球性的在线商店:Vogca
2019/05/10 全球购物
单位委托书范本
2014/04/04 职场文书
祖国在我心中演讲稿450字
2014/09/05 职场文书
小学科学教学计划
2015/01/21 职场文书
寻找最美乡村教师观后感
2015/06/18 职场文书
团组织关系介绍信
2019/06/24 职场文书
Mybatis-Plus进阶分页与乐观锁插件及通用枚举和多数据源详解
2022/03/21 Java/Android
Python 图片添加美颜效果
2022/04/28 Python
Linux安装Docker详细教程
2022/07/07 Servers