记录模型训练时loss值的变化情况


Posted in Python onJune 16, 2020

记录训练过程中的每一步的loss变化

if verbose and step % verbose == 0:
 sys.stdout.write('\r{} / {} : loss = {}'.format(
  step, total_steps, np.mean(total_loss)))
 sys.stdout.flush()
 if verbose:
 sys.stdout.write('\r') 
 sys.stdout.flush()

一般我们在训练神经网络模型的时候,都是每隔多少步,输出打印一下loss或者每一步打印一下loss,今天发现了另一种记录loss变化的方法,就是用

sys.stdout.write('\r{} / {} : loss = {}')

如图上的代码,可以记录每一个在每个epoch中记录用一行输出就可以记录每个step的loss值变化,

\r就是输出不会换行,因此如果你想同一样输出多次,在需要输出的字符串对象里面加上"\r",就可以回到行首了。

sys.stdout.flush() #一秒输出了一个数字

具体的实现就是下面的图:

记录模型训练时loss值的变化情况

这样在每个epoch中也可以观察loss变化,但是只需要打印一行,而不是每一行都输出。

补充知识:训练模型中损失(loss)异常分析

前言

训练模型过程中随时都要注意目标函数值(loss)的大小变化。一个正常的模型loss应该随训练轮数(epoch)的增加而缓慢下降,然后趋于稳定。虽然在模型训练的初始阶段,loss有可能会出现大幅度震荡变化,但是只要数据量充分,模型正确,训练的轮数足够长,模型最终会达到收敛状态,接近最优值或者找到了某个局部最优值。在模型实际训练过程中,可能会得到一些异常loss值,如loss等于nan;loss值忽大忽小,不收敛等。

下面根据自己使用Pythorh训练模型的经验,分析出一些具体原因和给出对应的解决办法。

一、输入数据

1. 数据的预处理

输入到模型的数据一般都是经过了预处理的,如用pandas先进行数据处理,尤其要注意空值,缺失值,异常值。

缺失值:数值类型(NaN),对象类型(None, NaN),时间类型(NaT)

空值:""

异常值:不再正常区间范围的值

例如对缺失值可以进行判断df.isnull()或者df.isna();丢弃df.dropna();填充df.fillna()等操作。

输入到模型中的数据一般而言都是数值类型的值,一定要保证不能出现NaN, numpy中的nan是一种特殊的float,该值数值运算的结果是不正常的,所以可能会导致loss值等于nan。可以用numpy.any(numpy.isnan(x))检查一下input和target。

2. 数据的读写

例如使用Pandas读取.csv类型的数据得到的DataFrame会添加默认的index,再写回到磁盘会多一列。如果用其他读取方式再读入,可能会导致数据有问题,读取到NaN。

import pandas as pd
 
Output = pd.read_csv('./data/diabetes/Output.csv')
trainOutput, testOutput = Output[:6000], Output[6000:]
trainOutput.to_csv('./data/diabetes/trainOutput.csv')
testOutput.to_csv('./data/diabetes/testOutput.csv')

记录模型训练时loss值的变化情况

3. 数据的格式

Pythorch中的 torch.utils.data.Dataset 类是一个表示数据集的抽象类。自己数据集的类应该继承自 Dataset 并且重写__len__方法和__getitem__方法:

__len__ : len(dataset) 返回数据集的大小

__getitem__ :用以支持索引操作, dataset[idx]能够返回第idx个样本数据

然后使用torch.utils.data.DataLoader 这个迭代器(iterator)来遍历所有的特征。具体可以参见这里

在构造自己Dataset类时,需要注意返回的数据格式和类型,一般不会出现NaN的情况但是可能会导致数据float, int, long这几种类型的不兼容,注意转换。

二、学习率

基于梯度下降的优化方法,当学习率太高时会导致loss值不收敛,太低则下降缓慢。需要对学习率等超参数进行调参如使用网格搜索,随机搜索等。

三、除零错

对于回归问题,可能出现了除0 的计算,加一个很小的余项可能可以解决。类似于计算概率时进行的平滑修正,下面的代码片段中loss使用交叉混合熵(CossEntropy),计算3分类问题的AUC值,为了避免概率计算出现NaN而采取了相应的平滑处理。

from sklearn.metrics import roc_auc_score
 
model_ft, y_true, losslists = test_model(model_ft, criterion, optimizer)
n_class = 3
y_one_hot = np.eye(n_class)[y_true.reshape(-1)]
# solve divide zero errot
eps = 0.0000001
y_scores = losslists / (losslists.sum(axis=1, keepdims=True)+eps)
#print(y_scores)
#print(np.isnan(y_scores))
"""
metrics.roc_auc_score(y_one_hot, y_pred)
"""
print("auc: ")
roc_auc_score(y_one_hot, y_scores)

