画pytorch模型图,以及参数计算的方法


Posted in Python onAugust 17, 2019

刚入pytorch的坑,代码还没看太懂。之前用keras用习惯了,第一次使用pytorch还有些不适应,希望广大老司机多多指教。

首先说说,我们如何可视化模型。在keras中就一句话,keras.summary(),或者plot_model(),就可以把模型展现的淋漓尽致。

但是pytorch中好像没有这样一个api让我们直观的看到模型的样子。但是有网友提供了一段代码,可以把模型画出来,对我来说简直就是如有神助啊。

话不多说,上代码吧。

import torch
from torch.autograd import Variable
import torch.nn as nn
from graphviz import Digraph
 
 
class CNN(nn.Module):
  def __init__(self):
    super(CNN, self).__init__()
    self.conv1 = nn.Sequential(
      nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2),
      nn.ReLU(),
      nn.MaxPool2d(kernel_size=2)
    )
    self.conv2 = nn.Sequential(
      nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2),
      nn.ReLU(),
      nn.MaxPool2d(kernel_size=2)
    )
    self.out = nn.Linear(32*7*7, 10)
 
  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    x = x.view(x.size(0), -1) # (batch, 32*7*7)
    out = self.out(x)
    return out
 
 
def make_dot(var, params=None):
  """ Produces Graphviz representation of PyTorch autograd graph
  Blue nodes are the Variables that require grad, orange are Tensors
  saved for backward in torch.autograd.Function
  Args:
    var: output Variable
    params: dict of (name, Variable) to add names to node that
      require grad (TODO: make optional)
  """
  if params is not None:
    assert isinstance(params.values()[0], Variable)
    param_map = {id(v): k for k, v in params.items()}
 
  node_attr = dict(style='filled',
           shape='box',
           align='left',
           fontsize='12',
           ranksep='0.1',
           height='0.2')
  dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
  seen = set()
 
  def size_to_str(size):
    return '('+(', ').join(['%d' % v for v in size])+')'
 
  def add_nodes(var):
    if var not in seen:
      if torch.is_tensor(var):
        dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
      elif hasattr(var, 'variable'):
        u = var.variable
        name = param_map[id(u)] if params is not None else ''
        node_name = '%s\n %s' % (name, size_to_str(u.size()))
        dot.node(str(id(var)), node_name, fillcolor='lightblue')
      else:
        dot.node(str(id(var)), str(type(var).__name__))
      seen.add(var)
      if hasattr(var, 'next_functions'):
        for u in var.next_functions:
          if u[0] is not None:
            dot.edge(str(id(u[0])), str(id(var)))
            add_nodes(u[0])
      if hasattr(var, 'saved_tensors'):
        for t in var.saved_tensors:
          dot.edge(str(id(t)), str(id(var)))
          add_nodes(t)
  add_nodes(var.grad_fn)
  return dot
 
 
if __name__ == '__main__':
  net = CNN()
  x = Variable(torch.randn(1, 1, 28, 28))
  y = net(x)
  g = make_dot(y)
  g.view()
 
  params = list(net.parameters())
  k = 0
  for i in params:
    l = 1
    print("该层的结构:" + str(list(i.size())))
    for j in i.size():
      l *= j
    print("该层参数和:" + str(l))
    k = k + l
  print("总参数数量和:" + str(k))

模型很简单,代码也很简单。就是conv -> relu -> maxpool -> conv -> relu -> maxpool -> fc

大家在可视化的时候,直接复制make_dot那段代码即可,然后需要初始化一个net,以及这个网络需要的数据规模,此处就以 这段代码为例,初始化一个模型net,准备这个模型的输入数据x,shape为(batch,channels,height,width) 然后把数据传入模型得到输出结果y。传入make_dot即可得到下图。

net = CNN()
  x = Variable(torch.randn(1, 1, 28, 28))
  y = net(x)
  g = make_dot(y)
  g.view()

画pytorch模型图,以及参数计算的方法

最后输出该网络的各种参数。

该层的结构:[16, 1, 5, 5]
该层参数和:400
该层的结构:[16]
该层参数和:16
该层的结构:[32, 16, 5, 5]
该层参数和:12800
该层的结构:[32]
该层参数和:32
该层的结构:[10, 1568]
该层参数和:15680
该层的结构:[10]
该层参数和:10
总参数数量和:28938

