Pytorch技巧:DataLoader的collate_fn参数使用详解


Posted in Python onJanuary 08, 2020

DataLoader完整的参数表如下:

class torch.utils.data.DataLoader(
 dataset,
 batch_size=1,
 shuffle=False,
 sampler=None,
 batch_sampler=None,
 num_workers=0,
 collate_fn=<function default_collate>,
 pin_memory=False,
 drop_last=False,
 timeout=0,
 worker_init_fn=None)

DataLoader在数据集上提供单进程或多进程的迭代器

几个关键的参数意思:

- shuffle:设置为True的时候,每个世代都会打乱数据集

- collate_fn:如何取样本的,我们可以定义自己的函数来准确地实现想要的功能

- drop_last:告诉如何处理数据集长度除于batch_size余下的数据。True就抛弃,否则保留

一个测试的例子

import torch
import torch.utils.data as Data
import numpy as np

test = np.array([0,1,2,3,4,5,6,7,8,9,10,11])

inputing = torch.tensor(np.array([test[i:i + 3] for i in range(10)]))
target = torch.tensor(np.array([test[i:i + 1] for i in range(10)]))

torch_dataset = Data.TensorDataset(inputing,target)
batch = 3

loader = Data.DataLoader(
 dataset=torch_dataset,
 batch_size=batch, # 批大小
 # 若dataset中的样本数不能被batch_size整除的话,最后剩余多少就使用多少
 collate_fn=lambda x:(
  torch.cat(
   [x[i][j].unsqueeze(0) for i in range(len(x))], 0
   ).unsqueeze(0) for j in range(len(x[0]))
  )
 )

for (i,j) in loader:
 print(i)
 print(j)

输出结果:

tensor([[[ 0, 1, 2],
   [ 1, 2, 3],
   [ 2, 3, 4]]], dtype=torch.int32)
tensor([[[ 0],
   [ 1],
   [ 2]]], dtype=torch.int32)
tensor([[[ 3, 4, 5],
   [ 4, 5, 6],
   [ 5, 6, 7]]], dtype=torch.int32)
tensor([[[ 3],
   [ 4],
   [ 5]]], dtype=torch.int32)
tensor([[[ 6, 7, 8],
   [ 7, 8, 9],
   [ 8, 9, 10]]], dtype=torch.int32)
tensor([[[ 6],
   [ 7],
   [ 8]]], dtype=torch.int32)
tensor([[[ 9, 10, 11]]], dtype=torch.int32)
tensor([[[ 9]]], dtype=torch.int32)

如果不要collate_fn的值,输出变成

tensor([[ 0, 1, 2],
  [ 1, 2, 3],
  [ 2, 3, 4]], dtype=torch.int32)
tensor([[ 0],
  [ 1],
  [ 2]], dtype=torch.int32)
tensor([[ 3, 4, 5],
  [ 4, 5, 6],
  [ 5, 6, 7]], dtype=torch.int32)
tensor([[ 3],
  [ 4],
  [ 5]], dtype=torch.int32)
tensor([[ 6, 7, 8],
  [ 7, 8, 9],
  [ 8, 9, 10]], dtype=torch.int32)
tensor([[ 6],
  [ 7],
  [ 8]], dtype=torch.int32)
tensor([[ 9, 10, 11]], dtype=torch.int32)
tensor([[ 9]], dtype=torch.int32)

所以collate_fn就是使结果多一维。

看看collate_fn的值是什么意思。我们把它改为如下

collate_fn=lambda x:x

并输出

for i in loader:
 print(i)

得到结果

[(tensor([ 0, 1, 2], dtype=torch.int32), tensor([ 0], dtype=torch.int32)), (tensor([ 1, 2, 3], dtype=torch.int32), tensor([ 1], dtype=torch.int32)), (tensor([ 2, 3, 4], dtype=torch.int32), tensor([ 2], dtype=torch.int32))]
[(tensor([ 3, 4, 5], dtype=torch.int32), tensor([ 3], dtype=torch.int32)), (tensor([ 4, 5, 6], dtype=torch.int32), tensor([ 4], dtype=torch.int32)), (tensor([ 5, 6, 7], dtype=torch.int32), tensor([ 5], dtype=torch.int32))]
[(tensor([ 6, 7, 8], dtype=torch.int32), tensor([ 6], dtype=torch.int32)), (tensor([ 7, 8, 9], dtype=torch.int32), tensor([ 7], dtype=torch.int32)), (tensor([ 8, 9, 10], dtype=torch.int32), tensor([ 8], dtype=torch.int32))]
[(tensor([ 9, 10, 11], dtype=torch.int32), tensor([ 9], dtype=torch.int32))]

每个i都是一个列表,每个列表包含batch_size个元组,每个元组包含TensorDataset的单独数据。所以要将重新组合成每个batch包含1*3*3的input和1*3*1的target,就要重新解包并打包。 看看我们的collate_fn:

collate_fn=lambda x:(
 torch.cat(
  [x[i][j].unsqueeze(0) for i in range(len(x))], 0
  ).unsqueeze(0) for j in range(len(x[0]))
 )

j取的是两个变量:input和target。i取的是batch_size。然后通过unsqueeze(0)方法在前面加一维。torch.cat(,0)将其打包起来。然后再通过unsqueeze(0)方法在前面加一维。 完成。

以上这篇Pytorch技巧:DataLoader的collate_fn参数使用详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
详解Python中find()方法的使用
May 18 Python
Python优化技巧之利用ctypes提高执行速度
Sep 11 Python
对web.py设置favicon.ico的方法详解
Dec 04 Python
python实现基于信息增益的决策树归纳
Dec 18 Python
对IPython交互模式下的退出方法详解
Feb 16 Python
Python3.5面向对象编程图文与实例详解
Apr 24 Python
在Python中等距取出一个数组其中n个数的实现方式
Nov 27 Python
python函数声明和调用定义及原理详解
Dec 02 Python
Pytorch Tensor 输出为txt和mat格式方式
Jan 03 Python
Python3操作YAML文件格式方法解析
Apr 10 Python
关于Kotlin中SAM转换的那些事
Sep 15 Python
Python中Permission denied的解决方案
Apr 02 Python
Pytorch DataLoader 变长数据处理方式
Jan 08 #Python
pytorch实现用CNN和LSTM对文本进行分类方式
Jan 08 #Python
使用pytorch和torchtext进行文本分类的实例
Jan 08 #Python
python爬虫爬取监控教务系统的思路详解
Jan 08 #Python
Pytorch实现基于CharRNN的文本分类与生成示例
Jan 08 #Python
python实现单目标、多目标、多尺度、自定义特征的KCF跟踪算法(实例代码)
Jan 08 #Python
Pytorch实现神经网络的分类方式
Jan 08 #Python
You might like
MOTOROLA 摩托罗拉 MODEL 66-XI五灯中波收音机
2021/03/02 无线电
咖啡因含量是由谁决定的?低因咖啡怎么来?低因咖啡适合什么人喝
2021/03/06 新手入门
推荐php模板技术[转]
2007/01/04 PHP
PHP简单系统数据添加以及数据删除模块源文件下载
2008/06/07 PHP
七款最流行的PHP本地服务器分享
2013/02/19 PHP
PHP图片处理之使用imagecopy函数添加图片水印实例
2014/11/19 PHP
PHP识别二维码的方法(php-zbarcode安装与使用)
2016/07/07 PHP
Discuz论坛密码与密保加密规则
2016/12/19 PHP
List the Codec Files on a Computer
2007/06/18 Javascript
js string 转 int 注意的问题小结
2013/08/15 Javascript
JavaScript获取图片的原始尺寸以宽度为例
2014/05/04 Javascript
判断iframe里的页面是否加载完成
2014/06/06 Javascript
angularjs学习笔记之双向数据绑定
2015/09/26 Javascript
微信小程序 swiper制作tab切换实现附源码
2017/01/21 Javascript
浅谈jQuery框架Ajax常用选项
2017/07/08 jQuery
Vue render深入开发讲解
2018/04/13 Javascript
ios设备中angularjs无法改变页面title的解决方法
2018/09/13 Javascript
js实现移动端tab切换时下划线滑动效果
2019/09/08 Javascript
移动端JS实现拖拽两种方法解析
2020/10/12 Javascript
[01:16:28]DOTA2-DPC中国联赛 正赛 iG vs Magma BO3 第二场 2月23日
2021/03/11 DOTA
Golang与python线程详解及简单实例
2017/04/27 Python
python实现定时自动备份文件到其他主机的实例代码
2018/02/23 Python
对python 命令的-u参数详解
2018/12/03 Python
python爬虫开发之Beautiful Soup模块从安装到详细使用方法与实例
2020/03/09 Python
python 制作本地应用搜索工具
2021/02/27 Python
Mytheresa中国官网:德国时尚奢侈品商城
2017/08/04 全球购物
英国现代家具和装饰网站:PN Home
2018/08/16 全球购物
Notino匈牙利:购买香水和化妆品
2019/04/12 全球购物
人事部主管岗位职责
2013/12/26 职场文书
2014个人反腐倡廉思想汇报
2014/09/15 职场文书
小时代观后感
2015/06/10 职场文书
《百分数的认识》教学反思
2016/02/19 职场文书
PHP遍历数组的6种方式总结
2021/11/17 PHP
各种货币符号快捷输入
2022/02/17 杂记
Android超详细讲解组件ScrollView的使用
2022/03/31 Java/Android
MySql中的json_extract函数处理json字段详情
2022/06/05 MySQL