pytorch 中的重要模块化接口nn.Module的使用


Posted in Python onApril 02, 2020

torch.nn 是专门为神经网络设计的模块化接口,nn构建于autgrad之上,可以用来定义和运行神经网络
nn.Module 是nn中重要的类,包含网络各层的定义,以及forward方法

查看源码

初始化部分:

def __init__(self):
  self._backend = thnn_backend
  self._parameters = OrderedDict()
  self._buffers = OrderedDict()
  self._backward_hooks = OrderedDict()
  self._forward_hooks = OrderedDict()
  self._forward_pre_hooks = OrderedDict()
  self._state_dict_hooks = OrderedDict()
  self._load_state_dict_pre_hooks = OrderedDict()
  self._modules = OrderedDict()
  self.training = True

属性解释:

  • _parameters:字典,保存用户直接设置的 Parameter
  • _modules:子 module,即子类构造函数中的内容
  • _buffers:缓存
  • _backward_hooks与_forward_hooks:钩子技术,用来提取中间变量
  • training:判断值来决定前向传播策略

方法定义:

def forward(self, *input):
 raise NotImplementedError

没有实际内容,用于被子类的 forward() 方法覆盖

且 forward 方法在 __call__ 方法中被调用:

def __call__(self, *input, **kwargs):
 for hook in self._forward_pre_hooks.values():
    hook(self, input)
  if torch._C._get_tracing_state():
    result = self._slow_forward(*input, **kwargs)
  else:
    result = self.forward(*input, **kwargs)
  ...
  ...

对于自己定义的网络,需要注意以下几点:

1)需要继承nn.Module类,并实现forward方法,只要在nn.Module的子类中定义forward方法,backward函数就会被自动实现(利用autograd机制)
2)一般把网络中可学习参数的层放在构造函数中__init__(),没有可学习参数的层如Relu层可以放在构造函数中,也可以不放在构造函数中(在forward函数中使用nn.Functional)
3)在forward中可以使用任何Variable支持的函数,在整个pytorch构建的图中,是Variable在流动,也可以使用for,print,log等
4)基于nn.Module构建的模型中,只支持mini-batch的Variable的输入方式,如,N*C*H*W

代码示例:

class LeNet(nn.Module):
  def __init__(self):
    # nn.Module的子类函数必须在构造函数中执行父类的构造函数
    super(LeNet, self).__init__() # 等价与nn.Module.__init__()

    # nn.Conv2d返回的是一个Conv2d class的一个对象,该类中包含forward函数的实现
    # 当调用self.conv1(input)的时候,就会调用该类的forward函数
    self.conv1 = nn.Conv2d(1, 6, (5, 5)) # output (N, C_{out}, H_{out}, W_{out})`
    self.conv2 = nn.Conv2d(6, 16, (5, 5))
    self.fc1 = nn.Linear(256, 120)
    self.fc2 = nn.Linear(120, 84)
    self.fc3 = nn.Linear(84, 10)

  def forward(self, x):
    # F.max_pool2d的返回值是一个Variable, input:(10,1,28,28) ouput:(10, 6, 12, 12)
    x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
    # input:(10, 6, 12, 12)  output:(10,6,4,4)
    x = F.max_pool2d(F.relu(self.conv2(x)), (2, 2))
    # 固定样本个数,将其他维度的数据平铺,无论你是几通道,最终都会变成参数, output:(10, 256)
    x = x.view(x.size()[0], -1)
    # 全连接
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = F.relu(self.fc3(x))

    # 返回值也是一个Variable对象
    return x


def output_name_and_params(net):
  for name, parameters in net.named_parameters():
    print('name: {}, param: {}'.format(name, parameters))


if __name__ == '__main__':
  net = LeNet()
  print('net: {}'.format(net))
  params = net.parameters() # generator object
  print('params: {}'.format(params))
  output_name_and_params(net)

  input_image = torch.FloatTensor(10, 1, 28, 28)

  # 和tensorflow不一样,pytorch中模型的输入是一个Variable,而且是Variable在图中流动,不是Tensor。
  # 这可以从forward中每一步的执行结果可以看出
  input_image = Variable(input_image)

  output = net(input_image)
  print('output: {}'.format(output))
  print('output.size: {}'.format(output.size()))

