pytorch动态网络以及权重共享实例


Posted in Python onJanuary 06, 2020

pytorch 动态网络+权值共享

pytorch以动态图著称,下面以一个栗子来实现动态网络和权值共享技术:

# -*- coding: utf-8 -*-
import random
import torch


class DynamicNet(torch.nn.Module):
  def __init__(self, D_in, H, D_out):
    """
    这里构造了几个向前传播过程中用到的线性函数
    """
    super(DynamicNet, self).__init__()
    self.input_linear = torch.nn.Linear(D_in, H)
    self.middle_linear = torch.nn.Linear(H, H)
    self.output_linear = torch.nn.Linear(H, D_out)

  def forward(self, x):
    """
    For the forward pass of the model, we randomly choose either 0, 1, 2, or 3
    and reuse the middle_linear Module that many times to compute hidden layer
    representations.

    Since each forward pass builds a dynamic computation graph, we can use normal
    Python control-flow operators like loops or conditional statements when
    defining the forward pass of the model.

    Here we also see that it is perfectly safe to reuse the same Module many
    times when defining a computational graph. This is a big improvement from Lua
    Torch, where each Module could be used only once.
    这里中间层每次向前过程中都是随机添加0-3层,而且中间层都是使用的同一个线性层,这样计算时,权值也是用的同一个。
    """
    h_relu = self.input_linear(x).clamp(min=0)
    for _ in range(random.randint(0, 3)):
      h_relu = self.middle_linear(h_relu).clamp(min=0)
    y_pred = self.output_linear(h_relu)
    return y_pred


    # N is batch size; D_in is input dimension;
    # H is hidden dimension; D_out is output dimension.
    N, D_in, H, D_out = 64, 1000, 100, 10

    # Create random Tensors to hold inputs and outputs
    x = torch.randn(N, D_in)
    y = torch.randn(N, D_out)

    # Construct our model by instantiating the class defined above
    model = DynamicNet(D_in, H, D_out)

    # Construct our loss function and an Optimizer. Training this strange model with
    # vanilla stochastic gradient descent is tough, so we use momentum
    criterion = torch.nn.MSELoss(reduction='sum')
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
    for t in range(500):
      # Forward pass: Compute predicted y by passing x to the model
      y_pred = model(x)

      # Compute and print loss
      loss = criterion(y_pred, y)
      print(t, loss.item())

      # Zero gradients, perform a backward pass, and update the weights.
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

这个程序实际上是一种RNN结构,在执行过程中动态的构建计算图

References: Pytorch Documentations.

以上这篇pytorch动态网络以及权重共享实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python使用pyodbc访问数据库操作方法详解
Jul 05 Python
Flask框架使用DBUtils模块连接数据库操作示例
Jul 20 Python
Anaconda下配置python+opencv+contribx的实例讲解
Aug 06 Python
Python实现的调用C语言函数功能简单实例
Mar 13 Python
Python统计一个字符串中每个字符出现了多少次的方法【字符串转换为列表再统计】
May 05 Python
python re.sub()替换正则的匹配内容方法
Jul 22 Python
Python学习笔记之错误和异常及访问错误消息详解
Aug 08 Python
Python如何使用内置库matplotlib绘制折线图
Feb 24 Python
python 装饰器功能与用法案例详解
Mar 06 Python
Python 实现一个计时器
Jul 28 Python
python如何实现word批量转HTML
Sep 30 Python
python分分钟绘制精美地图海报
Feb 15 Python
selenium中get_cookies()和add_cookie()的用法详解
Jan 06 #Python
pytorch中的自定义反向传播,求导实例
Jan 06 #Python
PyTorch中 tensor.detach() 和 tensor.data 的区别详解
Jan 06 #Python
6行Python代码实现进度条效果(Progress、tqdm、alive-progress​​​​​​​和PySimpleGUI库)
Jan 06 #Python
基于python+selenium的二次封装的实现
Jan 06 #Python
Python使用Tkinter实现滚动抽奖器效果
Jan 06 #Python
Python使用Tkinter实现转盘抽奖器的步骤详解
Jan 06 #Python
You might like
php adodb连接带密码access数据库实例,测试成功
2008/05/14 PHP
preg_match_all使用心得分享
2014/01/31 PHP
php使用curl通过代理获取数据的实现方法
2016/05/16 PHP
基于CI框架的微信网页授权库示例
2016/11/25 PHP
php数组实现根据某个键值将相同键值合并生成新二维数组的方法
2017/04/26 PHP
Laravel中错误与异常处理的用法示例
2018/09/16 PHP
laravel实现登录时监听事件,添加登录用户的记录方法
2019/09/30 PHP
调试Node.JS的辅助工具(NodeWatcher)
2012/01/04 Javascript
JS格式化数字金额用逗号隔开保留两位小数
2013/10/18 Javascript
Jquery图片延迟加载插件jquery.lazyload.js的使用方法
2014/05/21 Javascript
JS实现关键字搜索时的相关下拉字段效果
2014/08/05 Javascript
jquery插件jquery.nicescroll实现图片无滚动条左右拖拽的方法
2015/08/10 Javascript
Node.js静态文件服务器改进版
2016/01/10 Javascript
setTimeout学习小结
2017/02/08 Javascript
JS打开摄像头并截图上传示例
2017/02/18 Javascript
ES6模块化的import和export用法方法总结
2017/08/08 Javascript
vue渲染时闪烁{{}}的问题及解决方法
2018/03/28 Javascript
jquery实现购物车基本功能
2019/10/25 jQuery
[06:04]DOTA2英雄梦之声Vol19卓尔游侠
2014/06/20 DOTA
[07:48]DOTA2上海特级锦标赛主赛事首日RECAP
2016/03/04 DOTA
使用 Python 获取 Linux 系统信息的代码
2014/07/13 Python
一张图带我们入门Python基础教程
2017/02/05 Python
python学习之matplotlib绘制散点图实例
2017/12/09 Python
python3.6 +tkinter GUI编程 实现界面化的文本处理工具(推荐)
2017/12/20 Python
Python处理中文标点符号大集合
2018/05/14 Python
Jupyter Notebook添加代码自动补全功能的实现
2021/01/07 Python
手把手教你用Django执行原生SQL的方法
2021/02/18 Python
Sneaker Studio匈牙利:购买运动鞋
2018/03/26 全球购物
利物浦足球俱乐部官方商店(美国):Liverpool FC US
2019/10/09 全球购物
建议书的格式
2014/05/12 职场文书
教师节宣传方案
2014/05/23 职场文书
抵押贷款承诺书
2014/05/30 职场文书
2014年国庆节活动总结
2014/08/26 职场文书
在校学生证明格式
2015/06/24 职场文书
2016年教师节感言
2015/12/09 职场文书
TV动画《史上最强大魔王转生为村民A》番宣CM公布
2022/04/01 日漫