pytorch动态网络以及权重共享实例


Posted in Python onJanuary 06, 2020

pytorch 动态网络+权值共享

pytorch以动态图著称,下面以一个栗子来实现动态网络和权值共享技术:

# -*- coding: utf-8 -*-
import random
import torch


class DynamicNet(torch.nn.Module):
  def __init__(self, D_in, H, D_out):
    """
    这里构造了几个向前传播过程中用到的线性函数
    """
    super(DynamicNet, self).__init__()
    self.input_linear = torch.nn.Linear(D_in, H)
    self.middle_linear = torch.nn.Linear(H, H)
    self.output_linear = torch.nn.Linear(H, D_out)

  def forward(self, x):
    """
    For the forward pass of the model, we randomly choose either 0, 1, 2, or 3
    and reuse the middle_linear Module that many times to compute hidden layer
    representations.

    Since each forward pass builds a dynamic computation graph, we can use normal
    Python control-flow operators like loops or conditional statements when
    defining the forward pass of the model.

    Here we also see that it is perfectly safe to reuse the same Module many
    times when defining a computational graph. This is a big improvement from Lua
    Torch, where each Module could be used only once.
    这里中间层每次向前过程中都是随机添加0-3层,而且中间层都是使用的同一个线性层,这样计算时,权值也是用的同一个。
    """
    h_relu = self.input_linear(x).clamp(min=0)
    for _ in range(random.randint(0, 3)):
      h_relu = self.middle_linear(h_relu).clamp(min=0)
    y_pred = self.output_linear(h_relu)
    return y_pred


    # N is batch size; D_in is input dimension;
    # H is hidden dimension; D_out is output dimension.
    N, D_in, H, D_out = 64, 1000, 100, 10

    # Create random Tensors to hold inputs and outputs
    x = torch.randn(N, D_in)
    y = torch.randn(N, D_out)

    # Construct our model by instantiating the class defined above
    model = DynamicNet(D_in, H, D_out)

    # Construct our loss function and an Optimizer. Training this strange model with
    # vanilla stochastic gradient descent is tough, so we use momentum
    criterion = torch.nn.MSELoss(reduction='sum')
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
    for t in range(500):
      # Forward pass: Compute predicted y by passing x to the model
      y_pred = model(x)

      # Compute and print loss
      loss = criterion(y_pred, y)
      print(t, loss.item())

      # Zero gradients, perform a backward pass, and update the weights.
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

这个程序实际上是一种RNN结构,在执行过程中动态的构建计算图

References: Pytorch Documentations.

以上这篇pytorch动态网络以及权重共享实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python自定义解析简单xml格式文件的方法
May 11 Python
Python中Random和Math模块学习笔记
May 18 Python
Python实现模拟登录及表单提交的方法
Jul 25 Python
Python黑魔法@property装饰器的使用技巧解析
Jun 16 Python
Python三种遍历文件目录的方法实例代码
Jan 19 Python
使用python的pandas库读取csv文件保存至mysql数据库
Aug 20 Python
利用Python查看微信共同好友功能的实现代码
Apr 24 Python
使用Python实现牛顿法求极值
Feb 10 Python
Python 实现日志同时输出到屏幕和文件
Feb 19 Python
python GUI库图形界面开发之PyQt5选项卡控件QTabWidget详细使用方法与实例
Mar 01 Python
零基础学Python之前需要学c语言吗
Jul 21 Python
Python3爬虫带上cookie的实例代码
Jul 28 Python
selenium中get_cookies()和add_cookie()的用法详解
Jan 06 #Python
pytorch中的自定义反向传播,求导实例
Jan 06 #Python
PyTorch中 tensor.detach() 和 tensor.data 的区别详解
Jan 06 #Python
6行Python代码实现进度条效果(Progress、tqdm、alive-progress​​​​​​​和PySimpleGUI库)
Jan 06 #Python
基于python+selenium的二次封装的实现
Jan 06 #Python
Python使用Tkinter实现滚动抽奖器效果
Jan 06 #Python
Python使用Tkinter实现转盘抽奖器的步骤详解
Jan 06 #Python
You might like
真正面向对象编程:PHP5.01发布
2006/10/09 PHP
CI框架在CLI下执行占用内存过大问题的解决方法
2014/06/17 PHP
php中eval函数的危害与正确禁用方法
2014/06/30 PHP
PHP中实现接收多个name相同但Value不相同表单数据实例
2015/02/03 PHP
PHP新特性详解之命名空间、性状与生成器
2017/07/18 PHP
Laravel中unique和exists验证规则的优化详解
2018/01/28 PHP
Laravel框架实现的批量删除功能示例
2019/01/16 PHP
不用写JS也能使用EXTJS视频演示
2008/12/29 Javascript
为调试JavaScript添加输出窗口的代码
2010/02/07 Javascript
js 事件截取enter按键页面提交事件示例代码
2014/03/04 Javascript
javascript 应用小技巧方法汇总
2015/07/05 Javascript
jQuery实现订单提交页发送短信功能前端处理方法
2016/07/04 Javascript
解决wx.onMenuShareTimeline出现的问题
2016/08/16 Javascript
AngularJS入门教程之Helloworld示例
2016/12/25 Javascript
js获取当前页的URL与window.location.href简单方法
2017/02/13 Javascript
JavaScript实现无穷滚动加载数据
2017/05/06 Javascript
Vue 应用中结合vux使用微信 jssdk的方法
2018/08/28 Javascript
Vue的双向数据绑定实现原理解析
2020/02/17 Javascript
[01:16:37]【全国守擂赛】第三周决赛 Dark Knight vs. 一个弱队
2020/05/04 DOTA
pygame游戏之旅 载入小车图片、更新窗口
2018/11/20 Python
Python爬取破解无线网络wifi密码过程解析
2019/09/17 Python
Python中的list与tuple集合区别解析
2019/10/12 Python
Windows下Anaconda和PyCharm的安装与使用详解
2020/04/23 Python
python 操作excel表格的方法
2020/12/05 Python
秘鲁购物网站:Linio秘鲁
2017/04/07 全球购物
新锐科技Java程序员面试题
2016/07/25 面试题
在校大学生的职业生涯规划书
2014/03/14 职场文书
校庆标语集锦
2014/06/25 职场文书
党在我心中演讲稿
2014/09/02 职场文书
赔偿协议书范本
2014/09/12 职场文书
2014年节能工作总结
2014/12/18 职场文书
鲁迅故居导游词
2015/02/05 职场文书
营业员岗位职责
2015/02/11 职场文书
辩论赛开场白大全(主持人+辩手)
2015/05/29 职场文书
创业计划书之家教中心
2019/09/25 职场文书
深入理解go slice结构
2021/09/15 Golang