到此这篇关于pytorch 中的重要模块化接口nn.Module的使用的文章就介绍到这了,更多相关pytorch nn.Module内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
python入门之语句(if语句、while语句、for语句)
Jan 19 Python
python中enumerate函数用法实例分析
May 20 Python
Python标准库之Sys模块使用详解
May 23 Python
Python实现控制台中的进度条功能代码
Dec 22 Python
python读取txt文件并取其某一列数据的示例
Feb 19 Python
微信小程序python用户认证的实现
Jul 29 Python
使用python快速在局域网内搭建http传输文件服务的方法
Nov 14 Python
python编程进阶之类和对象用法实例分析
Feb 21 Python
Python+redis通过限流保护高并发系统
Apr 15 Python
python代码实现将列表中重复元素之间的内容全部滤除
May 22 Python
Python编程根据字典列表相同键的值进行合并
Oct 05 Python
Python可变集合和不可变集合的构造方法大全
Dec 06 Python
python递归函数求n的阶乘,优缺点及递归次数设置方式
Apr 02 #Python
PyTorch中的C++扩展实现
Apr 02 #Python
python实现将列表中各个值快速赋值给多个变量
Apr 02 #Python
Python运行提示缺少模块问题解决方案
Apr 02 #Python
Pycharm配置PyQt5环境的教程
Apr 02 #Python
Python无头爬虫下载文件的实现
Apr 02 #Python
linux 下selenium chrome使用详解
Apr 02 #Python
You might like
用PHP制作静态网站的模板框架
2006/10/09 PHP
thinkPHP下ueditor的使用方法详解
2015/12/26 PHP
老生常谈PHP面向对象之标识映射
2017/06/21 PHP
基于ThinkPHP5.0实现图片上传插件
2017/09/25 PHP
TP5框架实现自定义分页样式的方法示例
2020/04/05 PHP
Javascript注入技巧
2007/06/22 Javascript
js 操作符实例代码
2009/10/24 Javascript
基于jquery实现漂亮的动态信息提示效果
2011/08/02 Javascript
javascript自然分类法算法实现代码
2013/10/11 Javascript
jquery和ajax的关系详细介绍
2013/11/29 Javascript
jquery操作HTML5 的data-*的用法实例分享
2014/08/17 Javascript
node.js中EJS 模板快速入门教程
2017/05/08 Javascript
基于vue2框架的机器人自动回复mini-project实例代码
2017/06/13 Javascript
JS实现换肤功能的方法实例详解
2019/01/30 Javascript
js最实用string(字符串)类型的使用及截取与拼接详解
2019/04/26 Javascript
JS使用iView的Dropdown实现一个右键菜单
2019/05/06 Javascript
JS多个异步请求 按顺序执行next实现解析
2019/09/16 Javascript
解决layui轮播图有数据不显示的情况
2019/09/16 Javascript
javascript设计模式之装饰者模式
2020/01/30 Javascript
JavaScript常用工具函数大全
2020/05/06 Javascript
linux系统使用python获取内存使用信息脚本分享
2014/01/15 Python
Python编写电话薄实现增删改查功能
2016/05/07 Python
python用装饰器自动注册Tornado路由详解
2017/02/14 Python
Python断言assert的用法代码解析
2018/02/03 Python
python3.6使用urllib完成下载的实例
2018/12/19 Python
几个适合python初学者的简单小程序,看完受益匪浅!(推荐)
2019/04/16 Python
Flask模板引擎之Jinja2语法介绍
2019/06/26 Python
python获取网络图片方法及整理过程详解
2019/12/20 Python
HTML5如何使用SVG的方法示例
2019/01/11 HTML / CSS
HTML5播放实现rtmp流直播
2020/06/16 HTML / CSS
万宝龙英国官网:Montblanc手表、书写工具、皮革和珠宝
2018/10/16 全球购物
中文系师范生自荐信
2013/10/01 职场文书
质检员的岗位职责
2013/11/15 职场文书
高三上学期学习自我评价
2014/04/23 职场文书
2014年个人教学工作总结
2014/12/09 职场文书
HTML怎么设置下划线?html文字加下划线方法
2021/12/06 HTML / CSS