Pytorch技巧:DataLoader的collate_fn参数使用详解


Posted in Python onJanuary 08, 2020

DataLoader完整的参数表如下:

class torch.utils.data.DataLoader(
 dataset,
 batch_size=1,
 shuffle=False,
 sampler=None,
 batch_sampler=None,
 num_workers=0,
 collate_fn=<function default_collate>,
 pin_memory=False,
 drop_last=False,
 timeout=0,
 worker_init_fn=None)

DataLoader在数据集上提供单进程或多进程的迭代器

几个关键的参数意思:

- shuffle:设置为True的时候,每个世代都会打乱数据集

- collate_fn:如何取样本的,我们可以定义自己的函数来准确地实现想要的功能

- drop_last:告诉如何处理数据集长度除于batch_size余下的数据。True就抛弃,否则保留

一个测试的例子

import torch
import torch.utils.data as Data
import numpy as np

test = np.array([0,1,2,3,4,5,6,7,8,9,10,11])

inputing = torch.tensor(np.array([test[i:i + 3] for i in range(10)]))
target = torch.tensor(np.array([test[i:i + 1] for i in range(10)]))

torch_dataset = Data.TensorDataset(inputing,target)
batch = 3

loader = Data.DataLoader(
 dataset=torch_dataset,
 batch_size=batch, # 批大小
 # 若dataset中的样本数不能被batch_size整除的话,最后剩余多少就使用多少
 collate_fn=lambda x:(
  torch.cat(
   [x[i][j].unsqueeze(0) for i in range(len(x))], 0
   ).unsqueeze(0) for j in range(len(x[0]))
  )
 )

for (i,j) in loader:
 print(i)
 print(j)

输出结果:

tensor([[[ 0, 1, 2],
   [ 1, 2, 3],
   [ 2, 3, 4]]], dtype=torch.int32)
tensor([[[ 0],
   [ 1],
   [ 2]]], dtype=torch.int32)
tensor([[[ 3, 4, 5],
   [ 4, 5, 6],
   [ 5, 6, 7]]], dtype=torch.int32)
tensor([[[ 3],
   [ 4],
   [ 5]]], dtype=torch.int32)
tensor([[[ 6, 7, 8],
   [ 7, 8, 9],
   [ 8, 9, 10]]], dtype=torch.int32)
tensor([[[ 6],
   [ 7],
   [ 8]]], dtype=torch.int32)
tensor([[[ 9, 10, 11]]], dtype=torch.int32)
tensor([[[ 9]]], dtype=torch.int32)

如果不要collate_fn的值,输出变成

tensor([[ 0, 1, 2],
  [ 1, 2, 3],
  [ 2, 3, 4]], dtype=torch.int32)
tensor([[ 0],
  [ 1],
  [ 2]], dtype=torch.int32)
tensor([[ 3, 4, 5],
  [ 4, 5, 6],
  [ 5, 6, 7]], dtype=torch.int32)
tensor([[ 3],
  [ 4],
  [ 5]], dtype=torch.int32)
tensor([[ 6, 7, 8],
  [ 7, 8, 9],
  [ 8, 9, 10]], dtype=torch.int32)
tensor([[ 6],
  [ 7],
  [ 8]], dtype=torch.int32)
tensor([[ 9, 10, 11]], dtype=torch.int32)
tensor([[ 9]], dtype=torch.int32)

所以collate_fn就是使结果多一维。

看看collate_fn的值是什么意思。我们把它改为如下

collate_fn=lambda x:x

并输出

for i in loader:
 print(i)

得到结果

[(tensor([ 0, 1, 2], dtype=torch.int32), tensor([ 0], dtype=torch.int32)), (tensor([ 1, 2, 3], dtype=torch.int32), tensor([ 1], dtype=torch.int32)), (tensor([ 2, 3, 4], dtype=torch.int32), tensor([ 2], dtype=torch.int32))]
[(tensor([ 3, 4, 5], dtype=torch.int32), tensor([ 3], dtype=torch.int32)), (tensor([ 4, 5, 6], dtype=torch.int32), tensor([ 4], dtype=torch.int32)), (tensor([ 5, 6, 7], dtype=torch.int32), tensor([ 5], dtype=torch.int32))]
[(tensor([ 6, 7, 8], dtype=torch.int32), tensor([ 6], dtype=torch.int32)), (tensor([ 7, 8, 9], dtype=torch.int32), tensor([ 7], dtype=torch.int32)), (tensor([ 8, 9, 10], dtype=torch.int32), tensor([ 8], dtype=torch.int32))]
[(tensor([ 9, 10, 11], dtype=torch.int32), tensor([ 9], dtype=torch.int32))]

