Pytorch自定义Dataset和DataLoader去除不存在和空数据的操作


Posted in Python onMarch 03, 2021

【源码GitHub地址】:点击进入

1. 问题描述

之前写了一篇关于《pytorch Dataset, DataLoader产生自定义的训练数据》的博客,但存在一个问题,我们不能在Dataset做一些数据清理,如果我们传递给Dataset数据,本身存在问题,那么迭代过程肯定出错的。

比如我把很多图片路径都传递给Dataset,如果图片路径都是正确的,且图片都存在也没有损坏,那显然运行是没有问题的;

但倘若传递给Dataset的图片路径有些图片是不存在,这时你通过Dataset读取图片数据,然后再迭代返回,就会出现类似如下的错误:

File "D:\ProgramData\Anaconda3\envs\pytorch-py36\lib\site-packages\torch\utils\data\_utils\collate.py", line 68, in <listcomp> return [default_collate(samples) for samples in transposed]

File "D:\ProgramData\Anaconda3\envs\pytorch-py36\lib\site-packages\torch\utils\data\_utils\collate.py", line 70, in default_collate

raise TypeError((error_msg_fmt.format(type(batch[0])))) TypeError: batch must contain tensors, numbers, dicts or lists; found <class 'NoneType'>

2. 一般的解决方法

一般的解决方法也很简单粗暴,就是在传递数据给Dataset前,就做数据清理,把不存在的图片,损坏的数据都提前清理掉。

是的,这个是最简单粗暴的。

3. 另一种解决方法:自定义返回数据的规则:collate_fn()校对函数

我们希望不管传递什么处理给Dataset,Dataset都进行处理,如果不存在或者异常,就返回None,而在DataLoader时,对于不存为None的数据,都去除掉。

这样就保证在迭代过程中,DataLoader获得batch数据都是正确的。

比如读取batch_size=5的图片数据,如果其中有1个(或者多个)图片是不存在,那么返回的batch应该把不存在的数据过滤掉,即返回5-1=4大小的batch的数据。

是的,我要实现的就是这个功能:返回的batch数据会自定清理掉不合法的数据。

3.1 Pytorch数据处理函数:Dataset和 DataLoader

Pytorch有两个数据处理函数:Dataset和 DataLoader

from torch.utils.data import Dataset, DataLoader

其中Dataset用于定义数据的读取和预处理操作,而DataLoader用于加载并产生批训练数据。

torch.utils.data.DataLoader参数说明:

DataLoader(object)可用参数:

1、dataset(Dataset) 传入的数据集

2、batch_size(int, optional) 每个batch有多少个样本

3、shuffle(bool, optional) 在每个epoch开始的时候,对数据进行重新排序

4、sampler(Sampler, optional) 自定义从数据集中取样本的策略,如果指定这个参数,那么shuffle必须为False

5、batch_sampler(Sampler, optional) 与sampler类似,但是一次只返回一个batch的indices(索引),需要注意的是,一旦指定了这个参数,那么batch_size,shuffle,sampler,drop_last就不能再制定了(互斥——Mutually exclusive)

6、num_workers (int, optional) 这个参数决定了有几个进程来处理data loading。0意味着所有的数据都会被load进主进程。(默认为0)

7、collate_fn (callable, optional) 将一个list的sample组成一个mini-batch的函数

8、pin_memory (bool, optional) 如果设置为True,那么data loader将会在返回它们之前,将tensors拷贝到CUDA中的固定内存(CUDA pinned memory)中.

9、drop_last (bool, optional) 如果设置为True:这个是对最后的未完成的batch来说的,比如你的batch_size设置为64,而一个epoch只有100个样本,那么训练的时候后面的36个就被扔掉了。 如果为False(默认),那么会继续正常执行,只是最后的batch_size会小一点。

10、timeout(numeric, optional) 如果是正数,表明等待从worker进程中收集一个batch等待的时间,若超出设定的时间还没有收集到,那就不收集这个内容了。这个numeric应总是大于等于0。默认为0

11、worker_init_fn (callable, optional) 每个worker初始化函数 If not None, this will be called on eachworker subprocess with the worker id (an int in [0, num_workers - 1]) as input, after seeding and before data loading. (default: None)

我们要用到的是collate_fn()回调函数

3.2 自定义collate_fn()函数:

torch.utils.data.DataLoader的collate_fn()用于设置batch数据拼接方式,默认是default_collate函数,但当batch中含有None等数据时,默认的default_collate校队方法会出现错误。因此,我们需要自定义collate_fn()函数:

方法也很简单:只需在原来的default_collate函数中添加下面几句代码:判断image是否为None,如果为None,则在原来的batch中清除掉,这样就可以在迭代中避免出错了。

