pytorch查看模型weight与grad方式


Posted in Python onJune 24, 2020

在用pdb debug的时候,有时候需要看一下特定layer的权重以及相应的梯度信息,如何查看呢?

1. 首先把你的模型打印出来,像这样

pytorch查看模型weight与grad方式

2. 然后观察到model下面有module的key,module下面有features的key, features下面有(0)的key,这样就可以直接打印出weight了,在pdb debug界面输入p model.module.features[0].weight,就可以看到weight,输入 p model.module.features[0].weight.grad就可以查看梯度信息

pytorch查看模型weight与grad方式

pytorch查看模型weight与grad方式

补充知识:查看Pytorch网络的各层输出(feature map)、权重(weight)、偏置(bias)

BatchNorm2d参数量

torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# 卷积层中卷积核的数量C 
num_features ? C from an expected input of size (N, C, H, W)
>>> import torch
>>> m = torch.nn.BatchNorm2d(100)
>>> m.weight.shape
torch.Size([100])
>>> m.numel()
AttributeError: 'BatchNorm2d' object has no attribute 'numel'
>>> m.weight.numel()
100
>>> m.parameters().numel()
Traceback (most recent call last):
 File "<stdin>", line 1, in <module>
AttributeError: 'generator' object has no attribute 'numel'
>>> [p.numel() for p in m.parameters()]
[100, 100]

linear层

>>> import torch
>>> m1 = torch.nn.Linear(100,10)
# 参数数量= (输入神经元+1)*输出神经元
>>> m1.weight.shape
torch.Size([10, 100])
>>> m1.bias.shape
torch.Size([10])
>>> m1.bias.numel()
10
>>> m1.weight.numel()
1000
>>> m11 = list(m1.parameters())
>>> m11[0].shape
# weight
torch.Size([10, 100])
>>> m11[1].shape
# bias
torch.Size([10])

weight and bias

# Method 1 查看Parameters的方式多样化,直接访问即可
model = alexnet(pretrained=True).to(device)
conv1_weight = model.features[0].weight# Method 2 
# 这种方式还适合你想自己参考一个预训练模型写一个网络,各层的参数不变,但网络结构上表述有所不同
# 这样你就可以把param迭代出来,赋给你的网络对应层,避免直接load不能匹配的问题!
for layer,param in model.state_dict().items(): # param is weight or bias(Tensor) 
 print layer,param

feature map

由于pytorch是动态网络,不存储计算数据,查看各层输出的特征图并不是很方便!分下面两种情况讨论:

1、你想查看的层是独立的,那么你在forward时用变量接收并返回即可!!

class Net(nn.Module):
  def __init__(self):
    self.conv1 = nn.Conv2d(1, 1, 3)
    self.conv2 = nn.Conv2d(1, 1, 3)
    self.conv3 = nn.Conv2d(1, 1, 3)  def forward(self, x):
    out1 = F.relu(self.conv1(x))
    out2 = F.relu(self.conv2(out1))
    out3 = F.relu(self.conv3(out2))
    return out1, out2, out3

2、你的想看的层在nn.Sequential()顺序容器中,这个麻烦些,主要有以下几种思路:

# Method 1 巧用nn.Module.children()
# 在模型实例化之后,利用nn.Module.children()删除你查看的那层的后面层
import torch
import torch.nn as nn
from torchvision import modelsmodel = models.alexnet(pretrained=True)# remove last fully-connected layer
new_classifier = nn.Sequential(*list(model.classifier.children())[:-1])
model.classifier = new_classifier
# Third convolutional layer
new_features = nn.Sequential(*list(model.features.children())[:5])
model.features = new_features
# Method 2 巧用hook,推荐使用这种方式,不用改变原有模型
# torch.nn.Module.register_forward_hook(hook)
# hook(module, input, output) -> Nonemodel = models.alexnet(pretrained=True)
# 定义
def hook (module,input,output):
  print output.size()
# 注册
handle = model.features[0].register_forward_hook(hook)
# 删除句柄
handle.remove()# torch.nn.Module.register_backward_hook(hook)
# hook(module, grad_input, grad_output) -> Tensor or None
model = alexnet(pretrained=True).to(device)
outputs = []
def hook (module,input,output):
  outputs.append(output)
  print len(outputs)handle = model.features[0].register_backward_hook(hook)

