pytorch 中的重要模块化接口nn.Module的使用


Posted in Python onApril 02, 2020

torch.nn 是专门为神经网络设计的模块化接口,nn构建于autgrad之上,可以用来定义和运行神经网络
nn.Module 是nn中重要的类,包含网络各层的定义,以及forward方法

查看源码

初始化部分:

def __init__(self):
  self._backend = thnn_backend
  self._parameters = OrderedDict()
  self._buffers = OrderedDict()
  self._backward_hooks = OrderedDict()
  self._forward_hooks = OrderedDict()
  self._forward_pre_hooks = OrderedDict()
  self._state_dict_hooks = OrderedDict()
  self._load_state_dict_pre_hooks = OrderedDict()
  self._modules = OrderedDict()
  self.training = True

属性解释:

  • _parameters:字典,保存用户直接设置的 Parameter
  • _modules:子 module,即子类构造函数中的内容
  • _buffers:缓存
  • _backward_hooks与_forward_hooks:钩子技术,用来提取中间变量
  • training:判断值来决定前向传播策略

方法定义:

def forward(self, *input):
 raise NotImplementedError

没有实际内容,用于被子类的 forward() 方法覆盖

且 forward 方法在 __call__ 方法中被调用:

def __call__(self, *input, **kwargs):
 for hook in self._forward_pre_hooks.values():
    hook(self, input)
  if torch._C._get_tracing_state():
    result = self._slow_forward(*input, **kwargs)
  else:
    result = self.forward(*input, **kwargs)
  ...
  ...

对于自己定义的网络,需要注意以下几点:

1)需要继承nn.Module类,并实现forward方法,只要在nn.Module的子类中定义forward方法,backward函数就会被自动实现(利用autograd机制)
2)一般把网络中可学习参数的层放在构造函数中__init__(),没有可学习参数的层如Relu层可以放在构造函数中,也可以不放在构造函数中(在forward函数中使用nn.Functional)
3)在forward中可以使用任何Variable支持的函数,在整个pytorch构建的图中,是Variable在流动,也可以使用for,print,log等
4)基于nn.Module构建的模型中,只支持mini-batch的Variable的输入方式,如,N*C*H*W

代码示例:

class LeNet(nn.Module):
  def __init__(self):
    # nn.Module的子类函数必须在构造函数中执行父类的构造函数
    super(LeNet, self).__init__() # 等价与nn.Module.__init__()

    # nn.Conv2d返回的是一个Conv2d class的一个对象,该类中包含forward函数的实现
    # 当调用self.conv1(input)的时候,就会调用该类的forward函数
    self.conv1 = nn.Conv2d(1, 6, (5, 5)) # output (N, C_{out}, H_{out}, W_{out})`
    self.conv2 = nn.Conv2d(6, 16, (5, 5))
    self.fc1 = nn.Linear(256, 120)
    self.fc2 = nn.Linear(120, 84)
    self.fc3 = nn.Linear(84, 10)

  def forward(self, x):
    # F.max_pool2d的返回值是一个Variable, input:(10,1,28,28) ouput:(10, 6, 12, 12)
    x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
    # input:(10, 6, 12, 12)  output:(10,6,4,4)
    x = F.max_pool2d(F.relu(self.conv2(x)), (2, 2))
    # 固定样本个数,将其他维度的数据平铺,无论你是几通道,最终都会变成参数, output:(10, 256)
    x = x.view(x.size()[0], -1)
    # 全连接
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = F.relu(self.fc3(x))

    # 返回值也是一个Variable对象
    return x


def output_name_and_params(net):
  for name, parameters in net.named_parameters():
    print('name: {}, param: {}'.format(name, parameters))


if __name__ == '__main__':
  net = LeNet()
  print('net: {}'.format(net))
  params = net.parameters() # generator object
  print('params: {}'.format(params))
  output_name_and_params(net)

  input_image = torch.FloatTensor(10, 1, 28, 28)

  # 和tensorflow不一样,pytorch中模型的输入是一个Variable,而且是Variable在图中流动,不是Tensor。
  # 这可以从forward中每一步的执行结果可以看出
  input_image = Variable(input_image)

  output = net(input_image)
  print('output: {}'.format(output))
  print('output.size: {}'.format(output.size()))

