Pytorch 卷积中的 Input Shape用法


Posted in Python onJune 29, 2020

先看Pytorch中的卷积

class torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True)

二维卷积层, 输入的尺度是(N, C_in,H,W),输出尺度(N,C_out,H_out,W_out)的计算方式

Pytorch 卷积中的 Input Shape用法

这里比较奇怪的是这个卷积层居然没有定义input shape,输入尺寸明明是:(N, C_in, H,W),但是定义中却只需要输入in_channel的size,就能完成卷积,那是不是说这样任意size的image都可以进行卷积呢?

然后我进行了下面这样的实验:

import torch
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):

  def __init__(self):
    super(Net, self).__init__()
    # 输入图像channel:1;输出channel:6;5x5卷积核
    self.conv1 = nn.Conv2d(1, 6, 5)
    self.conv2 = nn.Conv2d(6, 16, 5)
    # an affine operation: y = Wx + b
    self.fc1 = nn.Linear(16 * 5 * 5, 120)
    self.fc2 = nn.Linear(120, 84)
    self.fc3 = nn.Linear(84, 10)

  def forward(self, x):
    # 2x2 Max pooling
    x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
    # If the size is a square you can only specify a single number
    x = F.max_pool2d(F.relu(self.conv2(x)), 2)
    x = x.view(-1, self.num_flat_features(x))
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)
    return x

  def num_flat_features(self, x):
    size = x.size()[1:] # 除去批大小维度的其余维度
    num_features = 1
    for s in size:
      num_features *= s
    return num_features

net = Net()
print(net)

输出

