pytorch 权重weight 与 梯度grad 可视化操作


Posted in Python onJune 05, 2021

pytorch 权重weight 与 梯度grad 可视化

查看特定layer的权重以及相应的梯度信息

打印模型

pytorch 权重weight 与 梯度grad 可视化操作

观察到model下面有module的key,module下面有features的key, features下面有(0)的key,这样就可以直接打印出weight了

在pdb debug界面输入p model.module.features[0].weight,就可以看到weight,输入 p model.module.features[0].weight.grad 就可以查看梯度信息。

中间变量的梯度 : .register_hook

pytorch 为了节省显存,在反向传播的过程中只针对计算图中的叶子结点(leaf variable)保留了梯度值(gradient)。但对于开发者来说,有时我们希望探测某些中间变量(intermediate variable) 的梯度来验证我们的实现是否有误,这个过程就需要用到 tensor的register_hook接口

grads = {}
def save_grad(name):
    def hook(grad):
        grads[name] = grad
    return hook
x = torch.randn(1, requires_grad=True)
y = 3*x
z = y * y
# 为中间变量注册梯度保存接口,存储梯度时名字为 y。
y.register_hook(save_grad('y'))
# 反向传播 
z.backward()
# 查看 y 的梯度值
print(grads['y'])

打印网络回传梯度

net.named_parameters()

parms.requires_grad 表示该参数是否可学习,是不是frozen的;

parm.grad 打印该参数的梯度值。

net = your_network().cuda()
def train():
 ...
 outputs = net(inputs)
    loss = criterion(outputs, targets)
    loss.backward()
 for name, parms in net.named_parameters(): 
  print('-->name:', name, '-->grad_requirs:',parms.requires_grad, \
   ' -->grad_value:',parms.grad)

查看pytorch产生的梯度

[x.grad for x in self.optimizer.param_groups[0]['params']]

pytorch模型可视化及参数计算

我们在设计完程序以后希望能对我们的模型进行可视化,pytorch这里似乎没有提供相应的包直接进行调用,参考一些博客。

下面把代码贴出来:

import torch
from torch.autograd import Variable
import torch.nn as nn
from graphviz import Digraph
def make_dot(var, params=None):
   
    if params is not None:
        assert isinstance(params.values()[0], Variable)
        param_map = {id(v): k for k, v in params.items()}
 
    node_attr = dict(style='filled',
                     shape='box',
                     align='left',
                     fontsize='12',
                     ranksep='0.1',
                     height='0.2')
    dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
    seen = set()
 
    def size_to_str(size):
        return '('+(', ').join(['%d' % v for v in size])+')'
 
    def add_nodes(var):
        if var not in seen:
            if torch.is_tensor(var):
                dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
            elif hasattr(var, 'variable'):
                u = var.variable
                name = param_map[id(u)] if params is not None else ''
                node_name = '%s\n %s' % (name, size_to_str(u.size()))
                dot.node(str(id(var)), node_name, fillcolor='lightblue')
            else:
                dot.node(str(id(var)), str(type(var).__name__))
            seen.add(var)
            if hasattr(var, 'next_functions'):
                for u in var.next_functions:
                    if u[0] is not None:
                        dot.edge(str(id(u[0])), str(id(var)))
                        add_nodes(u[0])
            if hasattr(var, 'saved_tensors'):
                for t in var.saved_tensors:
                    dot.edge(str(id(t)), str(id(var)))
                    add_nodes(t)
    add_nodes(var.grad_fn)
    return dot

我们在我们的模型下面直接进行调用就可以了,例如:

if __name__ == "__main__":
    model = DeepLab(backbone='resnet', output_stride=16)
    input = torch.rand(1, 3, 53, 53)
    output = model(input)
    g = make_dot(output)
    g.view()
    params = list(net.parameters())
    k = 0
    for i in params:
        l = 1
        print("该层的结构:" + str(list(i.size())))
        for j in i.size():
            l *= j
        print("该层参数和:" + str(l))
        k = k + l
    print("总参数数量和:" + str(k))

模型部分可视化结果:

pytorch 权重weight 与 梯度grad 可视化操作

参数计算:

