pytorch中的model.eval()和BN层的使用


Posted in Python onMay 22, 2021

看代码吧~

class ConvNet(nn.module):
    def __init__(self, num_class=10):
        super(ConvNet, self).__init__()
        self.layer1 = nn.Sequential(nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),
                                    nn.BatchNorm2d(16),
                                    nn.ReLU(),
                                    nn.MaxPool2d(kernel_size=2, stride=2))
        self.layer2 = nn.Sequential(nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),
                                    nn.BatchNorm2d(32),
                                    nn.ReLU(),
                                    nn.MaxPool2d(kernel_size=2, stride=2))
        self.fc = nn.Linear(7*7*32, num_classes)
         
    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        print(out.size())
        out = out.reshape(out.size(0), -1)
        out = self.fc(out)
        return out
# Test the model
model.eval()  # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

如果网络模型model中含有BN层,则在预测时应当将模式切换为评估模式,即model.eval()。

评估模拟下BN层的均值和方差应该是整个训练集的均值和方差,即 moving mean/variance。

训练模式下BN层的均值和方差为mini-batch的均值和方差,因此应当特别注意。

补充:Pytorch 模型训练模式和eval模型下差别巨大(Pytorch train and eval)附解决方案

当pytorch模型写明是eval()时有时表现的结果相对于train(True)差别非常巨大,这种差别经过逐层查看,主要来源于使用了BN,在eval下,使用的BN是一个固定的running rate,而在train下这个running rate会根据输入发生改变。

解决方案是冻住bn

def freeze_bn(m):
    if isinstance(m, nn.BatchNorm2d):
        m.eval()
model.apply(freeze_bn)

这样可以获得稳定输出的结果。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现拼接多张图片的方法
Dec 01 Python
整理Python最基本的操作字典的方法
Apr 24 Python
python学习笔记之调用eval函数出现invalid syntax错误问题
Oct 18 Python
python 实现自动远程登陆scp文件实例代码
Mar 13 Python
Python中类的创建和实例化操作示例
Feb 27 Python
彻底理解Python中的yield关键字
Apr 01 Python
python找出因数与质因数的方法
Jul 25 Python
Django模板语言 Tags使用详解
Sep 09 Python
python使用sklearn实现决策树的方法示例
Sep 12 Python
Python模块汇总(常用第三方库)
Oct 07 Python
Django-migrate报错问题解决方案
Apr 21 Python
Python基于unittest实现测试用例执行
Nov 25 Python
解决Pytorch中关于model.eval的问题
Pytorch 中net.train 和 net.eval的使用说明
May 22 #Python
对PyTorch中inplace字段的全面理解
May 22 #Python
pytorch中F.avg_pool1d()和F.avg_pool2d()的使用操作
May 22 #Python
用python实现监控视频人数统计
Python基础之进程详解
如何在C++中调用Python
May 21 #Python
You might like
用PHP开发GUI
2006/10/09 PHP
一个程序下载的管理程序(一)
2006/10/09 PHP
php 无限级缓存的类的扩展
2009/03/16 PHP
php获取本地图片文件并生成xml文件输出具体思路
2013/04/27 PHP
PHP中的Iterator迭代对象属性详解
2019/04/12 PHP
javascript基于jQuery的表格悬停变色/恢复,表格点击变色/恢复,点击行选Checkbox
2008/08/05 Javascript
jquery实现的超出屏幕时把固定层变为定位层的代码
2010/02/23 Javascript
如何确保JavaScript的执行顺序 之实战篇
2011/03/03 Javascript
jQuery的文档处理程序详解
2016/05/10 Javascript
基于Bootstrap重置输入框内容按钮插件
2016/05/12 Javascript
JavaScript基础语法之js表达式
2016/06/07 Javascript
js实现移动端轮播图效果
2020/12/09 Javascript
React学习笔记之列表渲染示例详解
2017/08/22 Javascript
python使用append合并两个数组的方法
2015/04/28 Python
python使用os.listdir和os.walk获得文件的路径的方法
2017/12/16 Python
python opencv之分水岭算法示例
2018/02/24 Python
python使用turtle绘制分形树
2018/06/22 Python
python之Flask实现简单登录功能的示例代码
2018/12/24 Python
实例详解Matlab 与 Python 的区别
2019/04/26 Python
Python中面向对象你应该知道的一下知识
2019/07/10 Python
Python学习笔记之Django创建第一个数据库模型的方法
2019/08/07 Python
Django单元测试中Fixtures的使用方法
2020/02/26 Python
Python3 pickle对象串行化代码实例解析
2020/03/23 Python
Python 执行矩阵与线性代数运算
2020/08/01 Python
福克斯租车:Fox Rent A Car
2017/04/13 全球购物
Big Green Smile德国网上商店:提供各种天然产品
2018/05/23 全球购物
Hawes & Curtis官网:英国经典品牌
2019/07/27 全球购物
公共汽车、火车和飞机票的通用在线预订和销售平台:INFOBUS
2019/11/30 全球购物
财务方面个人工作的自我评价
2013/12/28 职场文书
药学专业学生的自我评价分享
2014/02/06 职场文书
客户接待方案
2014/02/26 职场文书
中职三好学生事迹材料
2014/08/24 职场文书
群众路线自我剖析材料
2014/10/08 职场文书
学生不参加考试检讨书
2015/02/19 职场文书
学雷锋团日活动总结
2015/05/06 职场文书
为Centos安装指定版本的Docker
2022/04/01 Servers