PyTorch实现ResNet50、ResNet101和ResNet152示例


Posted in Python onJanuary 14, 2020

PyTorch: https://github.com/shanglianlm0525/PyTorch-Networks

PyTorch实现ResNet50、ResNet101和ResNet152示例

import torch
import torch.nn as nn
import torchvision
import numpy as np

print("PyTorch Version: ",torch.__version__)
print("Torchvision Version: ",torchvision.__version__)

__all__ = ['ResNet50', 'ResNet101','ResNet152']

def Conv1(in_planes, places, stride=2):
  return nn.Sequential(
    nn.Conv2d(in_channels=in_planes,out_channels=places,kernel_size=7,stride=stride,padding=3, bias=False),
    nn.BatchNorm2d(places),
    nn.ReLU(inplace=True),
    nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
  )

class Bottleneck(nn.Module):
  def __init__(self,in_places,places, stride=1,downsampling=False, expansion = 4):
    super(Bottleneck,self).__init__()
    self.expansion = expansion
    self.downsampling = downsampling

    self.bottleneck = nn.Sequential(
      nn.Conv2d(in_channels=in_places,out_channels=places,kernel_size=1,stride=1, bias=False),
      nn.BatchNorm2d(places),
      nn.ReLU(inplace=True),
      nn.Conv2d(in_channels=places, out_channels=places, kernel_size=3, stride=stride, padding=1, bias=False),
      nn.BatchNorm2d(places),
      nn.ReLU(inplace=True),
      nn.Conv2d(in_channels=places, out_channels=places*self.expansion, kernel_size=1, stride=1, bias=False),
      nn.BatchNorm2d(places*self.expansion),
    )

    if self.downsampling:
      self.downsample = nn.Sequential(
        nn.Conv2d(in_channels=in_places, out_channels=places*self.expansion, kernel_size=1, stride=stride, bias=False),
        nn.BatchNorm2d(places*self.expansion)
      )
    self.relu = nn.ReLU(inplace=True)
  def forward(self, x):
    residual = x
    out = self.bottleneck(x)

    if self.downsampling:
      residual = self.downsample(x)

    out += residual
    out = self.relu(out)
    return out

class ResNet(nn.Module):
  def __init__(self,blocks, num_classes=1000, expansion = 4):
    super(ResNet,self).__init__()
    self.expansion = expansion

    self.conv1 = Conv1(in_planes = 3, places= 64)

    self.layer1 = self.make_layer(in_places = 64, places= 64, block=blocks[0], stride=1)
    self.layer2 = self.make_layer(in_places = 256,places=128, block=blocks[1], stride=2)
    self.layer3 = self.make_layer(in_places=512,places=256, block=blocks[2], stride=2)
    self.layer4 = self.make_layer(in_places=1024,places=512, block=blocks[3], stride=2)

    self.avgpool = nn.AvgPool2d(7, stride=1)
    self.fc = nn.Linear(2048,num_classes)

    for m in self.modules():
      if isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
      elif isinstance(m, nn.BatchNorm2d):
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)

  def make_layer(self, in_places, places, block, stride):
    layers = []
    layers.append(Bottleneck(in_places, places,stride, downsampling =True))
    for i in range(1, block):
      layers.append(Bottleneck(places*self.expansion, places))

    return nn.Sequential(*layers)


  def forward(self, x):
    x = self.conv1(x)

    x = self.layer1(x)
    x = self.layer2(x)
    x = self.layer3(x)
    x = self.layer4(x)

    x = self.avgpool(x)
    x = x.view(x.size(0), -1)
    x = self.fc(x)
    return x

def ResNet50():
  return ResNet([3, 4, 6, 3])

def ResNet101():
  return ResNet([3, 4, 23, 3])

def ResNet152():
  return ResNet([3, 8, 36, 3])


if __name__=='__main__':
  #model = torchvision.models.resnet50()
  model = ResNet50()
  print(model)

  input = torch.randn(1, 3, 224, 224)
  out = model(input)
  print(out.shape)

