浅谈pytorch grad_fn以及权重梯度不更新的问题


Posted in Python onAugust 20, 2019

前提:我训练的是二分类网络,使用语言为pytorch

Varibale包含三个属性:

data:存储了Tensor,是本体的数据

grad:保存了data的梯度,本事是个Variable而非Tensor,与data形状一致

grad_fn:指向Function对象,用于反向传播的梯度计算之用

在构建网络时,刚开始的错误为:没有可以grad_fn属性的变量。

百度后得知要对需要进行迭代更新的变量设置requires_grad=True ,操作如下:

train_pred = Variable(train_pred.float(), requires_grad=True)`

这样设置之后网络是跑起来了,但是准确率一直没有提升,很明显可以看出网络什么都没学到。

我输出 model.parameters() (网络内部的权重和偏置)查看,发现它的权重并没有更新,一直是同一个值,至此可以肯定网络什么都没学到,还是迭代那里出了问题。

询问同门后发现问题不在这里。

计算loss时,target与train_pred的size不匹配,我以以下操作修改了train_pred,使两者尺寸一致,才导致了上述问题。

train_pred = model(data)
  train_pred = torch.max(train_pred, 1)[1].data.squeeze()
  train_pred = Variable(train_pred.float(), requires_grad=False)
  train_loss = F.binary_cross_entropy(validation_pred.float(), target)
  train_loss.backward()

对train_pred多次处理后,它已无法正确地反向传播,实际上应该更改target,使其与train_pred size一致。

重点!!!要想loss正确反向传播,应直接将model(data)传入loss函数。

最终修改代码如下:

for batch_idx, (data, target) in enumerate(train_loader):
  # Get Samples
  label = target.view(target.size(0), 1).long()
  target_onehot = torch.zeros(data.shape[0], args.num_classes).scatter_(1, label, 1)
  data, target_onehot = Variable(data.cuda()), Variable(target_onehot.cuda().float())
  
  model.zero_grad()

  # Predict
  train_pred = model(data)
  train_loss = F.binary_cross_entropy(train_pred, target_onehot)
  train_loss.backward()
  optimizer.step()

以上这篇浅谈pytorch grad_fn以及权重梯度不更新的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
简单的通用表达式求10乘阶示例
Mar 03 Python
全面解析Python的While循环语句的使用方法
Oct 13 Python
Python协程的用法和例子详解
Sep 09 Python
django加载本地html的方法
May 27 Python
python 将md5转为16字节的方法
May 29 Python
配置 Pycharm 默认 Test runner 的图文教程
Nov 30 Python
处理Selenium3+python3定位鼠标悬停才显示的元素
Jul 31 Python
Python中字典与恒等运算符的用法分析
Aug 22 Python
Python3操作MongoDB增册改查等方法详解
Feb 10 Python
python使用turtle库绘制奥运五环
Feb 24 Python
解决Alexnet训练模型在每个epoch中准确率和loss都会一升一降问题
Jun 17 Python
为什么python比较流行
Jun 19 Python
解决Pytorch 训练与测试时爆显存(out of memory)的问题
Aug 20 #Python
python中用logging实现日志滚动和过期日志删除功能
Aug 20 #Python
python3中替换python2中cmp函数的实现
Aug 20 #Python
python 并发编程 多路复用IO模型详解
Aug 20 #Python
关于pytorch中网络loss传播和参数更新的理解
Aug 20 #Python
对pytorch中的梯度更新方法详解
Aug 20 #Python
PyTorch: 梯度下降及反向传播的实例详解
Aug 20 #Python
You might like
php smarty 二级分类代码和模版循环例子
2011/06/16 PHP
php使用GD库创建图片缩略图的方法
2015/06/10 PHP
基于laravel制作APP接口(API)
2016/03/15 PHP
实例分析基于PHP微信网页获取用户信息
2017/11/24 PHP
Prototype使用指南之enumerable.js
2007/01/10 Javascript
JavaScript中的new的使用方法与注意事项
2007/05/16 Javascript
7个Javascript地图脚本整理
2009/10/20 Javascript
Nodejs关于gzip/deflate压缩详解
2015/03/04 NodeJs
smartcrop.js智能图片裁剪库
2015/10/14 Javascript
BootStrap 智能表单实战系列(十)自动完成组件的支持
2016/06/13 Javascript
AngularJs  Using $location详解及示例代码
2016/09/02 Javascript
浅谈jquery页面初始化的4种方式
2016/11/27 Javascript
jquery.rotate.js实现可选抽奖次数和中奖内容的转盘抽奖代码
2017/08/23 jQuery
ES6中定义类和对象的方法示例
2019/07/31 Javascript
微信小程序实现拼图小游戏
2020/10/22 Javascript
Python编程判断一个正整数是否为素数的方法
2017/04/14 Python
手把手教你用python抢票回家过年(代码简单)
2018/01/21 Python
对numpy 数组和矩阵的乘法的进一步理解
2018/04/04 Python
Python基于lxml模块解析html获取页面内所有叶子节点xpath路径功能示例
2018/05/16 Python
Python实现的简单线性回归算法实例分析
2018/12/26 Python
python for 循环获取index索引的方法
2019/02/01 Python
Linux下通过python获取本机ip方法示例
2019/09/06 Python
python调用c++返回带成员指针的类指针实例
2019/12/12 Python
python如何查看网页代码
2020/06/07 Python
python基于socket模拟实现ssh远程执行命令
2020/12/05 Python
利用CSS3实现文字折纸效果实例代码
2018/07/10 HTML / CSS
HTML5边玩边学(3)像素和颜色
2010/09/21 HTML / CSS
英国内衣连锁店:Boux Avenue
2018/01/24 全球购物
Linux机考试题
2015/10/16 面试题
个人简历中自我评价
2014/02/11 职场文书
社区道德讲堂实施方案
2014/03/21 职场文书
员工辞职信怎么写
2015/02/27 职场文书
社会实践单位意见
2015/06/05 职场文书
信息技术远程培训心得体会
2016/01/09 职场文书
python实现Thrift服务端的方法
2021/04/20 Python
JavaScript offset实现鼠标坐标获取和窗口内模块拖动
2021/05/30 Javascript