浅谈pytorch中stack和cat的及to_tensor的坑


Posted in Python onMay 20, 2021

初入计算机视觉遇到的一些坑

1.pytorch中转tensor

x=np.random.randint(10,100,(10,10,10))
x=TF.to_tensor(x)
print(x)

这个函数会对输入数据进行自动归一化,比如有时候我们需要将0-255的图片转为numpy类型的数据,则会自动转为0-1之间

2.stack和cat之间的差别

stack

x=torch.randn((1,2,3))
y=torch.randn((1,2,3))
z=torch.stack((x,y))#默认dim=0
print(z.shape)
#torch.Size([2, 1, 2, 3])

所以stack的之后的数据也就很好理解了,z[0,...]的数据是x,z[1,...]的数据是y。

cat

z=torch.cat((x,y))
print(z.size())
#torch.Size([2, 2, 3])

cat之后的数据 z[0,:,:]是x的值,z[1,:,:]是y的值。

其中最关键的是stack之后的数据的size会多出一个维度,而cat则不会,有一个很简单的例子来说明一下,比如要训练一个检测模型,label是一些标记点,eg:[x1,y1,x2,y2]

送入网络的加上batchsize则时Size:[batchsize,4],如果我已经有了两堆数据,data1:Size[128,4],data2:Size[128,4],需要将这两个数据合在一起的话目标data:Size[256,4]。

显然我们要做的是:torch.cat((data1,data2))

如果我们的数据是这样:有100个label,每一个label被放进一个list(data)中,[[x1,y1,x2,y2],[x1,y1,x2,y2],...]其中data是一个list长度为100,而list中每一个元素是张图片的标签,size为[4]我们需要将他们合一起成为一Size:[100,4]的的数据。

显然我们要做的是torch.stack(data)。而且torch.stack的输入参数为list类型!

补充:pytorch中的cat、stack、tranpose、permute、unsqeeze

pytorch中提供了对tensor常用的变换操作。

cat 连接

对数据沿着某一维度进行拼接。cat后数据的总维数不变。

比如下面代码对两个2维tensor(分别为2*3,1*3)进行拼接,拼接完后变为3*3还是2维的tensor。

代码如下:

import torch
torch.manual_seed(1)
x = torch.randn(2,3)
y = torch.randn(1,3)
print(x,y)

结果:

0.6614 0.2669 0.0617
0.6213 -0.4519 -0.1661
[torch.FloatTensor of size 2x3]

-1.5228 0.3817 -1.0276
[torch.FloatTensor of size 1x3]

将两个tensor拼在一起:

torch.cat((x,y),0)

结果:

0.6614 0.2669 0.0617
0.6213 -0.4519 -0.1661
-1.5228 0.3817 -1.0276
[torch.FloatTensor of size 3x3]

更灵活的拼法:

torch.manual_seed(1)
x = torch.randn(2,3)
print(x)
print(torch.cat((x,x),0))
print(torch.cat((x,x),1))

结果

// x
0.6614 0.2669 0.0617
0.6213 -0.4519 -0.1661
[torch.FloatTensor of size 2x3]

// torch.cat((x,x),0)
0.6614 0.2669 0.0617
0.6213 -0.4519 -0.1661
0.6614 0.2669 0.0617
0.6213 -0.4519 -0.1661
[torch.FloatTensor of size 4x3]

// torch.cat((x,x),1)
0.6614 0.2669 0.0617 0.6614 0.2669 0.0617
0.6213 -0.4519 -0.1661 0.6213 -0.4519 -0.1661
[torch.FloatTensor of size 2x6]

stack,增加新的维度进行堆叠

而stack则会增加新的维度。

如对两个1*2维的tensor在第0个维度上stack,则会变为2*1*2的tensor;在第1个维度上stack,则会变为1*2*2的tensor。

见代码:

a = torch.ones([1,2])
b = torch.ones([1,2])
c= torch.stack([a,b],0) // 第0个维度stack

输出:

(0 ,.,.) =
1 1

(1 ,.,.) =
1 1
[torch.FloatTensor of size 2x1x2]

c= torch.stack([a,b],1) // 第1个维度stack

输出:


(0 ,.,.) =

1 1

1 1

[torch.FloatTensor of size 1x2x2]

transpose ,两个维度互换

代码如下:

torch.manual_seed(1)
x = torch.randn(2,3)
print(x)

原来x的结果:

0.6614 0.2669 0.0617

0.6213 -0.4519 -0.1661

[torch.FloatTensor of size 2x3]

将x的维度互换

x.transpose(0,1)

结果

0.6614 0.6213

0.2669 -0.4519

0.0617 -0.1661

[torch.FloatTensor of size 3x2]

permute,多个维度互换,更灵活的transpose

permute是更灵活的transpose,可以灵活的对原数据的维度进行调换,而数据本身不变。

代码如下:

x = torch.randn(2,3,4)
print(x.size())
x_p = x.permute(1,0,2) # 将原来第1维变为0维,同理,0→1,2→2
print(x_p.size())

