使用pytorch实现论文中的unet网络


Posted in Python onJune 24, 2020

设计神经网络的一般步骤:

1. 设计框架

2. 设计骨干网络

Unet网络设计的步骤:

1. 设计Unet网络工厂模式

2. 设计编解码结构

3. 设计卷积模块

4. unet实例模块

Unet网络最重要的特征:

1. 编解码结构。

2. 解码结构,比FCN更加完善,采用连接方式。

3. 本质是一个框架,编码部分可以使用很多图像分类网络。

示例代码:

import torch
import torch.nn as nn

class Unet(nn.Module):
 #初始化参数:Encoder,Decoder,bridge
 #bridge默认值为无,如果有参数传入,则用该参数替换None
 def __init__(self,Encoder,Decoder,bridge = None):
  super(Unet,self).__init__()
  self.encoder = Encoder(encoder_blocks)
  self.decoder = Decoder(decoder_blocks)
  self.bridge = bridge
 def forward(self,x):
  res = self.encoder(x)
  out,skip = res[0],res[1,:]
  if bridge is not None:
   out = bridge(out)
  out = self.decoder(out,skip)
  return out
#设计编码模块
class Encoder(nn.Module):
 def __init__(self,blocks):
  super(Encoder,self).__init__()
  #assert:断言函数,避免出现参数错误
  assert len(blocks) > 0
  #nn.Modulelist():模型列表,所有的参数可以纳入网络,但是没有forward函数
  self.blocks = nn.Modulelist(blocks)
 def forward(self,x):
  skip = []
  for i in range(len(self.blocks) - 1):
   x = self.blocks[i](x)
   skip.append(x)
  res = [self.block[i+1](x)]
  #列表之间可以通过+号拼接
  res += skip
  return res
#设计Decoder模块
class Decoder(nn.Module):
 def __init__(self,blocks):
  super(Decoder, self).__init__()
  assert len(blocks) > 0
  self.blocks = nn.Modulelist(blocks)
 def ceter_crop(self,skips,x):
  _,_,height1,width1 = skips.shape()
  _,_,height2,width2 = x.shape()
  #对图像进行剪切处理,拼接的时候保持对应size参数一致
  ht,wt = min(height1,height2),min(width1,width2)
  dh1 = (height1 - height2)//2 if height1 > height2 else 0
  dw1 = (width1 - width2)//2 if width1 > width2 else 0
  dh2 = (height2 - height1)//2 if height2 > height1 else 0
  dw2 = (width2 - width1)//2 if width2 > width1 else 0
  return skips[:,:,dh1:(dh1 + ht),dw1:(dw1 + wt)],\
    x[:,:,dh2:(dh2 + ht),dw2 : (dw2 + wt)]

 def forward(self, skips,x,reverse_skips = True):
  assert len(skips) == len(blocks) - 1
  if reverse_skips is True:
   skips = skips[: : -1]
  x = self.blocks[0](x)
  for i in range(1, len(self.blocks)):
   skip = skips[i-1]
   x = torch.cat(skip,x,1)
   x = self.blocks[i](x)
  return x
#定义了一个卷积block
def unet_convs(in_channels,out_channels,padding = 0):
 #nn.Sequential:与Modulelist相比,包含了forward函数
 return nn.Sequential(
  nn.Conv2d(in_channels, out_channels, kernal_size = 3, padding = padding, bias = False),
  nn.BatchNorm2d(outchannels),
  nn.ReLU(inplace = True),
  nn.Conv2d(in_channels, out_channels, kernal_size=3, padding=padding, bias=False),
  nn.BatchNorm2d(outchannels),
  nn.ReLU(inplace=True),
 )
#实例化Unet模型
def unet(in_channels,out_channels):
 encoder_blocks = [unet_convs(in_channels, 64),\
      nn.Sequential(nn.Maxpool2d(kernal_size = 2, stride = 2, ceil_mode = True),\
         unet_convs(64,128)), \
      nn.Sequential(nn.Maxpool2d(kernal_size=2, stride=2, ceil_mode=True), \
         unet_convs(128, 256)),
      nn.Sequential(nn.Maxpool2d(kernal_size=2, stride=2, ceil_mode=True), \
         unet_convs(256, 512)),
      ]
 bridge = nn.Sequential(unet_convs(512, 1024))
 decoder_blocks = [nn.conTranpose2d(1024, 512), \
      nn.Sequential(unet_convs(1024, 512),
         nn.conTranpose2d(512, 256)),\
      nn.Sequential(unet_convs(512, 256),
         nn.conTranpose2d(256, 128)), \
      nn.Sequential(unet_convs(512, 256),
         nn.conTranpose2d(256, 128)), \
      nn.Sequential(unet_convs(256, 128),
         nn.conTranpose2d(128, 64))
      ]
 return Unet(encoder_blocks,decoder_blocks,bridge)

