获取Pytorch中间某一层权重或者特征的例子


Posted in Python onAugust 17, 2019

问题:训练好的网络模型想知道中间某一层的权重或者看看中间某一层的特征,如何处理呢?

1、获取某一层权重,并保存到excel中;

以resnet18为例说明:

import torch
import pandas as pd
import numpy as np
import torchvision.models as models

resnet18 = models.resnet18(pretrained=True)

parm={}
for name,parameters in resnet18.named_parameters():
  print(name,':',parameters.size())
  parm[name]=parameters.detach().numpy()

上述代码将每个模块参数存入parm字典中,parameters.detach().numpy()将tensor类型变量转换成numpy array形式,方便后续存储到表格中.输出为:

conv1.weight : torch.Size([64, 3, 7, 7])
bn1.weight : torch.Size([64])
bn1.bias : torch.Size([64])
layer1.0.conv1.weight : torch.Size([64, 64, 3, 3])
layer1.0.bn1.weight : torch.Size([64])
layer1.0.bn1.bias : torch.Size([64])
layer1.0.conv2.weight : torch.Size([64, 64, 3, 3])
layer1.0.bn2.weight : torch.Size([64])
layer1.0.bn2.bias : torch.Size([64])
layer1.1.conv1.weight : torch.Size([64, 64, 3, 3])
layer1.1.bn1.weight : torch.Size([64])
layer1.1.bn1.bias : torch.Size([64])
layer1.1.conv2.weight : torch.Size([64, 64, 3, 3])
layer1.1.bn2.weight : torch.Size([64])
layer1.1.bn2.bias : torch.Size([64])
layer2.0.conv1.weight : torch.Size([128, 64, 3, 3])
layer2.0.bn1.weight : torch.Size([128])
layer2.0.bn1.bias : torch.Size([128])
layer2.0.conv2.weight : torch.Size([128, 128, 3, 3])
layer2.0.bn2.weight : torch.Size([128])
layer2.0.bn2.bias : torch.Size([128])
layer2.0.downsample.0.weight : torch.Size([128, 64, 1, 1])
layer2.0.downsample.1.weight : torch.Size([128])
layer2.0.downsample.1.bias : torch.Size([128])
layer2.1.conv1.weight : torch.Size([128, 128, 3, 3])
layer2.1.bn1.weight : torch.Size([128])
layer2.1.bn1.bias : torch.Size([128])
layer2.1.conv2.weight : torch.Size([128, 128, 3, 3])
layer2.1.bn2.weight : torch.Size([128])
layer2.1.bn2.bias : torch.Size([128])
layer3.0.conv1.weight : torch.Size([256, 128, 3, 3])
layer3.0.bn1.weight : torch.Size([256])
layer3.0.bn1.bias : torch.Size([256])
layer3.0.conv2.weight : torch.Size([256, 256, 3, 3])
layer3.0.bn2.weight : torch.Size([256])
layer3.0.bn2.bias : torch.Size([256])
layer3.0.downsample.0.weight : torch.Size([256, 128, 1, 1])
layer3.0.downsample.1.weight : torch.Size([256])
layer3.0.downsample.1.bias : torch.Size([256])
layer3.1.conv1.weight : torch.Size([256, 256, 3, 3])
layer3.1.bn1.weight : torch.Size([256])
layer3.1.bn1.bias : torch.Size([256])
layer3.1.conv2.weight : torch.Size([256, 256, 3, 3])
layer3.1.bn2.weight : torch.Size([256])
layer3.1.bn2.bias : torch.Size([256])
layer4.0.conv1.weight : torch.Size([512, 256, 3, 3])
layer4.0.bn1.weight : torch.Size([512])
layer4.0.bn1.bias : torch.Size([512])
layer4.0.conv2.weight : torch.Size([512, 512, 3, 3])
layer4.0.bn2.weight : torch.Size([512])
layer4.0.bn2.bias : torch.Size([512])
layer4.0.downsample.0.weight : torch.Size([512, 256, 1, 1])
layer4.0.downsample.1.weight : torch.Size([512])
layer4.0.downsample.1.bias : torch.Size([512])
layer4.1.conv1.weight : torch.Size([512, 512, 3, 3])
layer4.1.bn1.weight : torch.Size([512])
layer4.1.bn1.bias : torch.Size([512])
layer4.1.conv2.weight : torch.Size([512, 512, 3, 3])
layer4.1.bn2.weight : torch.Size([512])
layer4.1.bn2.bias : torch.Size([512])
fc.weight : torch.Size([1000, 512])
fc.bias : torch.Size([1000])
parm['layer1.0.conv1.weight'][0,0,:,:]

输出为:

array([[ 0.05759342, -0.09511436, -0.02027232],
[-0.07455588, -0.799308 , -0.21283598],
[ 0.06557069, -0.09653367, -0.01211061]], dtype=float32)

利用如下函数将某一层的所有参数保存到表格中,数据维持卷积核特征大小,如3*3的卷积保存后还是3x3的.

def parm_to_excel(excel_name,key_name,parm):
with pd.ExcelWriter(excel_name) as writer:
[output_num,input_num,filter_size,_]=parm[key_name].size()
for i in range(output_num):
for j in range(input_num):
data=pd.DataFrame(parm[key_name][i,j,:,:].detach().numpy())
#print(data)
data.to_excel(writer,index=False,header=True,startrow=i*(filter_size+1),startcol=j*filter_size)

