对Pytorch 中的contiguous理解说明


Posted in Python onMarch 03, 2021

最近遇到这个函数,但查的中文博客里的解释貌似不是很到位,这里翻译一下stackoverflow上的回答并加上自己的理解。

在pytorch中,只有很少几个操作是不改变tensor的内容本身,而只是重新定义下标与元素的对应关系的。换句话说,这种操作不进行数据拷贝和数据的改变,变的是元数据。

这些操作是:

narrow(),view(),expand()和transpose()

举个栗子,在使用transpose()进行转置操作时,pytorch并不会创建新的、转置后的tensor,而是修改了tensor中的一些属性(也就是元数据),使得此时的offset和stride是与转置tensor相对应的。

转置的tensor和原tensor的内存是共享的!

为了证明这一点,我们来看下面的代码:

x = torch.randn(3, 2)
y = x.transpose(x, 0, 1)
x[0, 0] = 233
print(y[0, 0])
# print 233

可以看到,改变了y的元素的值的同时,x的元素的值也发生了变化。

也就是说,经过上述操作后得到的tensor,它内部数据的布局方式和从头开始创建一个这样的常规的tensor的布局方式是不一样的!于是…这就有contiguous()的用武之地了。

在上面的例子中,x是contiguous的,但y不是(因为内部数据不是通常的布局方式)。

注意不要被contiguous的字面意思“连续的”误解,tensor中数据还是在内存中一块区域里,只是布局的问题!

当调用contiguous()时,会强制拷贝一份tensor,让它的布局和从头创建的一毛一样。

一般来说这一点不用太担心,如果你没在需要调用contiguous()的地方调用contiguous(),运行时会提示你:

RuntimeError: input is not contiguous

只要看到这个错误提示,加上contiguous()就好啦~

补充:pytorch之expand,gather,squeeze,sum,contiguous,softmax,max,argmax

gather

torch.gather(input,dim,index,out=None)。对指定维进行索引。比如4*3的张量,对dim=1进行索引,那么index的取值范围就是0~2.

input是一个张量,index是索引张量。input和index的size要么全部维度都相同,要么指定的dim那一维度值不同。输出为和index大小相同的张量。

import torch
a=torch.tensor([[.1,.2,.3],
        [1.1,1.2,1.3],
        [2.1,2.2,2.3],
        [3.1,3.2,3.3]])
b=torch.LongTensor([[1,2,1],
          [2,2,2],
          [2,2,2],
          [1,1,0]])
b=b.view(4,3) 
print(a.gather(1,b))
print(a.gather(0,b))
c=torch.LongTensor([1,2,0,1])
c=c.view(4,1)
print(a.gather(1,c))

输出:

tensor([[ 0.2000, 0.3000, 0.2000],
    [ 1.3000, 1.3000, 1.3000],
    [ 2.3000, 2.3000, 2.3000],
    [ 3.2000, 3.2000, 3.1000]])
tensor([[ 1.1000, 2.2000, 1.3000],
    [ 2.1000, 2.2000, 2.3000],
    [ 2.1000, 2.2000, 2.3000],
    [ 1.1000, 1.2000, 0.3000]])
tensor([[ 0.2000],
    [ 1.3000],
    [ 2.1000],
    [ 3.2000]])

squeeze

将维度为1的压缩掉。如size为(3,1,1,2),压缩之后为(3,2)

import torch
a=torch.randn(2,1,1,3)
print(a)
print(a.squeeze())

输出:

tensor([[[[-0.2320, 0.9513, 1.1613]]],
    [[[ 0.0901, 0.9613, -0.9344]]]])
tensor([[-0.2320, 0.9513, 1.1613],
    [ 0.0901, 0.9613, -0.9344]])

expand

扩展某个size为1的维度。如(2,2,1)扩展为(2,2,3)

import torch
x=torch.randn(2,2,1)
print(x)
y=x.expand(2,2,3)
print(y)

输出:

tensor([[[ 0.0608],
     [ 2.2106]],
 
    [[-1.9287],
     [ 0.8748]]])
tensor([[[ 0.0608, 0.0608, 0.0608],
     [ 2.2106, 2.2106, 2.2106]],
 
    [[-1.9287, -1.9287, -1.9287],
     [ 0.8748, 0.8748, 0.8748]]])

sum

size为(m,n,d)的张量,dim=1时,输出为size为(m,d)的张量

import torch
a=torch.tensor([[[1,2,3],[4,8,12]],[[1,2,3],[4,8,12]]])
print(a.sum())
print(a.sum(dim=1))

输出:

tensor(60)
tensor([[ 5, 10, 15],
    [ 5, 10, 15]])

contiguous

返回一个内存为连续的张量,如本身就是连续的,返回它自己。一般用在view()函数之前,因为view()要求调用张量是连续的。

可以通过is_contiguous查看张量内存是否连续。

import torch
a=torch.tensor([[[1,2,3],[4,8,12]],[[1,2,3],[4,8,12]]])
print(a.is_contiguous) 
print(a.contiguous().view(4,3))

输出:

<built-in method is_contiguous of Tensor object at 0x7f4b5e35afa0>
tensor([[ 1,  2,  3],
    [ 4,  8, 12],
    [ 1,  2,  3],
    [ 4,  8, 12]])

softmax

假设数组V有C个元素。对其进行softmax等价于将V的每个元素的指数除以所有元素的指数之和。这会使值落在区间(0,1)上,并且和为1。

