解决Pytorch dataloader时报错每个tensor维度不一样的问题


Posted in Python onMay 28, 2021

使用pytorch的dataloader报错:

RuntimeError: stack expects each tensor to be equal size, but got [2] at entry 0 and [1] at entry 1

1. 问题描述

报错定位:位于定义dataset的代码中

def __getitem__(self, index):
 ...
 return y    #此处报错

报错内容

File "D:\python\lib\site-packages\torch\utils\data\_utils\collate.py", line 55, in default_collate
return torch.stack(batch, 0, out=out)
RuntimeError: stack expects each tensor to be equal size, but got [2] at entry 0 and [1] at entry 1

把前一行的报错带上能够更清楚地明白问题在哪里.

2.问题分析

从报错可以看到,是在代码中执行torch.stack时发生了报错.因此必须要明白在哪里执行了stack操作.

通过调试可以发现,在通过loader加载一个batch数据的时候,是通过每一次给一个随机的index取出相应的向量.那么最终要形成一个batch的数据就必须要进行拼接操作,而torch.stack就是进行这里所说的拼接.

再来看看具体报的什么错: 说是stack的向量维度不同. 这说明在每次给出一个随机的index,返回的y向量的维度应该是相同的,而我们这里是不同的.

这样解决方法也就明确了:使返回的向量y的维度固定下来.

3.问题出处

为什么我会出现这样的一个问题,是因为我的特征向量中存在multi-hot特征.而为了节省空间,我是用一个列表存储这个特征的.示例如下:

feature=[[1,3,5],
  [0,2],
  [1,2,5,8]]

这就导致了我每次返回的向量的维度是不同的.因此可以采用向量补全的方法,把不同长度的向量补全成等长的.

# 把所有向量的长度都补为6
 multi = np.pad(multi, (0, 6-multi.shape[0]), 'constant', constant_values=(0, -1))

4.总结

在构建dataset重写的__getitem__方法中要返回相同长度的tensor.

可以使用向量补全的方法来解决这个问题.

补充:pytorch学习笔记:torch.utils.data下的TensorDataset和DataLoader的使用

一、TensorDataset

对给定的tensor数据(样本和标签),将它们包装成dataset。注意,如果是numpy的array,或者Pandas的DataFrame需要先转换成Tensor。

'''
data_tensor (Tensor) - 样本数据
target_tensor (Tensor) - 样本目标(标签)
'''
 dataset=torch.utils.data.TensorDataset(data_tensor, 
                                        target_tensor)

下面举个例子:

我们先定义一下样本数据和标签数据,一共有1000个样本

import torch
import numpy as np
num_inputs = 2
num_examples = 1000
true_w = [2, -3.4]
true_b = 4.2
features = torch.tensor(np.random.normal(0, 1, 
                       (num_examples, num_inputs)), 
                       dtype=torch.float)

labels = true_w[0] * features[:, 0] + \
         true_w[1] * features[:, 1] + true_b

labels += torch.tensor(np.random.normal(0, 0.01, 
                       size=labels.size()), 
                       dtype=torch.float)

print(features.shape)
print(labels.shape)

'''
输出:torch.Size([1000, 2])
     torch.Size([1000])
'''

然后我们使用TensorDataset来生成数据集

import torch.utils.data as Data
# 将训练数据的特征和标签组合
dataset = Data.TensorDataset(features, labels)

二、DataLoader

数据加载器,组合数据集和采样器,并在数据集上提供单进程或多进程迭代器。它可以对我们上面所说的数据集Dataset作进一步的设置。

dataset (Dataset) ? 加载数据的数据集。

batch_size (int, optional) ? 每个batch加载多少个样本(默认: 1)。

shuffle (bool, optional) ? 设置为True时会在每个epoch重新打乱数据(默认: False).

sampler (Sampler, optional) ? 定义从数据集中提取样本的策略。如果指定,则shuffle必须设置成False。

num_workers (int, optional) ? 用多少个子进程加载数据。0表示数据将在主进程中加载(默认: 0)

pin_memory:内存寄存,默认为False。在数据返回前,是否将数据复制到CUDA内存中。

drop_last (bool, optional) ? 如果数据集大小不能被batch size整除,则设置为True后可删除最后一个不完整的batch。如果设为False并且数据集的大小不能被batch size整除,则最后一个batch将更小。(默认: False)

timeout:是用来设置数据读取的超时时间的,如果超过这个时间还没读取到数据的话就会报错。 所以,数值必须大于等于0。

data_iter=torch.utils.data.DataLoader(dataset, batch_size=1, 
                            shuffle=False, sampler=None, 
                            batch_sampler=None, num_workers=0, 
                            collate_fn=None, pin_memory=False, 
                            drop_last=False, timeout=0, 
                            worker_init_fn=None, 
                            multiprocessing_context=None)

上面对一些重要常用的参数做了说明,其中有一个参数是sampler,下面我们对它有哪些具体取值再做一下说明。只列出几个常用的取值:

torch.utils.data.sampler.SequentialSampler(dataset)

样本元素按顺序采样,始终以相同的顺序。

torch.utils.data.sampler.RandomSampler(dataset)

样本元素随机采样,没有替换。

torch.utils.data.sampler.SubsetRandomSampler(indices)

