画pytorch模型图,以及参数计算的方法


Posted in Python onAugust 17, 2019

刚入pytorch的坑,代码还没看太懂。之前用keras用习惯了,第一次使用pytorch还有些不适应,希望广大老司机多多指教。

首先说说,我们如何可视化模型。在keras中就一句话,keras.summary(),或者plot_model(),就可以把模型展现的淋漓尽致。

但是pytorch中好像没有这样一个api让我们直观的看到模型的样子。但是有网友提供了一段代码,可以把模型画出来,对我来说简直就是如有神助啊。

话不多说,上代码吧。

import torch
from torch.autograd import Variable
import torch.nn as nn
from graphviz import Digraph
 
 
class CNN(nn.Module):
  def __init__(self):
    super(CNN, self).__init__()
    self.conv1 = nn.Sequential(
      nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2),
      nn.ReLU(),
      nn.MaxPool2d(kernel_size=2)
    )
    self.conv2 = nn.Sequential(
      nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2),
      nn.ReLU(),
      nn.MaxPool2d(kernel_size=2)
    )
    self.out = nn.Linear(32*7*7, 10)
 
  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    x = x.view(x.size(0), -1) # (batch, 32*7*7)
    out = self.out(x)
    return out
 
 
def make_dot(var, params=None):
  """ Produces Graphviz representation of PyTorch autograd graph
  Blue nodes are the Variables that require grad, orange are Tensors
  saved for backward in torch.autograd.Function
  Args:
    var: output Variable
    params: dict of (name, Variable) to add names to node that
      require grad (TODO: make optional)
  """
  if params is not None:
    assert isinstance(params.values()[0], Variable)
    param_map = {id(v): k for k, v in params.items()}
 
  node_attr = dict(style='filled',
           shape='box',
           align='left',
           fontsize='12',
           ranksep='0.1',
           height='0.2')
  dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
  seen = set()
 
  def size_to_str(size):
    return '('+(', ').join(['%d' % v for v in size])+')'
 
  def add_nodes(var):
    if var not in seen:
      if torch.is_tensor(var):
        dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
      elif hasattr(var, 'variable'):
        u = var.variable
        name = param_map[id(u)] if params is not None else ''
        node_name = '%s\n %s' % (name, size_to_str(u.size()))
        dot.node(str(id(var)), node_name, fillcolor='lightblue')
      else:
        dot.node(str(id(var)), str(type(var).__name__))
      seen.add(var)
      if hasattr(var, 'next_functions'):
        for u in var.next_functions:
          if u[0] is not None:
            dot.edge(str(id(u[0])), str(id(var)))
            add_nodes(u[0])
      if hasattr(var, 'saved_tensors'):
        for t in var.saved_tensors:
          dot.edge(str(id(t)), str(id(var)))
          add_nodes(t)
  add_nodes(var.grad_fn)
  return dot
 
 
if __name__ == '__main__':
  net = CNN()
  x = Variable(torch.randn(1, 1, 28, 28))
  y = net(x)
  g = make_dot(y)
  g.view()
 
  params = list(net.parameters())
  k = 0
  for i in params:
    l = 1
    print("该层的结构:" + str(list(i.size())))
    for j in i.size():
      l *= j
    print("该层参数和:" + str(l))
    k = k + l
  print("总参数数量和:" + str(k))

模型很简单,代码也很简单。就是conv -> relu -> maxpool -> conv -> relu -> maxpool -> fc

大家在可视化的时候,直接复制make_dot那段代码即可,然后需要初始化一个net,以及这个网络需要的数据规模,此处就以 这段代码为例,初始化一个模型net,准备这个模型的输入数据x,shape为(batch,channels,height,width) 然后把数据传入模型得到输出结果y。传入make_dot即可得到下图。

net = CNN()
  x = Variable(torch.randn(1, 1, 28, 28))
  y = net(x)
  g = make_dot(y)
  g.view()

画pytorch模型图,以及参数计算的方法

最后输出该网络的各种参数。

该层的结构:[16, 1, 5, 5]
该层参数和:400
该层的结构:[16]
该层参数和:16
该层的结构:[32, 16, 5, 5]
该层参数和:12800
该层的结构:[32]
该层参数和:32
该层的结构:[10, 1568]
该层参数和:15680
该层的结构:[10]
该层参数和:10
总参数数量和:28938

