pytorch如何冻结某层参数的实现


Posted in Python onJanuary 10, 2020

在迁移学习finetune时我们通常需要冻结前几层的参数不参与训练,在Pytorch中的实现如下:

class Model(nn.Module):
 def __init__(self):
  super(Transfer_model, self).__init__()
  self.linear1 = nn.Linear(20, 50)
  self.linear2 = nn.Linear(50, 20)
  self.linear3 = nn.Linear(20, 2)

 def forward(self, x):
 pass

假如我们想要冻结linear1层,需要做如下操作:

model = Model()
# 这里是一般情况,共享层往往不止一层,所以做一个for循环
for para in model.linear1.parameters():
 para.requires_grad = False
# 假如真的只有一层也可以这样操作:
# model.linear1.weight.requires_grad = False

 最后我们需要将需要优化的参数传入优化器,不需要传入的参数过滤掉,所以要用到filter()函数。

optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.1)

其它的博客中都没有讲解filter()函数的作用,在这里我简单讲一下有助于更好的理解。

filter(function, iterable)

  • function: 判断函数
  • iterable: 可迭代对象

filter() 函数用于过滤序列,过滤掉不符合条件的元素,返回一个迭代器对象,如果要转换为列表,可以使用 list() 来转换。

该接收两个参数,第一个为函数,第二个为序列,序列的每个元素作为参数传递给函数进行判,然后返回 True 或 False,最后将返回 True 的元素放到新列表中。

filter()函数将requires_grad = True的参数传入优化器进行反向传播,requires_grad = False的则被过滤掉。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
跟老齐学Python之集合(set)
Sep 24 Python
在树莓派2或树莓派B+上安装Python和OpenCV的教程
Mar 30 Python
python 计算两个日期相差多少个月实例代码
May 24 Python
python仿evething的文件搜索器实例代码
May 13 Python
从列表或字典创建Pandas的DataFrame对象的方法
Jul 06 Python
python3 深浅copy对比详解
Aug 12 Python
Python下利用BeautifulSoup解析HTML的实现
Jan 17 Python
win10安装python3.6的常见问题
Jul 01 Python
Python3如何实现Win10桌面自动切换
Aug 11 Python
Python从MySQL数据库中面抽取试题,生成试卷
Jan 14 Python
使用Python+Appuim 清理微信的方法
Jan 26 Python
python3 sqlite3限制条件查询的操作
Apr 07 Python
python标识符命名规范原理解析
Jan 10 #Python
pytorch1.0中torch.nn.Conv2d用法详解
Jan 10 #Python
pytorch 利用lstm做mnist手写数字识别分类的实例
Jan 10 #Python
Tensorflow Summary用法学习笔记
Jan 10 #Python
TENSORFLOW变量作用域(VARIABLE SCOPE)
Jan 10 #Python
python numpy数组复制使用实例解析
Jan 10 #Python
关于Pytorch的MNIST数据集的预处理详解
Jan 10 #Python
You might like
让codeigniter与swfupload整合的最佳解决方案
2014/06/12 PHP
PHP实现递归无限级分类
2015/10/22 PHP
PHP6新特性分析
2016/03/03 PHP
tp5修改(实现即点即改)
2019/10/18 PHP
javascript 面向对象思想 附源码
2009/07/07 Javascript
JQuery 遮罩层实现(mask)实现代码
2010/01/09 Javascript
jQuery CSS()方法改变现有的CSS样式表
2014/09/09 Javascript
jQuery判断一个元素是否可见的方法
2015/06/05 Javascript
javascript文本模板用法实例
2015/07/31 Javascript
js和jQuery设置Opacity半透明 兼容IE6
2016/05/24 Javascript
AngularJS基础 ng-if 指令用法
2016/08/01 Javascript
浅谈JavaScript事件绑定的常用方法及其优缺点分析
2016/11/01 Javascript
原生JavaScript实现Tooltip浮动提示框特效
2017/03/07 Javascript
JavaScript模块化之使用requireJS按需加载
2017/04/12 Javascript
Angular中自定义Debounce Click指令防止重复点击
2017/07/26 Javascript
Kindeditor单独调用多图上传实例
2017/07/31 Javascript
node koa2实现上传图片并且同步上传到七牛云存储
2017/07/31 Javascript
使用vue中的混入mixin优化表单验证插件问题
2019/07/02 Javascript
解决layui数据表格Date日期格式的回显Object的问题
2019/09/19 Javascript
[40:31]Secret vs Alliacne 2019国际邀请赛小组赛 BO2 第二场 8.15
2019/08/17 DOTA
浅谈Python中数据解析
2015/05/05 Python
Python修改MP3文件的方法
2015/06/15 Python
详解Pytorch 使用Pytorch拟合多项式(多项式回归)
2018/05/24 Python
Python3.5面向对象与继承图文实例详解
2019/04/24 Python
Html5之自定义属性(data-,dataset)
2019/11/19 HTML / CSS
写一个函数,求一个字符串的长度。在main函数中输入字符串,并输出其长度
2015/11/18 面试题
大学生四个方面的自我评价
2013/09/19 职场文书
高三自我鉴定范文
2013/10/19 职场文书
《美丽的黄昏》教学反思
2014/02/28 职场文书
婚礼司仪主持词
2014/03/14 职场文书
读书演讲主持词
2014/03/18 职场文书
学校庆元旦歌咏比赛主持词
2014/03/18 职场文书
质量标语大全
2014/06/12 职场文书
敬老月活动总结
2014/08/28 职场文书
义卖募捐活动总结
2015/05/09 职场文书
新员工实习期个人工作总结
2015/10/15 职场文书