对Pytorch 中的contiguous理解说明


Posted in Python onMarch 03, 2021

最近遇到这个函数,但查的中文博客里的解释貌似不是很到位,这里翻译一下stackoverflow上的回答并加上自己的理解。

在pytorch中,只有很少几个操作是不改变tensor的内容本身,而只是重新定义下标与元素的对应关系的。换句话说,这种操作不进行数据拷贝和数据的改变,变的是元数据。

这些操作是:

narrow(),view(),expand()和transpose()

举个栗子,在使用transpose()进行转置操作时,pytorch并不会创建新的、转置后的tensor,而是修改了tensor中的一些属性(也就是元数据),使得此时的offset和stride是与转置tensor相对应的。

转置的tensor和原tensor的内存是共享的!

为了证明这一点,我们来看下面的代码:

x = torch.randn(3, 2)
y = x.transpose(x, 0, 1)
x[0, 0] = 233
print(y[0, 0])
# print 233

可以看到,改变了y的元素的值的同时,x的元素的值也发生了变化。

也就是说,经过上述操作后得到的tensor,它内部数据的布局方式和从头开始创建一个这样的常规的tensor的布局方式是不一样的!于是…这就有contiguous()的用武之地了。

在上面的例子中,x是contiguous的,但y不是(因为内部数据不是通常的布局方式)。

注意不要被contiguous的字面意思“连续的”误解,tensor中数据还是在内存中一块区域里,只是布局的问题!

当调用contiguous()时,会强制拷贝一份tensor,让它的布局和从头创建的一毛一样。

一般来说这一点不用太担心,如果你没在需要调用contiguous()的地方调用contiguous(),运行时会提示你:

RuntimeError: input is not contiguous

只要看到这个错误提示,加上contiguous()就好啦~

补充:pytorch之expand,gather,squeeze,sum,contiguous,softmax,max,argmax

gather

torch.gather(input,dim,index,out=None)。对指定维进行索引。比如4*3的张量,对dim=1进行索引,那么index的取值范围就是0~2.

input是一个张量,index是索引张量。input和index的size要么全部维度都相同,要么指定的dim那一维度值不同。输出为和index大小相同的张量。

import torch
a=torch.tensor([[.1,.2,.3],
        [1.1,1.2,1.3],
        [2.1,2.2,2.3],
        [3.1,3.2,3.3]])
b=torch.LongTensor([[1,2,1],
          [2,2,2],
          [2,2,2],
          [1,1,0]])
b=b.view(4,3) 
print(a.gather(1,b))
print(a.gather(0,b))
c=torch.LongTensor([1,2,0,1])
c=c.view(4,1)
print(a.gather(1,c))

输出:

tensor([[ 0.2000, 0.3000, 0.2000],
    [ 1.3000, 1.3000, 1.3000],
    [ 2.3000, 2.3000, 2.3000],
    [ 3.2000, 3.2000, 3.1000]])
tensor([[ 1.1000, 2.2000, 1.3000],
    [ 2.1000, 2.2000, 2.3000],
    [ 2.1000, 2.2000, 2.3000],
    [ 1.1000, 1.2000, 0.3000]])
tensor([[ 0.2000],
    [ 1.3000],
    [ 2.1000],
    [ 3.2000]])

squeeze

将维度为1的压缩掉。如size为(3,1,1,2),压缩之后为(3,2)

import torch
a=torch.randn(2,1,1,3)
print(a)
print(a.squeeze())

输出:

tensor([[[[-0.2320, 0.9513, 1.1613]]],
    [[[ 0.0901, 0.9613, -0.9344]]]])
tensor([[-0.2320, 0.9513, 1.1613],
    [ 0.0901, 0.9613, -0.9344]])

expand

扩展某个size为1的维度。如(2,2,1)扩展为(2,2,3)

import torch
x=torch.randn(2,2,1)
print(x)
y=x.expand(2,2,3)
print(y)

输出:

tensor([[[ 0.0608],
     [ 2.2106]],
 
    [[-1.9287],
     [ 0.8748]]])
tensor([[[ 0.0608, 0.0608, 0.0608],
     [ 2.2106, 2.2106, 2.2106]],
 
    [[-1.9287, -1.9287, -1.9287],
     [ 0.8748, 0.8748, 0.8748]]])

sum

size为(m,n,d)的张量,dim=1时,输出为size为(m,d)的张量

import torch
a=torch.tensor([[[1,2,3],[4,8,12]],[[1,2,3],[4,8,12]]])
print(a.sum())
print(a.sum(dim=1))

输出:

tensor(60)
tensor([[ 5, 10, 15],
    [ 5, 10, 15]])

contiguous

返回一个内存为连续的张量,如本身就是连续的,返回它自己。一般用在view()函数之前,因为view()要求调用张量是连续的。

可以通过is_contiguous查看张量内存是否连续。

import torch
a=torch.tensor([[[1,2,3],[4,8,12]],[[1,2,3],[4,8,12]]])
print(a.is_contiguous) 
print(a.contiguous().view(4,3))

输出:

<built-in method is_contiguous of Tensor object at 0x7f4b5e35afa0>
tensor([[ 1,  2,  3],
    [ 4,  8, 12],
    [ 1,  2,  3],
    [ 4,  8, 12]])

softmax

