Pytorch数据读取之Dataset和DataLoader知识总结


Posted in Python onMay 23, 2021

一、前言

确保安装

  • scikit-image
  • numpy

二、Dataset

一个例子:

# 导入需要的包
import torch
import torch.utils.data.dataset as Dataset
import numpy as np
 
# 编造数据
Data = np.asarray([[1, 2], [3, 4],[5, 6], [7, 8]])
Label = np.asarray([[0], [1], [0], [2]])
# 数据[1,2],对应的标签是[0],数据[3,4],对应的标签是[1]
 
 
#创建子类
class subDataset(Dataset.Dataset):
    #初始化,定义数据内容和标签
    def __init__(self, Data, Label):
        self.Data = Data
        self.Label = Label
    #返回数据集大小
    def __len__(self):
        return len(self.Data)
    #得到数据内容和标签
    def __getitem__(self, index):
        data = torch.Tensor(self.Data[index])
        label = torch.IntTensor(self.Label[index])
        return data, label
 
# 主函数
if __name__ == '__main__':
    dataset = subDataset(Data, Label)
    print(dataset)
    print('dataset大小为:', dataset.__len__())
    print(dataset.__getitem__(0))
    print(dataset[0])

 输出的结果

Pytorch数据读取之Dataset和DataLoader知识总结

我们有了对Dataset的一个整体的把握,再来分析里面的细节:

#创建子类
class subDataset(Dataset.Dataset):

创建子类时,继承的时Dataset.Dataset,不是一个Dataset。因为Dataset是module模块,不是class类,所以需要调用module里的class才行,因此是Dataset.Dataset!

lengetitem这两个函数,前者给出数据集的大小**,后者是用于查找数据和标签。是最重要的两个函数,我们后续如果要对数据做一些操作基本上都是再这两个函数的基础上进行。

三、DatasetLoader

DataLoader(dataset,
           batch_size=1,
           shuffle=False,
           sampler=None,
           batch_sampler=None,
           num_works=0,
           clollate_fn=None,
           pin_memory=False,
           drop_last=False,
           timeout=0,
           worker_init_fn=None,
           multiprocessing_context=None)

功能:构建可迭代的数据装载器;
dataset:Dataset类,决定数据从哪里读取及如何读取;数据集的路径
batchsize:批大小;
num_works:是否多进程读取数据;只对于CPU
shuffle:每个epoch是否打乱;
drop_last:当样本数不能被batchsize整除时,是否舍弃最后一批数据;
Epoch:所有训练样本都已输入到模型中,称为一个Epoch;
Iteration:一批样本输入到模型中,称之为一个Iteration;
Batchsize:批大小,决定一个Epoch中有多少个Iteration;

还是举一个实例:

import torch
import torch.utils.data.dataset as Dataset
import torch.utils.data.dataloader as DataLoader
import numpy as np
 
Data = np.asarray([[1, 2], [3, 4],[5, 6], [7, 8]])
Label = np.asarray([[0], [1], [0], [2]])
#创建子类
class subDataset(Dataset.Dataset):
    #初始化,定义数据内容和标签
    def __init__(self, Data, Label):
        self.Data = Data
        self.Label = Label
    #返回数据集大小
    def __len__(self):
        return len(self.Data)
    #得到数据内容和标签
    def __getitem__(self, index):
        data = torch.Tensor(self.Data[index])
        label = torch.IntTensor(self.Label[index])
        return data, label
 
if __name__ == '__main__':
    dataset = subDataset(Data, Label)
    print(dataset)
    print('dataset大小为:', dataset.__len__())
    print(dataset.__getitem__(0))
    print(dataset[0])
 
    #创建DataLoader迭代器,相当于我们要先定义好前面说的Dataset,然后再用Dataloader来对数据进行一些操作,比如是否需要打乱,则shuffle=True,是否需要多个进程读取数据num_workers=4,就是四个进程
 
    dataloader = DataLoader.DataLoader(dataset,batch_size= 2, shuffle = False, num_workers= 4)
    for i, item in enumerate(dataloader): #可以用enumerate来提取出里面的数据
        print('i:', i)
        data, label = item #数据是一个元组
        print('data:', data)
        print('label:', label)

四、将Dataset数据和标签放在GPU上(代码执行顺序出错则会有bug)

这部分可以直接去看博客:Dataset和DataLoader

总结下来时有两种方法解决

1.如果在创建Dataset的类时,定义__getitem__方法的时候,将数据转变为GPU类型。则需要将Dataloader里面的参数num_workers设置为0,因为这个参数是对于CPU而言的。如果数据改成了GPU,则只能单进程。如果是在Dataloader的部分,先多个子进程读取,再转变为GPU,则num_wokers不用修改。就是上述__getitem__部分的代码,移到Dataloader部分。