到此这篇关于pytorch 中的重要模块化接口nn.Module的使用的文章就介绍到这了,更多相关pytorch nn.Module内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
将Python中的数据存储到系统本地的简单方法
Apr 11 Python
Python使用tablib生成excel文件的简单实现方法
Mar 16 Python
Python基于列表list实现的CRUD操作功能示例
Jan 05 Python
Tensorflow实现卷积神经网络用于人脸关键点识别
Mar 05 Python
python2.7和NLTK安装详细教程
Sep 19 Python
在python中利用最小二乘拟合二次抛物线函数的方法
Dec 29 Python
将python2.7添加进64位系统的注册表方式
Nov 20 Python
scrapy数据存储在mysql数据库的两种方式(同步和异步)
Feb 18 Python
Python3 assert断言实现原理解析
Mar 02 Python
Python爬虫代理池搭建的方法步骤
Sep 28 Python
Python字典dict常用方法函数实例
Nov 09 Python
python网络爬虫实现发送短信验证码的方法
Feb 25 Python
python递归函数求n的阶乘,优缺点及递归次数设置方式
Apr 02 #Python
PyTorch中的C++扩展实现
Apr 02 #Python
python实现将列表中各个值快速赋值给多个变量
Apr 02 #Python
Python运行提示缺少模块问题解决方案
Apr 02 #Python
Pycharm配置PyQt5环境的教程
Apr 02 #Python
Python无头爬虫下载文件的实现
Apr 02 #Python
linux 下selenium chrome使用详解
Apr 02 #Python
You might like
PHP中数组的三种排序方法分享
2012/05/07 PHP
PHP中date与gmdate的区别及默认时区设置
2014/05/12 PHP
php实现按照权重随机排序数据的方法
2015/01/09 PHP
Thinkphp关闭缓存的方法
2015/06/26 PHP
PHP Header用于页面跳转时的几个注意事项
2016/10/21 PHP
PHP数据的提交与过滤基本操作实例详解
2016/11/11 PHP
破除网页鼠标右键被禁用的绝招大全
2006/12/27 Javascript
javascript 动态修改样式和层叠样式表代码
2010/04/27 Javascript
javascript setTimeout和setInterval计时的区别详解
2013/06/21 Javascript
利用JS延迟加载百度分享代码,提高网页速度
2013/07/01 Javascript
按Enter键触发事件的jquery方法实现代码
2014/02/17 Javascript
node.js中的fs.symlink方法使用说明
2014/12/15 Javascript
JS实现点击颜色块切换指定区域背景颜色的方法
2015/02/25 Javascript
深入理解JS addLoadEvent函数
2016/05/20 Javascript
mvc 、bootstrap 结合分布式图简单实现分页
2016/10/10 Javascript
ECMAScript6--解构
2017/03/30 Javascript
基于JavaScript实现无缝滚动效果
2017/07/21 Javascript
JS中实现浅拷贝和深拷贝的代码详解
2019/06/05 Javascript
Python按行读取文件的实现方法【小文件和大文件读取】
2016/09/19 Python
Python网络爬虫项目:内容提取器的定义
2016/10/25 Python
python下载图片实现方法(超简单)
2017/07/21 Python
CentOS下使用yum安装python-pip失败的完美解决方法
2017/08/16 Python
使用python装饰器计算函数运行时间的实例
2018/04/21 Python
python dict 相同key 合并value的实例
2019/01/21 Python
python模块之subprocess模块级方法的使用
2019/03/26 Python
python for和else语句趣谈
2019/07/02 Python
Pytorch高阶OP操作where,gather原理
2020/04/30 Python
用python发送微信消息
2020/12/21 Python
HTML5 CSS3实现一个精美VCD包装盒个性幻灯片案例
2014/06/16 HTML / CSS
美国在线宠物用品商店:Entirely Pets
2017/01/01 全球购物
女孩每月服装订阅盒:kidpik
2019/04/17 全球购物
工程部经理岗位职责
2013/12/08 职场文书
母亲节感恩寄语
2014/02/21 职场文书
聊聊golang中多个defer的执行顺序
2021/05/08 Golang
Flask搭建一个API服务器的步骤
2021/05/28 Python
浅谈Java实现分布式事务的三种方案
2021/06/11 Java/Android