pytorch下的unsqueeze和squeeze的用法说明


Posted in Python onFebruary 06, 2021

#squeeze 函数:从数组的形状中删除单维度条目,即把shape中为1的维度去掉

#unsqueeze() 是squeeze()的反向操作,增加一个维度,该维度维数为1,可以指定添加的维度。例如unsqueeze(a,1)表示在1这个维度进行添加

import torch 
a=torch.rand(2,3,1)       
print(torch.unsqueeze(a,2).size())#torch.Size([2, 3, 1, 1]) 
print(a.size())         #torch.Size([2, 3, 1])
print(a.squeeze().size())    #torch.Size([2, 3]) 
print(a.squeeze(0).size())   #torch.Size([2, 3, 1])
 
print(a.squeeze(-1).size())   #torch.Size([2, 3])
print(a.size())         #torch.Size([2, 3, 1])
print(a.squeeze(-2).size())   #torch.Size([2, 3, 1])
print(a.squeeze(-3).size())   #torch.Size([2, 3, 1])
print(a.squeeze(1).size())   #torch.Size([2, 3, 1])
print(a.squeeze(2).size())   #torch.Size([2, 3])
print(a.squeeze(3).size())   #RuntimeError: Dimension out of range (expected to be in range of [-3, 2], but got 3)
 
print(a.unsqueeze().size())   #TypeError: unsqueeze() missing 1 required positional arguments: "dim"
print(a.unsqueeze(-3).size())  #torch.Size([2, 1, 3, 1])
print(a.unsqueeze(-2).size())  #torch.Size([2, 3, 1, 1])
print(a.unsqueeze(-1).size())  #torch.Size([2, 3, 1, 1])
print(a.unsqueeze(0).size())  #torch.Size([1, 2, 3, 1])
print(a.unsqueeze(1).size())  #torch.Size([2, 1, 3, 1])
print(a.unsqueeze(2).size())  #torch.Size([2, 3, 1, 1])
print(a.unsqueeze(3).size())  #torch.Size([2, 3, 1, 1])
print(torch.unsqueeze(a,3))
b=torch.rand(2,1,3,1)
print(b.squeeze().size())    #torch.Size([2, 3])

补充:pytorch中unsqueeze()、squeeze()、expand()、repeat()、view()、和cat()函数的总结

学习Bert模型的时候,需要使用到pytorch来进行tensor的操作,由于对pytorch和tensor不熟悉,就把pytorch中常用的、有关tensor操作的unsqueeze()、squeeze()、expand()、view()、cat()和repeat()等函数做一个总结,加深记忆。

1、unsqueeze()和squeeze()

torch.unsqueeze(input, dim,out=None) → Tensor

unsqueeze()的作用是用来增加给定tensor的维度的,unsqueeze(dim)就是在维度序号为dim的地方给tensor增加一维。例如:维度为torch.Size([768])的tensor要怎样才能变为torch.Size([1, 768, 1])呢?就可以用到unsqueeze(),直接上代码:

a=torch.randn(768)
print(a.shape) # torch.Size([768])
a=a.unsqueeze(0)
print(a.shape) #torch.Size([1, 768])
a = a.unsqueeze(2)
print(a.shape) #torch.Size([1, 768, 1])

也可以直接使用链式编程:

a=torch.randn(768)
print(a.shape) # torch.Size([768])
a=a.unsqueeze(1).unsqueeze(0)
print(a.shape) #torch.Size([1, 768, 1])

tensor经过unsqueeze()处理之后,总数据量不变;维度的扩展类似于list不变直接在外面加几层[]括号。

torch.squeeze(input, dim=None, out=None) → Tensor

squeeze()的作用就是压缩维度,直接把维度为1的维给去掉。形式上表现为,去掉一层[]括号。

同时,输出的张量与原张量共享内存,如果改变其中的一个,另一个也会改变。

a=torch.randn(2,1,768)
print(a)
print(a.shape) #torch.Size([2, 1, 768])
a=a.squeeze()
print(a)
print(a.shape) #torch.Size([2, 768])

pytorch下的unsqueeze和squeeze的用法说明

图片中的维度信息就不一样,红框中的括号层数不同。

注意的是:squeeze()只能压缩维度为1的维;其他大小的维不起作用。