四、loss函数

loss函数代码编写不正确或者已经编写好的loss函数API使用不清楚

五、某些易错代码

Pytorch在进行自动微分的时候,默认梯度是会累加的,所以需要在每个epoch的每个batch中对梯度清零,否则可能会导致loss值不收敛。不要忘记添加如下代码

optimizer.zero_grad()

以上这篇记录模型训练时loss值的变化情况就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
C#返回当前系统所有可用驱动器符号的方法
Apr 18 Python
Python素数检测实例分析
Jun 15 Python
Python内置函数 next的具体使用方法
Nov 24 Python
python操作xlsx文件的包openpyxl实例
May 03 Python
python使用Plotly绘图工具绘制散点图、线形图
Apr 02 Python
python中利用numpy.array()实现俩个数值列表的对应相加方法
Aug 26 Python
Python中*args和**kwargs的区别详解
Sep 17 Python
Django 多对多字段的更新和插入数据实例
Mar 31 Python
keras输出预测值和真实值方式
Jun 27 Python
Python3基于plotly模块保存图片表格
Aug 03 Python
详解Python中string模块除去Str还剩下什么
Nov 30 Python
详解Python牛顿插值法
May 11 Python
python实现批量转换图片为黑白
Jun 16 #Python
在keras中实现查看其训练loss值
Jun 16 #Python
安装python3.7编译器后如何正确安装opnecv的方法详解
Jun 16 #Python
Keras在训练期间可视化训练误差和测试误差实例
Jun 16 #Python
如何在Windows中安装多个python解释器
Jun 16 #Python
使用pyplot.matshow()函数添加绘图标题
Jun 16 #Python
浅谈matplotlib中FigureCanvasXAgg的用法
Jun 16 #Python
You might like
PHP 导出数据到淘宝助手CSV的方法分享
2010/02/27 PHP
WordPress主题制作中自定义头部的相关PHP函数解析
2016/01/08 PHP
Laravel实现批量更新多条数据
2020/04/06 PHP
javascript下有关dom以及xml节点访问兼容问题
2007/11/26 Javascript
JavaScript 给汉字排序实例代码
2008/06/28 Javascript
jQuery Ajax之load()方法
2009/10/12 Javascript
jquery checkbox全选、取消全选实现代码
2010/03/05 Javascript
jquery的$getjson调用并获取远程的JSON字符串问题
2012/12/10 Javascript
jQuery插件实现表格隔行换色且感应鼠标高亮行变色
2013/09/22 Javascript
Area 区域实现post提交数据的js写法
2014/04/22 Javascript
node.js中的fs.fchown方法使用说明
2014/12/16 Javascript
javascript实现任务栏消息提示的简单实例
2016/05/31 Javascript
Javascript实现图片不间断滚动的代码
2016/06/22 Javascript
浅析js的模块化编写 require.js
2016/12/07 Javascript
Vue.js实现实例搜索应用功能详细代码
2017/08/24 Javascript
解决angularjs service中依赖注入$scope报错的问题
2018/10/02 Javascript
vue通过指令(directives)实现点击空白处收起下拉框
2018/12/06 Javascript
详解vue开发中调用微信jssdk的问题
2019/04/16 Javascript
基于JS实现table导出Excel并保留样式
2020/05/19 Javascript
[00:55]2015国际邀请赛中国区预选赛5月23日——28日约战上海
2015/05/25 DOTA
[35:43]2018DOTA2亚洲邀请赛 4.1 小组赛B组 paiN vs Effect
2018/04/03 DOTA
haskell实现多线程服务器实例代码
2013/11/26 Python
python实现泊松图像融合
2018/07/26 Python
Python通过cv2读取多个USB摄像头
2019/08/28 Python
html5中如何将图片的绝对路径转换成文件对象
2018/01/11 HTML / CSS
ProBikeKit美国官网:自行车套件,跑步和铁人三项套件
2016/10/13 全球购物
基本款天堂:Everlane
2017/05/13 全球购物
英国街头品牌:Bee Inspired Clothing
2018/02/12 全球购物
Bed Bath & Beyond加拿大官网:购买床上用品、浴巾、厨房电器等
2019/10/04 全球购物
简单介绍Object类的功能、常用方法
2013/10/02 面试题
服装设计专业自荐书范文
2013/12/30 职场文书
戒赌保证书
2015/05/11 职场文书
会议主持词开场白
2015/05/28 职场文书
2016年敬老月活动总结
2016/04/05 职场文书
闭幕词的写作格式与范文!
2019/06/24 职场文书
关于CentOS 8 搭建MongoDB4.4分片集群的问题
2021/10/24 MongoDB