pytorch下的unsqueeze和squeeze的用法说明


Posted in Python onFebruary 06, 2021

#squeeze 函数:从数组的形状中删除单维度条目,即把shape中为1的维度去掉

#unsqueeze() 是squeeze()的反向操作,增加一个维度,该维度维数为1,可以指定添加的维度。例如unsqueeze(a,1)表示在1这个维度进行添加

import torch 
a=torch.rand(2,3,1)       
print(torch.unsqueeze(a,2).size())#torch.Size([2, 3, 1, 1]) 
print(a.size())         #torch.Size([2, 3, 1])
print(a.squeeze().size())    #torch.Size([2, 3]) 
print(a.squeeze(0).size())   #torch.Size([2, 3, 1])
 
print(a.squeeze(-1).size())   #torch.Size([2, 3])
print(a.size())         #torch.Size([2, 3, 1])
print(a.squeeze(-2).size())   #torch.Size([2, 3, 1])
print(a.squeeze(-3).size())   #torch.Size([2, 3, 1])
print(a.squeeze(1).size())   #torch.Size([2, 3, 1])
print(a.squeeze(2).size())   #torch.Size([2, 3])
print(a.squeeze(3).size())   #RuntimeError: Dimension out of range (expected to be in range of [-3, 2], but got 3)
 
print(a.unsqueeze().size())   #TypeError: unsqueeze() missing 1 required positional arguments: "dim"
print(a.unsqueeze(-3).size())  #torch.Size([2, 1, 3, 1])
print(a.unsqueeze(-2).size())  #torch.Size([2, 3, 1, 1])
print(a.unsqueeze(-1).size())  #torch.Size([2, 3, 1, 1])
print(a.unsqueeze(0).size())  #torch.Size([1, 2, 3, 1])
print(a.unsqueeze(1).size())  #torch.Size([2, 1, 3, 1])
print(a.unsqueeze(2).size())  #torch.Size([2, 3, 1, 1])
print(a.unsqueeze(3).size())  #torch.Size([2, 3, 1, 1])
print(torch.unsqueeze(a,3))
b=torch.rand(2,1,3,1)
print(b.squeeze().size())    #torch.Size([2, 3])

补充:pytorch中unsqueeze()、squeeze()、expand()、repeat()、view()、和cat()函数的总结

学习Bert模型的时候,需要使用到pytorch来进行tensor的操作,由于对pytorch和tensor不熟悉,就把pytorch中常用的、有关tensor操作的unsqueeze()、squeeze()、expand()、view()、cat()和repeat()等函数做一个总结,加深记忆。

1、unsqueeze()和squeeze()

torch.unsqueeze(input, dim,out=None) → Tensor

unsqueeze()的作用是用来增加给定tensor的维度的,unsqueeze(dim)就是在维度序号为dim的地方给tensor增加一维。例如:维度为torch.Size([768])的tensor要怎样才能变为torch.Size([1, 768, 1])呢?就可以用到unsqueeze(),直接上代码:

a=torch.randn(768)
print(a.shape) # torch.Size([768])
a=a.unsqueeze(0)
print(a.shape) #torch.Size([1, 768])
a = a.unsqueeze(2)
print(a.shape) #torch.Size([1, 768, 1])

也可以直接使用链式编程:

a=torch.randn(768)
print(a.shape) # torch.Size([768])
a=a.unsqueeze(1).unsqueeze(0)
print(a.shape) #torch.Size([1, 768, 1])

tensor经过unsqueeze()处理之后,总数据量不变;维度的扩展类似于list不变直接在外面加几层[]括号。

torch.squeeze(input, dim=None, out=None) → Tensor

squeeze()的作用就是压缩维度,直接把维度为1的维给去掉。形式上表现为,去掉一层[]括号。

同时,输出的张量与原张量共享内存,如果改变其中的一个,另一个也会改变。

a=torch.randn(2,1,768)
print(a)
print(a.shape) #torch.Size([2, 1, 768])
a=a.squeeze()
print(a)
print(a.shape) #torch.Size([2, 768])

pytorch下的unsqueeze和squeeze的用法说明

图片中的维度信息就不一样,红框中的括号层数不同。

