踩坑:pytorch中eval模式下结果远差于train模式介绍


Posted in Python onJune 23, 2020

首先,eval模式和train模式得到不同的结果是正常的。我的模型中,eval模式和train模式不同之处在于Batch Normalization和Dropout。Dropout比较简单,在train时会丢弃一部分连接,在eval时则不会。Batch Normalization,在train时不仅使用了当前batch的均值和方差,也使用了历史batch统计上的均值和方差,并做一个加权平均(momentum参数)。在test时,由于此时batchsize不一定一致,因此不再使用当前batch的均值和方差,仅使用历史训练时的统计值。

我出bug的现象是,train模式下可以收敛,但一旦在测试中切换到了eval模式,结果就很差。如果在测试中仍沿用train模式,反而可以得到不错的结果。为了确保是程序bug而不是算法本身就不适合于预测,我在测试时再次使用了训练集,正常情况下此时应发生过拟合,正确率一定会很高,然而eval模式下正确率仍然很低。参照网上的一些说法(Performance highly degraded when eval() is activated in the test phase
),我调大了batchsize,降低了BN层的momentum,检查了是否存在不同层使用相同BN层的bug,均不见效。有一种方法说应在BN层设置track_running_stats为False,它虽然带来了好的效果,但实际上它只不过是不用eval模式,切回train模式罢了,所以也不对。

学习了在训练过程中,如何将BN层中统计的均值和方差输出。即在forward()中,

# bn是一个BN层,torch.nn.batch_normalization(...)
print(bn.running_mean)
print(bn.running_var)

同时学习了如何输出一个Tensor自身的均值和方差,即

# x是一个Tensor,dims是需要计算的维度
print(x.cpu().detach().numpy().mean(dims)
print(x.cpu().detach().numpy().var(dims)

观察每一层的输出结果,发现出现了很大的方差,才猛然意识到自己的输入数据没有做归一化(事后想想也确实如此,毕竟模型和训练方法都是github上参考别人的,出错概率很小;反而是自己写的DataSet部分,其实是最容易出错的)。给模型加上归一化后,eval和train的结果就没有问题了。

再次验证了我的观点:越是玄学的问题,越是傻逼的bug。

补充知识:Pytorch中的train和eval用法注意点

1.介绍

一般情况,model.train()是在训练的时候用到,model.eval()是在测试的时候用到

2.用法

如果模型中没有类似于BN这样的归一化或者Dropout,model.train()和model.eval()可以不要(建议写一下,比较安全),并且model.train()和model.eval()得到的效果是一样

如果模型中有类似于BN这样的归一化或者Dropout,并且程序需要边训练和边测试,最好就是用model.eval()测试完之后,后面补一个model.train()。

其中model.train()是保证BN用每一批数据的均值和方差,而model.eval()是保证BN用全部训练数据的均值和方差;而对于Dropout,model.train()是随机取一部分网络连接来训练更新参数,而model.eval()是利用到了所有网络连接(结果是取了平均)

以上这篇踩坑:pytorch中eval模式下结果远差于train模式介绍就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
pyv8学习python和javascript变量进行交互
Dec 04 Python
Python Sql数据库增删改查操作简单封装
Apr 18 Python
python使用标准库根据进程名如何获取进程的pid详解
Oct 31 Python
将TensorFlow的模型网络导出为单个文件的方法
Apr 23 Python
python实现狄克斯特拉算法
Jan 17 Python
pandas按行按列遍历Dataframe的几种方式
Oct 23 Python
Python面向对象封装操作案例详解 II
Jan 02 Python
使用tensorflow框架在Colab上跑通猫狗识别代码
Apr 26 Python
Pandas中两个dataframe的交集和差集的示例代码
Dec 13 Python
教你怎么用Python实现多路径迷宫
Apr 29 Python
使用pycharm运行flask应用程序的详细教程
Jun 07 Python
python画条形图的具体代码
Apr 20 Python
pytorch掉坑记录:model.eval的作用说明
Jun 23 #Python
Python使用Selenium实现淘宝抢单的流程分析
Jun 23 #Python
python2和python3哪个使用率高
Jun 23 #Python
python使用QQ邮箱实现自动发送邮件
Jun 22 #Python
浅谈keras中loss与val_loss的关系
Jun 22 #Python
python实现简易版学生成绩管理系统
Jun 22 #Python
python能否java成为主流语言吗
Jun 22 #Python
You might like
php在服务器执行exec命令失败的解决方法
2012/03/03 PHP
用来解析.htgroup文件的PHP类
2012/09/05 PHP
PHP获取音频文件的相关信息
2015/06/22 PHP
统计PHP目录中的文件数方法
2019/03/05 PHP
laravel中的fillable和guarded属性详解
2019/10/23 PHP
JavaScript入门教程 Cookies
2009/01/31 Javascript
Javascript实现获取窗口的大小和位置代码分享
2014/12/04 Javascript
jQuery中hasClass()方法用法实例
2015/01/06 Javascript
基于vue.js实现图片轮播效果
2016/12/01 Javascript
Angular.js项目中使用gulp实现自动化构建以及压缩打包详解
2017/07/19 Javascript
详解Vue 全局引入bass.scss 处理方案
2018/03/26 Javascript
vue实现随机验证码功能的实例代码
2019/04/30 Javascript
JS实现普通轮播图特效
2020/01/01 Javascript
JS实现图片懒加载(lazyload)过程详解
2020/04/02 Javascript
jQuery HTML获取内容和属性操作实例分析
2020/05/20 jQuery
Javascript如何递归遍历本地文件夹
2020/08/06 Javascript
你不知道的 TypeScript 高级类型(小结)
2020/08/28 Javascript
[01:06:12]VP vs NIP 2019国际邀请赛小组赛 BO2 第一场 8.15
2019/08/17 DOTA
[46:12]完美世界DOTA2联赛循环赛 DM vs Matador BO2第一场 11.04
2020/11/04 DOTA
python网络编程学习笔记(五):socket的一些补充
2014/06/09 Python
Python 探针的实现原理
2016/04/23 Python
tensorflow1.0学习之模型的保存与恢复(Saver)
2018/04/23 Python
Linux下python3.6.1环境配置教程
2018/09/26 Python
利用Python对文件夹下图片数据进行批量改名的代码实例
2019/02/21 Python
Python基本数据结构与用法详解【列表、元组、集合、字典】
2019/03/23 Python
python 下 CMake 安装配置 OPENCV 4.1.1的方法
2019/09/30 Python
pygame实现贪吃蛇游戏(下)
2019/10/29 Python
CSS3网格的三个新特性详解
2014/04/04 HTML / CSS
意大利制造的西装、衬衫和针对男士量身定制的服装:Lanieri
2018/04/08 全球购物
Burt’s Bees英国官网:世界领先的天然个人护理品牌
2020/08/17 全球购物
介绍java中初始化块的使用
2012/09/11 面试题
军训自我鉴定
2013/12/14 职场文书
给分销商的致歉信
2014/01/14 职场文书
道路建设实施方案
2014/03/18 职场文书
六查六看个人剖析材料
2014/10/14 职场文书
分享一些Java的常用工具
2021/06/11 Java/Android