2.不过一般来讲,数据集和标签不会像我们上述编辑的那么简单。一般再kaggle上的标签都是存在CSV这种文件中。需要pandas的配合。

这个进阶可以看:WRITING CUSTOM DATASETS, DATALOADERS AND TRANSFORMS,他是用人脸图片作为数据和人脸特征点作为标签。

到此这篇关于Pytorch数据读取之Dataset和DataLoader知识总结的文章就介绍到这了,更多相关详解Dataset和DataLoader内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
python3 selenium 切换窗口的几种方法小结
May 21 Python
python 脚本生成随机 字母 + 数字密码功能
May 26 Python
Python的argparse库使用详解
Oct 09 Python
浅谈pandas用groupby后对层级索引levels的处理方法
Nov 06 Python
Python实现深度遍历和广度遍历的方法
Jan 22 Python
django多文件上传,form提交,多对多外键保存的实例
Aug 06 Python
python中class的定义及使用教程
Sep 18 Python
Python 常用日期处理 -- calendar 与 dateutil 模块的使用
Sep 02 Python
python入门教程之基本算术运算符
Nov 13 Python
基于python制作简易版学生信息管理系统
Apr 20 Python
如何用python清洗文件中的数据
Jun 18 Python
利用python做数据拟合详情
Nov 17 Python
Python基础之函数嵌套知识总结
May 23 #Python
利用python Pandas实现批量拆分Excel与合并Excel
May 23 #Python
Python基础之元编程知识总结
May 23 #Python
Python利用folium实现地图可视化
python爬虫之selenium库的安装及使用教程
教你利用python实现企业微信发送消息
python基础之文件处理知识总结
May 23 #Python
You might like
兼容IE、FireFox、Chrome等浏览器的xml处理函数js代码
2011/11/30 Javascript
查询json的数据结构的8种方式简介
2014/03/10 Javascript
jquery的each方法使用示例分享
2014/03/25 Javascript
jQuery带时间的日期控件代码分享
2015/08/26 Javascript
跟我学习javascript的作用域与作用域链
2015/11/19 Javascript
ReactNative页面跳转实例代码
2016/09/27 Javascript
详解js产生对象的3种基本方式(工厂模式,构造函数模式,原型模式)
2017/01/09 Javascript
JavaScript函数表达式详解及实例
2017/05/05 Javascript
解决axios会发送两次请求,有个OPTIONS请求的问题
2018/10/25 Javascript
vue微信分享插件使用方法详解
2020/02/18 Javascript
vue使用openlayers实现移动点动画
2020/09/24 Javascript
[04:52]第二届DOTA2亚洲邀请赛主赛事第一天比赛集锦:OG娜迦海妖放大配合谜团大中3人
2017/04/02 DOTA
Python中还原JavaScript的escape函数编码后字符串的方法
2014/08/22 Python
Windows和Linux下Python输出彩色文字的方法教程
2017/05/02 Python
python执行系统命令后获取返回值的几种方式集合
2018/05/12 Python
Python 动态导入对象,importlib.import_module()的使用方法
2019/08/28 Python
Django中间件拦截未登录url实例详解
2019/09/03 Python
python实现图片转换成素描和漫画格式
2020/08/19 Python
python创建文本文件的简单方法
2020/08/30 Python
详解python中的三种命令行模块(sys.argv,argparse,click)
2020/12/15 Python
css3 iphone玻璃透明气泡完美实现
2013/03/20 HTML / CSS
CSS3 二级导航菜单的制作的示例
2018/04/02 HTML / CSS
Nordgreen台湾官网:极简北欧设计手表
2019/08/21 全球购物
英国最大的滑板品牌选择:Route One
2019/09/22 全球购物
分解成质因数(如435234=251*17*17*3*2,据说是华为笔试题)
2014/07/16 面试题
如何将无状态会话Bean发布为WEB服务,只有无状态会话Bean可以发布为WEB服务?
2015/12/03 面试题
大专自我鉴定范文
2013/10/23 职场文书
大学生就业意向书范文
2014/04/01 职场文书
大学生村官个人对照检查材料(群众路线)
2014/09/26 职场文书
党性分析材料格式
2014/12/19 职场文书
企业爱心捐款倡议书
2015/04/27 职场文书
投资合作意向书范本
2015/05/08 职场文书
和领导吃饭祝酒词
2015/08/11 职场文书
2016幼儿教师自荐信范文
2016/01/28 职场文书
golang中实现给gif、png、jpeg图片添加文字水印
2021/04/26 Golang
Javascript中Microtask和Macrotask鲜为人知的知识点
2022/04/02 Javascript