pytorch 权重weight 与 梯度grad 可视化操作


Posted in Python onJune 05, 2021

pytorch 权重weight 与 梯度grad 可视化

查看特定layer的权重以及相应的梯度信息

打印模型

pytorch 权重weight 与 梯度grad 可视化操作

观察到model下面有module的key,module下面有features的key, features下面有(0)的key,这样就可以直接打印出weight了

在pdb debug界面输入p model.module.features[0].weight,就可以看到weight,输入 p model.module.features[0].weight.grad 就可以查看梯度信息。

中间变量的梯度 : .register_hook

pytorch 为了节省显存,在反向传播的过程中只针对计算图中的叶子结点(leaf variable)保留了梯度值(gradient)。但对于开发者来说,有时我们希望探测某些中间变量(intermediate variable) 的梯度来验证我们的实现是否有误,这个过程就需要用到 tensor的register_hook接口

grads = {}
def save_grad(name):
    def hook(grad):
        grads[name] = grad
    return hook
x = torch.randn(1, requires_grad=True)
y = 3*x
z = y * y
# 为中间变量注册梯度保存接口,存储梯度时名字为 y。
y.register_hook(save_grad('y'))
# 反向传播 
z.backward()
# 查看 y 的梯度值
print(grads['y'])

打印网络回传梯度

net.named_parameters()

parms.requires_grad 表示该参数是否可学习,是不是frozen的;

parm.grad 打印该参数的梯度值。

net = your_network().cuda()
def train():
 ...
 outputs = net(inputs)
    loss = criterion(outputs, targets)
    loss.backward()
 for name, parms in net.named_parameters(): 
  print('-->name:', name, '-->grad_requirs:',parms.requires_grad, \
   ' -->grad_value:',parms.grad)

查看pytorch产生的梯度

[x.grad for x in self.optimizer.param_groups[0]['params']]

pytorch模型可视化及参数计算

我们在设计完程序以后希望能对我们的模型进行可视化,pytorch这里似乎没有提供相应的包直接进行调用,参考一些博客。

下面把代码贴出来:

import torch
from torch.autograd import Variable
import torch.nn as nn
from graphviz import Digraph
def make_dot(var, params=None):
   
    if params is not None:
        assert isinstance(params.values()[0], Variable)
        param_map = {id(v): k for k, v in params.items()}
 
    node_attr = dict(style='filled',
                     shape='box',
                     align='left',
                     fontsize='12',
                     ranksep='0.1',
                     height='0.2')
    dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
    seen = set()
 
    def size_to_str(size):
        return '('+(', ').join(['%d' % v for v in size])+')'
 
    def add_nodes(var):
        if var not in seen:
            if torch.is_tensor(var):
                dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
            elif hasattr(var, 'variable'):
                u = var.variable
                name = param_map[id(u)] if params is not None else ''
                node_name = '%s\n %s' % (name, size_to_str(u.size()))
                dot.node(str(id(var)), node_name, fillcolor='lightblue')
            else:
                dot.node(str(id(var)), str(type(var).__name__))
            seen.add(var)
            if hasattr(var, 'next_functions'):
                for u in var.next_functions:
                    if u[0] is not None:
                        dot.edge(str(id(u[0])), str(id(var)))
                        add_nodes(u[0])
            if hasattr(var, 'saved_tensors'):
                for t in var.saved_tensors:
                    dot.edge(str(id(t)), str(id(var)))
                    add_nodes(t)
    add_nodes(var.grad_fn)
    return dot

我们在我们的模型下面直接进行调用就可以了,例如:

if __name__ == "__main__":
    model = DeepLab(backbone='resnet', output_stride=16)
    input = torch.rand(1, 3, 53, 53)
    output = model(input)
    g = make_dot(output)
    g.view()
    params = list(net.parameters())
    k = 0
    for i in params:
        l = 1
        print("该层的结构:" + str(list(i.size())))
        for j in i.size():
            l *= j
        print("该层参数和:" + str(l))
        k = k + l
    print("总参数数量和:" + str(k))

模型部分可视化结果:

pytorch 权重weight 与 梯度grad 可视化操作

参数计算:

