PyTorch加载自己的数据集实例详解


Posted in Python onMarch 18, 2020

数据预处理在解决深度学习问题的过程中,往往需要花费大量的时间和精力。 数据处理的质量对训练神经网络来说十分重要,良好的数据处理不仅会加速模型训练, 更会提高模型性能。为解决这一问题,PyTorch提供了几个高效便捷的工具, 以便使用者进行数据处理或增强等操作,同时可通过并行化加速数据加载。

数据集存放大致有以下两种方式:

(1)所有数据集放在一个目录下,文件名上附有标签名,数据集存放格式如下: root/cat_dog/cat.01.jpg

root/cat_dog/cat.02.jpg

........................

root/cat_dog/dog.01.jpg

root/cat_dog/dog.02.jpg

......................

(2)不同类别的数据集放在不同目录下,目录名就是标签,数据集存放格式如下:

root/ants/xxx.png

root/ants/xxy.jpeg

root/ants/xxz.png

................

root/bees/123.jpg

root/bees/nsdf3.png

root/bees/asd932_.png

..................

1.1 对第1种数据集的处理步骤

(1)生成包含各文件名的列表(List)

(2)定义Dataset的一个子类,该子类需要继承Dataset类,查看Dataset类的源码

(3)重写父类Dataset中的两个魔法方法: 一个是: __lent__(self),其功能是len(Dataset),返回Dataset的样本数。 另一个是__getitem__(self,index),其功能假设索引为i,使Dataset[i]返回第i个样本。

(4)使用torch.utils.data.DataLoader加载数据集Dataset.

1.2 实例详解

以下以cat-dog数据集为例,说明如何实现自定义数据集的加载。

1.2.1 数据集结构

所有数据集在cat-dog目录下:

.\cat_dog\cat.01.jpg

.\cat_dog\cat.02.jpg

.\cat_dog\cat.03.jpg

....................

.\cat_dog\dog.01.jpg

.\cat_dog\dog.02.jpg

....................

1.2.2 导入需要用到的模块

from torch.utils.data import DataLoader,Dataset
from skimage import io,transform
import matplotlib.pyplot as plt
import oimport torch
from torchvision import transforms, utils
from PIL import Image
import pandas as pd
import numpy as np
#过滤警告信息
import warnings
warnings.filterwarnings("ignore")

1.2.3定义加载自定义数据的类

class MyDataset(Dataset): #继承Dataset
 def __init__(self, path_dir, transform=None): #初始化一些属性
  self.path_dir = path_dir #文件路径,如'.\data\cat-dog'
  self.transform = transform #对图形进行处理,如标准化、截取、转换等
  self.images = os.listdir(self.path_dir)#把路径下的所有文件放在一个列表中
 
 def __len__(self):#返回整个数据集的大小
  return len(self.images)
 
 def __getitem__(self,index):#根据索引index返回图像及标签
  image_index = self.images[index]#根据索引获取图像文件名称
  img_path = os.path.join(self.path_dir, image_index)#获取图像的路径或目录
  img = Image.open(img_path).convert('RGB')# 读取图像
    
  # 根据目录名称获取图像标签(cat或dog)
  label = img_path.split('\\')[-1].split('.')[0]
  #把字符转换为数字cat-0,dog-1
  label = 1 if 'dog' in label else 0
  
  if self.transform is not None:
   img = self.transform(img)
  return img,label

1.2.4 实例化类

dataset = MyDataset('.\data\cat-dog',transform=None)
img, label = dataset[0] #将启动魔法方法__getitem__(0)
print(type(img))
<class 'PIL.Image.Image'>

1.2.5 查看图像形状

i=1
for img, label in dataset:
    if i
img的形状(500, 374),label的值0

img的形状(300, 280),label的值0

img的形状(489, 499),label的值0

img的形状(431, 410),label的值0

img的形状(300, 224),label的值0

从上面返回样本的形状来看:

(1)每张图片的大小不一样,如果需要取batch训练的神经网络来说很不友好。

(2)返回样本的数值较大,未归一化至[-1, 1]

为此需要对img进行转换,如何转换?只要使用torchvision中的transforms即可

1.2.6 对图像数据进行处理

这里使用torchvision中的transforms模块

from torchvision import transforms as T
transform = T.Compose([
 T.Resize(224), # 缩放图片(Image),保持长宽比不变,最短边为224像素
 T.CenterCrop(224), # 从图片中间切出224*224的图片
 T.ToTensor(), # 将图片(Image)转成Tensor,归一化至[0, 1]
 T.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5]) # 标准化至[-1, 1],规定均值和标准差
])

