pytorch查看模型weight与grad方式


Posted in Python onJune 24, 2020

在用pdb debug的时候,有时候需要看一下特定layer的权重以及相应的梯度信息,如何查看呢?

1. 首先把你的模型打印出来,像这样

pytorch查看模型weight与grad方式

2. 然后观察到model下面有module的key,module下面有features的key, features下面有(0)的key,这样就可以直接打印出weight了,在pdb debug界面输入p model.module.features[0].weight,就可以看到weight,输入 p model.module.features[0].weight.grad就可以查看梯度信息

pytorch查看模型weight与grad方式

pytorch查看模型weight与grad方式

补充知识:查看Pytorch网络的各层输出(feature map)、权重(weight)、偏置(bias)

BatchNorm2d参数量

torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# 卷积层中卷积核的数量C 
num_features ? C from an expected input of size (N, C, H, W)
>>> import torch
>>> m = torch.nn.BatchNorm2d(100)
>>> m.weight.shape
torch.Size([100])
>>> m.numel()
AttributeError: 'BatchNorm2d' object has no attribute 'numel'
>>> m.weight.numel()
100
>>> m.parameters().numel()
Traceback (most recent call last):
 File "<stdin>", line 1, in <module>
AttributeError: 'generator' object has no attribute 'numel'
>>> [p.numel() for p in m.parameters()]
[100, 100]

linear层

>>> import torch
>>> m1 = torch.nn.Linear(100,10)
# 参数数量= (输入神经元+1)*输出神经元
>>> m1.weight.shape
torch.Size([10, 100])
>>> m1.bias.shape
torch.Size([10])
>>> m1.bias.numel()
10
>>> m1.weight.numel()
1000
>>> m11 = list(m1.parameters())
>>> m11[0].shape
# weight
torch.Size([10, 100])
>>> m11[1].shape
# bias
torch.Size([10])

weight and bias

# Method 1 查看Parameters的方式多样化,直接访问即可
model = alexnet(pretrained=True).to(device)
conv1_weight = model.features[0].weight# Method 2 
# 这种方式还适合你想自己参考一个预训练模型写一个网络,各层的参数不变,但网络结构上表述有所不同
# 这样你就可以把param迭代出来,赋给你的网络对应层,避免直接load不能匹配的问题!
for layer,param in model.state_dict().items(): # param is weight or bias(Tensor) 
 print layer,param

feature map

由于pytorch是动态网络,不存储计算数据,查看各层输出的特征图并不是很方便!分下面两种情况讨论:

1、你想查看的层是独立的,那么你在forward时用变量接收并返回即可!!

class Net(nn.Module):
  def __init__(self):
    self.conv1 = nn.Conv2d(1, 1, 3)
    self.conv2 = nn.Conv2d(1, 1, 3)
    self.conv3 = nn.Conv2d(1, 1, 3)  def forward(self, x):
    out1 = F.relu(self.conv1(x))
    out2 = F.relu(self.conv2(out1))
    out3 = F.relu(self.conv3(out2))
    return out1, out2, out3

2、你的想看的层在nn.Sequential()顺序容器中,这个麻烦些,主要有以下几种思路:

# Method 1 巧用nn.Module.children()
# 在模型实例化之后,利用nn.Module.children()删除你查看的那层的后面层
import torch
import torch.nn as nn
from torchvision import modelsmodel = models.alexnet(pretrained=True)# remove last fully-connected layer
new_classifier = nn.Sequential(*list(model.classifier.children())[:-1])
model.classifier = new_classifier
# Third convolutional layer
new_features = nn.Sequential(*list(model.features.children())[:5])
model.features = new_features
# Method 2 巧用hook,推荐使用这种方式,不用改变原有模型
# torch.nn.Module.register_forward_hook(hook)
# hook(module, input, output) -> Nonemodel = models.alexnet(pretrained=True)
# 定义
def hook (module,input,output):
  print output.size()
# 注册
handle = model.features[0].register_forward_hook(hook)
# 删除句柄
handle.remove()# torch.nn.Module.register_backward_hook(hook)
# hook(module, grad_input, grad_output) -> Tensor or None
model = alexnet(pretrained=True).to(device)
outputs = []
def hook (module,input,output):
  outputs.append(output)
  print len(outputs)handle = model.features[0].register_backward_hook(hook)