以上这篇画pytorch模型图,以及参数计算的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python网络编程之UDP通信实例(含服务器端、客户端、UDP广播例子)
Apr 25 Python
跟老齐学Python之深入变量和引用对象
Sep 24 Python
python操作redis的方法
Jul 07 Python
Python基于Pymssql模块实现连接SQL Server数据库的方法详解
Jul 20 Python
完美解决Python 2.7不能正常使用pip install的问题
Jun 12 Python
python实现简单日期工具类
Apr 24 Python
Python操作excel的方法总结(xlrd、xlwt、openpyxl)
Sep 02 Python
分享8点超级有用的Python编程建议(推荐)
Oct 13 Python
从pandas一个单元格的字符串中提取字符串方式
Dec 17 Python
如何基于python实现脚本加密
Dec 28 Python
如何使用Python破解ZIP或RAR压缩文件密码
Jan 09 Python
python 生成器需注意的小问题
Sep 29 Python
pytorch 共享参数的示例
Aug 17 #Python
Pytorch卷积层手动初始化权值的实例
Aug 17 #Python
pytorch自定义初始化权重的方法
Aug 17 #Python
在Pytorch中使用样本权重(sample_weight)的正确方法
Aug 17 #Python
获取Pytorch中间某一层权重或者特征的例子
Aug 17 #Python
pyenv与virtualenv安装实现python多版本多项目管理
Aug 17 #Python
pytorch 获取层权重,对特定层注入hook, 提取中间层输出的方法
Aug 17 #Python
You might like
两款万能的php分页类
2015/11/12 PHP
基于Jquery的文字滚动跑马灯插件(一个页面多个滚动区)
2010/07/26 Javascript
javascript 伪数组实现方法
2010/10/11 Javascript
yepnope.js 异步加载资源文件
2011/09/08 Javascript
poshytip 基于jquery的 插件 主要用于显示微博人的图像和鼠标提示等
2012/10/12 Javascript
Jquery仿淘宝京东多条件筛选可自行结合ajax加载示例
2013/08/28 Javascript
JavaScript 动态加载脚本和样式的方法
2015/04/13 Javascript
Wireshark基本介绍和学习TCP三次握手
2016/08/15 Javascript
Nodejs之TCP服务端与客户端聊天程序详解
2017/07/07 NodeJs
jQuery实现节点的追加、替换、删除、复制功能示例
2017/07/11 jQuery
JavaScript实现无刷新上传预览图片功能
2017/08/02 Javascript
Vue 2.0学习笔记之Vue中的computed属性
2017/10/16 Javascript
jQuery实现模糊搜索功能的方法分析
2018/06/29 jQuery
JS实现用特殊符号替换字符串的中间部分区域的实例代码
2018/07/24 Javascript
15 分钟掌握vue-next响应式原理
2019/10/13 Javascript
VsCode里的Vue模板的实现
2020/08/12 Javascript
[06:16]第十四期-国士无双绝地翻盘之撼地神牛
2014/06/24 DOTA
[36:52]DOTA2真视界:基辅特锦赛总决赛
2017/05/21 DOTA
python中 ? : 三元表达式的使用介绍
2013/10/09 Python
python实现分析apache和nginx日志文件并输出访客ip列表的方法
2015/04/04 Python
python解决Fedora解压zip时中文乱码的方法
2016/09/18 Python
python xml.etree.ElementTree遍历xml所有节点实例详解
2016/12/04 Python
获取Django项目的全部url方法详解
2017/10/26 Python
python逆序打印各位数字的方法
2018/06/25 Python
OpenCV HSV颜色识别及HSV基本颜色分量范围
2019/03/22 Python
python3利用Socket实现通信的方法示例
2019/05/06 Python
人工神经网络算法知识点总结
2019/06/11 Python
PyQt5基本控件使用之消息弹出、用户输入、文件对话框的使用方法
2019/08/06 Python
Python字符串查找基本操作代码案例
2020/10/27 Python
大学班长的职责
2014/01/27 职场文书
项目经理任命书内容
2014/06/06 职场文书
美术专业自荐信
2014/07/07 职场文书
领导干部群众路线教育实践活动个人对照检查材料
2014/09/23 职场文书
2015年中秋节演讲稿
2015/03/20 职场文书
mysql的数据压缩性能对比详情
2021/11/07 MySQL
MySQL 数据库范式化设计理论
2022/04/22 MySQL