pytorch如何冻结某层参数的实现


Posted in Python onJanuary 10, 2020

在迁移学习finetune时我们通常需要冻结前几层的参数不参与训练,在Pytorch中的实现如下:

class Model(nn.Module):
 def __init__(self):
  super(Transfer_model, self).__init__()
  self.linear1 = nn.Linear(20, 50)
  self.linear2 = nn.Linear(50, 20)
  self.linear3 = nn.Linear(20, 2)

 def forward(self, x):
 pass

假如我们想要冻结linear1层,需要做如下操作:

model = Model()
# 这里是一般情况,共享层往往不止一层,所以做一个for循环
for para in model.linear1.parameters():
 para.requires_grad = False
# 假如真的只有一层也可以这样操作:
# model.linear1.weight.requires_grad = False

 最后我们需要将需要优化的参数传入优化器,不需要传入的参数过滤掉,所以要用到filter()函数。

optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.1)

其它的博客中都没有讲解filter()函数的作用,在这里我简单讲一下有助于更好的理解。

filter(function, iterable)

  • function: 判断函数
  • iterable: 可迭代对象

filter() 函数用于过滤序列,过滤掉不符合条件的元素,返回一个迭代器对象,如果要转换为列表,可以使用 list() 来转换。

该接收两个参数,第一个为函数,第二个为序列,序列的每个元素作为参数传递给函数进行判,然后返回 True 或 False,最后将返回 True 的元素放到新列表中。

filter()函数将requires_grad = True的参数传入优化器进行反向传播,requires_grad = False的则被过滤掉。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python设计模式之单例模式实例
Apr 26 Python
python实现的多线程端口扫描功能示例
Jan 21 Python
利用python爬取散文网的文章实例教程
Jun 18 Python
Python复数属性和方法运算操作示例
Jul 21 Python
python字典DICT类型合并详解
Aug 17 Python
python调用c++返回带成员指针的类指针实例
Dec 12 Python
Python3将jpg转为pdf文件的方法示例
Dec 13 Python
Python3 集合set入门基础
Feb 10 Python
Python3 hashlib密码散列算法原理详解
Mar 30 Python
Python 实现一行输入多个数字(用空格隔开)
Apr 29 Python
Django启动时找不到mysqlclient问题解决方案
Nov 11 Python
python 实现德洛内三角剖分的操作
Apr 22 Python
python标识符命名规范原理解析
Jan 10 #Python
pytorch1.0中torch.nn.Conv2d用法详解
Jan 10 #Python
pytorch 利用lstm做mnist手写数字识别分类的实例
Jan 10 #Python
Tensorflow Summary用法学习笔记
Jan 10 #Python
TENSORFLOW变量作用域(VARIABLE SCOPE)
Jan 10 #Python
python numpy数组复制使用实例解析
Jan 10 #Python
关于Pytorch的MNIST数据集的预处理详解
Jan 10 #Python
You might like
关于mysql 字段的那个点为是定界符
2007/01/15 PHP
利用PHP实现智能文件类型检测的实现代码
2011/08/02 PHP
PHP处理Oracle的CLOB实例
2014/11/03 PHP
php如何实现只替换一次或N次
2015/10/29 PHP
简介WordPress中用于获取首页和站点链接的PHP函数
2015/12/17 PHP
PHPCMS2008广告模板SQL注入漏洞修复
2016/10/11 PHP
JAVASCRIPT HashTable
2007/01/22 Javascript
javascript在一段文字中的光标处插入其他文字
2007/08/26 Javascript
javascript 多级checkbox选择效果
2009/08/20 Javascript
jQuery chili图片远处放大插件
2009/11/30 Javascript
AngularJS基础知识笔记之表格
2015/05/10 Javascript
JavaScript返回上一页的三种方法及区别介绍
2015/07/04 Javascript
论Bootstrap3和Foundation5网格系统的异同
2016/05/16 Javascript
javascript之Array 数组对象详解
2016/06/07 Javascript
浅谈Node.js:理解stream
2016/12/08 Javascript
js数组与字符串常用方法总结
2017/01/13 Javascript
Node.JS中事件轮询(Event Loop)的解析
2017/02/25 Javascript
jQuery实现的粘性滚动导航栏效果实例【附源码下载】
2017/10/19 jQuery
webpack4.x开发环境配置详解
2018/08/04 Javascript
javascript中的event loop事件循环详解
2018/12/14 Javascript
js验证身份证号码记录的方法
2019/04/26 Javascript
bootstrap实现tab选项卡切换
2020/08/09 Javascript
以Flask为例讲解Python的框架的使用方法
2015/04/29 Python
Python2中的raw_input() 与 input()
2015/06/12 Python
python更新列表的方法
2015/07/28 Python
Python中的条件判断语句基础学习教程
2016/02/07 Python
Python3计算三角形的面积代码
2017/12/18 Python
python+matplotlib绘制简单的海豚(顶点和节点的操作)
2018/01/02 Python
python异常处理和日志处理方式
2019/12/24 Python
英国户外服装品牌:Craghoppers
2019/04/25 全球购物
EJB3.1都有哪些改进
2012/11/17 面试题
实习单位意见
2015/06/04 职场文书
狼牙山五壮士观后感
2015/06/09 职场文书
几款流行的HTML5 UI框架比较(小结)
2021/04/08 HTML / CSS
JPA如何使用entityManager执行SQL并指定返回类型
2021/06/15 Java/Android
动作冒险《Hell Is Us》将采用虚幻5 消灭怪物探索王国
2022/04/13 其他游戏