浅谈pytorch grad_fn以及权重梯度不更新的问题


Posted in Python onAugust 20, 2019

前提:我训练的是二分类网络,使用语言为pytorch

Varibale包含三个属性:

data:存储了Tensor,是本体的数据

grad:保存了data的梯度,本事是个Variable而非Tensor,与data形状一致

grad_fn:指向Function对象,用于反向传播的梯度计算之用

在构建网络时,刚开始的错误为:没有可以grad_fn属性的变量。

百度后得知要对需要进行迭代更新的变量设置requires_grad=True ,操作如下:

train_pred = Variable(train_pred.float(), requires_grad=True)`

这样设置之后网络是跑起来了,但是准确率一直没有提升,很明显可以看出网络什么都没学到。

我输出 model.parameters() (网络内部的权重和偏置)查看,发现它的权重并没有更新,一直是同一个值,至此可以肯定网络什么都没学到,还是迭代那里出了问题。

询问同门后发现问题不在这里。

计算loss时,target与train_pred的size不匹配,我以以下操作修改了train_pred,使两者尺寸一致,才导致了上述问题。

train_pred = model(data)
  train_pred = torch.max(train_pred, 1)[1].data.squeeze()
  train_pred = Variable(train_pred.float(), requires_grad=False)
  train_loss = F.binary_cross_entropy(validation_pred.float(), target)
  train_loss.backward()

对train_pred多次处理后,它已无法正确地反向传播,实际上应该更改target,使其与train_pred size一致。

重点!!!要想loss正确反向传播,应直接将model(data)传入loss函数。

最终修改代码如下:

for batch_idx, (data, target) in enumerate(train_loader):
  # Get Samples
  label = target.view(target.size(0), 1).long()
  target_onehot = torch.zeros(data.shape[0], args.num_classes).scatter_(1, label, 1)
  data, target_onehot = Variable(data.cuda()), Variable(target_onehot.cuda().float())
  
  model.zero_grad()

  # Predict
  train_pred = model(data)
  train_loss = F.binary_cross_entropy(train_pred, target_onehot)
  train_loss.backward()
  optimizer.step()

以上这篇浅谈pytorch grad_fn以及权重梯度不更新的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python脚本实现xls(xlsx)转成csv
Apr 10 Python
深入理解python多进程编程
Jun 12 Python
python实现折半查找和归并排序算法
Apr 14 Python
Python爬虫之模拟知乎登录的方法教程
May 25 Python
解决python3中自定义wsgi函数,make_server函数报错的问题
Nov 21 Python
浅谈pandas中Dataframe的查询方法([], loc, iloc, at, iat, ix)
Apr 10 Python
python使用插值法画出平滑曲线
Dec 15 Python
python命名空间(namespace)简单介绍
Aug 10 Python
Python lxml模块的基本使用方法分析
Dec 21 Python
python基于property()函数定义属性
Jan 22 Python
Python图像识别+KNN求解数独的实现
Nov 13 Python
如何用Django处理gzip数据流
Jan 29 Python
解决Pytorch 训练与测试时爆显存(out of memory)的问题
Aug 20 #Python
python中用logging实现日志滚动和过期日志删除功能
Aug 20 #Python
python3中替换python2中cmp函数的实现
Aug 20 #Python
python 并发编程 多路复用IO模型详解
Aug 20 #Python
关于pytorch中网络loss传播和参数更新的理解
Aug 20 #Python
对pytorch中的梯度更新方法详解
Aug 20 #Python
PyTorch: 梯度下降及反向传播的实例详解
Aug 20 #Python
You might like
PHP获取指定日期是星期几的实现方法
2016/11/30 PHP
php文件管理基本功能简单操作
2017/01/16 PHP
Yii2 队列 shmilyzxt/yii2-queue 简单概述
2017/08/02 PHP
PHP基于GD2函数库实现验证码功能示例
2019/01/27 PHP
PHP设计模式(七)组合模式Composite实例详解【结构型】
2020/05/02 PHP
js仿黑客帝国字母掉落效果代码分享
2020/11/08 Javascript
jQuery实现仿百度首页滑动伸缩展开的添加服务效果代码
2015/09/09 Javascript
直接拿来用的页面跳转进度条JS实现
2016/01/06 Javascript
jQuery+CSS实现一个侧滑导航菜单代码
2016/05/09 Javascript
浅谈Sublime Text 3运行JavaScript控制台
2016/06/06 Javascript
微信小程序(六):列表上拉加载下拉刷新示例
2017/01/13 Javascript
layui前段框架日期控件使用方法详解
2017/05/19 Javascript
JavaScript数组去重算法实例小结
2018/05/07 Javascript
nodejs基础之buffer缓冲区用法分析
2018/12/26 NodeJs
Vue2(三)实现子菜单展开收缩,带动画效果实现方法
2019/04/28 Javascript
Vue中img的src是动态渲染时不显示的解决
2019/11/14 Javascript
Vue通过配置WebSocket并实现群聊功能
2019/12/31 Javascript
[50:58]2018DOTA2亚洲邀请赛3月29日 小组赛A组OpTic VS Newbee
2018/03/30 DOTA
[53:49]LGD vs Fnatic 2018国际邀请赛小组赛BO2 第二场 8.18
2018/08/19 DOTA
用Python实现命令行闹钟脚本实例
2016/09/05 Python
梯度下降法介绍及利用Python实现的方法示例
2017/07/12 Python
pandas 实现字典转换成DataFrame的方法
2018/07/04 Python
python 直接赋值和copy的区别详解
2019/08/07 Python
深入浅析python变量加逗号,的含义
2020/02/22 Python
python中的插入排序的简单用法
2021/01/19 Python
CSS3,线性渐变(linear-gradient)的使用总结
2017/01/09 HTML / CSS
资深生产主管自我评价
2013/09/22 职场文书
自我评价正确写法范文
2013/12/10 职场文书
校友会欢迎辞
2014/01/13 职场文书
八项规定整改措施
2014/02/12 职场文书
学生保证书
2015/01/16 职场文书
入党积极分子个人总结
2015/03/02 职场文书
时尚女魔头观后感
2015/06/04 职场文书
围城读书笔记
2015/06/26 职场文书
结婚典礼致辞
2015/07/28 职场文书
PostgreSQL通过oracle_fdw访问Oracle数据的实现步骤
2021/05/21 PostgreSQL