浅谈pytorch中stack和cat的及to_tensor的坑


Posted in Python onMay 20, 2021

初入计算机视觉遇到的一些坑

1.pytorch中转tensor

x=np.random.randint(10,100,(10,10,10))
x=TF.to_tensor(x)
print(x)

这个函数会对输入数据进行自动归一化,比如有时候我们需要将0-255的图片转为numpy类型的数据,则会自动转为0-1之间

2.stack和cat之间的差别

stack

x=torch.randn((1,2,3))
y=torch.randn((1,2,3))
z=torch.stack((x,y))#默认dim=0
print(z.shape)
#torch.Size([2, 1, 2, 3])

所以stack的之后的数据也就很好理解了,z[0,...]的数据是x,z[1,...]的数据是y。

cat

z=torch.cat((x,y))
print(z.size())
#torch.Size([2, 2, 3])

cat之后的数据 z[0,:,:]是x的值,z[1,:,:]是y的值。

其中最关键的是stack之后的数据的size会多出一个维度,而cat则不会,有一个很简单的例子来说明一下,比如要训练一个检测模型,label是一些标记点,eg:[x1,y1,x2,y2]

送入网络的加上batchsize则时Size:[batchsize,4],如果我已经有了两堆数据,data1:Size[128,4],data2:Size[128,4],需要将这两个数据合在一起的话目标data:Size[256,4]。

显然我们要做的是:torch.cat((data1,data2))

如果我们的数据是这样:有100个label,每一个label被放进一个list(data)中,[[x1,y1,x2,y2],[x1,y1,x2,y2],...]其中data是一个list长度为100,而list中每一个元素是张图片的标签,size为[4]我们需要将他们合一起成为一Size:[100,4]的的数据。

显然我们要做的是torch.stack(data)。而且torch.stack的输入参数为list类型!

补充:pytorch中的cat、stack、tranpose、permute、unsqeeze

pytorch中提供了对tensor常用的变换操作。

cat 连接

对数据沿着某一维度进行拼接。cat后数据的总维数不变。

比如下面代码对两个2维tensor(分别为2*3,1*3)进行拼接,拼接完后变为3*3还是2维的tensor。

代码如下:

import torch
torch.manual_seed(1)
x = torch.randn(2,3)
y = torch.randn(1,3)
print(x,y)

结果:

0.6614 0.2669 0.0617
0.6213 -0.4519 -0.1661
[torch.FloatTensor of size 2x3]

-1.5228 0.3817 -1.0276
[torch.FloatTensor of size 1x3]

将两个tensor拼在一起:

torch.cat((x,y),0)

结果:

0.6614 0.2669 0.0617
0.6213 -0.4519 -0.1661
-1.5228 0.3817 -1.0276
[torch.FloatTensor of size 3x3]

更灵活的拼法:

torch.manual_seed(1)
x = torch.randn(2,3)
print(x)
print(torch.cat((x,x),0))
print(torch.cat((x,x),1))

结果

// x
0.6614 0.2669 0.0617
0.6213 -0.4519 -0.1661
[torch.FloatTensor of size 2x3]

// torch.cat((x,x),0)
0.6614 0.2669 0.0617
0.6213 -0.4519 -0.1661
0.6614 0.2669 0.0617
0.6213 -0.4519 -0.1661
[torch.FloatTensor of size 4x3]

// torch.cat((x,x),1)
0.6614 0.2669 0.0617 0.6614 0.2669 0.0617
0.6213 -0.4519 -0.1661 0.6213 -0.4519 -0.1661
[torch.FloatTensor of size 2x6]

stack,增加新的维度进行堆叠

而stack则会增加新的维度。

如对两个1*2维的tensor在第0个维度上stack,则会变为2*1*2的tensor;在第1个维度上stack,则会变为1*2*2的tensor。

见代码:

a = torch.ones([1,2])
b = torch.ones([1,2])
c= torch.stack([a,b],0) // 第0个维度stack

输出:

(0 ,.,.) =
1 1

(1 ,.,.) =
1 1
[torch.FloatTensor of size 2x1x2]

c= torch.stack([a,b],1) // 第1个维度stack

输出:


(0 ,.,.) =

1 1

1 1

[torch.FloatTensor of size 1x2x2]

transpose ,两个维度互换

代码如下:

torch.manual_seed(1)
x = torch.randn(2,3)
print(x)

原来x的结果:

0.6614 0.2669 0.0617

0.6213 -0.4519 -0.1661

[torch.FloatTensor of size 2x3]

将x的维度互换

x.transpose(0,1)

结果

0.6614 0.6213

0.2669 -0.4519

0.0617 -0.1661

[torch.FloatTensor of size 3x2]

permute,多个维度互换,更灵活的transpose

permute是更灵活的transpose,可以灵活的对原数据的维度进行调换,而数据本身不变。

代码如下:

x = torch.randn(2,3,4)
print(x.size())
x_p = x.permute(1,0,2) # 将原来第1维变为0维,同理,0→1,2→2
print(x_p.size())

