PyTorch的自适应池化Adaptive Pooling实例


Posted in Python onJanuary 03, 2020

简介

自适应池化Adaptive Pooling是PyTorch含有的一种池化层,在PyTorch的中有六种形式:

自适应最大池化Adaptive Max Pooling:

torch.nn.AdaptiveMaxPool1d(output_size)
torch.nn.AdaptiveMaxPool2d(output_size)
torch.nn.AdaptiveMaxPool3d(output_size)

自适应平均池化Adaptive Average Pooling:

torch.nn.AdaptiveAvgPool1d(output_size)
torch.nn.AdaptiveAvgPool2d(output_size)
torch.nn.AdaptiveAvgPool3d(output_size)

具体可见官方文档。

官方给出的例子:
>>> # target output size of 5x7
>>> m = nn.AdaptiveMaxPool2d((5,7))
>>> input = torch.randn(1, 64, 8, 9)
>>> output = m(input)
>>> output.size()
torch.Size([1, 64, 5, 7])

>>> # target output size of 7x7 (square)
>>> m = nn.AdaptiveMaxPool2d(7)
>>> input = torch.randn(1, 64, 10, 9)
>>> output = m(input)
>>> output.size()
torch.Size([1, 64, 7, 7])

>>> # target output size of 10x7
>>> m = nn.AdaptiveMaxPool2d((None, 7))
>>> input = torch.randn(1, 64, 10, 9)
>>> output = m(input)
>>> output.size()
torch.Size([1, 64, 10, 7])

Adaptive Pooling特殊性在于,输出张量的大小都是给定的output_size output\_sizeoutput_size。例如输入张量大小为(1, 64, 8, 9),设定输出大小为(5,7),通过Adaptive Pooling层,可以得到大小为(1, 64, 5, 7)的张量。

原理

PyTorch的自适应池化Adaptive Pooling实例

>>> inputsize = 9
>>> outputsize = 4

>>> input = torch.randn(1, 1, inputsize)
>>> input
tensor([[[ 1.5695, -0.4357, 1.5179, 0.9639, -0.4226, 0.5312, -0.5689, 0.4945, 0.1421]]])

>>> m1 = nn.AdaptiveMaxPool1d(outputsize)
>>> m2 = nn.MaxPool1d(kernel_size=math.ceil(inputsize / outputsize), stride=math.floor(inputsize / outputsize), padding=0)
>>> output1 = m1(input)
>>> output2 = m2(input)

>>> output1
tensor([[[1.5695, 1.5179, 0.5312, 0.4945]]]) torch.Size([1, 1, 4])
>>> output2
tensor([[[1.5695, 1.5179, 0.5312, 0.4945]]]) torch.Size([1, 1, 4])

通过实验发现:

PyTorch的自适应池化Adaptive Pooling实例

下面是Adaptive Average Pooling的c++源码部分。

template <typename scalar_t>
 static void adaptive_avg_pool2d_out_frame(
      scalar_t *input_p,
      scalar_t *output_p,
      int64_t sizeD,
      int64_t isizeH,
      int64_t isizeW,
      int64_t osizeH,
      int64_t osizeW,
      int64_t istrideD,
      int64_t istrideH,
      int64_t istrideW)
 {
  int64_t d;
 #pragma omp parallel for private(d)
  for (d = 0; d < sizeD; d++)
  {
   /* loop over output */
   int64_t oh, ow;
   for(oh = 0; oh < osizeH; oh++)
   {
    int istartH = start_index(oh, osizeH, isizeH);
    int iendH  = end_index(oh, osizeH, isizeH);
    int kH = iendH - istartH;

    for(ow = 0; ow < osizeW; ow++)
    {
     int istartW = start_index(ow, osizeW, isizeW);
     int iendW  = end_index(ow, osizeW, isizeW);
     int kW = iendW - istartW;

     /* local pointers */
     scalar_t *ip = input_p  + d*istrideD + istartH*istrideH + istartW*istrideW;
     scalar_t *op = output_p + d*osizeH*osizeW + oh*osizeW + ow;

     /* compute local average: */
     scalar_t sum = 0;
     int ih, iw;
     for(ih = 0; ih < kH; ih++)
     {
      for(iw = 0; iw < kW; iw++)
      {
       scalar_t val = *(ip + ih*istrideH + iw*istrideW);
       sum += val;
      }
     }

     /* set output to local average */
     *op = sum / kW / kH;
    }
   }
  }
}

