PyTorch 如何设置随机数种子使结果可复现


Posted in Python onMay 12, 2021

由于在模型训练的过程中存在大量的随机操作,使得对于同一份代码,重复运行后得到的结果不一致。

因此,为了得到可重复的实验结果,我们需要对随机数生成器设置一个固定的种子。

CUDNN

cudnn中对卷积操作进行了优化,牺牲了精度来换取计算效率。如果需要保证可重复性,可以使用如下设置:

from torch.backends import cudnn
cudnn.benchmark = False            # if benchmark=True, deterministic will be False
cudnn.deterministic = True

不过实际上这个设置对精度影响不大,仅仅是小数点后几位的差别。所以如果不是对精度要求极高,其实不太建议修改,因为会使计算效率降低。

Pytorch

torch.manual_seed(seed)            # 为CPU设置随机种子
torch.cuda.manual_seed(seed)       # 为当前GPU设置随机种子
torch.cuda.manual_seed_all(seed)   # 为所有GPU设置随机种子

Python & Numpy

如果读取数据的过程采用了随机预处理(如RandomCrop、RandomHorizontalFlip等),那么对python、numpy的随机数生成器也需要设置种子。

import random
import numpy as np
random.seed(seed)
np.random.seed(seed)

Dataloader

如果dataloader采用了多线程(num_workers > 1), 那么由于读取数据的顺序不同,最终运行结果也会有差异。

也就是说,改变num_workers参数,也会对实验结果产生影响。

目前暂时没有发现解决这个问题的方法,但是只要固定num_workers数目(线程数)不变,基本上也能够重复实验结果。

补充:pytorch 固定随机数种子踩过的坑

1.初步固定

def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     torch.cuda.manual_seed(seed)
     np.random.seed(seed)
     random.seed(seed)
     torch.backends.cudnn.deterministic = True
     torch.backends.cudnn.enabled = False
     torch.backends.cudnn.benchmark = False
     #torch.backends.cudnn.benchmark = True #for accelerating the running
 setup_seed(2019)

2.继续添加如下代码:

tensor_dataset = ImageList(opt.training_list,transform)
def _init_fn(worker_id): 
    random.seed(10 + worker_id)
    np.random.seed(10 + worker_id)
    torch.manual_seed(10 + worker_id)
    torch.cuda.manual_seed(10 + worker_id)
    torch.cuda.manual_seed_all(10 + worker_id)
dataloader = DataLoader(tensor_dataset,                        
                    batch_size=opt.batchSize,     
                    shuffle=True,     
                    num_workers=opt.workers,
                    worker_init_fn=_init_fn)

3.在上面的操作之后发现加载的数据多次试验大部分一致了

但是仍然有些数据是不一致的,后来发现是pytorch版本的问题,将原先的0.3.1版本升级到1.1.0版本,问题解决

4.按照上面的操作后虽然解决了问题

但是由于将cudnn.benchmark设置为False,运行速度降低到原来的1/3,所以继续探索,最终解决方案是把第1步变为如下,同时将该部分代码尽可能放在主程序最开始的部分,例如:

import torch
import torch.nn as nn
from torch.nn import init
import pdb
import torch.nn.parallel
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
from torch.utils.data import DataLoader, Dataset
import sys
gpu_id = "3,2"
os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id
print('GPU: ',gpu_id)
def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     torch.cuda.manual_seed(seed)
     np.random.seed(seed)
     random.seed(seed)
     cudnn.deterministic = True
     #cudnn.benchmark = False
     #cudnn.enabled = False

setup_seed(2019)

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
python实现图片批量剪切示例
Mar 25 Python
python从sqlite读取并显示数据的方法
May 08 Python
Python实现的最近最少使用算法
Jul 10 Python
Django之Mode的外键自关联和引用未定义的Model方法
Dec 15 Python
简单介绍python封装的基本知识
Aug 10 Python
使用Python刷淘宝喵币(低阶入门版)
Oct 30 Python
Python通过TensorFLow进行线性模型训练原理与实现方法详解
Jan 15 Python
在tensorflow中实现屏蔽输出的log信息
Feb 04 Python
关于python 跨域处理方式详解
Mar 28 Python
Python新手如何进行闭包时绑定变量操作
May 29 Python
python sleep和wait对比总结
Feb 03 Python
Python还能这么玩之只用30行代码从excel提取个人值班表
Jun 05 Python
Python Parser的用法
May 12 #Python
pytorch MSELoss计算平均的实现方法
May 12 #Python
Django如何创作一个简单的最小程序
May 12 #Python
Pytorch中TensorBoard及torchsummary的使用详解
pytorch 一行代码查看网络参数总量的实现
May 12 #Python
pytorch查看网络参数显存占用量等操作
May 12 #Python
Python入门之使用pandas分析excel数据
May 12 #Python
You might like
WordPress中用于获取文章信息以及分类链接的函数用法
2015/12/18 PHP
Yii安装与使用Excel扩展的方法
2016/07/13 PHP
php layui实现前端多图上传实例
2019/07/30 PHP
浅析PHP中的 inet_pton 网络函数
2019/12/16 PHP
js的闭包的一个示例说明
2008/11/18 Javascript
javascript+css 网页每次加载不同样式的实现方法
2009/12/27 Javascript
jquery.boxy弹出框(后隔N秒后自动隐藏/自动跳转)
2013/01/15 Javascript
jQuery实现限制textarea文本框输入字符数量的方法
2015/05/28 Javascript
jquery实现具有收缩功能的垂直导航菜单
2016/02/16 Javascript
jQuery 常用代码集锦(必看篇)
2016/05/16 Javascript
微信小程序 富文本转文本实例详解
2016/10/24 Javascript
深入浅析Vue组件开发
2016/11/25 Javascript
Bootstrap BootstrapDialog使用详解
2017/02/17 Javascript
微信小程序 后台登录(非微信账号)实例详解
2017/03/31 Javascript
微信扫码支付零云插件版实例详解
2017/04/26 Javascript
JS实现简单的天数计算器完整实例
2017/04/28 Javascript
angular 实现同步验证器跨字段验证的方法
2019/04/11 Javascript
微信小程序实现的canvas合成图片功能示例
2019/05/03 Javascript
关于Js中new操作符的作用详解
2021/02/21 Javascript
Python交换变量
2008/09/06 Python
videocapture库制作python视频高速传输程序
2013/12/23 Python
在Python的Django框架中显示对象子集的方法
2015/07/21 Python
python多线程下信号处理程序示例
2019/05/31 Python
python3格式化字符串 f-string的高级用法(推荐)
2020/03/04 Python
python 实现压缩和解压缩的示例
2020/09/22 Python
CSS3 开发工具收集
2010/04/17 HTML / CSS
Monnier Freres中文官网:法国领先的奢侈品配饰在线零售商
2017/11/01 全球购物
迪奥官网:Dior.com
2018/12/04 全球购物
审核会计岗位职责
2013/11/08 职场文书
反邪教警示教育方案
2014/05/13 职场文书
农业局党的群众路线教育实践活动整改方案
2014/09/20 职场文书
2015年世界急救日宣传活动方案
2015/05/06 职场文书
nginx配置proxy_pass中url末尾带/与不带/的区别详解
2021/03/31 Servers
浅谈Java实现分布式事务的三种方案
2021/06/11 Java/Android
Python使用psutil库对系统数据进行采集监控的方法
2021/08/23 Python
为什么MySQL8新特性会修改自增主键属性
2022/04/18 MySQL