pytorch 求网络模型参数实例


Posted in Python onDecember 30, 2019

用pytorch训练一个神经网络时,我们通常会很关心模型的参数总量。下面分别介绍来两种方法求模型参数

一 .求得每一层的模型参数,然后自然的可以计算出总的参数。

1.先初始化一个网络模型model

比如我这里是 model=cliqueNet(里面是些初始化的参数)

2.调用model的Parameters类获取参数列表

pytorch 求网络模型参数实例

一个典型的操作就是将参数列表传入优化器里。如下

optimizer = optim.Adam(model.parameters(), lr=opt.lr)

言归正传,继续回到参数里面,参数在网络里面就是variable,下面分别求每层的尺寸大小和个数。

函数get_number_of_param( ) 里面的参数就是刚才第一步初始化的model

def get_number_of_param(model):
  """get the number of param for every element"""
  count = 0
  for param in model.parameters():
    param_size = param.size()
    count_of_one_param = 1
    for dis in param_size:
      count_of_one_param *= dis
    print(param.size(), count_of_one_param)
    count += count_of_one_param
    print(count)
  print('total number of the model is %d'%count)

再来看看结果:

torch.Size([64, 1, 3, 3]) 576
576
torch.Size([64]) 64
640
torch.Size([6, 36, 64, 3, 3]) 124416
125056
torch.Size([30, 36, 36, 3, 3]) 349920
474976
torch.Size([12, 36]) 432
475408
torch.Size([6, 36, 216, 3, 3]) 419904
895312
torch.Size([30, 36, 36, 3, 3]) 349920
1245232
torch.Size([12, 36]) 432
1245664
torch.Size([6, 36, 216, 3, 3]) 419904
1665568
torch.Size([30, 36, 36, 3, 3]) 349920
2015488
torch.Size([12, 36]) 432
2015920
torch.Size([6, 36, 216, 3, 3]) 419904
2435824
torch.Size([30, 36, 36, 3, 3]) 349920
2785744
torch.Size([12, 36]) 432
2786176
torch.Size([216, 216, 1, 1]) 46656
2832832
torch.Size([216]) 216
2833048
torch.Size([108, 216]) 23328
2856376
torch.Size([108]) 108
2856484
torch.Size([216, 108]) 23328
2879812
torch.Size([216]) 216
2880028
torch.Size([216, 216, 1, 1]) 46656
2926684
torch.Size([216]) 216
2926900
torch.Size([108, 216]) 23328
2950228
torch.Size([108]) 108
2950336
torch.Size([216, 108]) 23328
2973664
torch.Size([216]) 216
2973880
torch.Size([216, 216, 1, 1]) 46656
3020536
torch.Size([216]) 216
3020752
torch.Size([108, 216]) 23328
3044080
torch.Size([108]) 108
3044188
torch.Size([216, 108]) 23328
3067516
torch.Size([216]) 216
3067732
torch.Size([140, 280, 1, 1]) 39200
3106932
torch.Size([140]) 140
3107072
torch.Size([216, 432, 1, 1]) 93312
3200384
torch.Size([216]) 216
3200600
torch.Size([216, 432, 1, 1]) 93312
3293912
torch.Size([216]) 216
3294128
torch.Size([9, 572, 3, 3]) 46332
3340460
torch.Size([9]) 9
3340469
total number of the model is 3340469

可以通过计算验证一下,发现参数与网络是一致的。

二:一行代码就可以搞定参数总个数问题

2.1 先来看看torch.tensor.numel( )这个函数的功能就是求tensor中的元素个数,在网络里面每层参数就是多维数组组成的tensor。

实际上就是求多维数组的元素个数。看代码。

print('cliqueNet parameters:', sum(param.numel() for param in model.parameters()))

当然上面代码中的model还是上面初始化的网络模型。

看看两种的计算结果

torch.Size([64, 1, 3, 3]) 576
576
torch.Size([64]) 64
640
torch.Size([6, 36, 64, 3, 3]) 124416
125056
torch.Size([30, 36, 36, 3, 3]) 349920
474976
torch.Size([12, 36]) 432
475408
torch.Size([6, 36, 216, 3, 3]) 419904
895312
torch.Size([30, 36, 36, 3, 3]) 349920
1245232
torch.Size([12, 36]) 432
1245664
torch.Size([6, 36, 216, 3, 3]) 419904
1665568
torch.Size([30, 36, 36, 3, 3]) 349920
2015488
torch.Size([12, 36]) 432
2015920
torch.Size([6, 36, 216, 3, 3]) 419904
2435824
torch.Size([30, 36, 36, 3, 3]) 349920
2785744
torch.Size([12, 36]) 432
2786176
torch.Size([216, 216, 1, 1]) 46656
2832832
torch.Size([216]) 216
2833048
torch.Size([108, 216]) 23328
2856376
torch.Size([108]) 108
2856484
torch.Size([216, 108]) 23328
2879812
torch.Size([216]) 216
2880028
torch.Size([216, 216, 1, 1]) 46656
2926684
torch.Size([216]) 216
2926900
torch.Size([108, 216]) 23328
2950228
torch.Size([108]) 108
2950336
torch.Size([216, 108]) 23328
2973664
torch.Size([216]) 216
2973880
torch.Size([216, 216, 1, 1]) 46656
3020536
torch.Size([216]) 216
3020752
torch.Size([108, 216]) 23328
3044080
torch.Size([108]) 108
3044188
torch.Size([216, 108]) 23328
3067516
torch.Size([216]) 216
3067732
torch.Size([140, 280, 1, 1]) 39200
3106932
torch.Size([140]) 140
3107072
torch.Size([216, 432, 1, 1]) 93312
3200384
torch.Size([216]) 216
3200600
torch.Size([216, 432, 1, 1]) 93312
3293912
torch.Size([216]) 216
3294128
torch.Size([9, 572, 3, 3]) 46332
3340460
torch.Size([9]) 9
3340469
total number of the model is 3340469
cliqueNet parameters: 3340469

