Pytorch实现GoogLeNet的方法


Posted in Python onAugust 18, 2019

GoogLeNet也叫InceptionNet,在2014年被提出,如今已到V4版本。GoogleNet比VGGNet具有更深的网络结构,一共有22层,但是参数比AlexNet要少12倍,但是计算量是AlexNet的4倍,原因就是它采用很有效的Inception模块,并且没有全连接层。

最重要的创新点就在于使用inception模块,通过使用不同维度的卷积提取不同尺度的特征图。左图是最初的Inception模块,右图是使用的1×1得卷积对左图的改进,降低了输入的特征图维度,同时降低了网络的参数量和计算复杂度,称为inception V1。

Pytorch实现GoogLeNet的方法

GoogleNet在架构设计上为保持低层为传统卷积方式不变,只在较高的层开始用Inception模块。

Pytorch实现GoogLeNet的方法

Pytorch实现GoogLeNet的方法

inception V2中将5x5的卷积改为2个3x3的卷积,扩大了感受野,原来是5x5,现在是6x6。Pytorch实现GoogLeNet(inception V2):

'''GoogLeNet with PyTorch.'''
import torch
import torch.nn as nn
import torch.nn.functional as F

# 编写卷积+bn+relu模块
class BasicConv2d(nn.Module):
  def __init__(self, in_channels, out_channals, **kwargs):
    super(BasicConv2d, self).__init__()
    self.conv = nn.Conv2d(in_channels, out_channals, **kwargs)
    self.bn = nn.BatchNorm2d(out_channals)

  def forward(self, x):
    x = self.conv(x)
    x = self.bn(x)
    return F.relu(x)

# 编写Inception模块
class Inception(nn.Module):
  def __init__(self, in_planes,
         n1x1, n3x3red, n3x3, n5x5red, n5x5, pool_planes):
    super(Inception, self).__init__()
    # 1x1 conv branch
    self.b1 = BasicConv2d(in_planes, n1x1, kernel_size=1)

    # 1x1 conv -> 3x3 conv branch
    self.b2_1x1_a = BasicConv2d(in_planes, n3x3red, 
                  kernel_size=1)
    self.b2_3x3_b = BasicConv2d(n3x3red, n3x3, 
                  kernel_size=3, padding=1)

    # 1x1 conv -> 3x3 conv -> 3x3 conv branch
    self.b3_1x1_a = BasicConv2d(in_planes, n5x5red, 
                  kernel_size=1)
    self.b3_3x3_b = BasicConv2d(n5x5red, n5x5, 
                  kernel_size=3, padding=1)
    self.b3_3x3_c = BasicConv2d(n5x5, n5x5, 
                  kernel_size=3, padding=1)

    # 3x3 pool -> 1x1 conv branch
    self.b4_pool = nn.MaxPool2d(3, stride=1, padding=1)
    self.b4_1x1 = BasicConv2d(in_planes, pool_planes, 
                 kernel_size=1)

  def forward(self, x):
    y1 = self.b1(x)
    y2 = self.b2_3x3_b(self.b2_1x1_a(x))
    y3 = self.b3_3x3_c(self.b3_3x3_b(self.b3_1x1_a(x)))
    y4 = self.b4_1x1(self.b4_pool(x))
    # y的维度为[batch_size, out_channels, C_out,L_out]
    # 合并不同卷积下的特征图
    return torch.cat([y1, y2, y3, y4], 1)


class GoogLeNet(nn.Module):
  def __init__(self):
    super(GoogLeNet, self).__init__()
    self.pre_layers = BasicConv2d(3, 192, 
                   kernel_size=3, padding=1)

    self.a3 = Inception(192, 64, 96, 128, 16, 32, 32)
    self.b3 = Inception(256, 128, 128, 192, 32, 96, 64)

    self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)

    self.a4 = Inception(480, 192, 96, 208, 16, 48, 64)
    self.b4 = Inception(512, 160, 112, 224, 24, 64, 64)
    self.c4 = Inception(512, 128, 128, 256, 24, 64, 64)
    self.d4 = Inception(512, 112, 144, 288, 32, 64, 64)
    self.e4 = Inception(528, 256, 160, 320, 32, 128, 128)

    self.a5 = Inception(832, 256, 160, 320, 32, 128, 128)
    self.b5 = Inception(832, 384, 192, 384, 48, 128, 128)

    self.avgpool = nn.AvgPool2d(8, stride=1)
    self.linear = nn.Linear(1024, 10)

  def forward(self, x):
    out = self.pre_layers(x)
    out = self.a3(out)
    out = self.b3(out)
    out = self.maxpool(out)
    out = self.a4(out)
    out = self.b4(out)
    out = self.c4(out)
    out = self.d4(out)
    out = self.e4(out)
    out = self.maxpool(out)
    out = self.a5(out)
    out = self.b5(out)
    out = self.avgpool(out)
    out = out.view(out.size(0), -1)
    out = self.linear(out)
    return out


