画pytorch模型图,以及参数计算的方法


Posted in Python onAugust 17, 2019

刚入pytorch的坑,代码还没看太懂。之前用keras用习惯了,第一次使用pytorch还有些不适应,希望广大老司机多多指教。

首先说说,我们如何可视化模型。在keras中就一句话,keras.summary(),或者plot_model(),就可以把模型展现的淋漓尽致。

但是pytorch中好像没有这样一个api让我们直观的看到模型的样子。但是有网友提供了一段代码,可以把模型画出来,对我来说简直就是如有神助啊。

话不多说,上代码吧。

import torch
from torch.autograd import Variable
import torch.nn as nn
from graphviz import Digraph
 
 
class CNN(nn.Module):
  def __init__(self):
    super(CNN, self).__init__()
    self.conv1 = nn.Sequential(
      nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2),
      nn.ReLU(),
      nn.MaxPool2d(kernel_size=2)
    )
    self.conv2 = nn.Sequential(
      nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2),
      nn.ReLU(),
      nn.MaxPool2d(kernel_size=2)
    )
    self.out = nn.Linear(32*7*7, 10)
 
  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    x = x.view(x.size(0), -1) # (batch, 32*7*7)
    out = self.out(x)
    return out
 
 
def make_dot(var, params=None):
  """ Produces Graphviz representation of PyTorch autograd graph
  Blue nodes are the Variables that require grad, orange are Tensors
  saved for backward in torch.autograd.Function
  Args:
    var: output Variable
    params: dict of (name, Variable) to add names to node that
      require grad (TODO: make optional)
  """
  if params is not None:
    assert isinstance(params.values()[0], Variable)
    param_map = {id(v): k for k, v in params.items()}
 
  node_attr = dict(style='filled',
           shape='box',
           align='left',
           fontsize='12',
           ranksep='0.1',
           height='0.2')
  dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
  seen = set()
 
  def size_to_str(size):
    return '('+(', ').join(['%d' % v for v in size])+')'
 
  def add_nodes(var):
    if var not in seen:
      if torch.is_tensor(var):
        dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
      elif hasattr(var, 'variable'):
        u = var.variable
        name = param_map[id(u)] if params is not None else ''
        node_name = '%s\n %s' % (name, size_to_str(u.size()))
        dot.node(str(id(var)), node_name, fillcolor='lightblue')
      else:
        dot.node(str(id(var)), str(type(var).__name__))
      seen.add(var)
      if hasattr(var, 'next_functions'):
        for u in var.next_functions:
          if u[0] is not None:
            dot.edge(str(id(u[0])), str(id(var)))
            add_nodes(u[0])
      if hasattr(var, 'saved_tensors'):
        for t in var.saved_tensors:
          dot.edge(str(id(t)), str(id(var)))
          add_nodes(t)
  add_nodes(var.grad_fn)
  return dot
 
 
if __name__ == '__main__':
  net = CNN()
  x = Variable(torch.randn(1, 1, 28, 28))
  y = net(x)
  g = make_dot(y)
  g.view()
 
  params = list(net.parameters())
  k = 0
  for i in params:
    l = 1
    print("该层的结构:" + str(list(i.size())))
    for j in i.size():
      l *= j
    print("该层参数和:" + str(l))
    k = k + l
  print("总参数数量和:" + str(k))

模型很简单,代码也很简单。就是conv -> relu -> maxpool -> conv -> relu -> maxpool -> fc

大家在可视化的时候,直接复制make_dot那段代码即可,然后需要初始化一个net,以及这个网络需要的数据规模,此处就以 这段代码为例,初始化一个模型net,准备这个模型的输入数据x,shape为(batch,channels,height,width) 然后把数据传入模型得到输出结果y。传入make_dot即可得到下图。

net = CNN()
  x = Variable(torch.randn(1, 1, 28, 28))
  y = net(x)
  g = make_dot(y)
  g.view()

画pytorch模型图,以及参数计算的方法

最后输出该网络的各种参数。

该层的结构:[16, 1, 5, 5]
该层参数和:400
该层的结构:[16]
该层参数和:16
该层的结构:[32, 16, 5, 5]
该层参数和:12800
该层的结构:[32]
该层参数和:32
该层的结构:[10, 1568]
该层参数和:15680
该层的结构:[10]
该层参数和:10
总参数数量和:28938