样本元素从指定的索引列表中随机抽取,没有替换。

下面就来看一个例子,该例子使用的dataset就是上面所生成的dataset

data_iter=Data.DataLoader(dataset, 
                          batch_size=10, 
                          shuffle=False,
sampler=torch.utils.data.sampler.RandomSampler(dataset))

for X, y in data_iter:
    print(X,"\n", y)
    break

'''
输出:
tensor([[-1.6338,  0.8451],
        [ 0.7245, -0.7387],
        [ 0.4672,  0.2623],
        [-1.9082,  0.0980],
        [-0.3881,  0.5138],
        [-0.6983, -0.4712],
        [ 0.1400,  0.7489],
        [-0.7761, -0.4596],
        [-2.2700, -0.2532],
        [-1.2641, -2.8089]]) 

tensor([-1.9451,  8.1587,  4.2374,  0.0519,  1.6843,  4.3970,  
        1.9311,  4.1999,0.5253, 11.2277])
'''

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python XML RPC服务器端和客户端实例
Nov 22 Python
Python __setattr__、 __getattr__、 __delattr__、__call__用法示例
Mar 06 Python
Django中的“惰性翻译”方法的相关使用
Jul 27 Python
python利用正则表达式提取字符串
Dec 08 Python
基于Python __dict__与dir()的区别详解
Oct 30 Python
Python 数据处理库 pandas进阶教程
Apr 21 Python
Python实现的特征提取操作示例
Dec 03 Python
Pyqt5实现英文学习词典
Jun 24 Python
使用Python实现牛顿法求极值
Feb 10 Python
Python发起请求提示UnicodeEncodeError错误代码解决方法
Apr 21 Python
Keras官方中文文档:性能评估Metrices详解
Jun 15 Python
Python使用Pygame绘制时钟
Nov 29 Python
pytorch锁死在dataloader(训练时卡死)
Python趣味爬虫之用Python实现智慧校园一键评教
Pytorch 如何加速Dataloader提升数据读取速度
在前女友婚礼上,用Python破解了现场的WIFI还把名称改成了
pytorch DataLoader的num_workers参数与设置大小详解
May 28 #Python
Flask搭建一个API服务器的步骤
May 28 #Python
Python趣味挑战之给幼儿园弟弟生成1000道算术题
May 28 #Python
You might like
PHP 和 MySQL 基础教程(三)
2006/10/09 PHP
PHP采集利器 Snoopy 试用心得
2011/07/03 PHP
jquery修改属性值实例代码(设置属性值)
2014/01/06 Javascript
javascript运行机制之this详细介绍
2014/02/07 Javascript
input:checkbox多选框实现单选效果跟radio一样
2014/06/16 Javascript
浅谈javascript对象模型和function对象
2014/12/26 Javascript
javascript实现dom动态创建省市纵向列表菜单的方法
2015/05/14 Javascript
node.js中格式化数字增加千位符的几种方法
2015/07/03 Javascript
jquery实现可自动判断位置的弹出层效果代码
2015/10/12 Javascript
JS实现的论坛Ajax打分效果完整实例
2015/10/31 Javascript
基于JavaScript实现Json数据根据某个字段进行排序
2015/11/24 Javascript
详解JavaScript的AngularJS框架中的表达式与指令
2016/03/05 Javascript
AngularJS通过$http和服务器通信详解
2016/09/21 Javascript
js阻止冒泡和默认事件(默认行为)详解
2016/10/20 Javascript
js图片上传的封装代码
2017/08/01 Javascript
javascript将json格式数组下载为excel表格的方法
2017/12/22 Javascript
Vuex mutitons和actions初使用详解
2019/03/04 Javascript
JavaScript 作用域实例分析
2019/10/02 Javascript
JavaScript对象原型链原理详解
2020/02/05 Javascript
[38:39]KG vs Mineski 2019国际邀请赛小组赛 BO2 第一场 8.15
2019/08/16 DOTA
python写的一个文本编辑器
2014/01/23 Python
Python Tkinter实现简易计算器功能
2018/01/30 Python
对pytorch网络层结构的数组化详解
2018/12/08 Python
python opencv 读取本地视频文件 修改ffmpeg的方法
2019/01/26 Python
浅谈pandas筛选出表中满足另一个表所有条件的数据方法
2019/02/08 Python
两个元祖T1=('a', 'b'),T2=('c', 'd')使用匿名函数将其转变成[{'a': 'c'},{'b': 'd'}]的几种方法
2019/03/05 Python
python GUI库图形界面开发之PyQt5简单绘图板实例与代码分析
2020/03/08 Python
Windows下Anaconda安装、换源与更新的方法
2020/04/17 Python
JAVA SWT事件四种写法实例解析
2020/06/05 Python
打造经典复古风格的品牌:Alice + Olivia(爱丽丝+奥利维亚)
2016/09/07 全球购物
菲律宾酒店预订网站:Hotels.com菲律宾
2017/07/12 全球购物
环境科学毕业生自荐信
2013/11/21 职场文书
会计学个人自荐信模板
2013/12/13 职场文书
不打扫卫生检讨书
2014/02/12 职场文书
文秘个人求职信范文
2014/04/22 职场文书
小学校园广播稿集锦
2014/10/04 职场文书