pytorch构建网络模型的4种方法


Posted in Python onApril 13, 2018

利用pytorch来构建网络模型有很多种方法,以下简单列出其中的四种。

假设构建一个网络模型如下:

卷积层--》Relu层--》池化层--》全连接层--》Relu层--》全连接层

首先导入几种方法用到的包:

import torch
import torch.nn.functional as F
from collections import OrderedDict

第一种方法

# Method 1 -----------------------------------------

class Net1(torch.nn.Module):
  def __init__(self):
    super(Net1, self).__init__()
    self.conv1 = torch.nn.Conv2d(3, 32, 3, 1, 1)
    self.dense1 = torch.nn.Linear(32 * 3 * 3, 128)
    self.dense2 = torch.nn.Linear(128, 10)

  def forward(self, x):
    x = F.max_pool2d(F.relu(self.conv(x)), 2)
    x = x.view(x.size(0), -1)
    x = F.relu(self.dense1(x))
    x = self.dense2(x)
    return x

print("Method 1:")
model1 = Net1()
print(model1)

这种方法比较常用,早期的教程通常就是使用这种方法。

pytorch构建网络模型的4种方法

第二种方法

# Method 2 ------------------------------------------
class Net2(torch.nn.Module):
  def __init__(self):
    super(Net2, self).__init__()
    self.conv = torch.nn.Sequential(
      torch.nn.Conv2d(3, 32, 3, 1, 1),
      torch.nn.ReLU(),
      torch.nn.MaxPool2d(2))
    self.dense = torch.nn.Sequential(
      torch.nn.Linear(32 * 3 * 3, 128),
      torch.nn.ReLU(),
      torch.nn.Linear(128, 10)
    )

  def forward(self, x):
    conv_out = self.conv1(x)
    res = conv_out.view(conv_out.size(0), -1)
    out = self.dense(res)
    return out

print("Method 2:")
model2 = Net2()
print(model2)

pytorch构建网络模型的4种方法

这种方法利用torch.nn.Sequential()容器进行快速搭建,模型的各层被顺序添加到容器中。缺点是每层的编号是默认的阿拉伯数字,不易区分。

第三种方法:

# Method 3 -------------------------------
class Net3(torch.nn.Module):
  def __init__(self):
    super(Net3, self).__init__()
    self.conv=torch.nn.Sequential()
    self.conv.add_module("conv1",torch.nn.Conv2d(3, 32, 3, 1, 1))
    self.conv.add_module("relu1",torch.nn.ReLU())
    self.conv.add_module("pool1",torch.nn.MaxPool2d(2))
    self.dense = torch.nn.Sequential()
    self.dense.add_module("dense1",torch.nn.Linear(32 * 3 * 3, 128))
    self.dense.add_module("relu2",torch.nn.ReLU())
    self.dense.add_module("dense2",torch.nn.Linear(128, 10))

  def forward(self, x):
    conv_out = self.conv1(x)
    res = conv_out.view(conv_out.size(0), -1)
    out = self.dense(res)
    return out

print("Method 3:")
model3 = Net3()
print(model3)

pytorch构建网络模型的4种方法

这种方法是对第二种方法的改进:通过add_module()添加每一层,并且为每一层增加了一个单独的名字。 

第四种方法:

# Method 4 ------------------------------------------
class Net4(torch.nn.Module):
  def __init__(self):
    super(Net4, self).__init__()
    self.conv = torch.nn.Sequential(
      OrderedDict(
        [
          ("conv1", torch.nn.Conv2d(3, 32, 3, 1, 1)),
          ("relu1", torch.nn.ReLU()),
          ("pool", torch.nn.MaxPool2d(2))
        ]
      ))

    self.dense = torch.nn.Sequential(
      OrderedDict([
        ("dense1", torch.nn.Linear(32 * 3 * 3, 128)),
        ("relu2", torch.nn.ReLU()),
        ("dense2", torch.nn.Linear(128, 10))
      ])
    )

  def forward(self, x):
    conv_out = self.conv1(x)
    res = conv_out.view(conv_out.size(0), -1)
    out = self.dense(res)
    return out

