PyTorch实现重写/改写Dataset并载入Dataloader


Posted in Python onJuly 14, 2020

前言

众所周知,Dataset和Dataloder是pytorch中进行数据载入的部件。必须将数据载入后,再进行深度学习模型的训练。在pytorch的一些案例教学中,常使用torchvision.datasets自带的MNIST、CIFAR-10数据集,一般流程为:

# 下载并存放数据集
train_dataset = torchvision.datasets.CIFAR10(root="数据集存放位置",download=True)
# load数据
train_loader = torch.utils.data.DataLoader(dataset=train_dataset)

但是,在我们自己的模型训练中,需要使用非官方自制的数据集。这时应该怎么办呢?

我们可以通过改写torch.utils.data.Dataset中的__getitem____len__来载入我们自己的数据集。
__getitem__获取数据集中的数据,__len__获取整个数据集的长度(即个数)。

改写

采用pytorch官网案例中提供的一个脸部landmark数据集。数据集中含有存放landmark的csv文件,但是我们在这篇文章中不使用(其实也可以随便下载一些图片作数据集来实验)。

import os
import torch
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

plt.ion()  # interactive mode

torch.utils.data.Dataset是一个抽象类,我们自己的数据集需要继承Dataset,然后改写上述两个函数:

class ImageLoader(Dataset):
  def __init__(self, file_path, transform=None):
    super(ImageLoader,self).__init__()
    self.file_path = file_path
    self.transform = transform # 对输入图像进行预处理,这里并没有做,预设为None
    self.image_names = os.listdir(self.file_path) # 文件名的列表
    
  def __getitem__(self,idx):
    image = self.image_names[idx]
    image = io.imread(os.path.join(self.file_path,image))
#    if self.transform:
#    	image= self.transform(image)
    return image
         
  def __len__(self):
    return len(self.image_names)

# 设置自己存放的数据集位置,并plot展示    
imageloader = ImageLoader(file_path="D:\\Projects\\datasets\\faces\\")
# imageloader.__len__()       # 输出数据集长度(个数),应为71
# print(imageloader.__getitem__(0)) # 以数据形式展示
plt.imshow(imageloader.__getitem__(0)) # 以图像形式展示
plt.show()

得到的图片输出:

PyTorch实现重写/改写Dataset并载入Dataloader

得到的数据输出,:

array([[[ 66, 59, 53],
    [ 66, 59, 53],
    [ 66, 59, 53],
    ...,
    [ 59, 54, 48],
    [ 59, 54, 48],
    [ 59, 54, 48]],
    ...,
    [153, 141, 129],
    [158, 146, 134],
    [158, 146, 134]]], dtype=uint8)

上面看到dytpe=uint8,实际进行训练的时候,常常需要更改成float的数据类型。可以使用:

# 直接改成pytorch中的tensor下的float格式 
# 也可以用numpy的改成普通的float格式
to_float= torch.from_numpy(imageloader.__getitem__(0)).float()

改写完成后,直接使用train_loader =torch.utils.data.DataLoader(dataset=imageloader)载入到Dataloader中,就可以使用了。
下面的代码可以试着运行一下,产生的是一模一样的图片结果。

train_loader = torch.utils.data.DataLoader(dataset=imageloader)
train_loader.dataset[0]
plt.imshow(train_loader.dataset[0])
plt.show()

到此这篇关于PyTorch实现重写/改写Dataset并载入Dataloader的文章就介绍到这了,更多相关PyTorch重写/改写Dataset 内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
收集的几个Python小技巧分享
Nov 22 Python
python 字符串转列表 list 出现\ufeff的解决方法
Jun 22 Python
Python实现自动上京东抢手机
Feb 06 Python
Python爬虫实现抓取京东店铺信息及下载图片功能示例
Aug 07 Python
python列表每个元素同增同减和列表元素去空格的实例
Jul 20 Python
详解python中的数据类型和控制流
Aug 08 Python
Python 装饰器原理、定义与用法详解
Dec 07 Python
对python中 math模块下 atan 和 atan2的区别详解
Jan 17 Python
python库skimage给灰度图像染色的方法示例
Apr 27 Python
python 实现简易的记事本
Nov 30 Python
python多线程和多进程关系详解
Dec 14 Python
python脚本框架webpy模板控制结构
Nov 20 Python
python实现将中文日期转换为数字日期
Jul 14 #Python
Python实时监控网站浏览记录实现过程详解
Jul 14 #Python
python3 中时间戳、时间、日期的转换和加减操作
Jul 14 #Python
python转化excel数字日期为标准日期操作
Jul 14 #Python
Python 实现将某一列设置为str类型
Jul 14 #Python
使用python编写一个语音朗读闹钟功能的示例代码
Jul 14 #Python
利用python对excel中一列的时间数据更改格式操作
Jul 14 #Python
You might like
使用 MySQL 开始 PHP 会话
2006/12/21 PHP
UCenter Home二次开发指南
2009/05/28 PHP
php和数据库结合的一个简单的web实例 代码分析 (php初学者)
2011/07/28 PHP
php中取得文件的后缀名?
2012/02/20 PHP
php中一个有意思的日期逻辑处理
2012/03/25 PHP
PHPThumb图片处理实例
2014/05/03 PHP
php的GD库imagettftext函数解决中文乱码问题
2015/01/24 PHP
修改Laravel5.3中的路由文件与路径
2016/08/10 PHP
PHP网站自动化配置的实现方法(必看)
2017/05/27 PHP
PHP操作路由器实现方法示例
2019/04/27 PHP
Laravel框架基础语法与知识点整理【模板变量、输出、include引入子视图等】
2019/12/03 PHP
jQuery 连续列表实现代码
2009/12/21 Javascript
Jsonp 跨域的原理以及Jquery的解决方案
2010/05/18 Javascript
div当滚动到页面顶部的时候固定在顶部实例代码
2013/05/27 Javascript
使用javascript过滤html的字符串(注释标记法)
2013/07/08 Javascript
js的.innerHTML = ""IE9下显示有错误的解决方法
2013/09/16 Javascript
js实现的全国省市二级联动下拉选择菜单完整实例
2015/08/17 Javascript
BootStrap点击下拉菜单项后显示一个新的输入框实现代码
2016/05/16 Javascript
javascript动画之磁性吸附效果篇
2016/12/09 Javascript
jQuey将序列化对象在前台显示地实现代码(方法总结)
2016/12/13 Javascript
AngularJS实现表单元素值绑定操作示例
2017/10/11 Javascript
vue单页应用在页面刷新时保留状态数据的方法
2018/09/21 Javascript
webpack4.0 入门实践教程
2018/10/08 Javascript
javascript导出csv文件(excel)的方法示例
2019/08/25 Javascript
javascript实现动态时钟的启动和停止
2020/07/29 Javascript
基于javascript canvas实现五子棋游戏
2020/07/08 Javascript
tensorflow 只恢复部分模型参数的实例
2020/01/06 Python
解决virtualenv -p python3 venv报错的问题
2021/02/05 Python
python实现不同数据库间数据同步功能
2021/02/25 Python
英国历史最悠久的DJ设备供应商:DJ Finance、DJ Warehouse、The DJ Shop
2019/09/04 全球购物
教师绩效工资方案
2014/02/01 职场文书
销售助理岗位职责
2014/02/21 职场文书
营销总监岗位职责范本
2014/02/26 职场文书
装修活动策划方案
2014/08/27 职场文书
学习雷锋主题班会
2015/08/14 职场文书
详解MindSpore自定义模型损失函数
2021/06/30 Python