Keras: model实现固定部分layer,训练部分layer操作


Posted in Python onJune 28, 2020

需求:Resnet50做调优训练,将最后分类数目由1000改为500。

问题:网上下载了resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5,更改了Resnet50后,由于所有层均参加训练,导致训练速度慢。实际上只需要训练最后3层,前面的层都不需要训练。

解决办法:

①将模型拆分为两个模型,一个为前面的notop部分,一个为最后三层,然后利用model的trainable属性设置只有后一个model训练,最后将两个模型合并起来。

②不用拆分,遍历模型的所有层,将前面层的trainable设置为False即可。代码如下:

for layer in model.layers[:-3]:
 print(layer.trainable)
 layer.trainable = False

注意事项:

①尽量不要这样:

layers.Conv2D(filters1, (1, 1), trainable=False)(input_tensor)

因为容易出错。。。

②加载notop参数时注意by_name=True.

补充知识:Keras关于训练冻结部分层

设置冻结层有两种方式。

(不推荐)是在搭建网络时,直接将某层的trainable设置为false,例如:

layers.Conv2D(filters1, (1, 1), trainable=False)(input_tensor)

在网络搭建完成时,遍历model.layer,然后将layer.trainable设置为False:

# 冻结网络倒数的3层
for layer in model.layers[:-3]:
 print(layer.trainable)
 layer.trainable = False

也可以根据layer.name来确定哪些层需要冻结,例如冻结最后一层和RNN层:

for layer in model.layers:
 layerName=str(layer.name)
 if layerName.startswith("RNN_") or layerName.startswith("Final_"):
 layer.trainable=False

可以在实例化之后将网络层的 trainable 属性设置为 True 或 False。为了使之生效,在修改 trainable 属性之后,需要在模型上调用 compile()。

这是一个例子

x = Input(shape=(32,))
layer = Dense(32)
layer.trainable = False
y = layer(x)
 
frozen_model = Model(x, y)
# 在下面的模型中,训练期间不会更新层的权重
frozen_model.compile(optimizer='rmsprop', loss='mse')
 
layer.trainable = True
trainable_model = Model(x, y)
# 使用这个模型,训练期间 `layer` 的权重将被更新
# (这也会影响上面的模型,因为它使用了同一个网络层实例)
trainable_model.compile(optimizer='rmsprop', loss='mse')
 
frozen_model.fit(data, labels) # 这不会更新 `layer` 的权重
trainable_model.fit(data, labels) # 这会更新 `layer` 的权重

在网络搭建时,可以考虑最后一个分类层命名和分类数量关联,这样当费雷数量方式变化时,model.load_weight(“weight.h5”,by_name=True)不会加载最后一层

以上这篇Keras: model实现固定部分layer,训练部分layer操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python类的基础入门知识
Nov 24 Python
python paramiko实现ssh远程访问的方法
Dec 03 Python
Python内置的字符串处理函数详细整理(覆盖日常所用)
Aug 19 Python
探索Python3.4中新引入的asyncio模块
Apr 08 Python
Python OpenCV 直方图的计算与显示的方法示例
Feb 08 Python
对numpy中数组元素的统一赋值实例
Apr 04 Python
Python根据已知邻接矩阵绘制无向图操作示例
Jun 23 Python
python去除拼音声调字母,替换为字母的方法
Nov 28 Python
从0开始的Python学习014面向对象编程(推荐)
Apr 02 Python
由Python编写的MySQL管理工具代码实例
Apr 09 Python
python实现五子棋游戏
Jun 18 Python
Python队列、进程间通信、线程案例
Oct 25 Python
sklearn的predict_proba使用说明
Jun 28 #Python
基于python实现ROC曲线绘制广场解析
Jun 28 #Python
Python sklearn中的.fit与.predict的用法说明
Jun 28 #Python
浅谈sklearn中predict与predict_proba区别
Jun 28 #Python
解决Pytorch自定义层出现多Variable共享内存错误问题
Jun 28 #Python
Pytorch学习之torch用法----比较操作(Comparison Ops)
Jun 28 #Python
PyTorch的torch.cat用法
Jun 28 #Python
You might like
PHP Zip压缩 在线对文件进行压缩的函数
2010/05/26 PHP
Node.js开发指南中的简单实例(mysql版)
2013/09/17 Javascript
Javascript中查找不以XX字符结尾的单词示例代码
2013/10/15 Javascript
使用javascript实现判断当前浏览器
2015/04/14 Javascript
jQuery UI设置固定日期选择特效代码分享
2015/08/27 Javascript
Jquery+Ajax+PHP+MySQL实现分类列表管理(下)
2015/10/28 Javascript
分享一些常用的jQuery动画事件和动画函数
2015/11/27 Javascript
JavaScript时间操作之年月日星期级联操作
2016/01/15 Javascript
Angularjs实现分页和分页算法的示例代码
2016/12/23 Javascript
JS设计模式之状态模式概念与用法分析
2018/02/05 Javascript
vue自定义移动端touch事件之点击、滑动、长按事件
2018/07/10 Javascript
微信小程序发送短信验证码完整实例
2019/01/07 Javascript
一步快速解决微信小程序中textarea层级太高遮挡其他组件
2019/03/04 Javascript
浅析Vue中拆分视图层代码的5点建议
2019/08/15 Javascript
JavaScript实现省市联动效果
2019/11/22 Javascript
[03:40]DOTA2英雄梦之声_第01期_炼金术士
2014/06/23 DOTA
[40:03]Liquid vs Optic 2018国际邀请赛淘汰赛BO3 第一场 8.21
2018/08/22 DOTA
Python中优化NumPy包使用性能的教程
2015/04/23 Python
Python入门必须知道的11个知识点
2018/03/21 Python
浅谈django orm 优化
2018/08/18 Python
Python绘制正余弦函数图像的方法
2018/08/28 Python
Python实现的逻辑回归算法示例【附测试csv文件下载】
2018/12/28 Python
实例讲解Python3中abs()函数
2019/02/19 Python
python爬虫学习笔记之Beautifulsoup模块用法详解
2020/04/09 Python
python 批量下载bilibili视频的gui程序
2020/11/20 Python
美国高端寝具品牌:Coyuchi
2017/02/08 全球购物
意大利包包和行李箱销售网站:Bagaglio.it
2021/03/02 全球购物
女儿十岁生日答谢词
2014/01/27 职场文书
通用自荐信范文
2014/03/14 职场文书
大学校务公开实施方案
2014/03/31 职场文书
学校创先争优活动总结
2014/08/28 职场文书
2014班子成员自我剖析材料思想汇报
2014/10/01 职场文书
校园会短篇的广播稿
2014/10/21 职场文书
小学见习报告
2014/10/31 职场文书
检举信的写法
2019/04/10 职场文书
Android Canvas绘制文字横纵向对齐
2022/06/05 Java/Android