结果:

torch.Size([2, 3, 4])

torch.Size([3, 2, 4])

squeeze 和 unsqueeze

常用来增加或减少维度,如没有batch维度时,增加batch维度为1。

squeeze(dim_n)压缩,减少dim_n维度 ,即去掉元素数量为1的dim_n维度。

unsqueeze(dim_n),增加dim_n维度,元素数量为1。

上代码:

# 定义张量
import torch

b = torch.Tensor(2,1)
b.shape
Out[28]: torch.Size([2, 1])

# 不加参数,去掉所有为元素个数为1的维度
b_ = b.squeeze()
b_.shape
Out[30]: torch.Size([2])

# 加上参数,去掉第一维的元素为1,不起作用,因为第一维有2个元素
b_ = b.squeeze(0)
b_.shape 
Out[32]: torch.Size([2, 1])

# 这样就可以了
b_ = b.squeeze(1)
b_.shape
Out[34]: torch.Size([2])

# 增加一个维度
b_ = b.unsqueeze(2)
b_.shape
Out[36]: torch.Size([2, 1, 1])

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python yield使用方法示例
Dec 04 Python
python查找第k小元素代码分享
Dec 18 Python
python实现ipsec开权限实例
Nov 11 Python
Django实现图片文字同时提交的方法
May 26 Python
Python合并字典键值并去除重复元素的实例
Dec 18 Python
Python将多份excel表格整理成一份表格
Jan 03 Python
使用numba对Python运算加速的方法
Oct 15 Python
使用 Visual Studio Code(VSCode)搭建简单的Python+Django开发环境的方法步骤
Dec 17 Python
python GUI库图形界面开发之PyQt5下拉列表框控件QComboBox详细使用方法与实例
Feb 27 Python
PyQt5 控件字体样式等设置的实现
May 13 Python
使用Python爬取小姐姐图片(beautifulsoup法)
Feb 11 Python
python上下文管理的使用场景实例讲解
Mar 03 Python
pytorch实现手写数字图片识别
解决python3安装pandas出错的问题
May 20 #Python
python 如何在list中找Topk的数值和索引
May 20 #Python
Requests什么的通通爬不了的Python超强反爬虫方案!
python使用glob检索文件的操作
python opencv通过按键采集图片源码
python 如何执行控制台命令与操作剪切板
You might like
实用函数4
2007/11/08 PHP
浅谈PHP中关于foreach使用引用变量的坑
2016/11/14 PHP
利用jQuery插件扩展识别浏览器内核与外壳的类型和版本的实现代码
2011/10/22 Javascript
Firefox/Chrome/Safari的中可直接使用$/$$函数进行调试
2012/02/13 Javascript
JS对象与JSON格式数据相互转换
2012/02/20 Javascript
根据json字符串生成Html的一种方式
2013/01/09 Javascript
button没写type=button会导致点击时提交
2014/03/06 Javascript
JavaScript中的replace()方法使用详解
2015/06/06 Javascript
JavaScript入门基础
2015/08/12 Javascript
jQuery Validate表单验证插件 添加class属性形式的校验
2016/01/18 Javascript
AngularJS实现网站换肤实例
2021/02/19 Javascript
基于Vue实例对象的数据选项
2017/08/09 Javascript
详解vue-cli与webpack结合如何处理静态资源
2017/09/19 Javascript
vue的状态管理模式vuex
2017/11/30 Javascript
利用ES6实现单例模式及其应用详解
2017/12/09 Javascript
[原创]jQuery实现合并/追加数组并去除重复项的方法
2018/04/11 jQuery
js canvas实现画图、滤镜效果
2018/11/27 Javascript
用 js 写一个 js 解释器过程详解
2019/08/02 Javascript
js实现的订阅发布者模式简单示例
2020/03/14 Javascript
将Emacs打造成强大的Python代码编辑工具
2015/11/20 Python
Python Requests 基础入门
2016/04/07 Python
Python 模板引擎的注入问题分析
2017/01/01 Python
Python图片裁剪实例代码(如头像裁剪)
2017/06/21 Python
django rest framework之请求与响应(详解)
2017/11/06 Python
pygame游戏之旅 调用按钮实现游戏开始功能
2018/11/21 Python
对python判断是否回文数的实例详解
2019/02/08 Python
pyqt5 comboBox获得下标、文本和事件选中函数的方法
2019/06/14 Python
基于torch.where和布尔索引的速度比较
2020/01/02 Python
Python reques接口测试框架实现代码
2020/07/28 Python
python判断字符串以什么结尾的实例方法
2020/09/18 Python
css3学习心得分享
2013/08/19 HTML / CSS
架构师岗位职责
2013/11/18 职场文书
小学教育毕业生自荐信
2013/11/18 职场文书
法定代表人身份证明书
2014/09/10 职场文书
志愿者个人总结
2015/03/03 职场文书
详解Alibaba Java诊断工具Arthas查看Dubbo动态代理类
2022/04/08 Java/Android