PyTorch 如何设置随机数种子使结果可复现


Posted in Python onMay 12, 2021

由于在模型训练的过程中存在大量的随机操作,使得对于同一份代码,重复运行后得到的结果不一致。

因此,为了得到可重复的实验结果,我们需要对随机数生成器设置一个固定的种子。

CUDNN

cudnn中对卷积操作进行了优化,牺牲了精度来换取计算效率。如果需要保证可重复性,可以使用如下设置:

from torch.backends import cudnn
cudnn.benchmark = False            # if benchmark=True, deterministic will be False
cudnn.deterministic = True

不过实际上这个设置对精度影响不大,仅仅是小数点后几位的差别。所以如果不是对精度要求极高,其实不太建议修改,因为会使计算效率降低。

Pytorch

torch.manual_seed(seed)            # 为CPU设置随机种子
torch.cuda.manual_seed(seed)       # 为当前GPU设置随机种子
torch.cuda.manual_seed_all(seed)   # 为所有GPU设置随机种子

Python & Numpy

如果读取数据的过程采用了随机预处理(如RandomCrop、RandomHorizontalFlip等),那么对python、numpy的随机数生成器也需要设置种子。

import random
import numpy as np
random.seed(seed)
np.random.seed(seed)

Dataloader

如果dataloader采用了多线程(num_workers > 1), 那么由于读取数据的顺序不同,最终运行结果也会有差异。

也就是说,改变num_workers参数,也会对实验结果产生影响。

目前暂时没有发现解决这个问题的方法,但是只要固定num_workers数目(线程数)不变,基本上也能够重复实验结果。

补充:pytorch 固定随机数种子踩过的坑

1.初步固定

def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     torch.cuda.manual_seed(seed)
     np.random.seed(seed)
     random.seed(seed)
     torch.backends.cudnn.deterministic = True
     torch.backends.cudnn.enabled = False
     torch.backends.cudnn.benchmark = False
     #torch.backends.cudnn.benchmark = True #for accelerating the running
 setup_seed(2019)

2.继续添加如下代码:

tensor_dataset = ImageList(opt.training_list,transform)
def _init_fn(worker_id): 
    random.seed(10 + worker_id)
    np.random.seed(10 + worker_id)
    torch.manual_seed(10 + worker_id)
    torch.cuda.manual_seed(10 + worker_id)
    torch.cuda.manual_seed_all(10 + worker_id)
dataloader = DataLoader(tensor_dataset,                        
                    batch_size=opt.batchSize,     
                    shuffle=True,     
                    num_workers=opt.workers,
                    worker_init_fn=_init_fn)

3.在上面的操作之后发现加载的数据多次试验大部分一致了

但是仍然有些数据是不一致的,后来发现是pytorch版本的问题,将原先的0.3.1版本升级到1.1.0版本,问题解决

4.按照上面的操作后虽然解决了问题

但是由于将cudnn.benchmark设置为False,运行速度降低到原来的1/3,所以继续探索,最终解决方案是把第1步变为如下,同时将该部分代码尽可能放在主程序最开始的部分,例如:

import torch
import torch.nn as nn
from torch.nn import init
import pdb
import torch.nn.parallel
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
from torch.utils.data import DataLoader, Dataset
import sys
gpu_id = "3,2"
os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id
print('GPU: ',gpu_id)
def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     torch.cuda.manual_seed(seed)
     np.random.seed(seed)
     random.seed(seed)
     cudnn.deterministic = True
     #cudnn.benchmark = False
     #cudnn.enabled = False

setup_seed(2019)

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
详解Python中的循环语句的用法
Apr 09 Python
详解Python编程中包的概念与管理
Oct 16 Python
Python中的异常处理相关语句基础学习笔记
Jul 11 Python
Python开发中爬虫使用代理proxy抓取网页的方法示例
Sep 26 Python
浅析Python3爬虫登录模拟
Feb 07 Python
Python数据结构之哈夫曼树定义与使用方法示例
Apr 22 Python
Python实现合并同一个文件夹下所有txt文件的方法示例
Apr 26 Python
Python3分析处理声音数据的例子
Aug 27 Python
对Keras中predict()方法和predict_classes()方法的区别说明
Jun 09 Python
Pycharm调试程序技巧小结
Aug 08 Python
Vs Code中8个好用的python 扩展插件
Oct 12 Python
浅谈python中的多态
Jun 15 Python
Python Parser的用法
May 12 #Python
pytorch MSELoss计算平均的实现方法
May 12 #Python
Django如何创作一个简单的最小程序
May 12 #Python
Pytorch中TensorBoard及torchsummary的使用详解
pytorch 一行代码查看网络参数总量的实现
May 12 #Python
pytorch查看网络参数显存占用量等操作
May 12 #Python
Python入门之使用pandas分析excel数据
May 12 #Python
You might like
星际争霸中的对战模式介绍
2020/03/04 星际争霸
PHP中MD5函数使用实例代码
2008/06/07 PHP
支持中文字母数字、自定义字体php验证码代码
2012/02/27 PHP
PHP中的命名空间详细介绍
2015/07/02 PHP
ThinkPHP5.0框架验证码功能实现方法【基于第三方扩展包】
2019/03/11 PHP
php解析非标准json、非规范json的方式实例
2020/12/10 PHP
基于jquery的模态div层弹出效果
2010/08/21 Javascript
JS函数验证总结(方便js客户端输入验证)
2010/10/29 Javascript
在chrome浏览器中,防止input[text]和textarea在聚焦时出现黄色边框的解决方法
2011/05/24 Javascript
Javascript实现简单的富文本编辑器附演示
2014/06/16 Javascript
jQuery.holdReady()方法用法实例
2014/12/27 Javascript
js设置document.domain实现跨域的注意点分析
2015/05/21 Javascript
Hammer.js+轮播原理实现简洁的滑屏功能
2016/02/02 Javascript
js 获取经纬度的实现方法
2016/06/20 Javascript
微信小程序 触控事件详细介绍
2016/10/17 Javascript
详解基于vue-cli优化的webpack配置
2017/11/06 Javascript
vue自定义filters过滤器
2018/04/26 Javascript
原生js实现form表单序列化的方法
2018/08/02 Javascript
详解Vue.js自定义tipOnce指令用法实例
2018/12/19 Javascript
vue自定义表单生成器form-create使用详解
2019/07/19 Javascript
一个简单的python爬虫程序 爬取豆瓣热度Top100以内的电影信息
2018/04/17 Python
Python超越函数积分运算以及绘图实现代码
2019/11/20 Python
win10安装tesserocr配置 Python使用tesserocr识别字母数字验证码
2020/01/16 Python
css3动画效果抖动解决方法
2018/09/03 HTML / CSS
css3打造一款漂亮的卡哇伊按钮
2013/03/20 HTML / CSS
使用HTML5里的classList操作CSS类
2016/06/28 HTML / CSS
CSS3 画基本图形,圆形、椭圆形、三角形等
2016/09/20 HTML / CSS
我的求职计划书
2014/01/10 职场文书
三年大学自我鉴定
2014/01/16 职场文书
护士在校生自荐信
2014/02/01 职场文书
副护士长竞聘演讲稿
2014/04/30 职场文书
2014年稽查工作总结
2014/12/20 职场文书
迎新生欢迎词2015
2015/07/16 职场文书
Vue+TypeScript中处理computed方式
2022/04/02 Vue.js
vue3使用vuedraggable实现拖拽功能
2022/04/06 Vue.js
详细介绍MySQL中limit和offset的用法
2022/05/06 MySQL