假设数组V有C个元素。对其进行softmax等价于将V的每个元素的指数除以所有元素的指数之和。这会使值落在区间(0,1)上,并且和为1。

对Pytorch 中的contiguous理解说明

import torch
import torch.nn.functional as F 
a=torch.tensor([[1.,1],[2,1],[3,1],[1,2],[1,3]])
b=F.softmax(a,dim=1)
print(b)

输出:

tensor([[ 0.5000, 0.5000],
    [ 0.7311, 0.2689],
    [ 0.8808, 0.1192],
    [ 0.2689, 0.7311],
    [ 0.1192, 0.8808]])

max

返回最大值,或指定维度的最大值以及index

import torch
a=torch.tensor([[.1,.2,.3],
        [1.1,1.2,1.3],
        [2.1,2.2,2.3],
        [3.1,3.2,3.3]])
print(a.max(dim=1))
print(a.max())

输出:

(tensor([ 0.3000, 1.3000, 2.3000, 3.3000]), tensor([ 2, 2, 2, 2]))
tensor(3.3000)

argmax

返回最大值的index

import torch
a=torch.tensor([[.1,.2,.3],
        [1.1,1.2,1.3],
        [2.1,2.2,2.3],
        [3.1,3.2,3.3]])
print(a.argmax(dim=1))
print(a.argmax())

输出:

tensor([ 2, 2, 2, 2])
tensor(11)

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
Python读写Excel文件的实例
Nov 01 Python
Python中使用gzip模块压缩文件的简单教程
Apr 08 Python
在Python中使用Neo4j数据库的教程
Apr 16 Python
python select.select模块通信全过程解析
Sep 20 Python
pandas实现选取特定索引的行
Apr 20 Python
pandas读取CSV文件时查看修改各列的数据类型格式
Jul 07 Python
python多线程共享变量的使用和效率方法
Jul 16 Python
Python集成开发工具Pycharm的安装和使用详解
Mar 18 Python
Django中使用Json返回数据的实现方法
Jun 03 Python
在keras中model.fit_generator()和model.fit()的区别说明
Jun 17 Python
python的launcher用法知识点总结
Aug 07 Python
python 如何利用argparse解析命令行参数
Sep 11 Python
Flask中jinja2的继承实现方法及实例
Mar 03 #Python
基于PyTorch中view的用法说明
Mar 03 #Python
Python 实现劳拉游戏的实例代码(四连环、重力四子棋)
Mar 03 #Python
对pytorch中x = x.view(x.size(0), -1) 的理解说明
Mar 03 #Python
Jupyter安装拓展nbextensions及解决官网下载慢的问题
Mar 03 #Python
Pytorch 中的optimizer使用说明
Mar 03 #Python
解决pytorch 的state_dict()拷贝问题
Mar 03 #Python
You might like
中英文字符串翻转函数
2008/12/09 PHP
php5 non-thread-safe和thread-safe这两个版本的区别分析
2010/03/13 PHP
laravel 去掉index.php伪静态的操作方法
2019/10/12 PHP
静态页面的值传递(三部曲)
2006/09/25 Javascript
CLASS_CONFUSION JS混淆 全源码
2007/12/12 Javascript
最近项目写了一些js,水平有待提高
2009/01/31 Javascript
jQery使网页在显示器上居中显示适用于任何分辨率
2014/06/09 Javascript
jquery中添加属性和删除属性
2015/06/03 Javascript
jquery.qtip提示信息插件用法简单实例
2016/06/17 Javascript
javascript运算符——逻辑运算符全面解析
2016/06/27 Javascript
js简单获取表单中单选按钮值的方法
2016/08/23 Javascript
如何理解jQuery中的ajaxSubmit方法
2017/03/13 Javascript
详解vue2路由vue-router配置(懒加载)
2017/04/08 Javascript
详解如何用webpack打包一个网站应用项目
2017/07/12 Javascript
Angular4 中内置指令的基本用法
2017/07/31 Javascript
vue-cli的工程模板与构建工具详解
2018/09/27 Javascript
vue单页面在微信下只能分享落地页的解决方案
2019/04/15 Javascript
python实现查询苹果手机维修进度
2015/03/16 Python
python简单验证码识别的实现方法
2019/05/10 Python
Python+PyQT5的子线程更新UI界面的实例
2019/06/14 Python
在python里面运用多继承方法详解
2019/07/01 Python
python list多级排序知识点总结
2019/10/23 Python
浅谈python 调用open()打开文件时路径出错的原因
2020/06/05 Python
Python实现冒泡排序算法的完整实例
2020/11/04 Python
Python 删除List元素的三种方法remove、pop、del
2020/11/16 Python
django使用多个数据库的方法实例
2021/03/04 Python
Html5 滚动穿透的方法
2019/05/13 HTML / CSS
贝斯特韦斯特酒店集团官网:Best Western
2019/01/03 全球购物
全球高级音频和视频专家:HiDef Lifestyle
2019/08/02 全球购物
乌克兰在线药房:Аптека24
2019/10/30 全球购物
电气工程及自动化专业自荐书范文
2013/12/18 职场文书
保安拾金不昧表扬信
2014/01/15 职场文书
自我介绍演讲稿
2014/01/15 职场文书
亲属关系公证书
2014/04/08 职场文书
辞职信模板(中英文版)
2015/02/27 职场文书
2019年励志签名:致拼搏路上的自己
2019/10/11 职场文书