pytorch 权重weight 与 梯度grad 可视化操作

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python用装饰器自动注册Tornado路由详解
Feb 14 Python
解决python3 urllib 链接中有中文的问题
Jul 16 Python
Python中__slots__属性介绍与基本使用方法
Sep 05 Python
Django xadmin开启搜索功能的实现
Nov 15 Python
Python使用Pandas读写Excel实例解析
Nov 19 Python
win10下python2和python3共存问题解决方法
Dec 23 Python
python读取文件指定行内容实例讲解
Mar 02 Python
selenium WebDriverWait类等待机制的实现
Mar 18 Python
matplotlib jupyter notebook 图像可视化 plt show操作
Apr 24 Python
在Sublime Editor中配置Python环境的详细教程
May 03 Python
python 实现图与图之间的间距调整subplots_adjust
May 21 Python
python通过函数名调用函数的几种方法总结
Jun 07 Python
PyTorch 如何检查模型梯度是否可导
python-opencv 中值滤波{cv2.medianBlur(src, ksize)}的用法
解决Pytorch修改预训练模型时遇到key不匹配的情况
Jun 05 #Python
pytorch 预训练模型读取修改相关参数的填坑问题
Jun 05 #Python
解决pytorch 损失函数中输入输出不匹配的问题
Jun 05 #Python
Pytorch distributed 多卡并行载入模型操作
Jun 05 #Python
Pytorch中的学习率衰减及其用法详解
Jun 05 #Python
You might like
咖啡店都有些什么常规豆子呢?有什么风味在里面
2021/03/04 咖啡文化
用PHP编程语言开发动态WAP页面
2006/10/09 PHP
PHP正则判断一个变量是否为正整数的方法
2019/02/27 PHP
javascript 一些用法小结
2009/09/11 Javascript
document节点对象的获取方式示例介绍
2013/12/24 Javascript
iframe里的页面禁止右键事件的方法
2014/06/10 Javascript
Javascript检查图片大小不要让大图片撑破页面
2014/11/04 Javascript
jQuery轻松实现表格的隔行变色和点击行变色的实例代码
2016/05/09 Javascript
所见即所得的富文本编辑器bootstrap-wysiwyg使用方法详解
2016/05/27 Javascript
Angular外部使用js调用Angular控制器中的函数方法或变量用法示例
2016/08/05 Javascript
js 性能优化之快速响应的用户界面
2017/02/15 Javascript
three.js加载obj模型的实例代码
2017/11/10 Javascript
React组件内事件传参实现tab切换的示例代码
2018/07/04 Javascript
JavaScript实现图片懒加载的方法分析
2018/07/05 Javascript
详解ES6新增字符串扩张方法includes()、startsWith()、endsWith()
2020/05/12 Javascript
python封装对象实现时间效果
2020/04/23 Python
Python实现的圆形绘制(画圆)示例
2018/01/31 Python
PyQt5每天必学之创建窗口居中效果
2018/04/19 Python
Django1.9 加载通过ImageField上传的图片方法
2018/05/25 Python
python url 参数修改方法
2018/12/26 Python
Python hashlib模块加密过程解析
2019/11/05 Python
python numpy库linspace相同间隔采样的实现
2020/02/25 Python
Python函数必须先定义,后调用说明(函数调用函数例外)
2020/06/02 Python
python关于倒排列的知识点总结
2020/10/13 Python
Python用摘要算法生成token及检验token的示例代码
2020/12/01 Python
详解HTML5中的元素与元素
2015/08/17 HTML / CSS
世界上最大的高分辨率在线图片库:Alamy
2018/07/07 全球购物
日常奢侈品,轻松购物:Verishop
2019/08/20 全球购物
潘多拉珠宝美国官方网站:Pandora US
2020/06/18 全球购物
JSF如何进行表格处理及取值
2012/08/06 面试题
质量主管工作职责
2014/09/26 职场文书
电影雨中的树观后感
2015/06/15 职场文书
创业计划书之餐饮
2019/09/02 职场文书
Nginx同一个域名配置多个项目的实现方法
2021/03/31 Servers
php+laravel 扫码二维码签到功能
2021/05/15 PHP
分析JVM源码之Thread.interrupt系统级别线程打断
2021/06/29 Java/Android