画pytorch模型图,以及参数计算的方法


Posted in Python onAugust 17, 2019

刚入pytorch的坑,代码还没看太懂。之前用keras用习惯了,第一次使用pytorch还有些不适应,希望广大老司机多多指教。

首先说说,我们如何可视化模型。在keras中就一句话,keras.summary(),或者plot_model(),就可以把模型展现的淋漓尽致。

但是pytorch中好像没有这样一个api让我们直观的看到模型的样子。但是有网友提供了一段代码,可以把模型画出来,对我来说简直就是如有神助啊。

话不多说,上代码吧。

import torch
from torch.autograd import Variable
import torch.nn as nn
from graphviz import Digraph
 
 
class CNN(nn.Module):
  def __init__(self):
    super(CNN, self).__init__()
    self.conv1 = nn.Sequential(
      nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2),
      nn.ReLU(),
      nn.MaxPool2d(kernel_size=2)
    )
    self.conv2 = nn.Sequential(
      nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2),
      nn.ReLU(),
      nn.MaxPool2d(kernel_size=2)
    )
    self.out = nn.Linear(32*7*7, 10)
 
  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    x = x.view(x.size(0), -1) # (batch, 32*7*7)
    out = self.out(x)
    return out
 
 
def make_dot(var, params=None):
  """ Produces Graphviz representation of PyTorch autograd graph
  Blue nodes are the Variables that require grad, orange are Tensors
  saved for backward in torch.autograd.Function
  Args:
    var: output Variable
    params: dict of (name, Variable) to add names to node that
      require grad (TODO: make optional)
  """
  if params is not None:
    assert isinstance(params.values()[0], Variable)
    param_map = {id(v): k for k, v in params.items()}
 
  node_attr = dict(style='filled',
           shape='box',
           align='left',
           fontsize='12',
           ranksep='0.1',
           height='0.2')
  dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
  seen = set()
 
  def size_to_str(size):
    return '('+(', ').join(['%d' % v for v in size])+')'
 
  def add_nodes(var):
    if var not in seen:
      if torch.is_tensor(var):
        dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
      elif hasattr(var, 'variable'):
        u = var.variable
        name = param_map[id(u)] if params is not None else ''
        node_name = '%s\n %s' % (name, size_to_str(u.size()))
        dot.node(str(id(var)), node_name, fillcolor='lightblue')
      else:
        dot.node(str(id(var)), str(type(var).__name__))
      seen.add(var)
      if hasattr(var, 'next_functions'):
        for u in var.next_functions:
          if u[0] is not None:
            dot.edge(str(id(u[0])), str(id(var)))
            add_nodes(u[0])
      if hasattr(var, 'saved_tensors'):
        for t in var.saved_tensors:
          dot.edge(str(id(t)), str(id(var)))
          add_nodes(t)
  add_nodes(var.grad_fn)
  return dot
 
 
if __name__ == '__main__':
  net = CNN()
  x = Variable(torch.randn(1, 1, 28, 28))
  y = net(x)
  g = make_dot(y)
  g.view()
 
  params = list(net.parameters())
  k = 0
  for i in params:
    l = 1
    print("该层的结构:" + str(list(i.size())))
    for j in i.size():
      l *= j
    print("该层参数和:" + str(l))
    k = k + l
  print("总参数数量和:" + str(k))

模型很简单,代码也很简单。就是conv -> relu -> maxpool -> conv -> relu -> maxpool -> fc

大家在可视化的时候,直接复制make_dot那段代码即可,然后需要初始化一个net,以及这个网络需要的数据规模,此处就以 这段代码为例,初始化一个模型net,准备这个模型的输入数据x,shape为(batch,channels,height,width) 然后把数据传入模型得到输出结果y。传入make_dot即可得到下图。

net = CNN()
  x = Variable(torch.randn(1, 1, 28, 28))
  y = net(x)
  g = make_dot(y)
  g.view()

画pytorch模型图,以及参数计算的方法

最后输出该网络的各种参数。

