Pytorch技巧:DataLoader的collate_fn参数使用详解


Posted in Python onJanuary 08, 2020

DataLoader完整的参数表如下:

class torch.utils.data.DataLoader(
 dataset,
 batch_size=1,
 shuffle=False,
 sampler=None,
 batch_sampler=None,
 num_workers=0,
 collate_fn=<function default_collate>,
 pin_memory=False,
 drop_last=False,
 timeout=0,
 worker_init_fn=None)

DataLoader在数据集上提供单进程或多进程的迭代器

几个关键的参数意思:

- shuffle:设置为True的时候,每个世代都会打乱数据集

- collate_fn:如何取样本的,我们可以定义自己的函数来准确地实现想要的功能

- drop_last:告诉如何处理数据集长度除于batch_size余下的数据。True就抛弃,否则保留

一个测试的例子

import torch
import torch.utils.data as Data
import numpy as np

test = np.array([0,1,2,3,4,5,6,7,8,9,10,11])

inputing = torch.tensor(np.array([test[i:i + 3] for i in range(10)]))
target = torch.tensor(np.array([test[i:i + 1] for i in range(10)]))

torch_dataset = Data.TensorDataset(inputing,target)
batch = 3

loader = Data.DataLoader(
 dataset=torch_dataset,
 batch_size=batch, # 批大小
 # 若dataset中的样本数不能被batch_size整除的话,最后剩余多少就使用多少
 collate_fn=lambda x:(
  torch.cat(
   [x[i][j].unsqueeze(0) for i in range(len(x))], 0
   ).unsqueeze(0) for j in range(len(x[0]))
  )
 )

for (i,j) in loader:
 print(i)
 print(j)

输出结果:

tensor([[[ 0, 1, 2],
   [ 1, 2, 3],
   [ 2, 3, 4]]], dtype=torch.int32)
tensor([[[ 0],
   [ 1],
   [ 2]]], dtype=torch.int32)
tensor([[[ 3, 4, 5],
   [ 4, 5, 6],
   [ 5, 6, 7]]], dtype=torch.int32)
tensor([[[ 3],
   [ 4],
   [ 5]]], dtype=torch.int32)
tensor([[[ 6, 7, 8],
   [ 7, 8, 9],
   [ 8, 9, 10]]], dtype=torch.int32)
tensor([[[ 6],
   [ 7],
   [ 8]]], dtype=torch.int32)
tensor([[[ 9, 10, 11]]], dtype=torch.int32)
tensor([[[ 9]]], dtype=torch.int32)

如果不要collate_fn的值,输出变成

tensor([[ 0, 1, 2],
  [ 1, 2, 3],
  [ 2, 3, 4]], dtype=torch.int32)
tensor([[ 0],
  [ 1],
  [ 2]], dtype=torch.int32)
tensor([[ 3, 4, 5],
  [ 4, 5, 6],
  [ 5, 6, 7]], dtype=torch.int32)
tensor([[ 3],
  [ 4],
  [ 5]], dtype=torch.int32)
tensor([[ 6, 7, 8],
  [ 7, 8, 9],
  [ 8, 9, 10]], dtype=torch.int32)
tensor([[ 6],
  [ 7],
  [ 8]], dtype=torch.int32)
tensor([[ 9, 10, 11]], dtype=torch.int32)
tensor([[ 9]], dtype=torch.int32)

所以collate_fn就是使结果多一维。

看看collate_fn的值是什么意思。我们把它改为如下

collate_fn=lambda x:x

并输出

for i in loader:
 print(i)

得到结果

[(tensor([ 0, 1, 2], dtype=torch.int32), tensor([ 0], dtype=torch.int32)), (tensor([ 1, 2, 3], dtype=torch.int32), tensor([ 1], dtype=torch.int32)), (tensor([ 2, 3, 4], dtype=torch.int32), tensor([ 2], dtype=torch.int32))]
[(tensor([ 3, 4, 5], dtype=torch.int32), tensor([ 3], dtype=torch.int32)), (tensor([ 4, 5, 6], dtype=torch.int32), tensor([ 4], dtype=torch.int32)), (tensor([ 5, 6, 7], dtype=torch.int32), tensor([ 5], dtype=torch.int32))]
[(tensor([ 6, 7, 8], dtype=torch.int32), tensor([ 6], dtype=torch.int32)), (tensor([ 7, 8, 9], dtype=torch.int32), tensor([ 7], dtype=torch.int32)), (tensor([ 8, 9, 10], dtype=torch.int32), tensor([ 8], dtype=torch.int32))]
[(tensor([ 9, 10, 11], dtype=torch.int32), tensor([ 9], dtype=torch.int32))]

每个i都是一个列表,每个列表包含batch_size个元组,每个元组包含TensorDataset的单独数据。所以要将重新组合成每个batch包含1*3*3的input和1*3*1的target,就要重新解包并打包。 看看我们的collate_fn:

