pytorch 实现在测试的时候启用dropout


Posted in Python onMay 27, 2021

我们知道,dropout一般都在训练的时候使用,那么测试的时候如何也开启dropout呢?

在pytorch中,网络有train和eval两种模式,在train模式下,dropout和batch normalization会生效,而val模式下,dropout不生效,bn固定参数。

想要在测试的时候使用dropout,可以把dropout单独设为train模式,这里可以使用apply函数:

def apply_dropout(m):
    if type(m) == nn.Dropout:
        m.train()

下面是完整demo代码:

# coding: utf-8
import torch
import torch.nn as nn
import numpy as np
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc = nn.Linear(8, 8)
        self.dropout = nn.Dropout(0.5)
    def forward(self, x):
        x = self.fc(x)
        x = self.dropout(x)
        return x
net = SimpleNet()
x = torch.FloatTensor([1]*8)
net.train()
y = net(x)
print('train mode result: ', y)
net.eval()
y = net(x)
print('eval mode result: ', y)
net.eval()
y = net(x)
print('eval2 mode result: ', y)
def apply_dropout(m):
    if type(m) == nn.Dropout:
        m.train()
net.eval()
net.apply(apply_dropout)
y = net(x)
print('apply eval result:', y)

运行结果:

pytorch 实现在测试的时候启用dropout

可以看到,在eval模式下,由于dropout未生效,每次跑的结果不同,利用apply函数,将Dropout单独设为train模式,dropout就生效了。

补充:Pytorch之dropout避免过拟合测试

一.做数据

pytorch 实现在测试的时候启用dropout

pytorch 实现在测试的时候启用dropout

二.搭建神经网络

pytorch 实现在测试的时候启用dropout

pytorch 实现在测试的时候启用dropout

三.训练

pytorch 实现在测试的时候启用dropout

四.对比测试结果

注意:测试过程中,一定要注意模式切换

pytorch 实现在测试的时候启用dropout

pytorch 实现在测试的时候启用dropout

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现在pickling的时候压缩的方法
Sep 25 Python
Python中的多行注释文档编写风格汇总
Jun 16 Python
Python数据分析之真实IP请求Pandas详解
Nov 18 Python
深入理解Python中的*重复运算符
Oct 28 Python
python基础教程项目二之画幅好画
Apr 02 Python
Python 通配符删除文件的实例
Apr 24 Python
python os.path模块常用方法实例详解
Sep 16 Python
Python argparse模块应用实例解析
Nov 15 Python
Python3如何在Windows和Linux上打包
Feb 25 Python
关于Python Tkinter Button控件command传参问题的解决方式
Mar 04 Python
Python基于pandas绘制散点图矩阵代码实例
Jun 04 Python
python Socket网络编程实现C/S模式和P2P
Jun 22 Python
使用Python脚本对GiteePages进行一键部署的使用说明
教你使用Python pypinyin库实现汉字转拼音
基于tensorflow权重文件的解读
May 26 #Python
解决Python字典查找报Keyerror的问题
浅谈tf.train.Saver()与tf.train.import_meta_graph的要点
tensorflow中的数据类型dtype用法说明
May 26 #Python
详解Python魔法方法之描述符类
May 26 #Python
You might like
求帮忙修改个php curl模拟post请求内容后并下载文件的解决思路
2015/09/20 PHP
php用户注册信息验证正则表达式
2015/11/12 PHP
php+jQuery实现的三级导航栏下拉菜单显示效果
2017/08/10 PHP
基于jQuery的ajax功能实现web service的json转化
2009/08/29 Javascript
关于firefox的ElementTraversal 接口 使用说明
2010/11/11 Javascript
javascript自执行函数之伪命名空间封装法
2010/12/25 Javascript
EasyUI的treegrid组件动态加载数据问题的解决办法
2011/12/11 Javascript
swtich/if...else的替代语句
2015/08/16 Javascript
js实现无限级树形导航列表效果代码
2015/09/23 Javascript
AngularJS Toaster使用详解
2017/02/24 Javascript
JS操作xml对象转换为Json对象示例
2017/03/25 Javascript
bootstrap弹出层的多种触发方式
2017/05/10 Javascript
AngularJS入门教程二:在路由中传递参数的方法分析
2017/05/27 Javascript
详解vue-router 路由元信息
2017/09/13 Javascript
mint-ui 时间插件使用及获取选择值的方法
2018/02/09 Javascript
Vue 中使用 CSS Modules优雅方法
2018/04/09 Javascript
[原创]jQuery实现合并/追加数组并去除重复项的方法
2018/04/11 jQuery
利用webpack理解CommonJS和ES Modules的差异区别
2020/06/16 Javascript
Vue页面手动刷新,实现导航栏激活项还原到初始状态
2020/08/06 Javascript
听歌识曲--用python实现一个音乐检索器的功能
2016/11/15 Python
Python实现的质因式分解算法示例
2018/05/03 Python
对python使用http、https代理的实例讲解
2018/05/07 Python
python利用小波分析进行特征提取的实例
2019/01/09 Python
python flask几分钟实现web服务的例子
2019/07/26 Python
Jupyter notebook 远程配置及SSL加密教程
2020/04/14 Python
基于Python词云分析政府工作报告关键词
2020/06/02 Python
如何使用python记录室友的抖音在线时间
2020/06/29 Python
西班牙英格列斯百货法国官网:El Corte Inglés法国
2017/07/09 全球购物
英国最受欢迎的母婴精品品牌:JoJo Maman BéBé
2021/02/17 全球购物
试解释COMMIT操作和ROLLBACK操作的语义
2014/07/25 面试题
银行实习生自我鉴定范文
2013/09/19 职场文书
中专药剂专业应届毕的自我评价
2013/12/27 职场文书
接受捐赠答谢词
2014/01/27 职场文书
教师岗位聘任书范文
2014/03/29 职场文书
竞选班长演讲稿500字
2014/08/22 职场文书
公诉意见书范文
2015/06/05 职场文书