pytorch 可视化feature map的示例代码


Posted in Python onAugust 20, 2019

之前做的一些项目中涉及到feature map 可视化的问题,一个层中feature map的数量往往就是当前层out_channels的值,我们可以通过以下代码可视化自己网络中某层的feature map,个人感觉可视化feature map对调参还是很有用的。

不多说了,直接看代码:

import torch
from torch.autograd import Variable
import torch.nn as nn
import pickle

from sys import path
path.append('/residual model path')
import residual_model
from residual_model import Residual_Model

model = Residual_Model()
model.load_state_dict(torch.load('./model.pkl'))



class myNet(nn.Module):
  def __init__(self,pretrained_model,layers):
    super(myNet,self).__init__()
    self.net1 = nn.Sequential(*list(pretrained_model.children())[:layers[0]])
    self.net2 = nn.Sequential(*list(pretrained_model.children())[:layers[1]])
    self.net3 = nn.Sequential(*list(pretrained_model.children())[:layers[2]])

  def forward(self,x):
    out1 = self.net1(x)
    out2 = self.net(out1)
    out3 = self.net(out2)
    return out1,out2,out3

def get_features(pretrained_model, x, layers = [3, 4, 9]): ## get_features 其实很简单
'''
1.首先import model 
2.将weights load 进model
3.熟悉model的每一层的位置,提前知道要输出feature map的网络层是处于网络的那一层
4.直接将test_x输入网络,*list(model.chidren())是用来提取网络的每一层的结构的。net1 = nn.Sequential(*list(pretrained_model.children())[:layers[0]]) ,就是第三层前的所有层。

'''
  net1 = nn.Sequential(*list(pretrained_model.children())[:layers[0]]) 
#  print net1 
  out1 = net1(x) 

  net2 = nn.Sequential(*list(pretrained_model.children())[layers[0]:layers[1]]) 
#  print net2 
  out2 = net2(out1) 

  #net3 = nn.Sequential(*list(pretrained_model.children())[layers[1]:layers[2]]) 
  #out3 = net3(out2) 

  return out1, out2
with open('test.pickle','rb') as f:
  data = pickle.load(f)
x = data['test_mains'][0]
x = Variable(torch.from_numpy(x)).view(1,1,128,1) ## test_x必须为Varibable
#x = Variable(torch.randn(1,1,128,1))
if torch.cuda.is_available():
  x = x.cuda() # 如果模型的训练是用cuda加速的话,输入的变量也必须是cuda加速的,两个必须是对应的,网络的参数weight都是用cuda加速的,不然会报错
  model = model.cuda()
output1,output2 = get_features(model,x)## model是训练好的model,前面已经import 进来了Residual model
print('output1.shape:',output1.shape)
print('output2.shape:',output2.shape)
#print('output3.shape:',output3.shape)
output_1 = torch.squeeze(output2,dim = 0)
output_1_arr = output_1.data.cpu().numpy() # 得到的cuda加速的输出不能直接转变成numpy格式的,当时根据报错的信息首先将变量转换为cpu的,然后转换为numpy的格式
output_1_arr = output_1_arr.reshape([output_1_arr.shape[0],output_1_arr.shape[1]])

以上这篇pytorch 可视化feature map的示例代码就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 测试实现方法
Dec 24 Python
python分割和拼接字符串
Nov 01 Python
python连接远程ftp服务器并列出目录下文件的方法
Apr 01 Python
python 转换 Javascript %u 字符串为python unicode的代码
Sep 06 Python
Python爬虫实现简单的爬取有道翻译功能示例
Jul 13 Python
python MNIST手写识别数据调用API的方法
Aug 08 Python
Python pygorithm模块用法示例【常见算法测试】
Aug 16 Python
Django 限制用户访问频率的中间件的实现
Aug 23 Python
PyQt5的安装配置过程,将ui文件转为py文件后显示窗口的实例
Jun 19 Python
Python 线性回归分析以及评价指标详解
Apr 02 Python
Python批量修改xml的坐标值全部转为整数的实例代码
Nov 26 Python
利用Python批量识别电子账单数据的方法
Feb 08 Python
python爬虫 基于requests模块的get请求实现详解
Aug 20 #Python
python爬虫 urllib模块url编码处理详解
Aug 20 #Python
pytorch实现用Resnet提取特征并保存为txt文件的方法
Aug 20 #Python
python web框架 django wsgi原理解析
Aug 20 #Python
opencv转换颜色空间更改图片背景
Aug 20 #Python
pytorch 预训练层的使用方法
Aug 20 #Python
python爬虫 urllib模块反爬虫机制UA详解
Aug 20 #Python
You might like
laravel安装和配置教程
2014/10/29 PHP
Yii列表定义与使用分页方法小结(3种方法)
2016/07/15 PHP
PHP基于面向对象封装的分页类示例
2019/03/15 PHP
jquery下利用jsonp跨域访问实现方法
2010/07/29 Javascript
js关闭当前页面(窗口)的几种方式总结
2013/03/05 Javascript
js 判断图片是否加载完以及实现图片的预下载
2014/08/14 Javascript
JavaScript极简入门教程(三):数组
2014/10/25 Javascript
jquery实现适用于门户站的导航下拉菜单效果代码
2015/08/24 Javascript
AngularJS 模块化详解及实例代码
2016/09/14 Javascript
微信js-sdk地理位置接口用法示例
2016/10/12 Javascript
Vue.js tab实现选项卡切换
2017/05/16 Javascript
微信小程序wx:for和wx:for-item的用法详解
2018/04/01 Javascript
Node.js的Koa实现JWT用户认证方法
2018/05/05 Javascript
微信小程序云开发如何使用npm安装依赖
2019/05/18 Javascript
简单了解微信小程序的目录结构
2019/07/01 Javascript
vue实现在线预览pdf文件和下载(pdf.js)
2019/11/26 Javascript
[01:56]《DOTA2》中文配音CG
2013/04/22 DOTA
python实现发送邮件功能代码
2017/12/14 Python
使用Python实现分别输出每个数组
2019/12/06 Python
python之生成多层json结构的实现
2020/02/27 Python
Python定时从Mysql提取数据存入Redis的实现
2020/05/03 Python
Keras自动下载的数据集/模型存放位置介绍
2020/06/19 Python
学python最电脑配置有要求么
2020/07/05 Python
如何在python中实现线性回归
2020/08/10 Python
学会迭代器设计模式,帮你大幅提升python性能
2021/01/03 Python
matplotlib部件之矩形选区(RectangleSelector)的实现
2021/02/01 Python
Canvas 文字碰撞检测并抽稀的方法
2019/05/27 HTML / CSS
阿里健康大药房:阿里自营网上药店
2017/08/01 全球购物
Shop Apotheke瑞士:您的健康与美容网上商店
2019/10/09 全球购物
大学生毕业自荐信
2013/10/10 职场文书
工商治理实习生的自我评价分享
2014/02/20 职场文书
金融事务专业毕业生求职信
2014/02/23 职场文书
工程类专业自荐信范文
2014/03/09 职场文书
房屋买卖委托公证书
2014/04/08 职场文书
运动会400米加油稿(8篇)
2014/09/22 职场文书
2015年话务员工作总结
2015/04/29 职场文书