pytorch中获取模型input/output shape实例


Posted in Python onDecember 30, 2019

Pytorch官方目前无法像tensorflow, caffe那样直接给出shape信息,详见

https://github.com/pytorch/pytorch/pull/3043

以下代码算一种workaround。由于CNN, RNN等模块实现不一样,添加其他模块支持可能需要改代码。

例如RNN中bias是bool类型,其权重也不是存于weight属性中,不过我们只关注shape够用了。

该方法必须构造一个输入调用forward后(model(x)调用)才可获取shape

#coding:utf-8
from collections import OrderedDict
import torch
from torch.autograd import Variable
import torch.nn as nn
import models.crnn as crnn
import json
 
 
def get_output_size(summary_dict, output):
 if isinstance(output, tuple):
 for i in xrange(len(output)):
  summary_dict[i] = OrderedDict()
  summary_dict[i] = get_output_size(summary_dict[i],output[i])
 else:
 summary_dict['output_shape'] = list(output.size())
 return summary_dict
 
def summary(input_size, model):
 def register_hook(module):
 def hook(module, input, output):
  class_name = str(module.__class__).split('.')[-1].split("'")[0]
  module_idx = len(summary)
 
  m_key = '%s-%i' % (class_name, module_idx+1)
  summary[m_key] = OrderedDict()
  summary[m_key]['input_shape'] = list(input[0].size())
  summary[m_key] = get_output_size(summary[m_key], output)
 
  params = 0
  if hasattr(module, 'weight'):
  params += torch.prod(torch.LongTensor(list(module.weight.size())))
  if module.weight.requires_grad:
   summary[m_key]['trainable'] = True
  else:
   summary[m_key]['trainable'] = False
  #if hasattr(module, 'bias'):
  # params += torch.prod(torch.LongTensor(list(module.bias.size())))
 
  summary[m_key]['nb_params'] = params
  
 if not isinstance(module, nn.Sequential) and \
  not isinstance(module, nn.ModuleList) and \
  not (module == model):
  hooks.append(module.register_forward_hook(hook))
 
 # check if there are multiple inputs to the network
 if isinstance(input_size[0], (list, tuple)):
 x = [Variable(torch.rand(1,*in_size)) for in_size in input_size]
 else:
 x = Variable(torch.rand(1,*input_size))
 
 # create properties
 summary = OrderedDict()
 hooks = []
 # register hook
 model.apply(register_hook)
 # make a forward pass
 model(x)
 # remove these hooks
 for h in hooks:
 h.remove()
 
 return summary
 
crnn = crnn.CRNN(32, 1, 3755, 256, 1)
x = summary([1,32,128],crnn)
print json.dumps(x)

以pytorch版CRNN为例,输出shape如下

{
"Conv2d-1": {
"input_shape": [1, 1, 32, 128],
"output_shape": [1, 64, 32, 128],
"trainable": true,
"nb_params": 576
},
"ReLU-2": {
"input_shape": [1, 64, 32, 128],
"output_shape": [1, 64, 32, 128],
"nb_params": 0
},
"MaxPool2d-3": {
"input_shape": [1, 64, 32, 128],
"output_shape": [1, 64, 16, 64],
"nb_params": 0
},
"Conv2d-4": {
"input_shape": [1, 64, 16, 64],
"output_shape": [1, 128, 16, 64],
"trainable": true,
"nb_params": 73728
},
"ReLU-5": {
"input_shape": [1, 128, 16, 64],
"output_shape": [1, 128, 16, 64],
"nb_params": 0
},
"MaxPool2d-6": {
"input_shape": [1, 128, 16, 64],
"output_shape": [1, 128, 8, 32],
"nb_params": 0
},
"Conv2d-7": {
"input_shape": [1, 128, 8, 32],
"output_shape": [1, 256, 8, 32],
"trainable": true,
"nb_params": 294912
},
"BatchNorm2d-8": {
"input_shape": [1, 256, 8, 32],
"output_shape": [1, 256, 8, 32],
"trainable": true,
"nb_params": 256
},
"ReLU-9": {
"input_shape": [1, 256, 8, 32],
"output_shape": [1, 256, 8, 32],
"nb_params": 0
},
"Conv2d-10": {
"input_shape": [1, 256, 8, 32],
"output_shape": [1, 256, 8, 32],
"trainable": true,
"nb_params": 589824
},
"ReLU-11": {
"input_shape": [1, 256, 8, 32],
"output_shape": [1, 256, 8, 32],
"nb_params": 0
},
"MaxPool2d-12": {
"input_shape": [1, 256, 8, 32],
"output_shape": [1, 256, 4, 33],
"nb_params": 0
},
"Conv2d-13": {
"input_shape": [1, 256, 4, 33],
"output_shape": [1, 512, 4, 33],
"trainable": true,
"nb_params": 1179648
},
"BatchNorm2d-14": {
"input_shape": [1, 512, 4, 33],
"output_shape": [1, 512, 4, 33],
"trainable": true,
"nb_params": 512
},
"ReLU-15": {
"input_shape": [1, 512, 4, 33],
"output_shape": [1, 512, 4, 33],
"nb_params": 0
},
"Conv2d-16": {
"input_shape": [1, 512, 4, 33],
"output_shape": [1, 512, 4, 33],
"trainable": true,
"nb_params": 2359296
},
"ReLU-17": {
"input_shape": [1, 512, 4, 33],
"output_shape": [1, 512, 4, 33],
"nb_params": 0
},
"MaxPool2d-18": {
"input_shape": [1, 512, 4, 33],
"output_shape": [1, 512, 2, 34],
"nb_params": 0
},
"Conv2d-19": {
"input_shape": [1, 512, 2, 34],
"output_shape": [1, 512, 1, 33],
"trainable": true,
"nb_params": 1048576
},
"BatchNorm2d-20": {
"input_shape": [1, 512, 1, 33],
"output_shape": [1, 512, 1, 33],
"trainable": true,
"nb_params": 512
},
"ReLU-21": {
"input_shape": [1, 512, 1, 33],
"output_shape": [1, 512, 1, 33],
"nb_params": 0
},
"LSTM-22": {
"input_shape": [33, 1, 512],
"0": {
"output_shape": [33, 1, 512]
},
"1": {
"0": {
"output_shape": [2, 1, 256]
},
"1": {
"output_shape": [2, 1, 256]
}
},
"nb_params": 0
},
"Linear-23": {
"input_shape": [33, 512],
"output_shape": [33, 256],
"trainable": true,
"nb_params": 131072
},
"BidirectionalLSTM-24": {
"input_shape": [33, 1, 512],
"output_shape": [33, 1, 256],
"nb_params": 0
},
"LSTM-25": {
"input_shape": [33, 1, 256],
"0": {
"output_shape": [33, 1, 512]
},
"1": {
"0": {
"output_shape": [2, 1, 256]
},
"1": {
"output_shape": [2, 1, 256]
}
},
"nb_params": 0
},
"Linear-26": {
"input_shape": [33, 512],
"output_shape": [33, 3755],
"trainable": true,
"nb_params": 1922560
},
"BidirectionalLSTM-27": {
"input_shape": [33, 1, 256],
"output_shape": [33, 1, 3755],
"nb_params": 0
}
}

