pytorch构建网络模型的4种方法


Posted in Python onApril 13, 2018

利用pytorch来构建网络模型有很多种方法,以下简单列出其中的四种。

假设构建一个网络模型如下:

卷积层--》Relu层--》池化层--》全连接层--》Relu层--》全连接层

首先导入几种方法用到的包:

import torch
import torch.nn.functional as F
from collections import OrderedDict

第一种方法

# Method 1 -----------------------------------------

class Net1(torch.nn.Module):
  def __init__(self):
    super(Net1, self).__init__()
    self.conv1 = torch.nn.Conv2d(3, 32, 3, 1, 1)
    self.dense1 = torch.nn.Linear(32 * 3 * 3, 128)
    self.dense2 = torch.nn.Linear(128, 10)

  def forward(self, x):
    x = F.max_pool2d(F.relu(self.conv(x)), 2)
    x = x.view(x.size(0), -1)
    x = F.relu(self.dense1(x))
    x = self.dense2(x)
    return x

print("Method 1:")
model1 = Net1()
print(model1)

这种方法比较常用,早期的教程通常就是使用这种方法。

pytorch构建网络模型的4种方法

第二种方法

# Method 2 ------------------------------------------
class Net2(torch.nn.Module):
  def __init__(self):
    super(Net2, self).__init__()
    self.conv = torch.nn.Sequential(
      torch.nn.Conv2d(3, 32, 3, 1, 1),
      torch.nn.ReLU(),
      torch.nn.MaxPool2d(2))
    self.dense = torch.nn.Sequential(
      torch.nn.Linear(32 * 3 * 3, 128),
      torch.nn.ReLU(),
      torch.nn.Linear(128, 10)
    )

  def forward(self, x):
    conv_out = self.conv1(x)
    res = conv_out.view(conv_out.size(0), -1)
    out = self.dense(res)
    return out

print("Method 2:")
model2 = Net2()
print(model2)

pytorch构建网络模型的4种方法

这种方法利用torch.nn.Sequential()容器进行快速搭建,模型的各层被顺序添加到容器中。缺点是每层的编号是默认的阿拉伯数字,不易区分。

第三种方法:

# Method 3 -------------------------------
class Net3(torch.nn.Module):
  def __init__(self):
    super(Net3, self).__init__()
    self.conv=torch.nn.Sequential()
    self.conv.add_module("conv1",torch.nn.Conv2d(3, 32, 3, 1, 1))
    self.conv.add_module("relu1",torch.nn.ReLU())
    self.conv.add_module("pool1",torch.nn.MaxPool2d(2))
    self.dense = torch.nn.Sequential()
    self.dense.add_module("dense1",torch.nn.Linear(32 * 3 * 3, 128))
    self.dense.add_module("relu2",torch.nn.ReLU())
    self.dense.add_module("dense2",torch.nn.Linear(128, 10))

  def forward(self, x):
    conv_out = self.conv1(x)
    res = conv_out.view(conv_out.size(0), -1)
    out = self.dense(res)
    return out

print("Method 3:")
model3 = Net3()
print(model3)

pytorch构建网络模型的4种方法

这种方法是对第二种方法的改进:通过add_module()添加每一层,并且为每一层增加了一个单独的名字。 

第四种方法:

# Method 4 ------------------------------------------
class Net4(torch.nn.Module):
  def __init__(self):
    super(Net4, self).__init__()
    self.conv = torch.nn.Sequential(
      OrderedDict(
        [
          ("conv1", torch.nn.Conv2d(3, 32, 3, 1, 1)),
          ("relu1", torch.nn.ReLU()),
          ("pool", torch.nn.MaxPool2d(2))
        ]
      ))

    self.dense = torch.nn.Sequential(
      OrderedDict([
        ("dense1", torch.nn.Linear(32 * 3 * 3, 128)),
        ("relu2", torch.nn.ReLU()),
        ("dense2", torch.nn.Linear(128, 10))
      ])
    )

  def forward(self, x):
    conv_out = self.conv1(x)
    res = conv_out.view(conv_out.size(0), -1)
    out = self.dense(res)
    return out

print("Method 4:")
model4 = Net4()
print(model4)

pytorch构建网络模型的4种方法

是第三种方法的另外一种写法,通过字典的形式添加每一层,并且设置单独的层名称。

完整代码:

import torch
import torch.nn.functional as F
from collections import OrderedDict

# Method 1 -----------------------------------------

class Net1(torch.nn.Module):
  def __init__(self):
    super(Net1, self).__init__()
    self.conv1 = torch.nn.Conv2d(3, 32, 3, 1, 1)
    self.dense1 = torch.nn.Linear(32 * 3 * 3, 128)
    self.dense2 = torch.nn.Linear(128, 10)

  def forward(self, x):
    x = F.max_pool2d(F.relu(self.conv(x)), 2)
    x = x.view(x.size(0), -1)
    x = F.relu(self.dense1(x))
    x = self.dense2()
    return x

print("Method 1:")
model1 = Net1()
print(model1)


# Method 2 ------------------------------------------
class Net2(torch.nn.Module):
  def __init__(self):
    super(Net2, self).__init__()
    self.conv = torch.nn.Sequential(
      torch.nn.Conv2d(3, 32, 3, 1, 1),
      torch.nn.ReLU(),
      torch.nn.MaxPool2d(2))
    self.dense = torch.nn.Sequential(
      torch.nn.Linear(32 * 3 * 3, 128),
      torch.nn.ReLU(),
      torch.nn.Linear(128, 10)
    )

  def forward(self, x):
    conv_out = self.conv1(x)
    res = conv_out.view(conv_out.size(0), -1)
    out = self.dense(res)
    return out