补充知识:Pytorch搭建U-Net网络

U-Net: Convolutional Networks for Biomedical Image Segmentation

使用pytorch实现论文中的unet网络

import torch.nn as nn
import torch
from torch import autograd
from torchsummary import summary

class DoubleConv(nn.Module):
 def __init__(self, in_ch, out_ch):
  super(DoubleConv, self).__init__()
  self.conv = nn.Sequential(
   nn.Conv2d(in_ch, out_ch, 3, padding=0),
   nn.BatchNorm2d(out_ch),
   nn.ReLU(inplace=True),
   nn.Conv2d(out_ch, out_ch, 3, padding=0),
   nn.BatchNorm2d(out_ch),
   nn.ReLU(inplace=True)
  )

 def forward(self, input):
  return self.conv(input)

class Unet(nn.Module):
 def __init__(self, in_ch, out_ch):
  super(Unet, self).__init__()
  self.conv1 = DoubleConv(in_ch, 64)
  self.pool1 = nn.MaxPool2d(2)
  self.conv2 = DoubleConv(64, 128)
  self.pool2 = nn.MaxPool2d(2)
  self.conv3 = DoubleConv(128, 256)
  self.pool3 = nn.MaxPool2d(2)
  self.conv4 = DoubleConv(256, 512)
  self.pool4 = nn.MaxPool2d(2)
  self.conv5 = DoubleConv(512, 1024)
  # 逆卷积,也可以使用上采样
  self.up6 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
  self.conv6 = DoubleConv(1024, 512)
  self.up7 = nn.ConvTranspose2d(512, 256, 2, stride=2)
  self.conv7 = DoubleConv(512, 256)
  self.up8 = nn.ConvTranspose2d(256, 128, 2, stride=2)
  self.conv8 = DoubleConv(256, 128)
  self.up9 = nn.ConvTranspose2d(128, 64, 2, stride=2)
  self.conv9 = DoubleConv(128, 64)
  self.conv10 = nn.Conv2d(64, out_ch, 1)

 def forward(self, x):
  c1 = self.conv1(x)
  crop1 = c1[:,:,88:480,88:480]
  p1 = self.pool1(c1)
  c2 = self.conv2(p1)
  crop2 = c2[:,:,40:240,40:240]
  p2 = self.pool2(c2)
  c3 = self.conv3(p2)
  crop3 = c3[:,:,16:120,16:120]
  p3 = self.pool3(c3)
  c4 = self.conv4(p3)
  crop4 = c4[:,:,4:60,4:60]
  p4 = self.pool4(c4)
  c5 = self.conv5(p4)
  up_6 = self.up6(c5)
  merge6 = torch.cat([up_6, crop4], dim=1)
  c6 = self.conv6(merge6)
  up_7 = self.up7(c6)
  merge7 = torch.cat([up_7, crop3], dim=1)
  c7 = self.conv7(merge7)
  up_8 = self.up8(c7)
  merge8 = torch.cat([up_8, crop2], dim=1)
  c8 = self.conv8(merge8)
  up_9 = self.up9(c8)
  merge9 = torch.cat([up_9, crop1], dim=1)
  c9 = self.conv9(merge9)
  c10 = self.conv10(c9)
  out = nn.Sigmoid()(c10)
  return out

if __name__=="__main__":
 test_input=torch.rand(1, 1, 572, 572)
 model=Unet(in_ch=1, out_ch=2)
 summary(model, (1,572,572))
 ouput=model(test_input)
 print(ouput.size())
----------------------------------------------------------------
  Layer (type)    Output Shape   Param #
