Pytorch技巧:DataLoader的collate_fn参数使用详解


Posted in Python onJanuary 08, 2020

DataLoader完整的参数表如下:

class torch.utils.data.DataLoader(
 dataset,
 batch_size=1,
 shuffle=False,
 sampler=None,
 batch_sampler=None,
 num_workers=0,
 collate_fn=<function default_collate>,
 pin_memory=False,
 drop_last=False,
 timeout=0,
 worker_init_fn=None)

DataLoader在数据集上提供单进程或多进程的迭代器

几个关键的参数意思:

- shuffle:设置为True的时候,每个世代都会打乱数据集

- collate_fn:如何取样本的,我们可以定义自己的函数来准确地实现想要的功能

- drop_last:告诉如何处理数据集长度除于batch_size余下的数据。True就抛弃,否则保留

一个测试的例子

import torch
import torch.utils.data as Data
import numpy as np

test = np.array([0,1,2,3,4,5,6,7,8,9,10,11])

inputing = torch.tensor(np.array([test[i:i + 3] for i in range(10)]))
target = torch.tensor(np.array([test[i:i + 1] for i in range(10)]))

torch_dataset = Data.TensorDataset(inputing,target)
batch = 3

loader = Data.DataLoader(
 dataset=torch_dataset,
 batch_size=batch, # 批大小
 # 若dataset中的样本数不能被batch_size整除的话,最后剩余多少就使用多少
 collate_fn=lambda x:(
  torch.cat(
   [x[i][j].unsqueeze(0) for i in range(len(x))], 0
   ).unsqueeze(0) for j in range(len(x[0]))
  )
 )

for (i,j) in loader:
 print(i)
 print(j)

输出结果:

tensor([[[ 0, 1, 2],
   [ 1, 2, 3],
   [ 2, 3, 4]]], dtype=torch.int32)
tensor([[[ 0],
   [ 1],
   [ 2]]], dtype=torch.int32)
tensor([[[ 3, 4, 5],
   [ 4, 5, 6],
   [ 5, 6, 7]]], dtype=torch.int32)
tensor([[[ 3],
   [ 4],
   [ 5]]], dtype=torch.int32)
tensor([[[ 6, 7, 8],
   [ 7, 8, 9],
   [ 8, 9, 10]]], dtype=torch.int32)
tensor([[[ 6],
   [ 7],
   [ 8]]], dtype=torch.int32)
tensor([[[ 9, 10, 11]]], dtype=torch.int32)
tensor([[[ 9]]], dtype=torch.int32)

如果不要collate_fn的值,输出变成

tensor([[ 0, 1, 2],
  [ 1, 2, 3],
  [ 2, 3, 4]], dtype=torch.int32)
tensor([[ 0],
  [ 1],
  [ 2]], dtype=torch.int32)
tensor([[ 3, 4, 5],
  [ 4, 5, 6],
  [ 5, 6, 7]], dtype=torch.int32)
tensor([[ 3],
  [ 4],
  [ 5]], dtype=torch.int32)
tensor([[ 6, 7, 8],
  [ 7, 8, 9],
  [ 8, 9, 10]], dtype=torch.int32)
tensor([[ 6],
  [ 7],
  [ 8]], dtype=torch.int32)
tensor([[ 9, 10, 11]], dtype=torch.int32)
tensor([[ 9]], dtype=torch.int32)

所以collate_fn就是使结果多一维。

看看collate_fn的值是什么意思。我们把它改为如下

collate_fn=lambda x:x

并输出

for i in loader:
 print(i)

得到结果

[(tensor([ 0, 1, 2], dtype=torch.int32), tensor([ 0], dtype=torch.int32)), (tensor([ 1, 2, 3], dtype=torch.int32), tensor([ 1], dtype=torch.int32)), (tensor([ 2, 3, 4], dtype=torch.int32), tensor([ 2], dtype=torch.int32))]
[(tensor([ 3, 4, 5], dtype=torch.int32), tensor([ 3], dtype=torch.int32)), (tensor([ 4, 5, 6], dtype=torch.int32), tensor([ 4], dtype=torch.int32)), (tensor([ 5, 6, 7], dtype=torch.int32), tensor([ 5], dtype=torch.int32))]
[(tensor([ 6, 7, 8], dtype=torch.int32), tensor([ 6], dtype=torch.int32)), (tensor([ 7, 8, 9], dtype=torch.int32), tensor([ 7], dtype=torch.int32)), (tensor([ 8, 9, 10], dtype=torch.int32), tensor([ 8], dtype=torch.int32))]
[(tensor([ 9, 10, 11], dtype=torch.int32), tensor([ 9], dtype=torch.int32))]

