Pytorch数据读取之Dataset和DataLoader知识总结


Posted in Python onMay 23, 2021

一、前言

确保安装

  • scikit-image
  • numpy

二、Dataset

一个例子:

# 导入需要的包
import torch
import torch.utils.data.dataset as Dataset
import numpy as np
 
# 编造数据
Data = np.asarray([[1, 2], [3, 4],[5, 6], [7, 8]])
Label = np.asarray([[0], [1], [0], [2]])
# 数据[1,2],对应的标签是[0],数据[3,4],对应的标签是[1]
 
 
#创建子类
class subDataset(Dataset.Dataset):
    #初始化,定义数据内容和标签
    def __init__(self, Data, Label):
        self.Data = Data
        self.Label = Label
    #返回数据集大小
    def __len__(self):
        return len(self.Data)
    #得到数据内容和标签
    def __getitem__(self, index):
        data = torch.Tensor(self.Data[index])
        label = torch.IntTensor(self.Label[index])
        return data, label
 
# 主函数
if __name__ == '__main__':
    dataset = subDataset(Data, Label)
    print(dataset)
    print('dataset大小为:', dataset.__len__())
    print(dataset.__getitem__(0))
    print(dataset[0])

 输出的结果

Pytorch数据读取之Dataset和DataLoader知识总结

我们有了对Dataset的一个整体的把握,再来分析里面的细节:

#创建子类
class subDataset(Dataset.Dataset):

创建子类时,继承的时Dataset.Dataset,不是一个Dataset。因为Dataset是module模块,不是class类,所以需要调用module里的class才行,因此是Dataset.Dataset!

lengetitem这两个函数,前者给出数据集的大小**,后者是用于查找数据和标签。是最重要的两个函数,我们后续如果要对数据做一些操作基本上都是再这两个函数的基础上进行。

三、DatasetLoader

DataLoader(dataset,
           batch_size=1,
           shuffle=False,
           sampler=None,
           batch_sampler=None,
           num_works=0,
           clollate_fn=None,
           pin_memory=False,
           drop_last=False,
           timeout=0,
           worker_init_fn=None,
           multiprocessing_context=None)

功能:构建可迭代的数据装载器;
dataset:Dataset类,决定数据从哪里读取及如何读取;数据集的路径
batchsize:批大小;
num_works:是否多进程读取数据;只对于CPU
shuffle:每个epoch是否打乱;
drop_last:当样本数不能被batchsize整除时,是否舍弃最后一批数据;
Epoch:所有训练样本都已输入到模型中,称为一个Epoch;
Iteration:一批样本输入到模型中,称之为一个Iteration;
Batchsize:批大小,决定一个Epoch中有多少个Iteration;

还是举一个实例:

import torch
import torch.utils.data.dataset as Dataset
import torch.utils.data.dataloader as DataLoader
import numpy as np
 
Data = np.asarray([[1, 2], [3, 4],[5, 6], [7, 8]])
Label = np.asarray([[0], [1], [0], [2]])
#创建子类
class subDataset(Dataset.Dataset):
    #初始化,定义数据内容和标签
    def __init__(self, Data, Label):
        self.Data = Data
        self.Label = Label
    #返回数据集大小
    def __len__(self):
        return len(self.Data)
    #得到数据内容和标签
    def __getitem__(self, index):
        data = torch.Tensor(self.Data[index])
        label = torch.IntTensor(self.Label[index])
        return data, label
 
if __name__ == '__main__':
    dataset = subDataset(Data, Label)
    print(dataset)
    print('dataset大小为:', dataset.__len__())
    print(dataset.__getitem__(0))
    print(dataset[0])
 
    #创建DataLoader迭代器,相当于我们要先定义好前面说的Dataset,然后再用Dataloader来对数据进行一些操作,比如是否需要打乱,则shuffle=True,是否需要多个进程读取数据num_workers=4,就是四个进程
 
    dataloader = DataLoader.DataLoader(dataset,batch_size= 2, shuffle = False, num_workers= 4)
    for i, item in enumerate(dataloader): #可以用enumerate来提取出里面的数据
        print('i:', i)
        data, label = item #数据是一个元组
        print('data:', data)
        print('label:', label)

四、将Dataset数据和标签放在GPU上(代码执行顺序出错则会有bug)

这部分可以直接去看博客:Dataset和DataLoader

总结下来时有两种方法解决

1.如果在创建Dataset的类时,定义__getitem__方法的时候,将数据转变为GPU类型。则需要将Dataloader里面的参数num_workers设置为0,因为这个参数是对于CPU而言的。如果数据改成了GPU,则只能单进程。如果是在Dataloader的部分,先多个子进程读取,再转变为GPU,则num_wokers不用修改。就是上述__getitem__部分的代码,移到Dataloader部分。

