获取Pytorch中间某一层权重或者特征的例子


Posted in Python onAugust 17, 2019

问题:训练好的网络模型想知道中间某一层的权重或者看看中间某一层的特征,如何处理呢?

1、获取某一层权重,并保存到excel中;

以resnet18为例说明:

import torch
import pandas as pd
import numpy as np
import torchvision.models as models

resnet18 = models.resnet18(pretrained=True)

parm={}
for name,parameters in resnet18.named_parameters():
  print(name,':',parameters.size())
  parm[name]=parameters.detach().numpy()

上述代码将每个模块参数存入parm字典中,parameters.detach().numpy()将tensor类型变量转换成numpy array形式,方便后续存储到表格中.输出为:

conv1.weight : torch.Size([64, 3, 7, 7])
bn1.weight : torch.Size([64])
bn1.bias : torch.Size([64])
layer1.0.conv1.weight : torch.Size([64, 64, 3, 3])
layer1.0.bn1.weight : torch.Size([64])
layer1.0.bn1.bias : torch.Size([64])
layer1.0.conv2.weight : torch.Size([64, 64, 3, 3])
layer1.0.bn2.weight : torch.Size([64])
layer1.0.bn2.bias : torch.Size([64])
layer1.1.conv1.weight : torch.Size([64, 64, 3, 3])
layer1.1.bn1.weight : torch.Size([64])
layer1.1.bn1.bias : torch.Size([64])
layer1.1.conv2.weight : torch.Size([64, 64, 3, 3])
layer1.1.bn2.weight : torch.Size([64])
layer1.1.bn2.bias : torch.Size([64])
layer2.0.conv1.weight : torch.Size([128, 64, 3, 3])
layer2.0.bn1.weight : torch.Size([128])
layer2.0.bn1.bias : torch.Size([128])
layer2.0.conv2.weight : torch.Size([128, 128, 3, 3])
layer2.0.bn2.weight : torch.Size([128])
layer2.0.bn2.bias : torch.Size([128])
layer2.0.downsample.0.weight : torch.Size([128, 64, 1, 1])
layer2.0.downsample.1.weight : torch.Size([128])
layer2.0.downsample.1.bias : torch.Size([128])
layer2.1.conv1.weight : torch.Size([128, 128, 3, 3])
layer2.1.bn1.weight : torch.Size([128])
layer2.1.bn1.bias : torch.Size([128])
layer2.1.conv2.weight : torch.Size([128, 128, 3, 3])
layer2.1.bn2.weight : torch.Size([128])
layer2.1.bn2.bias : torch.Size([128])
layer3.0.conv1.weight : torch.Size([256, 128, 3, 3])
layer3.0.bn1.weight : torch.Size([256])
layer3.0.bn1.bias : torch.Size([256])
layer3.0.conv2.weight : torch.Size([256, 256, 3, 3])
layer3.0.bn2.weight : torch.Size([256])
layer3.0.bn2.bias : torch.Size([256])
layer3.0.downsample.0.weight : torch.Size([256, 128, 1, 1])
layer3.0.downsample.1.weight : torch.Size([256])
layer3.0.downsample.1.bias : torch.Size([256])
layer3.1.conv1.weight : torch.Size([256, 256, 3, 3])
layer3.1.bn1.weight : torch.Size([256])
layer3.1.bn1.bias : torch.Size([256])
layer3.1.conv2.weight : torch.Size([256, 256, 3, 3])
layer3.1.bn2.weight : torch.Size([256])
layer3.1.bn2.bias : torch.Size([256])
layer4.0.conv1.weight : torch.Size([512, 256, 3, 3])
layer4.0.bn1.weight : torch.Size([512])
layer4.0.bn1.bias : torch.Size([512])
layer4.0.conv2.weight : torch.Size([512, 512, 3, 3])
layer4.0.bn2.weight : torch.Size([512])
layer4.0.bn2.bias : torch.Size([512])
layer4.0.downsample.0.weight : torch.Size([512, 256, 1, 1])
layer4.0.downsample.1.weight : torch.Size([512])
layer4.0.downsample.1.bias : torch.Size([512])
layer4.1.conv1.weight : torch.Size([512, 512, 3, 3])
layer4.1.bn1.weight : torch.Size([512])
layer4.1.bn1.bias : torch.Size([512])
layer4.1.conv2.weight : torch.Size([512, 512, 3, 3])
layer4.1.bn2.weight : torch.Size([512])
layer4.1.bn2.bias : torch.Size([512])
fc.weight : torch.Size([1000, 512])
fc.bias : torch.Size([1000])
parm['layer1.0.conv1.weight'][0,0,:,:]

输出为:

array([[ 0.05759342, -0.09511436, -0.02027232],
[-0.07455588, -0.799308 , -0.21283598],
[ 0.06557069, -0.09653367, -0.01211061]], dtype=float32)

利用如下函数将某一层的所有参数保存到表格中,数据维持卷积核特征大小,如3*3的卷积保存后还是3x3的.

def parm_to_excel(excel_name,key_name,parm):
with pd.ExcelWriter(excel_name) as writer:
[output_num,input_num,filter_size,_]=parm[key_name].size()
for i in range(output_num):
for j in range(input_num):
data=pd.DataFrame(parm[key_name][i,j,:,:].detach().numpy())
#print(data)
data.to_excel(writer,index=False,header=True,startrow=i*(filter_size+1),startcol=j*filter_size)