1.2.7查看处理后的数据

dataset = MyDataset('.\data\cat-dog',transform=transform)
for img, label in dataset: 
 print("图像img的形状{},标签label的值{}".format(img.shape, label))
 print("图像数据预处理后:\n",img)
 break

图像img的形状torch.Size([3, 224, 224]),标签label的值0

图像数据预处理后:

tensor([[[ 0.9059, 0.9137, 0.9137, ..., 0.9451, 0.9451, 0.9451],

[ 0.9059, 0.9137, 0.9137, ..., 0.9451, 0.9451, 0.9451],

[ 0.9059, 0.9137, 0.9137, ..., 0.9529, 0.9529, 0.9529],

...,

[-0.4824, -0.5294, -0.5373, ..., -0.9216, -0.9294, -0.9451],

[-0.4980, -0.5529, -0.5608, ..., -0.9294, -0.9373, -0.9529],

[-0.4980, -0.5529, -0.5686, ..., -0.9529, -0.9608, -0.9608]],

[[ 0.5686, 0.5765, 0.5765, ..., 0.7961, 0.7882, 0.7882],

[ 0.5686, 0.5765, 0.5765, ..., 0.7961, 0.7882, 0.7882],

[ 0.5686, 0.5765, 0.5765, ..., 0.8039, 0.7961, 0.7961],

...,

[-0.6078, -0.6471, -0.6549, ..., -0.9137, -0.9216, -0.9373],

[-0.6157, -0.6706, -0.6784, ..., -0.9216, -0.9294, -0.9451],

[-0.6157, -0.6706, -0.6863, ..., -0.9451, -0.9529, -0.9529]],

[[-0.0510, -0.0431, -0.0431, ..., 0.2078, 0.2157, 0.2157],

[-0.0510, -0.0431, -0.0431, ..., 0.2078, 0.2157, 0.2157],

[-0.0510, -0.0431, -0.0431, ..., 0.2157, 0.2235, 0.2235],

...,

[-0.9529, -0.9843, -0.9922, ..., -0.9529, -0.9608, -0.9765],

[-0.9686, -0.9922, -1.0000, ..., -0.9608, -0.9686, -0.9843],

[-0.9686, -0.9922, -1.0000, ..., -0.9843, -0.9922, -0.9922]]])

由此可知,数据已标准化、规范化。

1.2.8对数据集进行批量加载

使用DataLoader模块,对数据集dataset进行批量加载

#使用DataLoader加载数据
dataloader = DataLoader(dataset,batch_size=4,shuffle=True)
for batch_datas, batch_labels in dataloader:
 print(batch_datas.size(),batch_labels.size())
torch.Size([4, 3, 224, 224]) torch.Size([4])
torch.Size([4, 3, 224, 224]) torch.Size([4])
torch.Size([4, 3, 224, 224]) torch.Size([4])
torch.Size([4, 3, 224, 224]) torch.Size([4])
torch.Size([4, 3, 224, 224]) torch.Size([4])
torch.Size([4, 3, 224, 224]) torch.Size([4])
torch.Size([4, 3, 224, 224]) torch.Size([4])
torch.Size([4, 3, 224, 224]) torch.Size([4])
torch.Size([4, 3, 224, 224]) torch.Size([4])
torch.Size([4, 3, 224, 224]) torch.Size([4])
torch.Size([2, 3, 224, 224]) torch.Size([2])

1.2.9随机查看一个批次的图像

import torchvision
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline
# 显示图像
def imshow(img):
 img = img / 2 + 0.5  # unnormalize
 npimg = img.numpy()
 plt.imshow(np.transpose(npimg, (1, 2, 0)))
 plt.show()
# 随机获取部分训练数据
dataiter = iter(dataloader)
images, labels = dataiter.next()
# 显示图像
imshow(torchvision.utils.make_grid(images))
# 打印标签
print(' '.join('%s' % ["小狗" if labels[j].item()==1 else "小猫" for j in range(4)]))

2 对第2种数据集的处理

处理这种情况比较简单,可分为2步:

(1)使用datasets.ImageFolder读取、处理图像。

(2)使用.data.DataLoader批量加载数据集,示例如下:

import torch
from torchvision import transforms, datasets
data_transform = transforms.Compose([
  transforms.RandomSizedCrop(224),
  transforms.RandomHorizontalFlip(),
  transforms.ToTensor(),
  transforms.Normalize(mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225])
 ])
