对Pytorch中nn.ModuleList 和 nn.Sequential详解


Posted in Python onAugust 18, 2019

简而言之就是,nn.Sequential类似于Keras中的贯序模型,它是Module的子类,在构建数个网络层之后会自动调用forward()方法,从而有网络模型生成。而nn.ModuleList仅仅类似于pytho中的list类型,只是将一系列层装入列表,并没有实现forward()方法,因此也不会有网络模型产生的副作用。

需要注意的是,nn.ModuleList接受的必须是subModule类型,例如:

nn.ModuleList(
      [nn.ModuleList([Conv(inp_dim + j * increase, oup_dim, 1, relu=False, bn=False) for j in range(5)]) for i in
       range(nstack)])

其中,二次嵌套的list内部也必须额外使用一个nn.ModuleList修饰实例化,否则会无法识别类型而报错!

摘录自

nn.ModuleList is just like a Python list. It was designed to store any desired number of nn.Module's. It may be useful, for instance, if you want to design a neural network whose number of layers is passed as input:

class LinearNet(nn.Module):
 def __init__(self, input_size, num_layers, layers_size, output_size):
   super(LinearNet, self).__init__()
 
   self.linears = nn.ModuleList([nn.Linear(input_size, layers_size)])
   self.linears.extend([nn.Linear(layers_size, layers_size) for i in range(1, self.num_layers-1)])
   self.linears.append(nn.Linear(layers_size, output_size)

nn.Sequential allows you to build a neural net by specifying sequentially the building blocks (nn.Module's) of that net. Here's an example:

class Flatten(nn.Module):
 def forward(self, x):
  N, C, H, W = x.size() # read in N, C, H, W
  return x.view(N, -1)
 
simple_cnn = nn.Sequential(
      nn.Conv2d(3, 32, kernel_size=7, stride=2),
      nn.ReLU(inplace=True),
      Flatten(), 
      nn.Linear(5408, 10),
     )

In nn.Sequential, the nn.Module's stored inside are connected in a cascaded way. For instance, in the example that I gave, I define a neural network that receives as input an image with 3 channels and outputs 10 neurons. That network is composed by the following blocks, in the following order: Conv2D -> ReLU -> Linear layer. Moreover, an object of type nn.Sequential has a forward() method, so if I have an input image x I can directly call y = simple_cnn(x) to obtain the scores for x. When you define an nn.Sequential you must be careful to make sure that the output size of a block matches the input size of the following block. Basically, it behaves just like a nn.Module

On the other hand, nn.ModuleList does not have a forward() method, because it does not define any neural network, that is, there is no connection between each of the nn.Module's that it stores. You may use it to store nn.Module's, just like you use Python lists to store other types of objects (integers, strings, etc). The advantage of using nn.ModuleList's instead of using conventional Python lists to store nn.Module's is that Pytorch is “aware” of the existence of the nn.Module's inside an nn.ModuleList, which is not the case for Python lists. If you want to understand exactly what I mean, just try to redefine my class LinearNet using a Python list instead of a nn.ModuleList and train it. When defining the optimizer() for that net, you'll get an error saying that your model has no parameters, because PyTorch does not see the parameters of the layers stored in a Python list. If you use a nn.ModuleList instead, you'll get no error.

以上这篇对Pytorch中nn.ModuleList 和 nn.Sequential详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中使用中文的方法
Feb 19 Python
python使用Berkeley DB数据库实例
Sep 26 Python
在Python中移动目录结构的方法
Jan 31 Python
Python实现简单过滤文本段的方法
May 24 Python
从CentOS安装完成到生成词云python的实例
Dec 01 Python
python+matplotlib实现动态绘制图片实例代码(交互式绘图)
Jan 20 Python
Python实现删除时保留特定文件夹和文件的示例
Apr 27 Python
在Django中URL正则表达式匹配的方法
Dec 20 Python
使用Python Pandas处理亿级数据的方法
Jun 24 Python
pytorch制作自己的LMDB数据操作示例
Dec 18 Python
tf.concat中axis的含义与使用详解
Feb 07 Python
对python中各个response的使用说明
Mar 28 Python
pytorch 自定义数据集加载方法
Aug 18 #Python
PyTorch的Optimizer训练工具的实现
Aug 18 #Python
Pytorch反向求导更新网络参数的方法
Aug 17 #Python
pytorch 模型可视化的例子
Aug 17 #Python
pytorch 输出中间层特征的实例
Aug 17 #Python
基于pytorch的保存和加载模型参数的方法
Aug 17 #Python
pytorch 固定部分参数训练的方法
Aug 17 #Python
You might like
拼音码表的生成
2006/10/09 PHP
用header 发送cookie的php代码
2007/03/16 PHP
删除数组元素实用的PHP数组函数
2008/08/18 PHP
php解决抢购秒杀抽奖等大流量并发入库导致的库存负数的问题
2014/06/19 PHP
Laravel4中的Validator验证扩展用法详解
2016/07/26 PHP
php实现通过stomp协议连接ActiveMQ操作示例
2020/02/23 PHP
javascript之ESC(第二类混淆)
2007/05/06 Javascript
jQuery学习之prop和attr的区别示例介绍
2013/11/15 Javascript
js使用循环清空某个div中的input标签值
2014/09/29 Javascript
jquery图形密码实现方法
2015/03/11 Javascript
Bootstrap表单布局
2016/07/19 Javascript
AngularJS中的缓存使用
2017/01/11 Javascript
Bootstrap3多级下拉菜单
2017/02/24 Javascript
谈谈VUE种methods watch和compute的区别和联系
2017/08/01 Javascript
vue 添加vux的代码讲解
2017/11/30 Javascript
vue框架搭建之axios使用教程
2018/07/11 Javascript
微信公众平台 发送模板消息(Java接口开发)
2019/04/17 Javascript
JS实现页面鼠标点击出现图片特效
2020/08/19 Javascript
解决vue数据不实时更新的问题(数据更改了,但数据不实时更新)
2020/10/27 Javascript
[03:46]DAC趣味视频-中文考试.mp4
2017/04/02 DOTA
[00:44]2016完美“圣”典 风云人物:Mikasa宣传片
2016/12/07 DOTA
python网络编程学习笔记(四):域名系统
2014/06/09 Python
Python连接MySQL并使用fetchall()方法过滤特殊字符
2016/03/13 Python
Python中矩阵创建和矩阵运算方法
2018/08/04 Python
Python实现京东秒杀功能代码
2019/05/16 Python
Python学习笔记之函数的定义和作用域实例详解
2019/08/13 Python
Python中remove漏删和索引越界问题的解决
2020/03/18 Python
tensorflow下的图片标准化函数per_image_standardization用法
2020/06/30 Python
基于CSS3 animation动画属性实现轮播图效果
2017/09/12 HTML / CSS
西班牙英格列斯百货英国官网:El Corte Inglés英国
2017/10/30 全球购物
马德里著名的运动鞋商店:NOIRFONCE
2019/04/12 全球购物
怎样写好自我评价呢?
2014/02/16 职场文书
优质服务口号
2014/06/11 职场文书
学校安全责任书范本
2014/07/23 职场文书
2019年年中职场激励人心语录30条
2019/08/07 职场文书
导游词之镜泊湖
2019/12/09 职场文书