================================================================
   Conv2d-1   [-1, 64, 570, 570]    640
  BatchNorm2d-2   [-1, 64, 570, 570]    128
    ReLU-3   [-1, 64, 570, 570]    0
   Conv2d-4   [-1, 64, 568, 568]   36,928
  BatchNorm2d-5   [-1, 64, 568, 568]    128
    ReLU-6   [-1, 64, 568, 568]    0
  DoubleConv-7   [-1, 64, 568, 568]    0
   MaxPool2d-8   [-1, 64, 284, 284]    0
   Conv2d-9  [-1, 128, 282, 282]   73,856
  BatchNorm2d-10  [-1, 128, 282, 282]    256
    ReLU-11  [-1, 128, 282, 282]    0
   Conv2d-12  [-1, 128, 280, 280]   147,584
  BatchNorm2d-13  [-1, 128, 280, 280]    256
    ReLU-14  [-1, 128, 280, 280]    0
  DoubleConv-15  [-1, 128, 280, 280]    0
  MaxPool2d-16  [-1, 128, 140, 140]    0
   Conv2d-17  [-1, 256, 138, 138]   295,168
  BatchNorm2d-18  [-1, 256, 138, 138]    512
    ReLU-19  [-1, 256, 138, 138]    0
   Conv2d-20  [-1, 256, 136, 136]   590,080
  BatchNorm2d-21  [-1, 256, 136, 136]    512
    ReLU-22  [-1, 256, 136, 136]    0
  DoubleConv-23  [-1, 256, 136, 136]    0
  MaxPool2d-24   [-1, 256, 68, 68]    0
   Conv2d-25   [-1, 512, 66, 66]  1,180,160
  BatchNorm2d-26   [-1, 512, 66, 66]   1,024
    ReLU-27   [-1, 512, 66, 66]    0
   Conv2d-28   [-1, 512, 64, 64]  2,359,808
  BatchNorm2d-29   [-1, 512, 64, 64]   1,024
    ReLU-30   [-1, 512, 64, 64]    0
  DoubleConv-31   [-1, 512, 64, 64]    0
  MaxPool2d-32   [-1, 512, 32, 32]    0
   Conv2d-33   [-1, 1024, 30, 30]  4,719,616
  BatchNorm2d-34   [-1, 1024, 30, 30]   2,048
    ReLU-35   [-1, 1024, 30, 30]    0
   Conv2d-36   [-1, 1024, 28, 28]  9,438,208
  BatchNorm2d-37   [-1, 1024, 28, 28]   2,048
    ReLU-38   [-1, 1024, 28, 28]    0
  DoubleConv-39   [-1, 1024, 28, 28]    0
 ConvTranspose2d-40   [-1, 512, 56, 56]  2,097,664
   Conv2d-41   [-1, 512, 54, 54]  4,719,104
  BatchNorm2d-42   [-1, 512, 54, 54]   1,024
    ReLU-43   [-1, 512, 54, 54]    0
   Conv2d-44   [-1, 512, 52, 52]  2,359,808
  BatchNorm2d-45   [-1, 512, 52, 52]   1,024
    ReLU-46   [-1, 512, 52, 52]    0
  DoubleConv-47   [-1, 512, 52, 52]    0
 ConvTranspose2d-48  [-1, 256, 104, 104]   524,544
   Conv2d-49  [-1, 256, 102, 102]  1,179,904
  BatchNorm2d-50  [-1, 256, 102, 102]    512
    ReLU-51  [-1, 256, 102, 102]    0
   Conv2d-52  [-1, 256, 100, 100]   590,080
  BatchNorm2d-53  [-1, 256, 100, 100]    512
    ReLU-54  [-1, 256, 100, 100]    0
  DoubleConv-55  [-1, 256, 100, 100]    0
 ConvTranspose2d-56  [-1, 128, 200, 200]   131,200
   Conv2d-57  [-1, 128, 198, 198]   295,040
  BatchNorm2d-58  [-1, 128, 198, 198]    256
    ReLU-59  [-1, 128, 198, 198]    0
   Conv2d-60  [-1, 128, 196, 196]   147,584
  BatchNorm2d-61  [-1, 128, 196, 196]    256
    ReLU-62  [-1, 128, 196, 196]    0
  DoubleConv-63  [-1, 128, 196, 196]    0
 ConvTranspose2d-64   [-1, 64, 392, 392]   32,832
   Conv2d-65   [-1, 64, 390, 390]   73,792
  BatchNorm2d-66   [-1, 64, 390, 390]    128
    ReLU-67   [-1, 64, 390, 390]    0
   Conv2d-68   [-1, 64, 388, 388]   36,928
  BatchNorm2d-69   [-1, 64, 388, 388]    128
    ReLU-70   [-1, 64, 388, 388]    0
  DoubleConv-71   [-1, 64, 388, 388]    0
   Conv2d-72   [-1, 2, 388, 388]    130
