pytorch查看模型weight与grad方式


Posted in Python onJune 24, 2020

在用pdb debug的时候,有时候需要看一下特定layer的权重以及相应的梯度信息,如何查看呢?

1. 首先把你的模型打印出来,像这样

pytorch查看模型weight与grad方式

2. 然后观察到model下面有module的key,module下面有features的key, features下面有(0)的key,这样就可以直接打印出weight了,在pdb debug界面输入p model.module.features[0].weight,就可以看到weight,输入 p model.module.features[0].weight.grad就可以查看梯度信息

pytorch查看模型weight与grad方式

pytorch查看模型weight与grad方式

补充知识:查看Pytorch网络的各层输出(feature map)、权重(weight)、偏置(bias)

BatchNorm2d参数量

torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# 卷积层中卷积核的数量C 
num_features ? C from an expected input of size (N, C, H, W)
>>> import torch
>>> m = torch.nn.BatchNorm2d(100)
>>> m.weight.shape
torch.Size([100])
>>> m.numel()
AttributeError: 'BatchNorm2d' object has no attribute 'numel'
>>> m.weight.numel()
100
>>> m.parameters().numel()
Traceback (most recent call last):
 File "<stdin>", line 1, in <module>
AttributeError: 'generator' object has no attribute 'numel'
>>> [p.numel() for p in m.parameters()]
[100, 100]

linear层

>>> import torch
>>> m1 = torch.nn.Linear(100,10)
# 参数数量= (输入神经元+1)*输出神经元
>>> m1.weight.shape
torch.Size([10, 100])
>>> m1.bias.shape
torch.Size([10])
>>> m1.bias.numel()
10
>>> m1.weight.numel()
1000
>>> m11 = list(m1.parameters())
>>> m11[0].shape
# weight
torch.Size([10, 100])
>>> m11[1].shape
# bias
torch.Size([10])

weight and bias

# Method 1 查看Parameters的方式多样化,直接访问即可
model = alexnet(pretrained=True).to(device)
conv1_weight = model.features[0].weight# Method 2 
# 这种方式还适合你想自己参考一个预训练模型写一个网络,各层的参数不变,但网络结构上表述有所不同
# 这样你就可以把param迭代出来,赋给你的网络对应层,避免直接load不能匹配的问题!
for layer,param in model.state_dict().items(): # param is weight or bias(Tensor) 
 print layer,param

feature map

由于pytorch是动态网络,不存储计算数据,查看各层输出的特征图并不是很方便!分下面两种情况讨论:

1、你想查看的层是独立的,那么你在forward时用变量接收并返回即可!!

class Net(nn.Module):
  def __init__(self):
    self.conv1 = nn.Conv2d(1, 1, 3)
    self.conv2 = nn.Conv2d(1, 1, 3)
    self.conv3 = nn.Conv2d(1, 1, 3)  def forward(self, x):
    out1 = F.relu(self.conv1(x))
    out2 = F.relu(self.conv2(out1))
    out3 = F.relu(self.conv3(out2))
    return out1, out2, out3

2、你的想看的层在nn.Sequential()顺序容器中,这个麻烦些,主要有以下几种思路:

# Method 1 巧用nn.Module.children()
# 在模型实例化之后,利用nn.Module.children()删除你查看的那层的后面层
import torch
import torch.nn as nn
from torchvision import modelsmodel = models.alexnet(pretrained=True)# remove last fully-connected layer
new_classifier = nn.Sequential(*list(model.classifier.children())[:-1])
model.classifier = new_classifier
# Third convolutional layer
new_features = nn.Sequential(*list(model.features.children())[:5])
model.features = new_features
# Method 2 巧用hook,推荐使用这种方式,不用改变原有模型
# torch.nn.Module.register_forward_hook(hook)
# hook(module, input, output) -> Nonemodel = models.alexnet(pretrained=True)
# 定义
def hook (module,input,output):
  print output.size()
# 注册
handle = model.features[0].register_forward_hook(hook)
# 删除句柄
handle.remove()# torch.nn.Module.register_backward_hook(hook)
# hook(module, grad_input, grad_output) -> Tensor or None
model = alexnet(pretrained=True).to(device)
outputs = []
def hook (module,input,output):
  outputs.append(output)
  print len(outputs)handle = model.features[0].register_backward_hook(hook)