以上这篇画pytorch模型图,以及参数计算的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中实现对list做减法操作介绍
Jan 09 Python
python uuid模块使用实例
Apr 08 Python
python直接访问私有属性的简单方法
Jul 25 Python
python fabric实现远程部署
Jan 05 Python
python根据unicode判断语言类型实例代码
Jan 17 Python
python自动登录12306并自动点击验证码完成登录的实现源代码
Apr 25 Python
基于python log取对数详解
Jun 08 Python
解决Python 命令行执行脚本时,提示导入的包找不到的问题
Jan 19 Python
利用Python半自动化生成Nessus报告的方法
Mar 19 Python
Python多版本开发环境管理工具介绍
Jul 03 Python
python manage.py runserver流程解析
Nov 08 Python
解决Python logging模块无法正常输出日志的问题
Feb 21 Python
pytorch 共享参数的示例
Aug 17 #Python
Pytorch卷积层手动初始化权值的实例
Aug 17 #Python
pytorch自定义初始化权重的方法
Aug 17 #Python
在Pytorch中使用样本权重(sample_weight)的正确方法
Aug 17 #Python
获取Pytorch中间某一层权重或者特征的例子
Aug 17 #Python
pyenv与virtualenv安装实现python多版本多项目管理
Aug 17 #Python
pytorch 获取层权重,对特定层注入hook, 提取中间层输出的方法
Aug 17 #Python
You might like
php 论坛采集程序 模拟登陆,抓取页面 实现代码
2009/07/09 PHP
php输出xml必须header的解决方法
2014/10/17 PHP
PHP使用星号替代用户名手机和邮箱的实现代码
2018/02/07 PHP
利用PHP扩展Xhprof分析项目性能实践教程
2018/09/05 PHP
google地图的路线实现代码
2009/08/20 Javascript
jQuery EasyUI 开源插件套装 完全替代ExtJS
2010/03/24 Javascript
JS解决ie6下png透明的方法实例
2013/08/02 Javascript
jQuery回车实现登录简单实现
2013/08/20 Javascript
jquery禁止输入数字以外的字符的示例(纯数字验证码)
2014/04/10 Javascript
JavaScript实现按Ctrl键打开新页面
2014/09/04 Javascript
HTML页面,测试JS对C函数的调用简单实例
2016/08/09 Javascript
jquery插件treegrid树状表格的使用方法详解(.Net平台)
2017/01/03 Javascript
从零开始做一个pagination分页组件
2017/03/15 Javascript
bootstrap suggest搜索建议插件使用详解
2017/03/25 Javascript
实现div滚动条默认最底部以及默认最右边的示例代码
2017/11/15 Javascript
react实现菜单权限控制的方法
2017/12/11 Javascript
axios发送post请求springMVC接收不到参数的解决方法
2018/03/05 Javascript
微信小程序局部刷新触发整页刷新效果的实现代码
2018/11/21 Javascript
使用Python操作Elasticsearch数据索引的教程
2015/04/08 Python
编写Python脚本把sqlAlchemy对象转换成dict的教程
2015/05/29 Python
python 计算两个列表的相关系数的实现
2019/08/29 Python
python多进程间通信代码实例
2019/09/30 Python
意大利奢侈品购物网站:Deliberti
2019/10/08 全球购物
请说出这段代码执行后a和b的值分别是多少
2015/03/28 面试题
幼儿园英语教学反思
2014/01/30 职场文书
导师就业推荐信范文
2014/05/22 职场文书
校园绿化美化方案
2014/06/08 职场文书
爱国口号
2014/06/19 职场文书
学校领导班子四风对照检查材料
2014/09/27 职场文书
布达拉宫导游词
2015/02/02 职场文书
肖申克救赎观后感
2015/06/02 职场文书
入党群众意见范文
2015/06/02 职场文书
国庆节新闻稿
2015/07/17 职场文书
自信主题班会
2015/08/14 职场文书
java设计模式--七大原则详解
2021/07/21 Java/Android
win10如何更改appdata文件夹的默认位置?
2022/07/15 数码科技