PyTorch中Tensor的拼接与拆分的实现


Posted in Python onAugust 18, 2019

拼接张量:torch.cat() 、torch.stack()

  1. torch.cat(inputs, dimension=0) → Tensor

在给定维度上对输入的张量序列 seq 进行连接操作

举个例子:

>>> import torch
>>> x = torch.randn(2, 3)
>>> x
tensor([[-0.1997, -0.6900, 0.7039],
    [ 0.0268, -1.0140, -2.9764]])
>>> torch.cat((x, x, x), 0) # 在 0 维(纵向)进行拼接
tensor([[-0.1997, -0.6900, 0.7039],
    [ 0.0268, -1.0140, -2.9764],
    [-0.1997, -0.6900, 0.7039],
    [ 0.0268, -1.0140, -2.9764],
    [-0.1997, -0.6900, 0.7039],
    [ 0.0268, -1.0140, -2.9764]])
>>> torch.cat((x, x, x), 1) # 在 1 维(横向)进行拼接
tensor([[-0.1997, -0.6900, 0.7039, -0.1997, -0.6900, 0.7039, -0.1997, -0.6900,
     0.7039],
    [ 0.0268, -1.0140, -2.9764, 0.0268, -1.0140, -2.9764, 0.0268, -1.0140,
     -2.9764]])
>>> y1 = torch.randn(5, 3, 6)
>>> y2 = torch.randn(5, 3, 6)
>>> torch.cat([y1, y2], 2).size()
torch.Size([5, 3, 12])
>>> torch.cat([y1, y2], 1).size()
torch.Size([5, 6, 6])

对于需要拼接的张量,维度数量必须相同,进行拼接的维度的尺寸可以不同,但是其它维度的尺寸必须相同。

  • torch.stack(sequence, dim=0)

沿着一个新维度对输入张量序列进行连接。 序列中所有的张量都应该为相同形状

举个例子:

>>> x1 = torch.randn(2, 3)
>>> x2 = torch.randn(2, 3)
>>> torch.stack((x1, x2), 0).size() # 在 0 维插入一个维度,进行区分拼接
torch.Size([2, 2, 3])
>>> torch.stack((x1, x2), 1).size() # 在 1 维插入一个维度,进行组合拼接
torch.Size([2, 2, 3])
>>> torch.stack((x1, x2), 2).size()
torch.Size([2, 3, 2])
>>> torch.stack((x1, x2), 0)
tensor([[[-0.3499, -0.6124, 1.4332],
     [ 0.1516, -1.5439, -0.1758]],

    [[-0.4678, -1.1430, -0.5279],
     [-0.4917, -0.6504, 2.2512]]])
>>> torch.stack((x1, x2), 1)
tensor([[[-0.3499, -0.6124, 1.4332],
     [-0.4678, -1.1430, -0.5279]],

    [[ 0.1516, -1.5439, -0.1758],
     [-0.4917, -0.6504, 2.2512]]])
>>> torch.stack((x1, x2), 2)
tensor([[[-0.3499, -0.4678],
     [-0.6124, -1.1430],
     [ 1.4332, -0.5279]],

    [[ 0.1516, -0.4917],
     [-1.5439, -0.6504],
     [-0.1758, 2.2512]]])

把相同形状的张量合并,并根据提供的维度序列在相应位置插入维度,方法会根据位置来排列数据。代码中,根据第 0 维和第 1 维来进行合并时,虽然合并后的张量维度和尺寸相等,但是数据的位置并不是相同的。

拆分张量:torch.split()、torch.chunk()

  • torch.split(tensor, split_size, dim=0)

将输入张量分割成相等形状的 chunks(如果可分)。 如果沿指定维的张量形状大小不能被 split_size 整分, 则最后一个分块会小于其它分块。

举个例子:

>>> x = torch.randn(3, 10, 6)
>>> a, b, c = x.split(1, 0) # 在 0 维进行间隔维 1 的拆分
>>> a.size(), b.size(), c.size()
(torch.Size([1, 10, 6]), torch.Size([1, 10, 6]), torch.Size([1, 10, 6]))
>>> d, e = x.split(2, 0) # 在 0 维进行间隔维 2 的拆分
>>> d.size(), e.size()
(torch.Size([2, 10, 6]), torch.Size([1, 10, 6]))

把张量在 0 维度上以间隔 1 来拆分时,其中 x 在 0 维度上的尺寸为 3,就可以分成 3 份。

把张量在 0 维度上以间隔 2 来拆分时,只能分成 2 份,且只能把前面部分先以间隔 2 来拆分,后面不足 2 的部分就直接作为一个分块。

  • torch.chunk(tensor, chunks, dim=0)

在给定维度(轴)上将输入张量进行分块儿

直接用上面的数据来举个例子:

>>> l, m, n = x.chunk(3, 0) # 在 0 维上拆分成 3 份
>>> l.size(), m.size(), n.size()
(torch.Size([1, 10, 6]), torch.Size([1, 10, 6]), torch.Size([1, 10, 6]))
>>> u, v = x.chunk(2, 0) # 在 0 维上拆分成 2 份
>>> u.size(), v.size()
(torch.Size([2, 10, 6]), torch.Size([1, 10, 6]))

把张量在 0 维度上拆分成 3 部分时,因为尺寸正好为 3,所以每个分块的间隔相等,都为 1。

把张量在 0 维度上拆分成 2 部分时,无法平均分配,以上面的结果来看,可以看成是,用 0 维度的尺寸除以需要拆分的份数,把余数作为最后一个分块的间隔大小,再把前面的分块以相同的间隔拆分。

在某一维度上拆分的份数不能比这一维度的尺寸大

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python正则表达式修复网站文章字体不统一的解决方法
Feb 21 Python
python实现将pvr格式转换成pvr.ccz的方法
Apr 28 Python
Python实现更改图片尺寸大小的方法(基于Pillow包)
Sep 19 Python
解决出现Incorrect integer value: '' for column 'id' at row 1的问题
Oct 29 Python
Python字典,函数,全局变量代码解析
Dec 18 Python
wxPython的安装与使用教程
Aug 31 Python
Python代码打开本地.mp4格式文件的方法
Jan 03 Python
python 获取微信好友列表的方法(微信web)
Feb 21 Python
python批量修改图片尺寸,并保存指定路径的实现方法
Jul 04 Python
Python使用sklearn实现的各种回归算法示例
Jul 04 Python
简单了解python的一些位运算技巧
Jul 13 Python
Python3使用xlrd、xlwt处理Excel方法数据
Feb 28 Python
详解PyTorch中Tensor的高阶操作
Aug 18 #Python
浅析PyTorch中nn.Linear的使用
Aug 18 #Python
Pytorch实现GoogLeNet的方法
Aug 18 #Python
PyTorch之图像和Tensor填充的实例
Aug 18 #Python
Pytorch Tensor的索引与切片例子
Aug 18 #Python
在PyTorch中Tensor的查找和筛选例子
Aug 18 #Python
对Pytorch神经网络初始化kaiming分布详解
Aug 18 #Python
You might like
基于mysql的bbs设计(一)
2006/10/09 PHP
postfixadmin忘记密码后的修改密码方法详解
2016/07/20 PHP
PHP PDOStatement::bindParam讲解
2019/01/30 PHP
网页javascript精华代码集
2007/01/24 Javascript
Js获取事件对象代码
2010/08/05 Javascript
Dom与浏览器兼容性说明
2010/10/25 Javascript
也说JavaScript中String类的replace函数
2011/09/22 Javascript
Jquery图片延迟加载插件jquery.lazyload.js的使用方法
2014/05/21 Javascript
AngularJS基础 ng-mouseleave 指令详解
2016/08/02 Javascript
Vue.js自定义指令的用法与实例解析
2017/01/18 Javascript
JavaScript制作简易计算器(不用eval)
2017/02/05 Javascript
JS实现选定指定HTML元素对象中指定文本内容功能示例
2017/02/13 Javascript
关于react-router的几种配置方式详解
2017/07/24 Javascript
vue内置指令详解
2018/04/03 Javascript
vuejs中监听窗口关闭和窗口刷新事件的方法
2018/09/21 Javascript
LayUi使用switch开关,动态的去控制它是否被启用的方法
2019/09/21 Javascript
JavaScript 接口原理与用法实例详解
2020/05/12 Javascript
[06:36]吞吞映像top1
2014/06/20 DOTA
[41:21]夜魇凡尔赛茶话会 第三期02:看图识人
2021/03/11 DOTA
Python2.7.10以上pip更新及其他包的安装教程
2018/06/12 Python
使用Python更换外网IP的方法
2018/07/09 Python
Python django使用多进程连接mysql错误的解决方法
2018/10/08 Python
详解python编译器和解释器的区别
2019/06/24 Python
对python特殊函数 __call__()的使用详解
2019/07/02 Python
Python中无限循环需要什么条件
2020/05/27 Python
keras的backend 设置 tensorflow,theano操作
2020/06/30 Python
奇怪的鱼:Weird Fish
2018/03/18 全球购物
System.Array.CopyTo()和System.Array.Clone()有什么区别
2016/06/20 面试题
小学新教师培训方案
2014/02/03 职场文书
新教师培训方案
2014/06/08 职场文书
2016年大学生就业指导课心得体会
2015/10/09 职场文书
2016教师节感恩话语
2015/12/09 职场文书
施工安全责任协议书
2016/03/23 职场文书
Python人工智能之混合高斯模型运动目标检测详解分析
2021/11/07 Python
python获取带有返回值的多线程
2022/05/02 Python
windows server 2012安装FTP并配置被动模式指定开放端口
2022/06/10 Servers