Keras 使用 Lambda层详解


Posted in Python onJune 10, 2020

我就废话不多说了,大家还是直接看代码吧!

from tensorflow.python.keras.models import Sequential, Model
from tensorflow.python.keras.layers import Dense, Flatten, Conv2D, MaxPool2D, Dropout, Conv2DTranspose, Lambda, Input, Reshape, Add, Multiply
from tensorflow.python.keras.optimizers import Adam
 
def deconv(x):
  height = x.get_shape()[1].value
  width = x.get_shape()[2].value
  
  new_height = height*2
  new_width = width*2
  
  x_resized = tf.image.resize_images(x, [new_height, new_width], tf.image.ResizeMethod.NEAREST_NEIGHBOR)
  
  return x_resized
 
def Generator(scope='generator'):
  imgs_noise = Input(shape=inputs_shape)
  x = Conv2D(filters=32, kernel_size=(9,9), strides=(1,1), padding='same', activation='relu')(imgs_noise)
  x = Conv2D(filters=64, kernel_size=(3,3), strides=(2,2), padding='same', activation='relu')(x)
  x = Conv2D(filters=128, kernel_size=(3,3), strides=(2,2), padding='same', activation='relu')(x)
 
  x1 = Conv2D(filters=128, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu')(x)
  x1 = Conv2D(filters=128, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu')(x1)
  x2 = Add()([x1, x])
 
  x3 = Conv2D(filters=128, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu')(x2)
  x3 = Conv2D(filters=128, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu')(x3)
  x4 = Add()([x3, x2])
 
  x5 = Conv2D(filters=128, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu')(x4)
  x5 = Conv2D(filters=128, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu')(x5)
  x6 = Add()([x5, x4])
 
  x = MaxPool2D(pool_size=(2,2))(x6)
 
  x = Lambda(deconv)(x)
  x = Conv2D(filters=64, kernel_size=(3, 3), strides=(1,1), padding='same',activation='relu')(x)
  x = Lambda(deconv)(x)
  x = Conv2D(filters=32, kernel_size=(3, 3), strides=(1,1), padding='same',activation='relu')(x)
  x = Lambda(deconv)(x)
  x = Conv2D(filters=3, kernel_size=(3, 3), strides=(1, 1), padding='same',activation='tanh')(x)
 
  x = Lambda(lambda x: x+1)(x)
  y = Lambda(lambda x: x*127.5)(x)
  
  model = Model(inputs=imgs_noise, outputs=y)
  model.summary()
  
  return model
 
my_generator = Generator()
my_generator.compile(loss='binary_crossentropy', optimizer=Adam(0.7, decay=1e-3), metrics=['accuracy'])

补充知识:含有Lambda自定义层keras模型,保存遇到的问题及解决方案

一,许多应用,keras含有的层已经不能满足要求,需要透过Lambda自定义层来实现一些layer,这个情况下,只能保存模型的权重,无法使用model.save来保存模型。保存时会报

TypeError: can't pickle _thread.RLock objects

Keras 使用 Lambda层详解

二,解决方案,为了便于后续的部署,可以转成tensorflow的PB进行部署。

from keras.models import load_model
import tensorflow as tf
import os, sys
from keras import backend as K
from tensorflow.python.framework import graph_util, graph_io

def h5_to_pb(h5_weight_path, output_dir, out_prefix="output_", log_tensorboard=True):
  if not os.path.exists(output_dir):
    os.mkdir(output_dir)
  h5_model = build_model()
  h5_model.load_weights(h5_weight_path)
  out_nodes = []
  for i in range(len(h5_model.outputs)):
    out_nodes.append(out_prefix + str(i + 1))
    tf.identity(h5_model.output[i], out_prefix + str(i + 1))
  model_name = os.path.splitext(os.path.split(h5_weight_path)[-1])[0] + '.pb'
  sess = K.get_session()
  init_graph = sess.graph.as_graph_def()
  main_graph = graph_util.convert_variables_to_constants(sess, init_graph, out_nodes)
  graph_io.write_graph(main_graph, output_dir, name=model_name, as_text=False)
  if log_tensorboard:
    from tensorflow.python.tools import import_pb_to_tensorboard
    import_pb_to_tensorboard.import_to_tensorboard(os.path.join(output_dir, model_name), output_dir)

def build_model():
  inputs = Input(shape=(784,), name='input_img')
  x = Dense(64, activation='relu')(inputs)
  x = Dense(64, activation='relu')(x)
  y = Dense(10, activation='softmax')(x)
  h5_model = Model(inputs=inputs, outputs=y)
  return h5_model

if __name__ == '__main__':
  if len(sys.argv) == 3:
    # usage: python3 h5_to_pb.py h5_weight_path output_dir
    h5_to_pb(h5_weight_path=sys.argv[1], output_dir=sys.argv[2])

以上这篇Keras 使用 Lambda层详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python 随机生成中文验证码的实例代码
Mar 20 Python
Python编写电话薄实现增删改查功能
May 07 Python
详解python实现线程安全的单例模式
Mar 05 Python
解决matplotlib库show()方法不显示图片的问题
May 24 Python
python实现pdf转换成word/txt纯文本文件
Jun 07 Python
Django框架之登录后自定义跳转页面的实现方法
Jul 18 Python
Flask框架模板渲染操作简单示例
Jul 31 Python
Python使用random模块生成随机数操作实例详解
Sep 17 Python
Python实现检测文件的MD5值来查找重复文件案例
Mar 12 Python
快速解决jupyter notebook启动需要密码的问题
Apr 21 Python
给numpy.array增加维度的超简单方法
Jun 02 Python
Python实现信息轰炸工具(再也不怕说不过别人了)
Jun 11 Python
keras打印loss对权重的导数方式
Jun 10 #Python
Python xlrd模块导入过程及常用操作
Jun 10 #Python
keras-siamese用自己的数据集实现详解
Jun 10 #Python
python实现mean-shift聚类算法
Jun 10 #Python
Keras之自定义损失(loss)函数用法说明
Jun 10 #Python
Python xlwt模块使用代码实例
Jun 10 #Python
python中def是做什么的
Jun 10 #Python
You might like
php 带逗号千位符数字的处理方法
2012/01/10 PHP
PHP5.4中json_encode中文转码的变化小结
2013/01/30 PHP
php实现指定字符串中查找子字符串的方法
2015/03/17 PHP
jQuery 跨域访问问题解决方法
2009/12/02 Javascript
xml文档转换工具,附图表例子(hta)
2010/11/17 Javascript
javascript 学习笔记(八)javascript对象
2011/04/12 Javascript
javascript向flash swf文件传递参数值注意细节
2012/12/11 Javascript
JSON 数字排序多字段排序介绍
2013/09/18 Javascript
让javascript加载速度倍增的方法(解决JS加载速度慢的问题)
2014/12/12 Javascript
JavaScript版的TwoQueues缓存模型
2014/12/29 Javascript
jquery实现下拉框功能效果【实例代码】
2016/05/06 Javascript
jQuery基础知识点总结(DOM操作)
2016/06/01 Javascript
AngularJS使用ng-options指令实现下拉框
2016/08/23 Javascript
JavaScript九九乘法口诀表的简单实现
2016/10/04 Javascript
Linux系统中利用node.js提取Word(doc/docx)及PDF文本的内容
2017/06/17 Javascript
利用Javascript获取选择文本所在的句子详解
2017/12/03 Javascript
微信小程序如何获取用户手机号
2018/01/26 Javascript
利用npm 安装删除模块的方法
2018/05/15 Javascript
详解用vue2.x版本+adminLTE开源框架搭建后台应用模版
2019/03/15 Javascript
微信二次分享报错invalid signature问题及解决方法
2019/04/01 Javascript
详解VSCode配置启动Vue项目
2019/05/14 Javascript
vue+echarts+datav大屏数据展示及实现中国地图省市县下钻功能
2020/11/16 Javascript
使用Python的Treq on Twisted来进行HTTP压力测试
2015/04/16 Python
Python抓取淘宝下拉框关键词的方法
2015/07/08 Python
Python中内置数据类型list,tuple,dict,set的区别和用法
2015/12/14 Python
详解PANDAS 数据合并与重塑(join/merge篇)
2019/07/09 Python
Python读取图像并显示灰度图的实现
2020/12/01 Python
css3实现多个元素依次显示效果
2017/12/12 HTML / CSS
美国床垫和床上用品公司:Nest Bedding
2017/06/12 全球购物
俄罗斯药房连锁店:ASNA
2020/06/20 全球购物
IBatis持久层技术
2016/07/18 面试题
班主任工作经验材料
2014/02/02 职场文书
手机银行营销方案
2014/03/14 职场文书
年终工作总结范文2014
2014/11/27 职场文书
春季运动会开幕词
2015/01/28 职场文书
世界卫生日宣传活动总结
2015/02/09 职场文书