Keras: model实现固定部分layer,训练部分layer操作


Posted in Python onJune 28, 2020

需求:Resnet50做调优训练,将最后分类数目由1000改为500。

问题:网上下载了resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5,更改了Resnet50后,由于所有层均参加训练,导致训练速度慢。实际上只需要训练最后3层,前面的层都不需要训练。

解决办法:

①将模型拆分为两个模型,一个为前面的notop部分,一个为最后三层,然后利用model的trainable属性设置只有后一个model训练,最后将两个模型合并起来。

②不用拆分,遍历模型的所有层,将前面层的trainable设置为False即可。代码如下:

for layer in model.layers[:-3]:
 print(layer.trainable)
 layer.trainable = False

注意事项:

①尽量不要这样:

layers.Conv2D(filters1, (1, 1), trainable=False)(input_tensor)

因为容易出错。。。

②加载notop参数时注意by_name=True.

补充知识:Keras关于训练冻结部分层

设置冻结层有两种方式。

(不推荐)是在搭建网络时,直接将某层的trainable设置为false,例如:

layers.Conv2D(filters1, (1, 1), trainable=False)(input_tensor)

在网络搭建完成时,遍历model.layer,然后将layer.trainable设置为False:

# 冻结网络倒数的3层
for layer in model.layers[:-3]:
 print(layer.trainable)
 layer.trainable = False

也可以根据layer.name来确定哪些层需要冻结,例如冻结最后一层和RNN层:

for layer in model.layers:
 layerName=str(layer.name)
 if layerName.startswith("RNN_") or layerName.startswith("Final_"):
 layer.trainable=False

可以在实例化之后将网络层的 trainable 属性设置为 True 或 False。为了使之生效,在修改 trainable 属性之后,需要在模型上调用 compile()。

这是一个例子

x = Input(shape=(32,))
layer = Dense(32)
layer.trainable = False
y = layer(x)
 
frozen_model = Model(x, y)
# 在下面的模型中,训练期间不会更新层的权重
frozen_model.compile(optimizer='rmsprop', loss='mse')
 
layer.trainable = True
trainable_model = Model(x, y)
# 使用这个模型,训练期间 `layer` 的权重将被更新
# (这也会影响上面的模型,因为它使用了同一个网络层实例)
trainable_model.compile(optimizer='rmsprop', loss='mse')
 
frozen_model.fit(data, labels) # 这不会更新 `layer` 的权重
trainable_model.fit(data, labels) # 这会更新 `layer` 的权重

在网络搭建时,可以考虑最后一个分类层命名和分类数量关联,这样当费雷数量方式变化时,model.load_weight(“weight.h5”,by_name=True)不会加载最后一层

以上这篇Keras: model实现固定部分layer,训练部分layer操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python基础教程之简单入门说明(变量和控制语言使用方法)
Mar 25 Python
Python SQLAlchemy基本操作和常用技巧(包含大量实例,非常好)
May 06 Python
用pickle存储Python的原生对象方法
Apr 28 Python
Python读取文件内容的三种常用方式及效率比较
Oct 07 Python
Python爬虫实例_利用百度地图API批量获取城市所有的POI点
Jan 10 Python
分析运行中的 Python 进程详细解析
Jun 22 Python
Python面向对象之继承原理与用法案例分析
Dec 31 Python
Python定时任务APScheduler原理及实例解析
May 30 Python
Selenium webdriver添加cookie实现过程详解
Aug 12 Python
Python读取图像并显示灰度图的实现
Dec 01 Python
python drf各类组件的用法和作用
Jan 12 Python
python 实现IP子网计算
Feb 18 Python
sklearn的predict_proba使用说明
Jun 28 #Python
基于python实现ROC曲线绘制广场解析
Jun 28 #Python
Python sklearn中的.fit与.predict的用法说明
Jun 28 #Python
浅谈sklearn中predict与predict_proba区别
Jun 28 #Python
解决Pytorch自定义层出现多Variable共享内存错误问题
Jun 28 #Python
Pytorch学习之torch用法----比较操作(Comparison Ops)
Jun 28 #Python
PyTorch的torch.cat用法
Jun 28 #Python
You might like
php学习之 认清变量的作用范围
2010/01/26 PHP
php高级编程-函数-郑阿奇
2011/07/04 PHP
使用composer安装使用thinkphp6.0框架问题【视频教程】
2019/10/01 PHP
php封装的page分页类完整实例代码
2020/02/01 PHP
用document.documentElement取代document.body的原因分析
2009/11/12 Javascript
Javascript实现关联数据(Linked Data)查询及注意细节
2013/02/22 Javascript
jQuery实现类似淘宝网图片放大效果的方法
2015/07/08 Javascript
JavaScript实现点击文本自动定位到下拉框选中操作
2016/06/15 Javascript
JavaScript BASE64算法实现(完美解决中文乱码)
2017/01/10 Javascript
微信小程序 轮播图swiper详解及实例(源码下载)
2017/01/11 Javascript
vue2 如何实现div contenteditable=“true”(类似于v-model)的效果
2017/02/08 Javascript
微信小程序使用navigateTo数据传递的实例
2017/09/26 Javascript
JS基于对象的特性实现去除数组中重复项功能详解
2017/11/17 Javascript
vue结合Echarts实现点击高亮效果的示例
2018/03/17 Javascript
jQuery实现表单动态添加与删除数据操作示例
2018/07/03 jQuery
JQuery Ajax动态加载Table数据的实例讲解
2018/08/09 jQuery
d3绘制基本的柱形图的实现代码
2018/12/12 Javascript
小程序开发中如何使用async-await并封装公共异步请求的方法
2019/01/20 Javascript
ES6知识点整理之函数数组参数的默认值及其解构应用示例
2019/04/17 Javascript
vue 引用自定义ttf、otf、在线字体的方法
2019/05/09 Javascript
Vue代码整洁之去重方法整理
2019/08/06 Javascript
JS数据类型STRING使用实例解析
2019/12/18 Javascript
Vue中axios拦截器如何单独配置token
2019/12/27 Javascript
微信小程序实现吸顶效果
2020/01/08 Javascript
Vue切换组件实现返回后不重置数据,保留历史设置操作
2020/07/21 Javascript
[59:08]Ti4 冒泡赛第二天 NEWBEE vs Titan 2
2014/07/15 DOTA
简介二分查找算法与相关的Python实现示例
2015/08/26 Python
老生常谈python的私有公有属性(必看篇)
2017/06/09 Python
python3应用windows api对后台程序窗口及桌面截图并保存的方法
2019/08/27 Python
基于Python爬虫采集天气网实时信息
2020/06/05 Python
如何基于Python按行合并两个txt
2020/11/03 Python
python 实现简单的计算器(gui界面)
2020/11/11 Python
python实现马丁策略回测3000只股票的实例代码
2021/01/22 Python
使用CSS3制作响应式导航菜单的方法
2015/07/12 HTML / CSS
档案信息化建设方案
2014/05/16 职场文书
SQL实战演练之网上商城数据库商品类别数据操作
2021/10/24 MySQL