pytorch 权重weight 与 梯度grad 可视化操作


Posted in Python onJune 05, 2021

pytorch 权重weight 与 梯度grad 可视化

查看特定layer的权重以及相应的梯度信息

打印模型

pytorch 权重weight 与 梯度grad 可视化操作

观察到model下面有module的key,module下面有features的key, features下面有(0)的key,这样就可以直接打印出weight了

在pdb debug界面输入p model.module.features[0].weight,就可以看到weight,输入 p model.module.features[0].weight.grad 就可以查看梯度信息。

中间变量的梯度 : .register_hook

pytorch 为了节省显存,在反向传播的过程中只针对计算图中的叶子结点(leaf variable)保留了梯度值(gradient)。但对于开发者来说,有时我们希望探测某些中间变量(intermediate variable) 的梯度来验证我们的实现是否有误,这个过程就需要用到 tensor的register_hook接口

grads = {}
def save_grad(name):
    def hook(grad):
        grads[name] = grad
    return hook
x = torch.randn(1, requires_grad=True)
y = 3*x
z = y * y
# 为中间变量注册梯度保存接口,存储梯度时名字为 y。
y.register_hook(save_grad('y'))
# 反向传播 
z.backward()
# 查看 y 的梯度值
print(grads['y'])

打印网络回传梯度

net.named_parameters()

parms.requires_grad 表示该参数是否可学习,是不是frozen的;

parm.grad 打印该参数的梯度值。

net = your_network().cuda()
def train():
 ...
 outputs = net(inputs)
    loss = criterion(outputs, targets)
    loss.backward()
 for name, parms in net.named_parameters(): 
  print('-->name:', name, '-->grad_requirs:',parms.requires_grad, \
   ' -->grad_value:',parms.grad)

查看pytorch产生的梯度

[x.grad for x in self.optimizer.param_groups[0]['params']]

pytorch模型可视化及参数计算

我们在设计完程序以后希望能对我们的模型进行可视化,pytorch这里似乎没有提供相应的包直接进行调用,参考一些博客。

下面把代码贴出来:

import torch
from torch.autograd import Variable
import torch.nn as nn
from graphviz import Digraph
def make_dot(var, params=None):
   
    if params is not None:
        assert isinstance(params.values()[0], Variable)
        param_map = {id(v): k for k, v in params.items()}
 
    node_attr = dict(style='filled',
                     shape='box',
                     align='left',
                     fontsize='12',
                     ranksep='0.1',
                     height='0.2')
    dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
    seen = set()
 
    def size_to_str(size):
        return '('+(', ').join(['%d' % v for v in size])+')'
 
    def add_nodes(var):
        if var not in seen:
            if torch.is_tensor(var):
                dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
            elif hasattr(var, 'variable'):
                u = var.variable
                name = param_map[id(u)] if params is not None else ''
                node_name = '%s\n %s' % (name, size_to_str(u.size()))
                dot.node(str(id(var)), node_name, fillcolor='lightblue')
            else:
                dot.node(str(id(var)), str(type(var).__name__))
            seen.add(var)
            if hasattr(var, 'next_functions'):
                for u in var.next_functions:
                    if u[0] is not None:
                        dot.edge(str(id(u[0])), str(id(var)))
                        add_nodes(u[0])
            if hasattr(var, 'saved_tensors'):
                for t in var.saved_tensors:
                    dot.edge(str(id(t)), str(id(var)))
                    add_nodes(t)
    add_nodes(var.grad_fn)
    return dot

我们在我们的模型下面直接进行调用就可以了,例如:

if __name__ == "__main__":
    model = DeepLab(backbone='resnet', output_stride=16)
    input = torch.rand(1, 3, 53, 53)
    output = model(input)
    g = make_dot(output)
    g.view()
    params = list(net.parameters())
    k = 0
    for i in params:
        l = 1
        print("该层的结构:" + str(list(i.size())))
        for j in i.size():
            l *= j
        print("该层参数和:" + str(l))
        k = k + l
    print("总参数数量和:" + str(k))

模型部分可视化结果:

pytorch 权重weight 与 梯度grad 可视化操作

参数计算:

