pytorch查看模型weight与grad方式


Posted in Python onJune 24, 2020

在用pdb debug的时候,有时候需要看一下特定layer的权重以及相应的梯度信息,如何查看呢?

1. 首先把你的模型打印出来,像这样

pytorch查看模型weight与grad方式

2. 然后观察到model下面有module的key,module下面有features的key, features下面有(0)的key,这样就可以直接打印出weight了,在pdb debug界面输入p model.module.features[0].weight,就可以看到weight,输入 p model.module.features[0].weight.grad就可以查看梯度信息

pytorch查看模型weight与grad方式

pytorch查看模型weight与grad方式

补充知识:查看Pytorch网络的各层输出(feature map)、权重(weight)、偏置(bias)

BatchNorm2d参数量

torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# 卷积层中卷积核的数量C 
num_features ? C from an expected input of size (N, C, H, W)
>>> import torch
>>> m = torch.nn.BatchNorm2d(100)
>>> m.weight.shape
torch.Size([100])
>>> m.numel()
AttributeError: 'BatchNorm2d' object has no attribute 'numel'
>>> m.weight.numel()
100
>>> m.parameters().numel()
Traceback (most recent call last):
 File "<stdin>", line 1, in <module>
AttributeError: 'generator' object has no attribute 'numel'
>>> [p.numel() for p in m.parameters()]
[100, 100]

linear层

>>> import torch
>>> m1 = torch.nn.Linear(100,10)
# 参数数量= (输入神经元+1)*输出神经元
>>> m1.weight.shape
torch.Size([10, 100])
>>> m1.bias.shape
torch.Size([10])
>>> m1.bias.numel()
10
>>> m1.weight.numel()
1000
>>> m11 = list(m1.parameters())
>>> m11[0].shape
# weight
torch.Size([10, 100])
>>> m11[1].shape
# bias
torch.Size([10])

weight and bias

# Method 1 查看Parameters的方式多样化,直接访问即可
model = alexnet(pretrained=True).to(device)
conv1_weight = model.features[0].weight# Method 2 
# 这种方式还适合你想自己参考一个预训练模型写一个网络,各层的参数不变,但网络结构上表述有所不同
# 这样你就可以把param迭代出来,赋给你的网络对应层,避免直接load不能匹配的问题!
for layer,param in model.state_dict().items(): # param is weight or bias(Tensor) 
 print layer,param

feature map

由于pytorch是动态网络,不存储计算数据,查看各层输出的特征图并不是很方便!分下面两种情况讨论:

1、你想查看的层是独立的,那么你在forward时用变量接收并返回即可!!

class Net(nn.Module):
  def __init__(self):
    self.conv1 = nn.Conv2d(1, 1, 3)
    self.conv2 = nn.Conv2d(1, 1, 3)
    self.conv3 = nn.Conv2d(1, 1, 3)  def forward(self, x):
    out1 = F.relu(self.conv1(x))
    out2 = F.relu(self.conv2(out1))
    out3 = F.relu(self.conv3(out2))
    return out1, out2, out3

2、你的想看的层在nn.Sequential()顺序容器中,这个麻烦些,主要有以下几种思路:

# Method 1 巧用nn.Module.children()
# 在模型实例化之后,利用nn.Module.children()删除你查看的那层的后面层
import torch
import torch.nn as nn
from torchvision import modelsmodel = models.alexnet(pretrained=True)# remove last fully-connected layer
new_classifier = nn.Sequential(*list(model.classifier.children())[:-1])
model.classifier = new_classifier
# Third convolutional layer
new_features = nn.Sequential(*list(model.features.children())[:5])
model.features = new_features
# Method 2 巧用hook,推荐使用这种方式,不用改变原有模型
# torch.nn.Module.register_forward_hook(hook)
# hook(module, input, output) -> Nonemodel = models.alexnet(pretrained=True)
# 定义
def hook (module,input,output):
  print output.size()
# 注册
handle = model.features[0].register_forward_hook(hook)
# 删除句柄
handle.remove()# torch.nn.Module.register_backward_hook(hook)
# hook(module, grad_input, grad_output) -> Tensor or None
model = alexnet(pretrained=True).to(device)
outputs = []
def hook (module,input,output):
  outputs.append(output)
  print len(outputs)handle = model.features[0].register_backward_hook(hook)

