pytorch构建网络模型的4种方法


Posted in Python onApril 13, 2018

利用pytorch来构建网络模型有很多种方法,以下简单列出其中的四种。

假设构建一个网络模型如下:

卷积层--》Relu层--》池化层--》全连接层--》Relu层--》全连接层

首先导入几种方法用到的包:

import torch
import torch.nn.functional as F
from collections import OrderedDict

第一种方法

# Method 1 -----------------------------------------

class Net1(torch.nn.Module):
  def __init__(self):
    super(Net1, self).__init__()
    self.conv1 = torch.nn.Conv2d(3, 32, 3, 1, 1)
    self.dense1 = torch.nn.Linear(32 * 3 * 3, 128)
    self.dense2 = torch.nn.Linear(128, 10)

  def forward(self, x):
    x = F.max_pool2d(F.relu(self.conv(x)), 2)
    x = x.view(x.size(0), -1)
    x = F.relu(self.dense1(x))
    x = self.dense2(x)
    return x

print("Method 1:")
model1 = Net1()
print(model1)

这种方法比较常用,早期的教程通常就是使用这种方法。

pytorch构建网络模型的4种方法

第二种方法

# Method 2 ------------------------------------------
class Net2(torch.nn.Module):
  def __init__(self):
    super(Net2, self).__init__()
    self.conv = torch.nn.Sequential(
      torch.nn.Conv2d(3, 32, 3, 1, 1),
      torch.nn.ReLU(),
      torch.nn.MaxPool2d(2))
    self.dense = torch.nn.Sequential(
      torch.nn.Linear(32 * 3 * 3, 128),
      torch.nn.ReLU(),
      torch.nn.Linear(128, 10)
    )

  def forward(self, x):
    conv_out = self.conv1(x)
    res = conv_out.view(conv_out.size(0), -1)
    out = self.dense(res)
    return out

print("Method 2:")
model2 = Net2()
print(model2)

pytorch构建网络模型的4种方法

这种方法利用torch.nn.Sequential()容器进行快速搭建,模型的各层被顺序添加到容器中。缺点是每层的编号是默认的阿拉伯数字,不易区分。

第三种方法:

# Method 3 -------------------------------
class Net3(torch.nn.Module):
  def __init__(self):
    super(Net3, self).__init__()
    self.conv=torch.nn.Sequential()
    self.conv.add_module("conv1",torch.nn.Conv2d(3, 32, 3, 1, 1))
    self.conv.add_module("relu1",torch.nn.ReLU())
    self.conv.add_module("pool1",torch.nn.MaxPool2d(2))
    self.dense = torch.nn.Sequential()
    self.dense.add_module("dense1",torch.nn.Linear(32 * 3 * 3, 128))
    self.dense.add_module("relu2",torch.nn.ReLU())
    self.dense.add_module("dense2",torch.nn.Linear(128, 10))

  def forward(self, x):
    conv_out = self.conv1(x)
    res = conv_out.view(conv_out.size(0), -1)
    out = self.dense(res)
    return out

print("Method 3:")
model3 = Net3()
print(model3)

pytorch构建网络模型的4种方法

这种方法是对第二种方法的改进:通过add_module()添加每一层,并且为每一层增加了一个单独的名字。 

第四种方法:

# Method 4 ------------------------------------------
class Net4(torch.nn.Module):
  def __init__(self):
    super(Net4, self).__init__()
    self.conv = torch.nn.Sequential(
      OrderedDict(
        [
          ("conv1", torch.nn.Conv2d(3, 32, 3, 1, 1)),
          ("relu1", torch.nn.ReLU()),
          ("pool", torch.nn.MaxPool2d(2))
        ]
      ))

    self.dense = torch.nn.Sequential(
      OrderedDict([
        ("dense1", torch.nn.Linear(32 * 3 * 3, 128)),
        ("relu2", torch.nn.ReLU()),
        ("dense2", torch.nn.Linear(128, 10))
      ])
    )

  def forward(self, x):
    conv_out = self.conv1(x)
    res = conv_out.view(conv_out.size(0), -1)
    out = self.dense(res)
    return out

