PyTorch中model.zero_grad()和optimizer.zero_grad()用法


Posted in Python onJune 24, 2020

废话不多说,直接上代码吧~

model.zero_grad()
optimizer.zero_grad()

首先,这两种方式都是把模型中参数的梯度设为0

当optimizer = optim.Optimizer(net.parameters())时,二者等效,其中Optimizer可以是Adam、SGD等优化器

def zero_grad(self):
 """Sets gradients of all model parameters to zero."""
 for p in self.parameters():
  if p.grad is not None:
  p.grad.data.zero_()

补充知识:Pytorch中的optimizer.zero_grad和loss和net.backward和optimizer.step的理解

引言

一般训练神经网络,总是逃不开optimizer.zero_grad之后是loss(后面有的时候还会写forward,看你网络怎么写了)之后是是net.backward之后是optimizer.step的这个过程。

real_a, real_b = batch[0].to(device), batch[1].to(device)

fake_b = net_g(real_a)
optimizer_d.zero_grad()

# 判别器对虚假数据进行训练
fake_ab = torch.cat((real_a, fake_b), 1)
pred_fake = net_d.forward(fake_ab.detach())
loss_d_fake = criterionGAN(pred_fake, False)

# 判别器对真实数据进行训练
real_ab = torch.cat((real_a, real_b), 1)
pred_real = net_d.forward(real_ab)
loss_d_real = criterionGAN(pred_real, True)

# 判别器损失
loss_d = (loss_d_fake + loss_d_real) * 0.5

loss_d.backward()
optimizer_d.step()

上面这是一段cGAN的判别器训练过程。标题中所涉及到的这些方法,其实整个神经网络的参数更新过程(特别是反向传播),具体是怎么操作的,我们一起来探讨一下。

参数更新和反向传播

PyTorch中model.zero_grad()和optimizer.zero_grad()用法

上图为一个简单的梯度下降示意图。比如以SGD为例,是算一个batch计算一次梯度,然后进行一次梯度更新。这里梯度值就是对应偏导数的计算结果。显然,我们进行下一次batch梯度计算的时候,前一个batch的梯度计算结果,没有保留的必要了。所以在下一次梯度更新的时候,先使用optimizer.zero_grad把梯度信息设置为0。

我们使用loss来定义损失函数,是要确定优化的目标是什么,然后以目标为头,才可以进行链式法则和反向传播。

调用loss.backward方法时候,Pytorch的autograd就会自动沿着计算图反向传播,计算每一个叶子节点的梯度(如果某一个变量是由用户创建的,则它为叶子节点)。使用该方法,可以计算链式法则求导之后计算的结果值。

optimizer.step用来更新参数,就是图片中下半部分的w和b的参数更新操作。

以上这篇PyTorch中model.zero_grad()和optimizer.zero_grad()用法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
为Python的web框架编写MVC配置来使其运行的教程
Apr 30 Python
在Python的Django框架中调用方法和处理无效变量
Jul 15 Python
python xml解析实例详解
Nov 14 Python
Python实现变量数值交换及判断数组是否含有某个元素的方法
Sep 18 Python
Python遍历pandas数据方法总结
Feb 09 Python
使用DataFrame删除行和列的实例讲解
Apr 08 Python
Django中使用第三方登录的示例代码
Aug 20 Python
10 行 Python 代码教你自动发送短信(不想回复工作邮件妙招)
Oct 11 Python
python实现桌面壁纸切换功能
Jan 21 Python
python简单贪吃蛇开发
Jan 28 Python
Python编写通讯录通过数据库存储实现模糊查询功能
Jul 18 Python
使用Python拟合函数曲线
Apr 14 Python
Pytorch实现将模型的所有参数的梯度清0
Jun 24 #Python
你需要学会的8个Python列表技巧
Jun 24 #Python
pytorch实现查看当前学习率
Jun 24 #Python
在pytorch中动态调整优化器的学习率方式
Jun 24 #Python
CentOS 7如何实现定时执行python脚本
Jun 24 #Python
python tkiner实现 一个小小的图片翻页功能的示例代码
Jun 24 #Python
在tensorflow实现直接读取网络的参数(weight and bias)的值
Jun 24 #Python
You might like
PHP获取文件夹内文件数的方法
2015/03/12 PHP
Apache连接PHP后无法启动问题解决思路
2015/06/18 PHP
PHP Smarty模版简单使用方法
2016/03/30 PHP
JS图片浏览组件PhotoLook的公开属性方法介绍和进阶实例代码
2010/11/09 Javascript
js解析与序列化json数据(三)json的解析探讨
2013/02/01 Javascript
javascript获取网页中指定节点的父节点、子节点的方法小结
2013/04/24 Javascript
javascript中encodeURI和decodeURI方法使用介绍
2013/05/06 Javascript
JS远程获取网页源代码实例
2013/09/05 Javascript
JS对文本框值的判断示例
2014/03/10 Javascript
四种参数传递的形式——URL,超链接,js,form表单
2015/07/24 Javascript
js实现跨域的几种方法汇总(图片ping、JSONP和CORS)
2015/10/25 Javascript
bootstrap快速制作后台界面
2016/12/05 Javascript
canvas实现图像截取功能
2017/02/06 Javascript
vue2实现数据请求显示loading图
2017/11/28 Javascript
webpack打包非模块化js的方法
2018/10/24 Javascript
微信小程序缓存过期时间的使用详情
2019/05/12 Javascript
webpack4.0+vue2.0利用批处理生成前端单页或多页应用的方法
2019/06/28 Javascript
js之切换全屏和退出全屏实现代码实例
2019/09/09 Javascript
Python中join和split用法实例
2015/04/14 Python
在python3环境下的Django中使用MySQL数据库的实例
2017/08/29 Python
使用sklearn进行对数据标准化、归一化以及将数据还原的方法
2018/07/11 Python
Python实现捕获异常发生的文件和具体行数
2020/04/25 Python
深入了解Python装饰器的高级用法
2020/08/13 Python
python 通过exifread读取照片信息
2020/12/24 Python
简单的HTML5初步入门教程
2015/09/29 HTML / CSS
伊利莎白雅顿官网:Elizabeth Arden
2016/10/10 全球购物
美国摄影爱好者购物网站:Focus Camera
2016/10/21 全球购物
倩碧英国官网:Clinique英国
2018/08/10 全球购物
周仰杰(JIMMY CHOO)法国官方网站:闻名世界的鞋子品牌
2019/09/27 全球购物
电子商务个人自荐信
2013/12/12 职场文书
五水共治一句话承诺
2014/05/30 职场文书
公共机构节能宣传周活动总结
2014/07/09 职场文书
三八妇女节慰问信
2015/02/14 职场文书
幼儿园春季开学通知
2015/07/16 职场文书
python多线程方法详解
2022/01/18 Python
Redis+AOP+自定义注解实现限流
2022/06/28 Redis