以上这篇画pytorch模型图,以及参数计算的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python2.5/2.6实用教程 入门基础篇
Nov 29 Python
Python实现从订阅源下载图片的方法
Mar 11 Python
Python解析树及树的遍历
Feb 03 Python
Python的装饰器用法学习笔记
Jun 24 Python
python中 chr unichr ord函数的实例详解
Aug 06 Python
详解Python核心对象类型字符串
Feb 11 Python
python文本数据相似度的度量
Mar 12 Python
浅谈Pycharm调用同级目录下的py脚本bug
Dec 03 Python
django将数组传递给前台模板的方法
Aug 06 Python
python字符串反转的四种方法详解
Dec 02 Python
Python中base64与xml取值结合问题
Dec 22 Python
Django表单提交后实现获取相同name的不同value值
May 14 Python
pytorch 共享参数的示例
Aug 17 #Python
Pytorch卷积层手动初始化权值的实例
Aug 17 #Python
pytorch自定义初始化权重的方法
Aug 17 #Python
在Pytorch中使用样本权重(sample_weight)的正确方法
Aug 17 #Python
获取Pytorch中间某一层权重或者特征的例子
Aug 17 #Python
pyenv与virtualenv安装实现python多版本多项目管理
Aug 17 #Python
pytorch 获取层权重,对特定层注入hook, 提取中间层输出的方法
Aug 17 #Python
You might like
Php Image Resize图片大小调整的函数代码
2011/01/17 PHP
Php output buffering缓存及程序缓存深入解析
2013/07/15 PHP
腾讯微博提示missing parameter errorcode 102 错误的解决方法
2014/12/22 PHP
javascript实现行拖动的方法
2015/05/27 Javascript
js判断子窗体是否关闭的方法
2015/08/11 Javascript
jQuery点击按钮弹出遮罩层且内容居中特效
2015/12/14 Javascript
JS实现弹出居中的模式窗口示例
2016/06/20 Javascript
jQuery基本过滤选择器用法示例
2016/09/09 Javascript
值得分享的JavaScript实现图片轮播组件
2016/11/21 Javascript
微信小程序 页面滑动事件的实例详解
2017/10/12 Javascript
react学习笔记之state以及setState的使用
2017/12/07 Javascript
Angular6中使用Swiper的方法示例
2018/07/09 Javascript
解决angularJS中input标签的ng-change事件无效问题
2018/09/13 Javascript
Vue.js中的组件系统
2019/05/30 Javascript
[02:57]2014DOTA2国际邀请赛-观众采访
2014/07/19 DOTA
[44:15]DOTA2上海特级锦标赛主赛事日 - 5 败者组决赛Liquid VS EG第二局
2016/03/06 DOTA
python中的实例方法、静态方法、类方法、类变量和实例变量浅析
2014/04/26 Python
Python 遍历子文件和所有子文件夹的代码实例
2016/12/21 Python
Odoo中如何生成唯一不重复的序列号详解
2018/02/10 Python
Python随机函数random()使用方法小结
2018/04/29 Python
python实现登录密码重置简易操作代码
2019/08/14 Python
python可视化 matplotlib画图使用colorbar工具自定义颜色
2020/12/07 Python
canvas之自定义头像功能实现代码示例
2017/09/29 HTML / CSS
ZWILLING双立人英国网上商店:德国刀具锅具厨具品牌
2018/05/15 全球购物
幼儿园大班教学反思
2014/02/10 职场文书
2014学习优秀共产党员先进事迹材料思想汇报
2014/09/14 职场文书
见习报告格式要求
2014/11/04 职场文书
论文答谢词
2015/01/20 职场文书
青岛导游词
2015/02/12 职场文书
廉洁自律证明
2015/06/24 职场文书
军训决心书范文
2015/09/22 职场文书
2016猴年开门红标语口号
2015/12/26 职场文书
教师学习心得体会范文
2016/01/21 职场文书
小学思品教学反思
2016/02/20 职场文书
Python Django ORM连表正反操作技巧
2021/06/13 Python
Spring 使用注解开发
2022/05/20 Java/Android