pytorch查看网络参数显存占用量等操作


Posted in Python onMay 12, 2021

1.使用torchstat

pip install torchstat 

from torchstat import stat
import torchvision.models as models
model = models.resnet152()
stat(model, (3, 224, 224))

关于stat函数的参数,第一个应该是模型,第二个则是输入尺寸,3为通道数。我没有调研该函数的详细参数,也不知道为什么使用的时候并不提示相应的参数。

2.使用torchsummary

pip install torchsummary
 
from torchsummary import summary
summary(model.cuda(),input_size=(3,32,32),batch_size=-1)

使用该函数直接对参数进行提示,可以发现直接有显式输入batch_size的地方,我自己的感觉好像该函数更好一些。但是!!!不知道为什么,该函数在我的机器上一直报错!!!

TypeError: can't convert CUDA tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

Update:经过论坛咨询,报错的原因找到了,只需要把

pip install torchsummary

修改为

pip install torch-summary

补充:Pytorch查看模型参数并计算模型参数量与可训练参数量

查看模型参数(以AlexNet为例)

import torch
import torch.nn as nn
import torchvision
class AlexNet(nn.Module):
    def __init__(self,num_classes=1000):
        super(AlexNet,self).__init__()
        self.feature_extraction = nn.Sequential(
            nn.Conv2d(in_channels=3,out_channels=96,kernel_size=11,stride=4,padding=2,bias=False),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3,stride=2,padding=0),
            nn.Conv2d(in_channels=96,out_channels=192,kernel_size=5,stride=1,padding=2,bias=False),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3,stride=2,padding=0),
            nn.Conv2d(in_channels=192,out_channels=384,kernel_size=3,stride=1,padding=1,bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=384,out_channels=256,kernel_size=3,stride=1,padding=1,bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=256,out_channels=256,kernel_size=3,stride=1,padding=1,bias=False),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=0),
        )
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(in_features=256*6*6,out_features=4096),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(in_features=4096, out_features=4096),
            nn.ReLU(inplace=True),
            nn.Linear(in_features=4096, out_features=num_classes),
        )
    def forward(self,x):
        x = self.feature_extraction(x)
        x = x.view(x.size(0),256*6*6)
        x = self.classifier(x)
        return x
if __name__ =='__main__':
    # model = torchvision.models.AlexNet()
    model = AlexNet()
    
    # 打印模型参数
    #for param in model.parameters():
        #print(param)
    
    #打印模型名称与shape
    for name,parameters in model.named_parameters():
        print(name,':',parameters.size())
feature_extraction.0.weight : torch.Size([96, 3, 11, 11])
feature_extraction.3.weight : torch.Size([192, 96, 5, 5])
feature_extraction.6.weight : torch.Size([384, 192, 3, 3])
feature_extraction.8.weight : torch.Size([256, 384, 3, 3])
feature_extraction.10.weight : torch.Size([256, 256, 3, 3])
classifier.1.weight : torch.Size([4096, 9216])
classifier.1.bias : torch.Size([4096])
classifier.4.weight : torch.Size([4096, 4096])
classifier.4.bias : torch.Size([4096])
classifier.6.weight : torch.Size([1000, 4096])
classifier.6.bias : torch.Size([1000])

计算参数量与可训练参数量

def get_parameter_number(model):
    total_num = sum(p.numel() for p in model.parameters())
    trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return {'Total': total_num, 'Trainable': trainable_num}

第三方工具

from torchstat import stat
import torchvision.models as models
model = models.alexnet()
stat(model, (3, 224, 224))

pytorch查看网络参数显存占用量等操作

from torchvision.models import alexnet
import torch
from thop import profile
model = alexnet()
input = torch.randn(1, 3, 224, 224)
flops, params = profile(model, inputs=(input, ))
print(flops, params)

pytorch查看网络参数显存占用量等操作

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
解决Python中由于logging模块误用导致的内存泄露
Apr 23 Python
Python升级导致yum、pip报错的解决方法
Sep 06 Python
nohup后台启动Python脚本,log不刷新的解决方法
Jan 14 Python
python之pexpect实现自动交互的例子
Jul 25 Python
利用Python实现kNN算法的代码
Aug 16 Python
Django 拆分model和view的实现方法
Aug 16 Python
python list多级排序知识点总结
Oct 23 Python
Python操作列表常用方法实例小结【创建、遍历、统计、切片等】
Oct 25 Python
TensorFlow 显存使用机制详解
Feb 03 Python
Python计算指定日期是今年的第几天(三种方法)
Mar 26 Python
pandas之分组groupby()的使用整理与总结
Jun 18 Python
详解Python IO编程
Jul 24 Python
Python入门之使用pandas分析excel数据
May 12 #Python
将Python代码打包成.exe可执行文件的完整步骤
python3实现Dijkstra算法最短路径的实现
pytorch 中autograd.grad()函数的用法说明
python3实现无权最短路径的方法
Python入门之基础语法详解
May 11 #Python
如何利用Matlab制作一款真正的拼图小游戏
You might like
CI框架中libraries,helpers,hooks文件夹详细说明
2014/06/10 PHP
PHP中应该避免使用同名变量(拆分临时变量)
2015/04/03 PHP
Laravel4中的Validator验证扩展用法详解
2016/07/26 PHP
PHP判断数组是否为空的常用方法(五种方法)
2017/02/08 PHP
Asp.net下使用Jquery Ajax传送和接收DataTable的代码
2010/09/12 Javascript
js中数组Array的一些常用方法总结
2013/08/12 Javascript
js分页代码分享
2014/04/28 Javascript
jquery div模态窗口的简单实例
2016/05/28 Javascript
微信小程序 在Chrome浏览器上运行以及WebStorm的使用
2016/09/27 Javascript
JavaScript实现数组降维详解
2017/01/05 Javascript
jquery与js实现全选功能的区别
2017/06/11 jQuery
js实现水平滚动菜单导航
2017/07/21 Javascript
JS 实现获取验证码 倒计时功能
2018/10/29 Javascript
在layui中select更改后生效的方法
2019/09/05 Javascript
使用vue实现一个电子签名组件的示例代码
2020/01/06 Javascript
js实现数据导出为EXCEL(支持大量数据导出)
2020/03/31 Javascript
Python3控制路由器——使用requests重启极路由.py
2016/05/11 Python
python学习教程之Numpy和Pandas的使用
2017/09/11 Python
Python实现将一个正整数分解质因数的方法分析
2017/12/14 Python
对python中数组的del,remove,pop区别详解
2018/11/07 Python
对python xlrd读取datetime类型数据的方法详解
2018/12/26 Python
Django使用Channels实现WebSocket的方法
2019/07/28 Python
python 实现turtle画图并导出图片格式的文件
2019/12/07 Python
4行Python代码生成图像验证码(2种)
2020/04/07 Python
python实现画图工具
2020/08/27 Python
不同浏览器对CSS3和HTML5的支持状况
2009/10/31 HTML / CSS
html5设计原理(推荐收藏)
2014/05/17 HTML / CSS
HTML5制作酷炫音频播放器插件图文教程
2014/12/30 HTML / CSS
巴西Mr. Cat在线商店:购买包包和鞋子
2019/09/08 全球购物
软件测试面试题
2014/01/05 面试题
应用化学专业本科生求职信
2013/09/29 职场文书
关爱留守儿童倡议书
2014/04/15 职场文书
硕士毕业论文导师评语
2014/12/31 职场文书
高考诚信考试承诺书
2015/04/29 职场文书
运动会宣传稿50字
2015/07/23 职场文书
员工手册董事长致辞
2015/07/29 职场文书