Keras: model实现固定部分layer,训练部分layer操作


Posted in Python onJune 28, 2020

需求:Resnet50做调优训练,将最后分类数目由1000改为500。

问题:网上下载了resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5,更改了Resnet50后,由于所有层均参加训练,导致训练速度慢。实际上只需要训练最后3层,前面的层都不需要训练。

解决办法:

①将模型拆分为两个模型,一个为前面的notop部分,一个为最后三层,然后利用model的trainable属性设置只有后一个model训练,最后将两个模型合并起来。

②不用拆分,遍历模型的所有层,将前面层的trainable设置为False即可。代码如下:

for layer in model.layers[:-3]:
 print(layer.trainable)
 layer.trainable = False

注意事项:

①尽量不要这样:

layers.Conv2D(filters1, (1, 1), trainable=False)(input_tensor)

因为容易出错。。。

②加载notop参数时注意by_name=True.

补充知识:Keras关于训练冻结部分层

设置冻结层有两种方式。

(不推荐)是在搭建网络时,直接将某层的trainable设置为false,例如:

layers.Conv2D(filters1, (1, 1), trainable=False)(input_tensor)

在网络搭建完成时,遍历model.layer,然后将layer.trainable设置为False:

# 冻结网络倒数的3层
for layer in model.layers[:-3]:
 print(layer.trainable)
 layer.trainable = False

也可以根据layer.name来确定哪些层需要冻结,例如冻结最后一层和RNN层:

for layer in model.layers:
 layerName=str(layer.name)
 if layerName.startswith("RNN_") or layerName.startswith("Final_"):
 layer.trainable=False

可以在实例化之后将网络层的 trainable 属性设置为 True 或 False。为了使之生效,在修改 trainable 属性之后,需要在模型上调用 compile()。

这是一个例子

x = Input(shape=(32,))
layer = Dense(32)
layer.trainable = False
y = layer(x)
 
frozen_model = Model(x, y)
# 在下面的模型中,训练期间不会更新层的权重
frozen_model.compile(optimizer='rmsprop', loss='mse')
 
layer.trainable = True
trainable_model = Model(x, y)
# 使用这个模型,训练期间 `layer` 的权重将被更新
# (这也会影响上面的模型,因为它使用了同一个网络层实例)
trainable_model.compile(optimizer='rmsprop', loss='mse')
 
frozen_model.fit(data, labels) # 这不会更新 `layer` 的权重
trainable_model.fit(data, labels) # 这会更新 `layer` 的权重

在网络搭建时,可以考虑最后一个分类层命名和分类数量关联,这样当费雷数量方式变化时,model.load_weight(“weight.h5”,by_name=True)不会加载最后一层

以上这篇Keras: model实现固定部分layer,训练部分layer操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Windows下安装python2和python3多版本教程
Mar 30 Python
python爬取淘宝商品详情页数据
Feb 23 Python
批量将ppt转换为pdf的Python代码 只要27行!
Feb 26 Python
Python创建普通菜单示例【基于win32ui模块】
May 09 Python
Python基于pandas实现json格式转换成dataframe的方法
Jun 22 Python
Python画柱状统计图操作示例【基于matplotlib库】
Jul 04 Python
python读取图片的方式,以及将图片以三维数组的形式输出方法
Jul 03 Python
python 一个figure上显示多个图像的实例
Jul 08 Python
Python搭建代理IP池实现获取IP的方法
Oct 27 Python
Python如何获取文件指定行的内容
May 27 Python
matplotlib之属性组合包(cycler)的使用
Feb 24 Python
pytorch 如何把图像数据集进行划分成train,test和val
May 31 Python
sklearn的predict_proba使用说明
Jun 28 #Python
基于python实现ROC曲线绘制广场解析
Jun 28 #Python
Python sklearn中的.fit与.predict的用法说明
Jun 28 #Python
浅谈sklearn中predict与predict_proba区别
Jun 28 #Python
解决Pytorch自定义层出现多Variable共享内存错误问题
Jun 28 #Python
Pytorch学习之torch用法----比较操作(Comparison Ops)
Jun 28 #Python
PyTorch的torch.cat用法
Jun 28 #Python
You might like
PHP使用者状态管理功能的应用
2006/10/09 PHP
PHP中str_replace函数使用小结
2008/10/11 PHP
浅析PHP中的字符串编码转换(自动识别原编码)
2013/07/02 PHP
php调用MySQL存储过程的方法集合(推荐)
2013/07/03 PHP
php中的路径问题与set_include_path使用介绍
2014/02/11 PHP
thinkphp备份数据库的方法分享
2015/01/04 PHP
PHP的全局错误处理详解
2016/04/25 PHP
PHPExcel简单读取excel文件示例
2016/05/26 PHP
THINKPHP截取中文字符串函数实例代码
2017/03/20 PHP
php实现微信模拟登陆、获取用户列表及群发消息功能示例
2017/06/28 PHP
javascript 定义初始化数组函数
2009/09/07 Javascript
使用jQuery获取data-的自定义属性
2015/11/10 Javascript
JavaScript的Backbone.js框架环境搭建及Hellow world示例
2016/05/07 Javascript
html、css和jquery相结合实现简单的进度条效果实例代码
2016/10/24 Javascript
Vue2路由动画效果的实现代码
2017/07/10 Javascript
layui清空,重置表单数据的实例
2019/09/12 Javascript
微信小程序保持session会话的方法
2020/03/20 Javascript
[02:14]完美“圣”典2016风云人物:xiao8专访
2016/12/01 DOTA
python更新列表的方法
2015/07/28 Python
Python线程指南详细介绍
2017/01/05 Python
Python中pygal绘制雷达图代码分享
2017/12/07 Python
python+opencv识别图片中的圆形
2020/03/25 Python
Python Numpy,mask图像的生成详解
2020/02/19 Python
Python常用base64 md5 aes des crc32加密解密方法汇总
2020/11/06 Python
Django多个app urls配置代码实例
2020/11/26 Python
浅谈html5 响应式布局
2014/12/24 HTML / CSS
美国电视购物HSN官网:HSN
2016/09/07 全球购物
中国医药集团国药在线:国药网
2017/02/06 全球购物
马来西亚在线药房:RoyalePharma
2019/12/01 全球购物
2014新年寄语
2014/01/20 职场文书
一夜的工作教学反思
2014/02/08 职场文书
师德建设实施方案
2014/03/21 职场文书
一年级学生期末评语
2014/04/21 职场文书
2014年安全生产大检查方案
2014/05/13 职场文书
分析SQL窗口函数之排名窗口函数
2022/04/21 Oracle
html中两种获取标签内的值的方法
2022/06/10 HTML / CSS