print("Method 4:")
model4 = Net4()
print(model4)

pytorch构建网络模型的4种方法

是第三种方法的另外一种写法,通过字典的形式添加每一层,并且设置单独的层名称。

完整代码:

import torch
import torch.nn.functional as F
from collections import OrderedDict

# Method 1 -----------------------------------------

class Net1(torch.nn.Module):
  def __init__(self):
    super(Net1, self).__init__()
    self.conv1 = torch.nn.Conv2d(3, 32, 3, 1, 1)
    self.dense1 = torch.nn.Linear(32 * 3 * 3, 128)
    self.dense2 = torch.nn.Linear(128, 10)

  def forward(self, x):
    x = F.max_pool2d(F.relu(self.conv(x)), 2)
    x = x.view(x.size(0), -1)
    x = F.relu(self.dense1(x))
    x = self.dense2()
    return x

print("Method 1:")
model1 = Net1()
print(model1)


# Method 2 ------------------------------------------
class Net2(torch.nn.Module):
  def __init__(self):
    super(Net2, self).__init__()
    self.conv = torch.nn.Sequential(
      torch.nn.Conv2d(3, 32, 3, 1, 1),
      torch.nn.ReLU(),
      torch.nn.MaxPool2d(2))
    self.dense = torch.nn.Sequential(
      torch.nn.Linear(32 * 3 * 3, 128),
      torch.nn.ReLU(),
      torch.nn.Linear(128, 10)
    )

  def forward(self, x):
    conv_out = self.conv1(x)
    res = conv_out.view(conv_out.size(0), -1)
    out = self.dense(res)
    return out

print("Method 2:")
model2 = Net2()
print(model2)


# Method 3 -------------------------------
class Net3(torch.nn.Module):
  def __init__(self):
    super(Net3, self).__init__()
    self.conv=torch.nn.Sequential()
    self.conv.add_module("conv1",torch.nn.Conv2d(3, 32, 3, 1, 1))
    self.conv.add_module("relu1",torch.nn.ReLU())
    self.conv.add_module("pool1",torch.nn.MaxPool2d(2))
    self.dense = torch.nn.Sequential()
    self.dense.add_module("dense1",torch.nn.Linear(32 * 3 * 3, 128))
    self.dense.add_module("relu2",torch.nn.ReLU())
    self.dense.add_module("dense2",torch.nn.Linear(128, 10))

  def forward(self, x):
    conv_out = self.conv1(x)
    res = conv_out.view(conv_out.size(0), -1)
    out = self.dense(res)
    return out

print("Method 3:")
model3 = Net3()
print(model3)



# Method 4 ------------------------------------------
class Net4(torch.nn.Module):
  def __init__(self):
    super(Net4, self).__init__()
    self.conv = torch.nn.Sequential(
      OrderedDict(
        [
          ("conv1", torch.nn.Conv2d(3, 32, 3, 1, 1)),
          ("relu1", torch.nn.ReLU()),
          ("pool", torch.nn.MaxPool2d(2))
        ]
      ))

    self.dense = torch.nn.Sequential(
      OrderedDict([
        ("dense1", torch.nn.Linear(32 * 3 * 3, 128)),
        ("relu2", torch.nn.ReLU()),
        ("dense2", torch.nn.Linear(128, 10))
      ])
    )

  def forward(self, x):
    conv_out = self.conv1(x)
    res = conv_out.view(conv_out.size(0), -1)
    out = self.dense(res)
    return out

