Pytorch技巧:DataLoader的collate_fn参数使用详解


Posted in Python onJanuary 08, 2020

DataLoader完整的参数表如下:

class torch.utils.data.DataLoader(
 dataset,
 batch_size=1,
 shuffle=False,
 sampler=None,
 batch_sampler=None,
 num_workers=0,
 collate_fn=<function default_collate>,
 pin_memory=False,
 drop_last=False,
 timeout=0,
 worker_init_fn=None)

DataLoader在数据集上提供单进程或多进程的迭代器

几个关键的参数意思:

- shuffle:设置为True的时候,每个世代都会打乱数据集

- collate_fn:如何取样本的,我们可以定义自己的函数来准确地实现想要的功能

- drop_last:告诉如何处理数据集长度除于batch_size余下的数据。True就抛弃,否则保留

一个测试的例子

import torch
import torch.utils.data as Data
import numpy as np

test = np.array([0,1,2,3,4,5,6,7,8,9,10,11])

inputing = torch.tensor(np.array([test[i:i + 3] for i in range(10)]))
target = torch.tensor(np.array([test[i:i + 1] for i in range(10)]))

torch_dataset = Data.TensorDataset(inputing,target)
batch = 3

loader = Data.DataLoader(
 dataset=torch_dataset,
 batch_size=batch, # 批大小
 # 若dataset中的样本数不能被batch_size整除的话,最后剩余多少就使用多少
 collate_fn=lambda x:(
  torch.cat(
   [x[i][j].unsqueeze(0) for i in range(len(x))], 0
   ).unsqueeze(0) for j in range(len(x[0]))
  )
 )

for (i,j) in loader:
 print(i)
 print(j)

输出结果:

tensor([[[ 0, 1, 2],
   [ 1, 2, 3],
   [ 2, 3, 4]]], dtype=torch.int32)
tensor([[[ 0],
   [ 1],
   [ 2]]], dtype=torch.int32)
tensor([[[ 3, 4, 5],
   [ 4, 5, 6],
   [ 5, 6, 7]]], dtype=torch.int32)
tensor([[[ 3],
   [ 4],
   [ 5]]], dtype=torch.int32)
tensor([[[ 6, 7, 8],
   [ 7, 8, 9],
   [ 8, 9, 10]]], dtype=torch.int32)
tensor([[[ 6],
   [ 7],
   [ 8]]], dtype=torch.int32)
tensor([[[ 9, 10, 11]]], dtype=torch.int32)
tensor([[[ 9]]], dtype=torch.int32)

如果不要collate_fn的值,输出变成

tensor([[ 0, 1, 2],
  [ 1, 2, 3],
  [ 2, 3, 4]], dtype=torch.int32)
tensor([[ 0],
  [ 1],
  [ 2]], dtype=torch.int32)
tensor([[ 3, 4, 5],
  [ 4, 5, 6],
  [ 5, 6, 7]], dtype=torch.int32)
tensor([[ 3],
  [ 4],
  [ 5]], dtype=torch.int32)
tensor([[ 6, 7, 8],
  [ 7, 8, 9],
  [ 8, 9, 10]], dtype=torch.int32)
tensor([[ 6],
  [ 7],
  [ 8]], dtype=torch.int32)
tensor([[ 9, 10, 11]], dtype=torch.int32)
tensor([[ 9]], dtype=torch.int32)

所以collate_fn就是使结果多一维。

看看collate_fn的值是什么意思。我们把它改为如下

collate_fn=lambda x:x

并输出

for i in loader:
 print(i)

得到结果

[(tensor([ 0, 1, 2], dtype=torch.int32), tensor([ 0], dtype=torch.int32)), (tensor([ 1, 2, 3], dtype=torch.int32), tensor([ 1], dtype=torch.int32)), (tensor([ 2, 3, 4], dtype=torch.int32), tensor([ 2], dtype=torch.int32))]
[(tensor([ 3, 4, 5], dtype=torch.int32), tensor([ 3], dtype=torch.int32)), (tensor([ 4, 5, 6], dtype=torch.int32), tensor([ 4], dtype=torch.int32)), (tensor([ 5, 6, 7], dtype=torch.int32), tensor([ 5], dtype=torch.int32))]
[(tensor([ 6, 7, 8], dtype=torch.int32), tensor([ 6], dtype=torch.int32)), (tensor([ 7, 8, 9], dtype=torch.int32), tensor([ 7], dtype=torch.int32)), (tensor([ 8, 9, 10], dtype=torch.int32), tensor([ 8], dtype=torch.int32))]
[(tensor([ 9, 10, 11], dtype=torch.int32), tensor([ 9], dtype=torch.int32))]