2.不过一般来讲,数据集和标签不会像我们上述编辑的那么简单。一般再kaggle上的标签都是存在CSV这种文件中。需要pandas的配合。

这个进阶可以看:WRITING CUSTOM DATASETS, DATALOADERS AND TRANSFORMS,他是用人脸图片作为数据和人脸特征点作为标签。

到此这篇关于Pytorch数据读取之Dataset和DataLoader知识总结的文章就介绍到这了,更多相关详解Dataset和DataLoader内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Python中使用Tkinter模块创建GUI程序实例
Jan 14 Python
从Python的源码来解析Python下的freeblock
May 11 Python
Python检测字符串中是否包含某字符集合中的字符
May 21 Python
详解Swift中属性的声明与作用
Jun 30 Python
python爬取拉勾网职位数据的方法
Jan 24 Python
Python爬虫实例扒取2345天气预报
Mar 04 Python
使用Django启动命令行及执行脚本的方法
May 29 Python
Python3 关于pycharm自动导入包快捷设置的方法
Jan 16 Python
深入浅析Python中的迭代器
Jun 04 Python
浅谈keras中自定义二分类任务评价指标metrics的方法以及代码
Jun 11 Python
python中pathlib模块的基本用法与总结
Aug 17 Python
弄清Pytorch显存的分配机制
Dec 10 Python
Python基础之函数嵌套知识总结
May 23 #Python
利用python Pandas实现批量拆分Excel与合并Excel
May 23 #Python
Python基础之元编程知识总结
May 23 #Python
Python利用folium实现地图可视化
python爬虫之selenium库的安装及使用教程
教你利用python实现企业微信发送消息
python基础之文件处理知识总结
May 23 #Python
You might like
两个开源的Php输出Excel文件类
2010/02/08 PHP
解析php时间戳与日期的转换
2013/06/06 PHP
php创建和删除目录函数介绍和递归删除目录函数分享
2014/11/18 PHP
搭建Vim为自定义的PHP开发工具的一些技巧
2015/12/11 PHP
PHP实现求两个字符串最长公共子串的方法示例
2017/11/17 PHP
PHP实现递归的三种方法
2020/07/04 PHP
不错的asp中显示新闻的功能
2006/10/13 Javascript
jquery利用ajax调用后台方法实例
2013/08/23 Javascript
JavaScript函数模式详解
2014/11/07 Javascript
jQuery使用addClass()方法给元素添加多个class样式
2015/03/26 Javascript
基于JS实现省市联动效果代码分享
2016/06/06 Javascript
javascript正则表达式模糊匹配IP地址功能示例
2017/01/06 Javascript
js实现弹窗暗层效果
2017/01/16 Javascript
原生js封装自定义滚动条
2017/03/24 Javascript
bootstrap 弹出框modal添加垂直方向滚轴效果
2018/07/09 Javascript
vue代码分割的实现(codesplit)
2018/11/13 Javascript
使用Sonarqube扫描Javascript代码的示例
2018/12/26 Javascript
对layui中table组件工具栏的使用详解
2019/09/19 Javascript
vue 子组件修改data或调用操作
2020/08/07 Javascript
[00:57]英雄,你的补给到了!
2020/11/13 DOTA
python删除过期文件的方法
2015/05/29 Python
详解Python中的日志模块logging
2015/06/19 Python
如何准确判断请求是搜索引擎爬虫(蜘蛛)发出的请求
2015/10/13 Python
初探利用Python进行图文识别(OCR)
2019/02/26 Python
Python3.4解释器用法简单示例
2019/03/22 Python
Python pip配置国内源的方法
2020/02/14 Python
对python pandas中 inplace 参数的理解
2020/06/27 Python
细说CSS3中的选择符
2008/10/17 HTML / CSS
一个不错的HTML5 Canvas多层点击事件监听实例
2014/04/29 HTML / CSS
Spartoo葡萄牙鞋类网站:线上销售鞋履与时尚配饰
2017/01/11 全球购物
世界汽车零件:World Car Parts
2019/09/04 全球购物
小米乌克兰网上商店:Xiaomi.UA
2019/10/29 全球购物
欠条范文
2015/07/03 职场文书
2016中秋节广告语
2016/01/28 职场文书
html2 canvas svg不能识别的解决方案
2021/06/03 HTML / CSS
死磕 java同步系列之synchronized解析
2021/06/28 Java/Android