每个i都是一个列表,每个列表包含batch_size个元组,每个元组包含TensorDataset的单独数据。所以要将重新组合成每个batch包含1*3*3的input和1*3*1的target,就要重新解包并打包。 看看我们的collate_fn:

collate_fn=lambda x:(
 torch.cat(
  [x[i][j].unsqueeze(0) for i in range(len(x))], 0
  ).unsqueeze(0) for j in range(len(x[0]))
 )

j取的是两个变量:input和target。i取的是batch_size。然后通过unsqueeze(0)方法在前面加一维。torch.cat(,0)将其打包起来。然后再通过unsqueeze(0)方法在前面加一维。 完成。

以上这篇Pytorch技巧:DataLoader的collate_fn参数使用详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python通过自定义isnumber函数判断字符串是否为数字的方法
Apr 23 Python
python实现数组插入新元素的方法
May 22 Python
Python的Django框架中的数据库配置指南
Jul 17 Python
python生成验证码图片代码分享
Jan 28 Python
Python实现翻转数组功能示例
Jan 12 Python
对Python3中bytes和HexStr之间的转换详解
Dec 04 Python
解决Pycharm界面的子窗口不见了的问题
Jan 17 Python
Tensorflow实现酸奶销量预测分析
Jul 19 Python
python 字典有序并写入json文件过程解析
Sep 30 Python
PYTHON发送邮件YAGMAIL的简单实现解析
Oct 28 Python
python删除指定列或多列单个或多个内容实例
Jun 28 Python
pycharm 实现复制一行的快捷键
Jan 15 Python
Pytorch DataLoader 变长数据处理方式
Jan 08 #Python
pytorch实现用CNN和LSTM对文本进行分类方式
Jan 08 #Python
使用pytorch和torchtext进行文本分类的实例
Jan 08 #Python
python爬虫爬取监控教务系统的思路详解
Jan 08 #Python
Pytorch实现基于CharRNN的文本分类与生成示例
Jan 08 #Python
python实现单目标、多目标、多尺度、自定义特征的KCF跟踪算法(实例代码)
Jan 08 #Python
Pytorch实现神经网络的分类方式
Jan 08 #Python
You might like
php 流程控制switch的简单实例
2016/06/07 PHP
laravel-admin 中列表筛选方法
2019/10/03 PHP
Javascript 构造函数 实例分析
2008/11/26 Javascript
JavaScript 仿关机效果的图片层
2008/12/26 Javascript
jquery 输入框数字限制插件
2009/11/10 Javascript
两个select之间option的互相添加操作(jquery实现)
2009/11/12 Javascript
实现web打印的各种方法介绍及实现代码
2013/01/09 Javascript
推荐6款基于jQuery实现图片效果插件
2014/12/07 Javascript
jQuery内容过滤选择器用法分析
2015/02/10 Javascript
Bootstrap自定义文件上传下载样式
2016/05/26 Javascript
AngularJS国际化详解及示例代码
2016/08/18 Javascript
详解Vue 中 extend 、component 、mixins 、extends 的区别
2017/12/20 Javascript
nodejs 最新版安装npm 的使用详解
2018/01/18 NodeJs
Vue开发之watch监听数组、对象、变量操作分析
2019/04/25 Javascript
实例分析javascript中的异步
2020/06/02 Javascript
openlayers4实现点动态扩散
2020/08/17 Javascript
python中实现指定时间调用函数示例代码
2017/09/08 Python
python jieba分词并统计词频后输出结果到Excel和txt文档方法
2018/02/11 Python
python入门前的第一课 python怎样入门
2018/03/06 Python
python增加图像对比度的方法
2019/07/12 Python
在python中将list分段并保存为array类型的方法
2019/07/15 Python
Django REST Framework之频率限制的使用
2019/09/29 Python
python命令 -u参数用法解析
2019/10/24 Python
利用python3 的pygame模块实现塔防游戏
2019/12/30 Python
python with (as)语句实例详解
2020/02/04 Python
对python pandas中 inplace 参数的理解
2020/06/27 Python
Python pip 常用命令汇总
2020/10/19 Python
爱尔兰电子产品购物网站:Komplett.ie
2018/04/04 全球购物
攀岩、滑雪、徒步旅行装备:Black Diamond Equipment
2019/08/16 全球购物
卡骆驰英国官网:Crocs英国
2019/08/22 全球购物
电子商务专业学生职业生涯规划
2014/03/07 职场文书
2015年高中班主任工作总结
2015/04/30 职场文书
2015年幼儿园卫生保健工作总结
2015/05/12 职场文书
你会写请假条吗?
2019/06/26 职场文书
2019年公司卫生管理制度样本
2019/08/21 职场文书
mysql中int(3)和int(10)的数值范围是否相同
2021/10/16 MySQL