Pytorch实现GoogLeNet的方法


Posted in Python onAugust 18, 2019

GoogLeNet也叫InceptionNet,在2014年被提出,如今已到V4版本。GoogleNet比VGGNet具有更深的网络结构,一共有22层,但是参数比AlexNet要少12倍,但是计算量是AlexNet的4倍,原因就是它采用很有效的Inception模块,并且没有全连接层。

最重要的创新点就在于使用inception模块,通过使用不同维度的卷积提取不同尺度的特征图。左图是最初的Inception模块,右图是使用的1×1得卷积对左图的改进,降低了输入的特征图维度,同时降低了网络的参数量和计算复杂度,称为inception V1。

Pytorch实现GoogLeNet的方法

GoogleNet在架构设计上为保持低层为传统卷积方式不变,只在较高的层开始用Inception模块。

Pytorch实现GoogLeNet的方法

Pytorch实现GoogLeNet的方法

inception V2中将5x5的卷积改为2个3x3的卷积,扩大了感受野,原来是5x5,现在是6x6。Pytorch实现GoogLeNet(inception V2):

'''GoogLeNet with PyTorch.'''
import torch
import torch.nn as nn
import torch.nn.functional as F

# 编写卷积+bn+relu模块
class BasicConv2d(nn.Module):
  def __init__(self, in_channels, out_channals, **kwargs):
    super(BasicConv2d, self).__init__()
    self.conv = nn.Conv2d(in_channels, out_channals, **kwargs)
    self.bn = nn.BatchNorm2d(out_channals)

  def forward(self, x):
    x = self.conv(x)
    x = self.bn(x)
    return F.relu(x)

# 编写Inception模块
class Inception(nn.Module):
  def __init__(self, in_planes,
         n1x1, n3x3red, n3x3, n5x5red, n5x5, pool_planes):
    super(Inception, self).__init__()
    # 1x1 conv branch
    self.b1 = BasicConv2d(in_planes, n1x1, kernel_size=1)

    # 1x1 conv -> 3x3 conv branch
    self.b2_1x1_a = BasicConv2d(in_planes, n3x3red, 
                  kernel_size=1)
    self.b2_3x3_b = BasicConv2d(n3x3red, n3x3, 
                  kernel_size=3, padding=1)

    # 1x1 conv -> 3x3 conv -> 3x3 conv branch
    self.b3_1x1_a = BasicConv2d(in_planes, n5x5red, 
                  kernel_size=1)
    self.b3_3x3_b = BasicConv2d(n5x5red, n5x5, 
                  kernel_size=3, padding=1)
    self.b3_3x3_c = BasicConv2d(n5x5, n5x5, 
                  kernel_size=3, padding=1)

    # 3x3 pool -> 1x1 conv branch
    self.b4_pool = nn.MaxPool2d(3, stride=1, padding=1)
    self.b4_1x1 = BasicConv2d(in_planes, pool_planes, 
                 kernel_size=1)

  def forward(self, x):
    y1 = self.b1(x)
    y2 = self.b2_3x3_b(self.b2_1x1_a(x))
    y3 = self.b3_3x3_c(self.b3_3x3_b(self.b3_1x1_a(x)))
    y4 = self.b4_1x1(self.b4_pool(x))
    # y的维度为[batch_size, out_channels, C_out,L_out]
    # 合并不同卷积下的特征图
    return torch.cat([y1, y2, y3, y4], 1)


class GoogLeNet(nn.Module):
  def __init__(self):
    super(GoogLeNet, self).__init__()
    self.pre_layers = BasicConv2d(3, 192, 
                   kernel_size=3, padding=1)

    self.a3 = Inception(192, 64, 96, 128, 16, 32, 32)
    self.b3 = Inception(256, 128, 128, 192, 32, 96, 64)

    self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)

    self.a4 = Inception(480, 192, 96, 208, 16, 48, 64)
    self.b4 = Inception(512, 160, 112, 224, 24, 64, 64)
    self.c4 = Inception(512, 128, 128, 256, 24, 64, 64)
    self.d4 = Inception(512, 112, 144, 288, 32, 64, 64)
    self.e4 = Inception(528, 256, 160, 320, 32, 128, 128)

    self.a5 = Inception(832, 256, 160, 320, 32, 128, 128)
    self.b5 = Inception(832, 384, 192, 384, 48, 128, 128)

    self.avgpool = nn.AvgPool2d(8, stride=1)
    self.linear = nn.Linear(1024, 10)

  def forward(self, x):
    out = self.pre_layers(x)
    out = self.a3(out)
    out = self.b3(out)
    out = self.maxpool(out)
    out = self.a4(out)
    out = self.b4(out)
    out = self.c4(out)
    out = self.d4(out)
    out = self.e4(out)
    out = self.maxpool(out)
    out = self.a5(out)
    out = self.b5(out)
    out = self.avgpool(out)
    out = out.view(out.size(0), -1)
    out = self.linear(out)
    return out