注:还可以通过定义一个提取特征的类,甚至是重构成各层独立相同模型将问题转化成第一种

计算模型参数数量

def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)

以上这篇pytorch查看模型weight与grad方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
pyqt4教程之实现windows窗口小示例分享
Mar 07 Python
跟老齐学Python之开始真正编程
Sep 12 Python
django 自定义用户user模型的三种方法
Nov 18 Python
深入解析Python中的线程同步方法
Jun 14 Python
python函数中return后的语句一定不会执行吗?
Jul 06 Python
Django中redis的使用方法(包括安装、配置、启动)
Feb 21 Python
Python 打印中文字符的三种方法
Aug 14 Python
Python简单读写Xls格式文档的方法示例
Aug 17 Python
Pycharm设置去除显示的波浪线方法
Oct 28 Python
Python安装selenium包详细过程
Jul 23 Python
Pycharm生成可执行文件.exe的实现方法
Jun 02 Python
python 识别登录验证码图片功能的实现代码(完整代码)
Jul 03 Python
pytorch  网络参数 weight bias 初始化详解
Jun 24 #Python
可视化pytorch 模型中不同BN层的running mean曲线实例
Jun 24 #Python
python3.x中安装web.py步骤方法
Jun 23 #Python
python如何删除文件、目录
Jun 23 #Python
TensorFlow保存TensorBoard图像操作
Jun 23 #Python
python和js交互调用的方法
Jun 23 #Python
virtualenv介绍及简明教程
Jun 23 #Python
You might like
php gd2 上传图片/文字水印/图片水印/等比例缩略图/实现代码
2010/05/15 PHP
PHP中trait使用方法详细介绍
2017/05/21 PHP
input+select(multiple) 实现下拉框输入值
2009/05/21 Javascript
jQuery 使用手册(一)
2009/09/23 Javascript
基于JQuery实现的类似购物商城的购物车
2011/12/06 Javascript
jQuery队列操作方法实例
2014/06/11 Javascript
JS通过ajax动态读取xml文件内容的方法
2015/03/24 Javascript
Jquery 1.9.1源码分析系列(十二)之筛选操作
2015/12/02 Javascript
JavaScript位移运算符(无符号) &gt;&gt;&gt; 三个大于号 的使用方法详解
2016/03/31 Javascript
javascript小数精度丢失的完美解决方法
2016/05/31 Javascript
bootstrap输入框组件使用方法详解
2017/01/19 Javascript
jQuery实现单击按钮遮罩弹出对话框效果(1)
2017/02/20 Javascript
jQuery插件zTree实现删除树子节点的方法示例
2017/03/08 Javascript
vue中的event bus非父子组件通信解析
2017/10/27 Javascript
three.js实现3D视野缩放效果
2017/11/16 Javascript
vue 使用async写数字动态加载效果案例
2020/07/18 Javascript
vuex管理状态仓库使用详解
2020/07/29 Javascript
[01:09:40]Newbee vs Pain 2018国际邀请赛小组赛BO2 第一场 8.16
2018/08/17 DOTA
python实现聚类算法原理
2018/02/12 Python
python读出当前时间精度到秒的代码
2019/07/05 Python
python脚本执行CMD命令并返回结果的例子
2019/08/14 Python
wxPython绘图模块wxPyPlot实现数据可视化
2019/11/19 Python
基于Python爬取爱奇艺资源过程解析
2020/03/02 Python
django-xadmin根据当前登录用户动态设置表单字段默认值方式
2020/03/13 Python
Python 获取异常(Exception)信息的几种方法
2020/12/29 Python
瑞典的玛丽小姐:Miss Mary of Sweden
2019/02/13 全球购物
MIS软件工程师的面试题
2016/04/22 面试题
会计出纳岗位职责
2013/12/25 职场文书
大学生校园创业计划书
2014/02/08 职场文书
乡镇消防工作实施方案
2014/03/27 职场文书
《小鹰学飞》教学反思
2014/04/23 职场文书
元旦联欢会策划方案
2014/06/11 职场文书
2014乡镇领导班子四风对照检查材料思想汇报
2014/10/05 职场文书
个人投资合作协议书
2014/10/12 职场文书
浅谈Redis存储数据类型及存取值方法
2021/05/08 Redis
Python 详解通过Scrapy框架实现爬取CSDN全站热榜标题热词流程
2021/11/11 Python