Keras: model实现固定部分layer,训练部分layer操作


Posted in Python onJune 28, 2020

需求:Resnet50做调优训练,将最后分类数目由1000改为500。

问题:网上下载了resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5,更改了Resnet50后,由于所有层均参加训练,导致训练速度慢。实际上只需要训练最后3层,前面的层都不需要训练。

解决办法:

①将模型拆分为两个模型,一个为前面的notop部分,一个为最后三层,然后利用model的trainable属性设置只有后一个model训练,最后将两个模型合并起来。

②不用拆分,遍历模型的所有层,将前面层的trainable设置为False即可。代码如下:

for layer in model.layers[:-3]:
 print(layer.trainable)
 layer.trainable = False

注意事项:

①尽量不要这样:

layers.Conv2D(filters1, (1, 1), trainable=False)(input_tensor)

因为容易出错。。。

②加载notop参数时注意by_name=True.

补充知识:Keras关于训练冻结部分层

设置冻结层有两种方式。

(不推荐)是在搭建网络时,直接将某层的trainable设置为false,例如:

layers.Conv2D(filters1, (1, 1), trainable=False)(input_tensor)

在网络搭建完成时,遍历model.layer,然后将layer.trainable设置为False:

# 冻结网络倒数的3层
for layer in model.layers[:-3]:
 print(layer.trainable)
 layer.trainable = False

也可以根据layer.name来确定哪些层需要冻结,例如冻结最后一层和RNN层:

for layer in model.layers:
 layerName=str(layer.name)
 if layerName.startswith("RNN_") or layerName.startswith("Final_"):
 layer.trainable=False

可以在实例化之后将网络层的 trainable 属性设置为 True 或 False。为了使之生效,在修改 trainable 属性之后,需要在模型上调用 compile()。

这是一个例子

x = Input(shape=(32,))
layer = Dense(32)
layer.trainable = False
y = layer(x)
 
frozen_model = Model(x, y)
# 在下面的模型中,训练期间不会更新层的权重
frozen_model.compile(optimizer='rmsprop', loss='mse')
 
layer.trainable = True
trainable_model = Model(x, y)
# 使用这个模型,训练期间 `layer` 的权重将被更新
# (这也会影响上面的模型,因为它使用了同一个网络层实例)
trainable_model.compile(optimizer='rmsprop', loss='mse')
 
frozen_model.fit(data, labels) # 这不会更新 `layer` 的权重
trainable_model.fit(data, labels) # 这会更新 `layer` 的权重

在网络搭建时,可以考虑最后一个分类层命名和分类数量关联,这样当费雷数量方式变化时,model.load_weight(“weight.h5”,by_name=True)不会加载最后一层

以上这篇Keras: model实现固定部分layer,训练部分layer操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
详解Python的Django框架中的通用视图
May 04 Python
Python中操作符重载用法分析
Apr 29 Python
python定时利用QQ邮件发送天气预报的实例
Nov 17 Python
flask session组件的使用示例
Dec 25 Python
Python将字符串常量转化为变量方法总结
Mar 17 Python
OpenCV-Python 摄像头实时检测人脸代码实例
Apr 30 Python
django框架实现一次性上传多个文件功能示例【批量上传】
Jun 19 Python
基于python 微信小程序之获取已存在模板消息列表
Aug 05 Python
selenium+Chrome滑动验证码破解二(某某网站)
Dec 17 Python
python获取依赖包和安装依赖包教程
Feb 13 Python
浅谈ROC曲线的最佳阈值如何选取
Feb 28 Python
Python爬虫 简单介绍一下Xpath及使用
Apr 26 Python
sklearn的predict_proba使用说明
Jun 28 #Python
基于python实现ROC曲线绘制广场解析
Jun 28 #Python
Python sklearn中的.fit与.predict的用法说明
Jun 28 #Python
浅谈sklearn中predict与predict_proba区别
Jun 28 #Python
解决Pytorch自定义层出现多Variable共享内存错误问题
Jun 28 #Python
Pytorch学习之torch用法----比较操作(Comparison Ops)
Jun 28 #Python
PyTorch的torch.cat用法
Jun 28 #Python
You might like
用php实现的下载css文件中的图片的代码
2010/02/08 PHP
调试一段PHP程序时遇到的三个问题
2012/01/17 PHP
解析php开发中的中文编码问题
2013/08/08 PHP
PHP防止post重复提交数据的简单例子
2014/06/07 PHP
在一个浏览器里呈现所有浏览器测试结果的前端测试工具的思路
2010/03/02 Javascript
JQuery动态给table添加、删除行 改进版
2011/01/19 Javascript
js中arguments的用法(实例讲解)
2013/11/30 Javascript
js实现鼠标悬停图片上时滚动文字说明的方法
2015/02/17 Javascript
jQuery+PHP实现动态数字展示特效
2015/03/14 Javascript
js实现类似jquery里animate动画效果的方法
2015/04/10 Javascript
深入浅析JavaScript中的3DES
2016/08/24 Javascript
flag和jq on 的绑定多个对象和方法(必看)
2017/02/27 Javascript
layui使用label标签的方法
2019/09/14 Javascript
react实现同页面三级跳转路由布局
2019/09/26 Javascript
[02:43]DOTA2英雄基础教程 圣堂刺客
2013/12/09 DOTA
[01:38]完美世界高校联赛决赛花絮
2018/12/02 DOTA
Python中random模块用法实例分析
2015/05/19 Python
bat和python批量重命名文件的实现代码
2016/05/19 Python
Python编程实现二分法和牛顿迭代法求平方根代码
2017/12/04 Python
python构建深度神经网络(续)
2018/03/10 Python
python3获取url文件大小示例代码
2019/09/18 Python
Python 脚本拉取 Docker 镜像问题
2019/11/10 Python
简单了解Python3 bytes和str类型的区别和联系
2019/12/19 Python
python中图像通道分离与合并实例
2020/01/17 Python
如何基于python实现不邻接植花
2020/05/01 Python
python里的单引号和双引号的有什么作用
2020/06/17 Python
Prometheus开发中间件Exporter过程详解
2020/11/30 Python
Python hashlib和hmac模块使用方法解析
2020/12/08 Python
五分钟学会怎么用Pygame做一个简单的贪吃蛇
2021/01/06 Python
物业电工岗位职责
2013/11/20 职场文书
幼儿园中班教学反思
2014/02/10 职场文书
高中生学期学习自我评价
2014/02/24 职场文书
教室标语大全
2014/06/21 职场文书
印刷技术专业自荐信
2014/09/18 职场文书
教师党的群众路线教育实践活动个人整改方案
2014/10/31 职场文书
国际贸易实训报告
2014/11/05 职场文书