以上这篇PyTorch实现ResNet50、ResNet101和ResNet152示例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在RedHat系Linux上部署Python的Celery框架的教程
Apr 07 Python
python正则表达式的使用
Jun 12 Python
Python之csv文件从MySQL数据库导入导出的方法
Jun 21 Python
Python类和对象的定义与实际应用案例分析
Dec 27 Python
Python数据可视化之画图
Jan 15 Python
完美解决Python matplotlib绘图时汉字显示不正常的问题
Jan 29 Python
python单线程文件传输的实例(C/S)
Feb 13 Python
Python增强赋值和共享引用注意事项小结
May 28 Python
Python自动化运维之Ansible定义主机与组规则操作详解
Jun 13 Python
如何写python的配置文件
Jun 07 Python
python语音识别指南终极版(有这一篇足矣)
Sep 09 Python
Python OpenGL基本配置方式
May 20 Python
python重要函数eval多种用法解析
Jan 14 #Python
关于ResNeXt网络的pytorch实现
Jan 14 #Python
Python属性和内建属性实例解析
Jan 14 #Python
Python程序控制语句用法实例分析
Jan 14 #Python
dpn网络的pytorch实现方式
Jan 14 #Python
Django之form组件自动校验数据实现
Jan 14 #Python
简单了解python filter、map、reduce的区别
Jan 14 #Python
You might like
discuz 首页四格:最新话题+最新回复+热门话题+精华文章插件
2007/08/19 PHP
thinkphp5.1框架中容器(Container)和门面(Facade)的实现方法分析
2019/08/05 PHP
[IE&FireFox兼容]JS对select操作
2007/01/07 Javascript
Jquery中的CheckBox、RadioButton、DropDownList的取值赋值实现代码
2011/10/12 Javascript
JavaScript实现穷举排列(permutation)算法谜题解答
2014/12/29 Javascript
js漂浮广告实现代码
2015/08/15 Javascript
简要了解jQuery移动web开发的响应式布局设计
2015/12/04 Javascript
javascript html5移动端轻松实现文件上传
2020/03/27 Javascript
举例说明JavaScript中的实例对象与原型对象
2016/03/11 Javascript
AngularJS教程 ng-style 指令简单示例
2016/08/03 Javascript
微信小程序 保留小数(toFixed)详细介绍
2016/11/16 Javascript
基于cropper.js封装vue实现在线图片裁剪组件功能
2018/03/01 Javascript
JavaScript的数据类型转换原则(干货)
2018/03/15 Javascript
web页面和微信小程序页面实现瀑布流效果
2018/09/26 Javascript
原生JS实现的跳一跳小游戏完整实例
2019/01/27 Javascript
简单通过settimeout看javascript的运行机制
2019/05/10 Javascript
vue中uni-app 实现小程序登录注册功能
2019/10/12 Javascript
jQuery实现鼠标滑动切换图片
2020/05/27 jQuery
[01:25]2015国际邀请赛最佳短片奖——斧王《拆塔英雄:天赋异禀》
2015/09/22 DOTA
Python开发如何在ubuntu 15.10 上配置vim
2016/01/25 Python
python实现基于信息增益的决策树归纳
2018/12/18 Python
详解小白之KMP算法及python实现
2019/04/04 Python
python实现图片转字符小工具
2019/04/30 Python
使用python搭建服务器并实现Android端与之通信的方法
2019/06/28 Python
Python 分享10个PyCharm技巧
2019/07/13 Python
pytorch实现CNN卷积神经网络
2020/02/19 Python
电气工程及其自动化学生实习自我鉴定
2013/09/19 职场文书
财务管理专业应届毕业生求职信
2013/09/22 职场文书
专业销售业务员求职信
2013/11/18 职场文书
汉语言文学专业求职信
2014/06/19 职场文书
2014年文员工作总结
2014/11/18 职场文书
检讨书模板
2015/01/29 职场文书
总经理司机岗位职责
2015/04/10 职场文书
创业计划书之家教托管
2019/09/25 职场文书
CSS filter 有什么神奇用途
2021/05/25 HTML / CSS
css3应用示例:新增的选择器
2022/03/16 HTML / CSS