print("Method 2:")
model2 = Net2()
print(model2)


# Method 3 -------------------------------
class Net3(torch.nn.Module):
  def __init__(self):
    super(Net3, self).__init__()
    self.conv=torch.nn.Sequential()
    self.conv.add_module("conv1",torch.nn.Conv2d(3, 32, 3, 1, 1))
    self.conv.add_module("relu1",torch.nn.ReLU())
    self.conv.add_module("pool1",torch.nn.MaxPool2d(2))
    self.dense = torch.nn.Sequential()
    self.dense.add_module("dense1",torch.nn.Linear(32 * 3 * 3, 128))
    self.dense.add_module("relu2",torch.nn.ReLU())
    self.dense.add_module("dense2",torch.nn.Linear(128, 10))

  def forward(self, x):
    conv_out = self.conv1(x)
    res = conv_out.view(conv_out.size(0), -1)
    out = self.dense(res)
    return out

print("Method 3:")
model3 = Net3()
print(model3)



# Method 4 ------------------------------------------
class Net4(torch.nn.Module):
  def __init__(self):
    super(Net4, self).__init__()
    self.conv = torch.nn.Sequential(
      OrderedDict(
        [
          ("conv1", torch.nn.Conv2d(3, 32, 3, 1, 1)),
          ("relu1", torch.nn.ReLU()),
          ("pool", torch.nn.MaxPool2d(2))
        ]
      ))

    self.dense = torch.nn.Sequential(
      OrderedDict([
        ("dense1", torch.nn.Linear(32 * 3 * 3, 128)),
        ("relu2", torch.nn.ReLU()),
        ("dense2", torch.nn.Linear(128, 10))
      ])
    )

  def forward(self, x):
    conv_out = self.conv1(x)
    res = conv_out.view(conv_out.size(0), -1)
    out = self.dense(res)
    return out

print("Method 4:")
model4 = Net4()
print(model4)

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python open读写文件实现脚本
Sep 06 Python
python中的__slots__使用示例
Feb 26 Python
python自带的http模块详解
Nov 06 Python
简单学习Python多进程Multiprocessing
Aug 29 Python
Python编程实现粒子群算法(PSO)详解
Nov 13 Python
Flask框架配置与调试操作示例
Jul 23 Python
Python中list的交、并、差集获取方法示例
Aug 01 Python
用python写测试数据文件过程解析
Sep 25 Python
python实现五子棋游戏(pygame版)
Jan 19 Python
pytorch torchvision.ImageFolder的用法介绍
Feb 20 Python
Python极值整数的边界探讨分析
Sep 15 Python
Python 视频画质增强
Apr 28 Python
Python输入二维数组方法
Apr 13 #Python
Python基于递归实现电话号码映射功能示例
Apr 13 #Python
Python的多维空数组赋值方法
Apr 13 #Python
python多维数组切片方法
Apr 13 #Python
Python实现判断并移除列表指定位置元素的方法
Apr 13 #Python
Python中的二维数组实例(list与numpy.array)
Apr 13 #Python
对numpy的array和python中自带的list之间相互转化详解
Apr 13 #Python
You might like
PHP IDE phpstorm 常用快捷键
2015/05/18 PHP
在Thinkphp中使用ajax实现无刷新分页的方法
2016/10/25 PHP
php微信公众平台示例代码分析(二)
2016/12/06 PHP
PHP常见数组排序方法小结
2018/08/20 PHP
php curl操作API接口类完整示例
2019/05/21 PHP
PHP工厂模式、单例模式与注册树模式实例详解
2019/06/03 PHP
web css实现整站样式互相切换
2013/10/29 Javascript
检查输入的是否是数字使用keyCode配合onkeypress事件
2014/01/23 Javascript
jQuery判断多个input file 都不能为空的例子
2015/06/23 Javascript
你不需要jQuery(三) 新AJAX方法fetch()
2016/06/14 Javascript
AngularJS中的包含详细介绍及实现示例
2016/07/28 Javascript
JS中substring与substr的用法
2016/11/16 Javascript
深入理解vue.js中的v-if和v-show
2017/06/22 Javascript
Vue-router路由判断页面未登录跳转到登录页面的实例
2017/10/26 Javascript
JS简单获取并修改input文本框内容的方法示例
2018/04/08 Javascript
JS实现的RC4加密算法示例
2018/08/16 Javascript
layui将table转化表单显示的方法(即table.render转为表单展示)
2019/09/24 Javascript
React组件设计模式之组合组件应用实例分析
2020/04/29 Javascript
[54:25]Ti4 循环赛第三日LGD vs MOUZ
2014/07/12 DOTA
Python多进程编程技术实例分析
2014/09/16 Python
Python使用pymysql小技巧
2017/06/04 Python
python opencv实现图片旋转矩形分割
2018/07/26 Python
解决python3.5 正常安装 却不能直接使用Tkinter包的问题
2019/02/22 Python
Python 导入文件过程图解
2019/10/15 Python
Python大数据之从网页上爬取数据的方法详解
2019/11/16 Python
wxpython多线程防假死与线程间传递消息实例详解
2019/12/13 Python
Python3监控疫情的完整代码
2020/02/20 Python
在Python 的线程中运行协程的方法
2020/02/24 Python
解析Tensorflow之MNIST的使用
2020/06/30 Python
django form和field具体方法和属性说明
2020/07/09 Python
倩碧澳大利亚官网:Clinique澳大利亚
2019/07/22 全球购物
理工大学毕业生自荐信
2013/11/01 职场文书
法律进机关实施方案
2014/03/12 职场文书
中文专业毕业生自荐信
2014/05/24 职场文书
信用卡工作证明模板
2014/09/14 职场文书
征求意见函
2015/06/05 职场文书