pytorch 权重weight 与 梯度grad 可视化操作

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
跟老齐学Python之做一个小游戏
Sep 28 Python
python解析xml文件操作实例
Oct 05 Python
Python编程中用close()方法关闭文件的教程
May 24 Python
Python内置函数—vars的具体使用方法
Dec 04 Python
Django contenttypes 框架详解(小结)
Aug 13 Python
pyqt5 QScrollArea设置在自定义侧(任何位置)
Sep 25 Python
python取均匀不重复的随机数方式
Nov 27 Python
python同时遍历两个list用法说明
May 02 Python
Python爬虫之爬取淘女郎照片示例详解
Jul 28 Python
利用python如何实现猫捉老鼠小游戏
Dec 04 Python
基于Python-turtle库绘制路飞的草帽骷髅旗、美国队长的盾牌、高达的源码
Feb 18 Python
聊聊python在linux下与windows下导入模块的区别说明
Mar 03 Python
PyTorch 如何检查模型梯度是否可导
python-opencv 中值滤波{cv2.medianBlur(src, ksize)}的用法
解决Pytorch修改预训练模型时遇到key不匹配的情况
Jun 05 #Python
pytorch 预训练模型读取修改相关参数的填坑问题
Jun 05 #Python
解决pytorch 损失函数中输入输出不匹配的问题
Jun 05 #Python
Pytorch distributed 多卡并行载入模型操作
Jun 05 #Python
Pytorch中的学习率衰减及其用法详解
Jun 05 #Python
You might like
is_uploaded_file函数引发的不能上传文件问题
2013/10/29 PHP
PHP微信红包生成代码分享
2016/10/06 PHP
PHP实现通过strace定位故障原因的方法
2018/04/29 PHP
PHP 裁剪图片
2021/03/09 PHP
很多人都是用下面的js刷新站IP和PV
2008/09/05 Javascript
Javascript 生成指定范围数值随机数
2009/01/09 Javascript
javascript html 静态页面传参数
2009/04/10 Javascript
jquery post方式传递多个参数值后台以数组的方式进行接收
2013/01/11 Javascript
ionic实现可滑动的tab选项卡切换效果
2020/04/15 Javascript
jQuery基于xml格式数据实现模糊查询及分页功能的方法
2016/12/25 Javascript
jQuery实现单击按钮遮罩弹出对话框效果(1)
2017/02/20 Javascript
详解webpack解惑:require的五种用法
2017/06/09 Javascript
利用Jasmine对Angular进行单元测试的方法详解
2017/06/12 Javascript
深入理解基于vue-cli的vuex配置
2017/07/24 Javascript
在ABP框架中使用BootstrapTable组件的方法
2017/07/31 Javascript
微信小程序导入Vant报错VM292:1 thirdScriptError的解决方法
2019/08/01 Javascript
vue 表单输入框不支持focus及blur事件的解决方案
2020/11/17 Vue.js
python利用正则表达式排除集合中字符的功能示例
2017/10/10 Python
Python创建二维数组实例(关于list的一个小坑)
2017/11/07 Python
详解pandas如何去掉、过滤数据集中的某些值或者某些行?
2019/05/15 Python
python实现将json多行数据传入到mysql中使用
2019/12/31 Python
Python3爬虫mitmproxy的安装步骤
2020/07/29 Python
让ie浏览器成为支持html5的浏览器的解决方法(使用html5shiv)
2014/04/08 HTML / CSS
大学生学习自我评价
2014/01/13 职场文书
法律专业学生的自我评价
2014/02/07 职场文书
和解协议书
2014/04/16 职场文书
车辆工程专业求职信
2014/04/28 职场文书
2014入党积极分子批评与自我批评思想报告
2014/10/06 职场文书
化工见习报告范文
2014/10/31 职场文书
工会工作个人总结
2015/03/03 职场文书
2015年小学重阳节活动总结
2015/07/29 职场文书
幼儿园安全管理制度
2015/08/05 职场文书
Python进行区间取值案例讲解
2021/08/02 Python
Python 中的Sympy详细使用
2021/08/07 Python
MySQ InnoDB和MyISAM存储引擎介绍
2022/04/26 MySQL
CentOS7 minimal 最小化安装网络设置过程
2022/12/24 Servers