Keras 使用 Lambda层详解


Posted in Python onJune 10, 2020

我就废话不多说了,大家还是直接看代码吧!

from tensorflow.python.keras.models import Sequential, Model
from tensorflow.python.keras.layers import Dense, Flatten, Conv2D, MaxPool2D, Dropout, Conv2DTranspose, Lambda, Input, Reshape, Add, Multiply
from tensorflow.python.keras.optimizers import Adam
 
def deconv(x):
  height = x.get_shape()[1].value
  width = x.get_shape()[2].value
  
  new_height = height*2
  new_width = width*2
  
  x_resized = tf.image.resize_images(x, [new_height, new_width], tf.image.ResizeMethod.NEAREST_NEIGHBOR)
  
  return x_resized
 
def Generator(scope='generator'):
  imgs_noise = Input(shape=inputs_shape)
  x = Conv2D(filters=32, kernel_size=(9,9), strides=(1,1), padding='same', activation='relu')(imgs_noise)
  x = Conv2D(filters=64, kernel_size=(3,3), strides=(2,2), padding='same', activation='relu')(x)
  x = Conv2D(filters=128, kernel_size=(3,3), strides=(2,2), padding='same', activation='relu')(x)
 
  x1 = Conv2D(filters=128, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu')(x)
  x1 = Conv2D(filters=128, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu')(x1)
  x2 = Add()([x1, x])
 
  x3 = Conv2D(filters=128, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu')(x2)
  x3 = Conv2D(filters=128, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu')(x3)
  x4 = Add()([x3, x2])
 
  x5 = Conv2D(filters=128, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu')(x4)
  x5 = Conv2D(filters=128, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu')(x5)
  x6 = Add()([x5, x4])
 
  x = MaxPool2D(pool_size=(2,2))(x6)
 
  x = Lambda(deconv)(x)
  x = Conv2D(filters=64, kernel_size=(3, 3), strides=(1,1), padding='same',activation='relu')(x)
  x = Lambda(deconv)(x)
  x = Conv2D(filters=32, kernel_size=(3, 3), strides=(1,1), padding='same',activation='relu')(x)
  x = Lambda(deconv)(x)
  x = Conv2D(filters=3, kernel_size=(3, 3), strides=(1, 1), padding='same',activation='tanh')(x)
 
  x = Lambda(lambda x: x+1)(x)
  y = Lambda(lambda x: x*127.5)(x)
  
  model = Model(inputs=imgs_noise, outputs=y)
  model.summary()
  
  return model
 
my_generator = Generator()
my_generator.compile(loss='binary_crossentropy', optimizer=Adam(0.7, decay=1e-3), metrics=['accuracy'])

补充知识:含有Lambda自定义层keras模型,保存遇到的问题及解决方案

一,许多应用,keras含有的层已经不能满足要求,需要透过Lambda自定义层来实现一些layer,这个情况下,只能保存模型的权重,无法使用model.save来保存模型。保存时会报

TypeError: can't pickle _thread.RLock objects

Keras 使用 Lambda层详解

二,解决方案,为了便于后续的部署,可以转成tensorflow的PB进行部署。

from keras.models import load_model
import tensorflow as tf
import os, sys
from keras import backend as K
from tensorflow.python.framework import graph_util, graph_io

def h5_to_pb(h5_weight_path, output_dir, out_prefix="output_", log_tensorboard=True):
  if not os.path.exists(output_dir):
    os.mkdir(output_dir)
  h5_model = build_model()
  h5_model.load_weights(h5_weight_path)
  out_nodes = []
  for i in range(len(h5_model.outputs)):
    out_nodes.append(out_prefix + str(i + 1))
    tf.identity(h5_model.output[i], out_prefix + str(i + 1))
  model_name = os.path.splitext(os.path.split(h5_weight_path)[-1])[0] + '.pb'
  sess = K.get_session()
  init_graph = sess.graph.as_graph_def()
  main_graph = graph_util.convert_variables_to_constants(sess, init_graph, out_nodes)
  graph_io.write_graph(main_graph, output_dir, name=model_name, as_text=False)
  if log_tensorboard:
    from tensorflow.python.tools import import_pb_to_tensorboard
    import_pb_to_tensorboard.import_to_tensorboard(os.path.join(output_dir, model_name), output_dir)

def build_model():
  inputs = Input(shape=(784,), name='input_img')
  x = Dense(64, activation='relu')(inputs)
  x = Dense(64, activation='relu')(x)
  y = Dense(10, activation='softmax')(x)
  h5_model = Model(inputs=inputs, outputs=y)
  return h5_model

if __name__ == '__main__':
  if len(sys.argv) == 3:
    # usage: python3 h5_to_pb.py h5_weight_path output_dir
    h5_to_pb(h5_weight_path=sys.argv[1], output_dir=sys.argv[2])

以上这篇Keras 使用 Lambda层详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在Python中使用swapCase()方法转换大小写的教程
May 20 Python
python使用MySQLdb访问mysql数据库的方法
Aug 03 Python
浅谈numpy中linspace的用法 (等差数列创建函数)
Jun 07 Python
浅谈Scrapy框架普通反爬虫机制的应对策略
Dec 28 Python
JSON文件及Python对JSON文件的读写操作
Oct 07 Python
CentOS下Python3的安装及创建虚拟环境的方法
Nov 28 Python
python 实现提取某个索引中某个时间段的数据方法
Feb 01 Python
python绘制双Y轴折线图以及单Y轴双变量柱状图的实例
Jul 08 Python
python中resample函数实现重采样和降采样代码
Feb 25 Python
解决pycharm中opencv-python导入cv2后无法自动补全的问题(不用作任何文件上的修改)
Mar 05 Python
Python基于httpx模块实现发送请求
Jul 07 Python
详解Pycharm安装及Django安装配置指南
Sep 15 Python
keras打印loss对权重的导数方式
Jun 10 #Python
Python xlrd模块导入过程及常用操作
Jun 10 #Python
keras-siamese用自己的数据集实现详解
Jun 10 #Python
python实现mean-shift聚类算法
Jun 10 #Python
Keras之自定义损失(loss)函数用法说明
Jun 10 #Python
Python xlwt模块使用代码实例
Jun 10 #Python
python中def是做什么的
Jun 10 #Python
You might like
php 无限分类的树类代码
2009/12/03 PHP
php在程序中将网页生成word文档并提供下载的代码
2012/10/09 PHP
分享一个php 的异常处理程序
2014/06/22 PHP
CodeIgniter表单验证方法实例详解
2016/03/03 PHP
PHP遍历目录文件的常用方法小结
2017/02/03 PHP
javascript:以前写的xmlhttp池,代码
2008/05/18 Javascript
js 绑定带参数的事件以及手动触发事件
2010/04/27 Javascript
快速排序 php与javascript的不同之处
2011/02/22 Javascript
修改jQuery Validation里默认的验证方法
2012/02/14 Javascript
JS:window.onload的使用介绍
2013/11/13 Javascript
一个奇葩的最短的 IE 版本判断JS脚本
2014/05/28 Javascript
页面加载完后自动执行一个方法的js代码
2014/09/06 Javascript
jQuery实现鼠标划过添加和删除class的方法
2015/06/26 Javascript
JavaScript之AOP编程实例
2015/07/17 Javascript
js格式化时间的方法
2015/12/18 Javascript
基于Vue.js实现简单搜索框
2020/03/26 Javascript
vue项目搭建以及全家桶的使用详细教程(小结)
2018/12/19 Javascript
实例详解带参数的 npm script
2019/05/28 Javascript
JS函数基本定义与用法示例
2020/01/15 Javascript
vue实现登录拦截
2020/06/29 Javascript
记一次vue跨域的解决
2020/10/21 Javascript
Python下线程之间的共享和释放示例
2015/05/04 Python
基于Python 中函数的 收集参数 机制
2019/12/21 Python
Python opencv相机标定实现原理及步骤详解
2020/04/09 Python
屈臣氏马来西亚官网:Watsons马来西亚
2019/06/15 全球购物
体育纪念品、亲笔签名的体育收藏品:Steiner Sports
2020/07/31 全球购物
如何保障Web服务器安全
2014/05/05 面试题
大专应届生个人简历的自我评价
2013/10/15 职场文书
电大毕业自我鉴定
2014/02/03 职场文书
科级干部考察材料
2014/02/15 职场文书
安全责任书范本
2014/04/15 职场文书
诉讼授权委托书范本
2014/10/05 职场文书
监护人证明
2015/06/19 职场文书
详解Python牛顿插值法
2021/05/11 Python
SQL之各种join小结详细讲解
2021/08/04 MySQL
python数据可视化使用pyfinance分析证券收益示例详解
2021/11/20 Python