Pytorch DataLoader shuffle验证方式


Posted in Python onJune 02, 2021

shuffle = False时,不打乱数据顺序

shuffle = True,随机打乱

import numpy as np
import h5py
import torch
from torch.utils.data import DataLoader, Dataset  
h5f = h5py.File('train.h5', 'w');
data1 = np.array([[1,2,3],
               [2,5,6],
              [3,5,6],
              [4,5,6]])
data2 = np.array([[1,1,1],
                   [1,2,6],
                  [1,3,6],
                  [1,4,6]])
h5f.create_dataset(str('data'), data=data1)
h5f.create_dataset(str('label'), data=data2)
class Dataset(Dataset):
    def __init__(self):
        h5f = h5py.File('train.h5', 'r')
        self.data = h5f['data']
        self.label = h5f['label']
    def __getitem__(self, index):
        data = torch.from_numpy(self.data[index])
        label = torch.from_numpy(self.label[index])
        return data, label
 
    def __len__(self):
        assert self.data.shape[0] == self.label.shape[0], "wrong data length"
        return self.data.shape[0] 
 
dataset_train = Dataset()
loader_train = DataLoader(dataset=dataset_train,
                           batch_size=2,
                           shuffle = True)
 
for i, data in enumerate(loader_train):
    train_data, label = data
    print(train_data)

pytorch DataLoader使用细节

背景:

我一开始是对数据扩增这一块有疑问, 只看到了数据变换(torchvisiom.transforms),但是没看到数据扩增, 后来搞明白了, 数据扩增在pytorch指的是torchvisiom.transforms + torch.utils.data.DataLoader+多个epoch共同作用下完成的,

数据变换共有以下内容

composed = transforms.Compose([transforms.Resize((448, 448)), #  resize
                               transforms.RandomCrop(300), # random crop
                               transforms.ToTensor(),
                               transforms.Normalize(mean=[0.5, 0.5, 0.5],  # normalize
                                                    std=[0.5, 0.5, 0.5])])

简单的数据读取类, 进返回PIL格式的image:

class MyDataset(data.Dataset):    
    def __init__(self, labels_file, root_dir, transform=None):
        with open(labels_file) as csvfile:
            self.labels_file = list(csv.reader(csvfile))
        self.root_dir = root_dir
        self.transform = transform
        
    def __len__(self):
        return len(self.labels_file)
    
    def __getitem__(self, idx):
        im_name = os.path.join(root_dir, self.labels_file[idx][0])
        im = Image.open(im_name)
        
        if self.transform:
            im = self.transform(im)
            
        return im

下面是主程序

labels_file = "F:/test_temp/labels.csv"
root_dir = "F:/test_temp"
dataset_transform = MyDataset(labels_file, root_dir, transform=composed)
dataloader = data.DataLoader(dataset_transform, batch_size=1, shuffle=False)
"""原始数据集共3张图片, 以batch_size=1, epoch为2 展示所有图片(共6张)  """
for eopch in range(2):
    plt.figure(figsize=(6, 6)) 
    for ind, i in enumerate(dataloader):
        a = i[0, :, :, :].numpy().transpose((1, 2, 0))
        plt.subplot(1, 3, ind+1)
        plt.imshow(a)

Pytorch DataLoader shuffle验证方式

从上述图片总可以看到, 在每个eopch阶段实际上是对原始图片重新使用了transform, , 这就造就了数据的扩增

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python获取远程图片大小和尺寸的方法
Mar 26 Python
Python中列表的一些基本操作知识汇总
May 20 Python
Python运行报错UnicodeDecodeError的解决方法
Jun 07 Python
使用Python实现博客上进行自动翻页
Aug 23 Python
Python字符编码与函数的基本使用方法
Sep 30 Python
python多进程读图提取特征存npy
May 21 Python
Django 通过JS实现ajax过程详解
Jul 30 Python
详解python tkinter模块安装过程
Jan 06 Python
python中resample函数实现重采样和降采样代码
Feb 25 Python
Python自动重新加载模块详解(autoreload module)
Apr 01 Python
解决Jupyter因卸载重装导致的问题修复
Apr 10 Python
利用Python实现某OA系统的自动定位功能
May 27 Python
python 爬取吉首大学网站成绩单
python 批量压缩图片的脚本
Jun 02 #Python
python操作xlsx格式文件并读取
关于Numpy之repeat、tile的用法总结
Jun 02 #Python
Matlab如何实现矩阵复制扩充
Jun 02 #Python
给numpy.array增加维度的超简单方法
Jun 02 #Python
pytorch model.cuda()花费时间很长的解决
You might like
php function用法如何递归及return和echo区别
2014/03/07 PHP
php获取用户浏览器版本的方法
2015/01/03 PHP
php检查日期函数checkdate用法实例
2015/03/19 PHP
PHP中static关键字以及与self关键字的区别
2015/07/01 PHP
php基于mcrypt_encrypt和mcrypt_decrypt实现字符串加密解密的方法
2016/07/12 PHP
curl 出现错误的调试方法(必看)
2017/02/13 PHP
PHP中单例模式的使用场景与使用方法讲解
2019/03/18 PHP
asp javascript 实现关闭窗口时保存数据的办法
2007/11/24 Javascript
详解JavaScript中setSeconds()方法的使用
2015/06/11 Javascript
解决前端跨域问题方案汇总
2016/11/20 Javascript
JS获取浮动(float)元素的style.left值为空的快速解决办法
2017/02/19 Javascript
微信小程序Redux绑定实例详解
2017/06/07 Javascript
JS实现定时任务每隔N秒请求后台setInterval定时和ajax请求问题
2017/10/15 Javascript
JavaScript中Object值合并方法详解
2017/12/22 Javascript
vue中锚点的三种方法
2018/07/06 Javascript
详解webpack loader和plugin编写
2018/10/12 Javascript
jQuery 隐藏/显示效果函数用法实例分析
2020/05/20 jQuery
[00:12]2018DOTA2亚洲邀请赛SOLO赛 MidOne是否中单第一人?
2018/04/05 DOTA
Python中__call__用法实例
2014/08/29 Python
Python实现的读写json文件功能示例
2018/06/05 Python
python 统计一个列表当中的每一个元素出现了多少次的方法
2018/11/14 Python
Python学习笔记之For循环用法详解
2019/08/14 Python
Tensorflow中k.gradients()和tf.stop_gradient()用法说明
2020/06/10 Python
深入了解Python 变量作用域
2020/07/24 Python
HTML5制作3D爱心动画教程 献给女友浪漫的礼物
2014/11/05 HTML / CSS
世界上最全面的草药补充剂和顶级品牌维生素网站:HerbsPro
2019/01/20 全球购物
清洁工表扬信
2014/01/08 职场文书
通信研究生自荐信
2014/02/01 职场文书
学生会干部自我鉴定2014
2014/09/18 职场文书
教师个人查摆剖析材料
2014/10/14 职场文书
收入证明范本
2015/06/12 职场文书
旅游安全责任协议书
2016/03/22 职场文书
三年级作文之小小梦想
2019/12/06 职场文书
Python中使用Lambda函数的5种用法
2021/04/01 Python
用几道面试题来看JavaScript执行机制
2021/04/30 Javascript
Python与C++中梯度方向直方图的实现
2022/03/17 Python