由于权重矩阵中有很多的值非常小,取出固定大小的值,并将全部权重写入excel

counter=1
with pd.ExcelWriter('test1.xlsx') as writer:
  for key in parm_resnet50.keys():
    data=parm_resnet50[key].reshape(-1,1)
    data=data[data>0.001]
    
    data=pd.DataFrame(data,columns=[key])
    data.to_excel(writer,index=False,startcol=counter)
    counter+=1

2、获取中间某一层的特性

重写一个函数,将需要输出的层输出即可.

def resnet_cifar(net,input_data):
  x = net.conv1(input_data)
  x = net.bn1(x)
  x = F.relu(x)
  x = net.layer1(x)
  x = net.layer2(x)
  x = net.layer3(x)
  x = net.layer4[0].conv1(x) #这样就提取了layer4第一块的第一个卷积层的输出
  x=x.view(x.shape[0],-1)
  return x

model = models.resnet18()
x = resnet_cifar(model,input_data)

以上这篇获取Pytorch中间某一层权重或者特征的例子就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python常用模块介绍
Nov 21 Python
python 3利用BeautifulSoup抓取div标签的方法示例
May 28 Python
Python实现识别手写数字 Python图片读入与处理
Mar 23 Python
python 拷贝特定后缀名文件,并保留原始目录结构的实例
Apr 27 Python
Python Flask 搭建微信小程序后台详解
May 06 Python
Python学习笔记之读取文件、OS模块、异常处理、with as语法示例
Jun 04 Python
Django REST Framework之频率限制的使用
Sep 29 Python
详解使用Python写一个向数据库填充数据的小工具(推荐)
Sep 11 Python
详解vscode实现远程linux服务器上Python开发
Nov 10 Python
python 通过 pybind11 使用Eigen加速代码的步骤
Dec 07 Python
Python OpenCV超详细讲解读取图像视频和网络摄像头
Apr 02 Python
python中validators库的使用方法详解
Sep 23 Python
pyenv与virtualenv安装实现python多版本多项目管理
Aug 17 #Python
pytorch 获取层权重,对特定层注入hook, 提取中间层输出的方法
Aug 17 #Python
关于PyTorch源码解读之torchvision.models
Aug 17 #Python
django项目用higcharts统计最近七天文章点击量
Aug 17 #Python
Django对models里的objects的使用详解
Aug 17 #Python
python3.6中@property装饰器的使用方法示例
Aug 17 #Python
对django的User模型和四种扩展/重写方法小结
Aug 17 #Python
You might like
PHP中的函数-- foreach()的用法详解
2013/06/24 PHP
在laravel-admin中列表中禁止某行编辑、删除的方法
2019/10/03 PHP
js 多浏览器分别判断代码
2010/04/01 Javascript
jquery nth-child()选择器的简单应用
2010/07/10 Javascript
关于JavaScript的变量的数据类型的判断方法
2015/08/14 Javascript
浅谈javascript基础之客户端事件驱动
2016/06/10 Javascript
vue中用动态组件实现选项卡切换效果
2017/03/25 Javascript
探索webpack模块及webpack3新特性
2017/09/18 Javascript
详解Vue中localstorage和sessionstorage的使用
2017/12/22 Javascript
图片懒加载imgLazyLoading.js使用详解
2020/09/15 Javascript
微信小程序的tab选项卡的实现效果
2019/05/15 Javascript
VUE动态生成word的实现
2020/07/26 Javascript
如何用JS模拟实现数组的map方法
2020/07/30 Javascript
vue 限制input只能输入正数的操作
2020/08/05 Javascript
vue开发chrome插件,实现获取界面数据和保存到数据库功能
2020/12/01 Vue.js
JavaScript 生成唯一ID的几种方式
2021/02/19 Javascript
python模块之sys模块和序列化模块(实例讲解)
2017/09/13 Python
matplotlib subplots 调整子图间矩的实例
2018/05/25 Python
Python生成器generator用法示例
2018/08/10 Python
Python3.6 + TensorFlow 安装配置图文教程(Windows 64 bit)
2020/02/24 Python
Python Socketserver实现FTP文件上传下载代码实例
2020/03/27 Python
python中return不返回值的问题解析
2020/07/22 Python
骆驼官方商城:CAMEL
2016/11/22 全球购物
瑞典耳机品牌:URBANISTA
2019/12/03 全球购物
电视节目策划方案
2014/05/16 职场文书
拔河比赛口号
2014/06/10 职场文书
2014年银行员工年终自我评价
2014/09/19 职场文书
八年级英语教学计划
2015/01/23 职场文书
2015年公司行政后勤工作总结
2015/05/20 职场文书
隐形的翅膀观后感
2015/06/10 职场文书
如何让2019年上半年的工作总结更出色!
2019/07/01 职场文书
Win10 和 Win11可以共存吗? win10/11产品生命周期/服务更新介绍
2021/11/21 数码科技
Appium中scroll和drag_and_drop根据元素位置滑动
2022/02/15 Python
MySQL 计算连续登录天数
2022/05/11 MySQL
Win11无法安装更新补丁KB3045316怎么办 附KB3045316补丁修复教程
2022/08/14 数码科技
css弧边选项卡的项目实践
2023/05/07 HTML / CSS