Pytorch数据读取之Dataset和DataLoader知识总结


Posted in Python onMay 23, 2021

一、前言

确保安装

  • scikit-image
  • numpy

二、Dataset

一个例子:

# 导入需要的包
import torch
import torch.utils.data.dataset as Dataset
import numpy as np
 
# 编造数据
Data = np.asarray([[1, 2], [3, 4],[5, 6], [7, 8]])
Label = np.asarray([[0], [1], [0], [2]])
# 数据[1,2],对应的标签是[0],数据[3,4],对应的标签是[1]
 
 
#创建子类
class subDataset(Dataset.Dataset):
    #初始化,定义数据内容和标签
    def __init__(self, Data, Label):
        self.Data = Data
        self.Label = Label
    #返回数据集大小
    def __len__(self):
        return len(self.Data)
    #得到数据内容和标签
    def __getitem__(self, index):
        data = torch.Tensor(self.Data[index])
        label = torch.IntTensor(self.Label[index])
        return data, label
 
# 主函数
if __name__ == '__main__':
    dataset = subDataset(Data, Label)
    print(dataset)
    print('dataset大小为:', dataset.__len__())
    print(dataset.__getitem__(0))
    print(dataset[0])

 输出的结果

Pytorch数据读取之Dataset和DataLoader知识总结

我们有了对Dataset的一个整体的把握,再来分析里面的细节:

#创建子类
class subDataset(Dataset.Dataset):

创建子类时,继承的时Dataset.Dataset,不是一个Dataset。因为Dataset是module模块,不是class类,所以需要调用module里的class才行,因此是Dataset.Dataset!

lengetitem这两个函数,前者给出数据集的大小**,后者是用于查找数据和标签。是最重要的两个函数,我们后续如果要对数据做一些操作基本上都是再这两个函数的基础上进行。

三、DatasetLoader

DataLoader(dataset,
           batch_size=1,
           shuffle=False,
           sampler=None,
           batch_sampler=None,
           num_works=0,
           clollate_fn=None,
           pin_memory=False,
           drop_last=False,
           timeout=0,
           worker_init_fn=None,
           multiprocessing_context=None)

功能:构建可迭代的数据装载器;
dataset:Dataset类,决定数据从哪里读取及如何读取;数据集的路径
batchsize:批大小;
num_works:是否多进程读取数据;只对于CPU
shuffle:每个epoch是否打乱;
drop_last:当样本数不能被batchsize整除时,是否舍弃最后一批数据;
Epoch:所有训练样本都已输入到模型中,称为一个Epoch;
Iteration:一批样本输入到模型中,称之为一个Iteration;
Batchsize:批大小,决定一个Epoch中有多少个Iteration;

还是举一个实例:

import torch
import torch.utils.data.dataset as Dataset
import torch.utils.data.dataloader as DataLoader
import numpy as np
 
Data = np.asarray([[1, 2], [3, 4],[5, 6], [7, 8]])
Label = np.asarray([[0], [1], [0], [2]])
#创建子类
class subDataset(Dataset.Dataset):
    #初始化,定义数据内容和标签
    def __init__(self, Data, Label):
        self.Data = Data
        self.Label = Label
    #返回数据集大小
    def __len__(self):
        return len(self.Data)
    #得到数据内容和标签
    def __getitem__(self, index):
        data = torch.Tensor(self.Data[index])
        label = torch.IntTensor(self.Label[index])
        return data, label
 
if __name__ == '__main__':
    dataset = subDataset(Data, Label)
    print(dataset)
    print('dataset大小为:', dataset.__len__())
    print(dataset.__getitem__(0))
    print(dataset[0])
 
    #创建DataLoader迭代器,相当于我们要先定义好前面说的Dataset,然后再用Dataloader来对数据进行一些操作,比如是否需要打乱,则shuffle=True,是否需要多个进程读取数据num_workers=4,就是四个进程
 
    dataloader = DataLoader.DataLoader(dataset,batch_size= 2, shuffle = False, num_workers= 4)
    for i, item in enumerate(dataloader): #可以用enumerate来提取出里面的数据
        print('i:', i)
        data, label = item #数据是一个元组
        print('data:', data)
        print('label:', label)

四、将Dataset数据和标签放在GPU上(代码执行顺序出错则会有bug)

这部分可以直接去看博客:Dataset和DataLoader

总结下来时有两种方法解决

1.如果在创建Dataset的类时,定义__getitem__方法的时候,将数据转变为GPU类型。则需要将Dataloader里面的参数num_workers设置为0,因为这个参数是对于CPU而言的。如果数据改成了GPU,则只能单进程。如果是在Dataloader的部分,先多个子进程读取,再转变为GPU,则num_wokers不用修改。就是上述__getitem__部分的代码,移到Dataloader部分。

