pytorch中获取模型input/output shape实例


Posted in Python onDecember 30, 2019

Pytorch官方目前无法像tensorflow, caffe那样直接给出shape信息,详见

https://github.com/pytorch/pytorch/pull/3043

以下代码算一种workaround。由于CNN, RNN等模块实现不一样,添加其他模块支持可能需要改代码。

例如RNN中bias是bool类型,其权重也不是存于weight属性中,不过我们只关注shape够用了。

该方法必须构造一个输入调用forward后(model(x)调用)才可获取shape

#coding:utf-8
from collections import OrderedDict
import torch
from torch.autograd import Variable
import torch.nn as nn
import models.crnn as crnn
import json
 
 
def get_output_size(summary_dict, output):
 if isinstance(output, tuple):
 for i in xrange(len(output)):
  summary_dict[i] = OrderedDict()
  summary_dict[i] = get_output_size(summary_dict[i],output[i])
 else:
 summary_dict['output_shape'] = list(output.size())
 return summary_dict
 
def summary(input_size, model):
 def register_hook(module):
 def hook(module, input, output):
  class_name = str(module.__class__).split('.')[-1].split("'")[0]
  module_idx = len(summary)
 
  m_key = '%s-%i' % (class_name, module_idx+1)
  summary[m_key] = OrderedDict()
  summary[m_key]['input_shape'] = list(input[0].size())
  summary[m_key] = get_output_size(summary[m_key], output)
 
  params = 0
  if hasattr(module, 'weight'):
  params += torch.prod(torch.LongTensor(list(module.weight.size())))
  if module.weight.requires_grad:
   summary[m_key]['trainable'] = True
  else:
   summary[m_key]['trainable'] = False
  #if hasattr(module, 'bias'):
  # params += torch.prod(torch.LongTensor(list(module.bias.size())))
 
  summary[m_key]['nb_params'] = params
  
 if not isinstance(module, nn.Sequential) and \
  not isinstance(module, nn.ModuleList) and \
  not (module == model):
  hooks.append(module.register_forward_hook(hook))
 
 # check if there are multiple inputs to the network
 if isinstance(input_size[0], (list, tuple)):
 x = [Variable(torch.rand(1,*in_size)) for in_size in input_size]
 else:
 x = Variable(torch.rand(1,*input_size))
 
 # create properties
 summary = OrderedDict()
 hooks = []
 # register hook
 model.apply(register_hook)
 # make a forward pass
 model(x)
 # remove these hooks
 for h in hooks:
 h.remove()
 
 return summary
 
crnn = crnn.CRNN(32, 1, 3755, 256, 1)
x = summary([1,32,128],crnn)
print json.dumps(x)

以pytorch版CRNN为例,输出shape如下

{
"Conv2d-1": {
"input_shape": [1, 1, 32, 128],
"output_shape": [1, 64, 32, 128],
"trainable": true,
"nb_params": 576
},
"ReLU-2": {
"input_shape": [1, 64, 32, 128],
"output_shape": [1, 64, 32, 128],
"nb_params": 0
},
"MaxPool2d-3": {
"input_shape": [1, 64, 32, 128],
"output_shape": [1, 64, 16, 64],
"nb_params": 0
},
"Conv2d-4": {
"input_shape": [1, 64, 16, 64],
"output_shape": [1, 128, 16, 64],
"trainable": true,
"nb_params": 73728
},
"ReLU-5": {
"input_shape": [1, 128, 16, 64],
"output_shape": [1, 128, 16, 64],
"nb_params": 0
},
"MaxPool2d-6": {
"input_shape": [1, 128, 16, 64],
"output_shape": [1, 128, 8, 32],
"nb_params": 0
},
"Conv2d-7": {
"input_shape": [1, 128, 8, 32],
"output_shape": [1, 256, 8, 32],
"trainable": true,
"nb_params": 294912
},
"BatchNorm2d-8": {
"input_shape": [1, 256, 8, 32],
"output_shape": [1, 256, 8, 32],
"trainable": true,
"nb_params": 256
},
"ReLU-9": {
"input_shape": [1, 256, 8, 32],
"output_shape": [1, 256, 8, 32],
"nb_params": 0
},
"Conv2d-10": {
"input_shape": [1, 256, 8, 32],
"output_shape": [1, 256, 8, 32],
"trainable": true,
"nb_params": 589824
},
"ReLU-11": {
"input_shape": [1, 256, 8, 32],
"output_shape": [1, 256, 8, 32],
"nb_params": 0
},
"MaxPool2d-12": {
"input_shape": [1, 256, 8, 32],
"output_shape": [1, 256, 4, 33],
"nb_params": 0
},
"Conv2d-13": {
"input_shape": [1, 256, 4, 33],
"output_shape": [1, 512, 4, 33],
"trainable": true,
"nb_params": 1179648
},
"BatchNorm2d-14": {
"input_shape": [1, 512, 4, 33],
"output_shape": [1, 512, 4, 33],
"trainable": true,
"nb_params": 512
},
"ReLU-15": {
"input_shape": [1, 512, 4, 33],
"output_shape": [1, 512, 4, 33],
"nb_params": 0
},
"Conv2d-16": {
"input_shape": [1, 512, 4, 33],
"output_shape": [1, 512, 4, 33],
"trainable": true,
"nb_params": 2359296
},
"ReLU-17": {
"input_shape": [1, 512, 4, 33],
"output_shape": [1, 512, 4, 33],
"nb_params": 0
},
"MaxPool2d-18": {
"input_shape": [1, 512, 4, 33],
"output_shape": [1, 512, 2, 34],
"nb_params": 0
},
"Conv2d-19": {
"input_shape": [1, 512, 2, 34],
"output_shape": [1, 512, 1, 33],
"trainable": true,
"nb_params": 1048576
},
"BatchNorm2d-20": {
"input_shape": [1, 512, 1, 33],
"output_shape": [1, 512, 1, 33],
"trainable": true,
"nb_params": 512
},
"ReLU-21": {
"input_shape": [1, 512, 1, 33],
"output_shape": [1, 512, 1, 33],
"nb_params": 0
},
"LSTM-22": {
"input_shape": [33, 1, 512],
"0": {
"output_shape": [33, 1, 512]
},
"1": {
"0": {
"output_shape": [2, 1, 256]
},
"1": {
"output_shape": [2, 1, 256]
}
},
"nb_params": 0
},
"Linear-23": {
"input_shape": [33, 512],
"output_shape": [33, 256],
"trainable": true,
"nb_params": 131072
},
"BidirectionalLSTM-24": {
"input_shape": [33, 1, 512],
"output_shape": [33, 1, 256],
"nb_params": 0
},
"LSTM-25": {
"input_shape": [33, 1, 256],
"0": {
"output_shape": [33, 1, 512]
},
"1": {
"0": {
"output_shape": [2, 1, 256]
},
"1": {
"output_shape": [2, 1, 256]
}
},
"nb_params": 0
},
"Linear-26": {
"input_shape": [33, 512],
"output_shape": [33, 3755],
"trainable": true,
"nb_params": 1922560
},
"BidirectionalLSTM-27": {
"input_shape": [33, 1, 256],
"output_shape": [33, 1, 3755],
"nb_params": 0
}
}