print("Method 4:")
model4 = Net4()
print(model4)

pytorch构建网络模型的4种方法

是第三种方法的另外一种写法,通过字典的形式添加每一层,并且设置单独的层名称。

完整代码:

import torch
import torch.nn.functional as F
from collections import OrderedDict

# Method 1 -----------------------------------------

class Net1(torch.nn.Module):
  def __init__(self):
    super(Net1, self).__init__()
    self.conv1 = torch.nn.Conv2d(3, 32, 3, 1, 1)
    self.dense1 = torch.nn.Linear(32 * 3 * 3, 128)
    self.dense2 = torch.nn.Linear(128, 10)

  def forward(self, x):
    x = F.max_pool2d(F.relu(self.conv(x)), 2)
    x = x.view(x.size(0), -1)
    x = F.relu(self.dense1(x))
    x = self.dense2()
    return x

print("Method 1:")
model1 = Net1()
print(model1)


# Method 2 ------------------------------------------
class Net2(torch.nn.Module):
  def __init__(self):
    super(Net2, self).__init__()
    self.conv = torch.nn.Sequential(
      torch.nn.Conv2d(3, 32, 3, 1, 1),
      torch.nn.ReLU(),
      torch.nn.MaxPool2d(2))
    self.dense = torch.nn.Sequential(
      torch.nn.Linear(32 * 3 * 3, 128),
      torch.nn.ReLU(),
      torch.nn.Linear(128, 10)
    )

  def forward(self, x):
    conv_out = self.conv1(x)
    res = conv_out.view(conv_out.size(0), -1)
    out = self.dense(res)
    return out

print("Method 2:")
model2 = Net2()
print(model2)


# Method 3 -------------------------------
class Net3(torch.nn.Module):
  def __init__(self):
    super(Net3, self).__init__()
    self.conv=torch.nn.Sequential()
    self.conv.add_module("conv1",torch.nn.Conv2d(3, 32, 3, 1, 1))
    self.conv.add_module("relu1",torch.nn.ReLU())
    self.conv.add_module("pool1",torch.nn.MaxPool2d(2))
    self.dense = torch.nn.Sequential()
    self.dense.add_module("dense1",torch.nn.Linear(32 * 3 * 3, 128))
    self.dense.add_module("relu2",torch.nn.ReLU())
    self.dense.add_module("dense2",torch.nn.Linear(128, 10))

  def forward(self, x):
    conv_out = self.conv1(x)
    res = conv_out.view(conv_out.size(0), -1)
    out = self.dense(res)
    return out

print("Method 3:")
model3 = Net3()
print(model3)



# Method 4 ------------------------------------------
class Net4(torch.nn.Module):
  def __init__(self):
    super(Net4, self).__init__()
    self.conv = torch.nn.Sequential(
      OrderedDict(
        [
          ("conv1", torch.nn.Conv2d(3, 32, 3, 1, 1)),
          ("relu1", torch.nn.ReLU()),
          ("pool", torch.nn.MaxPool2d(2))
        ]
      ))

    self.dense = torch.nn.Sequential(
      OrderedDict([
        ("dense1", torch.nn.Linear(32 * 3 * 3, 128)),
        ("relu2", torch.nn.ReLU()),
        ("dense2", torch.nn.Linear(128, 10))
      ])
    )

  def forward(self, x):
    conv_out = self.conv1(x)
    res = conv_out.view(conv_out.size(0), -1)
    out = self.dense(res)
    return out

