浅谈pytorch grad_fn以及权重梯度不更新的问题


Posted in Python onAugust 20, 2019

前提:我训练的是二分类网络,使用语言为pytorch

Varibale包含三个属性:

data:存储了Tensor,是本体的数据

grad:保存了data的梯度,本事是个Variable而非Tensor,与data形状一致

grad_fn:指向Function对象,用于反向传播的梯度计算之用

在构建网络时,刚开始的错误为:没有可以grad_fn属性的变量。

百度后得知要对需要进行迭代更新的变量设置requires_grad=True ,操作如下:

train_pred = Variable(train_pred.float(), requires_grad=True)`

这样设置之后网络是跑起来了,但是准确率一直没有提升,很明显可以看出网络什么都没学到。

我输出 model.parameters() (网络内部的权重和偏置)查看,发现它的权重并没有更新,一直是同一个值,至此可以肯定网络什么都没学到,还是迭代那里出了问题。

询问同门后发现问题不在这里。

计算loss时,target与train_pred的size不匹配,我以以下操作修改了train_pred,使两者尺寸一致,才导致了上述问题。

train_pred = model(data)
  train_pred = torch.max(train_pred, 1)[1].data.squeeze()
  train_pred = Variable(train_pred.float(), requires_grad=False)
  train_loss = F.binary_cross_entropy(validation_pred.float(), target)
  train_loss.backward()

对train_pred多次处理后,它已无法正确地反向传播,实际上应该更改target,使其与train_pred size一致。

重点!!!要想loss正确反向传播,应直接将model(data)传入loss函数。

最终修改代码如下:

for batch_idx, (data, target) in enumerate(train_loader):
  # Get Samples
  label = target.view(target.size(0), 1).long()
  target_onehot = torch.zeros(data.shape[0], args.num_classes).scatter_(1, label, 1)
  data, target_onehot = Variable(data.cuda()), Variable(target_onehot.cuda().float())
  
  model.zero_grad()

  # Predict
  train_pred = model(data)
  train_loss = F.binary_cross_entropy(train_pred, target_onehot)
  train_loss.backward()
  optimizer.step()

以上这篇浅谈pytorch grad_fn以及权重梯度不更新的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python基于csv模块实现读取与写入csv数据的方法
Jan 18 Python
Python通过调用mysql存储过程实现更新数据功能示例
Apr 03 Python
解决Pycharm后台indexing导致不能run的问题
Jun 27 Python
python实现人工智能Ai抠图功能
Sep 05 Python
python绘制随机网络图形示例
Nov 21 Python
Python:slice与indices的用法
Nov 25 Python
Pyspark读取parquet数据过程解析
Mar 27 Python
Jupyter Notebook 文件默认目录的查看以及更改步骤
Apr 14 Python
Python图像阈值化处理及算法比对实例解析
Jun 19 Python
Python timeit模块原理及使用方法
Oct 10 Python
Django程序的优化技巧
Apr 29 Python
基于Python实现将列表数据生成折线图
Mar 23 Python
解决Pytorch 训练与测试时爆显存(out of memory)的问题
Aug 20 #Python
python中用logging实现日志滚动和过期日志删除功能
Aug 20 #Python
python3中替换python2中cmp函数的实现
Aug 20 #Python
python 并发编程 多路复用IO模型详解
Aug 20 #Python
关于pytorch中网络loss传播和参数更新的理解
Aug 20 #Python
对pytorch中的梯度更新方法详解
Aug 20 #Python
PyTorch: 梯度下降及反向传播的实例详解
Aug 20 #Python
You might like
解析php中eclipse 用空格替换 tab键
2013/06/24 PHP
PH P5.2至5.5、5.6的新增功能详解
2014/07/14 PHP
php cli配置文件问题分析
2015/10/15 PHP
微信自定义菜单的创建/查询/取消php示例代码
2016/08/05 PHP
PHP中$GLOBALS与global的区别详解
2019/03/21 PHP
搭建PhpStorm+PhpStudy开发环境的超详细教程
2020/09/17 PHP
JS中字符问题(二进制/十进制/十六进制及ASCII码之间的转换)
2008/11/03 Javascript
在vs2010中调试javascript代码方法
2011/02/11 Javascript
简单实用的反馈表单无刷新提交带验证
2013/11/15 Javascript
javascript实现可键盘控制的抽奖系统
2016/03/10 Javascript
使用JS正则表达式 替换括号,尖括号等
2016/11/29 Javascript
node.js发送邮件email的方法详解
2017/01/06 Javascript
cropper js基于vue的图片裁剪上传功能的实现代码
2018/03/01 Javascript
js实现点击按钮复制文本功能
2020/07/20 Javascript
vue-resource请求实现http登录拦截或者路由拦截的方法
2018/07/11 Javascript
vue-socket.io接收不到数据问题的解决方法
2020/05/13 Javascript
JS运算符优先级与表达式示例详解
2020/09/04 Javascript
[02:41]DOTA2亚洲邀请赛小组赛第三日 赛事回顾
2015/02/01 DOTA
python输入错误密码用户锁定实现方法
2017/11/27 Python
使用Python来开发微信功能
2018/06/13 Python
python 定义给定初值或长度的list方法
2018/06/23 Python
Python分割指定页数的pdf文件方法
2018/10/26 Python
python实现基于信息增益的决策树归纳
2018/12/18 Python
解决pyinstaller打包发布后的exe文件打开控制台闪退的问题
2019/06/21 Python
Python 使用matplotlib模块模拟掷骰子
2019/08/08 Python
Python selenium键盘鼠标事件实现过程详解
2020/07/28 Python
python中的列表和元组区别分析
2020/12/30 Python
一篇文章带你学习CSS3图片边框
2020/11/04 HTML / CSS
AmazeUI底部导航栏与分享按钮的示例代码
2020/08/18 HTML / CSS
个人求职简历的自我评价
2013/10/19 职场文书
优秀导游先进事迹材料
2014/01/25 职场文书
工业设计毕业生自荐信
2014/04/13 职场文书
七一建党日演讲稿
2014/09/05 职场文书
党员教师个人对照检查材料范文
2014/09/25 职场文书
廉政承诺书2015
2015/04/28 职场文书
2016年读书月活动总结范文
2016/04/06 职场文书