PyTorch中Tensor的维度变换实现


Posted in Python onAugust 18, 2019

对于 PyTorch 的基本数据对象 Tensor (张量),在处理问题时,需要经常改变数据的维度,以便于后期的计算和进一步处理,本文旨在列举一些维度变换的方法并举例,方便大家查看。

维度查看:torch.Tensor.size()

查看当前 tensor 的维度

举个例子:

>>> import torch
>>> a = torch.Tensor([[[1, 2], [3, 4], [5, 6]]])
>>> a.size()
torch.Size([1, 3, 2])

张量变形:torch.Tensor.view(*args) → Tensor

返回一个有相同数据但大小不同的 tensor。 返回的 tensor 必须有与原 tensor 相同的数据和相同数目的元素,但可以有不同的大小。一个 tensor 必须是连续的 contiguous() 才能被查看。

举个例子:

>>> x = torch.randn(2, 9)
>>> x.size()
torch.Size([2, 9])
>>> x
tensor([[-1.6833, -0.4100, -1.5534, -0.6229, -1.0310, -0.8038, 0.5166, 0.9774,
     0.3455],
    [-0.2306, 0.4217, 1.2874, -0.3618, 1.7872, -0.9012, 0.8073, -1.1238,
     -0.3405]])
>>> y = x.view(3, 6)
>>> y.size()
torch.Size([3, 6])
>>> y
tensor([[-1.6833, -0.4100, -1.5534, -0.6229, -1.0310, -0.8038],
    [ 0.5166, 0.9774, 0.3455, -0.2306, 0.4217, 1.2874],
    [-0.3618, 1.7872, -0.9012, 0.8073, -1.1238, -0.3405]])
>>> z = x.view(2, 3, 3)
>>> z.size()
torch.Size([2, 3, 3])
>>> z
tensor([[[-1.6833, -0.4100, -1.5534],
     [-0.6229, -1.0310, -0.8038],
     [ 0.5166, 0.9774, 0.3455]],

    [[-0.2306, 0.4217, 1.2874],
     [-0.3618, 1.7872, -0.9012],
     [ 0.8073, -1.1238, -0.3405]]])

可以看到 x 和 y 、z 中数据的数量和每个数据的大小都是相等的,只是尺寸或维度数量发生了改变。

压缩 / 解压张量:torch.squeeze()、torch.unsqueeze()

  • torch.squeeze(input, dim=None, out=None)

将输入张量形状中的 1 去除并返回。如果输入是形如(A×1×B×1×C×1×D),那么输出形状就为: (A×B×C×D)

当给定 dim 时,那么挤压操作只在给定维度上。例如,输入形状为: (A×1×B),squeeze(input, 0) 将会保持张量不变,只有用 squeeze(input, 1),形状会变成 (A×B)。

返回张量与输入张量共享内存,所以改变其中一个的内容会改变另一个。

举个例子:

>>> x = torch.randn(3, 1, 2)
>>> x
tensor([[[-0.1986, 0.4352]],

    [[ 0.0971, 0.2296]],

    [[ 0.8339, -0.5433]]])
>>> x.squeeze().size() # 不加参数,去掉所有为元素个数为1的维度
torch.Size([3, 2])
>>> x.squeeze()
tensor([[-0.1986, 0.4352],
    [ 0.0971, 0.2296],
    [ 0.8339, -0.5433]])
>>> torch.squeeze(x, 0).size() # 加上参数,去掉第一维的元素,不起作用,因为第一维有2个元素
torch.Size([3, 1, 2])
>>> torch.squeeze(x, 1).size() # 加上参数,去掉第二维的元素,正好为 1,起作用
torch.Size([3, 2])

可以看到如果加参数,只有维度中尺寸为 1 的位置才会消失

  • torch.unsqueeze(input, dim, out=None)

返回一个新的张量,对输入的制定位置插入维度 1

返回张量与输入张量共享内存,所以改变其中一个的内容会改变另一个。

如果 dim 为负,则将会被转化 dim+input.dim()+1

接着用上面的数据举个例子:

>>> x.unsqueeze(0).size()
torch.Size([1, 3, 1, 2])
>>> x.unsqueeze(0)
tensor([[[[-0.1986, 0.4352]],

     [[ 0.0971, 0.2296]],

     [[ 0.8339, -0.5433]]]])
>>> x.unsqueeze(-1).size()
torch.Size([3, 1, 2, 1])
>>> x.unsqueeze(-1)
tensor([[[[-0.1986],
     [ 0.4352]]],


    [[[ 0.0971],
     [ 0.2296]]],


    [[[ 0.8339],
     [-0.5433]]]])

可以看到在指定的位置,增加了一个维度。

扩大张量:torch.Tensor.expand(*sizes) → Tensor

返回 tensor 的一个新视图,单个维度扩大为更大的尺寸。 tensor 也可以扩大为更高维,新增加的维度将附在前面。 扩大 tensor 不需要分配新内存,只是仅仅新建一个 tensor 的视图,其中通过将 stride 设为 0,一维将会扩展位更高维。任何一个一维的在不分配新内存情况下可扩展为任意的数值。