该层的结构:[16, 1, 5, 5]
该层参数和:400
该层的结构:[16]
该层参数和:16
该层的结构:[32, 16, 5, 5]
该层参数和:12800
该层的结构:[32]
该层参数和:32
该层的结构:[10, 1568]
该层参数和:15680
该层的结构:[10]
该层参数和:10
总参数数量和:28938

以上这篇画pytorch模型图,以及参数计算的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现二分法算法实例
Feb 02 Python
使用Python的urllib2模块处理url和图片的技巧两则
Feb 18 Python
Python排序算法实例代码
Aug 10 Python
python实现简单的单变量线性回归方法
Nov 08 Python
对python cv2批量灰度图片并保存的实例讲解
Nov 09 Python
Python中dict和set的用法讲解
Mar 28 Python
django mysql数据库及图片上传接口详解
Jul 18 Python
python 图片二值化处理(处理后为纯黑白的图片)
Nov 01 Python
Python笔记之观察者模式
Nov 20 Python
解决Jupyter无法导入已安装的 module问题
Apr 17 Python
python对execl 处理操作代码
Jun 22 Python
python保存图片的四个常用方法
Feb 28 Python
pytorch 共享参数的示例
Aug 17 #Python
Pytorch卷积层手动初始化权值的实例
Aug 17 #Python
pytorch自定义初始化权重的方法
Aug 17 #Python
在Pytorch中使用样本权重(sample_weight)的正确方法
Aug 17 #Python
获取Pytorch中间某一层权重或者特征的例子
Aug 17 #Python
pyenv与virtualenv安装实现python多版本多项目管理
Aug 17 #Python
pytorch 获取层权重,对特定层注入hook, 提取中间层输出的方法
Aug 17 #Python
You might like
一步一步学习PHP(1) php开发环境配置
2010/02/15 PHP
php DOS攻击实现代码(附如何防范)
2012/05/29 PHP
wordpress自定义url参数实现路由功能的代码示例
2013/11/28 PHP
php采用ajax数据提交post与post常见方法总结
2014/11/10 PHP
php实现smarty模板无限极分类的方法
2015/12/07 PHP
PHP单元测试框架PHPUnit用法详解
2019/01/23 PHP
基于jquery的划词搜索实现(备忘)
2010/09/14 Javascript
js实现带农历和八字等信息的日历特效
2016/05/16 Javascript
js print打印网页指定区域内容的简单实例
2016/11/01 Javascript
概述jQuery中的ajax方法
2016/12/16 Javascript
详解A标签中href=""的几种用法
2017/08/20 Javascript
vue axios请求超时的正确处理方法
2018/04/02 Javascript
vuex提交state&&实时监听state数据的改变方法
2018/09/16 Javascript
JS实现继承的几种常用方式示例
2019/06/22 Javascript
javascript触发模拟鼠标点击事件
2019/06/26 Javascript
JS前端基于canvas给图片添加水印
2020/11/11 Javascript
python中enumerate函数遍历元素用法分析
2016/03/11 Python
Python贪心算法实例小结
2018/04/22 Python
python获取命令行输入参数列表的实例代码
2018/06/23 Python
详解Python3中setuptools、Pip安装教程
2019/06/18 Python
对pyqt5之menu和action的使用详解
2019/06/20 Python
Python机器学习算法库scikit-learn学习之决策树实现方法详解
2019/07/04 Python
django的分页器Paginator 从django中导入类
2019/07/25 Python
python cookie反爬处理的实现
2020/11/01 Python
python爬取抖音视频的实例分析
2021/01/19 Python
远程Wi-Fi宠物监控相机:Petcube
2017/04/26 全球购物
J2EE模式面试题
2016/10/11 面试题
财务专业大学生职业生涯规划范文
2013/12/30 职场文书
办加油卡单位介绍信
2014/01/09 职场文书
远程研修随笔感言
2014/02/10 职场文书
思想作风整顿个人剖析材料
2014/10/06 职场文书
2016暑期社会实践新闻稿
2015/11/25 职场文书
Python爬虫进阶之Beautiful Soup库详解
2021/04/29 Python
HTML5页面音频自动播放的实现方式
2021/06/21 HTML / CSS
Python3的进程和线程你了解吗
2022/03/16 Python
MySQL中order by的执行过程
2022/06/05 MySQL