pytorch如何冻结某层参数的实现


Posted in Python onJanuary 10, 2020

在迁移学习finetune时我们通常需要冻结前几层的参数不参与训练,在Pytorch中的实现如下:

class Model(nn.Module):
 def __init__(self):
  super(Transfer_model, self).__init__()
  self.linear1 = nn.Linear(20, 50)
  self.linear2 = nn.Linear(50, 20)
  self.linear3 = nn.Linear(20, 2)

 def forward(self, x):
 pass

假如我们想要冻结linear1层,需要做如下操作:

model = Model()
# 这里是一般情况,共享层往往不止一层,所以做一个for循环
for para in model.linear1.parameters():
 para.requires_grad = False
# 假如真的只有一层也可以这样操作:
# model.linear1.weight.requires_grad = False

 最后我们需要将需要优化的参数传入优化器,不需要传入的参数过滤掉,所以要用到filter()函数。

optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.1)

其它的博客中都没有讲解filter()函数的作用,在这里我简单讲一下有助于更好的理解。

filter(function, iterable)

  • function: 判断函数
  • iterable: 可迭代对象

filter() 函数用于过滤序列,过滤掉不符合条件的元素,返回一个迭代器对象,如果要转换为列表,可以使用 list() 来转换。

该接收两个参数,第一个为函数,第二个为序列,序列的每个元素作为参数传递给函数进行判,然后返回 True 或 False,最后将返回 True 的元素放到新列表中。

filter()函数将requires_grad = True的参数传入优化器进行反向传播,requires_grad = False的则被过滤掉。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python通过文件头判断文件类型
Oct 30 Python
Python字符串拼接六种方法介绍
Dec 18 Python
Pandas 数据处理,数据清洗详解
Jul 10 Python
wxPython的安装与使用教程
Aug 31 Python
Python中collections模块的基本使用教程
Dec 07 Python
python 判断矩阵中每行非零个数的方法
Jan 26 Python
Python转换时间的图文方法
Jul 01 Python
使用Python获取当前工作目录和执行命令的位置
Mar 09 Python
Python 在 VSCode 中使用 IPython Kernel 的方法详解
Sep 05 Python
pycharm 2020 1.1的安装流程
Sep 29 Python
10个顶级Python实用库推荐
Mar 04 Python
pytorch常用数据类型所占字节数对照表一览
May 17 Python
python标识符命名规范原理解析
Jan 10 #Python
pytorch1.0中torch.nn.Conv2d用法详解
Jan 10 #Python
pytorch 利用lstm做mnist手写数字识别分类的实例
Jan 10 #Python
Tensorflow Summary用法学习笔记
Jan 10 #Python
TENSORFLOW变量作用域(VARIABLE SCOPE)
Jan 10 #Python
python numpy数组复制使用实例解析
Jan 10 #Python
关于Pytorch的MNIST数据集的预处理详解
Jan 10 #Python
You might like
两个开源的Php输出Excel文件类
2010/02/08 PHP
MyEclipse常用配置图文教程
2014/09/11 PHP
PHP中trim()函数简单使用指南
2015/04/16 PHP
thinkPHP js文件中U方法不被解析问题的解决方法
2016/12/05 PHP
PHP单例模式模拟Java Bean实现方法示例
2018/12/07 PHP
PHP通过GD库实现验证码功能示例
2019/02/23 PHP
PHP中“=>
2019/03/01 PHP
event.srcElement+表格应用
2006/08/29 Javascript
javascript parseInt 函数分析(转)
2009/03/21 Javascript
js window.onload 加载多个函数的方法
2009/11/02 Javascript
javascript 处理事件绑定的一些兼容写法
2009/12/24 Javascript
JS实现按比例缩放图片的方法(附C#版代码)
2015/12/08 Javascript
Javascript实现鼠标框选操作  不是点击选取
2016/04/14 Javascript
JS取数字小数点后两位或n位的简单方法
2016/10/24 Javascript
bootstrap模态框消失问题的解决方法
2016/12/02 Javascript
完美解决node.js中使用https请求报CERT_UNTRUSTED的问题
2017/01/08 Javascript
Node.js  事件循环详解及实例
2017/08/06 Javascript
Vue 配合eiement动态路由,权限验证的方法
2018/09/26 Javascript
jquery实现直播弹幕效果
2019/11/28 jQuery
小程序自定义导航栏兼容适配所有机型(附完整案例)
2020/04/26 Javascript
vue+element使用动态加载路由方式实现三级菜单页面显示的操作
2020/08/04 Javascript
浅谈vue.watch的触发条件是什么
2020/11/07 Javascript
python通过urllib2获取带有中文参数url内容的方法
2015/03/13 Python
Python批量转换文件编码格式
2015/05/17 Python
在Django中编写模版节点及注册标签的方法
2015/07/20 Python
浅谈flask源码之请求过程
2018/07/26 Python
python读取图像矩阵文件并转换为向量实例
2020/06/18 Python
CSS实现聊天气泡效果
2020/04/26 HTML / CSS
HTML5 embed标签定义和用法详解
2014/05/09 HTML / CSS
美国玩具公司:U.S.Toy
2018/05/19 全球购物
Weblogc domain问题
2014/01/27 面试题
Linux的文件类型
2012/03/07 面试题
2014银行领导班子群众路线对照检查材料思想汇报
2014/09/17 职场文书
市场部岗位职责范本
2015/04/15 职场文书
初中体育课教学反思
2016/02/16 职场文书
小学音乐课歌曲《堆雪人》教学反思
2016/02/18 职场文书