浅谈pytorch中stack和cat的及to_tensor的坑


Posted in Python onMay 20, 2021

初入计算机视觉遇到的一些坑

1.pytorch中转tensor

x=np.random.randint(10,100,(10,10,10))
x=TF.to_tensor(x)
print(x)

这个函数会对输入数据进行自动归一化,比如有时候我们需要将0-255的图片转为numpy类型的数据,则会自动转为0-1之间

2.stack和cat之间的差别

stack

x=torch.randn((1,2,3))
y=torch.randn((1,2,3))
z=torch.stack((x,y))#默认dim=0
print(z.shape)
#torch.Size([2, 1, 2, 3])

所以stack的之后的数据也就很好理解了,z[0,...]的数据是x,z[1,...]的数据是y。

cat

z=torch.cat((x,y))
print(z.size())
#torch.Size([2, 2, 3])

cat之后的数据 z[0,:,:]是x的值,z[1,:,:]是y的值。

其中最关键的是stack之后的数据的size会多出一个维度,而cat则不会,有一个很简单的例子来说明一下,比如要训练一个检测模型,label是一些标记点,eg:[x1,y1,x2,y2]

送入网络的加上batchsize则时Size:[batchsize,4],如果我已经有了两堆数据,data1:Size[128,4],data2:Size[128,4],需要将这两个数据合在一起的话目标data:Size[256,4]。

显然我们要做的是:torch.cat((data1,data2))

如果我们的数据是这样:有100个label,每一个label被放进一个list(data)中,[[x1,y1,x2,y2],[x1,y1,x2,y2],...]其中data是一个list长度为100,而list中每一个元素是张图片的标签,size为[4]我们需要将他们合一起成为一Size:[100,4]的的数据。

显然我们要做的是torch.stack(data)。而且torch.stack的输入参数为list类型!

补充:pytorch中的cat、stack、tranpose、permute、unsqeeze

pytorch中提供了对tensor常用的变换操作。

cat 连接

对数据沿着某一维度进行拼接。cat后数据的总维数不变。

比如下面代码对两个2维tensor(分别为2*3,1*3)进行拼接,拼接完后变为3*3还是2维的tensor。

代码如下:

import torch
torch.manual_seed(1)
x = torch.randn(2,3)
y = torch.randn(1,3)
print(x,y)

结果:

0.6614 0.2669 0.0617
0.6213 -0.4519 -0.1661
[torch.FloatTensor of size 2x3]

-1.5228 0.3817 -1.0276
[torch.FloatTensor of size 1x3]

将两个tensor拼在一起:

torch.cat((x,y),0)

结果:

0.6614 0.2669 0.0617
0.6213 -0.4519 -0.1661
-1.5228 0.3817 -1.0276
[torch.FloatTensor of size 3x3]

更灵活的拼法:

torch.manual_seed(1)
x = torch.randn(2,3)
print(x)
print(torch.cat((x,x),0))
print(torch.cat((x,x),1))

结果

// x
0.6614 0.2669 0.0617
0.6213 -0.4519 -0.1661
[torch.FloatTensor of size 2x3]

// torch.cat((x,x),0)
0.6614 0.2669 0.0617
0.6213 -0.4519 -0.1661
0.6614 0.2669 0.0617
0.6213 -0.4519 -0.1661
[torch.FloatTensor of size 4x3]

// torch.cat((x,x),1)
0.6614 0.2669 0.0617 0.6614 0.2669 0.0617
0.6213 -0.4519 -0.1661 0.6213 -0.4519 -0.1661
[torch.FloatTensor of size 2x6]

stack,增加新的维度进行堆叠

而stack则会增加新的维度。

如对两个1*2维的tensor在第0个维度上stack,则会变为2*1*2的tensor;在第1个维度上stack,则会变为1*2*2的tensor。

见代码:

a = torch.ones([1,2])
b = torch.ones([1,2])
c= torch.stack([a,b],0) // 第0个维度stack

输出:

(0 ,.,.) =
1 1

(1 ,.,.) =
1 1
[torch.FloatTensor of size 2x1x2]

c= torch.stack([a,b],1) // 第1个维度stack

输出:


(0 ,.,.) =

1 1

1 1

[torch.FloatTensor of size 1x2x2]

transpose ,两个维度互换

代码如下:

torch.manual_seed(1)
x = torch.randn(2,3)
print(x)

原来x的结果:

0.6614 0.2669 0.0617

0.6213 -0.4519 -0.1661

[torch.FloatTensor of size 2x3]

将x的维度互换

x.transpose(0,1)

结果

0.6614 0.6213

0.2669 -0.4519

0.0617 -0.1661

[torch.FloatTensor of size 3x2]

permute,多个维度互换,更灵活的transpose

permute是更灵活的transpose,可以灵活的对原数据的维度进行调换,而数据本身不变。

代码如下:

x = torch.randn(2,3,4)
print(x.size())
x_p = x.permute(1,0,2) # 将原来第1维变为0维,同理,0→1,2→2
print(x_p.size())

