PyTorch梯度裁剪避免训练loss nan的操作


Posted in Python onMay 24, 2021

近来在训练检测网络的时候会出现loss为nan的情况,需要中断重新训练,会很麻烦。因而选择使用PyTorch提供的梯度裁剪库来对模型训练过程中的梯度范围进行限制,修改之后,不再出现loss为nan的情况。

PyTorch中采用torch.nn.utils.clip_grad_norm_来实现梯度裁剪,链接如下:

https://pytorch.org/docs/stable/_modules/torch/nn/utils/clip_grad.html

训练代码使用示例如下:

from torch.nn.utils import clip_grad_norm_
outputs = model(data)
loss= loss_fn(outputs, target)
optimizer.zero_grad()
loss.backward()
# clip the grad
clip_grad_norm_(model.parameters(), max_norm=20, norm_type=2)
optimizer.step()

其中,max_norm为梯度的最大范数,也是梯度裁剪时主要设置的参数。

备注:网上有同学提醒在(强化学习)使用了梯度裁剪之后训练时间会大大增加。目前在我的检测网络训练中暂时还没有碰到这个问题,以后遇到再来更新。

补充:pytorch训练过程中出现nan的排查思路

1、最常见的就是出现了除0或者log0这种

看看代码中在这种操作的时候有没有加一个很小的数,但是这个数数量级要和运算的数的数量级要差很多。一般是1e-8。

2、在optim.step()之前裁剪梯度

optim.zero_grad()
loss.backward()
nn.utils.clip_grad_norm(model.parameters, max_norm, norm_type=2)
optim.step()

max_norm一般是1,3,5。

3、前面两条还不能解决nan的话

就按照下面的流程来判断。

...
loss = model(input)
# 1. 先看loss是不是nan,如果loss是nan,那么说明可能是在forward的过程中出现了第一条列举的除0或者log0的操作
assert torch.isnan(loss).sum() == 0, print(loss)
optim.zero_grad()
loss.backward()
# 2. 如果loss不是nan,那么说明forward过程没问题,可能是梯度爆炸,所以用梯度裁剪试试
nn.utils.clip_grad_norm(model.parameters, max_norm, norm_type=2)
# 3.1 在step之前,判断参数是不是nan, 如果不是判断step之后是不是nan
assert torch.isnan(model.mu).sum() == 0, print(model.mu)
optim.step()
# 3.2 在step之后判断,参数和其梯度是不是nan,如果3.1不是nan,而3.2是nan,
# 特别是梯度出现了Nan,考虑学习速率是否太大,调小学习速率或者换个优化器试试。
assert torch.isnan(model.mu).sum() == 0, print(model.mu)
assert torch.isnan(model.mu.grad).sum() == 0, print(model.mu.grad)

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python查看zip包中文件及大小的方法
Jul 09 Python
Django中URLconf和include()的协同工作方法
Jul 20 Python
numpy使用技巧之数组过滤实例代码
Feb 03 Python
numpy使用fromstring创建矩阵的实例
Jun 15 Python
python3使用print打印带颜色的字符串代码实例
Aug 22 Python
python/Matplotlib绘制复变函数图像教程
Nov 21 Python
python实现扫雷游戏
Mar 03 Python
浅谈python量化 双均线策略(金叉死叉)
Jun 03 Python
python3跳出一个循环的实例操作
Aug 18 Python
详解python tkinter 图片插入问题
Sep 03 Python
15款Python编辑器的优缺点,别再问我“选什么编辑器”啦
Oct 19 Python
python 制作网站小说下载器
Feb 20 Python
python3读取文件指定行的三种方法
May 24 #Python
pytorch中Schedule与warmup_steps的用法说明
May 24 #Python
Python Pycharm虚拟下百度飞浆PaddleX安装报错问题及处理方法(亲测100%有效)
May 24 #Python
pytorch交叉熵损失函数的weight参数的使用
May 24 #Python
pytorch 实现变分自动编码器的操作
May 24 #Python
Pytorch数据读取之Dataset和DataLoader知识总结
May 23 #Python
Python基础之函数嵌套知识总结
May 23 #Python
You might like
sourcesafe管理phpproj文件的补充说明(downmoon)
2009/04/11 PHP
php处理json时中文问题的解决方法
2011/04/12 PHP
url decode problem 解决方法
2011/12/26 PHP
总结的一些PHP开发中的tips(必看篇)
2017/03/24 PHP
PHP memcache在微信公众平台的应用方法示例
2017/09/13 PHP
JavaScript Event学习第三章 早期的事件处理程序
2010/02/07 Javascript
input 和 textarea 输入框最大文字限制的jquery插件
2011/10/27 Javascript
利用js正则表达式验证手机号,email地址,邮政编码
2014/01/23 Javascript
node.js中的buffer.Buffer.isEncoding方法使用说明
2014/12/14 Javascript
js实现根据身份证号自动生成出生日期
2015/12/15 Javascript
jquery及js实现动态加载js文件的方法
2016/01/21 Javascript
JavaScript实现前端分页控件
2017/04/19 Javascript
JavaScript初学者必看“new”
2017/06/12 Javascript
angularjs实现过滤并替换关键字小功能
2017/09/19 Javascript
vue-cli脚手架引入图片的几种方法总结
2018/03/13 Javascript
vue中父子组件注意事项,传值及slot应用技巧
2018/05/09 Javascript
Vue跨域请求问题解决方案过程解析
2020/08/07 Javascript
vue-cli —— 如何局部修改Element样式
2020/10/22 Javascript
Vite和Vue CLI的优劣
2021/01/30 Vue.js
[40:03]Liquid vs Optic 2018国际邀请赛淘汰赛BO3 第一场 8.21
2018/08/22 DOTA
Python使用微信SDK实现的微信支付功能示例
2017/06/30 Python
python的pdb调试命令的命令整理及实例
2017/07/12 Python
用python制作游戏外挂
2018/01/04 Python
MNIST数据集转化为二维图片的实现示例
2020/01/10 Python
TensorFlow——Checkpoint为模型添加检查点的实例
2020/01/21 Python
Python unittest工作原理和使用过程解析
2020/02/24 Python
pytorch简介
2020/11/11 Python
Html5剪切板功能的实现代码
2018/06/29 HTML / CSS
通过HTML5 Canvas API绘制弧线和圆形的教程
2016/03/14 HTML / CSS
机械工程及自动化专业求职信
2014/09/03 职场文书
违反单位工作制度检讨书
2014/10/25 职场文书
幼儿园老师个人总结
2015/02/28 职场文书
2015年法制宣传月活动总结
2015/03/26 职场文书
小区物业管理2015年度工作总结
2015/10/22 职场文书
python实现黄金分割法的示例代码
2021/04/28 Python
Python Pytorch查询图像的特征从集合或数据库中查找图像
2022/04/09 Python