浅谈pytorch grad_fn以及权重梯度不更新的问题


Posted in Python onAugust 20, 2019

前提:我训练的是二分类网络,使用语言为pytorch

Varibale包含三个属性:

data:存储了Tensor,是本体的数据

grad:保存了data的梯度,本事是个Variable而非Tensor,与data形状一致

grad_fn:指向Function对象,用于反向传播的梯度计算之用

在构建网络时,刚开始的错误为:没有可以grad_fn属性的变量。

百度后得知要对需要进行迭代更新的变量设置requires_grad=True ,操作如下:

train_pred = Variable(train_pred.float(), requires_grad=True)`

这样设置之后网络是跑起来了,但是准确率一直没有提升,很明显可以看出网络什么都没学到。

我输出 model.parameters() (网络内部的权重和偏置)查看,发现它的权重并没有更新,一直是同一个值,至此可以肯定网络什么都没学到,还是迭代那里出了问题。

询问同门后发现问题不在这里。

计算loss时,target与train_pred的size不匹配,我以以下操作修改了train_pred,使两者尺寸一致,才导致了上述问题。

train_pred = model(data)
  train_pred = torch.max(train_pred, 1)[1].data.squeeze()
  train_pred = Variable(train_pred.float(), requires_grad=False)
  train_loss = F.binary_cross_entropy(validation_pred.float(), target)
  train_loss.backward()

对train_pred多次处理后,它已无法正确地反向传播,实际上应该更改target,使其与train_pred size一致。

重点!!!要想loss正确反向传播,应直接将model(data)传入loss函数。

最终修改代码如下:

for batch_idx, (data, target) in enumerate(train_loader):
  # Get Samples
  label = target.view(target.size(0), 1).long()
  target_onehot = torch.zeros(data.shape[0], args.num_classes).scatter_(1, label, 1)
  data, target_onehot = Variable(data.cuda()), Variable(target_onehot.cuda().float())
  
  model.zero_grad()

  # Predict
  train_pred = model(data)
  train_loss = F.binary_cross_entropy(train_pred, target_onehot)
  train_loss.backward()
  optimizer.step()

以上这篇浅谈pytorch grad_fn以及权重梯度不更新的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python如何判断数独是否合法
Sep 08 Python
梯度下降法介绍及利用Python实现的方法示例
Jul 12 Python
基于Django用户认证系统详解
Feb 21 Python
pyside+pyqt实现鼠标右键菜单功能
Dec 08 Python
Python3.5装饰器原理及应用实例详解
Apr 30 Python
python日期相关操作实例小结
Jun 24 Python
如何使用Python标准库进行性能测试
Jun 25 Python
使用pandas读取文件的实现
Jul 31 Python
numpy.random.shuffle打乱顺序函数的实现
Sep 10 Python
详解python中*号的用法
Oct 21 Python
linux 下python多线程递归复制文件夹及文件夹中的文件
Jan 02 Python
ansible-playbook实现自动部署KVM及安装python3的详细教程
May 11 Python
解决Pytorch 训练与测试时爆显存(out of memory)的问题
Aug 20 #Python
python中用logging实现日志滚动和过期日志删除功能
Aug 20 #Python
python3中替换python2中cmp函数的实现
Aug 20 #Python
python 并发编程 多路复用IO模型详解
Aug 20 #Python
关于pytorch中网络loss传播和参数更新的理解
Aug 20 #Python
对pytorch中的梯度更新方法详解
Aug 20 #Python
PyTorch: 梯度下降及反向传播的实例详解
Aug 20 #Python
You might like
php SQL防注入代码集合
2008/04/25 PHP
php图片上传类 附调用方法
2016/05/15 PHP
JSON 客户端和服务器端的格式转换
2009/08/27 Javascript
Javascript学习笔记4 Eval函数
2010/01/11 Javascript
在一个js文件里远程调用jquery.js会在ie8下的一个奇怪问题
2010/11/28 Javascript
13 个JavaScript 性能提升技巧分享
2012/07/26 Javascript
密码强度检测效果实现原理与代码
2013/01/04 Javascript
Javascript查询DBpedia小应用实例学习
2013/03/07 Javascript
ExtJS4如何自动生成控制grid的列显示、隐藏的checkbox
2014/05/02 Javascript
JavaScript正则表达式之multiline属性的应用
2015/06/16 Javascript
Bootstrap 最常用的JS插件系列总结(图片轮播、标签切换等)
2016/07/14 Javascript
js正则表达式验证密码强度【推荐】
2017/03/03 Javascript
JavaScript实现左右下拉框动态增删示例
2017/03/09 Javascript
bootstrap table实现单击单元格可编辑功能
2017/03/28 Javascript
vue-cli中的webpack配置详解
2017/09/25 Javascript
Vue实现简易翻页效果源码分享
2018/11/08 Javascript
vue监听用户输入和点击功能
2019/09/27 Javascript
javascript json对象小技巧之键名作为变量用法分析
2019/11/11 Javascript
vue实现前端分页完整代码
2020/06/17 Javascript
详解Django缓存处理中Vary头部的使用
2015/07/24 Python
Python3 实现随机生成一组不重复数并按行写入文件
2018/04/09 Python
python 利用文件锁单例执行脚本的方法
2019/02/19 Python
Python文件打开方式实例详解【a、a+、r+、w+区别】
2019/03/30 Python
Django的用户模块与权限系统的示例代码
2019/07/24 Python
python使用 request 发送表单数据操作示例
2019/09/25 Python
pycharm中导入模块错误时提示Try to run this command from the system terminal
2020/03/26 Python
解决python中显示图片的plt.imshow plt.show()内存泄漏问题
2020/04/24 Python
Python面向对象多态实现原理及代码实例
2020/09/16 Python
《小动物过冬》教学反思
2014/04/17 职场文书
高中课程设置方案
2014/05/28 职场文书
校园元旦活动总结
2014/07/09 职场文书
老人节标语大全
2014/10/08 职场文书
学校政风行风自查自纠报告
2014/10/21 职场文书
CSS3 天气图标动画效果
2021/04/06 HTML / CSS
vue中的可拖拽宽度div的实现示例
2022/04/08 Vue.js
Sql Server 行数据的某列值想作为字段列显示的方法
2022/04/20 SQL Server