pytorch中获取模型input/output shape实例


Posted in Python onDecember 30, 2019

Pytorch官方目前无法像tensorflow, caffe那样直接给出shape信息,详见

https://github.com/pytorch/pytorch/pull/3043

以下代码算一种workaround。由于CNN, RNN等模块实现不一样,添加其他模块支持可能需要改代码。

例如RNN中bias是bool类型,其权重也不是存于weight属性中,不过我们只关注shape够用了。

该方法必须构造一个输入调用forward后(model(x)调用)才可获取shape

#coding:utf-8
from collections import OrderedDict
import torch
from torch.autograd import Variable
import torch.nn as nn
import models.crnn as crnn
import json
 
 
def get_output_size(summary_dict, output):
 if isinstance(output, tuple):
 for i in xrange(len(output)):
  summary_dict[i] = OrderedDict()
  summary_dict[i] = get_output_size(summary_dict[i],output[i])
 else:
 summary_dict['output_shape'] = list(output.size())
 return summary_dict
 
def summary(input_size, model):
 def register_hook(module):
 def hook(module, input, output):
  class_name = str(module.__class__).split('.')[-1].split("'")[0]
  module_idx = len(summary)
 
  m_key = '%s-%i' % (class_name, module_idx+1)
  summary[m_key] = OrderedDict()
  summary[m_key]['input_shape'] = list(input[0].size())
  summary[m_key] = get_output_size(summary[m_key], output)
 
  params = 0
  if hasattr(module, 'weight'):
  params += torch.prod(torch.LongTensor(list(module.weight.size())))
  if module.weight.requires_grad:
   summary[m_key]['trainable'] = True
  else:
   summary[m_key]['trainable'] = False
  #if hasattr(module, 'bias'):
  # params += torch.prod(torch.LongTensor(list(module.bias.size())))
 
  summary[m_key]['nb_params'] = params
  
 if not isinstance(module, nn.Sequential) and \
  not isinstance(module, nn.ModuleList) and \
  not (module == model):
  hooks.append(module.register_forward_hook(hook))
 
 # check if there are multiple inputs to the network
 if isinstance(input_size[0], (list, tuple)):
 x = [Variable(torch.rand(1,*in_size)) for in_size in input_size]
 else:
 x = Variable(torch.rand(1,*input_size))
 
 # create properties
 summary = OrderedDict()
 hooks = []
 # register hook
 model.apply(register_hook)
 # make a forward pass
 model(x)
 # remove these hooks
 for h in hooks:
 h.remove()
 
 return summary
 
crnn = crnn.CRNN(32, 1, 3755, 256, 1)
x = summary([1,32,128],crnn)
print json.dumps(x)

以pytorch版CRNN为例,输出shape如下

{
"Conv2d-1": {
"input_shape": [1, 1, 32, 128],
"output_shape": [1, 64, 32, 128],
"trainable": true,
"nb_params": 576
},
"ReLU-2": {
"input_shape": [1, 64, 32, 128],
"output_shape": [1, 64, 32, 128],
"nb_params": 0
},
"MaxPool2d-3": {
"input_shape": [1, 64, 32, 128],
"output_shape": [1, 64, 16, 64],
"nb_params": 0
},
"Conv2d-4": {
"input_shape": [1, 64, 16, 64],
"output_shape": [1, 128, 16, 64],
"trainable": true,
"nb_params": 73728
},
"ReLU-5": {
"input_shape": [1, 128, 16, 64],
"output_shape": [1, 128, 16, 64],
"nb_params": 0
},
"MaxPool2d-6": {
"input_shape": [1, 128, 16, 64],
"output_shape": [1, 128, 8, 32],
"nb_params": 0
},
"Conv2d-7": {
"input_shape": [1, 128, 8, 32],
"output_shape": [1, 256, 8, 32],
"trainable": true,
"nb_params": 294912
},
"BatchNorm2d-8": {
"input_shape": [1, 256, 8, 32],
"output_shape": [1, 256, 8, 32],
"trainable": true,
"nb_params": 256
},
"ReLU-9": {
"input_shape": [1, 256, 8, 32],
"output_shape": [1, 256, 8, 32],
"nb_params": 0
},
"Conv2d-10": {
"input_shape": [1, 256, 8, 32],
"output_shape": [1, 256, 8, 32],
"trainable": true,
"nb_params": 589824
},
"ReLU-11": {
"input_shape": [1, 256, 8, 32],
"output_shape": [1, 256, 8, 32],
"nb_params": 0
},
"MaxPool2d-12": {
"input_shape": [1, 256, 8, 32],
"output_shape": [1, 256, 4, 33],
"nb_params": 0
},
"Conv2d-13": {
"input_shape": [1, 256, 4, 33],
"output_shape": [1, 512, 4, 33],
"trainable": true,
"nb_params": 1179648
},
"BatchNorm2d-14": {
"input_shape": [1, 512, 4, 33],
"output_shape": [1, 512, 4, 33],
"trainable": true,
"nb_params": 512
},
"ReLU-15": {
"input_shape": [1, 512, 4, 33],
"output_shape": [1, 512, 4, 33],
"nb_params": 0
},
"Conv2d-16": {
"input_shape": [1, 512, 4, 33],
"output_shape": [1, 512, 4, 33],
"trainable": true,
"nb_params": 2359296
},
"ReLU-17": {
"input_shape": [1, 512, 4, 33],
"output_shape": [1, 512, 4, 33],
"nb_params": 0
},
"MaxPool2d-18": {
"input_shape": [1, 512, 4, 33],
"output_shape": [1, 512, 2, 34],
"nb_params": 0
},
"Conv2d-19": {
"input_shape": [1, 512, 2, 34],
"output_shape": [1, 512, 1, 33],
"trainable": true,
"nb_params": 1048576
},
"BatchNorm2d-20": {
"input_shape": [1, 512, 1, 33],
"output_shape": [1, 512, 1, 33],
"trainable": true,
"nb_params": 512
},
"ReLU-21": {
"input_shape": [1, 512, 1, 33],
"output_shape": [1, 512, 1, 33],
"nb_params": 0
},
"LSTM-22": {
"input_shape": [33, 1, 512],
"0": {
"output_shape": [33, 1, 512]
},
"1": {
"0": {
"output_shape": [2, 1, 256]
},
"1": {
"output_shape": [2, 1, 256]
}
},
"nb_params": 0
},
"Linear-23": {
"input_shape": [33, 512],
"output_shape": [33, 256],
"trainable": true,
"nb_params": 131072
},
"BidirectionalLSTM-24": {
"input_shape": [33, 1, 512],
"output_shape": [33, 1, 256],
"nb_params": 0
},
"LSTM-25": {
"input_shape": [33, 1, 256],
"0": {
"output_shape": [33, 1, 512]
},
"1": {
"0": {
"output_shape": [2, 1, 256]
},
"1": {
"output_shape": [2, 1, 256]
}
},
"nb_params": 0
},
"Linear-26": {
"input_shape": [33, 512],
"output_shape": [33, 3755],
"trainable": true,
"nb_params": 1922560
},
"BidirectionalLSTM-27": {
"input_shape": [33, 1, 256],
"output_shape": [33, 1, 3755],
"nb_params": 0
}
}

