Pytorch DataLoader shuffle验证方式


Posted in Python onJune 02, 2021

shuffle = False时,不打乱数据顺序

shuffle = True,随机打乱

import numpy as np
import h5py
import torch
from torch.utils.data import DataLoader, Dataset  
h5f = h5py.File('train.h5', 'w');
data1 = np.array([[1,2,3],
               [2,5,6],
              [3,5,6],
              [4,5,6]])
data2 = np.array([[1,1,1],
                   [1,2,6],
                  [1,3,6],
                  [1,4,6]])
h5f.create_dataset(str('data'), data=data1)
h5f.create_dataset(str('label'), data=data2)
class Dataset(Dataset):
    def __init__(self):
        h5f = h5py.File('train.h5', 'r')
        self.data = h5f['data']
        self.label = h5f['label']
    def __getitem__(self, index):
        data = torch.from_numpy(self.data[index])
        label = torch.from_numpy(self.label[index])
        return data, label
 
    def __len__(self):
        assert self.data.shape[0] == self.label.shape[0], "wrong data length"
        return self.data.shape[0] 
 
dataset_train = Dataset()
loader_train = DataLoader(dataset=dataset_train,
                           batch_size=2,
                           shuffle = True)
 
for i, data in enumerate(loader_train):
    train_data, label = data
    print(train_data)

pytorch DataLoader使用细节

背景:

我一开始是对数据扩增这一块有疑问, 只看到了数据变换(torchvisiom.transforms),但是没看到数据扩增, 后来搞明白了, 数据扩增在pytorch指的是torchvisiom.transforms + torch.utils.data.DataLoader+多个epoch共同作用下完成的,

数据变换共有以下内容

composed = transforms.Compose([transforms.Resize((448, 448)), #  resize
                               transforms.RandomCrop(300), # random crop
                               transforms.ToTensor(),
                               transforms.Normalize(mean=[0.5, 0.5, 0.5],  # normalize
                                                    std=[0.5, 0.5, 0.5])])

简单的数据读取类, 进返回PIL格式的image:

class MyDataset(data.Dataset):    
    def __init__(self, labels_file, root_dir, transform=None):
        with open(labels_file) as csvfile:
            self.labels_file = list(csv.reader(csvfile))
        self.root_dir = root_dir
        self.transform = transform
        
    def __len__(self):
        return len(self.labels_file)
    
    def __getitem__(self, idx):
        im_name = os.path.join(root_dir, self.labels_file[idx][0])
        im = Image.open(im_name)
        
        if self.transform:
            im = self.transform(im)
            
        return im

下面是主程序

labels_file = "F:/test_temp/labels.csv"
root_dir = "F:/test_temp"
dataset_transform = MyDataset(labels_file, root_dir, transform=composed)
dataloader = data.DataLoader(dataset_transform, batch_size=1, shuffle=False)
"""原始数据集共3张图片, 以batch_size=1, epoch为2 展示所有图片(共6张)  """
for eopch in range(2):
    plt.figure(figsize=(6, 6)) 
    for ind, i in enumerate(dataloader):
        a = i[0, :, :, :].numpy().transpose((1, 2, 0))
        plt.subplot(1, 3, ind+1)
        plt.imshow(a)

Pytorch DataLoader shuffle验证方式

从上述图片总可以看到, 在每个eopch阶段实际上是对原始图片重新使用了transform, , 这就造就了数据的扩增

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python随机生成一个6位的验证码代码分享
Mar 24 Python
详解Python中break语句的用法
May 14 Python
Python模块搜索概念介绍及模块安装方法介绍
Jun 03 Python
Python结巴中文分词工具使用过程中遇到的问题及解决方法
Apr 15 Python
python 全文检索引擎详解
Apr 25 Python
python模拟表单提交登录图书馆
Apr 27 Python
python excel使用xlutils类库实现追加写功能的方法
May 02 Python
快速解决pandas.read_csv()乱码的问题
Jun 15 Python
Flask-WTF表单的使用方法
Jul 12 Python
wxPython实现带颜色的进度条
Nov 19 Python
pycharm 实现复制一行的快捷键
Jan 15 Python
Python爬取酷狗MP3音频的步骤
Feb 26 Python
python 爬取吉首大学网站成绩单
python 批量压缩图片的脚本
Jun 02 #Python
python操作xlsx格式文件并读取
关于Numpy之repeat、tile的用法总结
Jun 02 #Python
Matlab如何实现矩阵复制扩充
Jun 02 #Python
给numpy.array增加维度的超简单方法
Jun 02 #Python
pytorch model.cuda()花费时间很长的解决
You might like
让的PHP代码飞起来的40条小技巧(提升php效率)
2010/04/12 PHP
php 按指定元素值去除数组元素的实现方法
2011/11/04 PHP
yii添删改查实例
2015/11/16 PHP
Javascript 中文字符串处理额外注意事项
2009/11/15 Javascript
toString()一个会自动调用的方法
2010/02/08 Javascript
一起来写段JS drag拖动代码
2010/12/09 Javascript
基于jquery的滚动条滚动固定div(附演示下载)
2012/10/29 Javascript
ajax提交表单实现网页无刷新注册示例
2014/05/08 Javascript
jQuery弹出框代码封装DialogHelper
2015/01/30 Javascript
理解js回收机制通俗易懂版
2016/02/29 Javascript
jQuery实现的购物车物品数量加减功能代码
2016/11/16 Javascript
jQuery实现CheckBox全选、全不选功能
2017/01/11 Javascript
RequireJS 依赖关系的实例(推荐)
2017/01/21 Javascript
js中setTimeout的妙用--防止循环超时
2017/03/06 Javascript
JavaScript数组排序reverse()和sort()方法详解
2017/12/24 Javascript
老生常谈JS中的继承及实现代码
2018/07/06 Javascript
浅谈vuex actions和mutation的异曲同工
2018/12/13 Javascript
微信小程序实现张图片合成为一张并下载
2019/07/16 Javascript
在Python的Django框架的视图中使用Session的方法
2015/07/23 Python
Python中的字符串替换操作示例
2016/06/27 Python
python使用mysql数据库示例代码
2017/05/21 Python
linux环境下的python安装过程图解(含setuptools)
2017/11/22 Python
Python PyQt4实现QQ抽屉效果
2018/04/20 Python
Python玩转加密的技巧【推荐】
2019/05/13 Python
Python学习之路安装pycharm的教程详解
2020/06/17 Python
Html5写一个简单的俄罗斯方块小游戏
2019/12/03 HTML / CSS
美国牛仔品牌:True Religion
2018/11/16 全球购物
外语学院毕业生的自我鉴定
2013/11/28 职场文书
如何打造一封优秀的留学推荐信
2014/01/25 职场文书
社区工作感言
2014/02/21 职场文书
小学清明节活动方案
2014/03/08 职场文书
竞争与合作演讲稿
2014/05/12 职场文书
新闻编辑求职信
2014/07/13 职场文书
我的中国梦演讲稿高中篇
2014/08/19 职场文书
朋友离别感言
2015/08/04 职场文书
资产移交协议书
2016/03/24 职场文书