hymenoptera_dataset = datasets.ImageFolder(root='.\catdog\train',
           transform=data_transform)
dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset,

总结

到此这篇关于PyTorch加载自己的数据集实例详解的文章就介绍到这了,更多相关PyTorch加载 数据集内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Python采用Django开发自己的博客系统
Sep 29 Python
Python中的左斜杠、右斜杠(正斜杠和反斜杠)
Aug 30 Python
Python守护进程和脚本单例运行详解
Jan 06 Python
Python 中字符串拼接的多种方法
Jul 30 Python
使用python list 查找所有匹配元素的位置实例
Jun 11 Python
对Django外键关系的描述
Jul 26 Python
pycharm 安装JPype的教程
Aug 08 Python
使用Python实现图像标记点的坐标输出功能
Aug 14 Python
python读取tif图片时保留其16bit的编码格式实例
Jan 13 Python
Python安装tar.gz格式文件方法详解
Jan 19 Python
python检查目录文件权限并修改目录文件权限的操作
Mar 11 Python
教你怎么用PyCharm为同一服务器配置多个python解释器
May 31 Python
Python进程间通信multiprocess代码实例
Mar 18 #Python
python实现超级玛丽游戏
Mar 18 #Python
python实现超级马里奥
Mar 18 #Python
Python开发企业微信机器人每天定时发消息实例
Mar 17 #Python
10个python3常用排序算法详细说明与实例(快速排序,冒泡排序,桶排序,基数排序,堆排序,希尔排序,归并排序,计数排序)
Mar 17 #Python
Python Selenium安装及环境配置的实现
Mar 17 #Python
详解python环境安装selenium和手动下载安装selenium的方法
Mar 17 #Python
You might like
PHP.MVC的模板标签系统(一)
2006/09/05 PHP
一个PHP分页类的代码
2011/05/18 PHP
PHP中的一些常用函数收集
2015/05/26 PHP
PHP基于双向链表与排序操作实现的会员排名功能示例
2017/12/26 PHP
Javascript 函数中的参数使用分析
2010/03/27 Javascript
javascript根据时间生成m位随机数最大13位
2014/10/30 Javascript
jquery实现可自动判断位置的弹出层效果代码
2015/10/12 Javascript
js实现自动轮换选项卡
2017/01/13 Javascript
Vue 仿百度搜索功能实现代码
2017/02/16 Javascript
React 子组件向父组件传值的方法
2017/07/24 Javascript
vue实现留言板todolist功能
2017/08/16 Javascript
浅谈vue-router 路由传参的方法
2017/12/27 Javascript
基于vue-cli npm run build之后vendor.js文件过大的解决方法
2018/09/27 Javascript
vue使用一些外部插件及样式的配置代码
2019/11/18 Javascript
vuex实现购物车的增加减少移除
2020/06/28 Javascript
适用于 Vue 的播放器组件Vue-Video-Player操作
2020/11/16 Javascript
[19:15]DK战队纪录片
2014/09/02 DOTA
python实现倒计时的示例
2014/02/14 Python
python3.6.3+opencv3.3.0实现动态人脸捕获
2018/05/25 Python
Python中最大递归深度值的探讨
2019/03/05 Python
Python Flask框架扩展操作示例
2019/05/03 Python
python中通过pip安装库文件时出现“EnvironmentError: [WinError 5] 拒绝访问”的问题及解决方案
2020/08/11 Python
西班牙在线宠物食品和配件商店:bitiba
2019/10/11 全球购物
内业资料员岗位职责
2014/01/04 职场文书
行政文秘岗位职责范本
2014/02/10 职场文书
运动会稿件300字
2014/02/14 职场文书
党的群众路线教育实践活动批评与自我批评范文
2014/10/16 职场文书
2014年学生会生活部工作总结
2014/11/07 职场文书
2014年创先争优工作总结
2014/12/11 职场文书
小学工作总结2015
2015/05/04 职场文书
小学教育见习总结
2015/06/23 职场文书
医院感染管理制度
2015/08/05 职场文书
Pytest之测试命名规则的使用
2021/04/16 Python
Python3接口性能测试实例代码
2021/06/20 Python
对象析构函数__del__在Python中何时使用
2022/03/22 Python
CSS使用SVG实现动态分布的圆环发散路径动画
2022/12/24 HTML / CSS