结果:

torch.Size([2, 3, 4])

torch.Size([3, 2, 4])

squeeze 和 unsqueeze

常用来增加或减少维度,如没有batch维度时,增加batch维度为1。

squeeze(dim_n)压缩,减少dim_n维度 ,即去掉元素数量为1的dim_n维度。

unsqueeze(dim_n),增加dim_n维度,元素数量为1。

上代码:

# 定义张量
import torch

b = torch.Tensor(2,1)
b.shape
Out[28]: torch.Size([2, 1])

# 不加参数,去掉所有为元素个数为1的维度
b_ = b.squeeze()
b_.shape
Out[30]: torch.Size([2])

# 加上参数,去掉第一维的元素为1,不起作用,因为第一维有2个元素
b_ = b.squeeze(0)
b_.shape 
Out[32]: torch.Size([2, 1])

# 这样就可以了
b_ = b.squeeze(1)
b_.shape
Out[34]: torch.Size([2])

# 增加一个维度
b_ = b.unsqueeze(2)
b_.shape
Out[36]: torch.Size([2, 1, 1])

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中列表元素连接方法join用法实例
Apr 07 Python
Django日志模块logging的配置详解
Feb 14 Python
Python实现上下班抢个顺风单脚本
Feb 07 Python
python如何读写csv数据
Mar 21 Python
Python使用add_subplot与subplot画子图操作示例
Jun 01 Python
tensorflow实现图像的裁剪和填充方法
Jul 27 Python
解决python ogr shp字段写入中文乱码的问题
Dec 31 Python
python实现提取COCO,VOC数据集中特定的类
Mar 10 Python
tensorflow模型文件(ckpt)转pb文件的方法(不知道输出节点名)
Apr 22 Python
pytorch查看通道数 维数 尺寸大小方式
May 26 Python
Python如何给函数库增加日志功能
Aug 04 Python
python3.9.1环境安装的方法(图文)
Feb 02 Python
pytorch实现手写数字图片识别
解决python3安装pandas出错的问题
May 20 #Python
python 如何在list中找Topk的数值和索引
May 20 #Python
Requests什么的通通爬不了的Python超强反爬虫方案!
python使用glob检索文件的操作
python opencv通过按键采集图片源码
python 如何执行控制台命令与操作剪切板
You might like
通过文字传递创建的图形按钮
2006/10/09 PHP
使用adodb lite解决问题
2006/12/31 PHP
php生成的html meta和link标记在body标签里 顶部有个空行
2010/05/18 PHP
php中simplexml_load_file函数用法实例
2014/11/12 PHP
php版微信支付api.mch.weixin.qq.com域名解析慢原因与解决方法
2016/10/12 PHP
setTimeout和setInterval的浏览器兼容性分析
2007/02/27 Javascript
javascript preload&lazy load
2010/05/13 Javascript
js变量以及其作用域详解
2020/07/18 Javascript
S2SH整合JQuery+Ajax实现登录验证功能实现代码
2013/01/30 Javascript
深入document.write()与HTML4.01的非成对标签的详解
2013/05/08 Javascript
中止javascript执行的方法
2014/02/14 Javascript
浅析BootStrap中Modal(模态框)使用心得
2016/12/24 Javascript
javascript设计模式之策略模式学习笔记
2017/02/15 Javascript
JavaScript箭头函数_动力节点Java学院整理
2017/06/28 Javascript
微信小程序的生命周期的详解
2017/10/19 Javascript
详解js 创建对象的几种方法
2019/03/08 Javascript
Vue+Element实现动态生成新表单并添加验证功能
2019/05/23 Javascript
koa+jwt实现token验证与刷新功能
2019/05/30 Javascript
vue实现记事本功能
2019/06/26 Javascript
详解在Python和IPython中使用Docker
2015/04/28 Python
pygame播放音乐的方法
2015/05/19 Python
Python实现短网址ShortUrl的Hash运算实例讲解
2015/08/10 Python
使用pyecharts无法import Bar的解决方案
2020/04/23 Python
python 动态生成变量名以及动态获取变量的变量名方法
2019/01/20 Python
Python中remove漏删和索引越界问题的解决
2020/03/18 Python
基于django micro搭建网站实现加水印功能
2020/05/22 Python
PyCharm+Miniconda3安装配置教程详解
2021/02/16 Python
美国相机和电子产品零售商:Beach Camera
2020/11/26 全球购物
网站推广策划方案
2014/06/04 职场文书
财务会计专业求职信
2014/06/09 职场文书
效能风暴心得体会
2014/09/04 职场文书
法定代表人授权委托书
2014/09/19 职场文书
开展党的群众路线教育实践活动工作总结
2014/11/05 职场文书
作弊检讨书
2015/01/27 职场文书
2016年安全生产先进个人事迹材料
2016/02/29 职场文书
浅谈node.js中间件有哪些类型
2021/04/29 Javascript