Pytorch 如何加速Dataloader提升数据读取速度


Posted in Python onMay 28, 2021

在利用DL解决图像问题时,影响训练效率最大的有时候是GPU,有时候也可能是CPU和你的磁盘。

很多设计不当的任务,在训练神经网络的时候,大部分时间都是在从磁盘中读取数据,而不是做 Backpropagation 。

这种症状的体现是使用 Nividia-smi 查看 GPU 使用率时,Memory-Usage 占用率很高,但是 GPU-Util 时常为 0% ,如下图所示:

Pytorch 如何加速Dataloader提升数据读取速度

如何解决这种问题呢?

在 Nvidia 提出的分布式框架 Apex 里面,我们在源码里面找到了一个简单的解决方案:

https://github.com/NVIDIA/apex/blob/f5cd5ae937f168c763985f627bbf850648ea5f3f/examples/imagenet/main_amp.py#L256 ​

class data_prefetcher():
    def __init__(self, loader):
        self.loader = iter(loader)
        self.stream = torch.cuda.Stream()
        self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1)
        self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1,3,1,1)
        # With Amp, it isn't necessary to manually convert data to half.
        # if args.fp16:
        #     self.mean = self.mean.half()
        #     self.std = self.std.half()
        self.preload()

    def preload(self):
        try:
            self.next_input, self.next_target = next(self.loader)
        except StopIteration:
            self.next_input = None
            self.next_target = None
            return
        with torch.cuda.stream(self.stream):
            self.next_input = self.next_input.cuda(non_blocking=True)
            self.next_target = self.next_target.cuda(non_blocking=True)
            # With Amp, it isn't necessary to manually convert data to half.
            # if args.fp16:
            #     self.next_input = self.next_input.half()
            # else:
            self.next_input = self.next_input.float()
            self.next_input = self.next_input.sub_(self.mean).div_(self.std)

我们能看到 Nvidia 是在读取每次数据返回给网络的时候,预读取下一次迭代需要的数据,

那么对我们自己的训练代码只需要做下面的改造:

training_data_loader = DataLoader(
    dataset=train_dataset,
    num_workers=opts.threads,
    batch_size=opts.batchSize,
    pin_memory=True,
    shuffle=True,
)
for iteration, batch in enumerate(training_data_loader, 1):
    # 训练代码

#-------------升级后---------

data, label = prefetcher.next()
iteration = 0
while data is not None:
    iteration += 1
    # 训练代码
    data, label = prefetcher.next()

这样子我们的 Dataloader 就像打了鸡血一样提高了效率很多,如下图:

Pytorch 如何加速Dataloader提升数据读取速度

当然,最好的解决方案还是从硬件上,把读取速度慢的机械硬盘换成 NVME 固态吧~

补充:Pytorch设置多线程进行dataloader时影响GPU运行

使用PyTorch设置多线程(threads)进行数据读取时,其实是假的多线程,他是开了N个子进程(PID是连续的)进行模拟多线程工作。

以载入cocodataset为例

DataLoader

dataloader = torch.utils.data.DataLoader(COCODataset(config["train_path"],
                                                     (config["img_w"], config["img_h"]),
                                                     is_training=True),
                                         batch_size=config["batch_size"],
                                         shuffle=True, num_workers=32, pin_memory=True)

numworkers就是指定多少线程的参数,原为32。

检查GPU是否运行该程序

查看运行在gpu上的所有程序:

fuser -v /dev/nvidia*

如果没有返回,则该程序并没有在GPU上运行

指定GPU运行