以上这篇PyTorch的自适应池化Adaptive Pooling实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python 文件操作实现代码
Oct 07 Python
Python实现基于HTTP文件传输实例
Nov 08 Python
使用tensorflow实现线性回归
Sep 08 Python
Python 保存矩阵为Excel的实现方法
Jan 28 Python
利用Python实现微信找房机器人实例教程
Mar 10 Python
如何用C代码给Python写扩展库(Cython)
May 17 Python
如何在Cloud Studio上执行Python代码?
Aug 09 Python
vscode 配置 python3开发环境的方法
Sep 19 Python
关于matplotlib-legend 位置属性 loc 使用说明
May 16 Python
浅谈keras中自定义二分类任务评价指标metrics的方法以及代码
Jun 11 Python
如何验证python安装成功
Jul 06 Python
python 如何用terminal输入参数
May 25 Python
pytorch torch.nn.AdaptiveAvgPool2d()自适应平均池化函数详解
Jan 03 #Python
pytorch AvgPool2d函数使用详解
Jan 03 #Python
使用pyhon绘图比较两个手机屏幕大小(实例代码)
Jan 03 #Python
Python基础之函数原理与应用实例详解
Jan 03 #Python
对Pytorch中Tensor的各种池化操作解析
Jan 03 #Python
Python基础之高级变量类型实例详解
Jan 03 #Python
关于Pytorch MaxUnpool2d中size操作方式
Jan 03 #Python
You might like
用来给图片加水印的PHP类
2008/04/09 PHP
PHP 出现乱码和Sessions验证问题的解决方法!
2008/12/06 PHP
php array_flip() 删除数组重复元素
2009/01/14 PHP
PHP 多维数组排序实现代码
2009/08/05 PHP
php开发留言板的CRUD(增,删,改,查)操作
2012/04/19 PHP
PHP laravel中的多对多关系实例详解
2017/06/07 PHP
PHP连接sftp并下载文件的方法教程
2018/08/26 PHP
前台js改变Session的值(用ajax实现)
2012/12/28 Javascript
JS动态修改表格cellPadding和cellSpacing的方法
2015/03/31 Javascript
JS实现网页Div层Clone拖拽效果
2015/09/26 Javascript
JS组件Bootstrap Table使用方法详解
2016/02/02 Javascript
JS中如何实现点击a标签返回页面顶部的问题
2017/01/19 Javascript
[01:45:05]VGJ.T vs Newbee Supermajor 败者组 BO3 第二场 6.6
2018/06/07 DOTA
python 不关闭控制台的实现方法
2011/10/23 Python
Python减少循环层次和缩进的技巧分析
2016/03/15 Python
python3使用scrapy生成csv文件代码示例
2017/12/28 Python
Python读取txt内容写入xls格式excel中的方法
2018/10/11 Python
python 实现视频流下载保存MP4的方法
2019/01/09 Python
详解DeBug Python神级工具PySnooper
2019/07/03 Python
Django分页功能的实现代码详解
2019/07/29 Python
python selenium实现发送带附件的邮件代码实例
2019/12/10 Python
pytorch1.0中torch.nn.Conv2d用法详解
2020/01/10 Python
Python脚本去除文件的只读性操作
2020/03/05 Python
pygame实现飞机大战
2020/03/11 Python
纯CSS实现的大小渐变、渐远效果
2014/04/15 HTML / CSS
基于CSS3实现的黑色个性导航菜单效果
2015/09/14 HTML / CSS
CSS3 实现弹幕的示例代码
2017/08/07 HTML / CSS
英国领先的奢侈品零售商之一:CRUISE
2016/12/02 全球购物
公司前台辞职报告
2014/01/19 职场文书
优秀本科毕业生自荐信
2014/07/04 职场文书
第二批党的群众路线教育实践活动个人对照检查材料
2014/09/23 职场文书
2015年度销售个人工作总结
2015/03/31 职场文书
学校远程教育工作总结
2015/08/11 职场文书
《中华上下五千年》读后感3篇
2019/11/29 职场文书
MySQL 如何分析查询性能
2021/05/12 MySQL
深入讲解Vue中父子组件通信与事件触发
2022/03/22 Vue.js