Pytorch数据读取之Dataset和DataLoader知识总结


Posted in Python onMay 23, 2021

一、前言

确保安装

  • scikit-image
  • numpy

二、Dataset

一个例子:

# 导入需要的包
import torch
import torch.utils.data.dataset as Dataset
import numpy as np
 
# 编造数据
Data = np.asarray([[1, 2], [3, 4],[5, 6], [7, 8]])
Label = np.asarray([[0], [1], [0], [2]])
# 数据[1,2],对应的标签是[0],数据[3,4],对应的标签是[1]
 
 
#创建子类
class subDataset(Dataset.Dataset):
    #初始化,定义数据内容和标签
    def __init__(self, Data, Label):
        self.Data = Data
        self.Label = Label
    #返回数据集大小
    def __len__(self):
        return len(self.Data)
    #得到数据内容和标签
    def __getitem__(self, index):
        data = torch.Tensor(self.Data[index])
        label = torch.IntTensor(self.Label[index])
        return data, label
 
# 主函数
if __name__ == '__main__':
    dataset = subDataset(Data, Label)
    print(dataset)
    print('dataset大小为:', dataset.__len__())
    print(dataset.__getitem__(0))
    print(dataset[0])

 输出的结果

Pytorch数据读取之Dataset和DataLoader知识总结

我们有了对Dataset的一个整体的把握,再来分析里面的细节:

#创建子类
class subDataset(Dataset.Dataset):

创建子类时,继承的时Dataset.Dataset,不是一个Dataset。因为Dataset是module模块,不是class类,所以需要调用module里的class才行,因此是Dataset.Dataset!

lengetitem这两个函数,前者给出数据集的大小**,后者是用于查找数据和标签。是最重要的两个函数,我们后续如果要对数据做一些操作基本上都是再这两个函数的基础上进行。

三、DatasetLoader

DataLoader(dataset,
           batch_size=1,
           shuffle=False,
           sampler=None,
           batch_sampler=None,
           num_works=0,
           clollate_fn=None,
           pin_memory=False,
           drop_last=False,
           timeout=0,
           worker_init_fn=None,
           multiprocessing_context=None)

功能:构建可迭代的数据装载器;
dataset:Dataset类,决定数据从哪里读取及如何读取;数据集的路径
batchsize:批大小;
num_works:是否多进程读取数据;只对于CPU
shuffle:每个epoch是否打乱;
drop_last:当样本数不能被batchsize整除时,是否舍弃最后一批数据;
Epoch:所有训练样本都已输入到模型中,称为一个Epoch;
Iteration:一批样本输入到模型中,称之为一个Iteration;
Batchsize:批大小,决定一个Epoch中有多少个Iteration;

还是举一个实例:

import torch
import torch.utils.data.dataset as Dataset
import torch.utils.data.dataloader as DataLoader
import numpy as np
 
Data = np.asarray([[1, 2], [3, 4],[5, 6], [7, 8]])
Label = np.asarray([[0], [1], [0], [2]])
#创建子类
class subDataset(Dataset.Dataset):
    #初始化,定义数据内容和标签
    def __init__(self, Data, Label):
        self.Data = Data
        self.Label = Label
    #返回数据集大小
    def __len__(self):
        return len(self.Data)
    #得到数据内容和标签
    def __getitem__(self, index):
        data = torch.Tensor(self.Data[index])
        label = torch.IntTensor(self.Label[index])
        return data, label
 
if __name__ == '__main__':
    dataset = subDataset(Data, Label)
    print(dataset)
    print('dataset大小为:', dataset.__len__())
    print(dataset.__getitem__(0))
    print(dataset[0])
 
    #创建DataLoader迭代器,相当于我们要先定义好前面说的Dataset,然后再用Dataloader来对数据进行一些操作,比如是否需要打乱,则shuffle=True,是否需要多个进程读取数据num_workers=4,就是四个进程
 
    dataloader = DataLoader.DataLoader(dataset,batch_size= 2, shuffle = False, num_workers= 4)
    for i, item in enumerate(dataloader): #可以用enumerate来提取出里面的数据
        print('i:', i)
        data, label = item #数据是一个元组
        print('data:', data)
        print('label:', label)

四、将Dataset数据和标签放在GPU上(代码执行顺序出错则会有bug)

这部分可以直接去看博客:Dataset和DataLoader

总结下来时有两种方法解决

1.如果在创建Dataset的类时,定义__getitem__方法的时候,将数据转变为GPU类型。则需要将Dataloader里面的参数num_workers设置为0,因为这个参数是对于CPU而言的。如果数据改成了GPU,则只能单进程。如果是在Dataloader的部分,先多个子进程读取,再转变为GPU,则num_wokers不用修改。就是上述__getitem__部分的代码,移到Dataloader部分。