print("Method 4:")
model4 = Net4()
print(model4)

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中unittest模块做UT(单元测试)使用实例
Jun 12 Python
Python中特殊函数集锦
Jul 27 Python
浅谈django model的get和filter方法的区别(必看篇)
May 23 Python
用Python删除本地目录下某一时间点之前创建的所有文件的实例
Dec 14 Python
TensorFLow用Saver保存和恢复变量
Mar 10 Python
TensorFlow Session使用的两种方法小结
Jul 30 Python
python中使用zip函数出现错误的原因
Sep 28 Python
python3使用matplotlib绘制条形图
Mar 25 Python
python 正则表达式参数替换实例详解
Jan 17 Python
python对数组进行排序,并输出排序后对应的索引值方式
Feb 28 Python
Python classmethod装饰器原理及用法解析
Oct 17 Python
python爬虫利器之requests库的用法(超全面的爬取网页案例)
Dec 17 Python
Python输入二维数组方法
Apr 13 #Python
Python基于递归实现电话号码映射功能示例
Apr 13 #Python
Python的多维空数组赋值方法
Apr 13 #Python
python多维数组切片方法
Apr 13 #Python
Python实现判断并移除列表指定位置元素的方法
Apr 13 #Python
Python中的二维数组实例(list与numpy.array)
Apr 13 #Python
对numpy的array和python中自带的list之间相互转化详解
Apr 13 #Python
You might like
PHP新手上路(十三)
2006/10/09 PHP
一个简单的PHP&MYSQL留言板源码
2020/07/19 PHP
PHP生成唯一的促销/优惠/折扣码(附源码)
2012/12/28 PHP
解析在zend Farmework下如何创立一个FORM表单
2013/06/28 PHP
64位windows系统下安装Memcache缓存
2015/12/06 PHP
弹出模态框modal的实现方法及实例
2017/09/19 PHP
Laravel 6.2 中添加了可调用容器对象的方法
2019/10/22 PHP
如何通过PHP实现Des加密算法代码实例
2020/05/09 PHP
contains和compareDocumentPosition 方法来确定是否HTML节点间的关系
2011/09/13 Javascript
JSON语法五大要素图文介绍
2012/12/04 Javascript
XMLHttpRequest处理xml格式的返回数据(示例代码)
2013/11/21 Javascript
用javascript为页面添加天气显示实现思路及代码
2013/12/02 Javascript
AngularJS表单详解及示例代码
2016/08/17 Javascript
Js实现中国公民身份证号码有效性验证实例代码
2017/05/03 Javascript
js中split()方法得到的数组长度问题
2018/07/19 Javascript
bootstrap table列和表头对不齐的解决方法
2019/07/19 Javascript
微信小程序 组件的外部样式externalClasses使用详解
2019/09/06 Javascript
详解关于Vue单元测试的几个坑
2020/04/26 Javascript
[02:42]DOTA2英雄基础教程 杰奇洛
2013/12/23 DOTA
[01:48]完美圣典齐天大圣至宝宣传片
2016/12/17 DOTA
python ip正则式
2009/05/07 Python
python网络编程学习笔记(二):socket建立网络客户端
2014/06/09 Python
python数据处理 根据颜色对图片进行分类的方法
2018/12/08 Python
实例详解Matlab 与 Python 的区别
2019/04/26 Python
python3.x 生成3维随机数组实例
2019/11/28 Python
python基于socket函数实现端口扫描
2020/05/28 Python
萌新的HTML5 入门指南
2020/11/06 HTML / CSS
印度购物网站:TATA CLiQ
2017/11/23 全球购物
捷克体育用品购物网站:D-sport
2017/12/28 全球购物
乌克兰移动电子产品和相关配件的在线商店:iTMag
2020/03/16 全球购物
CSS实现fullpage.js全屏滚动效果的示例代码
2021/03/24 HTML / CSS
经典大学生求职信范文
2014/01/06 职场文书
暂住证明怎么写
2015/06/19 职场文书
圣诞晚会主持词
2015/07/01 职场文书
Spring Boot 启动、停止、重启、状态脚本
2021/06/26 Java/Android
解决Windows Server2012 R2 无法安装 .NET Framework 3.5
2022/04/29 Servers