可以看出两种计算出来的是一模一样的。

以上这篇pytorch 求网络模型参数实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python转换HTML到Text纯文本的方法
Jan 15 Python
深入解析Python中的descriptor描述器的作用及用法
Jun 27 Python
利用python实现简单的循环购物车功能示例代码
Jul 05 Python
使用Python+wxpy 找出微信里把你删除的好友实例
Feb 21 Python
python求最大值,不使用内置函数的实现方法
Jul 09 Python
Python3离线安装Requests模块问题
Oct 13 Python
python3 pillow模块实现简单验证码
Oct 31 Python
python计算Content-MD5并获取文件的Content-MD5值方式
Apr 03 Python
Python 整行读取文本方法并去掉readlines换行\n操作
Sep 03 Python
python里反向传播算法详解
Nov 22 Python
python opencv实现直线检测并测出倾斜角度(附源码+注释)
Dec 31 Python
Python Pytorch查询图像的特征从集合或数据库中查找图像
Apr 09 Python
利用python3 的pygame模块实现塔防游戏
Dec 30 #Python
pytorch 批次遍历数据集打印数据的例子
Dec 30 #Python
python多线程使用方法实例详解
Dec 30 #Python
Python动态声明变量赋值代码实例
Dec 30 #Python
使用pytorch实现可视化中间层的结果
Dec 30 #Python
在Pytorch中计算自己模型的FLOPs方式
Dec 30 #Python
Pytorch之保存读取模型实例
Dec 30 #Python
You might like
有关JSON以及JSON在PHP中的应用
2010/04/09 PHP
php判断邮箱地址是否存在的方法
2016/02/13 PHP
PHP简单预防sql注入的方法
2016/09/27 PHP
thinkPHP5 ACL用户权限模块用法详解
2017/05/10 PHP
juqery 学习之四 筛选查找
2010/11/30 Javascript
THREE.JS入门教程(1)THREE.JS使用前了解
2013/01/24 Javascript
JavaScript将相对地址转换为绝对地址示例代码
2013/07/19 Javascript
node.js WEB开发中图片验证码的实现方法
2014/06/03 Javascript
JS实现网页背景颜色与select框中颜色同时变化的方法
2015/02/27 Javascript
jquery实现鼠标滑过后动态图片提示效果实例
2015/08/10 Javascript
Web打印解决方案之证件套打的实现思路
2016/08/29 Javascript
js中获取 table节点各tr及td的内容简单实例
2016/10/14 Javascript
jQuery插件HighCharts绘制简单2D柱状图效果示例【附demo源码】
2017/03/21 jQuery
基于Vue实现拖拽效果
2018/04/27 Javascript
jQuery访问json文件中数据的方法示例
2019/01/28 jQuery
node.js微信小程序配置消息推送的实现
2019/02/13 Javascript
JS函数动态传递参数的方法分析【基于arguments对象】
2019/06/05 Javascript
vue.js 2.0实现简单分页效果
2019/07/29 Javascript
微信小程序实现单个卡片左滑显示按钮并防止上下滑动干扰功能
2019/12/06 Javascript
react的hooks的用法详解
2020/10/12 Javascript
vue中可编辑树状表格的实现代码
2020/10/31 Javascript
Python Tkinter GUI编程入门介绍
2015/03/10 Python
简单的编程0基础下Python入门指引
2015/04/01 Python
Python import用法以及与from...import的区别
2015/05/28 Python
python list元素为tuple时的排序方法
2018/04/18 Python
Python实现获取邮箱内容并解析的方法示例
2018/06/16 Python
Python面向对象之类和对象实例详解
2018/12/10 Python
python 将日期戳(五位数时间)转换为标准时间
2019/07/11 Python
Zavvi西班牙:电子游戏、极客服装、Blu-ray、Funko Pop等
2019/05/03 全球购物
医学毕业生自我鉴定
2013/10/30 职场文书
2014年会演讲稿范文
2014/01/06 职场文书
2014年母亲节寄语
2014/05/07 职场文书
运动会口号8字
2014/06/07 职场文书
2014最新自愿离婚协议书范本
2014/11/19 职场文书
专业技术职务聘任证明
2015/03/02 职场文书
sql查询结果列拼接成逗号分隔的字符串方法
2021/05/25 SQL Server