Pytorch 如何加速Dataloader提升数据读取速度


Posted in Python onMay 28, 2021

在利用DL解决图像问题时,影响训练效率最大的有时候是GPU,有时候也可能是CPU和你的磁盘。

很多设计不当的任务,在训练神经网络的时候,大部分时间都是在从磁盘中读取数据,而不是做 Backpropagation 。

这种症状的体现是使用 Nividia-smi 查看 GPU 使用率时,Memory-Usage 占用率很高,但是 GPU-Util 时常为 0% ,如下图所示:

Pytorch 如何加速Dataloader提升数据读取速度

如何解决这种问题呢?

在 Nvidia 提出的分布式框架 Apex 里面,我们在源码里面找到了一个简单的解决方案:

https://github.com/NVIDIA/apex/blob/f5cd5ae937f168c763985f627bbf850648ea5f3f/examples/imagenet/main_amp.py#L256 ​

class data_prefetcher():
    def __init__(self, loader):
        self.loader = iter(loader)
        self.stream = torch.cuda.Stream()
        self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1)
        self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1,3,1,1)
        # With Amp, it isn't necessary to manually convert data to half.
        # if args.fp16:
        #     self.mean = self.mean.half()
        #     self.std = self.std.half()
        self.preload()

    def preload(self):
        try:
            self.next_input, self.next_target = next(self.loader)
        except StopIteration:
            self.next_input = None
            self.next_target = None
            return
        with torch.cuda.stream(self.stream):
            self.next_input = self.next_input.cuda(non_blocking=True)
            self.next_target = self.next_target.cuda(non_blocking=True)
            # With Amp, it isn't necessary to manually convert data to half.
            # if args.fp16:
            #     self.next_input = self.next_input.half()
            # else:
            self.next_input = self.next_input.float()
            self.next_input = self.next_input.sub_(self.mean).div_(self.std)

我们能看到 Nvidia 是在读取每次数据返回给网络的时候,预读取下一次迭代需要的数据,

那么对我们自己的训练代码只需要做下面的改造:

training_data_loader = DataLoader(
    dataset=train_dataset,
    num_workers=opts.threads,
    batch_size=opts.batchSize,
    pin_memory=True,
    shuffle=True,
)
for iteration, batch in enumerate(training_data_loader, 1):
    # 训练代码

#-------------升级后---------

data, label = prefetcher.next()
iteration = 0
while data is not None:
    iteration += 1
    # 训练代码
    data, label = prefetcher.next()

这样子我们的 Dataloader 就像打了鸡血一样提高了效率很多,如下图:

Pytorch 如何加速Dataloader提升数据读取速度

当然,最好的解决方案还是从硬件上,把读取速度慢的机械硬盘换成 NVME 固态吧~

补充:Pytorch设置多线程进行dataloader时影响GPU运行

使用PyTorch设置多线程(threads)进行数据读取时,其实是假的多线程,他是开了N个子进程(PID是连续的)进行模拟多线程工作。

以载入cocodataset为例

DataLoader

dataloader = torch.utils.data.DataLoader(COCODataset(config["train_path"],
                                                     (config["img_w"], config["img_h"]),
                                                     is_training=True),
                                         batch_size=config["batch_size"],
                                         shuffle=True, num_workers=32, pin_memory=True)

numworkers就是指定多少线程的参数,原为32。

检查GPU是否运行该程序

查看运行在gpu上的所有程序:

fuser -v /dev/nvidia*

如果没有返回,则该程序并没有在GPU上运行

指定GPU运行

将num_workers改成0即可

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Web服务器框架 Tornado简介
Jul 16 Python
深入解析Python中的WSGI接口
May 11 Python
Python通过90行代码搭建一个音乐搜索工具
Jul 29 Python
Python缩进和冒号详解
Jun 01 Python
浅谈python中copy和deepcopy中的区别
Oct 23 Python
在dataframe两列日期相减并且得到具体的月数实例
Jul 03 Python
Django文件上传与下载(FileFlid)
Oct 06 Python
Python通过递归获取目录下指定文件代码实例
Nov 07 Python
一文轻松掌握python语言命名规范规则
Jun 18 Python
python中字典增加和删除使用方法
Sep 30 Python
PyTorch梯度裁剪避免训练loss nan的操作
May 24 Python
Python基础之条件语句详解
Jun 16 Python
在前女友婚礼上,用Python破解了现场的WIFI还把名称改成了
pytorch DataLoader的num_workers参数与设置大小详解
May 28 #Python
Flask搭建一个API服务器的步骤
May 28 #Python
Python趣味挑战之给幼儿园弟弟生成1000道算术题
May 28 #Python
解决Python中的modf()函数取小数部分不准确问题
May 28 #Python
利用Python+OpenCV三步去除水印
python实现自定义日志的具体方法
May 28 #Python
You might like
discuz图片顺序混乱解决方案
2015/07/29 PHP
PHP获取二叉树镜像的方法
2018/01/17 PHP
获取页面高度,窗口高度,滚动条高度等参数值getPageSize,getPageScroll
2006/09/22 Javascript
jquery ajax 登录验证实现代码
2009/09/23 Javascript
点击下载链接 弹出页面实现代码
2009/10/01 Javascript
JS表的模拟方法
2015/02/05 Javascript
angular.js之路由的选择方法
2016/09/24 Javascript
利用vue写todolist单页应用
2016/12/15 Javascript
bootstrap导航栏、下拉菜单、表单的简单应用实例解析
2017/01/06 Javascript
最常用的jQuery表单验证(简单)
2017/05/23 jQuery
jquery.validate.js 多个相同name的处理方式
2017/07/10 jQuery
对于js垃圾回收机制的理解
2017/09/14 Javascript
Javascript数组方法reduce的妙用之处分享
2019/06/10 Javascript
Vue 实现CLI 3.0 + momentjs + lodash打包时优化
2019/11/13 Javascript
JS实现动态无缝轮播
2020/01/11 Javascript
Vue获取微博授权URL代码实例
2020/11/04 Javascript
[03:21]辉夜杯主赛事 12月25日TOP5
2015/12/26 DOTA
Python生成随机密码
2015/03/10 Python
Pyhthon中使用compileall模块编译源文件为pyc文件
2015/04/28 Python
Python global全局变量函数详解
2018/09/18 Python
Python变量类型知识点总结
2019/02/18 Python
Python 异常处理Ⅳ过程图解
2019/10/18 Python
基于python实现蓝牙通信代码实例
2019/11/19 Python
Python内置方法实现字符串的秘钥加解密(推荐)
2019/12/09 Python
Python中SQLite如何使用
2020/05/27 Python
使用CSS3来代替JS实现交互
2017/08/10 HTML / CSS
Smallable英国家庭概念店:设计师童装及家居装饰
2017/07/05 全球购物
英国家庭珠宝商:T. H. Baker
2018/02/08 全球购物
Notino芬兰:购买香水和化妆品
2019/04/15 全球购物
高职教师岗位职责
2013/12/24 职场文书
2015年乡镇信访工作总结
2015/04/07 职场文书
2015年社区党务工作总结
2015/04/21 职场文书
同意报考公务员证明
2015/06/17 职场文书
导游词之江苏溱潼古镇
2019/11/27 职场文书
css实现文章分割线样式的多种方法总结
2021/04/21 HTML / CSS
利用js实现简单开关灯代码
2021/11/23 Javascript