collate_fn=lambda x:(
 torch.cat(
  [x[i][j].unsqueeze(0) for i in range(len(x))], 0
  ).unsqueeze(0) for j in range(len(x[0]))
 )

j取的是两个变量:input和target。i取的是batch_size。然后通过unsqueeze(0)方法在前面加一维。torch.cat(,0)将其打包起来。然后再通过unsqueeze(0)方法在前面加一维。 完成。

以上这篇Pytorch技巧:DataLoader的collate_fn参数使用详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在Django框架中运行Python应用全攻略
Jul 17 Python
详细介绍Python的鸭子类型
Sep 12 Python
在django中使用自定义标签实现分页功能
Jul 04 Python
Python中Scrapy爬虫图片处理详解
Nov 29 Python
python使用xpath中遇到:到底是什么?
Jan 04 Python
Numpy中stack(),hstack(),vstack()函数用法介绍及实例
Jan 09 Python
详解如何将python3.6软件的py文件打包成exe程序
Oct 09 Python
Python函数中不定长参数的写法
Feb 13 Python
Python从list类型、range()序列简单认识类(class)【可迭代】
May 31 Python
Python关于__name__属性的含义和作用详解
Feb 19 Python
python3爬虫中多线程的优势总结
Nov 24 Python
python 机器学习的标准化、归一化、正则化、离散化和白化
Apr 16 Python
Pytorch DataLoader 变长数据处理方式
Jan 08 #Python
pytorch实现用CNN和LSTM对文本进行分类方式
Jan 08 #Python
使用pytorch和torchtext进行文本分类的实例
Jan 08 #Python
python爬虫爬取监控教务系统的思路详解
Jan 08 #Python
Pytorch实现基于CharRNN的文本分类与生成示例
Jan 08 #Python
python实现单目标、多目标、多尺度、自定义特征的KCF跟踪算法(实例代码)
Jan 08 #Python
Pytorch实现神经网络的分类方式
Jan 08 #Python
You might like
十大催泪虐心动漫电影,有几部你还没看
2020/03/04 日漫
php分页代码学习示例分享
2014/02/20 PHP
PHP大文件及断点续传下载实现代码
2020/08/18 PHP
capacityFixed 基于jquery的类似于新浪微博新消息提示的定位框
2011/05/24 Javascript
JavaScript SetInterval与setTimeout使用方法详解
2013/11/15 Javascript
Extjs grid添加一个图片状态或者按钮的方法
2014/04/03 Javascript
关于Javascript 对象(object)的prototype
2014/05/09 Javascript
json实现前后台的相互传值详解
2015/01/05 Javascript
教你使用javascript简单写一个页面模板引擎
2015/05/05 Javascript
基于jQuery实现复选框是否选中进行答题提示
2015/12/10 Javascript
基于javascript实现随机颜色变化效果
2016/01/14 Javascript
JS动态改变浏览器标题的方法
2016/04/06 Javascript
浅谈js基本数据类型和typeof
2016/08/09 Javascript
浅谈JS运算符&amp;&amp;和|| 及其优先级
2016/08/10 Javascript
jQuery内容过滤选择器用法示例
2016/09/09 Javascript
jquery判断页面网址是否有效的两种方法
2016/12/11 Javascript
vue.js中Vue-router 2.0基础实践教程
2017/05/08 Javascript
AngularJs定时器$interval 和 $timeout详解
2017/05/25 Javascript
angular json对象push到数组中的方法
2018/02/27 Javascript
AngularJS标签页tab选项卡切换功能经典实例详解
2018/05/16 Javascript
详解一次Vue低版本安卓白屏问题的解决过程
2019/05/30 Javascript
Ant-design-vue Table组件customRow属性的使用说明
2020/10/28 Javascript
[05:03]显微镜下的DOTA2第十期——Ti3豪之超神幽鬼
2014/06/23 DOTA
Python实现单词拼写检查
2015/04/25 Python
Python for Informatics 第11章之正则表达式(四)
2016/04/21 Python
Python调用系统底层API播放wav文件的方法
2017/08/11 Python
利用python操作SQLite数据库及文件操作详解
2017/09/22 Python
python 运用Django 开发后台接口的实例
2018/12/11 Python
Python for循环及基础用法详解
2019/11/08 Python
解决python虚拟环境切换无效的问题
2020/04/30 Python
总经理助理的职责
2014/03/14 职场文书
企业业务员岗位职责
2014/03/14 职场文书
班干部竞选演讲稿
2014/04/24 职场文书
学生违反校规检讨书
2014/10/28 职场文书
超市工作总结范文2014
2014/12/19 职场文书
本科毕业论文指导教师评语
2014/12/30 职场文书