Pytorch 实现自定义参数层的例子


Posted in Python onAugust 17, 2019

注意,一般官方接口都带有可导功能,如果你实现的层不具有可导功能,就需要自己实现梯度的反向传递。

官方Linear层:

class Linear(Module):
  def __init__(self, in_features, out_features, bias=True):
    super(Linear, self).__init__()
    self.in_features = in_features
    self.out_features = out_features
    self.weight = Parameter(torch.Tensor(out_features, in_features))
    if bias:
      self.bias = Parameter(torch.Tensor(out_features))
    else:
      self.register_parameter('bias', None)
    self.reset_parameters()

  def reset_parameters(self):
    stdv = 1. / math.sqrt(self.weight.size(1))
    self.weight.data.uniform_(-stdv, stdv)
    if self.bias is not None:
      self.bias.data.uniform_(-stdv, stdv)

  def forward(self, input):
    return F.linear(input, self.weight, self.bias)

  def extra_repr(self):
    return 'in_features={}, out_features={}, bias={}'.format(
      self.in_features, self.out_features, self.bias is not None
    )

实现view层

class Reshape(nn.Module):
  def __init__(self, *args):
    super(Reshape, self).__init__()
    self.shape = args

  def forward(self, x):
    return x.view((x.size(0),)+self.shape)

实现LinearWise层

class LinearWise(nn.Module):
  def __init__(self, in_features, bias=True):
    super(LinearWise, self).__init__()
    self.in_features = in_features

    self.weight = nn.Parameter(torch.Tensor(self.in_features))
    if bias:
      self.bias = nn.Parameter(torch.Tensor(self.in_features))
    else:
      self.register_parameter('bias', None)
    self.reset_parameters()

  def reset_parameters(self):
    stdv = 1. / math.sqrt(self.weight.size(0))
    self.weight.data.uniform_(-stdv, stdv)
    if self.bias is not None:
      self.bias.data.uniform_(-stdv, stdv)

  def forward(self, input):
    x = input * self.weight
    if self.bias is not None:
      x = x + self.bias
    return x

以上这篇Pytorch 实现自定义参数层的例子就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
PyQt5每天必学之事件与信号
Apr 20 Python
Python之文字转图片方法
May 10 Python
tensorflow实现加载mnist数据集
Sep 08 Python
Python Unittest根据不同测试环境跳过用例的方法
Dec 16 Python
Python批量删除只保留最近几天table的代码实例
Apr 01 Python
Python中新式类与经典类的区别详析
Jul 10 Python
Django 重写用户模型的实现
Jul 29 Python
python列表插入append(), extend(), insert()用法详解
Sep 14 Python
解析python 类方法、对象方法、静态方法
Aug 15 Python
python爬虫智能翻页批量下载文件的实例详解
Feb 02 Python
python 实现定时任务的四种方式
Apr 01 Python
Python使用pandas导入xlsx格式的excel文件内容操作代码
Dec 24 Python
Python中PyQt5/PySide2的按钮控件使用实例
Aug 17 #Python
画pytorch模型图,以及参数计算的方法
Aug 17 #Python
pytorch 共享参数的示例
Aug 17 #Python
Pytorch卷积层手动初始化权值的实例
Aug 17 #Python
pytorch自定义初始化权重的方法
Aug 17 #Python
在Pytorch中使用样本权重(sample_weight)的正确方法
Aug 17 #Python
获取Pytorch中间某一层权重或者特征的例子
Aug 17 #Python
You might like
简单PHP上传图片、删除图片实现代码
2010/05/12 PHP
php返回相对时间(如:20分钟前,3天前)的方法
2015/04/14 PHP
PHP自定义函数格式化json数据示例
2016/09/14 PHP
IE6下CSS图片缓存问题解决方法
2010/12/09 Javascript
javascript数字数组去重复项的实现代码
2010/12/30 Javascript
JavaScript高级程序设计(第3版)学习笔记4 js运算符和操作符
2012/10/11 Javascript
改进版通过Json对象实现深复制的方法
2012/10/24 Javascript
jquery插件之信息弹出框showInfoDialog(成功/错误/警告/通知/背景遮罩)
2013/01/09 Javascript
js时间比较示例分享(日期比较)
2014/03/05 Javascript
html的DOM中Event对象onblur事件用法实例
2015/01/21 Javascript
AngularJs动态加载模块和依赖注入详解
2016/01/11 Javascript
jQuery.form插件的使用及跨域异步上传文件
2016/04/27 Javascript
JS不用正则验证输入的字符串是否为空(包含空格)的实现代码
2016/06/14 Javascript
Javascript 基础---Ajax入门必看
2016/07/06 Javascript
JavaScript计算值然后把值嵌入到html中的实现方法
2016/10/29 Javascript
Node.js用readline模块实现输入输出
2016/12/16 Javascript
jQuery Validate 校验多个相同name的方法
2017/05/18 jQuery
vue.js国际化 vue-i18n插件的使用详解
2017/07/07 Javascript
Jquery中.bind()、.live()、.delegate()和.on()之间的区别详解
2017/08/01 jQuery
Vue 表情包输入组件的实现代码
2019/01/21 Javascript
微信小程序 select 下拉框组件功能
2019/09/09 Javascript
js实现简单的倒计时
2021/01/28 Javascript
[01:57]2018年度DOTA2最具潜力解说-完美盛典
2018/12/16 DOTA
使用numba对Python运算加速的方法
2018/10/15 Python
python实现石头剪刀布小游戏
2021/01/20 Python
python多线程共享变量的使用和效率方法
2019/07/16 Python
HTML5 canvas绘制的玫瑰花效果
2014/05/29 HTML / CSS
马来西亚排名第一的宠物用品店:Pets Wonderland
2020/04/16 全球购物
C# Debug和Testing相关面试题
2015/10/25 面试题
应届大专毕业生个人自荐信
2013/09/22 职场文书
中班下学期幼儿评语
2014/12/30 职场文书
前台文员岗位职责
2015/02/04 职场文书
三严三实·严以修身心得体会
2016/01/15 职场文书
python 逐步回归算法
2021/04/06 Python
浅谈Python中对象是如何被调用的
2022/04/06 Python
Android中View.post和Handler.post的关系
2022/06/05 Java/Android