Pytorch: 自定义网络层实例


Posted in Python onJanuary 07, 2020

自定义Autograd函数

对于浅层的网络,我们可以手动的书写前向传播和反向传播过程。但是当网络变得很大时,特别是在做深度学习时,网络结构变得复杂。前向传播和反向传播也随之变得复杂,手动书写这两个过程就会存在很大的困难。幸运地是在pytorch中存在了自动微分的包,可以用来解决该问题。在使用自动求导的时候,网络的前向传播会定义一个计算图(computational graph),图中的节点是张量(tensor),两个节点之间的边对应了两个张量之间变换关系的函数。有了计算图的存在,张量的梯度计算也变得容易了些。例如, x是一个张量,其属性 x.requires_grad = True,那么 x.grad就是一个保存这个张量x的梯度的一些标量值。

最基础的自动求导操作在底层就是作用在两个张量上。前向传播函数是从输入张量到输出张量的计算过程;反向传播是输入输出张量的梯度(一些标量)并输出输入张量的梯度(一些标量)。在pytorch中我们可以很容易地定义自己的自动求导操作,通过继承torch.autograd.Function并定义forward和backward函数。

forward(): 前向传播操作。可以输入任意多的参数,任意的python对象都可以。

backward():反向传播(梯度公式)。输出的梯度个数需要与所使用的张量个数保持一致,且返回的顺序也要对应起来。

# Inherit from Function
class LinearFunction(Function):

  # Note that both forward and backward are @staticmethods
  @staticmethod
  # bias is an optional argument
  def forward(ctx, input, weight, bias=None):
    # ctx在这里类似self,ctx的属性可以在backward中调用
    ctx.save_for_backward(input, weight, bias)
    output = input.mm(weight.t())
    if bias is not None:
      output += bias.unsqueeze(0).expand_as(output)
    return output

  # This function has only a single output, so it gets only one gradient
  @staticmethod
  def backward(ctx, grad_output):
    # This is a pattern that is very convenient - at the top of backward
    # unpack saved_tensors and initialize all gradients w.r.t. inputs to
    # None. Thanks to the fact that additional trailing Nones are
    # ignored, the return statement is simple even when the function has
    # optional inputs.
    input, weight, bias = ctx.saved_tensors
    grad_input = grad_weight = grad_bias = None

    # These needs_input_grad checks are optional and there only to
    # improve efficiency. If you want to make your code simpler, you can
    # skip them. Returning gradients for inputs that don't require it is
    # not an error.
    if ctx.needs_input_grad[0]:
      grad_input = grad_output.mm(weight)
    if ctx.needs_input_grad[1]:
      grad_weight = grad_output.t().mm(input)
    if bias is not None and ctx.needs_input_grad[2]:
      grad_bias = grad_output.sum(0).squeeze(0)

    return grad_input, grad_weight, grad_bias

#调用自定义的自动求导函数
linear = LinearFunction.apply(*args) #前向传播
linear.backward()#反向传播
linear.grad_fn.apply(*args)#反向传播

对于非参数化的张量(权重是常量,不需要更新),此时可以定义为:

class MulConstant(Function):
  @staticmethod
  def forward(ctx, tensor, constant):
    # ctx is a context object that can be used to stash information
    # for backward computation
    ctx.constant = constant
    return tensor * constant

  @staticmethod
  def backward(ctx, grad_output):
    # We return as many input gradients as there were arguments.
    # Gradients of non-Tensor arguments to forward must be None.
    return grad_output * ctx.constant, None

高阶导数

grad_x =t.autograd.grad(y, x, create_graph=True)

grad_grad_x = t.autograd.grad(grad_x[0],x)

自定义Module

计算图和自动求导在定义复杂网络和求梯度的时候非常好用,但对于大型的网络,这个还是有点偏底层。在我们构建网络的时候,经常希望将计算限制在每个层之内(参数更新分层更新)。而且在TensorFlow等其他深度学习框架中都提供了高级抽象结构。因此,在pytorch中也提供了类似的包nn,它定义了一组等价于层(layer)的模块(Modules)。一个Module接受输入张量并得到输出张量,同时也会包含可学习的参数。

有时候,我们希望运用一些新的且nn包中不存在的Module。此时就需要定义自己的Module了。自定义的Module需要继承nn.Module且自定义forward函数。其中forward函数可以接受输入张量并利用其它模型或者其他自动求导操作来产生输出张量。但并不需要重写backward函数,因此nn使用了autograd。这也就意味着,需要自定义Module, 都必须有对应的autograd函数以调用其中的backward。