a=torch.randn(2,768)
print(a.shape) #torch.Size([2, 768])
a=a.squeeze()
print(a.shape) #torch.Size([2, 768])

2、expand()

这个函数的作用就是对指定的维度进行数值大小的改变。只能改变维大小为1的维,否则就会报错。不改变的维可以传入-1或者原来的数值。

torch.Tensor.expand(*sizes) → Tensor

返回张量的一个新视图,可以将张量的单个维度扩大为更大的尺寸。

a=torch.randn(1,1,3,768)
print(a) 
print(a.shape) #torch.Size([1, 1, 3, 768])
b=a.expand(2,-1,-1,-1)
print(b)
print(b.shape) #torch.Size([2, 1, 3, 768])
c=a.expand(2,1,3,768)
print(c.shape) #torch.Size([2, 1, 3, 768])

可以看到b和c的维度是一样的

pytorch下的unsqueeze和squeeze的用法说明

第0维由1变为2,可以看到就直接把原来的tensor在该维度上复制了一下。

3、repeat()

repeat(*sizes)

沿着指定的维度,对原来的tensor进行数据复制。这个函数和expand()还是有点区别的。expand()只能对维度为1的维进行扩大,而repeat()对所有的维度可以随意操作。

a=torch.randn(2,1,768)
print(a)
print(a.shape) #torch.Size([2, 1, 768])
b=a.repeat(1,2,1)
print(b)
print(b.shape) #torch.Size([2, 2, 768])
c=a.repeat(3,3,3)
print(c)
print(c.shape) #torch.Size([6, 3, 2304])

b表示对a的对应维度进行乘以1,乘以2,乘以1的操作,所以b:torch.Size([2, 1, 768])

c表示对a的对应维度进行乘以3,乘以3,乘以3的操作,所以c:torch.Size([6, 3, 2304])

a:

pytorch下的unsqueeze和squeeze的用法说明

b

pytorch下的unsqueeze和squeeze的用法说明

c

pytorch下的unsqueeze和squeeze的用法说明

4、view()

tensor.view()这个函数有点类似reshape的功能,简单的理解就是:先把一个tensor转换成一个一维的tensor,然后再组合成指定维度的tensor。例如:

word_embedding=torch.randn(16,3,768)
print(word_embedding.shape)
new_word_embedding=word_embedding.view(8,6,768)
print(new_word_embedding.shape)

当然这里指定的维度的乘积一定要和原来的tensor的维度乘积相等,不然会报错的。16*3*768=8*6*768

另外当我们需要改变一个tensor的维度的时候,知道关键的维度,有不想手动的去计算其他的维度值,就可以使用view(-1),pytorch就会自动帮你计算出来。

word_embedding=torch.randn(16,3,768)
print(word_embedding.shape)
new_word_embedding=word_embedding.view(-1)
print(new_word_embedding.shape)
new_word_embedding=word_embedding.view(1,-1)
print(new_word_embedding.shape)
new_word_embedding=word_embedding.view(-1,768)
print(new_word_embedding.shape)

结果如下:使用-1以后,就会自动得到其他维度维。

pytorch下的unsqueeze和squeeze的用法说明

需要特别注意的是:view(-1,-1)这样的用法就会出错。也就是说view()函数中只能出现单个-1。

5、cat()

cat(seq,dim,out=None),表示把两个或者多个tensor拼接起来。

其中 seq表示要连接的两个序列,以元组的形式给出,例如:seq=(a,b), a,b 为两个可以连接的序列

dim 表示以哪个维度连接,dim=0, 横向连接 dim=1,纵向连接

a=torch.randn(4,3)
b=torch.randn(4,3)
 
c=torch.cat((a,b),dim=0)#横向拼接,增加行 torch.Size([8, 3])
print(c.shape)
d=torch.cat((a,b),dim=1)#纵向拼接,增加列 torch.Size([4, 6])
print(d.shape)

还有一种写法:cat(list,dim,out=None),其中list中的元素为tensor。

tensors=[]
for i in range(10):
  tensors.append(torch.randn(4,3))
a=torch.cat(tensors,dim=0) #torch.Size([40, 3])
print(a.shape)
b=torch.cat(tensors,dim=1) #torch.Size([4, 30])
print(b.shape)