举个例子:

>>> x = torch.Tensor([[1], [2], [3]])
>>> x.size()
torch.Size([3, 1])
>>> x.expand(3, 4)
tensor([[1., 1., 1., 1.],
    [2., 2., 2., 2.],
    [3., 3., 3., 3.]])
>>> x.expand(3, -1)
tensor([[1.],
    [2.],
    [3.]])

原数据是 3 行 1 列,扩大后变为 3 行 4 列,方法中填 -1 的效果与 1 一样,只有尺寸为 1 才可以扩大,如果不为 1 就无法改变,而且尺寸不为 1 的维度必须要和原来一样填写进去。

重复张量:torch.Tensor.repeat(*sizes)

沿着指定的维度重复 tensor。 不同于 expand(),本函数复制的是 tensor 中的数据。

举个例子:

>>> x = torch.Tensor([1, 2, 3])
>>> x.size()
torch.Size([3])
>>> x.repeat(4, 2)
    [1., 2., 3., 1., 2., 3.],
    [1., 2., 3., 1., 2., 3.],
    [1., 2., 3., 1., 2., 3.]])
>>> x.repeat(4, 2).size()
torch.Size([4, 6])

原数据为 1 行 3 列,按行方向扩大为原来的 4 倍,列方向扩大为原来的 2 倍,变为了 4 行 6 列。

变化时可以看成是把原数据作成一个整体,再按指定的维度和尺寸重复,变成一个 4 行 2 列的矩阵,其中的每一个单位都是相同的,再把原数据放到每个单位中。

矩阵转置:torch.t(input, out=None) → Tensor

输入一个矩阵(2维张量),并转置0, 1维。 可以被视为函数 transpose(input, 0, 1) 的简写函数。

举个例子:

>>> x = torch.randn(3, 5)
>>> x
tensor([[-1.0752, -0.9706, -0.8770, -0.4224, 0.9776],
    [ 0.2489, -0.2986, -0.7816, -0.0823, 1.1811],
    [-1.1124, 0.2160, -0.8446, 0.1762, -0.5164]])
>>> x.t()
tensor([[-1.0752, 0.2489, -1.1124],
    [-0.9706, -0.2986, 0.2160],
    [-0.8770, -0.7816, -0.8446],
    [-0.4224, -0.0823, 0.1762],
    [ 0.9776, 1.1811, -0.5164]])
>>> torch.t(x) # 另一种用法
tensor([[-1.0752, 0.2489, -1.1124],
    [-0.9706, -0.2986, 0.2160],
    [-0.8770, -0.7816, -0.8446],
    [-0.4224, -0.0823, 0.1762],
    [ 0.9776, 1.1811, -0.5164]])

必须要是 2 维的张量,也就是矩阵,才可以使用。

维度置换:torch.transpose()、torch.Tensor.permute()

  • torch.transpose(input, dim0, dim1, out=None) → Tensor

返回输入矩阵 input 的转置。交换维度 dim0 和 dim1。 输出张量与输入张量共享内存,所以改变其中一个会导致另外一个也被修改。

举个例子:

>>> x = torch.randn(2, 4, 3)
>>> x
tensor([[[-1.2502, -0.7363, 0.5534],
     [-0.2050, 3.1847, -1.6729],
     [-0.2591, -0.0860, 0.4660],
     [-1.2189, -1.1206, 0.0637]],

    [[ 1.4791, -0.7569, 2.5017],
     [ 0.0098, -1.0217, 0.8142],
     [-0.2414, -0.1790, 2.3506],
     [-0.6860, -0.2363, 1.0481]]])
>>> torch.transpose(x, 1, 2).size()
torch.Size([2, 3, 4])
>>> torch.transpose(x, 1, 2)
tensor([[[-1.2502, -0.2050, -0.2591, -1.2189],
     [-0.7363, 3.1847, -0.0860, -1.1206],
     [ 0.5534, -1.6729, 0.4660, 0.0637]],

    [[ 1.4791, 0.0098, -0.2414, -0.6860],
     [-0.7569, -1.0217, -0.1790, -0.2363],
     [ 2.5017, 0.8142, 2.3506, 1.0481]]])
>>> torch.transpose(x, 0, 1).size()
torch.Size([4, 2, 3])
>>> torch.transpose(x, 0, 1)
tensor([[[-1.2502, -0.7363, 0.5534],
     [ 1.4791, -0.7569, 2.5017]],

    [[-0.2050, 3.1847, -1.6729],
     [ 0.0098, -1.0217, 0.8142]],

    [[-0.2591, -0.0860, 0.4660],
     [-0.2414, -0.1790, 2.3506]],

    [[-1.2189, -1.1206, 0.0637],
     [-0.6860, -0.2363, 1.0481]]])

