pytorch对可变长度序列的处理方法详解


Posted in Python onDecember 08, 2018

主要是用函数torch.nn.utils.rnn.PackedSequence()和torch.nn.utils.rnn.pack_padded_sequence()以及torch.nn.utils.rnn.pad_packed_sequence()来进行的,分别来看看这三个函数的用法。

1、torch.nn.utils.rnn.PackedSequence()

NOTE: 这个类的实例不能手动创建。它们只能被 pack_padded_sequence() 实例化。

PackedSequence对象包括:

一个data对象:一个torch.Variable(令牌的总数,每个令牌的维度),在这个简单的例子中有五个令牌序列(用整数表示):(18,1)

一个batch_sizes对象:每个时间步长的令牌数列表,在这个例子中为:[6,5,2,4,1]

用pack_padded_sequence函数来构造这个对象非常的简单:

pytorch对可变长度序列的处理方法详解

如何构造一个PackedSequence对象(batch_first = True)

PackedSequence对象有一个很不错的特性,就是我们无需对序列解包(这一步操作非常慢)即可直接在PackedSequence数据变量上执行许多操作。特别是我们可以对令牌执行任何操作(即对令牌的顺序/上下文不敏感)。当然,我们也可以使用接受PackedSequence作为输入的任何一个pyTorch模块(pyTorch 0.2)。

2、torch.nn.utils.rnn.pack_padded_sequence()

这里的pack,理解成压紧比较好。 将一个 填充过的变长序列 压紧。(填充时候,会有冗余,所以压紧一下)

输入的形状可以是(T×B×* )。T是最长序列长度,B是batch size,*代表任意维度(可以是0)。如果batch_first=True的话,那么相应的 input size 就是 (B×T×*)。

Variable中保存的序列,应该按序列长度的长短排序,长的在前,短的在后。即input[:,0]代表的是最长的序列,input[:, B-1]保存的是最短的序列。

NOTE: 只要是维度大于等于2的input都可以作为这个函数的参数。你可以用它来打包labels,然后用RNN的输出和打包后的labels来计算loss。通过PackedSequence对象的.data属性可以获取 Variable。

参数说明:

input (Variable) ? 变长序列 被填充后的 batch

lengths (list[int]) ? Variable 中 每个序列的长度。

batch_first (bool, optional) ? 如果是True,input的形状应该是B*T*size。

返回值:

一个PackedSequence 对象。

3、torch.nn.utils.rnn.pad_packed_sequence()

填充packed_sequence。

上面提到的函数的功能是将一个填充后的变长序列压紧。 这个操作和pack_padded_sequence()是相反的。把压紧的序列再填充回来。

返回的Varaible的值的size是 T×B×*, T 是最长序列的长度,B 是 batch_size,如果 batch_first=True,那么返回值是B×T×*。

Batch中的元素将会以它们长度的逆序排列。

参数说明:

sequence (PackedSequence) ? 将要被填充的 batch

batch_first (bool, optional) ? 如果为True,返回的数据的格式为 B×T×*。

返回值: 一个tuple,包含被填充后的序列,和batch中序列的长度列表。

例子:

import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.nn import utils as nn_utils
batch_size = 2
max_length = 3
hidden_size = 2
n_layers =1
 
tensor_in = torch.FloatTensor([[1, 2, 3], [1, 0, 0]]).resize_(2,3,1)
tensor_in = Variable( tensor_in ) #[batch, seq, feature], [2, 3, 1]
seq_lengths = [3,1] # list of integers holding information about the batch size at each sequence step
 
# pack it
pack = nn_utils.rnn.pack_padded_sequence(tensor_in, seq_lengths, batch_first=True)
 
# initialize
rnn = nn.RNN(1, hidden_size, n_layers, batch_first=True)
h0 = Variable(torch.randn(n_layers, batch_size, hidden_size))
 
#forward
out, _ = rnn(pack, h0)
 
# unpack
unpacked = nn_utils.rnn.pad_packed_sequence(out)
print('111',unpacked)

输出:

111 (Variable containing:
(0 ,.,.) =
 0.5406 0.3584
 -0.1403 0.0308
 
(1 ,.,.) =
 -0.6855 -0.9307
 0.0000 0.0000
[torch.FloatTensor of size 2x2x2]
, [2, 1])