以上这篇pytorch中获取模型input/output shape实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中shape计算矩阵的方法示例
Apr 21 Python
Python中字典(dict)合并的四种方法总结
Aug 10 Python
Python中enumerate()函数编写更Pythonic的循环
Mar 06 Python
python 对dataframe下面的值进行大规模赋值方法
Jun 09 Python
Python实现求解一元二次方程的方法示例
Jun 20 Python
Python pymongo模块常用操作分析
Sep 01 Python
python协程之动态添加任务的方法
Feb 19 Python
python 实现识别图片上的数字
Jul 30 Python
基于pandas中expand的作用详解
Dec 17 Python
flask 实现上传图片并缩放作为头像的例子
Jan 09 Python
利用python3筛选excel中特定的行(行值满足某个条件/行值属于某个集合)
Sep 04 Python
python工具——Mimesis的简单使用教程
Jan 16 Python
Python读取csv文件实例解析
Dec 30 #Python
Pytorch Tensor的统计属性实例讲解
Dec 30 #Python
PyTorch中permute的用法详解
Dec 30 #Python
python实现多进程按序号批量修改文件名的方法示例
Dec 30 #Python
Pytorch Tensor基本数学运算详解
Dec 30 #Python
python垃圾回收机制(GC)原理解析
Dec 30 #Python
利用Python代码实现一键抠背景功能
Dec 29 #Python
You might like
第八节 访问方式 [8]
2006/10/09 PHP
php中文本数据翻页(留言本翻页)
2006/10/09 PHP
ThinkPHP3.0略缩图不能保存到子目录的解决方法
2012/09/30 PHP
深入php处理整数函数的详解
2013/06/09 PHP
thinkphp3查询mssql数据库乱码解决方法分享
2014/02/11 PHP
php命令行(cli)模式下报require 加载路径错误的解决方法
2015/11/23 PHP
用JavaScript脚本实现Web页面信息交互
2006/12/21 Javascript
jquery+ajax实现跨域请求的方法
2015/01/20 Javascript
JS获取图片lowsrc属性的方法
2015/04/01 Javascript
java必学必会之static关键字
2015/12/03 Javascript
基于jquery实现图片上传本地预览功能
2016/01/08 Javascript
小白谈谈对JS原型链的理解
2016/05/03 Javascript
jQuery插件EasyUI获取当前Tab中iframe窗体对象的方法
2016/08/05 Javascript
js拼接html字符串的注意事项
2016/10/13 Javascript
微信小程序中单位rpx和rem的使用
2016/12/06 Javascript
js 获取元素的具体样式信息getcss(实例讲解)
2017/07/05 Javascript
Webpack性能优化 DLL 用法详解
2017/08/10 Javascript
微信小程序getPhoneNumber获取用户手机号
2017/09/29 Javascript
分析javascript原型及原型链
2018/03/18 Javascript
使用Angular 6创建各种动画效果的方法
2018/10/10 Javascript
Python实现批量读取word中表格信息的方法
2015/07/30 Python
Python3中使用urllib的方法详解(header,代理,超时,认证,异常处理)
2016/09/21 Python
简单的python后台管理程序
2017/04/13 Python
python时间日期函数与利用pandas进行时间序列处理详解
2018/03/13 Python
基于python实现百度翻译功能
2019/05/09 Python
python中单下划线(_)和双下划线(__)的特殊用法
2019/08/29 Python
Python 保存加载mat格式文件的示例代码
2020/08/04 Python
CSS3 选择器 基本选择器介绍
2012/01/21 HTML / CSS
canvas线条的属性详解
2018/03/27 HTML / CSS
捷克家居装饰及图书音像购物网站:Velký košík
2018/04/16 全球购物
设计师珠宝:Ylang 23
2018/05/11 全球购物
《长征》教学反思
2014/04/27 职场文书
2014小学生国庆65周年演讲稿
2014/09/21 职场文书
车辆转让协议书
2014/09/24 职场文书
四风问题个人剖析材料
2014/10/07 职场文书
Python可视化学习之seaborn绘制矩阵图详解
2022/02/24 Python