可视化pytorch 模型中不同BN层的running mean曲线实例


Posted in Python onJune 24, 2020

加载模型字典

逐一判断每一层,如果该层是bn 的 running mean,就取出参数并取平均作为该层的代表

对保存的每个BN层的数值进行曲线可视化

from functools import partial
import pickle
import torch
import matplotlib.pyplot as plt

pth_path = 'checkpoint.pth'

pickle.load = partial(pickle.load, encoding="latin1")
pickle.Unpickler = partial(pickle.Unpickler, encoding="latin1")
pretrained_dict = torch.load(pth_path, map_location=lambda storage, loc: storage, pickle_module=pickle)
pretrained_dict = pretrained_dict['state_dict']

means = []
for name, param in pretrained_dict.items():
 print(name)
 if 'running_mean' in name:
  means.append(mean.numpy())

layers = [i for i in range(len(means))]

plt.plot(layers, means, color='blue')
plt.legend()
plt.xticks(layers)
plt.xlabel('layers')
plt.show()

可视化pytorch 模型中不同BN层的running mean曲线实例

补充知识:关于pytorch中BN层(具体实现)的一些小细节

最近在做目标检测,需要把训好的模型放到嵌入式设备上跑前向,因此得把各种层的实现都用C手撸一遍,,,此为背景。

其他层没什么好说的,但是BN层这有个小坑。pytorch在打印网络参数的时候,只打出weight和bias这两个参数。咦,说好的BN层有四个参数running_mean、running_var 、gamma 、beta的呢?一开始我以为是pytorch把BN层的计算简化成weight * X + bias,但马上反应过来应该没这么简单,因为pytorch中只有可学习的参数才称为parameter。上网找了一些资料但都没有说到这么细的,毕竟大部分用户使用时只要模型能跑起来就行了,,,于是开始看BN层有哪些属性,果然发现了熟悉的running_mean和running_var,原来pytorch的BN层实现并没有不同。这里吐个槽:为啥要把gamma和beta改叫weight、bias啊,很有迷惑性的好不好,,,

扯了这么多,干脆捋一遍pytorch里BN层的具体实现过程,帮自己理清思路,也可以给大家提供参考。再吐槽一下,在网上搜“pytorch bn层”出来的全是关于这一层怎么用的、初始化时要输入哪些参数,没找到一个pytorch中BN层是怎么实现的,,,

众所周知,BN层的输出Y与输入X之间的关系是:Y = (X - running_mean) / sqrt(running_var + eps) * gamma + beta,此不赘言。其中gamma、beta为可学习参数(在pytorch中分别改叫weight和bias),训练时通过反向传播更新;而running_mean、running_var则是在前向时先由X计算出mean和var,再由mean和var以动量momentum来更新running_mean和running_var。所以在训练阶段,running_mean和running_var在每次前向时更新一次;在测试阶段,则通过net.eval()固定该BN层的running_mean和running_var,此时这两个值即为训练阶段最后一次前向时确定的值,并在整个测试阶段保持不变。

以上这篇可视化pytorch 模型中不同BN层的running mean曲线实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python兔子毒药问题实例分析
Mar 05 Python
详解python中的json的基本使用方法
Dec 21 Python
Python网络爬虫出现乱码问题的解决方法
Jan 05 Python
Python面向对象编程基础解析(二)
Oct 26 Python
TensorFlow模型保存和提取的方法
Mar 08 Python
便捷提取python导入包的属性方法
Oct 15 Python
Python远程视频监控程序的实例代码
May 05 Python
Python实现快速排序的方法详解
Oct 25 Python
python生成13位或16位时间戳以及反向解析时间戳的实例
Mar 03 Python
python torch.utils.data.DataLoader使用方法
Apr 02 Python
vscode+PyQt5安装详解步骤
Aug 12 Python
Django URL参数Template反向解析
Nov 24 Python
python3.x中安装web.py步骤方法
Jun 23 #Python
python如何删除文件、目录
Jun 23 #Python
TensorFlow保存TensorBoard图像操作
Jun 23 #Python
python和js交互调用的方法
Jun 23 #Python
virtualenv介绍及简明教程
Jun 23 #Python
python不同系统中打开方法
Jun 23 #Python
自学python用什么系统好
Jun 23 #Python
You might like
PHP迅雷、快车、旋风下载专用链转换代码
2010/06/15 PHP
PHP.ini安全配置检测工具pcc简单介绍
2015/07/02 PHP
基于php(Thinkphp)+jquery 实现ajax多选反选不选删除数据功能
2017/02/24 PHP
dwr spring的集成实现代码
2009/03/22 Javascript
通过jQuery打造支持汉字,拼音,英文快速定位查询的超级select插件
2010/06/18 Javascript
基于JQuery实现异步刷新的代码(转载)
2011/03/29 Javascript
javascript实现tab切换的两个实例
2015/11/05 Javascript
Javascript中神奇的this
2016/01/20 Javascript
JavaScript事件处理的方式(三种)
2016/04/26 Javascript
jQuery+CSS实现一个侧滑导航菜单代码
2016/05/09 Javascript
Bootstrap免费字体和图标网站(值得收藏)
2017/03/16 Javascript
利用Javascript实现一套自定义事件机制
2017/12/14 Javascript
JavaScript学习总结(一) ECMAScript、BOM、DOM(核心、浏览器对象模型与文档对象模型)
2018/01/07 Javascript
javascript标准库(js的标准内置对象)总结
2018/05/26 Javascript
详解Ant Design of React的安装和使用方法
2018/12/27 Javascript
基于vue+echarts 数据可视化大屏展示的方法示例
2020/03/09 Javascript
[15:15]教你分分钟做大人:狙击手
2014/10/30 DOTA
[01:10:48]完美世界DOTA2联赛PWL S2 GXR vs PXG 第一场 11.18
2020/11/18 DOTA
在Python中使用zlib模块进行数据压缩的教程
2015/06/26 Python
详解tensorflow训练自己的数据集实现CNN图像分类
2018/02/07 Python
Python RabbitMQ消息队列实现rpc
2018/05/30 Python
python 读取鼠标点击坐标的实例
2018/12/29 Python
用 Python 制作地球仪的方法
2020/04/24 Python
Python中有几个关键字
2020/06/04 Python
html5 postMessage解决跨域、跨窗口消息传递方案
2016/12/20 HTML / CSS
Lentiamo丹麦:购买便宜的隐形眼镜
2021/01/13 全球购物
Java中的异常处理机制的简单原理和应用
2013/04/27 面试题
环境工程求职简历的自我评价范文
2013/10/24 职场文书
公开承诺书格式
2014/05/21 职场文书
习近平在党的群众路线教育实践活动总结大会上的讲话全文
2014/10/25 职场文书
会计师事务所实习证明
2014/11/16 职场文书
收入及婚姻状况证明
2014/11/20 职场文书
邀请函样本
2015/02/02 职场文书
认真学习保证书
2015/02/26 职场文书
apache基于端口创建虚拟主机的示例
2021/04/24 Servers
canvas实现贪食蛇的实践
2022/02/15 Javascript