PyTorch中Tensor的拼接与拆分的实现


Posted in Python onAugust 18, 2019

拼接张量:torch.cat() 、torch.stack()

  1. torch.cat(inputs, dimension=0) → Tensor

在给定维度上对输入的张量序列 seq 进行连接操作

举个例子:

>>> import torch
>>> x = torch.randn(2, 3)
>>> x
tensor([[-0.1997, -0.6900, 0.7039],
    [ 0.0268, -1.0140, -2.9764]])
>>> torch.cat((x, x, x), 0) # 在 0 维(纵向)进行拼接
tensor([[-0.1997, -0.6900, 0.7039],
    [ 0.0268, -1.0140, -2.9764],
    [-0.1997, -0.6900, 0.7039],
    [ 0.0268, -1.0140, -2.9764],
    [-0.1997, -0.6900, 0.7039],
    [ 0.0268, -1.0140, -2.9764]])
>>> torch.cat((x, x, x), 1) # 在 1 维(横向)进行拼接
tensor([[-0.1997, -0.6900, 0.7039, -0.1997, -0.6900, 0.7039, -0.1997, -0.6900,
     0.7039],
    [ 0.0268, -1.0140, -2.9764, 0.0268, -1.0140, -2.9764, 0.0268, -1.0140,
     -2.9764]])
>>> y1 = torch.randn(5, 3, 6)
>>> y2 = torch.randn(5, 3, 6)
>>> torch.cat([y1, y2], 2).size()
torch.Size([5, 3, 12])
>>> torch.cat([y1, y2], 1).size()
torch.Size([5, 6, 6])

对于需要拼接的张量,维度数量必须相同,进行拼接的维度的尺寸可以不同,但是其它维度的尺寸必须相同。

  • torch.stack(sequence, dim=0)

沿着一个新维度对输入张量序列进行连接。 序列中所有的张量都应该为相同形状

举个例子:

>>> x1 = torch.randn(2, 3)
>>> x2 = torch.randn(2, 3)
>>> torch.stack((x1, x2), 0).size() # 在 0 维插入一个维度,进行区分拼接
torch.Size([2, 2, 3])
>>> torch.stack((x1, x2), 1).size() # 在 1 维插入一个维度,进行组合拼接
torch.Size([2, 2, 3])
>>> torch.stack((x1, x2), 2).size()
torch.Size([2, 3, 2])
>>> torch.stack((x1, x2), 0)
tensor([[[-0.3499, -0.6124, 1.4332],
     [ 0.1516, -1.5439, -0.1758]],

    [[-0.4678, -1.1430, -0.5279],
     [-0.4917, -0.6504, 2.2512]]])
>>> torch.stack((x1, x2), 1)
tensor([[[-0.3499, -0.6124, 1.4332],
     [-0.4678, -1.1430, -0.5279]],

    [[ 0.1516, -1.5439, -0.1758],
     [-0.4917, -0.6504, 2.2512]]])
>>> torch.stack((x1, x2), 2)
tensor([[[-0.3499, -0.4678],
     [-0.6124, -1.1430],
     [ 1.4332, -0.5279]],

    [[ 0.1516, -0.4917],
     [-1.5439, -0.6504],
     [-0.1758, 2.2512]]])

把相同形状的张量合并,并根据提供的维度序列在相应位置插入维度,方法会根据位置来排列数据。代码中,根据第 0 维和第 1 维来进行合并时,虽然合并后的张量维度和尺寸相等,但是数据的位置并不是相同的。

拆分张量:torch.split()、torch.chunk()

  • torch.split(tensor, split_size, dim=0)

将输入张量分割成相等形状的 chunks(如果可分)。 如果沿指定维的张量形状大小不能被 split_size 整分, 则最后一个分块会小于其它分块。

举个例子:

>>> x = torch.randn(3, 10, 6)
>>> a, b, c = x.split(1, 0) # 在 0 维进行间隔维 1 的拆分
>>> a.size(), b.size(), c.size()
(torch.Size([1, 10, 6]), torch.Size([1, 10, 6]), torch.Size([1, 10, 6]))
>>> d, e = x.split(2, 0) # 在 0 维进行间隔维 2 的拆分
>>> d.size(), e.size()
(torch.Size([2, 10, 6]), torch.Size([1, 10, 6]))

把张量在 0 维度上以间隔 1 来拆分时,其中 x 在 0 维度上的尺寸为 3,就可以分成 3 份。

把张量在 0 维度上以间隔 2 来拆分时,只能分成 2 份,且只能把前面部分先以间隔 2 来拆分,后面不足 2 的部分就直接作为一个分块。

  • torch.chunk(tensor, chunks, dim=0)

在给定维度(轴)上将输入张量进行分块儿

直接用上面的数据来举个例子:

