Keras实现将两个模型连接到一起


Posted in Python onMay 23, 2020

神经网络玩得越久就越会尝试一些网络结构上的大改动。

先说意图

有两个模型:模型A和模型B。模型A的输出可以连接B的输入。将两个小模型连接成一个大模型,A-B,既可以同时训练又可以分离训练。

流行的算法里经常有这么关系的两个模型,对GAN来说,生成器和判别器就是这样子;对VAE来说,编码器和解码器就是这样子;对目标检测网络来说,backbone和整体也是可以拆分的。所以,应用范围还是挺广的。

实现方法

首先说明,我的实现方法不一定是最佳方法。也是实在没有借鉴到比较好的方法,所以才自己手动写了一个。

第一步,我们有现成的两个模型A和B;我们想把A的输出连到B的输入,组成一个整体C。

第二步, 重构新模型C;我的方法是:读出A和B各有哪些layer,然后一层一层重新搭成C。

可以看一个自编码器的代码(本人所编写):

class AE:
 def __init__(self, dim, img_dim, batch_size):
  self.dim = dim
  self.img_dim = img_dim
  self.batch_size = batch_size
  self.encoder = self.encoder_construct()
  self.decoder = self.decoder_construct()
 
 def encoder_construct(self):
  x_in = Input(shape=(self.img_dim, self.img_dim, 3))
  x = x_in
  x = Conv2D(self.dim // 16, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = Conv2D(self.dim // 8, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = Conv2D(self.dim // 4, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = Conv2D(self.dim // 2, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = Conv2D(self.dim, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = GlobalAveragePooling2D()(x)
  encoder = Model(x_in, x)
  return encoder
 
 def decoder_construct(self):
  map_size = K.int_shape(self.encoder.layers[-2].output)[1:-1]
  # print(type(map_size))
  z_in = Input(shape=K.int_shape(self.encoder.output)[1:])
  z = z_in
  z_dim = self.dim
  z = Dense(np.prod(map_size) * z_dim)(z)
  z = Reshape(map_size + (z_dim,))(z)
  z = Conv2DTranspose(z_dim // 2, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = BatchNormalization()(z)
  z = Activation('relu')(z)
  z = Conv2DTranspose(z_dim // 4, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = BatchNormalization()(z)
  z = Activation('relu')(z)
  z = Conv2DTranspose(z_dim // 8, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = BatchNormalization()(z)
  z = Activation('relu')(z)
  z = Conv2DTranspose(z_dim // 16, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = BatchNormalization()(z)
  z = Activation('relu')(z)
  z = Conv2DTranspose(3, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = Activation('tanh')(z)
  decoder = Model(z_in, z)
  return decoder
 
 def build_ae(self):
  input_x = Input(shape=(self.img_dim, self.img_dim, 3))
  x = input_x
  for i in range(1, len(self.encoder.layers)):
   x = self.encoder.layers[i](x)
  for j in range(1, len(self.decoder.layers)):
   x = self.decoder.layers[j](x)
  y = x
  auto_encoder = Model(input_x, y)
  return auto_encoder

模型A就是这里的encoder,模型B就是这里的decoder。所以,连接的精髓在build_ae()函数,直接用for循环读出各层,然后一层一层重新构造新的模型,从而实现连接效果。因为keras也是基于图的框架,这个操作并不会很费时,因为没有实际地计算。

补充知识:keras得到每层的系数

使用keras搭建好一个模型,训练好,怎么得到每层的系数呢:

weights = np.array(model.get_weights())
print(weights)
print(weights[0].shape)
print(weights[1].shape)

这样系数就被存放到一个np中了。

以上这篇Keras实现将两个模型连接到一起就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python time模块用法实例详解
Sep 11 Python
Python列表(list)常用操作方法小结
Feb 02 Python
Pycharm 操作Django Model的简单运用方法
May 23 Python
pygame实现简易飞机大战
Sep 11 Python
浅析Python与Mongodb数据库之间的操作方法
Jul 01 Python
python 随机森林算法及其优化详解
Jul 11 Python
python实现知乎高颜值图片爬取
Aug 12 Python
python统计指定目录内文件的代码行数
Sep 19 Python
python 解决cv2绘制中文乱码问题
Dec 23 Python
使用tensorflow DataSet实现高效加载变长文本输入
Jan 20 Python
pytorch掉坑记录:model.eval的作用说明
Jun 23 Python
Python之Matplotlib绘制热力图和面积图
Apr 13 Python
keras 获取某层输出 获取复用层的多次输出实例
May 23 #Python
给keras层命名,并提取中间层输出值,保存到文档的实例
May 23 #Python
keras小技巧——获取某一个网络层的输出方式
May 23 #Python
keras自定义回调函数查看训练的loss和accuracy方式
May 23 #Python
Keras设定GPU使用内存大小方式(Tensorflow backend)
May 22 #Python
tensorflow使用L2 regularization正则化修正overfitting过拟合方式
May 22 #Python
Softmax函数原理及Python实现过程解析
May 22 #Python
You might like
WINDOWS服务器安装多套PHP的另类解决方案
2006/10/09 PHP
php urlencode()与urldecode()函数字符编码原理详解
2011/12/06 PHP
php IP转换整形(ip2long)的详解
2013/06/06 PHP
php截取指定2个字符之间字符串的方法
2015/04/15 PHP
PHP中的switch语句的用法实例详解
2015/10/21 PHP
PHP HTTP 认证实例详解
2016/11/03 PHP
如何离线执行php任务
2017/02/21 PHP
解决PHP上传非标准格式的图片pjpeg失败的方法
2017/03/12 PHP
图片按比例缩放函数
2006/06/26 Javascript
JS关键字变色实现思路及代码
2013/02/21 Javascript
js控制的回到页面顶端goTop的代码实现
2013/03/20 Javascript
JavaScript原型链示例分享
2014/01/26 Javascript
JavaScript实现动态添加,删除行的方法实例详解
2015/07/02 Javascript
在for循环中length值是否需要缓存
2015/07/27 Javascript
详解JavaScript异步编程中jQuery的promise对象的作用
2016/05/03 Javascript
前端分页功能的实现以及原理(jQuery)
2017/01/22 Javascript
详解Vue中过度动画效果应用
2017/05/25 Javascript
基于对象合并功能的实现示例
2017/10/10 Javascript
微信小程序 组件的外部样式externalClasses使用详解
2019/09/06 Javascript
浅析vue中的provide / inject 有什么用处
2019/11/10 Javascript
vue-cli3项目打包后自动化部署到服务器的方法
2020/09/16 Javascript
Python简单删除列表中相同元素的方法示例
2017/06/12 Python
机器学习python实战之决策树
2017/11/01 Python
通过Python实现一个简单的html页面
2020/05/16 Python
Python虚拟环境的创建和包下载过程分析
2020/06/19 Python
YSL圣罗兰美妆美国官网:Yves Saint Lauret US
2016/11/21 全球购物
优衣库英国官网:UNIQLO英国
2016/12/25 全球购物
俄罗斯的精英皮具:Wittchen
2018/01/29 全球购物
巴基斯坦购物网站:Goto
2019/03/11 全球购物
如何高效率的查找一个月以内的数据
2012/04/15 面试题
个人自我鉴定
2013/11/07 职场文书
企业承诺书怎么写
2014/05/24 职场文书
2014最新党员批评与自我批评材料
2014/09/24 职场文书
公司客户答谢酒会祝酒词
2015/08/11 职场文书
幼儿园园长新年寄语
2015/08/17 职场文书
Android Studio实现简易进制转换计算器
2022/05/20 Java/Android