注意的是:squeeze()只能压缩维度为1的维;其他大小的维不起作用。

a=torch.randn(2,768)
print(a.shape) #torch.Size([2, 768])
a=a.squeeze()
print(a.shape) #torch.Size([2, 768])

2、expand()

这个函数的作用就是对指定的维度进行数值大小的改变。只能改变维大小为1的维,否则就会报错。不改变的维可以传入-1或者原来的数值。

torch.Tensor.expand(*sizes) → Tensor

返回张量的一个新视图,可以将张量的单个维度扩大为更大的尺寸。

a=torch.randn(1,1,3,768)
print(a) 
print(a.shape) #torch.Size([1, 1, 3, 768])
b=a.expand(2,-1,-1,-1)
print(b)
print(b.shape) #torch.Size([2, 1, 3, 768])
c=a.expand(2,1,3,768)
print(c.shape) #torch.Size([2, 1, 3, 768])

可以看到b和c的维度是一样的

pytorch下的unsqueeze和squeeze的用法说明

第0维由1变为2,可以看到就直接把原来的tensor在该维度上复制了一下。

3、repeat()

repeat(*sizes)

沿着指定的维度,对原来的tensor进行数据复制。这个函数和expand()还是有点区别的。expand()只能对维度为1的维进行扩大,而repeat()对所有的维度可以随意操作。

a=torch.randn(2,1,768)
print(a)
print(a.shape) #torch.Size([2, 1, 768])
b=a.repeat(1,2,1)
print(b)
print(b.shape) #torch.Size([2, 2, 768])
c=a.repeat(3,3,3)
print(c)
print(c.shape) #torch.Size([6, 3, 2304])

b表示对a的对应维度进行乘以1,乘以2,乘以1的操作,所以b:torch.Size([2, 1, 768])

c表示对a的对应维度进行乘以3,乘以3,乘以3的操作,所以c:torch.Size([6, 3, 2304])

a:

pytorch下的unsqueeze和squeeze的用法说明

b

pytorch下的unsqueeze和squeeze的用法说明

c

pytorch下的unsqueeze和squeeze的用法说明

4、view()

tensor.view()这个函数有点类似reshape的功能,简单的理解就是:先把一个tensor转换成一个一维的tensor,然后再组合成指定维度的tensor。例如:

word_embedding=torch.randn(16,3,768)
print(word_embedding.shape)
new_word_embedding=word_embedding.view(8,6,768)
print(new_word_embedding.shape)

当然这里指定的维度的乘积一定要和原来的tensor的维度乘积相等,不然会报错的。16*3*768=8*6*768

另外当我们需要改变一个tensor的维度的时候,知道关键的维度,有不想手动的去计算其他的维度值,就可以使用view(-1),pytorch就会自动帮你计算出来。

word_embedding=torch.randn(16,3,768)
print(word_embedding.shape)
new_word_embedding=word_embedding.view(-1)
print(new_word_embedding.shape)
new_word_embedding=word_embedding.view(1,-1)
print(new_word_embedding.shape)
new_word_embedding=word_embedding.view(-1,768)
print(new_word_embedding.shape)

结果如下:使用-1以后,就会自动得到其他维度维。

pytorch下的unsqueeze和squeeze的用法说明

需要特别注意的是:view(-1,-1)这样的用法就会出错。也就是说view()函数中只能出现单个-1。

5、cat()

cat(seq,dim,out=None),表示把两个或者多个tensor拼接起来。

其中 seq表示要连接的两个序列,以元组的形式给出,例如:seq=(a,b), a,b 为两个可以连接的序列

dim 表示以哪个维度连接,dim=0, 横向连接 dim=1,纵向连接

a=torch.randn(4,3)
b=torch.randn(4,3)
 
c=torch.cat((a,b),dim=0)#横向拼接,增加行 torch.Size([8, 3])
print(c.shape)
d=torch.cat((a,b),dim=1)#纵向拼接,增加列 torch.Size([4, 6])
print(d.shape)

还有一种写法:cat(list,dim,out=None),其中list中的元素为tensor。

tensors=[]
for i in range(10):
  tensors.append(torch.randn(4,3))
