pytorch打印网络结构的实例


Posted in Python onAugust 19, 2019

最简单的方法当然可以直接print(net),但是这样网络比较复杂的时候效果不太好,看着比较乱;以前使用caffe的时候有一个网站可以在线生成网络框图,tensorflow可以用tensor board,keras中可以用model.summary()、或者plot_model()。pytorch没有这样的API,但是可以用代码来完成。

(1)安装环境:graphviz

conda install -n pytorch python-graphviz

或:

sudo apt-get install graphviz

或者从官网下载,按此教程。

(2)生成网络结构的代码:

def make_dot(var, params=None):
  """ Produces Graphviz representation of PyTorch autograd graph
  Blue nodes are the Variables that require grad, orange are Tensors
  saved for backward in torch.autograd.Function
  Args:
    var: output Variable
    params: dict of (name, Variable) to add names to node that
      require grad (TODO: make optional)
  """
  if params is not None:
    assert isinstance(params.values()[0], Variable)
    param_map = {id(v): k for k, v in params.items()}
 
  node_attr = dict(style='filled',
           shape='box',
           align='left',
           fontsize='12',
           ranksep='0.1',
           height='0.2')
  dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
  seen = set()
 
  def size_to_str(size):
    return '('+(', ').join(['%d' % v for v in size])+')'
  def add_nodes(var):
    if var not in seen:
      if torch.is_tensor(var):
        dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
      elif hasattr(var, 'variable'):
        u = var.variable
        name = param_map[id(u)] if params is not None else ''
        node_name = '%s\n %s' % (name, size_to_str(u.size()))
        dot.node(str(id(var)), node_name, fillcolor='lightblue')
      else:
        dot.node(str(id(var)), str(type(var).__name__))
      seen.add(var)
      if hasattr(var, 'next_functions'):
        for u in var.next_functions:
          if u[0] is not None:
            dot.edge(str(id(u[0])), str(id(var)))
            add_nodes(u[0])
      if hasattr(var, 'saved_tensors'):
        for t in var.saved_tensors:
          dot.edge(str(id(t)), str(id(var)))
          add_nodes(t)
  add_nodes(var.grad_fn)
  return dot

(3)打印网络结构:

import torch 
from torch.autograd import Variable 
import torch.nn as nn 
from graphviz import Digraph
 
class CNN(nn.module):
  def __init__(self):
   ******
   def forward(self,x):
   ******
   return out
 
*****************************
def make_dot(): #复制上面的代码
*****************************
 
if __name__ == '__main__': 
  net = CNN() 
  x = Variable(torch.randn(1, 1, 1024,1024)) 
  y = net(x) 
  g = make_dot(y) 
  g.view() 
 
  params = list(net.parameters()) 
  k = 0 
  for i in params: 
    l = 1 
    print("该层的结构:" + str(list(i.size()))) 
    for j in i.size(): 
      l *= j 
    print("该层参数和:" + str(l)) 
    k = k + l 
  print("总参数数量和:" + str(k))

(4)结果展示(例如这是一个resnet block类型的网络):

pytorch打印网络结构的实例

以上这篇pytorch打印网络结构的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python检测lvs real server状态
Jan 22 Python
详解Python使用simplejson模块解析JSON的方法
Mar 24 Python
深入理解 Python 中的多线程 新手必看
Nov 20 Python
Python深入06——python的内存管理详解
Dec 07 Python
python构建自定义回调函数详解
Jun 20 Python
python内置函数:lambda、map、filter简单介绍
Nov 16 Python
浅谈python实现Google翻译PDF,解决换行的问题
Nov 28 Python
在python中实现将一张图片剪切成四份的方法
Dec 05 Python
Python中文分词库jieba,pkusegwg性能准确度比较
Feb 11 Python
Pycharm创建python文件自动添加日期作者等信息(步骤详解)
Feb 03 Python
python四个坐标点对图片区域最小外接矩形进行裁剪
Jun 04 Python
pandas中对文本类型数据的处理小结
Nov 01 Python
pytorch索引查找 index_select的例子
Aug 18 #Python
浅谈Pytorch中的torch.gather函数的含义
Aug 18 #Python
PyTorch中Tensor的维度变换实现
Aug 18 #Python
PyTorch中Tensor的拼接与拆分的实现
Aug 18 #Python
详解PyTorch中Tensor的高阶操作
Aug 18 #Python
浅析PyTorch中nn.Linear的使用
Aug 18 #Python
Pytorch实现GoogLeNet的方法
Aug 18 #Python
You might like
php中通过正则表达式下载内容中的远程图片的函数代码
2012/01/10 PHP
PHP中的使用curl发送请求(GET请求和POST请求)
2017/02/08 PHP
Thinkphp5.0自动生成模块及目录的方法详解
2017/04/17 PHP
PHP实现网页内容html标签补全和过滤的方法小结【2种方法】
2017/04/27 PHP
利用Ext Js生成动态树实例代码
2008/09/08 Javascript
js继承的实现代码
2010/08/05 Javascript
利用百度地图JSAPI生成h7n9禽流感分布图实现代码
2013/04/15 Javascript
JS在IE下缺少标识符的错误
2014/07/23 Javascript
jquery库文件略庞大用纯js替换jquery的方法
2014/08/12 Javascript
使用JQuery在线制作ppt并在线演示源码特效
2015/09/08 Javascript
学习vue.js中class与style绑定
2016/12/03 Javascript
BootStrap的select2既可以查询又可以输入的实现代码
2017/02/17 Javascript
Vue组件之Tooltip的示例代码
2017/10/18 Javascript
详解angular分页插件tm.pagination二次触发问题解决方案
2018/07/20 Javascript
讲解Python中运算符使用时的优先级
2015/05/14 Python
一张图带我们入门Python基础教程
2017/02/05 Python
Python实现图片转字符画的示例代码
2017/08/21 Python
Python装饰器用法实例总结
2018/02/07 Python
Python高级特性切片(Slice)操作详解
2018/09/27 Python
Python Pandas批量读取csv文件到dataframe的方法
2018/10/08 Python
Python代码太长换行的实现
2019/07/05 Python
在Python中获取操作系统的进程信息
2019/08/27 Python
python+mysql实现个人论文管理系统
2019/10/25 Python
python实现根据文件格式分类
2019/10/31 Python
python自动化测试之异常及日志操作实例分析
2019/11/09 Python
有关Tensorflow梯度下降常用的优化方法分享
2020/02/04 Python
详解python常用命令行选项与环境变量
2020/02/20 Python
TensorFlow的环境配置与安装教程详解(win10+GeForce GTX1060+CUDA 9.0+cuDNN7.3+tensorflow-gpu 1.12.0+python3.5.5)
2020/06/22 Python
NICKIS.com荷兰:设计师儿童时装
2020/01/08 全球购物
大学毕业生文采飞扬的自我鉴定
2013/12/03 职场文书
学生会竞选演讲稿
2014/04/24 职场文书
机关副主任个人四风问题整改措施
2014/09/26 职场文书
2015年销售助理工作总结
2015/05/11 职场文书
婚礼领导致辞大全
2015/07/28 职场文书
毕业生的自我鉴定表范文
2019/05/16 职场文书
Python办公自动化之教你如何用Python将任意文件转为PDF格式
2021/06/28 Python