# 这里添加:判断image是否为None,如果为None,则在原来的batch中清除掉,这样就可以在迭代中避免出错了
 if isinstance(batch, list):
 batch = [(image, image_id) for (image, image_id) in batch if image is not None]
 if batch==[]:
 return (None,None)

dataset_collate.py:

# -*-coding: utf-8 -*-
"""
 @Project: pytorch-learning-tutorials
 @File : dataset_collate.py
 @Author : panjq
 @E-mail : pan_jinquan@163.com
 @Date : 2019-06-07 17:09:13
"""
 
r""""Contains definitions of the methods used by the _DataLoaderIter workers to
collate samples fetched from dataset into Tensor(s).
These **needs** to be in global scope since Py2 doesn't support serializing
static methods.
"""
import torch
import re
from torch._six import container_abcs, string_classes, int_classes 
_use_shared_memory = False
r"""Whether to use shared memory in default_collate"""
 
np_str_obj_array_pattern = re.compile(r'[SaUO]')
 
error_msg_fmt = "batch must contain tensors, numbers, dicts or lists; found {}"
 
numpy_type_map = {
 'float64': torch.DoubleTensor,
 'float32': torch.FloatTensor,
 'float16': torch.HalfTensor,
 'int64': torch.LongTensor,
 'int32': torch.IntTensor,
 'int16': torch.ShortTensor,
 'int8': torch.CharTensor,
 'uint8': torch.ByteTensor,
}
 
def collate_fn(batch):
 '''
 collate_fn (callable, optional): merges a list of samples to form a mini-batch.
 该函数参考touch的default_collate函数,也是DataLoader的默认的校对方法,当batch中含有None等数据时,
 默认的default_collate校队方法会出现错误
 一种的解决方法是:
 判断batch中image是否为None,如果为None,则在原来的batch中清除掉,这样就可以在迭代中避免出错了
 :param batch:
 :return:
 '''
 r"""Puts each data field into a tensor with outer dimension batch size"""
 # 这里添加:判断image是否为None,如果为None,则在原来的batch中清除掉,这样就可以在迭代中避免出错了
 if isinstance(batch, list):
 batch = [(image, image_id) for (image, image_id) in batch if image is not None]
 if batch==[]:
 return (None,None)
 
 elem_type = type(batch[0])
 if isinstance(batch[0], torch.Tensor):
 out = None
 if _use_shared_memory:
  # If we're in a background process, concatenate directly into a
  # shared memory tensor to avoid an extra copy
  numel = sum([x.numel() for x in batch])
  storage = batch[0].storage()._new_shared(numel)
  out = batch[0].new(storage)
 return torch.stack(batch, 0, out=out)
 elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
  and elem_type.__name__ != 'string_':
 elem = batch[0]
 if elem_type.__name__ == 'ndarray':
  # array of string classes and object
  if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
  raise TypeError(error_msg_fmt.format(elem.dtype))
 
  return collate_fn([torch.from_numpy(b) for b in batch])
 if elem.shape == (): # scalars
  py_type = float if elem.dtype.name.startswith('float') else int
  return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
 elif isinstance(batch[0], float):
 return torch.tensor(batch, dtype=torch.float64)
 elif isinstance(batch[0], int_classes):
 return torch.tensor(batch)
 elif isinstance(batch[0], string_classes):
 return batch
 elif isinstance(batch[0], container_abcs.Mapping):
 return {key: collate_fn([d[key] for d in batch]) for key in batch[0]}
 elif isinstance(batch[0], tuple) and hasattr(batch[0], '_fields'): # namedtuple
 return type(batch[0])(*(collate_fn(samples) for samples in zip(*batch)))
 elif isinstance(batch[0], container_abcs.Sequence):
 transposed = zip(*batch)#ok
 return [collate_fn(samples) for samples in transposed]
 
 raise TypeError((error_msg_fmt.format(type(batch[0]))))

测试方法:

# -*-coding: utf-8 -*-
"""
 @Project: pytorch-learning-tutorials
 @File : dataset.py
 @Author : panjq
 @E-mail : pan_jinquan@163.com
 @Date : 2019-03-07 18:45:06
"""
import torch
from torch.autograd import Variable
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import numpy as np
from utils import dataset_collate
import os
import cv2
from PIL import Image
def read_image(path,mode='RGB'):
 '''
 :param path:
 :param mode: RGB or L
 :return:
 '''
 return Image.open(path).convert(mode)
 
