pytorch 中的重要模块化接口nn.Module的使用


Posted in Python onApril 02, 2020

torch.nn 是专门为神经网络设计的模块化接口,nn构建于autgrad之上,可以用来定义和运行神经网络
nn.Module 是nn中重要的类,包含网络各层的定义,以及forward方法

查看源码

初始化部分:

def __init__(self):
  self._backend = thnn_backend
  self._parameters = OrderedDict()
  self._buffers = OrderedDict()
  self._backward_hooks = OrderedDict()
  self._forward_hooks = OrderedDict()
  self._forward_pre_hooks = OrderedDict()
  self._state_dict_hooks = OrderedDict()
  self._load_state_dict_pre_hooks = OrderedDict()
  self._modules = OrderedDict()
  self.training = True

属性解释:

  • _parameters:字典,保存用户直接设置的 Parameter
  • _modules:子 module,即子类构造函数中的内容
  • _buffers:缓存
  • _backward_hooks与_forward_hooks:钩子技术,用来提取中间变量
  • training:判断值来决定前向传播策略

方法定义:

def forward(self, *input):
 raise NotImplementedError

没有实际内容,用于被子类的 forward() 方法覆盖

且 forward 方法在 __call__ 方法中被调用:

def __call__(self, *input, **kwargs):
 for hook in self._forward_pre_hooks.values():
    hook(self, input)
  if torch._C._get_tracing_state():
    result = self._slow_forward(*input, **kwargs)
  else:
    result = self.forward(*input, **kwargs)
  ...
  ...

对于自己定义的网络,需要注意以下几点:

1)需要继承nn.Module类,并实现forward方法,只要在nn.Module的子类中定义forward方法,backward函数就会被自动实现(利用autograd机制)
2)一般把网络中可学习参数的层放在构造函数中__init__(),没有可学习参数的层如Relu层可以放在构造函数中,也可以不放在构造函数中(在forward函数中使用nn.Functional)
3)在forward中可以使用任何Variable支持的函数,在整个pytorch构建的图中,是Variable在流动,也可以使用for,print,log等
4)基于nn.Module构建的模型中,只支持mini-batch的Variable的输入方式,如,N*C*H*W

代码示例:

class LeNet(nn.Module):
  def __init__(self):
    # nn.Module的子类函数必须在构造函数中执行父类的构造函数
    super(LeNet, self).__init__() # 等价与nn.Module.__init__()

    # nn.Conv2d返回的是一个Conv2d class的一个对象,该类中包含forward函数的实现
    # 当调用self.conv1(input)的时候,就会调用该类的forward函数
    self.conv1 = nn.Conv2d(1, 6, (5, 5)) # output (N, C_{out}, H_{out}, W_{out})`
    self.conv2 = nn.Conv2d(6, 16, (5, 5))
    self.fc1 = nn.Linear(256, 120)
    self.fc2 = nn.Linear(120, 84)
    self.fc3 = nn.Linear(84, 10)

  def forward(self, x):
    # F.max_pool2d的返回值是一个Variable, input:(10,1,28,28) ouput:(10, 6, 12, 12)
    x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
    # input:(10, 6, 12, 12)  output:(10,6,4,4)
    x = F.max_pool2d(F.relu(self.conv2(x)), (2, 2))
    # 固定样本个数,将其他维度的数据平铺,无论你是几通道,最终都会变成参数, output:(10, 256)
    x = x.view(x.size()[0], -1)
    # 全连接
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = F.relu(self.fc3(x))

    # 返回值也是一个Variable对象
    return x


def output_name_and_params(net):
  for name, parameters in net.named_parameters():
    print('name: {}, param: {}'.format(name, parameters))


if __name__ == '__main__':
  net = LeNet()
  print('net: {}'.format(net))
  params = net.parameters() # generator object
  print('params: {}'.format(params))
  output_name_and_params(net)

  input_image = torch.FloatTensor(10, 1, 28, 28)

  # 和tensorflow不一样,pytorch中模型的输入是一个Variable,而且是Variable在图中流动,不是Tensor。
  # 这可以从forward中每一步的执行结果可以看出
  input_image = Variable(input_image)

  output = net(input_image)
  print('output: {}'.format(output))
  print('output.size: {}'.format(output.size()))

