pytorch 自定义数据集加载方法


Posted in Python onAugust 18, 2019

pytorch 官网给出的例子中都是使用了已经定义好的特殊数据集接口来加载数据,而且其使用的数据都是官方给出的数据。如果我们有自己收集的数据集,如何用来训练网络呢?此时需要我们自己定义好数据处理接口。幸运的是pytroch给出了一个数据集接口类(torch.utils.data.Dataset),可以方便我们继承并实现自己的数据集接口。

torch.utils.data

torch的这个文件包含了一些关于数据集处理的类。

class torch.utils.data.Dataset: 一个抽象类, 所有其他类的数据集类都应该是它的子类。而且其子类必须重载两个重要的函数:len(提供数据集的大小)、getitem(支持整数索引)。

class torch.utils.data.TensorDataset: 封装成tensor的数据集,每一个样本都通过索引张量来获得。

class torch.utils.data.ConcatDataset: 连接不同的数据集以构成更大的新数据集。

class torch.utils.data.Subset(dataset, indices): 获取指定一个索引序列对应的子数据集。

class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None): 数据加载器。组合了一个数据集和采样器,并提供关于数据的迭代器。

torch.utils.data.random_split(dataset, lengths): 按照给定的长度将数据集划分成没有重叠的新数据集组合。

class torch.utils.data.Sampler(data_source):所有采样的器的基类。每个采样器子类都需要提供 __iter__ 方法以方便迭代器进行索引 和一个 len方法 以方便返回迭代器的长度。

class torch.utils.data.SequentialSampler(data_source):顺序采样样本,始终按照同一个顺序。

class torch.utils.data.RandomSampler(data_source):无放回地随机采样样本元素。

class torch.utils.data.SubsetRandomSampler(indices):无放回地按照给定的索引列表采样样本元素。

class torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True): 按照给定的概率来采样样本。

class torch.utils.data.BatchSampler(sampler, batch_size, drop_last): 在一个batch中封装一个其他的采样器。

class torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=None, rank=None):采样器可以约束数据加载进数据集的子集。

自定义数据集

自己定义的数据集需要继承抽象类class torch.utils.data.Dataset,并且需要重载两个重要的函数:__len__ 和__getitem__。

整个代码仅供参考。在__init__中是初始化了该类的一些基本参数;__getitem__中是真正读取数据的地方,迭代器通过索引来读取数据集中数据,因此只需要这一个方法中加入读取数据的相关功能即可;__len__给出了整个数据集的尺寸大小,迭代器的索引范围是根据这个函数得来的。

import torch

class myDataset(torch.nn.data.Dataset):
 def __init__(self, dataSource)
  self.dataSource = dataSource

 def __getitem__(self, index):
  element = self.dataSource[index]
  return element
 def __len__(self):
  return len(self.dataSource)

train_data = myDataset(dataSource)

自定义数据集加载器

class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None): 数据加载器。组合了一个数据集和采样器,并提供关于数据的迭代器。

dataset (Dataset) ? 需要加载的数据集(可以是自定义或者自带的数据集)。

batch_size ? batch的大小(可选项,默认值为1)。

shuffle ? 是否在每个epoch中shuffle整个数据集, 默认值为False。

sampler ? 定义从数据中抽取样本的策略. 如果指定了, shuffle参数必须为False。

num_workers ? 表示读取样本的线程数, 0表示只有主线程。

collate_fn ? 合并一个样本列表称为一个batch。

pin_memory ? 是否在返回数据之前将张量拷贝到CUDA。

drop_last (bool, optional) ? 设置是否丢弃最后一个不完整的batch,默认为False。

timeout ? 用来设置数据读取的超时时间的,但超过这个时间还没读取到数据的话就会报错。应该为非负整数。

train_loader=torch.utils.data.DataLoader(dataset=train_data, batch_size=64, shuffle=True)

