pytorch 实现多个Dataloader同时训练


Posted in Python onMay 29, 2021

看代码吧~

pytorch 实现多个Dataloader同时训练

如果两个dataloader的长度不一样,那就加个:

from itertools import cycle

仅使用zip,迭代器将在长度等于最小数据集的长度时耗尽。 但是,使用cycle时,我们将再次重复最小的数据集,除非迭代器查看最大数据集中的所有样本。

pytorch 实现多个Dataloader同时训练

补充:pytorch技巧:自定义数据集 torch.utils.data.DataLoader 及Dataset的使用

本博客中有可直接运行的例子,便于直观的理解,在torch环境中运行即可。

1. 数据传递机制

在 pytorch 中数据传递按一下顺序:

1、创建 datasets ,也就是所需要读取的数据集。

2、把 datasets 传入DataLoader。

3、DataLoader迭代产生训练数据提供给模型。

2. torch.utils.data.Dataset

Pytorch提供两种数据集:

Map式数据集 Iterable式数据集。其中Map式数据集继承torch.utils.data.Dataset,Iterable式数据集继承torch.utils.data.IterableDataset。

本文只介绍 Map式数据集。

一个Map式的数据集必须要重写 __getitem__(self, index)、 __len__(self) 两个方法,用来表示从索引到样本的映射(Map)。 __getitem__(self, index)按索引映射到对应的数据, __len__(self)则会返回这个数据集的长度。

基本格式如下:

import torch.utils.data as data
class VOCDetection(data.Dataset):
    '''
    必须继承data.Dataset类
    '''
    def __init__(self):
        '''
        在这里进行初始化,一般是初始化文件路径或文件列表
        '''
        pass
    def __getitem__(self, index):
        '''
        1. 按照index,读取文件中对应的数据  (读取一个数据!!!!我们常读取的数据是图片,一般我们送入模型的数据成批的,但在这里只是读取一张图片,成批后面会说到)
        2. 对读取到的数据进行数据增强 (数据增强是深度学习中经常用到的,可以提高模型的泛化能力)
        3. 返回数据对 (一般我们要返回 图片,对应的标签) 在这里因为我没有写完整的代码,返回值用 0 代替
        '''
        return 0
    def __len__(self):
        '''
        返回数据集的长度
        '''
        return 0

可直接运行的例子:

import torch.utils.data as data
import numpy as np
x = np.array(range(80)).reshape(8, 10) # 模拟输入, 8个样本,每个样本长度为10
y = np.array(range(8))  # 模拟对应样本的标签, 8个标签 
class Mydataset(data.Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y
        self.idx = list()
        for item in x:
            self.idx.append(item)
        pass
    def __getitem__(self, index):
        input_data = self.idx[index] #可继续进行数据增强,这里没有进行数据增强操作
        target = self.y[index]
        return input_data, target
    def __len__(self):
        return len(self.idx)
datasets = Mydataset(x, y)  # 初始化
print(datasets.__len__())  # 调用__len__() 返回数据的长度
for i in range(len(y)):
    input_data, target = datasets.__getitem__(i)  # 调用__getitem__(index) 返回读取的数据对
    print('input_data%d =' % i, input_data)
    print('target%d = ' % i, target)

结果如下:

pytorch 实现多个Dataloader同时训练

3. torch.utils.data.DataLoader

PyTorch中数据读取的一个重要接口是 torch.utils.data.DataLoader。

该接口主要用来将自定义的数据读取接口的输出或者PyTorch已有的数据读取接口的输入按照batch_size封装成Tensor,后续只需要再包装成Variable即可作为模型的输入。

torch.utils.data.DataLoader(onject)的可用参数如下:

1.dataset(Dataset): 数据读取接口,该输出是torch.utils.data.Dataset类的对象(或者继承自该类的自定义类的对象)。

2.batch_size (int, optional): 批训练数据量的大小,根据具体情况设置即可。一般为2的N次方(默认:1)

3.shuffle (bool, optional):是否打乱数据,一般在训练数据中会采用。(默认:False)

4.sampler (Sampler, optional):从数据集中提取样本的策略。如果指定,“shuffle”必须为false。我没有用过,不太了解。

5.batch_sampler (Sampler, optional):和batch_size、shuffle等参数互斥,一般用默认。

6.num_workers:这个参数必须大于等于0,为0时默认使用主线程读取数据,其他大于0的数表示通过多个进程来读取数据,可以加快数据读取速度,一般设置为2的N次方,且小于batch_size(默认:0)

7.collate_fn (callable, optional): 合并样本清单以形成小批量。用来处理不同情况下的输入dataset的封装。

8.pin_memory (bool, optional):如果设置为True,那么data loader将会在返回它们之前,将tensors拷贝到CUDA中的固定内存中.

9.drop_last (bool, optional): 如果数据集大小不能被批大小整除,则设置为“true”以除去最后一个未完成的批。如果“false”那么最后一批将更小。(默认:false)

10.timeout(numeric, optional):设置数据读取时间限制,超过这个时间还没读取到数据的话就会报错。(默认:0)

11.worker_init_fn (callable, optional): 每个worker初始化函数(默认:None)

可直接运行的例子:

import torch.utils.data as data
import numpy as np
x = np.array(range(80)).reshape(8, 10) # 模拟输入, 8个样本,每个样本长度为10
y = np.array(range(8))  # 模拟对应样本的标签, 8个标签
class Mydataset(data.Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y
        self.idx = list()
        for item in x:
            self.idx.append(item)
        pass
    def __getitem__(self, index):
        input_data = self.idx[index]
        target = self.y[index]
        return input_data, target
    def __len__(self):
        return len(self.idx)
if __name__ ==('__main__'):
    datasets = Mydataset(x, y)  # 初始化
    dataloader = data.DataLoader(datasets, batch_size=4, num_workers=2) 
    for i, (input_data, target) in enumerate(dataloader):
        print('input_data%d' % i, input_data)
        print('target%d' % i, target)

结果如下:(注意看类别,DataLoader把数据封装为Tensor)

pytorch 实现多个Dataloader同时训练

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 参数列表中的self 显式不等于冗余
Dec 01 Python
介绍Python中的fabs()方法的使用
May 14 Python
在Python中操作列表之List.pop()方法的使用
May 21 Python
Python中pow()和math.pow()函数用法示例
Feb 11 Python
Python实现k-means算法
Feb 23 Python
python打包压缩、读取指定目录下的指定类型文件
Apr 12 Python
python 2.7 检测一个网页是否能正常访问的方法
Dec 26 Python
python将控制台输出保存至文件的方法
Jan 07 Python
Python三元运算与lambda表达式实例解析
Nov 30 Python
15行Python代码实现免费发送手机短信推送消息功能
Feb 27 Python
Python xlrd/xlwt 创建excel文件及常用操作
Sep 24 Python
matplotlib bar()实现百分比堆积柱状图
Feb 24 Python
python 如何做一个识别率百分百的OCR
基于PyTorch实现一个简单的CNN图像分类器
May 29 #Python
python 爬取华为应用市场评论
python 开心网和豆瓣日记爬取的小爬虫
May 29 #Python
Python趣味挑战之实现简易版音乐播放器
新手必备Python开发环境搭建教程
Keras多线程机制与flask多线程冲突的解决方案
May 28 #Python
You might like
Sony CFR 320 修复改造
2020/03/14 无线电
php 利用socket发送HTTP请求(GET,POST)
2015/08/24 PHP
CodeIgniter分页类pagination使用方法示例
2016/03/28 PHP
php mysql实现mysql_select_db选择数据库
2016/12/30 PHP
PHP去除空数组且数组键名重置的讲解
2019/02/28 PHP
PHP中的Iterator迭代对象属性详解
2019/04/12 PHP
PHP实现的62进制转10进制,10进制转62进制函数示例
2019/06/06 PHP
解决在laravel中auth建立时候遇到的问题
2019/10/15 PHP
js onpropertychange输入框 事件获取属性
2009/03/26 Javascript
javascript的数据类型、字面量、变量介绍
2012/05/23 Javascript
js实现的类似于asp数据字典的数据类型代码实例
2014/09/03 Javascript
jQuery检测鼠标左键和右键点击的方法
2015/03/17 Javascript
浅谈Angularjs link和compile的使用区别
2016/10/21 Javascript
bootstrap学习使用(导航条、下拉菜单、轮播、栅格布局等)
2016/12/01 Javascript
原生JS实现导航下拉菜单效果
2020/11/25 Javascript
JS改变页面颜色源码分享
2018/02/24 Javascript
JavaScript适配器模式原理与用法实例详解
2020/03/09 Javascript
微信小程序中data-key属性之数据传输(经验总结)
2020/08/22 Javascript
举例详解Python中的split()函数的使用方法
2015/04/07 Python
Python安装使用命令行交互模块pexpect的基础教程
2016/05/12 Python
Django中的forms组件实例详解
2018/11/08 Python
如何在Django配置文件里配置session链接
2019/08/06 Python
用python3 urllib破解有道翻译反爬虫机制详解
2019/08/14 Python
对YOLOv3模型调用时候的python接口详解
2019/08/26 Python
使用jTopo给Html5 Canva中绘制的元素添加鼠标事件
2014/05/15 HTML / CSS
西班牙最大的婴儿用品网上商店:Bebitus
2019/05/30 全球购物
DeinDesign德国:设计自己的手机壳
2019/12/14 全球购物
Linux面试题LINUX系统类
2015/11/25 面试题
北大研究生linux应用求职信
2013/10/29 职场文书
中学教师培训制度
2014/01/31 职场文书
有多年工作经验的自我评价
2014/03/02 职场文书
民生工程实施方案
2014/03/22 职场文书
指导教师评语
2014/04/26 职场文书
年检委托书
2014/08/30 职场文书
学习党的群众路线教育实践活动剖析材料
2014/10/13 职场文书
起诉离婚协议书样本
2014/11/25 职场文书