def test():
  net = GoogLeNet()
  x = torch.randn(1,3,32,32)
  y = net(x)
  print(y.size())

test()

以上这篇Pytorch实现GoogLeNet的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
测试、预发布后用python检测网页是否有日常链接
Jun 03 Python
Python的迭代器和生成器使用实例
Jan 14 Python
Python httplib模块使用实例
Apr 11 Python
Python编程之event对象的用法实例分析
Mar 23 Python
Python分析学校四六级过关情况
Nov 22 Python
Python中反射和描述器总结
Sep 23 Python
对matplotlib改变colorbar位置和方向的方法详解
Dec 13 Python
Python实现某论坛自动签到功能
Aug 20 Python
python的Jenkins接口调用方式
May 12 Python
matplotlib运行时配置(Runtime Configuration,rc)参数rcParams解析
Jan 05 Python
python脚本使用阿里云slb对恶意攻击进行封堵的实现
Feb 04 Python
总结python多进程multiprocessing的相关知识
Jun 29 Python
PyTorch之图像和Tensor填充的实例
Aug 18 #Python
Pytorch Tensor的索引与切片例子
Aug 18 #Python
在PyTorch中Tensor的查找和筛选例子
Aug 18 #Python
对Pytorch神经网络初始化kaiming分布详解
Aug 18 #Python
pytorch中的embedding词向量的使用方法
Aug 18 #Python
Pytorch加载部分预训练模型的参数实例
Aug 18 #Python
在pytorch中查看可训练参数的例子
Aug 18 #Python
You might like
简单的页面缓冲技术
2006/10/09 PHP
php MySQL与分页效率
2008/06/04 PHP
利用Memcached在php下实现session机制 替换PHP的原生session支持
2010/08/21 PHP
基于PHP常用字符串的总结(待续)
2013/06/07 PHP
mac环境中使用brew安装php5.5.15
2014/08/18 PHP
php+mysqli实现批量执行插入、更新及删除数据的方法
2015/01/29 PHP
PHP实现在对象之外访问其私有属性private及保护属性protected的方法
2017/11/20 PHP
PHP笛卡尔积实现原理及代码实例
2020/12/09 PHP
JS+CSS实现仿支付宝菜单选中效果代码
2015/09/25 Javascript
Underscore源码分析
2015/12/30 Javascript
Backbone.js框架中简单的View视图编写学习笔记
2016/02/14 Javascript
Ajax跨域实现代码(后台jsp)
2017/01/21 Javascript
JavaScript通过filereader接口读取文件
2017/05/10 Javascript
详解React Native网络请求fetch简单封装
2017/08/10 Javascript
微信小程序使用Socket的实例
2017/09/19 Javascript
详解微信小程序调起键盘性能优化
2018/07/24 Javascript
vue2 拖动排序 vuedraggable组件的实现
2019/08/08 Javascript
写一个Vue loading 插件
2020/11/09 Javascript
[49:41]NB vs NAVI Supermajor小组赛A组 BO3 第一场 6.2
2018/06/03 DOTA
pyramid配置session的方法教程
2013/11/27 Python
python bmp转换为jpg 并删除原图的方法
2018/10/25 Python
解决每次打开pycharm直接进入项目的问题
2018/10/28 Python
这可能是最好玩的python GUI入门实例(推荐)
2019/07/19 Python
css3选择器基本介绍
2014/12/15 HTML / CSS
html5+css3气泡组件的实现
2014/11/21 HTML / CSS
捷克厨房用品购物网站:Tescoma
2018/07/13 全球购物
什么是符号链接,什么是硬链接?符号链接与硬链接的区别是什么?
2014/01/19 面试题
农行实习自我鉴定
2013/09/22 职场文书
银行委托书范本
2014/04/04 职场文书
英语导游词
2015/02/13 职场文书
看上去很美观后感
2015/06/10 职场文书
大学新生入学感想
2015/08/07 职场文书
2015年秋学期教研工作总结
2015/10/14 职场文书
2016廉洁从业学习心得体会
2016/01/19 职场文书
2016优秀教师先进个人事迹材料
2016/02/25 职场文书
python自动化八大定位元素讲解
2021/07/09 Python