PyTorch中model.zero_grad()和optimizer.zero_grad()用法


Posted in Python onJune 24, 2020

废话不多说,直接上代码吧~

model.zero_grad()
optimizer.zero_grad()

首先,这两种方式都是把模型中参数的梯度设为0

当optimizer = optim.Optimizer(net.parameters())时,二者等效,其中Optimizer可以是Adam、SGD等优化器

def zero_grad(self):
 """Sets gradients of all model parameters to zero."""
 for p in self.parameters():
  if p.grad is not None:
  p.grad.data.zero_()

补充知识:Pytorch中的optimizer.zero_grad和loss和net.backward和optimizer.step的理解

引言

一般训练神经网络,总是逃不开optimizer.zero_grad之后是loss(后面有的时候还会写forward,看你网络怎么写了)之后是是net.backward之后是optimizer.step的这个过程。

real_a, real_b = batch[0].to(device), batch[1].to(device)

fake_b = net_g(real_a)
optimizer_d.zero_grad()

# 判别器对虚假数据进行训练
fake_ab = torch.cat((real_a, fake_b), 1)
pred_fake = net_d.forward(fake_ab.detach())
loss_d_fake = criterionGAN(pred_fake, False)

# 判别器对真实数据进行训练
real_ab = torch.cat((real_a, real_b), 1)
pred_real = net_d.forward(real_ab)
loss_d_real = criterionGAN(pred_real, True)

# 判别器损失
loss_d = (loss_d_fake + loss_d_real) * 0.5

loss_d.backward()
optimizer_d.step()

上面这是一段cGAN的判别器训练过程。标题中所涉及到的这些方法,其实整个神经网络的参数更新过程(特别是反向传播),具体是怎么操作的,我们一起来探讨一下。

参数更新和反向传播

PyTorch中model.zero_grad()和optimizer.zero_grad()用法

上图为一个简单的梯度下降示意图。比如以SGD为例,是算一个batch计算一次梯度,然后进行一次梯度更新。这里梯度值就是对应偏导数的计算结果。显然,我们进行下一次batch梯度计算的时候,前一个batch的梯度计算结果,没有保留的必要了。所以在下一次梯度更新的时候,先使用optimizer.zero_grad把梯度信息设置为0。

我们使用loss来定义损失函数,是要确定优化的目标是什么,然后以目标为头,才可以进行链式法则和反向传播。

调用loss.backward方法时候,Pytorch的autograd就会自动沿着计算图反向传播,计算每一个叶子节点的梯度(如果某一个变量是由用户创建的,则它为叶子节点)。使用该方法,可以计算链式法则求导之后计算的结果值。

optimizer.step用来更新参数,就是图片中下半部分的w和b的参数更新操作。

以上这篇PyTorch中model.zero_grad()和optimizer.zero_grad()用法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在Windows服务器下用Apache和mod_wsgi配置Python应用的教程
May 06 Python
Python中的深拷贝和浅拷贝详解
Jun 03 Python
Python科学画图代码分享
Nov 29 Python
Python实现的视频播放器功能完整示例
Feb 01 Python
python实现微信远程控制电脑
Feb 22 Python
Python中循环引用(import)失败的解决方法
Apr 22 Python
Python爬取视频(其实是一篇福利)过程解析
Aug 01 Python
Python3.7.0 Shell添加清屏快捷键的实现示例
Mar 23 Python
Django 多对多字段的更新和插入数据实例
Mar 31 Python
利用python中的matplotlib打印混淆矩阵实例
Jun 16 Python
django模型类中,null=True,blank=True用法说明
Jul 09 Python
python中doctest库实例用法
Dec 31 Python
Pytorch实现将模型的所有参数的梯度清0
Jun 24 #Python
你需要学会的8个Python列表技巧
Jun 24 #Python
pytorch实现查看当前学习率
Jun 24 #Python
在pytorch中动态调整优化器的学习率方式
Jun 24 #Python
CentOS 7如何实现定时执行python脚本
Jun 24 #Python
python tkiner实现 一个小小的图片翻页功能的示例代码
Jun 24 #Python
在tensorflow实现直接读取网络的参数(weight and bias)的值
Jun 24 #Python
You might like
基于python发送邮件的乱码问题的解决办法
2013/04/25 PHP
yii 框架实现按天,月,年,自定义时间段统计数据的方法分析
2020/04/04 PHP
javascript 数据类型转换(parseInt,parseFloat)
2010/07/20 Javascript
常见效果实现之返回顶部(结合淡入、淡出、减速滚动)
2012/01/04 Javascript
js渐变显示渐变消失示例代码
2013/08/01 Javascript
浅析JavaScript中的类型和对象
2013/11/29 Javascript
jQuery中outerHeight()方法用法实例
2015/01/19 Javascript
Angular.js中$apply()和$digest()的深入理解
2016/10/13 Javascript
Bootstrap基本插件学习笔记之模态对话框(16)
2016/12/08 Javascript
关于javascript事件响应的基础语法总结(必看篇)
2016/12/26 Javascript
JavaScript实现移动端轮播效果
2017/06/06 Javascript
使用Require.js封装原生js轮播图的实现代码
2017/06/15 Javascript
使用mint-ui实现省市区三级联动效果的示例代码
2018/02/09 Javascript
使用异步组件优化Vue应用程序的性能
2019/04/28 Javascript
Angular6使用forRoot() 注册单一实例服务问题
2019/08/27 Javascript
微信小程序学习总结(二)样式、属性、模板操作分析
2020/06/04 Javascript
vue-cli3项目打包后自动化部署到服务器的方法
2020/09/16 Javascript
[02:56]DOTA2亚洲邀请赛 VG出场战队巡礼
2015/02/07 DOTA
Python字典简介以及用法详解
2016/11/15 Python
python爬取哈尔滨天气信息
2018/07/14 Python
应用OpenCV和Python进行SIFT算法的实现详解
2019/08/21 Python
dpn网络的pytorch实现方式
2020/01/14 Python
一文了解python 3 字符串格式化 F-string 用法
2020/03/04 Python
python使用opencv resize图像不进行插值的操作
2020/07/05 Python
html5自带表单验证体验优化及提示气泡修改功能
2017/09/12 HTML / CSS
ASOS亚洲:ASOS Asia
2018/03/04 全球购物
随机分配座位,共50个学生,使学号相邻的同学座位不能相邻
2014/01/18 面试题
档案室主任岗位职责
2014/02/12 职场文书
党建工作先进材料
2014/05/02 职场文书
整改通知书格式
2015/04/22 职场文书
技能培训通讯稿
2015/07/18 职场文书
pytorch 中autograd.grad()函数的用法说明
2021/05/12 Python
浅析Redis Sentinel 与 Redis Cluster
2021/06/24 Redis
Java实现扫雷游戏详细代码讲解
2022/05/25 Java/Android
NoSQL优缺点与MongoDB数据库简介
2022/06/05 MongoDB
CSS浮动引起的高度塌陷问题
2022/08/05 HTML / CSS