================================================================
Total params: 31,042,434
Trainable params: 31,042,434
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 1.25
Forward/backward pass size (MB): 3280.59
Params size (MB): 118.42
Estimated Total Size (MB): 3400.26
----------------------------------------------------------------
torch.Size([1, 2, 388, 388])

以上这篇使用pytorch实现论文中的unet网络就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用PDB简单调试Python程序简明指南
Apr 25 Python
简单介绍Python中的try和finally和with方法
May 05 Python
python任务调度实例分析
May 19 Python
详解如何用OpenCV + Python 实现人脸识别
Oct 20 Python
[原创]教女朋友学Python3(二)简单的输入输出及内置函数查看
Nov 30 Python
Django restframework 源码分析之认证详解
Feb 22 Python
Python Django框架单元测试之文件上传测试示例
May 17 Python
分享8个非常流行的 Python 可视化工具包
Jun 05 Python
在python中用print()输出多个格式化参数的方法
Jul 16 Python
使用Keras实现Tensor的相乘和相加代码
Jun 18 Python
python爬取豆瓣电影排行榜(requests)的示例代码
Feb 18 Python
python开发的自动化运维工具ansible详解
Aug 07 Python
python连接mysql有哪些方法
Jun 24 #Python
pytorch VGG11识别cifar10数据集(训练+预测单张输入图片操作)
Jun 24 #Python
Python Tornado核心及相关原理详解
Jun 24 #Python
如何使用Python处理HDF格式数据及可视化问题
Jun 24 #Python
pytorch SENet实现案例
Jun 24 #Python
利用PyTorch实现VGG16教程
Jun 24 #Python
python安装读取grib库总结(推荐)
Jun 24 #Python
You might like
深入PHP与浏览器缓存的分析
2013/06/03 PHP
php的SimpleXML方法读写XML接口文件实例解析
2014/06/16 PHP
phpstorm编辑器乱码问题解决
2014/12/01 PHP
大家都应该掌握的PHP关联数组使用技巧
2015/12/25 PHP
PHP常用操作类之通信数据封装类的实现
2017/07/16 PHP
javascript背投广告代码的完善
2008/04/08 Javascript
一款由jquery实现的整屏切换特效
2014/09/15 Javascript
jQuery控制网页打印指定区域的方法
2015/04/07 Javascript
JavaScript解八皇后问题的方法总结
2016/06/12 Javascript
js实现简单的获取验证码按钮效果
2017/03/03 Javascript
详解vue的数据binding绑定原理
2017/04/12 Javascript
js实现canvas图片与img图片的相互转换的示例
2017/08/31 Javascript
vue项目实现github在线预览功能
2018/06/20 Javascript
如何将HTML字符转换为DOM节点并动态添加到文档中详解
2018/08/19 Javascript
微信小程序实现登录注册tab切换效果
2020/12/29 Javascript
通过实例学习React中事件节流防抖
2019/06/17 Javascript
webpack4.0+vue2.0利用批处理生成前端单页或多页应用的方法
2019/06/28 Javascript
[34:56]Ti4冒泡赛LGD vs Liquid 1
2014/07/14 DOTA
[01:09:20]NB vs NAVI Supermajor小组赛A组 BO3 第二场 6.2
2018/06/03 DOTA
打印出python 当前全局变量和入口参数的所有属性
2009/07/01 Python
Python实现在Linux系统下更改当前进程运行用户
2015/02/04 Python
Python遍历目录中的所有文件的方法
2016/07/08 Python
Python用户推荐系统曼哈顿算法实现完整代码
2017/12/01 Python
python实现用户管理系统
2018/01/10 Python
详解通过API管理或定制开发ECS实例
2018/09/30 Python
Python实现某论坛自动签到功能
2019/08/20 Python
pygame实现俄罗斯方块游戏(基础篇3)
2019/10/29 Python
Python FtpLib模块应用操作详解
2019/12/12 Python
Python判断字符串是否为空和null方法实例
2020/04/26 Python
Softmax函数原理及Python实现过程解析
2020/05/22 Python
pytorch SENet实现案例
2020/06/24 Python
Ralph Lauren英国官方网站:Ralph Lauren UK
2018/04/03 全球购物
《匆匆》教学反思
2014/02/22 职场文书
学生吸烟检讨书
2014/09/14 职场文书
大学生操行评语大全
2014/12/31 职场文书
国产动画《万圣街》日语配音版制作决定!
2022/03/20 国漫