PyTorch实现ResNet50、ResNet101和ResNet152示例


Posted in Python onJanuary 14, 2020

PyTorch: https://github.com/shanglianlm0525/PyTorch-Networks

PyTorch实现ResNet50、ResNet101和ResNet152示例

import torch
import torch.nn as nn
import torchvision
import numpy as np

print("PyTorch Version: ",torch.__version__)
print("Torchvision Version: ",torchvision.__version__)

__all__ = ['ResNet50', 'ResNet101','ResNet152']

def Conv1(in_planes, places, stride=2):
  return nn.Sequential(
    nn.Conv2d(in_channels=in_planes,out_channels=places,kernel_size=7,stride=stride,padding=3, bias=False),
    nn.BatchNorm2d(places),
    nn.ReLU(inplace=True),
    nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
  )

class Bottleneck(nn.Module):
  def __init__(self,in_places,places, stride=1,downsampling=False, expansion = 4):
    super(Bottleneck,self).__init__()
    self.expansion = expansion
    self.downsampling = downsampling

    self.bottleneck = nn.Sequential(
      nn.Conv2d(in_channels=in_places,out_channels=places,kernel_size=1,stride=1, bias=False),
      nn.BatchNorm2d(places),
      nn.ReLU(inplace=True),
      nn.Conv2d(in_channels=places, out_channels=places, kernel_size=3, stride=stride, padding=1, bias=False),
      nn.BatchNorm2d(places),
      nn.ReLU(inplace=True),
      nn.Conv2d(in_channels=places, out_channels=places*self.expansion, kernel_size=1, stride=1, bias=False),
      nn.BatchNorm2d(places*self.expansion),
    )

    if self.downsampling:
      self.downsample = nn.Sequential(
        nn.Conv2d(in_channels=in_places, out_channels=places*self.expansion, kernel_size=1, stride=stride, bias=False),
        nn.BatchNorm2d(places*self.expansion)
      )
    self.relu = nn.ReLU(inplace=True)
  def forward(self, x):
    residual = x
    out = self.bottleneck(x)

    if self.downsampling:
      residual = self.downsample(x)

    out += residual
    out = self.relu(out)
    return out

class ResNet(nn.Module):
  def __init__(self,blocks, num_classes=1000, expansion = 4):
    super(ResNet,self).__init__()
    self.expansion = expansion

    self.conv1 = Conv1(in_planes = 3, places= 64)

    self.layer1 = self.make_layer(in_places = 64, places= 64, block=blocks[0], stride=1)
    self.layer2 = self.make_layer(in_places = 256,places=128, block=blocks[1], stride=2)
    self.layer3 = self.make_layer(in_places=512,places=256, block=blocks[2], stride=2)
    self.layer4 = self.make_layer(in_places=1024,places=512, block=blocks[3], stride=2)

    self.avgpool = nn.AvgPool2d(7, stride=1)
    self.fc = nn.Linear(2048,num_classes)

    for m in self.modules():
      if isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
      elif isinstance(m, nn.BatchNorm2d):
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)

  def make_layer(self, in_places, places, block, stride):
    layers = []
    layers.append(Bottleneck(in_places, places,stride, downsampling =True))
    for i in range(1, block):
      layers.append(Bottleneck(places*self.expansion, places))

    return nn.Sequential(*layers)


  def forward(self, x):
    x = self.conv1(x)

    x = self.layer1(x)
    x = self.layer2(x)
    x = self.layer3(x)
    x = self.layer4(x)

    x = self.avgpool(x)
    x = x.view(x.size(0), -1)
    x = self.fc(x)
    return x

def ResNet50():
  return ResNet([3, 4, 6, 3])

def ResNet101():
  return ResNet([3, 4, 23, 3])

def ResNet152():
  return ResNet([3, 8, 36, 3])


if __name__=='__main__':
  #model = torchvision.models.resnet50()
  model = ResNet50()
  print(model)

  input = torch.randn(1, 3, 224, 224)
  out = model(input)
  print(out.shape)