到此这篇关于pytorch 中的重要模块化接口nn.Module的使用的文章就介绍到这了,更多相关pytorch nn.Module内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Django中模型Model添加JSON类型字段的方法
Jun 17 Python
Python 提取dict转换为xml/json/table并输出的实现代码
Aug 28 Python
浅谈Django REST Framework限速
Dec 12 Python
详解Python中的四种队列
May 21 Python
Centos部署django服务nginx+uwsgi的方法
Jan 02 Python
Python学习笔记之lambda表达式用法详解
Aug 08 Python
python实现WebSocket服务端过程解析
Oct 18 Python
使用Python的networkx绘制精美网络图教程
Nov 21 Python
pycharm快捷键汇总
Feb 14 Python
Anaconda使用IDLE的实现示例
Sep 23 Python
python 多线程共享全局变量的优劣
Sep 24 Python
python通过函数名调用函数的几种方法总结
Jun 07 Python
python递归函数求n的阶乘,优缺点及递归次数设置方式
Apr 02 #Python
PyTorch中的C++扩展实现
Apr 02 #Python
python实现将列表中各个值快速赋值给多个变量
Apr 02 #Python
Python运行提示缺少模块问题解决方案
Apr 02 #Python
Pycharm配置PyQt5环境的教程
Apr 02 #Python
Python无头爬虫下载文件的实现
Apr 02 #Python
linux 下selenium chrome使用详解
Apr 02 #Python
You might like
LotusPhp笔记之:Logger组件的使用方法
2013/05/06 PHP
微信access_token的获取开发示例
2015/04/16 PHP
yii框架使用分页的方法分析
2019/07/25 PHP
PHP全局使用Laravel辅助函数dd
2019/12/26 PHP
js option删除代码集合
2008/11/12 Javascript
JQuery循环滚动图片代码
2011/12/08 Javascript
js 验证密码强弱的小例子
2013/03/21 Javascript
javascript自定义函数参数传递为字符串格式
2014/07/29 Javascript
js delete 用法(删除对象属性及变量)
2014/08/24 Javascript
jQuery中Ajax全局事件引用方式及各个事件(全局/局部)执行顺序
2016/06/02 Javascript
javascript 实现文本使用省略号替代(超出固定高度的情况)
2017/02/21 Javascript
AngularJS1.X学习笔记2-数据绑定详解
2017/04/01 Javascript
angular使用post、get向后台传参的问题实例
2017/05/27 Javascript
JS中LocalStorage与SessionStorage五种循序渐进的使用方法
2017/07/12 Javascript
nodejs中Express与Koa2对比分析
2018/02/06 NodeJs
通过vue-router懒加载解决首次加载时资源过多导致的速度缓慢问题
2018/04/08 Javascript
jQuery扩展方法实现Form表单与Json互相转换的实例代码
2018/09/05 jQuery
详解vue文件中使用echarts.js的两种方式
2018/10/18 Javascript
vue滑动吸顶及锚点定位的示例代码
2020/05/10 Javascript
element-ui中el-upload多文件一次性上传的实现
2020/12/02 Javascript
Python内置函数dir详解
2015/04/14 Python
利用scrapy将爬到的数据保存到mysql(防止重复)
2018/03/31 Python
python递归全排列实现方法
2018/08/18 Python
opencv实现简单人脸识别
2021/02/19 Python
使用Python的turtle模块画国旗
2019/09/24 Python
深入浅析Python科学计算库Scipy及安装步骤
2019/10/12 Python
Python和Sublime整合过程图示
2019/12/25 Python
django-利用session机制实现唯一登录的例子
2020/03/16 Python
NOTINO英国:在线购买美容和香水
2020/02/25 全球购物
优秀大学生的自我评价
2014/01/16 职场文书
2014两会学习心得:时代的发展
2014/03/17 职场文书
网络编辑岗位职责
2014/03/18 职场文书
酒店员工培训方案
2014/06/02 职场文书
祖国在我心中演讲稿200字
2014/08/28 职场文书
护理专业自荐信范文
2015/03/06 职场文书
患者身份识别制度
2015/08/06 职场文书