print("Method 4:")
model4 = Net4()
print(model4)

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
基于python编写的微博应用
Oct 17 Python
Python Sleep休眠函数使用简单实例
Feb 02 Python
python处理csv数据的方法
Mar 11 Python
在Python的循环体中使用else语句的方法
Mar 30 Python
Python代码解决RenderView窗口not found问题
Aug 28 Python
python构建自定义回调函数详解
Jun 20 Python
Pycharm技巧之代码跳转该如何回退
Jul 16 Python
python 实现读取一个excel多个sheet表并合并的方法
Feb 12 Python
python BlockingScheduler定时任务及其他方式的实现
Sep 19 Python
解决python调用自己文件函数/执行函数找不到包问题
Jun 01 Python
详解查看Python解释器路径的两种方式
Oct 15 Python
Ubuntu20.04环境安装tensorflow2的方法步骤
Jan 29 Python
Python输入二维数组方法
Apr 13 #Python
Python基于递归实现电话号码映射功能示例
Apr 13 #Python
Python的多维空数组赋值方法
Apr 13 #Python
python多维数组切片方法
Apr 13 #Python
Python实现判断并移除列表指定位置元素的方法
Apr 13 #Python
Python中的二维数组实例(list与numpy.array)
Apr 13 #Python
对numpy的array和python中自带的list之间相互转化详解
Apr 13 #Python
You might like
使用PHP强制下载PDF文件示例
2014/01/17 PHP
php使用多个进程同时控制文件读写示例
2014/02/28 PHP
php实现评论回复删除功能
2017/05/23 PHP
Laravel5.5以下版本中如何自定义日志行为详解
2018/08/01 PHP
有趣的JavaScript数组长度问题代码说明
2011/01/20 Javascript
jquery插件制作 提示框插件实现代码
2012/08/17 Javascript
js解析与序列化json数据(二)序列化探讨
2013/02/01 Javascript
JS 数字转换研究总结
2013/12/26 Javascript
JS动态添加iframe的代码
2015/09/14 Javascript
Bootstrap每天必学之js插件
2015/11/30 Javascript
js实现表单及时验证功能 用户信息立即验证
2016/09/13 Javascript
Bootstrap源码解读按钮(5)
2016/12/23 Javascript
百度地图JavascriptApi Marker平滑移动及车头指向行径方向
2017/03/13 Javascript
Vue.js展示AJAX数据简单示例讲解
2017/03/29 Javascript
javascript、php关键字搜索函数的使用方法
2018/05/29 Javascript
JavaScript面试技巧之数组的一些不low操作
2019/03/22 Javascript
Node.js实现一个HTTP服务器的方法示例
2019/05/13 Javascript
Vue内部渲染视图的方法
2019/09/02 Javascript
[57:53]Secret vs Pain 2018国际邀请赛小组赛BO2 第二场 8.17
2018/08/20 DOTA
[58:46]OG vs NAVI 2019国际邀请赛小组赛 BO2 第二场 8.15
2019/08/17 DOTA
使用python分析git log日志示例
2014/02/27 Python
在Django中进行用户注册和邮箱验证的方法
2016/05/09 Python
python实现FTP服务器服务的方法
2017/04/11 Python
通过Python 接口使用OpenCV的方法
2018/04/02 Python
Python实现对文件进行单词划分并去重排序操作示例
2018/07/10 Python
python爬取酷狗音乐排行榜
2019/02/20 Python
Python 中的 import 机制之实现远程导入模块
2019/10/29 Python
将tf.batch_matmul替换成tf.matmul的实现
2020/06/18 Python
使用Python绘制台风轨迹图的示例代码
2020/09/21 Python
html5 web本地存储将取代我们的cookie
2012/12/26 HTML / CSS
美国名牌太阳镜折扣网站:Eyedictive
2017/05/15 全球购物
JSF面试题:如何管量web层中的Bean,用什么标签。如何通过jsp页面与Bean绑定在一起进行处理?
2012/10/05 面试题
管理专员自荐信
2014/01/26 职场文书
Python Numpy之linspace用法说明
2021/04/17 Python
Python中三种花式打印的示例详解
2022/03/19 Python
解决vue中provide inject的响应式监听
2022/04/19 Vue.js