PyTorch梯度裁剪避免训练loss nan的操作


Posted in Python onMay 24, 2021

近来在训练检测网络的时候会出现loss为nan的情况,需要中断重新训练,会很麻烦。因而选择使用PyTorch提供的梯度裁剪库来对模型训练过程中的梯度范围进行限制,修改之后,不再出现loss为nan的情况。

PyTorch中采用torch.nn.utils.clip_grad_norm_来实现梯度裁剪,链接如下:

https://pytorch.org/docs/stable/_modules/torch/nn/utils/clip_grad.html

训练代码使用示例如下:

from torch.nn.utils import clip_grad_norm_
outputs = model(data)
loss= loss_fn(outputs, target)
optimizer.zero_grad()
loss.backward()
# clip the grad
clip_grad_norm_(model.parameters(), max_norm=20, norm_type=2)
optimizer.step()

其中,max_norm为梯度的最大范数,也是梯度裁剪时主要设置的参数。

备注:网上有同学提醒在(强化学习)使用了梯度裁剪之后训练时间会大大增加。目前在我的检测网络训练中暂时还没有碰到这个问题,以后遇到再来更新。

补充:pytorch训练过程中出现nan的排查思路

1、最常见的就是出现了除0或者log0这种

看看代码中在这种操作的时候有没有加一个很小的数,但是这个数数量级要和运算的数的数量级要差很多。一般是1e-8。

2、在optim.step()之前裁剪梯度

optim.zero_grad()
loss.backward()
nn.utils.clip_grad_norm(model.parameters, max_norm, norm_type=2)
optim.step()

max_norm一般是1,3,5。

3、前面两条还不能解决nan的话

就按照下面的流程来判断。

...
loss = model(input)
# 1. 先看loss是不是nan,如果loss是nan,那么说明可能是在forward的过程中出现了第一条列举的除0或者log0的操作
assert torch.isnan(loss).sum() == 0, print(loss)
optim.zero_grad()
loss.backward()
# 2. 如果loss不是nan,那么说明forward过程没问题,可能是梯度爆炸,所以用梯度裁剪试试
nn.utils.clip_grad_norm(model.parameters, max_norm, norm_type=2)
# 3.1 在step之前,判断参数是不是nan, 如果不是判断step之后是不是nan
assert torch.isnan(model.mu).sum() == 0, print(model.mu)
optim.step()
# 3.2 在step之后判断,参数和其梯度是不是nan,如果3.1不是nan,而3.2是nan,
# 特别是梯度出现了Nan,考虑学习速率是否太大,调小学习速率或者换个优化器试试。
assert torch.isnan(model.mu).sum() == 0, print(model.mu)
assert torch.isnan(model.mu.grad).sum() == 0, print(model.mu.grad)

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python脚本设置超时机制系统时间的方法
Feb 21 Python
浅谈Python的垃圾回收机制
Dec 17 Python
基于python select.select模块通信的实例讲解
Sep 21 Python
通过python+selenium3实现浏览器刷简书文章阅读量
Dec 26 Python
python获取程序执行文件路径的方法(推荐)
Apr 26 Python
TensorFlow的权值更新方法
Jun 14 Python
对python_discover方法遍历所有执行的用例详解
Feb 13 Python
python自定义线程池控制线程数量的示例
Feb 22 Python
对python3 sort sorted 函数的应用详解
Jun 27 Python
Python 中使用 PyMySQL模块操作数据库的方法
Nov 10 Python
python 使用递归回溯完美解决八皇后的问题
Feb 26 Python
keras小技巧——获取某一个网络层的输出方式
May 23 Python
python3读取文件指定行的三种方法
May 24 #Python
pytorch中Schedule与warmup_steps的用法说明
May 24 #Python
Python Pycharm虚拟下百度飞浆PaddleX安装报错问题及处理方法(亲测100%有效)
May 24 #Python
pytorch交叉熵损失函数的weight参数的使用
May 24 #Python
pytorch 实现变分自动编码器的操作
May 24 #Python
Pytorch数据读取之Dataset和DataLoader知识总结
May 23 #Python
Python基础之函数嵌套知识总结
May 23 #Python
You might like
Windows下PHP5和Apache的安装与配置
2006/09/05 PHP
解析posix与perl标准的正则表达式区别
2013/06/17 PHP
php获取目标函数执行时间示例
2014/03/04 PHP
PHP的几个常用加密函数
2016/02/03 PHP
Zend Framework框架路由机制代码分析
2016/03/22 PHP
PHP扩展Swoole实现实时异步任务队列示例
2019/04/13 PHP
JS setCapture 区域外事件捕捉
2010/03/18 Javascript
JavaScript Date对象 日期获取函数
2010/12/19 Javascript
JS的replace方法详细介绍
2012/11/09 Javascript
基于javascript实现漂亮的页面过渡动画效果附源码下载
2015/10/26 Javascript
jQuery判断checkbox选中状态
2016/05/12 Javascript
JQuery 在文档中查找指定name的元素并移除的实现方法
2016/05/19 Javascript
jQuery UI仿淘宝搜索下拉列表功能
2017/01/10 Javascript
layui表格checkbox选择全选样式及功能的实例
2018/03/07 Javascript
在create-react-app中使用css modules的示例代码
2018/07/31 Javascript
layui 给数据表格加序号的方法
2018/08/20 Javascript
layui table去掉右侧滑动条的实现方法
2019/09/05 Javascript
js实现移动端tab切换时下划线滑动效果
2019/09/08 Javascript
Element Steps步骤条的使用方法
2020/07/26 Javascript
win与linux系统中python requests 安装
2016/12/04 Python
Tensorflow之构建自己的图片数据集TFrecords的方法
2018/02/07 Python
浅谈Python黑帽子取代netcat
2018/02/10 Python
python3实现mysql导出excel的方法
2019/07/31 Python
python 实现将Numpy数组保存为图像
2020/01/09 Python
利用Python实现某OA系统的自动定位功能
2020/05/27 Python
keras训练曲线,混淆矩阵,CNN层输出可视化实例
2020/06/15 Python
Python 解析简单的XML数据
2020/07/24 Python
瑞士香水购物网站:Parfumcity.ch
2017/01/14 全球购物
Carter’s OshKosh加拿大:购买婴幼儿服装和童装
2018/11/27 全球购物
德国W家官网,可直邮中国的母婴商城:Windeln.de
2021/03/03 全球购物
信息部岗位职责
2013/11/12 职场文书
给导游的表扬信
2014/01/10 职场文书
学生党员公开承诺书
2014/05/28 职场文书
忠诚教育学习心得体会
2016/01/23 职场文书
游戏开发中如何使用CocosCreator进行音效处理
2021/04/14 Javascript
Windows 64位 安装 mysql 8.0.28 图文教程
2022/04/19 MySQL