对Pytorch 中的contiguous理解说明

import torch
import torch.nn.functional as F 
a=torch.tensor([[1.,1],[2,1],[3,1],[1,2],[1,3]])
b=F.softmax(a,dim=1)
print(b)

输出:

tensor([[ 0.5000, 0.5000],
    [ 0.7311, 0.2689],
    [ 0.8808, 0.1192],
    [ 0.2689, 0.7311],
    [ 0.1192, 0.8808]])

max

返回最大值,或指定维度的最大值以及index

import torch
a=torch.tensor([[.1,.2,.3],
        [1.1,1.2,1.3],
        [2.1,2.2,2.3],
        [3.1,3.2,3.3]])
print(a.max(dim=1))
print(a.max())

输出:

(tensor([ 0.3000, 1.3000, 2.3000, 3.3000]), tensor([ 2, 2, 2, 2]))
tensor(3.3000)

argmax

返回最大值的index

import torch
a=torch.tensor([[.1,.2,.3],
        [1.1,1.2,1.3],
        [2.1,2.2,2.3],
        [3.1,3.2,3.3]])
print(a.argmax(dim=1))
print(a.argmax())

输出:

tensor([ 2, 2, 2, 2])
tensor(11)

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
python抓取京东价格分析京东商品价格走势
Jan 09 Python
pycharm 使用心得(五)断点调试
Jun 06 Python
python django集成cas验证系统
Jul 14 Python
python获取指定目录下所有文件名列表的方法
May 20 Python
python实现简单socket通信的方法
Apr 19 Python
解决Mac下使用python的坑
Aug 13 Python
Python pandas.DataFrame 找出有空值的行
Sep 09 Python
pyqt5 QlistView列表显示的实现示例
Mar 24 Python
python对XML文件的操作实现代码
Mar 27 Python
Pytorch学习之torch用法----比较操作(Comparison Ops)
Jun 28 Python
Django windows使用Apache实现部署流程解析
Oct 12 Python
python工具快速为音视频自动生成字幕(使用说明)
Jan 27 Python
Flask中jinja2的继承实现方法及实例
Mar 03 #Python
基于PyTorch中view的用法说明
Mar 03 #Python
Python 实现劳拉游戏的实例代码(四连环、重力四子棋)
Mar 03 #Python
对pytorch中x = x.view(x.size(0), -1) 的理解说明
Mar 03 #Python
Jupyter安装拓展nbextensions及解决官网下载慢的问题
Mar 03 #Python
Pytorch 中的optimizer使用说明
Mar 03 #Python
解决pytorch 的state_dict()拷贝问题
Mar 03 #Python
You might like
ThinkPHP跳转页success及error模板实例教程
2014/07/17 PHP
php中filter_input函数用法分析
2014/11/15 PHP
php中使用sftp教程
2015/03/30 PHP
试用php中oci8扩展
2015/06/18 PHP
PHP获取音频文件的相关信息
2015/06/22 PHP
yii通过小物件生成view的方法
2016/10/08 PHP
php设计模式之抽象工厂模式分析【星际争霸游戏案例】
2020/01/23 PHP
fireworks菜单生成器mm_menu.js在 IE 7.0 显示问题的解决方法
2009/10/20 Javascript
javascript基础知识大全 便于大家学习,也便于我自己查看
2012/08/17 Javascript
replace()方法查找字符使用示例
2013/10/28 Javascript
js生成动态表格并为每个单元格添加单击事件的方法
2014/04/14 Javascript
jquery实现html页面 div 假分页有原理有代码
2014/09/06 Javascript
JavaScript link方法入门实例(给字符串加上超链接)
2014/10/17 Javascript
微信小程序  Mustache语法详细介绍
2016/10/27 Javascript
javascript常用经典算法详解
2017/01/11 Javascript
浅谈vue.js中v-for循环渲染
2017/07/26 Javascript
vue2.0模拟锚点的实例
2018/03/14 Javascript
vue数据控制视图源码解析
2018/03/28 Javascript
微信小程序之分享页面如何返回首页的示例
2018/03/28 Javascript
新手快速入门微信小程序组件库 iView Weapp
2019/06/24 Javascript
[03:01]2014DOTA2国际邀请赛 DC:我是核弹粉,为Burning和国土祝福
2014/07/13 DOTA
python实现批量修改图片格式和尺寸
2018/06/07 Python
纯CSS3制作页面切换效果的实例代码
2019/05/30 HTML / CSS
canvas绘制太极图的实现示例
2020/04/29 HTML / CSS
局部内部类是否可以访问非final变量?
2013/04/20 面试题
法院实习人员自我鉴定
2013/09/26 职场文书
什么是岗位职责
2013/11/12 职场文书
资产经营总监岗位职责
2013/12/04 职场文书
公务员培训心得体会
2013/12/28 职场文书
市场营销管理毕业生自荐信
2014/03/03 职场文书
学校领导班子对照检查材料
2014/08/28 职场文书
2016大学自主招生推荐信范文
2015/03/23 职场文书
2015年监理工作总结范文
2015/04/07 职场文书
2016高考寄语或鼓励的话语
2015/12/04 职场文书
MySQL Server层四个日志的实现
2022/03/31 MySQL
【海涛解说】pis亲自推荐,其实你从来不会玩NW
2022/04/01 DOTA