注:还可以通过定义一个提取特征的类,甚至是重构成各层独立相同模型将问题转化成第一种

计算模型参数数量

def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)

以上这篇pytorch查看模型weight与grad方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python处理csv数据的方法
Mar 11 Python
使用Python简单的实现树莓派的WEB控制
Feb 18 Python
用tensorflow实现弹性网络回归算法
Jan 09 Python
Python3 XML 获取雅虎天气的实现方法
Feb 01 Python
用Eclipse写python程序
Feb 10 Python
Flask框架响应、调度方法和蓝图操作实例分析
Jul 24 Python
python根据url地址下载小文件的实例
Dec 18 Python
解决python测试opencv时imread导致的错误问题
Jan 26 Python
Python模拟登入的N种方式(建议收藏)
May 31 Python
matplotlib.pyplot.matshow 矩阵可视化实例
Jun 16 Python
Python内置函数property()如何使用
Sep 01 Python
浅析python 字典嵌套
Sep 29 Python
pytorch  网络参数 weight bias 初始化详解
Jun 24 #Python
可视化pytorch 模型中不同BN层的running mean曲线实例
Jun 24 #Python
python3.x中安装web.py步骤方法
Jun 23 #Python
python如何删除文件、目录
Jun 23 #Python
TensorFlow保存TensorBoard图像操作
Jun 23 #Python
python和js交互调用的方法
Jun 23 #Python
virtualenv介绍及简明教程
Jun 23 #Python
You might like
通过ODBC连接的SQL SERVER实例
2006/10/09 PHP
php读取3389的脚本
2014/05/06 PHP
PHP-Java-Bridge使用笔记
2014/09/22 PHP
PHP实现的简单日历类
2014/11/29 PHP
php上传文件并显示上传进度的方法
2015/03/24 PHP
在Mac OS上编译安装Nginx+PHP+MariaDB开发环境的教程
2016/02/23 PHP
老生常谈PHP面向对象之标识映射
2017/06/21 PHP
关于Curl在Swoole协程中的解决方案详析
2019/09/12 PHP
JQuery Tab选项卡效果代码改进版
2010/04/01 Javascript
关于Javascript模块化和命名空间管理的问题说明
2010/12/06 Javascript
js移除事件 js绑定事件实例应用
2012/11/28 Javascript
不同的jQuery API来处理不同的浏览器事件
2012/12/09 Javascript
浅析Node.js中使用依赖注入的相关问题及解决方法
2015/06/24 Javascript
jQuery实现产品对比功能附源码下载
2016/08/09 Javascript
简单谈谈Vue 模板各类数据绑定
2016/09/25 Javascript
js实现4个方向滚动的球
2017/03/06 Javascript
微信小程序 选项卡的简单实例
2017/05/24 Javascript
详解使用Visual Studio Code对Node.js进行断点调试
2017/09/14 Javascript
bootstrap treeview 树形菜单带复选框及级联选择功能
2018/06/08 Javascript
vuejs使用axios异步访问时用get和post的实例讲解
2018/08/09 Javascript
详解vue中使用微信jssdk
2019/04/19 Javascript
jquery+ajax实现上传图片并显示上传进度功能【附php后台接收】
2019/06/06 jQuery
jQuery实现文本显示一段时间后隐藏的方法分析
2019/06/20 jQuery
为什么Vue3.0使用Proxy实现数据监听(defineProperty表示不背这个锅)
2019/10/14 Javascript
深入理解javascript中的this
2021/02/08 Javascript
浅谈python中的__init__、__new__和__call__方法
2017/07/18 Python
python3实现公众号每日定时发送日报和图片
2018/02/24 Python
python引用(import)某个模块提示没找到对应模块的解决方法
2019/01/19 Python
小学生防溺水广播稿
2014/01/12 职场文书
2014新年寄语
2014/01/20 职场文书
食堂个人先进事迹
2014/01/22 职场文书
旅行社各个岗位职责
2014/03/15 职场文书
保研推荐信格式
2015/03/25 职场文书
党员电教片《信仰》心得体会
2016/01/15 职场文书
2016年区委书记抓基层党建工作公开承诺书
2016/03/25 职场文书
 python中的元类metaclass详情
2022/05/30 Python