a=torch.cat(tensors,dim=0) #torch.Size([40, 3])
print(a.shape)
b=torch.cat(tensors,dim=1) #torch.Size([4, 30])
print(b.shape)

结果:

torch.Size([40, 3])
torch.Size([4, 30])

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
零基础写python爬虫之神器正则表达式
Nov 06 Python
Python实现的Google IP 可用性检测脚本
Apr 23 Python
编写Python小程序来统计测试脚本的关键字
Mar 12 Python
Python设置默认编码为utf8的方法
Jul 01 Python
flask框架视图函数用法示例
Jul 19 Python
selenium+python自动化测试之环境搭建
Jan 23 Python
Python实现程序判断季节的代码示例
Jan 28 Python
Python操作excel的方法总结(xlrd、xlwt、openpyxl)
Sep 02 Python
python正则表达式实例代码
Mar 03 Python
python绘制汉诺塔
Mar 01 Python
Python中Selenium对Cookie的操作方法
Jul 09 Python
python可视化分析绘制带趋势线的散点图和边缘直方图
Jun 25 Python
一文带你掌握Pyecharts地理数据可视化的方法
Feb 06 #Python
解决pycharm不能自动保存在远程linux中的问题
Feb 06 #Python
Python第三方库安装缓慢的解决方法
Feb 06 #Python
python中threading和queue库实现多线程编程
Feb 06 #Python
Python3爬虫ChromeDriver的安装实例
Feb 06 #Python
解决pycharm修改代码后第一次运行不生效的问题
Feb 06 #Python
Python tkinter之ComboBox(下拉框)的使用简介
Feb 05 #Python
You might like
php 字符转义 注意事项
2009/05/27 PHP
smarty模板局部缓存方法使用示例
2014/06/17 PHP
说说掌握JavaScript语言的思想前提想学习js的朋友可以看看
2009/04/01 Javascript
ajax与302响应代码测试
2013/10/23 Javascript
js使用心得分享
2015/01/13 Javascript
JavaScript定时显示广告代码分享
2015/03/02 Javascript
Javascript中String的常用方法实例分析
2015/06/13 Javascript
jQuery实现可以编辑的表格实例详解【附demo源码下载】
2016/07/09 Javascript
JavaScript运动框架 多值运动(四)
2017/05/18 Javascript
php中and 和 &&出坑指南
2018/07/13 Javascript
Vue.js 事件修饰符的使用教程
2018/11/01 Javascript
vue使用websocket的方法实例分析
2019/06/22 Javascript
JavaScript中var的重要性实例分析
2019/07/09 Javascript
Emberjs 通过 axios 下载文件的方法
2019/09/03 Javascript
全面解析js中的原型,原型对象,原型链
2021/01/25 Javascript
[04:29]2014DOTA2国际邀请赛 主赛事第三日TOPPLAY
2014/07/21 DOTA
Python装饰器的函数式编程详解
2015/02/27 Python
使用Python编写简单的画图板程序的示例教程
2015/12/08 Python
理解Python垃圾回收机制
2016/02/12 Python
Python处理PDF及生成多层PDF实例代码
2017/04/24 Python
Flask框架URL管理操作示例【基于@app.route】
2018/07/23 Python
Django 后台获取文件列表 InMemoryUploadedFile的例子
2019/08/07 Python
Python 装饰器@,对函数进行功能扩展操作示例【开闭原则】
2019/10/17 Python
Python利用全连接神经网络求解MNIST问题详解
2020/01/14 Python
Python切割图片成九宫格的示例代码
2020/03/10 Python
什么是python的函数体
2020/06/19 Python
解决Python 写文件报错TypeError的问题
2020/10/23 Python
迪士尼西班牙官方网上商店:ShopDisney西班牙
2020/02/02 全球购物
生物科学系大学生的自我评价
2013/12/20 职场文书
实习生评语
2014/04/26 职场文书
实习公司领导推荐函
2014/05/21 职场文书
处级领导班子全部召开专题民主生活会情况汇报
2014/09/27 职场文书
2016年感恩节寄语
2015/12/07 职场文书
导游词之张家口
2019/12/13 职场文书
深入理解python多线程编程
2021/04/18 Python
教你用eclipse连接mysql数据库
2021/04/22 MySQL