pytorch对可变长度序列的处理方法详解


Posted in Python onDecember 08, 2018

主要是用函数torch.nn.utils.rnn.PackedSequence()和torch.nn.utils.rnn.pack_padded_sequence()以及torch.nn.utils.rnn.pad_packed_sequence()来进行的,分别来看看这三个函数的用法。

1、torch.nn.utils.rnn.PackedSequence()

NOTE: 这个类的实例不能手动创建。它们只能被 pack_padded_sequence() 实例化。

PackedSequence对象包括:

一个data对象:一个torch.Variable(令牌的总数,每个令牌的维度),在这个简单的例子中有五个令牌序列(用整数表示):(18,1)

一个batch_sizes对象:每个时间步长的令牌数列表,在这个例子中为:[6,5,2,4,1]

用pack_padded_sequence函数来构造这个对象非常的简单:

pytorch对可变长度序列的处理方法详解

如何构造一个PackedSequence对象(batch_first = True)

PackedSequence对象有一个很不错的特性,就是我们无需对序列解包(这一步操作非常慢)即可直接在PackedSequence数据变量上执行许多操作。特别是我们可以对令牌执行任何操作(即对令牌的顺序/上下文不敏感)。当然,我们也可以使用接受PackedSequence作为输入的任何一个pyTorch模块(pyTorch 0.2)。

2、torch.nn.utils.rnn.pack_padded_sequence()

这里的pack,理解成压紧比较好。 将一个 填充过的变长序列 压紧。(填充时候,会有冗余,所以压紧一下)

输入的形状可以是(T×B×* )。T是最长序列长度,B是batch size,*代表任意维度(可以是0)。如果batch_first=True的话,那么相应的 input size 就是 (B×T×*)。

Variable中保存的序列,应该按序列长度的长短排序,长的在前,短的在后。即input[:,0]代表的是最长的序列,input[:, B-1]保存的是最短的序列。

NOTE: 只要是维度大于等于2的input都可以作为这个函数的参数。你可以用它来打包labels,然后用RNN的输出和打包后的labels来计算loss。通过PackedSequence对象的.data属性可以获取 Variable。

参数说明:

input (Variable) ? 变长序列 被填充后的 batch

lengths (list[int]) ? Variable 中 每个序列的长度。

batch_first (bool, optional) ? 如果是True,input的形状应该是B*T*size。

返回值:

一个PackedSequence 对象。

3、torch.nn.utils.rnn.pad_packed_sequence()

填充packed_sequence。

上面提到的函数的功能是将一个填充后的变长序列压紧。 这个操作和pack_padded_sequence()是相反的。把压紧的序列再填充回来。

返回的Varaible的值的size是 T×B×*, T 是最长序列的长度,B 是 batch_size,如果 batch_first=True,那么返回值是B×T×*。

Batch中的元素将会以它们长度的逆序排列。

参数说明:

sequence (PackedSequence) ? 将要被填充的 batch

batch_first (bool, optional) ? 如果为True,返回的数据的格式为 B×T×*。

返回值: 一个tuple,包含被填充后的序列,和batch中序列的长度列表。

例子:

import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.nn import utils as nn_utils
batch_size = 2
max_length = 3
hidden_size = 2
n_layers =1
 
tensor_in = torch.FloatTensor([[1, 2, 3], [1, 0, 0]]).resize_(2,3,1)
tensor_in = Variable( tensor_in ) #[batch, seq, feature], [2, 3, 1]
seq_lengths = [3,1] # list of integers holding information about the batch size at each sequence step
 
# pack it
pack = nn_utils.rnn.pack_padded_sequence(tensor_in, seq_lengths, batch_first=True)
 
# initialize
rnn = nn.RNN(1, hidden_size, n_layers, batch_first=True)
h0 = Variable(torch.randn(n_layers, batch_size, hidden_size))
 
#forward
out, _ = rnn(pack, h0)
 
# unpack
unpacked = nn_utils.rnn.pad_packed_sequence(out)
print('111',unpacked)

输出:

111 (Variable containing:
(0 ,.,.) =
 0.5406 0.3584
 -0.1403 0.0308
 
(1 ,.,.) =
 -0.6855 -0.9307
 0.0000 0.0000
[torch.FloatTensor of size 2x2x2]
, [2, 1])

