Keras: model实现固定部分layer,训练部分layer操作


Posted in Python onJune 28, 2020

需求:Resnet50做调优训练,将最后分类数目由1000改为500。

问题:网上下载了resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5,更改了Resnet50后,由于所有层均参加训练,导致训练速度慢。实际上只需要训练最后3层,前面的层都不需要训练。

解决办法:

①将模型拆分为两个模型,一个为前面的notop部分,一个为最后三层,然后利用model的trainable属性设置只有后一个model训练,最后将两个模型合并起来。

②不用拆分,遍历模型的所有层,将前面层的trainable设置为False即可。代码如下:

for layer in model.layers[:-3]:
 print(layer.trainable)
 layer.trainable = False

注意事项:

①尽量不要这样:

layers.Conv2D(filters1, (1, 1), trainable=False)(input_tensor)

因为容易出错。。。

②加载notop参数时注意by_name=True.

补充知识:Keras关于训练冻结部分层

设置冻结层有两种方式。

(不推荐)是在搭建网络时,直接将某层的trainable设置为false,例如:

layers.Conv2D(filters1, (1, 1), trainable=False)(input_tensor)

在网络搭建完成时,遍历model.layer,然后将layer.trainable设置为False:

# 冻结网络倒数的3层
for layer in model.layers[:-3]:
 print(layer.trainable)
 layer.trainable = False

也可以根据layer.name来确定哪些层需要冻结,例如冻结最后一层和RNN层:

for layer in model.layers:
 layerName=str(layer.name)
 if layerName.startswith("RNN_") or layerName.startswith("Final_"):
 layer.trainable=False

可以在实例化之后将网络层的 trainable 属性设置为 True 或 False。为了使之生效,在修改 trainable 属性之后,需要在模型上调用 compile()。

这是一个例子

x = Input(shape=(32,))
layer = Dense(32)
layer.trainable = False
y = layer(x)
 
frozen_model = Model(x, y)
# 在下面的模型中,训练期间不会更新层的权重
frozen_model.compile(optimizer='rmsprop', loss='mse')
 
layer.trainable = True
trainable_model = Model(x, y)
# 使用这个模型,训练期间 `layer` 的权重将被更新
# (这也会影响上面的模型,因为它使用了同一个网络层实例)
trainable_model.compile(optimizer='rmsprop', loss='mse')
 
frozen_model.fit(data, labels) # 这不会更新 `layer` 的权重
trainable_model.fit(data, labels) # 这会更新 `layer` 的权重

在网络搭建时,可以考虑最后一个分类层命名和分类数量关联,这样当费雷数量方式变化时,model.load_weight(“weight.h5”,by_name=True)不会加载最后一层

以上这篇Keras: model实现固定部分layer,训练部分layer操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python使用matplotlib绘图无法显示中文问题的解决方法
Mar 14 Python
python 生成图形验证码的方法示例
Nov 11 Python
python字符串Intern机制详解
Jul 01 Python
Django集成celery发送异步邮件实例
Dec 17 Python
在django admin详情表单显示中添加自定义控件的实现
Mar 11 Python
Python如何操作office实现自动化及win32com.client的运用
Apr 01 Python
python实现音乐播放和下载小程序功能
Apr 26 Python
Python基于pandas爬取网页表格数据
May 11 Python
Pytorch转keras的有效方法,以FlowNet为例讲解
May 26 Python
如何更换python默认编辑器的背景色
Aug 10 Python
python3代码输出嵌套式对象实例详解
Dec 03 Python
python中subplot大小的设置步骤
Jun 28 Python
sklearn的predict_proba使用说明
Jun 28 #Python
基于python实现ROC曲线绘制广场解析
Jun 28 #Python
Python sklearn中的.fit与.predict的用法说明
Jun 28 #Python
浅谈sklearn中predict与predict_proba区别
Jun 28 #Python
解决Pytorch自定义层出现多Variable共享内存错误问题
Jun 28 #Python
Pytorch学习之torch用法----比较操作(Comparison Ops)
Jun 28 #Python
PyTorch的torch.cat用法
Jun 28 #Python
You might like
PHP经典的给图片加水印程序
2006/12/06 PHP
使用adodb lite解决问题
2006/12/31 PHP
关于php连接mssql:pdo odbc sql server
2011/07/20 PHP
php采用curl模仿登录人人网发布动态的方法
2014/11/07 PHP
PHP SplObjectStorage使用实例
2015/05/12 PHP
PHP实现QQ登录实例代码
2016/01/14 PHP
js中prototype用法详细介绍
2013/11/14 Javascript
Javascript常用小技巧汇总
2015/06/24 Javascript
Jquery ajax 同步阻塞引起的UI线程阻塞问题
2015/11/17 Javascript
jQuery实现表格元素动态创建功能
2017/01/09 Javascript
JS组件系列之MVVM组件 vue 30分钟搞定前端增删改查
2017/04/28 Javascript
详解node中创建服务进程
2017/05/09 Javascript
jQuery实现动态给table赋值的方法示例
2017/07/04 jQuery
详解vue中axios的使用与封装
2019/03/20 Javascript
vue cli3.0 引入eslint 结合vscode使用
2019/05/27 Javascript
vue 使用axios 数据请求第三方插件的使用教程详解
2019/07/05 Javascript
解决layui的radio属性或别的属性没显示出来的问题
2019/09/26 Javascript
vue基于v-charts封装双向条形图的实现代码
2019/12/09 Javascript
在Django的视图中使用form对象的方法
2015/07/18 Python
简单讲解Python中的闭包
2015/08/11 Python
详解Python自建logging模块
2018/01/29 Python
Python中%是什么意思?python中百分号如何使用?
2018/03/20 Python
python存储16bit和32bit图像的实例
2018/12/05 Python
python使用opencv实现马赛克效果示例
2019/09/28 Python
Python猜数字算法题详解
2020/03/01 Python
关于Python字符串显示u...的解决方式
2020/03/06 Python
python如何用matplotlib创建三维图表
2021/01/26 Python
css3 实现元素弧线运动的示例代码
2020/04/24 HTML / CSS
基于HTML5 WebGL的3D机房的示例
2018/03/16 HTML / CSS
英国电信商店:BT Shop
2019/12/17 全球购物
如何写毕业求职自荐信
2013/11/06 职场文书
电子专业求职信
2014/06/19 职场文书
小学数学教研活动总结
2014/07/01 职场文书
2016北大自主招生自荐信模板
2016/01/28 职场文书
 分享一个Python 遇到数据库超好用的模块
2022/04/06 Python