结果:

torch.Size([40, 3])
torch.Size([4, 30])

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
Windows8下安装Python的BeautifulSoup
Jan 22 Python
Python调用命令行进度条的方法
May 05 Python
python 使用get_argument获取url query参数
Apr 28 Python
python之virtualenv的简单使用方法(必看篇)
Nov 25 Python
用Python删除本地目录下某一时间点之前创建的所有文件的实例
Dec 14 Python
详谈python3 numpy-loadtxt的编码问题
Apr 29 Python
详解Django中间件执行顺序
Jul 16 Python
python3利用venv配置虚拟环境及过程中的小问题小结
Aug 01 Python
Python中实现单例模式的n种方式和原理
Nov 14 Python
在Python中COM口的调用方法
Jul 03 Python
关于tf.nn.dynamic_rnn返回值详解
Jan 20 Python
Python如何根据时间序列数据作图
May 12 Python
一文带你掌握Pyecharts地理数据可视化的方法
Feb 06 #Python
解决pycharm不能自动保存在远程linux中的问题
Feb 06 #Python
Python第三方库安装缓慢的解决方法
Feb 06 #Python
python中threading和queue库实现多线程编程
Feb 06 #Python
Python3爬虫ChromeDriver的安装实例
Feb 06 #Python
解决pycharm修改代码后第一次运行不生效的问题
Feb 06 #Python
Python tkinter之ComboBox(下拉框)的使用简介
Feb 05 #Python
You might like
PHP概述.
2006/10/09 PHP
PHP获取表单所有复选框的值的方法
2014/08/28 PHP
PHP经典面试题集锦
2015/03/19 PHP
PHP实现用户登录的案例代码
2018/05/10 PHP
学习ExtJS table布局
2009/10/08 Javascript
js性能优化 如何更快速加载你的JavaScript页面
2012/03/17 Javascript
Jquery带搜索框的下拉菜单
2013/05/06 Javascript
如何实现textarea里的不同文本显示不同颜色
2014/01/20 Javascript
jQuery判断checkbox(复选框)是否被选中以及全选、反选实现代码
2014/02/21 Javascript
浅谈jQuery this和$(this)的区别及获取$(this)子元素对象的方法
2016/11/29 Javascript
微信公众号开发 自定义菜单跳转页面并获取用户信息实例详解
2016/12/08 Javascript
jQuery中值得注意的trigger方法浅析
2016/12/12 Javascript
js实现自定义进度条效果
2017/03/15 Javascript
vue-cli如何快速构建vue项目
2017/04/26 Javascript
浅谈React深度编程之受控组件与非受控组件
2017/12/26 Javascript
webpack配置打包后图片路径出错的解决
2018/04/26 Javascript
说说如何利用 Node.js 代理解决跨域问题
2019/04/22 Javascript
vue插槽slot的简单理解与用法实例分析
2020/03/14 Javascript
JS实现点星星消除小游戏
2020/03/24 Javascript
[52:00]2018DOTA2亚洲邀请赛 4.1 小组赛 A组加赛 LGD vs Optic
2018/04/02 DOTA
Python多进程编程技术实例分析
2014/09/16 Python
Python爬虫实现抓取京东店铺信息及下载图片功能示例
2018/08/07 Python
python 同时运行多个程序的实例
2019/01/07 Python
对Python中DataFrame选择某列值为XX的行实例详解
2019/01/29 Python
wxPython实现整点报时
2019/11/18 Python
Python自动化办公Excel模块openpyxl原理及用法解析
2020/11/05 Python
CSS3中伪元素::before和::after的用法示例
2017/09/18 HTML / CSS
俄罗斯Sportmarket体育在线商店:用于旅游和户外活动
2019/11/12 全球购物
毕业生找工作的求职信范文
2013/12/24 职场文书
预备党员思想汇报范文
2013/12/29 职场文书
护士自我评价
2014/02/01 职场文书
力学专业求职信
2014/07/23 职场文书
要账委托书范本
2014/09/15 职场文书
2014年技术部工作总结
2014/12/12 职场文书
网络研修心得体会
2016/01/08 职场文书
Go语言入门exec的基本使用
2022/05/20 Golang