pytorch 求网络模型参数实例


Posted in Python onDecember 30, 2019

用pytorch训练一个神经网络时,我们通常会很关心模型的参数总量。下面分别介绍来两种方法求模型参数

一 .求得每一层的模型参数,然后自然的可以计算出总的参数。

1.先初始化一个网络模型model

比如我这里是 model=cliqueNet(里面是些初始化的参数)

2.调用model的Parameters类获取参数列表

pytorch 求网络模型参数实例

一个典型的操作就是将参数列表传入优化器里。如下

optimizer = optim.Adam(model.parameters(), lr=opt.lr)

言归正传,继续回到参数里面,参数在网络里面就是variable,下面分别求每层的尺寸大小和个数。

函数get_number_of_param( ) 里面的参数就是刚才第一步初始化的model

def get_number_of_param(model):
  """get the number of param for every element"""
  count = 0
  for param in model.parameters():
    param_size = param.size()
    count_of_one_param = 1
    for dis in param_size:
      count_of_one_param *= dis
    print(param.size(), count_of_one_param)
    count += count_of_one_param
    print(count)
  print('total number of the model is %d'%count)

再来看看结果:

torch.Size([64, 1, 3, 3]) 576
576
torch.Size([64]) 64
640
torch.Size([6, 36, 64, 3, 3]) 124416
125056
torch.Size([30, 36, 36, 3, 3]) 349920
474976
torch.Size([12, 36]) 432
475408
torch.Size([6, 36, 216, 3, 3]) 419904
895312
torch.Size([30, 36, 36, 3, 3]) 349920
1245232
torch.Size([12, 36]) 432
1245664
torch.Size([6, 36, 216, 3, 3]) 419904
1665568
torch.Size([30, 36, 36, 3, 3]) 349920
2015488
torch.Size([12, 36]) 432
2015920
torch.Size([6, 36, 216, 3, 3]) 419904
2435824
torch.Size([30, 36, 36, 3, 3]) 349920
2785744
torch.Size([12, 36]) 432
2786176
torch.Size([216, 216, 1, 1]) 46656
2832832
torch.Size([216]) 216
2833048
torch.Size([108, 216]) 23328
2856376
torch.Size([108]) 108
2856484
torch.Size([216, 108]) 23328
2879812
torch.Size([216]) 216
2880028
torch.Size([216, 216, 1, 1]) 46656
2926684
torch.Size([216]) 216
2926900
torch.Size([108, 216]) 23328
2950228
torch.Size([108]) 108
2950336
torch.Size([216, 108]) 23328
2973664
torch.Size([216]) 216
2973880
torch.Size([216, 216, 1, 1]) 46656
3020536
torch.Size([216]) 216
3020752
torch.Size([108, 216]) 23328
3044080
torch.Size([108]) 108
3044188
torch.Size([216, 108]) 23328
3067516
torch.Size([216]) 216
3067732
torch.Size([140, 280, 1, 1]) 39200
3106932
torch.Size([140]) 140
3107072
torch.Size([216, 432, 1, 1]) 93312
3200384
torch.Size([216]) 216
3200600
torch.Size([216, 432, 1, 1]) 93312
3293912
torch.Size([216]) 216
3294128
torch.Size([9, 572, 3, 3]) 46332
3340460
torch.Size([9]) 9
3340469
total number of the model is 3340469

可以通过计算验证一下,发现参数与网络是一致的。

二:一行代码就可以搞定参数总个数问题

2.1 先来看看torch.tensor.numel( )这个函数的功能就是求tensor中的元素个数,在网络里面每层参数就是多维数组组成的tensor。

实际上就是求多维数组的元素个数。看代码。

print('cliqueNet parameters:', sum(param.numel() for param in model.parameters()))

当然上面代码中的model还是上面初始化的网络模型。

看看两种的计算结果

torch.Size([64, 1, 3, 3]) 576
576
torch.Size([64]) 64
640
torch.Size([6, 36, 64, 3, 3]) 124416
125056
torch.Size([30, 36, 36, 3, 3]) 349920
474976
torch.Size([12, 36]) 432
475408
torch.Size([6, 36, 216, 3, 3]) 419904
895312
torch.Size([30, 36, 36, 3, 3]) 349920
1245232
torch.Size([12, 36]) 432
1245664
torch.Size([6, 36, 216, 3, 3]) 419904
1665568
torch.Size([30, 36, 36, 3, 3]) 349920
2015488
torch.Size([12, 36]) 432
2015920
torch.Size([6, 36, 216, 3, 3]) 419904
2435824
torch.Size([30, 36, 36, 3, 3]) 349920
2785744
torch.Size([12, 36]) 432
2786176
torch.Size([216, 216, 1, 1]) 46656
2832832
torch.Size([216]) 216
2833048
torch.Size([108, 216]) 23328
2856376
torch.Size([108]) 108
2856484
torch.Size([216, 108]) 23328
2879812
torch.Size([216]) 216
2880028
torch.Size([216, 216, 1, 1]) 46656
2926684
torch.Size([216]) 216
2926900
torch.Size([108, 216]) 23328
2950228
torch.Size([108]) 108
2950336
torch.Size([216, 108]) 23328
2973664
torch.Size([216]) 216
2973880
torch.Size([216, 216, 1, 1]) 46656
3020536
torch.Size([216]) 216
3020752
torch.Size([108, 216]) 23328
3044080
torch.Size([108]) 108
3044188
torch.Size([216, 108]) 23328
3067516
torch.Size([216]) 216
3067732
torch.Size([140, 280, 1, 1]) 39200
3106932
torch.Size([140]) 140
3107072
torch.Size([216, 432, 1, 1]) 93312
3200384
torch.Size([216]) 216
3200600
torch.Size([216, 432, 1, 1]) 93312
3293912
torch.Size([216]) 216
3294128
torch.Size([9, 572, 3, 3]) 46332
3340460
torch.Size([9]) 9
3340469
total number of the model is 3340469
cliqueNet parameters: 3340469