将num_workers改成0即可

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python常用内置函数总结
Feb 08 Python
python读取word文档的方法
May 09 Python
Python编程实现双击更新所有已安装python模块的方法
Jun 05 Python
django 创建过滤器的实例详解
Aug 14 Python
解决csv.writer写入文件有多余的空行问题
Jul 06 Python
Sanic框架应用部署方法详解
Jul 18 Python
Python统计python文件中代码,注释及空白对应的行数示例【测试可用】
Jul 25 Python
python多进程使用及线程池的使用方法代码详解
Oct 24 Python
Python PyInstaller安装和使用教程详解
Jan 08 Python
python列表推导和生成器表达式知识点总结
Jan 10 Python
执行Python程序时模块报错问题
Mar 26 Python
Python下划线5种含义代码实例解析
Jul 10 Python
在前女友婚礼上,用Python破解了现场的WIFI还把名称改成了
pytorch DataLoader的num_workers参数与设置大小详解
May 28 #Python
Flask搭建一个API服务器的步骤
May 28 #Python
Python趣味挑战之给幼儿园弟弟生成1000道算术题
May 28 #Python
解决Python中的modf()函数取小数部分不准确问题
May 28 #Python
利用Python+OpenCV三步去除水印
python实现自定义日志的具体方法
May 28 #Python
You might like
MVC模式的PHP实现
2006/10/09 PHP
PHP连接和操作MySQL数据库基础教程
2014/09/29 PHP
标准PHP的AES加密算法类
2015/03/12 PHP
php从给定url获取文件扩展名的方法
2015/03/14 PHP
解决PHP程序运行时:Fatal error: Maximum execution time of 30 seconds exceeded in的错误提示
2016/11/25 PHP
在Yii2特定页面如何禁用调试工具栏Debug Toolbar详解
2017/08/07 PHP
详解PHP素材图片上传、下载功能
2019/04/12 PHP
关于laravel 子查询 & join的使用
2019/10/16 PHP
使用PHP+Redis实现延迟任务,实现自动取消订单功能
2019/11/21 PHP
JavaScript 原型链学习总结
2010/10/29 Javascript
jquery事件机制扩展插件 jquery鼠标右键事件。
2011/12/26 Javascript
关于IE中getElementsByClassName不能用的问题解决方法
2013/08/26 Javascript
jquery使用hide方法隐藏指定id的元素
2015/03/30 Javascript
PHP+jQuery实现随意拖动层并即时保存拖动位置
2015/04/30 Javascript
JavaScript数据结构与算法之链表
2016/01/29 Javascript
JavaScript设计模式开发中组合模式的使用教程
2016/05/18 Javascript
Form表单按回车自动提交表单的实现方法
2016/11/18 Javascript
利用JS实现页面删除并重新排序功能
2016/12/09 Javascript
AngularJS基于ui-route实现深层路由的方法【路由嵌套】
2016/12/14 Javascript
JavaScript装饰器函数(Decorator)实例详解
2017/03/30 Javascript
Vue 2.0中生命周期与钩子函数的一些理解
2017/05/09 Javascript
AngularJS实现的生成随机数与猜数字大小功能示例
2017/12/25 Javascript
使用vue.js在页面内组件监听scroll事件的方法
2018/09/11 Javascript
微信小程序自定义tabBar组件开发详解
2020/09/24 Javascript
vue 实现强制类型转换 数字类型转为字符串
2019/11/07 Javascript
JavaScript 正则应用详解【模式、欲查、反向引用等】
2020/05/13 Javascript
vue:el-input输入时限制输入的类型操作
2020/08/05 Javascript
Python多进程并发与多线程并发编程实例总结
2018/02/08 Python
收集的7个CSS3代码生成工具
2010/04/17 HTML / CSS
有关HTML5页面在iPhoneX适配问题
2017/11/13 HTML / CSS
美国知名保健品网站:LuckyVitamin(支持中文)
2017/08/09 全球购物
商务英语求职自荐信范文
2013/12/24 职场文书
物业管理工作方案
2014/05/10 职场文书
节约能源标语
2014/06/17 职场文书
领导班子三严三实对照检查材料
2014/09/25 职场文书
上级领导检查欢迎词
2015/09/30 职场文书