PyTorch实现重写/改写Dataset并载入Dataloader


Posted in Python onJuly 14, 2020

前言

众所周知,Dataset和Dataloder是pytorch中进行数据载入的部件。必须将数据载入后,再进行深度学习模型的训练。在pytorch的一些案例教学中,常使用torchvision.datasets自带的MNIST、CIFAR-10数据集,一般流程为:

# 下载并存放数据集
train_dataset = torchvision.datasets.CIFAR10(root="数据集存放位置",download=True)
# load数据
train_loader = torch.utils.data.DataLoader(dataset=train_dataset)

但是,在我们自己的模型训练中,需要使用非官方自制的数据集。这时应该怎么办呢?

我们可以通过改写torch.utils.data.Dataset中的__getitem____len__来载入我们自己的数据集。
__getitem__获取数据集中的数据,__len__获取整个数据集的长度(即个数)。

改写

采用pytorch官网案例中提供的一个脸部landmark数据集。数据集中含有存放landmark的csv文件,但是我们在这篇文章中不使用(其实也可以随便下载一些图片作数据集来实验)。

import os
import torch
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

plt.ion()  # interactive mode

torch.utils.data.Dataset是一个抽象类,我们自己的数据集需要继承Dataset,然后改写上述两个函数:

class ImageLoader(Dataset):
  def __init__(self, file_path, transform=None):
    super(ImageLoader,self).__init__()
    self.file_path = file_path
    self.transform = transform # 对输入图像进行预处理,这里并没有做,预设为None
    self.image_names = os.listdir(self.file_path) # 文件名的列表
    
  def __getitem__(self,idx):
    image = self.image_names[idx]
    image = io.imread(os.path.join(self.file_path,image))
#    if self.transform:
#    	image= self.transform(image)
    return image
         
  def __len__(self):
    return len(self.image_names)

# 设置自己存放的数据集位置,并plot展示    
imageloader = ImageLoader(file_path="D:\\Projects\\datasets\\faces\\")
# imageloader.__len__()       # 输出数据集长度(个数),应为71
# print(imageloader.__getitem__(0)) # 以数据形式展示
plt.imshow(imageloader.__getitem__(0)) # 以图像形式展示
plt.show()

得到的图片输出:

PyTorch实现重写/改写Dataset并载入Dataloader

得到的数据输出,:

array([[[ 66, 59, 53],
    [ 66, 59, 53],
    [ 66, 59, 53],
    ...,
    [ 59, 54, 48],
    [ 59, 54, 48],
    [ 59, 54, 48]],
    ...,
    [153, 141, 129],
    [158, 146, 134],
    [158, 146, 134]]], dtype=uint8)

上面看到dytpe=uint8,实际进行训练的时候,常常需要更改成float的数据类型。可以使用:

# 直接改成pytorch中的tensor下的float格式 
# 也可以用numpy的改成普通的float格式
to_float= torch.from_numpy(imageloader.__getitem__(0)).float()

改写完成后,直接使用train_loader =torch.utils.data.DataLoader(dataset=imageloader)载入到Dataloader中,就可以使用了。
下面的代码可以试着运行一下,产生的是一模一样的图片结果。

train_loader = torch.utils.data.DataLoader(dataset=imageloader)
train_loader.dataset[0]
plt.imshow(train_loader.dataset[0])
plt.show()

到此这篇关于PyTorch实现重写/改写Dataset并载入Dataloader的文章就介绍到这了,更多相关PyTorch重写/改写Dataset 内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
python轻松查到删除自己的微信好友
Jan 10 Python
浅谈Scrapy框架普通反爬虫机制的应对策略
Dec 28 Python
使用Python对微信好友进行数据分析
Jun 27 Python
win7+Python3.5下scrapy的安装方法
Jul 31 Python
Python 3 判断2个字典相同
Aug 06 Python
python实现一个函数版的名片管理系统过程解析
Aug 27 Python
python yield关键词案例测试
Oct 15 Python
Python实现快速排序的方法详解
Oct 25 Python
python Shapely使用指南详解
Feb 18 Python
Python实现自动打开电脑应用的示例代码
Apr 17 Python
Python3爬虫ChromeDriver的安装实例
Feb 06 Python
Python编程中Python与GIL互斥锁关系作用分析
Sep 15 Python
python实现将中文日期转换为数字日期
Jul 14 #Python
Python实时监控网站浏览记录实现过程详解
Jul 14 #Python
python3 中时间戳、时间、日期的转换和加减操作
Jul 14 #Python
python转化excel数字日期为标准日期操作
Jul 14 #Python
Python 实现将某一列设置为str类型
Jul 14 #Python
使用python编写一个语音朗读闹钟功能的示例代码
Jul 14 #Python
利用python对excel中一列的时间数据更改格式操作
Jul 14 #Python
You might like
php新建文件自动编号的思路与实现
2011/06/27 PHP
php用户注册信息验证正则表达式
2015/11/12 PHP
PHP+Ajax验证码验证用户登录
2016/07/20 PHP
详解Yii2 rules 的验证规则
2016/12/02 PHP
laravel 执行迁移回滚示例
2019/10/23 PHP
使用jQuery Ajax功能时需要注意的一个问题(内存溢出)
2012/05/30 Javascript
EASYUI TREEGRID异步加载数据实现方法
2012/08/22 Javascript
js拦截alert对话框另类应用
2013/01/16 Javascript
ext combobox动态加载数据库数据(附前后台)
2014/06/17 Javascript
JavaScript中setUTCMilliseconds()方法的使用详解
2015/06/12 Javascript
jQuery实现的淡入淡出二级菜单效果代码
2015/09/15 Javascript
sso跨域写cookie的一段js脚本(推荐)
2016/05/25 Javascript
BootStrap select2 动态改变值的方法
2017/02/10 Javascript
详解Node.js实现301、302重定向服务
2017/04/07 Javascript
VUE2 前端实现 静态二级省市联动选择select的示例
2018/02/09 Javascript
Javascript中弹窗confirm与prompt的区别
2018/10/26 Javascript
Vue 实现一个命令式弹窗组件功能
2019/09/25 Javascript
vue基本使用--refs获取组件或元素的实例
2019/11/07 Javascript
JavaScript如何实现防止重复的网络请求的示例
2021/01/28 Javascript
对于Python的Django框架部署的一些建议
2015/04/09 Python
pymongo为mongodb数据库添加索引的方法
2015/05/11 Python
python使用wmi模块获取windows下硬盘信息的方法
2015/05/15 Python
python实现基本进制转换的方法
2015/07/11 Python
对python append 与浅拷贝的实例讲解
2018/05/04 Python
python检测文件夹变化,并拷贝有更新的文件到对应目录的方法
2018/10/17 Python
Python中有几个关键字
2020/06/04 Python
Python创建临时文件和文件夹
2020/08/05 Python
python中pdb模块实例用法
2021/01/15 Python
详解matplotlib中pyplot和面向对象两种绘图模式之间的关系
2021/01/22 Python
英格兰橄榄球商店:England Rugby Store
2016/12/17 全球购物
在校实习生求职信
2014/06/18 职场文书
2015年禁毒宣传活动总结
2015/03/25 职场文书
2015年个人实习工作总结
2015/05/28 职场文书
2016小学教师读书心得体会
2016/01/13 职场文书
2019入党申请书范文3篇
2019/08/21 职场文书
ConstraintValidator类如何实现自定义注解校验前端传参
2021/06/18 Java/Android