2.不过一般来讲,数据集和标签不会像我们上述编辑的那么简单。一般再kaggle上的标签都是存在CSV这种文件中。需要pandas的配合。

这个进阶可以看:WRITING CUSTOM DATASETS, DATALOADERS AND TRANSFORMS,他是用人脸图片作为数据和人脸特征点作为标签。

到此这篇关于Pytorch数据读取之Dataset和DataLoader知识总结的文章就介绍到这了,更多相关详解Dataset和DataLoader内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Python中的闭包总结
Sep 18 Python
详解Python 序列化Serialize 和 反序列化Deserialize
Aug 20 Python
Python引用传值概念与用法实例小结
Oct 07 Python
Python列表删除的三种方法代码分享
Oct 31 Python
python实现俄罗斯方块
Jun 26 Python
Python 实现某个功能每隔一段时间被执行一次的功能方法
Oct 14 Python
Python3爬虫学习之爬虫利器Beautiful Soup用法分析
Dec 12 Python
对python制作自己的数据集实例讲解
Dec 12 Python
Python使用post及get方式提交数据的实例
Jan 24 Python
python turtle库画一个方格和圆实例
Jun 27 Python
Django中的用户身份验证示例详解
Aug 07 Python
Python 判断时间是否在时间区间内的实例
May 16 Python
Python基础之函数嵌套知识总结
May 23 #Python
利用python Pandas实现批量拆分Excel与合并Excel
May 23 #Python
Python基础之元编程知识总结
May 23 #Python
Python利用folium实现地图可视化
python爬虫之selenium库的安装及使用教程
教你利用python实现企业微信发送消息
python基础之文件处理知识总结
May 23 #Python
You might like
PHP 七大优势分析
2009/06/23 PHP
javascript 学习之旅 (1)
2009/02/05 Javascript
超链接的禁用属性Disabled使用示例
2014/07/31 Javascript
jQuery中add()方法用法实例
2015/01/08 Javascript
ajax+jQuery实现级联显示地址的方法
2015/05/06 Javascript
javascript实现数组中的内容随机输出
2015/08/11 Javascript
js 中文汉字转Unicode、Unicode转中文汉字、ASCII转换Unicode、Unicode转换ASCII、中文转换
2016/12/06 Javascript
vue单页面应用打开新窗口显示跳转页面的实例
2018/09/21 Javascript
从零开始用electron手撸一个截屏工具的示例代码
2018/10/10 Javascript
深入学习TypeScript 、React、 Redux和Ant-Design的最佳实践
2019/06/17 Javascript
解决vue+router路由跳转不起作用的一项原因
2020/07/19 Javascript
JS页面动态绘图工具SVG,Canvas,VML介简介
2020/10/16 Javascript
[01:02:34]TFT vs VGJ.T Supermajor 败者组 BO3 第二场 6.5
2018/06/06 DOTA
[49:27]LGD vs OG 2018国际邀请赛小组赛BO2 第二场 8.16
2018/08/17 DOTA
[01:38]女王驾到——至宝魔廷新尊技能&特效展示
2020/06/16 DOTA
Python实现的Google IP 可用性检测脚本
2015/04/23 Python
将Python的Django框架与认证系统整合的方法
2015/07/24 Python
结合Python的SimpleHTTPServer源码来解析socket通信
2016/06/27 Python
使用Python快速制作可视化报表的方法
2019/02/03 Python
python+selenium 点击单选框-radio的实现方法
2019/09/03 Python
Python计算两个矩形重合面积代码实例
2019/09/16 Python
python字符串,元组,列表,字典互转代码实例详解
2020/02/14 Python
Python任务调度模块APScheduler使用
2020/04/15 Python
用python实现一个简单的验证码
2020/12/09 Python
世界上最大的乐谱选择:Sheet Music Plus
2020/01/18 全球购物
向全球直邮输送天然健康产品:iHerb.com
2020/05/03 全球购物
STP协议的主要用途是什么?为什么要用STP
2012/12/20 面试题
年会活动策划方案
2014/01/23 职场文书
勤奋学习演讲稿
2014/05/10 职场文书
2014年教师政治学习材料
2014/06/02 职场文书
运动会开幕词
2015/01/28 职场文书
社区端午节活动总结
2015/02/11 职场文书
电信营业员岗位职责
2015/04/14 职场文书
python数字类型和占位符详情
2022/03/13 Python
Consul在linux环境的集群部署
2022/04/08 Servers
解决Mysql报错 Table 'mysql.user' doesn't exist
2022/05/06 MySQL