对Pytorch中nn.ModuleList 和 nn.Sequential详解


Posted in Python onAugust 18, 2019

简而言之就是,nn.Sequential类似于Keras中的贯序模型,它是Module的子类,在构建数个网络层之后会自动调用forward()方法,从而有网络模型生成。而nn.ModuleList仅仅类似于pytho中的list类型,只是将一系列层装入列表,并没有实现forward()方法,因此也不会有网络模型产生的副作用。

需要注意的是,nn.ModuleList接受的必须是subModule类型,例如:

nn.ModuleList(
      [nn.ModuleList([Conv(inp_dim + j * increase, oup_dim, 1, relu=False, bn=False) for j in range(5)]) for i in
       range(nstack)])

其中,二次嵌套的list内部也必须额外使用一个nn.ModuleList修饰实例化,否则会无法识别类型而报错!

摘录自

nn.ModuleList is just like a Python list. It was designed to store any desired number of nn.Module's. It may be useful, for instance, if you want to design a neural network whose number of layers is passed as input:

class LinearNet(nn.Module):
 def __init__(self, input_size, num_layers, layers_size, output_size):
   super(LinearNet, self).__init__()
 
   self.linears = nn.ModuleList([nn.Linear(input_size, layers_size)])
   self.linears.extend([nn.Linear(layers_size, layers_size) for i in range(1, self.num_layers-1)])
   self.linears.append(nn.Linear(layers_size, output_size)

nn.Sequential allows you to build a neural net by specifying sequentially the building blocks (nn.Module's) of that net. Here's an example:

class Flatten(nn.Module):
 def forward(self, x):
  N, C, H, W = x.size() # read in N, C, H, W
  return x.view(N, -1)
 
simple_cnn = nn.Sequential(
      nn.Conv2d(3, 32, kernel_size=7, stride=2),
      nn.ReLU(inplace=True),
      Flatten(), 
      nn.Linear(5408, 10),
     )

In nn.Sequential, the nn.Module's stored inside are connected in a cascaded way. For instance, in the example that I gave, I define a neural network that receives as input an image with 3 channels and outputs 10 neurons. That network is composed by the following blocks, in the following order: Conv2D -> ReLU -> Linear layer. Moreover, an object of type nn.Sequential has a forward() method, so if I have an input image x I can directly call y = simple_cnn(x) to obtain the scores for x. When you define an nn.Sequential you must be careful to make sure that the output size of a block matches the input size of the following block. Basically, it behaves just like a nn.Module

On the other hand, nn.ModuleList does not have a forward() method, because it does not define any neural network, that is, there is no connection between each of the nn.Module's that it stores. You may use it to store nn.Module's, just like you use Python lists to store other types of objects (integers, strings, etc). The advantage of using nn.ModuleList's instead of using conventional Python lists to store nn.Module's is that Pytorch is “aware” of the existence of the nn.Module's inside an nn.ModuleList, which is not the case for Python lists. If you want to understand exactly what I mean, just try to redefine my class LinearNet using a Python list instead of a nn.ModuleList and train it. When defining the optimizer() for that net, you'll get an error saying that your model has no parameters, because PyTorch does not see the parameters of the layers stored in a Python list. If you use a nn.ModuleList instead, you'll get no error.

以上这篇对Pytorch中nn.ModuleList 和 nn.Sequential详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python内存管理分析
Apr 08 Python
深入浅析python中的多进程、多线程、协程
Jun 22 Python
python爬虫入门教程--快速理解HTTP协议(一)
May 25 Python
python pandas.DataFrame选取、修改数据最好用.loc,.iloc,.ix实现
Jun 11 Python
Python爬取视频(其实是一篇福利)过程解析
Aug 01 Python
python 字符串常用方法汇总详解
Sep 16 Python
在Python中利用pickle保存变量的实例
Dec 30 Python
基于Python计算圆周率pi代码实例
Mar 25 Python
怎么快速自学python
Jun 22 Python
Python用requests库爬取返回为空的解决办法
Feb 21 Python
Python基本知识点总结
Apr 07 Python
Python中tqdm的使用和例子
Sep 23 Python
pytorch 自定义数据集加载方法
Aug 18 #Python
PyTorch的Optimizer训练工具的实现
Aug 18 #Python
Pytorch反向求导更新网络参数的方法
Aug 17 #Python
pytorch 模型可视化的例子
Aug 17 #Python
pytorch 输出中间层特征的实例
Aug 17 #Python
基于pytorch的保存和加载模型参数的方法
Aug 17 #Python
pytorch 固定部分参数训练的方法
Aug 17 #Python
You might like
php的控制语句
2006/10/09 PHP
PHP4实际应用经验篇(6)
2006/10/09 PHP
PHP中date()日期函数有关参数整理
2011/07/19 PHP
php空间不支持socket但支持curl时recaptcha的用法
2011/11/07 PHP
thinkphp实现面包屑导航(当前位置)例子分享
2014/05/10 PHP
Yii 2.0如何使用页面缓存方法示例
2017/05/23 PHP
Code:findPosX 和 findPosY
2006/12/20 Javascript
基于jquery的让textarea自适应高度的插件
2010/08/03 Javascript
一个简单的js树形菜单
2011/12/09 Javascript
document.createElement()用法及注意事项(ff下不兼容)
2013/03/13 Javascript
jQuery中before()方法用法实例
2014/12/25 Javascript
浅谈Javascript Base64 加密解密
2014/12/28 Javascript
JavaScript计时器示例分析
2015/02/05 Javascript
js简单倒计时实现代码
2016/04/30 Javascript
JavaScript中removeChild 方法开发示例代码
2016/08/15 Javascript
AngularJS 在同一个界面启动多个ng-app应用模块详解
2016/12/20 Javascript
JavaScript继承定义与用法实践分析
2018/05/28 Javascript
vue实现多组关键词对应高亮显示功能
2019/07/25 Javascript
layui固定下拉框的显示条数(有滚动条)的方法
2019/09/10 Javascript
javascript设计模式 ? 组合模式原理与应用实例分析
2020/04/14 Javascript
原生JavaScript实现轮播图
2021/01/10 Javascript
[02:17]DOTA2亚洲邀请赛 RAVE战队出场宣传片
2015/02/07 DOTA
[03:21]【TI9纪实】Old Boys
2019/08/23 DOTA
python计算最大优先级队列实例
2013/12/18 Python
Python爬取附近餐馆信息代码示例
2017/12/09 Python
基于并发服务器几种实现方法(总结)
2017/12/29 Python
python简单商城购物车实例代码
2018/03/15 Python
pytorch中的自定义反向传播,求导实例
2020/01/06 Python
韩国邮政旗下生鲜食品网上超市:epost
2016/08/27 全球购物
Giuseppe Zanotti美国官方网站:将鞋履视为高级时装般精心制作
2018/02/06 全球购物
如何将整数int转换成字串String
2014/03/21 面试题
人力资源管理专业学生自我评价
2013/11/20 职场文书
文明之星事迹材料
2014/05/09 职场文书
实习护士自荐信
2014/06/21 职场文书
好的旅游活动方案
2014/08/19 职场文书
2016年五一国际劳动节活动总结
2016/04/06 职场文书