以上这篇pytorch 自定义数据集加载方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
利用Python获取操作系统信息实例
Sep 02 Python
Python实现学生成绩管理系统
Apr 05 Python
Python模块WSGI使用详解
Feb 02 Python
Django中的CBV和FBV示例介绍
Feb 25 Python
Windows下安装Django框架的方法简明教程
Mar 28 Python
详解分布式任务队列Celery使用说明
Nov 29 Python
python矩阵的转置和逆转实例
Dec 12 Python
Django 创建/删除用户的示例代码
Jul 24 Python
python3使用print打印带颜色的字符串代码实例
Aug 22 Python
python logging日志模块原理及操作解析
Oct 12 Python
Pycharm+Python工程,引用子模块的实现
Mar 09 Python
记录模型训练时loss值的变化情况
Jun 16 Python
PyTorch的Optimizer训练工具的实现
Aug 18 #Python
Pytorch反向求导更新网络参数的方法
Aug 17 #Python
pytorch 模型可视化的例子
Aug 17 #Python
pytorch 输出中间层特征的实例
Aug 17 #Python
基于pytorch的保存和加载模型参数的方法
Aug 17 #Python
pytorch 固定部分参数训练的方法
Aug 17 #Python
python之PyQt按钮右键菜单功能的实现代码
Aug 17 #Python
You might like
百度ping方法使用示例 自动ping百度
2014/01/26 PHP
PHP用户管理中常用接口调用实例及解析(含源码)
2017/03/09 PHP
javascript while语句和do while语句的区别分析
2007/12/08 Javascript
IE JS无提示关闭窗口不提示的方法
2010/04/29 Javascript
基于jquery的内容循环滚动小模块(仿新浪微博未登录首页滚动微博显示)
2011/03/28 Javascript
Ajax同步与异步传输的示例代码
2013/11/21 Javascript
js之ActiveX控件使用说明 new ActiveXObject()
2014/03/03 Javascript
IE中鼠标经过option触发mouseout的解决方法
2015/01/29 Javascript
jQuery制作仿Mac Lion OS滚动条效果
2015/02/10 Javascript
分享五个有用的jquery小技巧
2015/10/08 Javascript
jQuery动态增减行的实例代码解析(推荐)
2016/12/05 Javascript
在Vant的基础上封装下拉日期控件的代码示例
2018/12/05 Javascript
vue draggable resizable gorkys与v-chart使用与总结
2019/09/05 Javascript
vue 重塑数组之修改数组指定index的值操作
2020/08/09 Javascript
js实现抽奖功能
2020/11/24 Javascript
python解析html开发库pyquery使用方法
2014/02/07 Python
研究Python的ORM框架中的SQLAlchemy库的映射关系
2015/04/25 Python
对python For 循环的三种遍历方式解析
2019/02/01 Python
Python异常处理知识点总结
2019/02/18 Python
Python 在OpenCV里实现仿射变换—坐标变换效果
2019/08/30 Python
哈工大自然语言处理工具箱之ltp在windows10下的安装使用教程
2020/05/07 Python
Python pip安装模块提示错误解决方案
2020/05/22 Python
Django实现内容缓存实例方法
2020/06/30 Python
HTML5 Canvas的性能提高技巧经验分享
2013/07/02 HTML / CSS
Exoticca英国:以最优惠的价格提供豪华异国情调旅行
2018/10/18 全球购物
类的核心特性有哪些
2014/01/01 面试题
高分子材料与工程专业个人求职信
2013/12/15 职场文书
竞争上岗演讲稿
2014/01/05 职场文书
北京大学自荐信范文
2014/01/28 职场文书
《湘夫人》教学反思
2014/02/21 职场文书
教师查摆问题及整改措施
2014/10/11 职场文书
地球上的星星观后感
2015/06/02 职场文书
高中同学会致辞
2015/08/01 职场文书
小学语文继续教育研修日志
2015/11/13 职场文书
Python matplotlib绘制条形统计图 处理多个实验多组观测值
2022/04/21 Python
GoFrame框架数据校验之校验结果Error接口对象
2022/06/21 Golang