class TorchDataset(Dataset):
 def __init__(self, image_id_list, image_dir, resize_height=256, resize_width=256, repeat=1, transform=None):
 '''
 :param filename: 数据文件TXT:格式:imge_name.jpg label1_id labe2_id
 :param image_dir: 图片路径:image_dir+imge_name.jpg构成图片的完整路径
 :param resize_height 为None时,不进行缩放
 :param resize_width 为None时,不进行缩放,
    PS:当参数resize_height或resize_width其中一个为None时,可实现等比例缩放
 :param repeat: 所有样本数据重复次数,默认循环一次,当repeat为None时,表示无限循环<sys.maxsize
 :param transform:预处理
 '''
 self.image_dir = image_dir
 self.image_id_list=image_id_list
 self.len = len(image_id_list)
 self.repeat = repeat
 self.resize_height = resize_height
 self.resize_width = resize_width
 self.transform= transform
 
 def __getitem__(self, i):
 index = i % self.len
 # print("i={},index={}".format(i, index))
 image_id = self.image_id_list[index]
 image_path = os.path.join(self.image_dir, image_id)
 img = self.load_data(image_path)
 
 if img is None:
  return None,image_id
 img = self.data_preproccess(img)
 return img,image_id
 
 def __len__(self):
 if self.repeat == None:
  data_len = 10000000
 else:
  data_len = len(self.image_id_list) * self.repeat
 return data_len
 
 def load_data(self, path):
 '''
 加载数据
 :param path:
 :param resize_height:
 :param resize_width:
 :param normalization: 是否归一化
 :return:
 '''
 try:
  image = read_image(path)
 except Exception as e:
  image=None
  print(e)
 # image = image_processing.read_image(path)#用opencv读取图像
 return image
 
 def data_preproccess(self, data):
 '''
 数据预处理
 :param data:
 :return:
 '''
 if self.transform is not None:
  data = self.transform(data)
 return data
 
if __name__=='__main__':
 
 resize_height = 224
 resize_width = 224
 image_id_list=["1.jpg","ddd.jpg","111.jpg","3.jpg","4.jpg","5.jpg","6.jpg","7.jpg","8.jpg","9.jpg"]
 image_dir="../dataset/test_images/images"
 # 相关预处理的初始化
 '''class torchvision.transforms.ToTensor把shape=(H,W,C)的像素值范围为[0, 255]的PIL.Image或者numpy.ndarray数据
 # 转换成shape=(C,H,W)的像素数据,并且被归一化到[0.0, 1.0]的torch.FloatTensor类型。
 '''
 train_transform = transforms.Compose([
 transforms.Resize(size=(resize_height, resize_width)),
 # transforms.RandomHorizontalFlip(),#随机翻转图像
 transforms.RandomCrop(size=(resize_height, resize_width), padding=4), # 随机裁剪
 transforms.ToTensor(), # 吧shape=(H,W,C)->换成shape=(C,H,W),并且归一化到[0.0, 1.0]的torch.FloatTensor类型
 # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))#给定均值(R,G,B) 方差(R,G,B),将会把Tensor正则化
 ])
 
 epoch_num=2 #总样本循环次数
 batch_size=5 #训练时的一组数据的大小
 train_data_nums=10
 max_iterate=int((train_data_nums+batch_size-1)/batch_size*epoch_num) #总迭代次数
 
 train_data = TorchDataset(image_id_list=image_id_list,
    image_dir=image_dir,
    resize_height=resize_height,
    resize_width=resize_width,
    repeat=1,
    transform=train_transform)
 # 使用默认的default_collate会报错
 # train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=False)
 # 使用自定义的collate_fn
 train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=False, collate_fn=dataset_collate.collate_fn)
 
 
 # [1]使用epoch方法迭代,TorchDataset的参数repeat=1
 for epoch in range(epoch_num):
 for step,(batch_image, batch_label) in enumerate(train_loader):
  if batch_image is None and batch_label is None:
  print("batch_image:{},batch_label:{}".format(batch_image, batch_label))
  continue
  image=batch_image[0,:]
  image=image.numpy()#image=np.array(image)
  image = image.transpose(1, 2, 0) # 通道由[c,h,w]->[h,w,c]
  cv2.imshow("image",image)
  cv2.waitKey(2000)
  print("batch_image.shape:{},batch_label:{}".format(batch_image.shape,batch_label))
  # batch_x, batch_y = Variable(batch_x), Variable(batch_y)

输出结果说明:

batch_size=5,输入图片列表image_id_list=["1.jpg","ddd.jpg","111.jpg","3.jpg","4.jpg","5.jpg","6.jpg","7.jpg","8.jpg","9.jpg"] ,其中"ddd.jpg","111.jpg"是不存在的,resize_width=224,正常情况下返回的数据应该是torch.Size([5, 3, 224, 224]),但由于"ddd.jpg","111.jpg"不存在,被过滤掉了,所以第一个batch的维度变为torch.Size([3, 3, 224, 224])

[Errno 2] No such file or directory: '../dataset/test_images/images\\ddd.jpg'

[Errno 2] No such file or directory: '../dataset/test_images/images\\111.jpg'