每个i都是一个列表,每个列表包含batch_size个元组,每个元组包含TensorDataset的单独数据。所以要将重新组合成每个batch包含1*3*3的input和1*3*1的target,就要重新解包并打包。 看看我们的collate_fn:

collate_fn=lambda x:(
 torch.cat(
  [x[i][j].unsqueeze(0) for i in range(len(x))], 0
  ).unsqueeze(0) for j in range(len(x[0]))
 )

j取的是两个变量:input和target。i取的是batch_size。然后通过unsqueeze(0)方法在前面加一维。torch.cat(,0)将其打包起来。然后再通过unsqueeze(0)方法在前面加一维。 完成。

以上这篇Pytorch技巧:DataLoader的collate_fn参数使用详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在Python的setuptools框架下生成egg的教程
Apr 13 Python
Python中设置变量访问权限的方法
Apr 27 Python
Python 实现「食行生鲜」签到领积分功能
Sep 26 Python
Python实现统计英文文章词频的方法分析
Jan 28 Python
python 求一个列表中所有元素的乘积实例
Jun 11 Python
pandas分区间,算频率的实例
Jul 04 Python
讲解Python3中NumPy数组寻找特定元素下标的两种方法
Aug 04 Python
解决Django 在ForeignKey中出现 non-nullable field错误的问题
Aug 06 Python
Django接收自定义http header过程详解
Aug 23 Python
详解python路径拼接os.path.join()函数的用法
Oct 09 Python
命令行运行Python脚本时传入参数的三种方式详解
Oct 11 Python
python3实现名片管理系统(控制台版)
Nov 29 Python
Pytorch DataLoader 变长数据处理方式
Jan 08 #Python
pytorch实现用CNN和LSTM对文本进行分类方式
Jan 08 #Python
使用pytorch和torchtext进行文本分类的实例
Jan 08 #Python
python爬虫爬取监控教务系统的思路详解
Jan 08 #Python
Pytorch实现基于CharRNN的文本分类与生成示例
Jan 08 #Python
python实现单目标、多目标、多尺度、自定义特征的KCF跟踪算法(实例代码)
Jan 08 #Python
Pytorch实现神经网络的分类方式
Jan 08 #Python
You might like
使用Apache的htaccess防止图片被盗链的解决方法
2013/04/27 PHP
Laravel中使用自己编写类库的3种方法
2015/02/10 PHP
在openSUSE42.1下编译安装PHP7 的方法
2015/12/24 PHP
phpmyadmin下载、安装、配置教程
2017/05/16 PHP
JavaScript中的作用域链和闭包
2012/06/30 Javascript
jQuery之字体大小的设置方法
2014/02/27 Javascript
jquery实现文本框数量加减功能的例子分享
2014/05/10 Javascript
用Jquery选择器计算table中的某一列某一行的合计
2014/08/13 Javascript
javascript属性访问表达式用法分析
2015/04/25 Javascript
JavaScript转换与解析JSON方法实例详解
2015/11/24 Javascript
解决vue项目报错webpackJsonp is not defined问题
2018/03/14 Javascript
javascript使用正则实现去掉字符串前面的所有0
2018/07/23 Javascript
微信小程序实现底部导航
2018/11/05 Javascript
深入解析ES6中的promise
2018/11/08 Javascript
用Electron写个带界面的nodejs爬虫的实现方法
2019/01/29 NodeJs
layui对工具条进行选择性的显示方法
2019/09/19 Javascript
selenium+java中用js来完成日期的修改
2019/10/31 Javascript
[09:43]DOTA2每周TOP10 精彩击杀集锦vol.5
2014/06/25 DOTA
Python随机生成信用卡卡号的实现方法
2015/05/14 Python
Python中的集合类型知识讲解
2015/08/19 Python
Python3 模块、包调用&amp;路径详解
2017/10/25 Python
python微信聊天机器人改进版(定时或触发抓取天气预报、励志语录等,向好友推送)
2019/04/25 Python
pycharm设置python文件模板信息过程图解
2020/03/10 Python
树莓派升级python的具体步骤
2020/07/05 Python
html5触摸事件判断滑动方向的实现
2018/06/05 HTML / CSS
基于 HTML5 Canvas实现 的交互式地铁线路图
2018/03/05 HTML / CSS
Html5 canvas实现粒子时钟的示例代码
2018/09/06 HTML / CSS
英语国培研修感言
2014/02/13 职场文书
规划编制实施方案
2014/03/15 职场文书
建房协议书
2014/04/11 职场文书
无子女夫妻离婚协议书(4篇)
2014/10/20 职场文书
群众路线教育实践活动整改落实情况汇报
2014/10/28 职场文书
企业投资意向书
2015/05/09 职场文书
学习经验交流会总结
2015/11/02 职场文书
详解Python中的进程和线程
2021/06/23 Python
详解解Django 多对多表关系的三种创建方式
2021/08/23 Python