以上这篇pytorch对可变长度序列的处理方法详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现socket端口重定向示例
Feb 10 Python
详解Python中的Descriptor描述符类
Jun 14 Python
python安装mysql-python简明笔记(ubuntu环境)
Jun 25 Python
Python Django使用forms来实现评论功能
Aug 17 Python
Python读csv文件去掉一列后再写入新的文件实例
Dec 28 Python
Django 中间键和上下文处理器的使用
Mar 17 Python
python通过txt文件批量安装依赖包的实现步骤
Aug 13 Python
pytorch 使用单个GPU与多个GPU进行训练与测试的方法
Aug 19 Python
torch 中各种图像格式转换的实现方法
Dec 26 Python
python实现引用其他路径包里面的模块
Mar 09 Python
python使用PIL剪切和拼接图片
Mar 23 Python
使用Python爬取小姐姐图片(beautifulsoup法)
Feb 11 Python
pytorch 转换矩阵的维数位置方法
Dec 08 #Python
pytorch 调整某一维度数据顺序的方法
Dec 08 #Python
Python操作mongodb数据库的方法详解
Dec 08 #Python
Opencv+Python 色彩通道拆分及合并的示例
Dec 08 #Python
python-opencv颜色提取分割方法
Dec 08 #Python
使用python将图片按标签分入不同文件夹的方法
Dec 08 #Python
对python的输出和输出格式详解
Dec 08 #Python
You might like
学习discuz php 引入文件的方法DISCUZ_ROOT
2009/06/21 PHP
PHP form 表单传参明细研究
2009/07/17 PHP
php 文件上传系统手记
2009/10/26 PHP
php排序算法(冒泡排序,快速排序)
2012/10/09 PHP
PHP使用ffmpeg给视频增加字幕显示的方法
2015/03/12 PHP
php使用glob函数遍历文件和目录详解
2016/09/23 PHP
PHP 爬取网页的主要方法
2018/07/13 PHP
用javascript实现自定义标签
2007/05/08 Javascript
$()JS小技巧
2007/07/21 Javascript
JavaScript 创建运动框架的实现代码
2013/05/08 Javascript
如何解决Jquery库及其他库之间的$命名冲突
2013/09/15 Javascript
jquery定时滑出可最小化的底部提示层特效代码
2013/10/02 Javascript
在JavaScript中操作时间之getMonth()方法的使用
2015/06/10 Javascript
js获取地址栏参数的两种方法
2017/06/27 Javascript
bootstrap精简教程_动力节点Java学院整理
2017/07/14 Javascript
Angular2学习笔记之数据绑定的示例代码
2018/01/03 Javascript
vue-socket.io接收不到数据问题的解决方法
2020/05/13 Javascript
vue-cli3访问public文件夹静态资源报错的解决方式
2020/09/02 Javascript
vue3.0 项目搭建和使用流程
2021/03/04 Vue.js
用Python输出一个杨辉三角的例子
2014/06/13 Python
python学习教程之使用py2exe打包
2017/09/24 Python
python shell根据ip获取主机名代码示例
2017/11/25 Python
详解用Python练习画个美队盾牌
2019/03/23 Python
详解Python 字符串相似性的几种度量方法
2019/08/29 Python
python Jupyter运行时间实例过程解析
2019/12/13 Python
解决PyCharm不在run输出运行结果而不是再Console里输出的问题
2020/09/21 Python
Saucony澳大利亚官网:美国跑鞋品牌,运动鞋中的劳斯莱斯
2018/05/05 全球购物
预备党员的自我评价
2014/03/12 职场文书
银行内勤岗位职责
2014/04/09 职场文书
代理协议书
2014/04/22 职场文书
大学生党校培训心得体会
2014/09/11 职场文书
小石潭记导游词
2015/02/03 职场文书
MySQL创建索引需要了解的
2021/04/08 MySQL
浅谈Python响应式类库RxPy
2021/06/14 Python
MySQL创建管理LIST分区
2022/04/13 MySQL
如何解决flex文本溢出问题小结
2022/07/15 HTML / CSS