浅谈PyTorch的可重复性问题(如何使实验结果可复现)


Posted in Python onFebruary 20, 2020

由于在模型训练的过程中存在大量的随机操作,使得对于同一份代码,重复运行后得到的结果不一致。因此,为了得到可重复的实验结果,我们需要对随机数生成器设置一个固定的种子。

许多博客都有介绍如何解决这个问题,但是很多都不够全面,往往不能保证结果精确一致。我经过许多调研和实验,总结了以下方法,记录下来。

全部设置可以分为三部分:

1. CUDNN

cudnn中对卷积操作进行了优化,牺牲了精度来换取计算效率。如果需要保证可重复性,可以使用如下设置:

from torch.backends import cudnn
cudnn.benchmark = False      # if benchmark=True, deterministic will be False
cudnn.deterministic = True

不过实际上这个设置对精度影响不大,仅仅是小数点后几位的差别。所以如果不是对精度要求极高,其实不太建议修改,因为会使计算效率降低。

2. Pytorch

torch.manual_seed(seed)      # 为CPU设置随机种子
torch.cuda.manual_seed(seed)    # 为当前GPU设置随机种子
torch.cuda.manual_seed_all(seed)  # 为所有GPU设置随机种子

3. Python & Numpy

如果读取数据的过程采用了随机预处理(如RandomCrop、RandomHorizontalFlip等),那么对python、numpy的随机数生成器也需要设置种子。

import random
import numpy as np
random.seed(seed)
np.random.seed(seed)

最后,关于dataloader:

注意,如果dataloader采用了多线程(num_workers > 1), 那么由于读取数据的顺序不同,最终运行结果也会有差异。也就是说,改变num_workers参数,也会对实验结果产生影响。目前暂时没有发现解决这个问题的方法,但是只要固定num_workers数目(线程数)不变,基本上也能够重复实验结果。

对于不同线程的随机数种子设置,主要通过DataLoader的worker_init_fn参数来实现。默认情况下使用线程ID作为随机数种子。如果需要自己设定,可以参考以下代码:

GLOBAL_SEED = 1
 
def set_seed(seed):
  random.seed(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  torch.cuda.manual_seed(seed)
  torch.cuda.manual_seed_all(seed)
 
GLOBAL_WORKER_ID = None
def worker_init_fn(worker_id):
  global GLOBAL_WORKER_ID
  GLOBAL_WORKER_ID = worker_id
  set_seed(GLOBAL_SEED + worker_id)
 
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=2, worker_init_fn=worker_init_fn)

以上这篇浅谈PyTorch的可重复性问题(如何使实验结果可复现)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python模块学习 re 正则表达式
May 19 Python
python笔记(1) 关于我们应不应该继续学习python
Oct 24 Python
python解析发往本机的数据包示例 (解析数据包)
Jan 16 Python
python抓取网页图片并放到指定文件夹
Apr 24 Python
Python获取SQLite查询结果表列名的方法
Jun 21 Python
pyhanlp安装介绍和简单应用
Feb 22 Python
利用python实现逐步回归
Feb 24 Python
Django 解决由save方法引发的错误
May 21 Python
python能否java成为主流语言吗
Jun 22 Python
如何利用python发送邮件
Sep 26 Python
Python实现查询剪贴板自动匹配信息的思路详解
Jul 09 Python
详解Golang如何实现支持随机删除元素的堆
Sep 23 Python
pytorch 模型的train模式与eval模式实例
Feb 20 #Python
pytorch dataloader 取batch_size时候出现bug的解决方式
Feb 20 #Python
pytorch 使用加载训练好的模型做inference
Feb 20 #Python
pytorch中的inference使用实例
Feb 20 #Python
python encrypt 实现AES加密的实例详解
Feb 20 #Python
Python关于反射的实例代码分享
Feb 20 #Python
Python3监控疫情的完整代码
Feb 20 #Python
You might like
日本十大最佳动漫,全都是二次元的神级作品
2019/10/05 日漫
德生BCL3000的电路分析和打磨
2021/03/02 无线电
PHP中date()日期函数有关参数整理
2011/07/19 PHP
浅谈PHP与C#的值类型指向区别的详解
2013/05/21 PHP
使用CodeIgniter的类库做图片上传
2014/06/12 PHP
从sohu弄下来的flash中展示图片的代码
2007/04/27 Javascript
ExtJS 2.0实用简明教程 之获得ExtJS
2009/04/29 Javascript
js实现单一html页面两套css切换代码
2013/04/11 Javascript
extjs 如何给column 加上提示
2014/07/29 Javascript
基于Bootstrap重置输入框内容按钮插件
2016/05/12 Javascript
AngularJS基础 ng-hide 指令用法及示例代码
2016/08/01 Javascript
Angular的事件和表单详解
2016/12/26 Javascript
jquery实现的table排序功能示例
2017/03/10 Javascript
js读取本地文件的实例
2017/12/22 Javascript
基于IView中on-change属性的使用详解
2018/03/15 Javascript
Vue绑定内联样式问题
2018/10/17 Javascript
详解微信小程序中组件通讯
2018/10/30 Javascript
[03:36]DOTA2完美大师赛coL战队趣味视频——我演你猜
2017/11/23 DOTA
Python中利用sqrt()方法进行平方根计算的教程
2015/05/15 Python
python在线编译器的简单原理及简单实现代码
2018/02/02 Python
Python 实现12306登录功能实例代码
2018/02/09 Python
python3.6使用pymysql连接Mysql数据库
2018/05/25 Python
pycharm下配置pyqt5的教程(anaconda虚拟环境下+tensorflow)
2020/03/25 Python
matplotlib实现数据实时刷新的示例代码
2021/01/05 Python
使用HTML5捕捉音频与视频信息概述及实例
2018/08/22 HTML / CSS
Antonioli美国在线商店:时尚前卫奢华
2019/07/29 全球购物
Lentiamo丹麦:购买便宜的隐形眼镜
2021/01/13 全球购物
如何实现一个自定义类的序列化
2012/05/22 面试题
文秘专业自荐信
2013/10/14 职场文书
新员工欢迎词
2014/01/12 职场文书
幼儿教师个人总结
2015/02/05 职场文书
大连星海广场导游词
2015/02/10 职场文书
2015年小学体育教师工作总结
2015/10/23 职场文书
python将图片转为矢量图的方法步骤
2021/03/30 Python
python ansible自动化运维工具执行流程
2021/06/24 Python
AJAX实现指定部分页面刷新效果
2021/10/16 Javascript