pytorch对可变长度序列的处理方法详解


Posted in Python onDecember 08, 2018

主要是用函数torch.nn.utils.rnn.PackedSequence()和torch.nn.utils.rnn.pack_padded_sequence()以及torch.nn.utils.rnn.pad_packed_sequence()来进行的,分别来看看这三个函数的用法。

1、torch.nn.utils.rnn.PackedSequence()

NOTE: 这个类的实例不能手动创建。它们只能被 pack_padded_sequence() 实例化。

PackedSequence对象包括:

一个data对象:一个torch.Variable(令牌的总数,每个令牌的维度),在这个简单的例子中有五个令牌序列(用整数表示):(18,1)

一个batch_sizes对象:每个时间步长的令牌数列表,在这个例子中为:[6,5,2,4,1]

用pack_padded_sequence函数来构造这个对象非常的简单:

pytorch对可变长度序列的处理方法详解

如何构造一个PackedSequence对象(batch_first = True)

PackedSequence对象有一个很不错的特性,就是我们无需对序列解包(这一步操作非常慢)即可直接在PackedSequence数据变量上执行许多操作。特别是我们可以对令牌执行任何操作(即对令牌的顺序/上下文不敏感)。当然,我们也可以使用接受PackedSequence作为输入的任何一个pyTorch模块(pyTorch 0.2)。

2、torch.nn.utils.rnn.pack_padded_sequence()

这里的pack,理解成压紧比较好。 将一个 填充过的变长序列 压紧。(填充时候,会有冗余,所以压紧一下)

输入的形状可以是(T×B×* )。T是最长序列长度,B是batch size,*代表任意维度(可以是0)。如果batch_first=True的话,那么相应的 input size 就是 (B×T×*)。

Variable中保存的序列,应该按序列长度的长短排序,长的在前,短的在后。即input[:,0]代表的是最长的序列,input[:, B-1]保存的是最短的序列。

NOTE: 只要是维度大于等于2的input都可以作为这个函数的参数。你可以用它来打包labels,然后用RNN的输出和打包后的labels来计算loss。通过PackedSequence对象的.data属性可以获取 Variable。

参数说明:

input (Variable) ? 变长序列 被填充后的 batch

lengths (list[int]) ? Variable 中 每个序列的长度。

batch_first (bool, optional) ? 如果是True,input的形状应该是B*T*size。

返回值:

一个PackedSequence 对象。

3、torch.nn.utils.rnn.pad_packed_sequence()

填充packed_sequence。

上面提到的函数的功能是将一个填充后的变长序列压紧。 这个操作和pack_padded_sequence()是相反的。把压紧的序列再填充回来。

返回的Varaible的值的size是 T×B×*, T 是最长序列的长度,B 是 batch_size,如果 batch_first=True,那么返回值是B×T×*。

Batch中的元素将会以它们长度的逆序排列。

参数说明:

sequence (PackedSequence) ? 将要被填充的 batch

batch_first (bool, optional) ? 如果为True,返回的数据的格式为 B×T×*。

返回值: 一个tuple,包含被填充后的序列,和batch中序列的长度列表。

例子:

import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.nn import utils as nn_utils
batch_size = 2
max_length = 3
hidden_size = 2
n_layers =1
 
tensor_in = torch.FloatTensor([[1, 2, 3], [1, 0, 0]]).resize_(2,3,1)
tensor_in = Variable( tensor_in ) #[batch, seq, feature], [2, 3, 1]
seq_lengths = [3,1] # list of integers holding information about the batch size at each sequence step
 
# pack it
pack = nn_utils.rnn.pack_padded_sequence(tensor_in, seq_lengths, batch_first=True)
 
# initialize
rnn = nn.RNN(1, hidden_size, n_layers, batch_first=True)
h0 = Variable(torch.randn(n_layers, batch_size, hidden_size))
 
#forward
out, _ = rnn(pack, h0)
 
# unpack
unpacked = nn_utils.rnn.pad_packed_sequence(out)
print('111',unpacked)

输出:

111 (Variable containing:
(0 ,.,.) =
 0.5406 0.3584
 -0.1403 0.0308
 
(1 ,.,.) =
 -0.6855 -0.9307
 0.0000 0.0000
[torch.FloatTensor of size 2x2x2]
, [2, 1])