以上这篇pytorch中获取模型input/output shape实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python处理PHP数组文本文件实例
Sep 18 Python
Python读写Excel文件方法介绍
Nov 22 Python
Python的Django框架中settings文件的部署建议
May 30 Python
python实现的正则表达式功能入门教程【经典】
Jun 05 Python
Python中使用支持向量机(SVM)算法
Dec 26 Python
django 通过ajax完成邮箱用户注册、激活账号的方法
Apr 17 Python
浅谈numpy生成数组的零值问题
Nov 12 Python
Python多项式回归的实现方法
Mar 11 Python
Pytorch 数据加载与数据预处理方式
Dec 31 Python
浅析python标准库中的glob
Mar 13 Python
python实现测试工具(一)——命令行发送get请求
Oct 19 Python
Python实现位图分割的效果
Nov 20 Python
Python读取csv文件实例解析
Dec 30 #Python
Pytorch Tensor的统计属性实例讲解
Dec 30 #Python
PyTorch中permute的用法详解
Dec 30 #Python
python实现多进程按序号批量修改文件名的方法示例
Dec 30 #Python
Pytorch Tensor基本数学运算详解
Dec 30 #Python
python垃圾回收机制(GC)原理解析
Dec 30 #Python
利用Python代码实现一键抠背景功能
Dec 29 #Python
You might like
PHP中spl_autoload_register函数的用法总结
2013/11/07 PHP
PHP实现PDO的mysql数据库操作类
2014/12/12 PHP
PHP上传文件时自动分配路径的方法
2015/01/09 PHP
ThinkPHP开发框架函数详解:C方法
2015/08/14 PHP
PHP中的Iterator迭代对象属性详解
2019/04/12 PHP
JavaScript Eval 函数使用
2010/03/23 Javascript
jQuery遍历Form示例代码
2013/09/03 Javascript
关于删除时的提示处理(确定删除吗)
2013/11/03 Javascript
jQuery遍历对象、数组、集合实例
2014/11/08 Javascript
JavaScript匿名函数用法分析
2015/02/13 Javascript
jquery实现用户打分评分特效
2015/05/28 Javascript
jQuery实现hover合成事件的方法
2015/08/06 Javascript
深入理解vue路由的使用
2017/03/24 Javascript
JavaScript继承与聚合实例详解
2019/01/22 Javascript
微信公众号生成新浪短网址的实现(快速生成)
2019/08/18 Javascript
在vue中高德地图引入和轨迹的绘制的实现
2019/10/11 Javascript
小程序自动化测试的示例代码
2020/08/11 Javascript
antd日期选择器禁止选择当天之前的时间操作
2020/10/29 Javascript
Python实现两款计算器功能示例
2017/12/19 Python
mac安装pytorch及系统的numpy更新方法
2018/07/26 Python
python直接获取API传递回来的参数方法
2018/12/17 Python
对python周期性定时器的示例详解
2019/02/19 Python
Python如何爬取实时变化的WebSocket数据的方法
2019/03/09 Python
Python Datetime模块和Calendar模块用法实例分析
2019/04/15 Python
python实现登录密码重置简易操作代码
2019/08/14 Python
Django-rest-framework中过滤器的定制实例
2020/04/01 Python
python实现简单文件读写函数
2021/02/25 Python
CSS3实现的闪烁跳跃进度条示例(附源码)
2013/08/19 HTML / CSS
俄罗斯香水在线商店:AromaCode
2019/12/04 全球购物
关于.NET, HTML的五个问题
2012/08/29 面试题
新学期教师寄语
2014/04/02 职场文书
2014年社区党建工作汇报材料
2014/11/02 职场文书
领导参观欢迎词
2015/01/26 职场文书
白银帝国观后感
2015/06/17 职场文书
草房子读书笔记
2015/06/29 职场文书
房地产置业顾问工作总结
2015/10/23 职场文书