pytorch如何冻结某层参数的实现


Posted in Python onJanuary 10, 2020

在迁移学习finetune时我们通常需要冻结前几层的参数不参与训练,在Pytorch中的实现如下:

class Model(nn.Module):
 def __init__(self):
  super(Transfer_model, self).__init__()
  self.linear1 = nn.Linear(20, 50)
  self.linear2 = nn.Linear(50, 20)
  self.linear3 = nn.Linear(20, 2)

 def forward(self, x):
 pass

假如我们想要冻结linear1层,需要做如下操作:

model = Model()
# 这里是一般情况,共享层往往不止一层,所以做一个for循环
for para in model.linear1.parameters():
 para.requires_grad = False
# 假如真的只有一层也可以这样操作:
# model.linear1.weight.requires_grad = False

 最后我们需要将需要优化的参数传入优化器,不需要传入的参数过滤掉,所以要用到filter()函数。

optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.1)

其它的博客中都没有讲解filter()函数的作用,在这里我简单讲一下有助于更好的理解。

filter(function, iterable)

  • function: 判断函数
  • iterable: 可迭代对象

filter() 函数用于过滤序列,过滤掉不符合条件的元素,返回一个迭代器对象,如果要转换为列表,可以使用 list() 来转换。

该接收两个参数,第一个为函数,第二个为序列,序列的每个元素作为参数传递给函数进行判,然后返回 True 或 False,最后将返回 True 的元素放到新列表中。

filter()函数将requires_grad = True的参数传入优化器进行反向传播,requires_grad = False的则被过滤掉。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python套接字流重定向实例汇总
Mar 03 Python
python解决方案:WindowsError: [Error 2]
Aug 28 Python
简单谈谈Python中的几种常见的数据类型
Feb 10 Python
Python实现爬取百度贴吧帖子所有楼层图片的爬虫示例
Apr 26 Python
Tensorflow使用tfrecord输入数据格式
Jun 19 Python
Python 获取中文字拼音首个字母的方法
Nov 28 Python
django框架基于模板 生成 excel(xls) 文件操作示例
Jun 19 Python
Python jieba库用法及实例解析
Nov 04 Python
Python的历史与优缺点整理
May 26 Python
python中如何进行连乘计算
May 28 Python
Python3中小括号()、中括号[]、花括号{}的区别详解
Nov 15 Python
告别网页搜索!教你用python实现一款属于自己的翻译词典软件
Jun 03 Python
python标识符命名规范原理解析
Jan 10 #Python
pytorch1.0中torch.nn.Conv2d用法详解
Jan 10 #Python
pytorch 利用lstm做mnist手写数字识别分类的实例
Jan 10 #Python
Tensorflow Summary用法学习笔记
Jan 10 #Python
TENSORFLOW变量作用域(VARIABLE SCOPE)
Jan 10 #Python
python numpy数组复制使用实例解析
Jan 10 #Python
关于Pytorch的MNIST数据集的预处理详解
Jan 10 #Python
You might like
全国FM电台频率大全 - 24 贵州省
2020/03/11 无线电
MOTOROLA 摩托罗拉 MODEL 66-XI五灯中波收音机
2021/03/02 无线电
php在文件指定行中写入代码的方法
2012/05/23 PHP
WordPress中注册菜单与调用菜单的方法详解
2015/12/18 PHP
PHP读取word文档的方法分析【基于COM组件】
2017/08/01 PHP
setTimeout 不断吐食CPU的问题分析
2009/04/01 Javascript
JavaScript Object的extend是一个常用的功能
2009/12/02 Javascript
jQuery学习总结之元素的相对定位和选择器(持续更新)
2011/04/26 Javascript
javascript (用setTimeout而非setInterval)
2011/12/28 Javascript
模拟电子签章盖章效果的jQuery插件源码
2013/06/24 Javascript
javascript实用小函数使用介绍
2013/11/11 Javascript
理解AngularJs篇:30分钟快速掌握AngularJs
2016/12/23 Javascript
详解Angular4中路由Router类的跳转navigate
2017/06/09 Javascript
jQuery实现列表的增加和删除功能
2018/06/14 jQuery
深入理解JS中Number(),parseInt(),parseFloat()三者比较
2018/08/24 Javascript
解决vue-cli项目webpack打包后iconfont文件路径的问题
2018/09/01 Javascript
JavaScript类的继承多种实现方法
2020/05/30 Javascript
[51:20]完美世界DOTA2联赛PWL S2 Magma vs PXG 第一场 11.28
2020/12/01 DOTA
Python的爬虫程序编写框架Scrapy入门学习教程
2016/07/02 Python
对python GUI实现完美进度条的示例详解
2018/12/13 Python
神经网络相关之基础概念的讲解
2018/12/29 Python
Python时间序列处理之ARIMA模型的使用讲解
2019/04/02 Python
解决pycharm remote deployment 配置的问题
2019/06/27 Python
python中将两组数据放在一起按照某一固定顺序shuffle的实例
2019/07/15 Python
python生成任意频率正弦波方式
2020/02/25 Python
python 如何引入协程和原理分析
2020/11/30 Python
css3 利用transform打造走动的2D时钟
2020/10/20 HTML / CSS
HTML5中使用postMessage实现Ajax跨域请求的方法
2016/04/19 HTML / CSS
钉钉企业内部H5微应用开发详解
2020/05/12 HTML / CSS
蔻驰美国官网:COACH美国
2016/08/18 全球购物
尤为Wconcept中国官网:韩国设计师品牌服饰
2019/01/10 全球购物
街头时尚在线:JESSICABUURMAN
2019/06/16 全球购物
The North Face北面法国官网:美国著名户外品牌
2019/11/01 全球购物
会计专业应届生求职信
2013/11/24 职场文书
七年级作文之关于奶奶
2019/10/29 职场文书
分析MySQL抛出异常的几种常见解决方式
2021/05/18 MySQL