>>> l, m, n = x.chunk(3, 0) # 在 0 维上拆分成 3 份
>>> l.size(), m.size(), n.size()
(torch.Size([1, 10, 6]), torch.Size([1, 10, 6]), torch.Size([1, 10, 6]))
>>> u, v = x.chunk(2, 0) # 在 0 维上拆分成 2 份
>>> u.size(), v.size()
(torch.Size([2, 10, 6]), torch.Size([1, 10, 6]))

把张量在 0 维度上拆分成 3 部分时,因为尺寸正好为 3,所以每个分块的间隔相等,都为 1。

把张量在 0 维度上拆分成 2 部分时,无法平均分配,以上面的结果来看,可以看成是,用 0 维度的尺寸除以需要拆分的份数,把余数作为最后一个分块的间隔大小,再把前面的分块以相同的间隔拆分。

在某一维度上拆分的份数不能比这一维度的尺寸大

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python连接MySQL、MongoDB、Redis、memcache等数据库的方法
Nov 15 Python
python操作mongodb根据_id查询数据的实现方法
May 20 Python
python使用wmi模块获取windows下的系统信息 监控系统
Oct 27 Python
python cx_Oracle模块的安装和使用详细介绍
Feb 13 Python
小白如何入门Python? 制作一个网站为例
Mar 06 Python
python 使用re.search()筛选后 选取部分结果的方法
Nov 28 Python
解决Python运行文件出现out of memory框的问题
Dec 03 Python
Python用字典构建多级菜单功能
Jul 11 Python
django rest framework vue 实现用户登录详解
Jul 29 Python
基于python3.7利用Motor来异步读写Mongodb提高效率(推荐)
Apr 29 Python
python 基于opencv操作摄像头
Dec 24 Python
python自动打开浏览器下载zip并提取内容写入excel
Jan 04 Python
详解PyTorch中Tensor的高阶操作
Aug 18 #Python
浅析PyTorch中nn.Linear的使用
Aug 18 #Python
Pytorch实现GoogLeNet的方法
Aug 18 #Python
PyTorch之图像和Tensor填充的实例
Aug 18 #Python
Pytorch Tensor的索引与切片例子
Aug 18 #Python
在PyTorch中Tensor的查找和筛选例子
Aug 18 #Python
对Pytorch神经网络初始化kaiming分布详解
Aug 18 #Python
You might like
PHP 获取文件路径(灵活应用__FILE__)
2013/02/15 PHP
PHP中的str_repeat函数在JavaScript中的实现
2013/09/16 PHP
php关闭warning问题的解决方法
2016/05/17 PHP
php简单统计中文个数的方法
2016/09/30 PHP
PHP 多任务秒级定时器的实现方法
2018/05/13 PHP
详解PHP 二维数组排序保持键名不变
2019/03/06 PHP
jquery 必填项判断表单是否为空的方法
2008/09/14 Javascript
父元素与子iframe相互获取变量和元素对象的具体实现
2013/10/15 Javascript
JS测试显示屏分辨率以及屏幕尺寸的方法
2013/11/22 Javascript
js使用removeChild方法动态删除div元素
2014/08/01 Javascript
javascript基于DOM实现省市级联下拉框的方法
2015/05/14 Javascript
javascript实现点击单选按钮链接转向对应网址的方法
2015/08/12 Javascript
JavaScript高级程序设计(第三版)学习笔记6、7章
2016/03/11 Javascript
JS前端笔试题分析
2016/12/19 Javascript
bootstrap导航条实现代码
2016/12/28 Javascript
ajax接收后台数据在html页面显示
2017/02/19 Javascript
jQuery 判断元素整理汇总
2017/02/28 Javascript
JS实现的透明度渐变动画效果示例
2018/04/28 Javascript
vue-for循环嵌套操作示例
2019/01/28 Javascript
vue实现二级导航栏效果
2019/10/19 Javascript
python中字典dict常用操作方法实例总结
2015/04/04 Python
Python实现的数据结构与算法之快速排序详解
2015/04/22 Python
介绍Python中的文档测试模块
2015/04/28 Python
python 调用钉钉机器人的方法
2019/02/20 Python
Django获取该数据的上一条和下一条方法
2019/08/12 Python
在notepad++中实现直接运行python代码
2019/12/18 Python
jupyter notebook 重装教程
2020/04/16 Python
完美解决keras 读取多个hdf5文件进行训练的问题
2020/07/01 Python
使用CSS实现弹性视频html5案例实践
2012/12/26 HTML / CSS
保安部任务及岗位职责
2014/02/25 职场文书
爱耳日宣传活动总结
2014/07/05 职场文书
校外活动方案
2014/08/28 职场文书
2015年防灾减灾工作总结
2015/07/24 职场文书
2020优秀员工演讲稿(三篇)
2019/10/17 职场文书
常用的文件对应的MIME类型汇总
2022/04/26 HTML / CSS
JS实现简单九宫格抽奖
2022/06/28 Javascript