pytorch 共享参数的示例


Posted in Python onAugust 17, 2019

在很多神经网络中,往往会出现多个层共享一个权重的情况,pytorch可以快速地处理权重共享问题。

例子1:

class ConvNet(nn.Module):
  def __init__(self):
    super(ConvNet, self).__init__()
    self.conv_weight = nn.Parameter(torch.randn(3, 3, 5, 5))
 
  def forward(self, x):
    x = nn.functional.conv2d(x, self.conv_weight, bias=None, stride=1, padding=2, dilation=1, groups=1)
    x = nn.functional.conv2d(x, self.conv_weight.transpose(2, 3).contiguous(), bias=None, stride=1, padding=0, dilation=1,
                 groups=1)
    return x

上边这段程序定义了两个卷积层,这两个卷积层共享一个权重conv_weight,第一个卷积层的权重是conv_weight本身,第二个卷积层是conv_weight的转置。注意在gpu上运行时,transpose()后边必须加上.contiguous()使转置操作连续化,否则会报错。

例子2:

class LinearNet(nn.Module):
  def __init__(self):
    super(LinearNet, self).__init__()
    self.linear_weight = nn.Parameter(torch.randn(3, 3))
 
  def forward(self, x):
    x = nn.functional.linear(x, self.linear_weight)
    x = nn.functional.linear(x, self.linear_weight.t())
 
    return x

这个网络实现了一个双层感知器,权重同样是一个parameter的本身及其转置。

例子3:

class LinearNet2(nn.Module):
  def __init__(self):
    super(LinearNet2, self).__init__()
    self.w = nn.Parameter(torch.FloatTensor([[1.1,0,0], [0,1,0], [0,0,1]]))
 
  def forward(self, x):
    x = x.mm(self.w)
    x = x.mm(self.w.t())
    return x

这个方法直接用mm函数将x与w相乘,与上边的网络效果相同。

以上这篇pytorch 共享参数的示例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python程序设计入门(4)模块和包
Jun 16 Python
python中常用的九种预处理方法分享
Sep 11 Python
浅谈Python 的枚举 Enum
Jun 12 Python
python获取外网IP并发邮件的实现方法
Oct 01 Python
Python+OpenCV图片局部区域像素值处理改进版详解
Jan 23 Python
对django后台admin下拉框进行过滤的实例
Jul 26 Python
手动安装python3.6的操作过程详解
Jan 13 Python
python实现一个猜拳游戏
Apr 05 Python
使用openCV去除文字中乱入的线条实例
Jun 02 Python
使用keras实现孪生网络中的权值共享教程
Jun 11 Python
如何Tkinter模块编写Python图形界面
Oct 14 Python
解决pip安装tensorflow中出现的no module named tensorflow.python 问题方法
Feb 20 Python
Pytorch卷积层手动初始化权值的实例
Aug 17 #Python
pytorch自定义初始化权重的方法
Aug 17 #Python
在Pytorch中使用样本权重(sample_weight)的正确方法
Aug 17 #Python
获取Pytorch中间某一层权重或者特征的例子
Aug 17 #Python
pyenv与virtualenv安装实现python多版本多项目管理
Aug 17 #Python
pytorch 获取层权重,对特定层注入hook, 提取中间层输出的方法
Aug 17 #Python
关于PyTorch源码解读之torchvision.models
Aug 17 #Python
You might like
php图片验证码代码
2008/03/27 PHP
PHP的Yii框架中过滤器相关的使用总结
2016/03/29 PHP
ThinkPHP5.1验证码功能实现的示例代码
2020/06/08 PHP
javascript读取RSS数据
2007/01/20 Javascript
json 实例详细说明教程
2009/10/31 Javascript
JavaScript Distilled 基础知识与函数
2010/04/07 Javascript
js实现网站首页图片滚动显示
2013/02/04 Javascript
jquery 实现密码框的显示与隐藏示例代码
2013/09/18 Javascript
JQuery.get提交页面不跳转的解决方法
2015/01/13 Javascript
关于JS中match() 和 exec() 返回值和属性的测试
2016/03/21 Javascript
基于javascript实现tab选项卡切换特效调试笔记
2016/03/30 Javascript
浅谈javascript中遇到的字符串对象处理
2016/11/18 Javascript
微信JS-SDK自定义分享功能实例详解【分享给朋友/分享到朋友圈】
2016/11/25 Javascript
js实现tab选项卡切换功能
2017/01/13 Javascript
ES6新特性之函数的扩展实例详解
2017/04/01 Javascript
Vue动态实现评分效果
2017/05/24 Javascript
JavaScript调用模式与this关键字绑定的关系
2018/04/21 Javascript
教你如何用Node实现API的转发(某音乐)
2019/09/20 Javascript
[16:56]教你分分钟做大人:司夜刺客
2014/10/30 DOTA
Python中对元组和列表按条件进行排序的方法示例
2015/11/10 Python
Python基于递归实现电话号码映射功能示例
2018/04/13 Python
python scipy求解非线性方程的方法(fsolve/root)
2018/11/12 Python
值得收藏,Python 开发中的高级技巧
2018/11/23 Python
python处理multipart/form-data的请求方法
2018/12/26 Python
Python 实现数据结构-循环队列的操作方法
2019/07/17 Python
python中seaborn包常用图形使用详解
2019/11/25 Python
python——全排列数的生成方式
2020/02/26 Python
解决Tensorflow2.0 tf.keras.Model.load_weights() 报错处理问题
2020/06/12 Python
解决python3输入的坑——input()
2020/12/05 Python
HTML5如何使用SVG的方法示例
2019/01/11 HTML / CSS
Omio西班牙:全欧洲低价大巴、火车和航班搜索和比价
2017/02/11 全球购物
西部世纪.net笔试题面试题
2014/04/03 面试题
工程招投标邀请书
2014/01/30 职场文书
保护环境倡议书范文
2014/05/13 职场文书
病危通知书样本
2015/04/17 职场文书
pytorch通过训练结果的复现设置随机种子
2021/06/01 Python