Net(
(conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
(conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
(fc1): Linear(in_features=400, out_features=120, bias=True)
(fc2): Linear(in_features=120, out_features=84, bias=True)
(fc3): Linear(in_features=84, out_features=10, bias=True)
)

官网Tutorial 说:这个网络(LeNet)的期待输入是32x32,我就比较奇怪他又没有设置Input shape或者Tensorflow里的Input层,怎么就知道(H,W) =(32, 32)。

输入:

input = torch.randn(1, 1, 32, 32)

output = Net(input)

没问题,但是

input = torch.randn(1, 1, 64, 64)

output = Net(input)

出现:mismatch Error

我们看一下卷积模型部分。

input:(1, 1, 32, 32) --> conv1(1, 6, 5) --> (1, 6, 28, 28) --> max_pool1(2, 2) --> (1, 6, 14, 14) --> conv2(6, 16, 5) -->(1, 16, 10, 10) --> max_pool2(2, 2) --> (1, 16, 5, 5)

然后是将其作为一个全连接网络的输入。Linear相当于tensorflow 中的Dense。所以当你的输入尺寸不为(32, 32)时,卷积得到最终feature map shape就不是(None, 16, 5, 5),而我们的第一个Linear层的输入为(None, 16 * 5 * 5),故会出现mismatch Error。

之所以会有这样一个问题还是因为keras model 必须提定义Input shape,而pytorch更像是一个流程化操作,具体看官网吧。

补充知识:pytorch 卷积 分组卷积 及其深度卷积

先来看看pytorch二维卷积的操作API

Pytorch 卷积中的 Input Shape用法

现在继续讲讲几个卷积是如何操作的。

一. 普通卷积

torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True)

普通卷积时group默认为1 dilation=1(这里先暂时不讨论dilation)其余都正常的话,比如输入为Nx in_channel x high x width

输出为N x out_channel x high xwidth .还是来具体的数字吧,输入为64通道的特征图,输出为32通道的特征图,要想得到32通道的特征图就必须得有32种不同的卷积核。

也就是上面传入的参数out_channel。继续说说具体是怎么的得到的,对于每一种卷积核会和64种不同的特征图依次进行卷积,比如卷积核大小是3x3大小的,由于卷积权值共享,所以对于输入的一张特征图卷积时,只有3x3个参数,同一张特征图上进行滑窗卷积操作时参数是一样的,刚才说的第一种卷积核和输入的第一个特征图卷积完后再继续和后面的第2,3,........64个不同的特征图依次卷积(一种卷积核对于输入特征图来说,同一特征图上面卷积,参数一样,对于不同的特征图上卷积不一样),最后的参数是3x3x64。

此时输出才为一个特征图,因为现在才只使用了一种卷积核。一种核分别在局部小窗口里面和64个特征图卷积应该得到64个数,最后将64个数相加就可以得到一个数了,也就是输出一个特征图上对应于那个窗口的值,依次滑窗就可以得到完整的特征图了。

前面将了这么多才使用一种卷积核,那么现在依次类推使用32种不同的卷积核就可以得到输出的32通道的特征图。最终参数为64x3x3x32.

二.分组卷积

参数group=1时,就是和普通的卷积一样。现在假如group=4,前提是输入特征图和输出特征图必须是4的倍数。现在来看看是如何操作的。in_channel64分成4组,out_inchannel(也就是32种核)也分成4组,依次对应上面的普通卷方式,最终将每组输出的8个特征图依次concat起来,就是结果的out_channel

三. 深度卷积depthwise

此时group=in_channle,也就是对每一个输入的特征图分别用不同的卷积核卷积。out_channel必须是in_channel 的整数倍。

Pytorch 卷积中的 Input Shape用法

3.1 当k=1时,out_channel=in_channel ,每一个卷积核分别和每一个输入的通道进行卷积,最后在concat起来。参数总量为3x3x64。如果此时卷积完之后接着一个64个1x1大小的卷积核。就是谷歌公司于2017年的CVPR中在论文”Xception: deep learning with depthwise separable convolutions”中提出的结构。如下图

Pytorch 卷积中的 Input Shape用法

上图是将1x1放在depthwise前面,其实原理都一样。最终参数的个数是64x1x1+64x3x3。参数要小于普通的卷积方法64x3x3x64

3.2 当k是大于1的整数时,比如k=2

Pytorch 卷积中的 Input Shape用法

此时每一个输入的特征图对应k个卷积核,生成k特征图,最终生成的特征图个数就是k×in_channel .

以上这篇Pytorch 卷积中的 Input Shape用法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 实现归并排序算法
Jun 05 Python
浅谈MySQL中的触发器
May 05 Python
在Python中使用PIL模块对图片进行高斯模糊处理的教程
May 05 Python
使用python实现省市三级菜单效果
Jan 20 Python
浅谈python中的__init__、__new__和__call__方法
Jul 18 Python
Python机器学习logistic回归代码解析
Jan 17 Python
利用Python代码实现数据可视化的5种方法详解
Mar 25 Python
windows下添加Python环境变量的方法汇总
May 14 Python
Python使用jsonpath-rw模块处理Json对象操作示例
Jul 31 Python
Window 64位下python3.6.2环境搭建图文教程
Sep 19 Python
使用python将mysql数据库的数据转换为json数据的方法
Jul 01 Python
keras:model.compile损失函数的用法
Jul 01 Python
Python闭包装饰器使用方法汇总
Jun 29 #Python
使用已经得到的keras模型识别自己手写的数字方式
Jun 29 #Python
Python接口测试环境搭建过程详解
Jun 29 #Python
python字典的值可以修改吗
Jun 29 #Python
python怎么自定义捕获错误
Jun 29 #Python
python打开文件的方式有哪些
Jun 29 #Python
解决tensorflow/keras时出现数组维度不匹配问题
Jun 29 #Python
You might like
php 阴历-农历-转换类代码
2012/01/16 PHP
php 类中的常量、静态属性、非静态属性的区别
2017/04/09 PHP
让网页根据不同IE版本显示不同的内容
2009/02/08 Javascript
javascript重复绑定事件造成的后果说明
2013/03/02 Javascript
ExtJS DOM元素操作经验分享
2013/08/28 Javascript
Google Dart编程语法和基本类型学习教程
2013/11/27 Javascript
基于jquery实现的文字向上跑动类似跑马灯的效果
2014/06/17 Javascript
javascript实现的登陆遮罩效果汇总
2015/11/09 Javascript
Vue.js实战之通过监听滚动事件实现动态锚点
2017/04/04 Javascript
JS简单实现自定义右键菜单实例
2017/05/31 Javascript
Angularjs使用过滤器完成排序功能
2017/09/20 Javascript
微信小程序上传图片实例
2018/05/28 Javascript
Vuex的初探与实战小结
2018/11/26 Javascript
Element Table的row-class-name无效与动态高亮显示选中行背景色
2018/11/30 Javascript
Node.js使用MongoDB的ObjectId作为查询条件的方法
2019/09/10 Javascript
[01:42:49]DOTA2-DPC中国联赛 正赛 iG vs PSG.LGD BO3 第一场 2月26日
2021/03/11 DOTA
Python logging管理不同级别log打印和存储实例
2018/01/19 Python
windows下python 3.6.4安装配置图文教程
2018/08/21 Python
Django框架会话技术实例分析【Cookie与Session】
2019/05/24 Python
Python递归函数 二分查找算法实现解析
2019/08/12 Python
使用pyshp包进行shapefile文件修改的例子
2019/12/06 Python
解决在keras中使用model.save()函数保存模型失败的问题
2020/05/21 Python
Python3使用tesserocr识别字母数字验证码的实现
2021/01/29 Python
使用html5实现表格实现标题合并的实例代码
2019/05/13 HTML / CSS
班会关于环保演讲稿
2013/12/29 职场文书
单位领导证婚词
2014/01/14 职场文书
市场营销管理毕业生自荐信
2014/03/03 职场文书
企业元宵节主持词
2014/03/25 职场文书
庆元旦文艺演出主持词
2014/03/27 职场文书
安全目标责任书
2014/07/22 职场文书
单方离婚协议书范本(2014版)
2014/09/30 职场文书
2016年国庆节宣传标语
2015/11/25 职场文书
2016大学生社会实践单位评语
2015/12/01 职场文书
同学聚会开幕词
2019/04/02 职场文书
详解SQL的窗口函数
2022/04/21 Oracle
linux目录管理方法介绍
2022/06/01 Servers