可以对多维度的张量进行转置

  • torch.Tensor.permute(dims)

将 tensor 的维度换位

接着用上面的数据举个例子:

>>> x.size()
torch.Size([2, 4, 3])
>>> x.permute(2, 0, 1).size()
torch.Size([3, 2, 4])
>>> x.permute(2, 0, 1)
tensor([[[-1.2502, -0.2050, -0.2591, -1.2189],
     [ 1.4791, 0.0098, -0.2414, -0.6860]],

    [[-0.7363, 3.1847, -0.0860, -1.1206],
     [-0.7569, -1.0217, -0.1790, -0.2363]],

    [[ 0.5534, -1.6729, 0.4660, 0.0637],
     [ 2.5017, 0.8142, 2.3506, 1.0481]]])

直接在方法中填入各个维度的索引,张量就会交换指定维度的尺寸,不限于两两交换。

 以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
wxPython窗口的继承机制实例分析
Sep 28 Python
PyQt5 pyqt多线程操作入门
May 05 Python
python通过Windows下远程控制Linux系统
Jun 20 Python
Flask框架URL管理操作示例【基于@app.route】
Jul 23 Python
详解python中的线程与线程池
May 10 Python
python 执行终端/控制台命令的例子
Jul 12 Python
一篇文章弄懂Python中的可迭代对象、迭代器和生成器
Aug 12 Python
Python猜数字算法题详解
Mar 01 Python
python+opencv边缘提取与各函数参数解析
Mar 09 Python
Pytorch转onnx、torchscript方式
May 25 Python
python使用opencv resize图像不进行插值的操作
Jul 05 Python
Python3+Django get/post请求实现教程详解
Feb 16 Python
PyTorch中Tensor的拼接与拆分的实现
Aug 18 #Python
详解PyTorch中Tensor的高阶操作
Aug 18 #Python
浅析PyTorch中nn.Linear的使用
Aug 18 #Python
Pytorch实现GoogLeNet的方法
Aug 18 #Python
PyTorch之图像和Tensor填充的实例
Aug 18 #Python
Pytorch Tensor的索引与切片例子
Aug 18 #Python
在PyTorch中Tensor的查找和筛选例子
Aug 18 #Python
You might like
PHP Memcached应用实现代码
2010/02/08 PHP
linux下使用ThinkPHP需要注意大小写导致的问题
2011/08/02 PHP
php数组函数序列 之shuffle()和array_rand() 随机函数使用介绍
2011/10/29 PHP
浅谈apache和nginx的rewrite的区别
2013/02/22 PHP
浅析PHP 按位与或 (^ 、&)
2013/06/21 PHP
PHP实现抓取迅雷VIP账号的方法
2015/07/30 PHP
PHP htmlspecialchars()函数用法与实例讲解
2019/03/08 PHP
JQuery 动画卷页 返回顶部 动画特效(兼容Chrome)
2010/02/15 Javascript
jquery实用代码片段集合
2010/08/12 Javascript
关于URL中的特殊符号使用介绍
2011/11/03 Javascript
JavaScript对象创建及继承原理实例解剖
2013/02/28 Javascript
JavaScript中实现依赖注入的思路分享
2015/01/15 Javascript
jquery密码强度校验
2015/12/02 Javascript
解决jquery无法找到其他父级子集问题的方法
2016/05/10 Javascript
JavaScript字符集编码与解码详谈
2017/02/02 Javascript
nodejs基础应用
2017/02/03 NodeJs
原生js封装运动框架的示例讲解
2017/10/01 Javascript
Vue.set() this.$set()引发的视图更新思考及注意事项
2018/08/30 Javascript
vue的hash值原理也是table切换实例代码
2020/12/14 Vue.js
[00:30]明星选手化身超级英雄!2018DOTA2亚洲邀请赛全明星赛来临!
2018/04/06 DOTA
pycharm 使用心得(八)如何调用另一文件中的函数
2014/06/06 Python
python判断all函数输出结果是否为true的方法
2020/12/03 Python
利用python绘制正态分布曲线
2021/01/04 Python
使用css3 属性如何丰富图片样式(圆角 阴影 渐变)
2012/11/22 HTML / CSS
HTML5新增form控件和表单属性实例代码详解
2019/05/15 HTML / CSS
计算s=f(f(-1.4))的值
2014/05/06 面试题
环境科学专业优秀毕业生自荐书
2014/02/03 职场文书
医院总经理岗位职责
2014/02/04 职场文书
公安学专业求职信
2014/07/27 职场文书
解约证明模板
2015/06/19 职场文书
毕业班工作总结
2015/08/10 职场文书
pandas中DataFrame数据合并连接(merge、join、concat)
2021/05/30 Python
Python编解码问题及文本文件处理方法详解
2021/06/20 Python
MySQL约束超详解
2021/09/04 MySQL
python turtle绘制多边形和跳跃和改变速度特效
2022/03/16 Python
HTML5基础学习之文本标签控制
2022/03/25 HTML / CSS