PyTorch实现ResNet50、ResNet101和ResNet152示例


Posted in Python onJanuary 14, 2020

PyTorch: https://github.com/shanglianlm0525/PyTorch-Networks

PyTorch实现ResNet50、ResNet101和ResNet152示例

import torch
import torch.nn as nn
import torchvision
import numpy as np

print("PyTorch Version: ",torch.__version__)
print("Torchvision Version: ",torchvision.__version__)

__all__ = ['ResNet50', 'ResNet101','ResNet152']

def Conv1(in_planes, places, stride=2):
  return nn.Sequential(
    nn.Conv2d(in_channels=in_planes,out_channels=places,kernel_size=7,stride=stride,padding=3, bias=False),
    nn.BatchNorm2d(places),
    nn.ReLU(inplace=True),
    nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
  )

class Bottleneck(nn.Module):
  def __init__(self,in_places,places, stride=1,downsampling=False, expansion = 4):
    super(Bottleneck,self).__init__()
    self.expansion = expansion
    self.downsampling = downsampling

    self.bottleneck = nn.Sequential(
      nn.Conv2d(in_channels=in_places,out_channels=places,kernel_size=1,stride=1, bias=False),
      nn.BatchNorm2d(places),
      nn.ReLU(inplace=True),
      nn.Conv2d(in_channels=places, out_channels=places, kernel_size=3, stride=stride, padding=1, bias=False),
      nn.BatchNorm2d(places),
      nn.ReLU(inplace=True),
      nn.Conv2d(in_channels=places, out_channels=places*self.expansion, kernel_size=1, stride=1, bias=False),
      nn.BatchNorm2d(places*self.expansion),
    )

    if self.downsampling:
      self.downsample = nn.Sequential(
        nn.Conv2d(in_channels=in_places, out_channels=places*self.expansion, kernel_size=1, stride=stride, bias=False),
        nn.BatchNorm2d(places*self.expansion)
      )
    self.relu = nn.ReLU(inplace=True)
  def forward(self, x):
    residual = x
    out = self.bottleneck(x)

    if self.downsampling:
      residual = self.downsample(x)

    out += residual
    out = self.relu(out)
    return out

class ResNet(nn.Module):
  def __init__(self,blocks, num_classes=1000, expansion = 4):
    super(ResNet,self).__init__()
    self.expansion = expansion

    self.conv1 = Conv1(in_planes = 3, places= 64)

    self.layer1 = self.make_layer(in_places = 64, places= 64, block=blocks[0], stride=1)
    self.layer2 = self.make_layer(in_places = 256,places=128, block=blocks[1], stride=2)
    self.layer3 = self.make_layer(in_places=512,places=256, block=blocks[2], stride=2)
    self.layer4 = self.make_layer(in_places=1024,places=512, block=blocks[3], stride=2)

    self.avgpool = nn.AvgPool2d(7, stride=1)
    self.fc = nn.Linear(2048,num_classes)

    for m in self.modules():
      if isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
      elif isinstance(m, nn.BatchNorm2d):
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)

  def make_layer(self, in_places, places, block, stride):
    layers = []
    layers.append(Bottleneck(in_places, places,stride, downsampling =True))
    for i in range(1, block):
      layers.append(Bottleneck(places*self.expansion, places))

    return nn.Sequential(*layers)


  def forward(self, x):
    x = self.conv1(x)

    x = self.layer1(x)
    x = self.layer2(x)
    x = self.layer3(x)
    x = self.layer4(x)

    x = self.avgpool(x)
    x = x.view(x.size(0), -1)
    x = self.fc(x)
    return x

def ResNet50():
  return ResNet([3, 4, 6, 3])

def ResNet101():
  return ResNet([3, 4, 23, 3])

def ResNet152():
  return ResNet([3, 8, 36, 3])


if __name__=='__main__':
  #model = torchvision.models.resnet50()
  model = ResNet50()
  print(model)

  input = torch.randn(1, 3, 224, 224)
  out = model(input)
  print(out.shape)

