Pytorch DataLoader shuffle验证方式


Posted in Python onJune 02, 2021

shuffle = False时,不打乱数据顺序

shuffle = True,随机打乱

import numpy as np
import h5py
import torch
from torch.utils.data import DataLoader, Dataset  
h5f = h5py.File('train.h5', 'w');
data1 = np.array([[1,2,3],
               [2,5,6],
              [3,5,6],
              [4,5,6]])
data2 = np.array([[1,1,1],
                   [1,2,6],
                  [1,3,6],
                  [1,4,6]])
h5f.create_dataset(str('data'), data=data1)
h5f.create_dataset(str('label'), data=data2)
class Dataset(Dataset):
    def __init__(self):
        h5f = h5py.File('train.h5', 'r')
        self.data = h5f['data']
        self.label = h5f['label']
    def __getitem__(self, index):
        data = torch.from_numpy(self.data[index])
        label = torch.from_numpy(self.label[index])
        return data, label
 
    def __len__(self):
        assert self.data.shape[0] == self.label.shape[0], "wrong data length"
        return self.data.shape[0] 
 
dataset_train = Dataset()
loader_train = DataLoader(dataset=dataset_train,
                           batch_size=2,
                           shuffle = True)
 
for i, data in enumerate(loader_train):
    train_data, label = data
    print(train_data)

pytorch DataLoader使用细节

背景:

我一开始是对数据扩增这一块有疑问, 只看到了数据变换(torchvisiom.transforms),但是没看到数据扩增, 后来搞明白了, 数据扩增在pytorch指的是torchvisiom.transforms + torch.utils.data.DataLoader+多个epoch共同作用下完成的,

数据变换共有以下内容

composed = transforms.Compose([transforms.Resize((448, 448)), #  resize
                               transforms.RandomCrop(300), # random crop
                               transforms.ToTensor(),
                               transforms.Normalize(mean=[0.5, 0.5, 0.5],  # normalize
                                                    std=[0.5, 0.5, 0.5])])

简单的数据读取类, 进返回PIL格式的image:

class MyDataset(data.Dataset):    
    def __init__(self, labels_file, root_dir, transform=None):
        with open(labels_file) as csvfile:
            self.labels_file = list(csv.reader(csvfile))
        self.root_dir = root_dir
        self.transform = transform
        
    def __len__(self):
        return len(self.labels_file)
    
    def __getitem__(self, idx):
        im_name = os.path.join(root_dir, self.labels_file[idx][0])
        im = Image.open(im_name)
        
        if self.transform:
            im = self.transform(im)
            
        return im

下面是主程序

labels_file = "F:/test_temp/labels.csv"
root_dir = "F:/test_temp"
dataset_transform = MyDataset(labels_file, root_dir, transform=composed)
dataloader = data.DataLoader(dataset_transform, batch_size=1, shuffle=False)
"""原始数据集共3张图片, 以batch_size=1, epoch为2 展示所有图片(共6张)  """
for eopch in range(2):
    plt.figure(figsize=(6, 6)) 
    for ind, i in enumerate(dataloader):
        a = i[0, :, :, :].numpy().transpose((1, 2, 0))
        plt.subplot(1, 3, ind+1)
        plt.imshow(a)

Pytorch DataLoader shuffle验证方式

从上述图片总可以看到, 在每个eopch阶段实际上是对原始图片重新使用了transform, , 这就造就了数据的扩增

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现读取目录所有文件的文件名并保存到txt文件代码
Nov 22 Python
Python和GO语言实现的消息摘要算法示例
Mar 10 Python
python实现线程池的方法
Jun 30 Python
利用Python实现Windows下的鼠标键盘模拟的实例代码
Jul 13 Python
Python获取当前页面内所有链接的四种方法对比分析
Aug 19 Python
基于Python的文件类型和字符串详解
Dec 21 Python
解决Pycharm中import时无法识别自己写的程序方法
May 18 Python
用Python PIL实现几个简单的图片特效
Jan 18 Python
Python可视化mhd格式和raw格式的医学图像并保存的方法
Jan 24 Python
Python Numpy 自然数填充数组的实现
Nov 28 Python
python设置表格边框的具体方法
Jul 17 Python
python 贪心算法的实现
Sep 18 Python
python 爬取吉首大学网站成绩单
python 批量压缩图片的脚本
Jun 02 #Python
python操作xlsx格式文件并读取
关于Numpy之repeat、tile的用法总结
Jun 02 #Python
Matlab如何实现矩阵复制扩充
Jun 02 #Python
给numpy.array增加维度的超简单方法
Jun 02 #Python
pytorch model.cuda()花费时间很长的解决
You might like
如何使用Strace调试工具
2013/06/03 PHP
在Yii框架中使用PHP模板引擎Twig的例子
2014/06/13 PHP
ThinkPHP实现多数据库连接的解决方法
2014/07/01 PHP
Zend Framework教程之配置文件application.ini解析
2016/03/10 PHP
PDO::query讲解
2019/01/29 PHP
张孝祥JavaScript学习阶段性总结(2)--(X)HTML学习
2007/02/03 Javascript
Javascript读取cookie函数代码
2010/10/16 Javascript
js的隐含参数(arguments,callee,caller)使用方法
2014/01/28 Javascript
js对文章内容进行分页示例代码
2014/03/05 Javascript
jquery中one()方法的用法实例
2015/01/16 Javascript
深入理解JQuery循环绑定事件
2016/06/02 Javascript
vue货币过滤器的实现方法
2017/04/01 Javascript
ES6数组的扩展详解
2017/04/25 Javascript
JavaScript条件判断_动力节点Java学院整理
2017/06/26 Javascript
vue awesome swiper异步加载数据出现的bug问题
2018/07/03 Javascript
vue实现滑动到底部加载更多效果
2020/10/27 Javascript
详解小程序横屏方案对比
2020/06/28 Javascript
Python获取当前函数名称方法实例分享
2018/01/18 Python
Python unittest模块用法实例分析
2018/05/25 Python
面向初学者的Python编辑器Mu
2018/10/08 Python
Python解决线性代数问题之矩阵的初等变换方法
2018/12/12 Python
pandas计数 value_counts()的使用
2019/06/24 Python
python读写配置文件操作示例
2019/07/03 Python
Python K最近邻从原理到实现的方法
2019/08/15 Python
Python requests模块安装及使用教程图解
2020/06/30 Python
pycharm激活码2020最新分享适用pycharm2020最新版亲测可用
2020/11/22 Python
Python爬虫实现selenium处理iframe作用域问题
2021/01/27 Python
HomeAway澳大利亚:预订你的度假屋,公寓、度假村、别墅等
2019/02/20 全球购物
世界上第一个创建了罩杯系统的美国内衣品牌:Maidenform
2019/03/23 全球购物
《学棋》教后反思
2014/04/14 职场文书
小学五年级学生评语
2014/04/22 职场文书
敬老院献爱心活动总结
2014/07/08 职场文书
评先进个人材料
2014/12/29 职场文书
生日寿星公答谢词
2015/09/29 职场文书
利用Python实现模拟登录知乎
2022/05/25 Python
Mysql数据库group by原理详解
2022/07/07 MySQL