以上这篇PyTorch实现ResNet50、ResNet101和ResNet152示例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中的is和id用法分析
Jan 26 Python
python通过smpt发送邮件的方法
Apr 30 Python
详解Python自建logging模块
Jan 29 Python
python实现log日志的示例代码
Apr 28 Python
python3爬取数据至mysql的方法
Jun 26 Python
python中map的基本用法示例
Sep 10 Python
Python编程深度学习计算库之numpy
Dec 28 Python
Python使用字典的嵌套功能详解
Feb 27 Python
python Pandas库基础分析之时间序列的处理详解
Jul 13 Python
在pycharm中配置Anaconda以及pip源配置详解
Sep 09 Python
python打印异常信息的两种实现方式
Dec 24 Python
python 制作python包,封装成可用模块教程
Jul 13 Python
python重要函数eval多种用法解析
Jan 14 #Python
关于ResNeXt网络的pytorch实现
Jan 14 #Python
Python属性和内建属性实例解析
Jan 14 #Python
Python程序控制语句用法实例分析
Jan 14 #Python
dpn网络的pytorch实现方式
Jan 14 #Python
Django之form组件自动校验数据实现
Jan 14 #Python
简单了解python filter、map、reduce的区别
Jan 14 #Python
You might like
PHP脚本中include文件出错解决方法
2008/11/20 PHP
php字符串函数学习之strstr()
2015/03/27 PHP
PHP结合Ffmpeg快速搭建流媒体服务的实践记录
2018/10/31 PHP
php设计模式之模板模式实例分析【星际争霸游戏案例】
2020/03/24 PHP
PHP dirname功能及原理实例解析
2020/10/28 PHP
IE 下的只读 innerHTML
2009/08/21 Javascript
JS俄罗斯方块,包含完整的设计理念
2010/12/11 Javascript
js操作textarea方法集合封装(兼容IE,firefox)
2011/02/22 Javascript
游览器中javascript的执行过程(图文)
2012/05/20 Javascript
jQuery之选择组件的深入解析
2013/06/19 Javascript
如何使用Javascript正则表达式来格式化XML内容
2013/07/04 Javascript
JavaScript实现带缓冲效果的随屏滚动漂浮广告代码
2015/11/06 Javascript
详谈Angular路由与Nodejs路由的区别
2017/03/05 NodeJs
Vue 组件(component)教程之实现精美的日历方法示例
2018/01/08 Javascript
webpack源码之loader机制详解
2018/04/06 Javascript
浅谈webpack 自动刷新与解析
2018/04/09 Javascript
Nodejs实现多文件夹文件同步
2018/10/17 NodeJs
Vue商品控件与购物车联动效果的实例代码
2019/07/21 Javascript
vuex state中的数组变化监听实例
2019/11/06 Javascript
JavaScript 装逼指南(js另类写法)
2020/05/10 Javascript
使用Python脚本和ADB命令实现卸载App
2017/02/10 Python
python 显示数组全部元素的方法
2018/04/19 Python
使用Python来开发微信功能
2018/06/13 Python
django使用admin站点上传图片的实例
2019/07/28 Python
调用其他python脚本文件里面的类和方法过程解析
2019/11/15 Python
如何使用PyCharm将代码上传到GitHub上(图文详解)
2020/04/27 Python
python help函数实例用法
2020/12/06 Python
HTML5中的Article和Section元素认识及使用
2013/03/22 HTML / CSS
英国知名奢侈品包包品牌:Milli Millu
2016/12/22 全球购物
澳大利亚购买最佳炊具品牌网站:Cookware Brands
2019/02/16 全球购物
三月学雷锋活动总结
2014/06/26 职场文书
群众路线剖析材料怎么写
2014/10/09 职场文书
2014年检验科工作总结
2014/11/22 职场文书
2014年卫生工作总结
2014/11/27 职场文书
给男朋友的道歉短信
2015/05/12 职场文书
关于办理居住证的介绍信模板
2019/11/27 职场文书