以上这篇PyTorch实现ResNet50、ResNet101和ResNet152示例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python转换摩斯密码示例
Feb 16 Python
python实现颜色空间转换程序(Tkinter)
Dec 31 Python
让代码变得更易维护的7个Python库
Oct 09 Python
pandas去重复行并分类汇总的实现方法
Jan 29 Python
python根据时间获取周数代码实例
Sep 30 Python
python打印n位数“水仙花数”(实例代码)
Dec 25 Python
python实现批量修改文件名
Mar 23 Python
Python自带的IDE在哪里
Jul 01 Python
Pytorch损失函数nn.NLLLoss2d()用法说明
Jul 07 Python
python 实现一个图形界面的汇率计算器
Nov 09 Python
pandas apply使用多列计算生成新的列实现示例
Feb 24 Python
Python可视化学习之matplotlib内置单颜色
Feb 24 Python
python重要函数eval多种用法解析
Jan 14 #Python
关于ResNeXt网络的pytorch实现
Jan 14 #Python
Python属性和内建属性实例解析
Jan 14 #Python
Python程序控制语句用法实例分析
Jan 14 #Python
dpn网络的pytorch实现方式
Jan 14 #Python
Django之form组件自动校验数据实现
Jan 14 #Python
简单了解python filter、map、reduce的区别
Jan 14 #Python
You might like
MOTOROLA 摩托罗拉 MODEL 66-XI五灯中波收音机
2021/03/02 无线电
PHP函数实现分页含文本分页和数字分页
2014/10/23 PHP
总结PHP中DateTime的常用方法
2016/08/11 PHP
PHP的介绍以及优势详细分析
2019/09/05 PHP
Avengerls vs Newbee BO3 第二场2.18
2021/03/10 DOTA
JS获取页面窗口大小的代码解读
2011/12/01 Javascript
js加强的经典分页实例
2013/03/15 Javascript
xmlhttp缓存清除的2种解决方法
2013/12/13 Javascript
jquery select 设置默认选中的示例代码
2014/02/07 Javascript
Javascript实现Web颜色值转换
2015/02/05 Javascript
jquery实现标签上移、下移、置顶
2015/04/26 Javascript
js获得当前系统日期时间的方法
2015/05/06 Javascript
JavaScript实现标题栏文字轮播效果代码
2015/10/24 Javascript
Jquery ajax 同步阻塞引起的UI线程阻塞问题
2015/11/17 Javascript
快速实现JS图片懒加载(可视区域加载)示例代码
2017/01/04 Javascript
JSON与JS对象的区别与对比
2017/03/01 Javascript
微信小程序实现移动端滑动分页效果(ajax)
2017/06/13 Javascript
利用yarn代替npm管理前端项目模块依赖的方法详解
2017/09/04 Javascript
什么是Vue.js框架 为什么选择它?
2017/10/17 Javascript
关于在vue 中使用百度ueEditor编辑器的方法实例代码
2018/09/14 Javascript
用原生 JS 实现 innerHTML 功能实例详解
2019/04/03 Javascript
python Django模板的使用方法(图文)
2013/11/04 Python
python多线程方式执行多个bat代码
2016/06/07 Python
浅谈python中的getattr函数 hasattr函数
2016/06/14 Python
Python2.7编程中SQLite3基本操作方法示例
2017/08/09 Python
浅谈配置OpenCV3 + Python3的简易方法(macOS)
2018/04/02 Python
基于python的socket实现单机五子棋到双人对战
2020/03/24 Python
python调用摄像头的示例代码
2020/09/28 Python
Hurley官方网站:扎根于海滩生活方式的全球青年文化品牌
2020/05/18 全球购物
交通事故协议书
2014/04/15 职场文书
舞蹈教育学专业自荐信
2014/06/15 职场文书
2014年护士个人工作总结
2014/11/11 职场文书
五年级小学生评语
2014/12/26 职场文书
2015年幼儿园教育教学工作总结
2015/05/25 职场文书
办公室规章制度范本
2015/08/04 职场文书
MySQL 使用事件(Events)完成计划任务
2021/05/24 MySQL