可以看出两种计算出来的是一模一样的。

以上这篇pytorch 求网络模型参数实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python 深入理解yield
Sep 06 Python
如何解决django配置settings时遇到Could not import settings 'conf.local'
Nov 18 Python
python2.7实现邮件发送功能
Dec 12 Python
Python检测数据类型的方法总结
May 20 Python
pygame实现打字游戏
Feb 19 Python
python 变量初始化空列表的例子
Nov 28 Python
利用Pytorch实现简单的线性回归算法
Jan 15 Python
详解Python中pyautogui库的最全使用方法
Apr 01 Python
python同时遍历两个list用法说明
May 02 Python
python用分数表示矩阵的方法实例
Jan 11 Python
python 可视化库PyG2Plot的使用
Jan 21 Python
Python尝试实现蒙特卡罗模拟期权定价
Apr 21 Python
利用python3 的pygame模块实现塔防游戏
Dec 30 #Python
pytorch 批次遍历数据集打印数据的例子
Dec 30 #Python
python多线程使用方法实例详解
Dec 30 #Python
Python动态声明变量赋值代码实例
Dec 30 #Python
使用pytorch实现可视化中间层的结果
Dec 30 #Python
在Pytorch中计算自己模型的FLOPs方式
Dec 30 #Python
Pytorch之保存读取模型实例
Dec 30 #Python
You might like
PHP不用第三变量交换2个变量的值的解决方法
2013/06/02 PHP
跟我学Laravel之视图 & Response
2014/10/15 PHP
PHP使用redis实现统计缓存mysql压力的方法
2015/11/14 PHP
Yii2框架引用bootstrap中日期插件yii2-date-picker的方法
2016/01/09 PHP
laravel中的错误与日志用法详解
2016/07/26 PHP
使用ThinkPHP生成缩略图及显示
2017/04/27 PHP
来自国外的页面JavaScript文件优化
2010/12/08 Javascript
jQuery+CSS 半开折叠效果原理及代码(自写)
2013/03/04 Javascript
基于jquery的手风琴图片展示效果实现方法
2014/12/16 Javascript
javascript 动态创建表格
2015/01/08 Javascript
jQuery经过一段时间自动隐藏指定元素的方法
2015/03/17 Javascript
JavaScript Array对象详解
2016/03/01 Javascript
jquery事件绑定解绑机制源码解析
2016/09/19 Javascript
AngularJS学习第一篇 AngularJS基础知识
2017/02/13 Javascript
vue过渡和animate.css结合使用详解
2017/06/14 Javascript
AngularJS实现的获取焦点及失去焦点时的表单验证功能示例
2017/10/25 Javascript
详解angular路由高亮之RouterLinkActive
2018/04/28 Javascript
在Vuex使用dispatch和commit来调用mutations的区别详解
2018/09/18 Javascript
js获取form表单中name属性的值
2019/02/27 Javascript
element-ui中按需引入的实现
2019/12/25 Javascript
微信小程序聊天功能的示例代码
2020/01/13 Javascript
微信小程序仿淘宝热搜词在搜索框中轮播功能
2020/01/21 Javascript
Node.js API详解之 os模块用法实例分析
2020/05/06 Javascript
[04:27]2014DOTA2国际邀请赛 NAVI战队官方纪录片
2014/07/21 DOTA
[50:04]DOTA2上海特级锦标赛D组小组赛#2 Liquid VS VP第二局
2016/02/28 DOTA
pycharm 使用心得(八)如何调用另一文件中的函数
2014/06/06 Python
Mac中升级Python2.7到Python3.5步骤详解
2017/04/27 Python
python使用pycharm环境调用opencv库
2018/02/11 Python
Python基于Logistic回归建模计算某银行在降低贷款拖欠率的数据示例
2019/01/23 Python
Tensorflow与Keras自适应使用显存方式
2020/06/22 Python
2015年依法行政工作总结
2015/04/29 职场文书
预备党员群众意见
2015/06/01 职场文书
2015大学迎新晚会策划书
2015/07/16 职场文书
运动会跳远广播稿
2015/08/19 职场文书
写给医护人员的一封感谢信
2019/09/16 职场文书
Nginx虚拟主机的配置步骤过程全解
2022/03/31 Servers