pytorch动态网络以及权重共享实例


Posted in Python onJanuary 06, 2020

pytorch 动态网络+权值共享

pytorch以动态图著称,下面以一个栗子来实现动态网络和权值共享技术:

# -*- coding: utf-8 -*-
import random
import torch


class DynamicNet(torch.nn.Module):
  def __init__(self, D_in, H, D_out):
    """
    这里构造了几个向前传播过程中用到的线性函数
    """
    super(DynamicNet, self).__init__()
    self.input_linear = torch.nn.Linear(D_in, H)
    self.middle_linear = torch.nn.Linear(H, H)
    self.output_linear = torch.nn.Linear(H, D_out)

  def forward(self, x):
    """
    For the forward pass of the model, we randomly choose either 0, 1, 2, or 3
    and reuse the middle_linear Module that many times to compute hidden layer
    representations.

    Since each forward pass builds a dynamic computation graph, we can use normal
    Python control-flow operators like loops or conditional statements when
    defining the forward pass of the model.

    Here we also see that it is perfectly safe to reuse the same Module many
    times when defining a computational graph. This is a big improvement from Lua
    Torch, where each Module could be used only once.
    这里中间层每次向前过程中都是随机添加0-3层,而且中间层都是使用的同一个线性层,这样计算时,权值也是用的同一个。
    """
    h_relu = self.input_linear(x).clamp(min=0)
    for _ in range(random.randint(0, 3)):
      h_relu = self.middle_linear(h_relu).clamp(min=0)
    y_pred = self.output_linear(h_relu)
    return y_pred


    # N is batch size; D_in is input dimension;
    # H is hidden dimension; D_out is output dimension.
    N, D_in, H, D_out = 64, 1000, 100, 10

    # Create random Tensors to hold inputs and outputs
    x = torch.randn(N, D_in)
    y = torch.randn(N, D_out)

    # Construct our model by instantiating the class defined above
    model = DynamicNet(D_in, H, D_out)

    # Construct our loss function and an Optimizer. Training this strange model with
    # vanilla stochastic gradient descent is tough, so we use momentum
    criterion = torch.nn.MSELoss(reduction='sum')
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
    for t in range(500):
      # Forward pass: Compute predicted y by passing x to the model
      y_pred = model(x)

      # Compute and print loss
      loss = criterion(y_pred, y)
      print(t, loss.item())

      # Zero gradients, perform a backward pass, and update the weights.
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

这个程序实际上是一种RNN结构,在执行过程中动态的构建计算图

References: Pytorch Documentations.

以上这篇pytorch动态网络以及权重共享实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python3.4用函数操作mysql5.7数据库
Jun 23 Python
Python实现的插入排序算法原理与用法实例分析
Nov 22 Python
对Python 数组的切片操作详解
Jul 02 Python
pandas 将list切分后存入DataFrame中的实例
Jul 03 Python
python统计中文字符数量的两种方法
Jan 31 Python
详解opencv中画圆circle函数和椭圆ellipse函数
Dec 27 Python
Python unittest框架操作实例解析
Apr 13 Python
Python如何把Spark数据写入ElasticSearch
Apr 18 Python
Pandas把dataframe或series转换成list的方法
Jun 14 Python
彻底解决Python包下载慢问题
Nov 15 Python
Python图像处理之膨胀与腐蚀的操作
Feb 07 Python
python爬取新闻门户网站的示例
Apr 25 Python
selenium中get_cookies()和add_cookie()的用法详解
Jan 06 #Python
pytorch中的自定义反向传播,求导实例
Jan 06 #Python
PyTorch中 tensor.detach() 和 tensor.data 的区别详解
Jan 06 #Python
6行Python代码实现进度条效果(Progress、tqdm、alive-progress​​​​​​​和PySimpleGUI库)
Jan 06 #Python
基于python+selenium的二次封装的实现
Jan 06 #Python
Python使用Tkinter实现滚动抽奖器效果
Jan 06 #Python
Python使用Tkinter实现转盘抽奖器的步骤详解
Jan 06 #Python
You might like
浅谈Windows下 PHP4.0与oracle 8的连接设置
2006/10/09 PHP
简单分析ucenter 会员同步登录通信原理
2014/08/25 PHP
PHP中substr()与explode()函数用法分析
2014/11/24 PHP
PHP微信网页授权的配置文件操作分析
2019/05/29 PHP
JS类的封装及实现代码
2009/12/02 Javascript
jquery对象和DOM对象的区别介绍
2013/08/09 Javascript
CheckBoxList多选样式jquery、C#获取选择项
2013/09/06 Javascript
javascript页面渲染速度测试脚本分享
2014/04/15 Javascript
chrome下jq width()方法取值为0的解决方法
2014/05/26 Javascript
JavaScript跨域方法汇总
2014/10/16 Javascript
javascript多行字符串的简单实现方式
2015/05/04 Javascript
JQuery ztree 异步加载实例讲解
2016/02/25 Javascript
微信小程序 欢迎界面开发的实例详解
2016/11/30 Javascript
Vue.directive自定义指令的使用详解
2017/03/10 Javascript
javascript基本数据类型和转换
2017/03/17 Javascript
JS+CSS实现下拉刷新/上拉加载插件
2017/03/31 Javascript
快速解决vue-cli在ie9+中无效的问题
2018/09/04 Javascript
如何在Vue中使用CleaveJS格式化你的输入内容
2018/12/14 Javascript
详解在网页上通过JS实现文本的语音朗读
2019/03/28 Javascript
微信小程序上传文件到阿里OSS教程
2019/05/20 Javascript
微信小程序可滑动月日历组件使用详解
2019/10/21 Javascript
Vue中实现回车键切换焦点的方法
2020/02/19 Javascript
python实现通过代理服务器访问远程url的方法
2015/04/29 Python
Python中return语句用法实例分析
2015/08/04 Python
使用Python对SQLite数据库操作
2017/04/06 Python
python实现求两个字符串的最长公共子串方法
2018/07/20 Python
Django密码系统实现过程详解
2019/07/19 Python
Python 简单计算要求形状面积的实例
2020/01/18 Python
django实现更改数据库某个字段以及字段段内数据
2020/03/31 Python
Python识别验证码的实现示例
2020/09/30 Python
Python xlwings插入Excel图片的实现方法
2021/02/26 Python
德国滑雪和户外用品网上商店:XSPO
2019/10/30 全球购物
美国排名第一的葡萄酒俱乐部:Firstleaf Wine Club
2020/01/02 全球购物
教师个人自我剖析材料
2014/09/29 职场文书
创业计划之特色精品店
2019/08/12 职场文书
五年级作文之劳动作文
2019/11/12 职场文书