pytorch 权重weight 与 梯度grad 可视化操作

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python求列表交集的方法汇总
Nov 10 Python
Python基于回溯法子集树模板解决马踏棋盘问题示例
Sep 11 Python
Python使用arrow库优雅地处理时间数据详解
Oct 10 Python
Python中Numpy包的安装与使用方法简明教程
Jul 03 Python
Python 按字典dict的键排序,并取出相应的键值放于list中的实例
Feb 12 Python
Python通过TensorFlow卷积神经网络实现猫狗识别
Mar 14 Python
pytorch 模型可视化的例子
Aug 17 Python
opencv python在视屏上截图功能的实现
Mar 05 Python
python+selenium+Chrome options参数的使用
Mar 18 Python
基于Python共轭梯度法与最速下降法之间的对比
Apr 02 Python
Python pandas如何向excel添加数据
May 22 Python
使用tensorflow 实现反向传播求导
May 26 Python
PyTorch 如何检查模型梯度是否可导
python-opencv 中值滤波{cv2.medianBlur(src, ksize)}的用法
解决Pytorch修改预训练模型时遇到key不匹配的情况
Jun 05 #Python
pytorch 预训练模型读取修改相关参数的填坑问题
Jun 05 #Python
解决pytorch 损失函数中输入输出不匹配的问题
Jun 05 #Python
Pytorch distributed 多卡并行载入模型操作
Jun 05 #Python
Pytorch中的学习率衰减及其用法详解
Jun 05 #Python
You might like
利用文件属性结合Session实现在线人数统计
2006/10/09 PHP
PHP数组操作汇总 php数组的使用技巧
2011/07/17 PHP
php动态生成版权所有信息的方法
2015/03/24 PHP
php实现可逆加密的方法
2015/08/11 PHP
PHP引用的调用方法分析
2016/04/25 PHP
Win10 下安装配置IIS + MySQL + nginx + php7.1.7
2017/08/04 PHP
PHP 范围解析操作符(::)用法分析【访问静态成员和类常量】
2020/04/14 PHP
JS控件autocomplete 0.11演示及下载 1月5日已更新
2007/01/09 Javascript
基于jquery的滚动新闻列表
2010/06/19 Javascript
JavaScript高级程序设计(第3版)学习笔记4 js运算符和操作符
2012/10/11 Javascript
Javascript实现重力弹跳拖拽运动效果示例
2013/06/28 Javascript
js实现全屏漂浮广告移入光标停止移动
2013/12/02 Javascript
js数值和和字符串进行转换时可以对不同进制进行操作
2014/03/05 Javascript
分享我的jquery实现下拉菜单心的
2015/11/29 Javascript
js运动事件函数详解
2016/10/21 Javascript
微信小程序 tabs选项卡效果的实现
2017/01/05 Javascript
js中let和var定义变量的区别
2018/02/08 Javascript
vue + element-ui的分页问题实现
2018/12/17 Javascript
详解如何用webpack4从零开始构建react开发环境
2019/01/27 Javascript
vue实现可视化可拖放的自定义表单的示例代码
2019/03/20 Javascript
JS使用百度地图API自动获取地址和经纬度操作示例
2019/04/16 Javascript
[01:39](回顾)各路豪强针锋相对,几经鏖战四强产生
2014/07/01 DOTA
python判断端口是否打开的实现代码
2013/02/10 Python
Python 实现文件的全备份和差异备份详解
2016/12/27 Python
python如何把嵌套列表转变成普通列表
2018/03/20 Python
Python 必须了解的5种高级特征
2020/09/10 Python
TCP/IP模型的分界线
2012/12/01 面试题
《问银河》教学反思
2014/02/19 职场文书
大学开学计划书
2014/04/30 职场文书
企业党建工作汇报材料
2014/08/19 职场文书
法定代表人授权委托书
2014/09/19 职场文书
员工保密协议书
2014/09/27 职场文书
四年级小学生评语
2014/12/26 职场文书
护士实习自荐信
2015/03/06 职场文书
小兵张嘎观后感
2015/06/03 职场文书
Python创建SQL数据库流程逐步讲解
2022/09/23 Python