def test():
  net = GoogLeNet()
  x = torch.randn(1,3,32,32)
  y = net(x)
  print(y.size())

test()

以上这篇Pytorch实现GoogLeNet的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python局部赋值的规则
Mar 07 Python
Python设计模式之观察者模式简单示例
Jan 10 Python
Python数据分析库pandas基本操作方法
Apr 08 Python
解决Python网页爬虫之中文乱码问题
May 11 Python
Python访问MongoDB,并且转换成Dataframe的方法
Oct 15 Python
python批量获取html内body内容的实例
Jan 02 Python
python3.x提取中文的正则表达式示例代码
Jul 23 Python
pandas的相关系数与协方差实例
Dec 27 Python
Python双链表原理与实现方法详解
Feb 22 Python
在tensorflow下利用plt画论文中loss,acc等曲线图实例
Jun 15 Python
PyQt5爬取12306车票信息程序的实现
May 14 Python
python opencv通过按键采集图片源码
May 20 Python
PyTorch之图像和Tensor填充的实例
Aug 18 #Python
Pytorch Tensor的索引与切片例子
Aug 18 #Python
在PyTorch中Tensor的查找和筛选例子
Aug 18 #Python
对Pytorch神经网络初始化kaiming分布详解
Aug 18 #Python
pytorch中的embedding词向量的使用方法
Aug 18 #Python
Pytorch加载部分预训练模型的参数实例
Aug 18 #Python
在pytorch中查看可训练参数的例子
Aug 18 #Python
You might like
PHP中file_exists函数不支持中文名的解决方法
2014/07/26 PHP
thinkphp中ajax与php响应过程详解
2014/12/08 PHP
PHP读取mssql json数据中文乱码的解决办法
2016/04/11 PHP
Thinkphp5 微信公众号token验证不成功的原因及解决方法
2017/11/12 PHP
PDO::commit讲解
2019/01/27 PHP
实例介绍PHP删除数组中的重复元素
2019/03/03 PHP
javascript 控制 html元素 显示/隐藏实现代码
2009/09/01 Javascript
Js 代码中,ajax请求地址后加随机数防止浏览器缓存的原因
2013/05/07 Javascript
javascript实现分栏显示小技巧附图
2014/10/13 Javascript
使用JavaScript实现弹出层效果的简单实例
2016/05/31 Javascript
js实现图片淡入淡出切换简易效果
2016/08/22 Javascript
使用BootStrapValidator完成前端输入验证
2016/09/28 Javascript
JS当前页面登录注册框,固定DIV,底层阴影的实例代码
2016/09/29 Javascript
js判断iframe中元素是否存在的实现代码
2016/12/24 Javascript
JS中利用swiper实现3d翻转幻灯片实例代码
2017/08/25 Javascript
web前端开发中常见的多列布局解决方案整理(一定要看)
2017/10/15 Javascript
微信小程序支付之c#后台实现方法
2017/10/19 Javascript
React native ListView 增加顶部下拉刷新和底下点击刷新示例
2018/04/27 Javascript
详解封装基础的angular4的request请求方法
2018/06/05 Javascript
jQuery移动端跑马灯抽奖特效升级版(抽奖概率固定)实现方法
2019/01/18 jQuery
利用Dectorator分模块存储Vuex状态的实现
2019/02/05 Javascript
vue实现购物车案例
2020/05/30 Javascript
Python实现批量检测HTTP服务的状态
2016/10/27 Python
总结python实现父类调用两种方法的不同
2017/01/15 Python
python 生成器生成杨辉三角的方法(必看)
2017/04/10 Python
解决安装python库时windows error5 报错的问题
2018/10/21 Python
对python中xlsx,csv以及json文件的相互转化方法详解
2018/12/25 Python
Pycharm 2020年最新激活码(亲测有效)
2020/09/18 Python
浅析Django 接收所有文件,前端展示文件(包括视频,文件,图片)ajax请求
2020/03/09 Python
CSS3模拟IOS滑动开关效果
2016/09/28 HTML / CSS
GLAMGLOW格莱魅美国官网:美国知名的面膜品牌
2016/12/31 全球购物
Arti-shopping中文官网:大型海外商品一站式直邮平台
2020/03/23 全球购物
美国婴儿服装购物网站:Gerber Childrenswear
2020/05/06 全球购物
20岁生日感言
2014/01/13 职场文书
趣味比赛活动方案
2014/02/15 职场文书
2015年监理工作总结范文
2015/04/07 职场文书