结果:

torch.Size([2, 3, 4])

torch.Size([3, 2, 4])

squeeze 和 unsqueeze

常用来增加或减少维度,如没有batch维度时,增加batch维度为1。

squeeze(dim_n)压缩,减少dim_n维度 ,即去掉元素数量为1的dim_n维度。

unsqueeze(dim_n),增加dim_n维度,元素数量为1。

上代码:

# 定义张量
import torch

b = torch.Tensor(2,1)
b.shape
Out[28]: torch.Size([2, 1])

# 不加参数,去掉所有为元素个数为1的维度
b_ = b.squeeze()
b_.shape
Out[30]: torch.Size([2])

# 加上参数,去掉第一维的元素为1,不起作用,因为第一维有2个元素
b_ = b.squeeze(0)
b_.shape 
Out[32]: torch.Size([2, 1])

# 这样就可以了
b_ = b.squeeze(1)
b_.shape
Out[34]: torch.Size([2])

# 增加一个维度
b_ = b.unsqueeze(2)
b_.shape
Out[36]: torch.Size([2, 1, 1])

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中实现字符串类型与字典类型相互转换的方法
Aug 18 Python
进一步探究Python的装饰器的运用
May 05 Python
python中循环语句while用法实例
May 16 Python
Python中的urllib模块使用详解
Jul 07 Python
python实现斐波那契数列的方法示例
Jan 12 Python
Pycharm学习教程(5) Python快捷键相关设置
May 03 Python
Python字符串处理实例详解
May 18 Python
python使用fcntl模块实现程序加锁功能示例
Jun 23 Python
Python使用type动态创建类操作示例
Feb 29 Python
python实现ftp文件传输系统(案例分析)
Mar 20 Python
Python 实现国产SM3加密算法的示例代码
Sep 21 Python
Python爬虫获取op.gg英雄联盟英雄对位胜率的源码
Jan 29 Python
pytorch实现手写数字图片识别
解决python3安装pandas出错的问题
May 20 #Python
python 如何在list中找Topk的数值和索引
May 20 #Python
Requests什么的通通爬不了的Python超强反爬虫方案!
python使用glob检索文件的操作
python opencv通过按键采集图片源码
python 如何执行控制台命令与操作剪切板
You might like
PHP实现根据浏览器跳转不同语言页面代码
2013/08/02 PHP
php修改指定文件后缀的方法
2014/09/11 PHP
使用PHP接受文件并获得其后缀名的方法
2015/08/05 PHP
分享ThinkPHP3.2中关联查询解决思路
2015/09/20 PHP
PHP页面跳转操作实例分析(header方法)
2016/09/28 PHP
Javascript 不能释放内存.
2006/09/07 Javascript
JavaScript弹簧振子超简洁版 完全符合能量守恒,胡克定理
2009/10/25 Javascript
jquery图片上下tab切换效果
2011/03/18 Javascript
JavaScript程序员应该知道的45个实用技巧
2014/03/04 Javascript
jQuery截取指定长度字符串的实现原理及代码
2014/07/01 Javascript
jQuery制作效果超棒的手风琴折叠菜单
2015/04/03 Javascript
JS实现队列与堆栈的方法
2016/04/21 Javascript
基于jQuery实现数字滚动效果
2017/01/16 Javascript
JavaScript表单验证完美代码
2017/03/02 Javascript
jquery easyui dataGrid动态改变排序字段名的方法
2017/03/02 Javascript
基于Express框架使用POST传递Form数据
2019/08/10 Javascript
微信小程序获取当前位置和城市名
2019/11/13 Javascript
vue2路由方式--嵌套路由实现方法分析
2020/03/06 Javascript
Vue单文件组件开发实现过程详解
2020/07/30 Javascript
[00:33]2018DOTA2亚洲邀请赛TNC出场
2018/04/04 DOTA
python中随机函数random用法实例
2015/04/30 Python
Python通过DOM和SAX方式解析XML的应用实例分享
2015/11/16 Python
python中安装Scrapy模块依赖包汇总
2017/07/02 Python
使用PyCharm创建Django项目及基本配置详解
2018/10/24 Python
Python autoescape标签用法解析
2020/01/17 Python
python实现移动木板小游戏
2020/10/09 Python
耐克巴西官方网站:Nike巴西
2016/08/14 全球购物
行政文员实习自我鉴定范文
2014/09/14 职场文书
小学感恩节活动策划方案
2014/10/06 职场文书
基层干部个人对照检查及整改措施
2014/10/28 职场文书
2014年科室工作总结范文
2014/12/19 职场文书
蓬莱阁导游词
2015/02/04 职场文书
部门优秀员工推荐信
2015/03/24 职场文书
三严三实·严以用权心得体会
2016/01/12 职场文书
机关干部作风整顿心得体会
2016/01/22 职场文书
Python实现打乒乓小游戏
2021/09/25 Python