class Linear(nn.Module):
  def __init__(self, input_features, output_features, bias=True):
    super(Linear, self).__init__()
    self.input_features = input_features
    self.output_features = output_features

    # nn.Parameter is a special kind of Tensor, that will get
    # automatically registered as Module's parameter once it's assigned
    # as an attribute. Parameters and buffers need to be registered, or
    # they won't appear in .parameters() (doesn't apply to buffers), and
    # won't be converted when e.g. .cuda() is called. You can use
    # .register_buffer() to register buffers.
    # (很重要!!!参数一定需要梯度!)nn.Parameters require gradients by default.
    self.weight = nn.Parameter(torch.Tensor(output_features, input_features))
    if bias:
      self.bias = nn.Parameter(torch.Tensor(output_features))
    else:
      # You should always register all possible parameters, but the
      # optional ones can be None if you want.
      self.register_parameter('bias', None)

    # Not a very smart way to initialize weights
    self.weight.data.uniform_(-0.1, 0.1)
    if bias is not None:
      self.bias.data.uniform_(-0.1, 0.1)

  def forward(self, input):
    # See the autograd section for explanation of what happens here.
    return LinearFunction.apply(input, self.weight, self.bias)

  def extra_repr(self):
    # (Optional)Set the extra information about this module. You can test
    # it by printing an object of this class.
    return 'in_features={}, out_features={}, bias={}'.format(
      self.in_features, self.out_features, self.bias is not None

Function与Module的异同

Function与Module都可以对pytorch进行自定义拓展,使其满足网络的需求,但这两者还是有十分重要的不同:

Function一般只定义一个操作,因为其无法保存参数,因此适用于激活函数、pooling等操作;Module是保存了参数,因此适合于定义一层,如线性层,卷积层,也适用于定义一个网络

Function需要定义三个方法:init, forward, backward(需要自己写求导公式);Module:只需定义init和forward,而backward的计算由自动求导机制构成

可以不严谨的认为,Module是由一系列Function组成,因此其在forward的过程中,Function和Variable组成了计算图,在backward时,只需调用Function的backward就得到结果,因此Module不需要再定义backward。

Module不仅包括了Function,还包括了对应的参数,以及其他函数与变量,这是Function所不具备的。

module 是 pytorch 组织神经网络的基本方式。Module 包含了模型的参数以及计算逻辑。Function 承载了实际的功能,定义了前向和后向的计算逻辑。

Module 是任何神经网络的基类,pytorch 中所有模型都必需是 Module 的子类。 Module 可以套嵌,构成树状结构。一个 Module 可以通过将其他 Module 做为属性的方式,完成套嵌。

Function 是 pytorch 自动求导机制的核心类。Function 是无参数或者说无状态的,它只负责接收输入,返回相应的输出;对于反向,它接收输出相应的梯度,返回输入相应的梯度。

在调用loss.backward()时,使用的是Function子类中定义的backward()函数。

以上这篇Pytorch: 自定义网络层实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python 元组(Tuple)操作详解
Mar 11 Python
介绍Python的Django框架中的静态资源管理器django-pipeline
Apr 25 Python
python 匹配url中是否存在IP地址的方法
Jun 04 Python
使用Python opencv实现视频与图片的相互转换
Jul 08 Python
在Django model中设置多个字段联合唯一约束的实例
Jul 17 Python
Django Rest framework解析器和渲染器详解
Jul 25 Python
Python @property使用方法解析
Sep 17 Python
Python通过递归获取目录下指定文件代码实例
Nov 07 Python
python3 pathlib库Path类方法总结
Dec 26 Python
基于CentOS搭建Python Django环境过程解析
Aug 24 Python
python PIL模块的基本使用
Sep 29 Python
python绘制高斯曲线
Feb 19 Python
Python StringIO如何在内存中读写str
Jan 07 #Python
Python内置数据类型list各方法的性能测试过程解析
Jan 07 #Python
python模拟实现斗地主发牌
Jan 07 #Python
python全局变量引用与修改过程解析
Jan 07 #Python
python__new__内置静态方法使用解析
Jan 07 #Python
Python常用模块sys,os,time,random功能与用法实例分析
Jan 07 #Python
python单例设计模式实现解析
Jan 07 #Python
You might like
php 时间计算问题小结
2009/01/04 PHP
php 数组使用详解 推荐
2011/06/02 PHP
php读取纯真ip数据库使用示例
2014/01/26 PHP
PHP中mysql_field_type()函数用法
2014/11/24 PHP
利用PHP获取网站访客的所在地位置
2017/01/18 PHP
PHP排序算法之快速排序(Quick Sort)及其优化算法详解
2018/04/21 PHP
php判断目录存在的简单方法
2019/09/26 PHP
laravel-admin 在列表页添加自定义按钮的例子
2019/09/30 PHP
用js重建星际争霸
2006/12/22 Javascript
javascript或asp实现的判断身份证号码是否正确两种验证方法
2009/11/26 Javascript
Google排名中的10个最著名的 JavaScript库
2010/04/27 Javascript
Confirmer JQuery确认对话框组件
2010/06/09 Javascript
一起来写段JS drag拖动代码
2010/12/09 Javascript
真正的JQuery.ajax传递中文参数的解决方法
2011/05/28 Javascript
jQuery cdn使用介绍
2013/05/08 Javascript
ScrollDown的基本操作示例
2013/06/09 Javascript
在jquery中combobox多选的不兼容问题总结
2013/12/24 Javascript
jquery跟js初始化加载的多种方法及区别介绍
2014/04/02 Javascript
node.js中的console用法总结
2014/12/15 Javascript
DOM 高级编程
2015/05/06 Javascript
js验证框架之RealyEasy验证详解
2016/06/08 Javascript
JS继承与闭包及JS实现继承的三种方式
2017/10/15 Javascript
vue登录以及权限验证相关的实现
2019/10/25 Javascript
vue中 数字相加为字串转化为数值的例子
2019/11/07 Javascript
Centos Python2 升级到Python3的简单实现
2016/06/21 Python
Python类装饰器实现方法详解
2018/12/21 Python
Python实现读取txt文件中的数据并绘制出图形操作示例
2019/02/26 Python
Python Collatz序列实现过程解析
2019/10/12 Python
python之MSE、MAE、RMSE的使用
2020/02/24 Python
django rest framework serializer返回时间自动格式化方法
2020/03/31 Python
OpenCV Python实现图像指定区域裁剪
2021/03/12 Python
一个不错的HTML5 Canvas多层点击事件监听实例
2014/04/29 HTML / CSS
企业治理工作自我评价
2013/09/26 职场文书
2014年民政工作总结
2014/11/26 职场文书
少年的你:世界上没有如果,要在第一次就勇敢的反抗
2019/11/20 职场文书
ant design vue的form表单取值方法
2022/06/01 Vue.js