每个i都是一个列表,每个列表包含batch_size个元组,每个元组包含TensorDataset的单独数据。所以要将重新组合成每个batch包含1*3*3的input和1*3*1的target,就要重新解包并打包。 看看我们的collate_fn:

collate_fn=lambda x:(
 torch.cat(
  [x[i][j].unsqueeze(0) for i in range(len(x))], 0
  ).unsqueeze(0) for j in range(len(x[0]))
 )

j取的是两个变量:input和target。i取的是batch_size。然后通过unsqueeze(0)方法在前面加一维。torch.cat(,0)将其打包起来。然后再通过unsqueeze(0)方法在前面加一维。 完成。

以上这篇Pytorch技巧:DataLoader的collate_fn参数使用详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python基础教程之序列详解
Aug 29 Python
解决uWSGI的编码问题详解
Mar 24 Python
python ftp 按目录结构上传下载的实现代码
Sep 12 Python
关于python下cv.waitKey无响应的原因及解决方法
Jan 10 Python
Python字符串对象实现原理详解
Jul 01 Python
python全栈知识点总结
Jul 01 Python
python实现桌面托盘气泡提示
Jul 29 Python
用python拟合等角螺线的实现示例
Dec 27 Python
pytorch 中pad函数toch.nn.functional.pad()的用法
Jan 08 Python
keras之权重初始化方式
May 21 Python
Python爬虫HTPP请求方法有哪些
Jun 03 Python
如何完美的建立一个python项目
Oct 09 Python
Pytorch DataLoader 变长数据处理方式
Jan 08 #Python
pytorch实现用CNN和LSTM对文本进行分类方式
Jan 08 #Python
使用pytorch和torchtext进行文本分类的实例
Jan 08 #Python
python爬虫爬取监控教务系统的思路详解
Jan 08 #Python
Pytorch实现基于CharRNN的文本分类与生成示例
Jan 08 #Python
python实现单目标、多目标、多尺度、自定义特征的KCF跟踪算法(实例代码)
Jan 08 #Python
Pytorch实现神经网络的分类方式
Jan 08 #Python
You might like
在php和MySql中计算时间差的方法
2011/04/22 PHP
谈谈关于php的优点与缺点
2013/04/11 PHP
Yii核心组件AssetManager原理分析
2014/12/02 PHP
CI框架AR操作(数组形式)实现插入多条sql数据的方法
2016/05/18 PHP
PHP实现批量重命名某个文件夹下所有文件的方法
2017/09/04 PHP
jQuery 对象中的类数组操作
2009/04/27 Javascript
JavaScript单元测试ABC
2012/04/12 Javascript
如何使用jQuery Draggable和Droppable实现拖拽功能
2013/07/05 Javascript
禁止页面刷新让F5快捷键及右键都无效
2014/01/22 Javascript
JavaScript中判断两个字符串是否相等的方法
2015/07/07 Javascript
图片旋转、鼠标滚轮缩放、镜像、切换图片js代码
2020/12/13 Javascript
文本框只能输入数字的实现方法(兼容IE火狐)
2016/06/25 Javascript
详解jQuery插件开发方式
2016/11/22 Javascript
vue+element+Java实现批量删除功能
2019/04/08 Javascript
JavaScript随机数的组合问题案例分析
2020/05/16 Javascript
python中合并两个文本文件并按照姓名首字母排序的例子
2014/04/25 Python
Python中你应该知道的一些内置函数
2017/03/31 Python
python操作excel的方法
2018/08/16 Python
python实现图片识别汽车功能
2018/11/30 Python
Python 可变类型和不可变类型及引用过程解析
2019/09/27 Python
python实现翻译word表格小程序
2020/02/27 Python
python新手学习可变和不可变对象
2020/06/11 Python
使用CSS3的rem属性制作响应式页面布局的要点解析
2016/05/24 HTML / CSS
《长城和运河》教学反思
2014/04/14 职场文书
舞蹈教育学专业求职信
2014/06/29 职场文书
高中国旗下的演讲稿
2014/08/28 职场文书
公司离职证明范本(汇总)
2014/09/10 职场文书
婚前协议书范本两则
2014/10/16 职场文书
二年级上册数学教学计划
2015/01/20 职场文书
南极大冒险观后感
2015/06/05 职场文书
基石观后感
2015/06/12 职场文书
优秀共产党员事迹材料2016
2016/02/29 职场文书
2019年行政人事个人工作总结范本!
2019/07/19 职场文书
浅谈MySQL表空间回收的正确姿势
2021/10/05 MySQL
Python闭包的定义和使用方法
2022/04/11 Python
Android开发EditText禁止输入监听及InputFilter字符过滤
2022/06/10 Java/Android