batch_image.shape:torch.Size([3, 3, 224, 224]),batch_label:('1.jpg', '3.jpg', '4.jpg')

batch_image.shape:torch.Size([5, 3, 224, 224]),batch_label:('5.jpg', '6.jpg', '7.jpg', '8.jpg', '9.jpg')

[Errno 2] No such file or directory: '../dataset/test_images/images\\ddd.jpg'

[Errno 2] No such file or directory: '../dataset/test_images/images\\111.jpg'

batch_image.shape:torch.Size([3, 3, 224, 224]),batch_label:('1.jpg', '3.jpg', '4.jpg')

batch_image.shape:torch.Size([5, 3, 224, 224]),batch_label:('5.jpg', '6.jpg', '7.jpg', '8.jpg', '9.jpg')

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
python实现查找excel里某一列重复数据并且剔除后打印的方法
May 26 Python
举例讲解Python设计模式编程中对抽象工厂模式的运用
Mar 02 Python
浅谈Python实现Apriori算法介绍
Dec 20 Python
pandas 实现将重复表格去重,并重新转换为表格的方法
Apr 18 Python
实例分析python3实现并发访问水平切分表
Sep 29 Python
神经网络相关之基础概念的讲解
Dec 29 Python
python处理大日志文件
Jul 23 Python
python join方法使用详解
Jul 30 Python
Python telnet登陆功能实现代码
Apr 16 Python
完美解决Django2.0中models下的ForeignKey()问题
May 19 Python
利用python做表格数据处理
Apr 13 Python
python 进阶学习之python装饰器小结
Sep 04 Python
python爬取youtube视频的示例代码
Mar 03 #Python
pytorch Dataset,DataLoader产生自定义的训练数据案例
Mar 03 #Python
解决pytorch 数据类型报错的问题
Mar 03 #Python
python反编译教程之2048小游戏实例
Mar 03 #Python
python 如何读、写、解析CSV文件
Mar 03 #Python
聊聊python在linux下与windows下导入模块的区别说明
Mar 03 #Python
python 递归相关知识总结
Mar 03 #Python
You might like
PHP 日志缩略名的创建函数代码
2010/05/26 PHP
php下将多个数组合并成一个数组的方法与实例代码
2011/02/03 PHP
PHP5函数小全(分享)
2013/06/06 PHP
php中Socket创建与监听实现方法
2015/01/05 PHP
javaScript 利用闭包模拟对象的私有属性
2011/12/29 Javascript
JS打开层/关闭层/移动层动画效果的实例代码
2013/05/11 Javascript
文本框回车提交与禁止提交示例
2013/09/27 Javascript
3种不同方式的焦点图轮播特效分享
2013/10/30 Javascript
jquery实现实时改变网页字体大小、字体背景色和颜色的方法
2015/08/05 Javascript
JavaScript常用函数工具集:lao-utils
2016/03/01 Javascript
JavaScript制作简单的日历效果
2016/03/10 Javascript
BootStrap网页中代码显示用法详解
2016/10/21 Javascript
AngularJS自定义过滤器用法经典实例总结
2018/05/17 Javascript
vue form 表单提交后刷新页面的方法
2018/09/04 Javascript
cdn模式下vue的基本用法详解
2018/10/07 Javascript
layui动态加载多表头的实例
2019/09/05 Javascript
Vue使用虚拟dom进行渲染view的方法
2019/12/26 Javascript
vue如何在用户要关闭当前网页时弹出提示的实现
2020/05/31 Javascript
[02:43]DOTA2英雄基础教程 半人马战行者
2014/01/13 DOTA
[57:16]2014 DOTA2华西杯精英邀请赛 5 25 LGD VS VG第二场
2014/05/26 DOTA
Python使用scrapy抓取网站sitemap信息的方法
2015/04/08 Python
python anaconda 安装 环境变量 升级 以及特殊库安装的方法
2017/06/21 Python
python绘制条形图方法代码详解
2017/12/19 Python
Anaconda下安装mysql-python的包实例
2018/06/11 Python
对pandas数据判断是否为NaN值的方法详解
2018/11/06 Python
Django 自定义404 500等错误页面的实现
2020/03/08 Python
Django中的模型类设计及展示示例详解
2020/05/29 Python
Python实现自动签到脚本的示例代码
2020/08/19 Python
Python爬取微信小程序通用方法代码实例详解
2020/09/29 Python
Giglio英国站:意大利奢侈品购物网
2018/03/06 全球购物
入党积极分子思想汇报范文
2014/01/05 职场文书
会计师职业生涯规划范文
2014/02/18 职场文书
授权委托书(完整版)
2014/09/10 职场文书
统计员岗位职责
2015/02/11 职场文书
乡镇团委工作总结2015
2015/05/26 职场文书
Spring Data JPA框架Repository自定义实现
2022/04/28 Java/Android