以上这篇pytorch对可变长度序列的处理方法详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中二维阵列的变换实例
Oct 09 Python
详解详解Python中writelines()方法的使用
May 25 Python
Python中super关键字用法实例分析
May 28 Python
Python抽象类的新写法
Jun 18 Python
win10下tensorflow和matplotlib安装教程
Sep 19 Python
pyqt5之将textBrowser的内容写入txt文档的方法
Jun 21 Python
详细整理python 字符串(str)与列表(list)以及数组(array)之间的转换方法
Aug 30 Python
tensorflow生成多个tfrecord文件实例
Feb 17 Python
Python新手学习标准库模块命名
May 29 Python
python使用建议与技巧分享(二)
Aug 17 Python
Python 里最强的地图绘制神器
Mar 01 Python
Python3使用Qt5来实现简易的五子棋小游戏
May 02 Python
pytorch 转换矩阵的维数位置方法
Dec 08 #Python
pytorch 调整某一维度数据顺序的方法
Dec 08 #Python
Python操作mongodb数据库的方法详解
Dec 08 #Python
Opencv+Python 色彩通道拆分及合并的示例
Dec 08 #Python
python-opencv颜色提取分割方法
Dec 08 #Python
使用python将图片按标签分入不同文件夹的方法
Dec 08 #Python
对python的输出和输出格式详解
Dec 08 #Python
You might like
如何提高MYSQL数据库的查询统计速度 select 索引应用
2007/04/11 PHP
php另类上传图片的方法(PHP用Socket上传图片)
2013/10/30 PHP
PHP实现的redis主从数据库状态检测功能示例
2017/07/20 PHP
PHP删除数组中指定下标的元素方法
2018/02/03 PHP
javascript radio 联动效果
2009/03/04 Javascript
jquery自动完成插件(autocomplete)应用之PHP版
2009/12/15 Javascript
屏蔽Flash右键信息的js代码
2010/01/17 Javascript
JavaScript中使用replace结合正则实现replaceAll的效果
2010/06/04 Javascript
基于JQuery 选择器使用说明介绍
2013/04/18 Javascript
JS不间断向上滚动效果代码
2013/12/25 Javascript
jquery实现pager控件示例
2014/04/09 Javascript
javascript中获取class的简单实现
2016/07/12 Javascript
BootStrap 附加导航组件
2016/07/22 Javascript
JavaScript比较当前时间是否在指定时间段内的方法
2016/08/02 Javascript
两行代码轻松搞定JavaScript日期验证
2016/08/03 Javascript
详解tween.js 中文使用指南
2018/01/05 Javascript
Javascript将图片的绝对路径转换为base64编码的方法
2018/01/11 Javascript
Angularjs实现控制器之间通信方式实例总结
2018/03/27 Javascript
vue-quill-editor+plupload富文本编辑器实例详解
2018/10/19 Javascript
python实现ipsec开权限实例
2014/11/11 Python
使用pyecharts在jupyter notebook上绘图
2020/04/23 Python
Django admin实现图书管理系统菜鸟级教程完整实例
2017/12/12 Python
python Flask 装饰器顺序问题解决
2018/08/08 Python
通过PHP与Python代码对比的语法差异详解
2019/07/10 Python
Python使用ffmpy将amr格式的音频转化为mp3格式的例子
2019/08/08 Python
详解python安装matplotlib库三种失败情况
2020/07/28 Python
Bally美国官网:经典瑞士鞋履、手袋及配饰奢侈品牌
2018/05/18 全球购物
网游商务专员求职信
2013/10/15 职场文书
教师自荐信范文
2013/12/09 职场文书
大学本科生的个人自我评价
2013/12/09 职场文书
大学学生会竞选演讲稿
2014/04/25 职场文书
幼儿教师师德演讲稿
2014/05/06 职场文书
文明寝室申报材料
2014/05/12 职场文书
安全口号大全
2014/06/21 职场文书
研修心得体会
2014/09/04 职场文书
分享7个 Python 实战项目练习
2022/03/03 Python