以上这篇pytorch中获取模型input/output shape实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python爬虫入门教程之点点美女图片爬虫代码分享
Sep 02 Python
在Django同1个页面中的多表单处理详解
Jan 25 Python
Python中使用多进程来实现并行处理的方法小结
Aug 09 Python
python 获取字符串MD5值方法
May 29 Python
python把数组中的数字每行打印3个并保存在文档中的方法
Jul 17 Python
python实现在cmd窗口显示彩色文字
Jun 24 Python
Python中变量的输入输出实例代码详解
Jul 28 Python
Python全栈之列表数据类型详解
Oct 01 Python
python二维键值数组生成转json的例子
Dec 06 Python
浅谈python处理json和redis hash的坑
Jul 16 Python
基于python爬取链家二手房信息代码示例
Oct 21 Python
Python利用zhdate模块实现农历日期处理
Mar 31 Python
Python读取csv文件实例解析
Dec 30 #Python
Pytorch Tensor的统计属性实例讲解
Dec 30 #Python
PyTorch中permute的用法详解
Dec 30 #Python
python实现多进程按序号批量修改文件名的方法示例
Dec 30 #Python
Pytorch Tensor基本数学运算详解
Dec 30 #Python
python垃圾回收机制(GC)原理解析
Dec 30 #Python
利用Python代码实现一键抠背景功能
Dec 29 #Python
You might like
PHP5中实现多态的两种方法实例分享
2014/04/21 PHP
全新Mac配置PHP开发环境教程
2016/02/03 PHP
php制作简单模版引擎
2016/04/07 PHP
Laravel框架实现的记录SQL日志功能示例
2018/06/19 PHP
gearman中worker常驻后台,导致MySQL server has gone away的解决方法
2020/02/27 PHP
PHP实现简易用户登录系统
2020/07/10 PHP
javascript中length属性的探索
2011/07/31 Javascript
js Dialog 实践分享
2012/10/22 Javascript
JS 有趣的eval优化输入验证实例代码
2013/09/22 Javascript
jquery获取css中的选择器(实例讲解)
2013/12/02 Javascript
jQuery中:animated选择器用法实例
2014/12/29 Javascript
JavaScript中的值类型转换介绍
2014/12/31 Javascript
JQuery CheckBox(复选框)操作方法汇总
2015/04/15 Javascript
jquery获取节点名称
2015/04/26 Javascript
jquery实现鼠标拖拽滑动效果来选择数字的方法
2015/05/04 Javascript
Bootstrap路径导航与分页学习使用
2017/02/08 Javascript
让bootstrap的carousel支持滑动滚屏的实现代码
2017/11/27 Javascript
Angularjs过滤器实现动态搜索与排序功能示例
2017/12/13 Javascript
Vue动态生成el-checkbox点击无法赋值的解决方法
2019/02/21 Javascript
详解使用WebPack搭建React开发环境
2019/08/06 Javascript
理解python多线程(python多线程简明教程)
2014/06/09 Python
Python合并多个装饰器小技巧
2015/04/28 Python
Python3内置模块pprint让打印比print更美观详解
2019/06/02 Python
python 接口实现 供第三方调用的例子
2019/08/13 Python
移动Web—CSS为Retina屏幕替换更高质量的图片
2012/12/24 HTML / CSS
HTML5中的Web Notification桌面右下角通知功能的实现
2018/04/19 HTML / CSS
美国汽车轮胎和轮毂销售网站:Tire Rack
2018/01/11 全球购物
德国童装购物网站:NICKI´S.com
2018/04/20 全球购物
开学典礼主持词
2014/03/19 职场文书
新学期开学标语
2014/06/30 职场文书
争先创优公开承诺书
2014/08/30 职场文书
2015年服务员工作总结
2015/04/08 职场文书
2015年学校食堂工作总结
2015/04/22 职场文书
关于ObjectUtils.isEmpty() 和 null 的区别
2022/02/28 Java/Android
MySQL数据库表约束讲解
2022/06/21 MySQL
教你使用RustDesk 搭建一个自己的远程桌面中继服务器
2022/08/14 Servers