以上这篇pytorch对可变长度序列的处理方法详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在Python中使用第三方模块的教程
Apr 27 Python
在Python中处理XML的教程
Apr 29 Python
Python编程入门之Hello World的三种实现方式
Nov 13 Python
Python3中的真除和Floor除法用法分析
Mar 16 Python
python中print()函数的“,”与java中System.out.print()函数中的“+”功能详解
Nov 24 Python
浅谈python 里面的单下划线与双下划线的区别
Dec 01 Python
Python采集代理ip并判断是否可用和定时更新的方法
May 07 Python
对tf.reduce_sum tensorflow维度上的操作详解
Jul 26 Python
修改python plot折线图的坐标轴刻度方法
Dec 13 Python
Flask框架实现的前端RSA加密与后端Python解密功能详解
Aug 13 Python
Python调用shell cmd方法代码示例解析
Jun 18 Python
python的launcher用法知识点总结
Aug 07 Python
pytorch 转换矩阵的维数位置方法
Dec 08 #Python
pytorch 调整某一维度数据顺序的方法
Dec 08 #Python
Python操作mongodb数据库的方法详解
Dec 08 #Python
Opencv+Python 色彩通道拆分及合并的示例
Dec 08 #Python
python-opencv颜色提取分割方法
Dec 08 #Python
使用python将图片按标签分入不同文件夹的方法
Dec 08 #Python
对python的输出和输出格式详解
Dec 08 #Python
You might like
PHP下几种删除目录的方法总结
2007/08/19 PHP
PHP的5个安全措施小结
2012/07/17 PHP
PHP中文编码小技巧
2014/12/25 PHP
PHP缓冲区用法总结
2016/02/14 PHP
Laravel 5.2 文档 数据库 ―― 起步介绍
2019/10/21 PHP
初窥JQuery(二)事件机制(2)
2010/12/06 Javascript
javascript字符串拼接的效率问题
2010/12/25 Javascript
jQuery初学:find()方法及children方法的区别分析
2011/01/31 Javascript
P3P Header解决Cookie跨域的问题
2013/03/12 Javascript
js 去掉空格实例 Trim() LTrim() RTrim()
2014/01/07 Javascript
angularjs的一些优化小技巧
2014/12/06 Javascript
javascript实现倒计时N秒后网页自动跳转代码
2014/12/11 Javascript
Javascript进制转换实例分析
2015/05/14 Javascript
javascript删除数组重复元素的方法汇总
2015/06/24 Javascript
详解Bootstrap创建表单的三种格式(一)
2016/01/04 Javascript
bootstrap手风琴折叠示例代码分享
2017/05/22 Javascript
karma+webpack搭建vue单元测试环境的方法示例
2018/05/24 Javascript
JavaScript中变量提升与函数提升经典实例分析
2018/07/26 Javascript
clipboard在vue中的使用的方法示例
2018/10/19 Javascript
手把手教你如何使用nodejs编写cli命令行
2018/11/05 NodeJs
js实现弹出框的拖拽效果实例代码详解
2019/04/16 Javascript
Python的面向对象思想分析
2015/01/14 Python
python、PyTorch图像读取与numpy转换实例
2020/01/13 Python
pyinstaller打包单文件时--uac-admin选项不起作用怎么办
2020/04/15 Python
Python使用pycharm导入pymysql教程
2020/09/16 Python
Python之字典对象的几种创建方法
2020/09/30 Python
matplotlib之pyplot模块之标题(title()和suptitle())
2021/02/22 Python
GANT英国官方网上商店:甘特衬衫
2018/02/06 全球购物
加拿大快时尚零售商:Ardene
2018/02/14 全球购物
美国第一大药店连锁机构:Walgreens(沃尔格林)
2019/10/10 全球购物
考试退步检讨书
2014/01/15 职场文书
大学活动邀请函
2014/01/28 职场文书
公司门卫岗位职责
2014/03/15 职场文书
房屋出租协议书范本(标准版)
2014/09/24 职场文书
再婚婚前财产协议书范本
2014/10/19 职场文书
SQL Server中的逻辑函数介绍
2022/05/25 SQL Server