注:还可以通过定义一个提取特征的类,甚至是重构成各层独立相同模型将问题转化成第一种

计算模型参数数量

def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)

以上这篇pytorch查看模型weight与grad方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在Python中处理时间之clock()方法的使用
May 22 Python
浅谈Python中用datetime包进行对时间的一些操作
Jun 23 Python
Python实现查找匹配项作处理后再替换回去的方法
Jun 10 Python
使用Kivy将python程序打包为apk文件
Jul 29 Python
python处理大日志文件
Jul 23 Python
如何在Django配置文件里配置session链接
Aug 06 Python
使用Python自动生成HTML的方法示例
Aug 06 Python
python使用配置文件过程详解
Dec 28 Python
TensorFlow Saver:保存和读取模型参数.ckpt实例
Feb 10 Python
Python字符串三种格式化输出
Sep 17 Python
PyTorch 实现L2正则化以及Dropout的操作
May 27 Python
Python用any()函数检查字符串中的字母以及如何使用all()函数
Apr 14 Python
pytorch  网络参数 weight bias 初始化详解
Jun 24 #Python
可视化pytorch 模型中不同BN层的running mean曲线实例
Jun 24 #Python
python3.x中安装web.py步骤方法
Jun 23 #Python
python如何删除文件、目录
Jun 23 #Python
TensorFlow保存TensorBoard图像操作
Jun 23 #Python
python和js交互调用的方法
Jun 23 #Python
virtualenv介绍及简明教程
Jun 23 #Python
You might like
新版mysql+apache+php Linux安装指南
2006/10/09 PHP
PHP 木马攻击防御技巧
2009/06/13 PHP
php中apc缓存使用示例
2013/12/25 PHP
php导出csv格式数据并将数字转换成文本的思路以及代码分享
2014/06/05 PHP
浅析php原型模式
2014/11/25 PHP
yii2.0整合阿里云oss上传单个文件的示例
2017/09/19 PHP
PHP count_chars()函数讲解
2019/02/14 PHP
jQuery实现统计复选框选中数量
2014/11/24 Javascript
JavaScript中的关联数组问题
2015/03/04 Javascript
JavaScript获得表单target属性的方法
2015/04/02 Javascript
JSON简介以及用法汇总
2016/02/21 Javascript
概述VUE2.0不可忽视的很多变化
2016/09/25 Javascript
js replace()去除代码中空格的实例
2017/02/14 Javascript
JQuery判断正整数整理小结
2017/08/21 jQuery
JS中跳出循环的示例代码
2017/09/14 Javascript
JavaScript设计模式之缓存代理模式原理与简单用法示例
2018/08/07 Javascript
angular 实现下拉列表组件的示例代码
2019/03/09 Javascript
js实现表格数据搜索
2020/08/09 Javascript
[15:41]教你分分钟做大人——灰烬之灵
2015/03/11 DOTA
python两种遍历字典(dict)的方法比较
2014/05/29 Python
利用pyinstaller或virtualenv将python程序打包详解
2017/03/22 Python
学习python的前途 python挣钱
2019/02/27 Python
nginx黑名单和django限速,最简单的防恶意请求方法分享
2019/08/09 Python
Linux下通过python获取本机ip方法示例
2019/09/06 Python
Python 中如何实现参数化测试的方法示例
2019/12/10 Python
python3 循环读取excel文件并写入json操作
2020/07/14 Python
基于HTML5 WebGL的3D机房的示例
2018/03/16 HTML / CSS
一站式跨境收款解决方案:Payoneer(派安盈)
2018/09/06 全球购物
意大利单身交友网站:Meetic
2020/07/12 全球购物
新西兰最大的天然保健及护肤品网站:HealthPost(直邮中国)
2021/02/13 全球购物
C#和SQL Server的面试题
2016/08/12 面试题
交通法规咨询中心工作职责
2013/11/27 职场文书
最新会计专业求职信范文
2014/01/28 职场文书
四年级语文教学反思
2014/02/05 职场文书
留学推荐信英文范文
2015/03/26 职场文书
六年级作文之关于梦
2019/10/22 职场文书