由于权重矩阵中有很多的值非常小,取出固定大小的值,并将全部权重写入excel

counter=1
with pd.ExcelWriter('test1.xlsx') as writer:
  for key in parm_resnet50.keys():
    data=parm_resnet50[key].reshape(-1,1)
    data=data[data>0.001]
    
    data=pd.DataFrame(data,columns=[key])
    data.to_excel(writer,index=False,startcol=counter)
    counter+=1

2、获取中间某一层的特性

重写一个函数,将需要输出的层输出即可.

def resnet_cifar(net,input_data):
  x = net.conv1(input_data)
  x = net.bn1(x)
  x = F.relu(x)
  x = net.layer1(x)
  x = net.layer2(x)
  x = net.layer3(x)
  x = net.layer4[0].conv1(x) #这样就提取了layer4第一块的第一个卷积层的输出
  x=x.view(x.shape[0],-1)
  return x

model = models.resnet18()
x = resnet_cifar(model,input_data)

以上这篇获取Pytorch中间某一层权重或者特征的例子就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用python编写批量卸载手机中安装的android应用脚本
Jul 21 Python
python对指定目录下文件进行批量重命名的方法
Apr 18 Python
使用Python进行二进制文件读写的简单方法(推荐)
Sep 12 Python
基于python的多进程共享变量正确打开方式
Apr 28 Python
Sanic框架应用部署方法详解
Jul 18 Python
简单了解Python生成器是什么
Jul 02 Python
python基于opencv检测程序运行效率
Dec 28 Python
python爬虫开发之urllib模块详细使用方法与实例全解
Mar 09 Python
Python新手学习装饰器
Jun 04 Python
Python StringIO及BytesIO包使用方法解析
Jun 15 Python
Python-split()函数实例用法讲解
Dec 18 Python
这样写python注释让代码更加的优雅
Jun 02 Python
pyenv与virtualenv安装实现python多版本多项目管理
Aug 17 #Python
pytorch 获取层权重,对特定层注入hook, 提取中间层输出的方法
Aug 17 #Python
关于PyTorch源码解读之torchvision.models
Aug 17 #Python
django项目用higcharts统计最近七天文章点击量
Aug 17 #Python
Django对models里的objects的使用详解
Aug 17 #Python
python3.6中@property装饰器的使用方法示例
Aug 17 #Python
对django的User模型和四种扩展/重写方法小结
Aug 17 #Python
You might like
967 个函式
2006/10/09 PHP
php Smarty 字符比较代码
2011/02/27 PHP
PHP数组交集的优化代码分析
2011/03/06 PHP
PHP获取windows登录用户名的方法
2014/06/24 PHP
PHP面向对象精要总结
2014/11/07 PHP
thinkphp分页实现效果
2016/10/13 PHP
浅析PHP7的多进程及实例源码
2019/04/14 PHP
那些年,我还在学习jquery 学习笔记
2012/03/05 Javascript
20行代码实现的一个CSS覆盖率测试脚本
2013/07/07 Javascript
在jquery中combobox多选的不兼容问题总结
2013/12/24 Javascript
JavaScript数值转换的三种方式总结
2014/07/31 Javascript
Bootstrap零基础入门教程(三)
2016/07/18 Javascript
jquery实现简单的瀑布流布局
2016/12/11 Javascript
ES5学习教程之Array对象
2017/04/01 Javascript
详解微信小程序中的页面代码中的模板的封装
2017/10/12 Javascript
three.js中文文档学习之如何本地运行详解
2017/11/20 Javascript
解决Jquery下拉框数据动态获取的问题
2018/01/25 jQuery
如何根据业务封装自己的功能组件
2019/04/19 Javascript
20个必会的JavaScript面试题(小结)
2019/07/02 Javascript
Vue中跨域及打包部署到nginx跨域设置方法
2019/08/26 Javascript
vue实现学生信息管理系统
2020/05/30 Javascript
关于vue-cli3打包代码后白屏的解决方案
2020/09/02 Javascript
vue 自定指令生成uuid滚动监听达到tab表格吸顶效果的代码
2020/09/16 Javascript
如何利用node转发请求详解
2020/09/17 Javascript
[01:07:19]2018DOTA2亚洲邀请赛 4.5 淘汰赛 Mineski vs VG 第一场
2018/04/06 DOTA
Python读取指定目录下指定后缀文件并保存为docx
2017/04/23 Python
python PyTorch参数初始化和Finetune
2018/02/11 Python
Python安装图文教程 Pycharm安装教程
2018/03/27 Python
python爬取淘宝商品销量信息
2018/11/16 Python
python根据字典的键来删除元素的方法
2020/08/16 Python
荷兰照明、灯具和配件网上商店:dmlights
2019/08/25 全球购物
大学生专科学习生活的自我评价
2013/12/07 职场文书
自动化专业职业生涯规划书范文
2014/01/16 职场文书
2019年个人工作总结范文
2019/03/25 职场文书
年终奖金发放管理制度,中小企业适用,拿去救急吧!
2019/07/12 职场文书
python3操作redis实现List列表实例
2021/08/04 Python