注:还可以通过定义一个提取特征的类,甚至是重构成各层独立相同模型将问题转化成第一种

计算模型参数数量

def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)

以上这篇pytorch查看模型weight与grad方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python判断直线和矩形是否相交的方法
Jul 14 Python
Python和Perl绘制中国北京跑步地图的方法
Mar 03 Python
python实现list由于numpy array的转换
Apr 04 Python
Python 中的lambda函数介绍
Oct 10 Python
PyTorch搭建一维线性回归模型(二)
May 22 Python
Python实现FM算法解析
Jun 18 Python
浅谈Python3识别判断图片主要颜色并和颜色库进行对比的方法
Oct 25 Python
wxPython色环电阻计算器
Nov 18 Python
自定义Django Form中choicefield下拉菜单选取数据库内容实例
Mar 13 Python
django model 条件过滤 queryset.filter(**condtions)用法详解
May 20 Python
python3.7调试的实例方法
Jul 21 Python
Python使用pandas导入csv文件内容的示例代码
Dec 24 Python
pytorch  网络参数 weight bias 初始化详解
Jun 24 #Python
可视化pytorch 模型中不同BN层的running mean曲线实例
Jun 24 #Python
python3.x中安装web.py步骤方法
Jun 23 #Python
python如何删除文件、目录
Jun 23 #Python
TensorFlow保存TensorBoard图像操作
Jun 23 #Python
python和js交互调用的方法
Jun 23 #Python
virtualenv介绍及简明教程
Jun 23 #Python
You might like
php关于array_multisort多维数组排序的使用说明
2011/01/04 PHP
PHP 中关于ord($str)&amp;gt;0x80的详细说明
2012/09/23 PHP
ThinkPHP调用百度翻译类实现在线翻译
2014/06/26 PHP
Yii的CDbCriteria查询条件用法实例
2014/12/04 PHP
深入浅出讲解:php的socket通信原理
2016/12/03 PHP
PHP实现的字符串匹配算法示例【sunday算法】
2017/12/19 PHP
js弹出模式对话框,并接收回传值的方法
2013/03/12 Javascript
jQuery.extend()的实现方式详解及实例
2013/06/29 Javascript
解析dom中的children对象数组元素firstChild,lastChild的使用
2013/07/10 Javascript
纯文字版返回顶端的js代码
2013/08/01 Javascript
详细介绍8款超实用JavaScript框架
2013/10/25 Javascript
JavaScript跨平台的开源框架NativeScript
2015/03/24 Javascript
jQuery实现页面内锚点平滑跳转特效的方法总结
2015/05/11 Javascript
借助FileReader实现将文件编码为Base64后通过AJAX上传
2015/12/24 Javascript
用jquery获取自定义的标签属性的值简单实例
2016/09/17 Javascript
纯JS代码实现隔行变色鼠标移入高亮
2016/11/23 Javascript
js 概率计算(简单版)
2017/09/12 Javascript
vue实现吸顶、锚点和滚动高亮按钮效果
2019/10/21 Javascript
python使用PyFetion来发送短信的例子
2014/04/22 Python
Python简单实现TCP包发送十六进制数据的方法
2016/04/16 Python
Python数据类型详解(三)元祖:tuple
2016/05/08 Python
Python request设置HTTPS代理代码解析
2018/02/12 Python
Python中dict和set的用法讲解
2019/03/28 Python
python中的反斜杠问题深入讲解
2019/08/12 Python
python通过实例讲解反射机制
2019/10/17 Python
python 字典套字典或列表的示例
2019/12/16 Python
Python如何读取文件中图片格式
2020/01/13 Python
Python2.7:使用Pyhook模块监听鼠标键盘事件-获取坐标实例
2020/03/14 Python
python学生管理系统的实现
2020/04/05 Python
Viking比利时:购买办公用品
2019/10/30 全球购物
Internet主要有哪些网络群组成
2015/12/24 面试题
什么是规则表达式
2012/05/03 面试题
高中运动会入场词
2014/02/14 职场文书
求职意向书范文
2014/04/01 职场文书
Python趣味挑战之给幼儿园弟弟生成1000道算术题
2021/05/28 Python
Python 语言实现六大查找算法
2021/06/30 Python