浅谈PyTorch的可重复性问题(如何使实验结果可复现)


Posted in Python onFebruary 20, 2020

由于在模型训练的过程中存在大量的随机操作,使得对于同一份代码,重复运行后得到的结果不一致。因此,为了得到可重复的实验结果,我们需要对随机数生成器设置一个固定的种子。

许多博客都有介绍如何解决这个问题,但是很多都不够全面,往往不能保证结果精确一致。我经过许多调研和实验,总结了以下方法,记录下来。

全部设置可以分为三部分:

1. CUDNN

cudnn中对卷积操作进行了优化,牺牲了精度来换取计算效率。如果需要保证可重复性,可以使用如下设置:

from torch.backends import cudnn
cudnn.benchmark = False      # if benchmark=True, deterministic will be False
cudnn.deterministic = True

不过实际上这个设置对精度影响不大,仅仅是小数点后几位的差别。所以如果不是对精度要求极高,其实不太建议修改,因为会使计算效率降低。

2. Pytorch

torch.manual_seed(seed)      # 为CPU设置随机种子
torch.cuda.manual_seed(seed)    # 为当前GPU设置随机种子
torch.cuda.manual_seed_all(seed)  # 为所有GPU设置随机种子

3. Python & Numpy

如果读取数据的过程采用了随机预处理(如RandomCrop、RandomHorizontalFlip等),那么对python、numpy的随机数生成器也需要设置种子。

import random
import numpy as np
random.seed(seed)
np.random.seed(seed)

最后,关于dataloader:

注意,如果dataloader采用了多线程(num_workers > 1), 那么由于读取数据的顺序不同,最终运行结果也会有差异。也就是说,改变num_workers参数,也会对实验结果产生影响。目前暂时没有发现解决这个问题的方法,但是只要固定num_workers数目(线程数)不变,基本上也能够重复实验结果。

对于不同线程的随机数种子设置,主要通过DataLoader的worker_init_fn参数来实现。默认情况下使用线程ID作为随机数种子。如果需要自己设定,可以参考以下代码:

GLOBAL_SEED = 1
 
def set_seed(seed):
  random.seed(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  torch.cuda.manual_seed(seed)
  torch.cuda.manual_seed_all(seed)
 
GLOBAL_WORKER_ID = None
def worker_init_fn(worker_id):
  global GLOBAL_WORKER_ID
  GLOBAL_WORKER_ID = worker_id
  set_seed(GLOBAL_SEED + worker_id)
 
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=2, worker_init_fn=worker_init_fn)

以上这篇浅谈PyTorch的可重复性问题(如何使实验结果可复现)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中splitlines()方法的使用简介
May 20 Python
Python中的字典与成员运算符初步探究
Oct 13 Python
python文件操作相关知识点总结整理
Feb 22 Python
python实现简单登陆系统
Oct 18 Python
python模拟登陆,用session维持回话的实例
Dec 27 Python
Python设计模式之代理模式实例详解
Jan 19 Python
python实现浪漫的烟花秀
Jan 30 Python
python爬虫爬取幽默笑话网站
Oct 24 Python
flask框架json数据的拿取和返回操作示例
Nov 28 Python
Python 虚拟环境工作原理解析
Dec 24 Python
Python新建项目自动添加介绍和utf-8编码的方法
Dec 26 Python
python中 Flask Web 表单的使用方法
May 20 Python
pytorch 模型的train模式与eval模式实例
Feb 20 #Python
pytorch dataloader 取batch_size时候出现bug的解决方式
Feb 20 #Python
pytorch 使用加载训练好的模型做inference
Feb 20 #Python
pytorch中的inference使用实例
Feb 20 #Python
python encrypt 实现AES加密的实例详解
Feb 20 #Python
Python关于反射的实例代码分享
Feb 20 #Python
Python3监控疫情的完整代码
Feb 20 #Python
You might like
PHP实现用户认证及管理完全源码
2007/03/11 PHP
浅谈PHP eval()函数定义和用法
2016/06/21 PHP
PHP生成静态HTML文档实现代码
2016/06/23 PHP
javascript 操作select下拉列表框的一点小经验
2010/03/20 Javascript
用方法封装javascript的new操作符(一)
2010/12/25 Javascript
Node.js和PHP根据ip获取地理位置的方法
2014/03/14 Javascript
JavaScript操作XML文件之XML读取方法
2015/06/09 Javascript
javascript实现控制的多级下拉菜单
2015/07/05 Javascript
理解JavaScript表单的基础知识
2016/01/25 Javascript
jQuery1.9+中删除了live以后的替代方法
2016/06/17 Javascript
Angularjs使用directive自定义指令实现attribute继承的方法详解
2016/08/05 Javascript
bootstrap基础知识学习笔记
2016/11/02 Javascript
js初始化验证实例详解
2016/11/26 Javascript
angular2中router路由跳转navigate的使用与刷新页面问题详解
2017/05/07 Javascript
详解Angular CLI + Electron 开发环境搭建
2017/07/20 Javascript
js html实现计算器功能
2018/11/13 Javascript
js最实用string(字符串)类型的使用及截取与拼接详解
2019/04/26 Javascript
详解Python中的Numpy、SciPy、MatPlotLib安装与配置
2017/11/17 Python
python基础之包的导入和__init__.py的介绍
2018/01/08 Python
python读取TXT每行,并存到LIST中的方法
2018/10/26 Python
局域网内python socket实现windows与linux间的消息传送
2019/04/19 Python
Python实现微信好友的数据分析
2019/12/16 Python
python中图像通道分离与合并实例
2020/01/17 Python
Windows下PyCharm配置Anaconda环境(超详细教程)
2020/07/31 Python
CSS3动画之利用requestAnimationFrame触发重新播放功能
2019/09/11 HTML / CSS
html5弹跳球示例代码
2013/07/23 HTML / CSS
非洲NO.1网上商店:Jumia肯尼亚
2016/08/18 全球购物
英国乡村时尚和宠物用品专家:Pet & Country
2018/07/02 全球购物
2014年情人节活动方案
2014/02/16 职场文书
2014教育局对照检查材料思想汇报
2014/09/23 职场文书
个人整改方案范文
2014/10/25 职场文书
2014社会治安综合治理工作总结
2014/12/04 职场文书
乌镇导游词
2015/02/02 职场文书
小学教师个人工作总结2015
2015/04/20 职场文书
2016年秋季趣味运动会开幕词
2016/03/04 职场文书
Python anaconda安装库命令详解
2021/10/16 Python