2.不过一般来讲,数据集和标签不会像我们上述编辑的那么简单。一般再kaggle上的标签都是存在CSV这种文件中。需要pandas的配合。

这个进阶可以看:WRITING CUSTOM DATASETS, DATALOADERS AND TRANSFORMS,他是用人脸图片作为数据和人脸特征点作为标签。

到此这篇关于Pytorch数据读取之Dataset和DataLoader知识总结的文章就介绍到这了,更多相关详解Dataset和DataLoader内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
python BeautifulSoup使用方法详解
Nov 21 Python
Python使用matplotlib实现在坐标系中画一个矩形的方法
May 20 Python
python判断图片宽度和高度后删除图片的方法
May 22 Python
浅谈Django REST Framework限速
Dec 12 Python
Python 查看文件的读写权限方法
Jan 23 Python
Pyqt实现无边框窗口拖动以及窗口大小改变
Apr 19 Python
Python使用logging模块实现打印log到指定文件的方法
Sep 05 Python
对python特殊函数 __call__()的使用详解
Jul 02 Python
python 字符串常用方法汇总详解
Sep 16 Python
使用Python串口实时显示数据并绘图的例子
Dec 26 Python
tensorflow通过模型文件,使用tensorboard查看其模型图Graph方式
Jan 23 Python
Python读取yaml文件的详细教程
Jul 21 Python
Python基础之函数嵌套知识总结
May 23 #Python
利用python Pandas实现批量拆分Excel与合并Excel
May 23 #Python
Python基础之元编程知识总结
May 23 #Python
Python利用folium实现地图可视化
python爬虫之selenium库的安装及使用教程
教你利用python实现企业微信发送消息
python基础之文件处理知识总结
May 23 #Python
You might like
使用PHP维护文件系统
2006/10/09 PHP
php部分常见问题总结
2008/03/27 PHP
php strlen mb_strlen计算中英文混排字符串长度
2009/07/10 PHP
Fatal error: Allowed memory size of 134217728 bytes exhausted (tried to allocate 2611816 bytes)
2014/11/08 PHP
php查询ip所在地的方法
2014/12/05 PHP
php微信开发之自定义菜单完整流程
2016/10/08 PHP
THINKPHP3.2使用soap连接webservice的解决方法
2017/12/13 PHP
Js+XML 操作
2006/09/20 Javascript
js 鼠标点击事件及其它捕获
2009/06/04 Javascript
JS清空多文本框、文本域示例代码
2014/02/24 Javascript
jquery实现全屏滚动
2015/12/28 Javascript
JS验证全角与半角及相互转化的介绍
2017/05/18 Javascript
基于代数方程库Algebra.js解二元一次方程功能示例
2017/06/09 Javascript
Puppeteer环境搭建的详细步骤
2018/09/21 Javascript
微信小程序模板消息推送的两种实现方式
2019/08/27 Javascript
VueX模块的具体使用(小白教程)
2020/06/05 Javascript
在vue中使用cookie记住用户上次选择的实例(本次例子中为下拉框)
2020/09/11 Javascript
vue实现广告栏上下滚动效果
2020/11/26 Vue.js
JS实现页面侧边栏效果探究
2021/01/08 Javascript
Python实现Linux中的du命令
2017/06/12 Python
python logging重复记录日志问题的解决方法
2018/07/12 Python
用Python写一个模拟qq聊天小程序的代码实例
2019/03/06 Python
解决pycharm中导入自己写的.py函数出错问题
2020/02/12 Python
浅谈numpy中函数resize与reshape,ravel与flatten的区别
2020/06/18 Python
法国娇韵诗官方旗舰店:Clarins是来自法国的天然护肤品牌
2018/06/30 全球购物
英国曼彻斯特宠物用品品牌:Bunty Pet Products
2019/07/27 全球购物
迪斯尼假期(欧洲、中东及非洲):Disney Holidays EMEA
2021/02/15 全球购物
货车司机岗位职责
2014/03/18 职场文书
亲子活动总结
2014/04/26 职场文书
遵纪守法演讲稿
2014/05/23 职场文书
2014年教师节座谈会发言稿
2014/09/10 职场文书
个人汇报材料范文
2014/12/30 职场文书
微信搭讪开场白
2015/05/28 职场文书
跑出一片天观后感
2015/06/08 职场文书
儿童诗两首教学反思
2016/02/23 职场文书
如何正确理解python装饰器
2021/06/15 Python