使用pytorch实现可视化中间层的结果


Posted in Python onDecember 30, 2019

摘要

一直比较想知道图片经过卷积之后中间层的结果,于是使用pytorch写了一个脚本查看,先看效果

这是原图,随便从网上下载的一张大概224*224大小的图片,如下

使用pytorch实现可视化中间层的结果

网络介绍

我们使用的VGG16,包含RULE层总共有30层可以可视化的结果,我们把这30层分别保存在30个文件夹中,每个文件中根据特征的大小保存了64~128张图片

结果如下:

原图大小为224224,经过第一层后大小为64224*224,下面是第一层可视化的结果,总共有64张这样的图片:

使用pytorch实现可视化中间层的结果

下面看看第六层的结果

这层的输出大小是 1128112*112,总共有128张这样的图片

使用pytorch实现可视化中间层的结果

下面是完整的代码

import cv2
import numpy as np
import torch
from torch.autograd import Variable
from torchvision import models

#创建30个文件夹
def mkdir(path): # 判断是否存在指定文件夹,不存在则创建
  # 引入模块
  import os

  # 去除首位空格
  path = path.strip()
  # 去除尾部 \ 符号
  path = path.rstrip("\\")

  # 判断路径是否存在
  # 存在   True
  # 不存在  False
  isExists = os.path.exists(path)

  # 判断结果
  if not isExists:
    # 如果不存在则创建目录
    # 创建目录操作函数
    os.makedirs(path)
    return True
  else:

    return False


def preprocess_image(cv2im, resize_im=True):
  """
    Processes image for CNNs

  Args:
    PIL_img (PIL_img): Image to process
    resize_im (bool): Resize to 224 or not
  returns:
    im_as_var (Pytorch variable): Variable that contains processed float tensor
  """
  # mean and std list for channels (Imagenet)
  mean = [0.485, 0.456, 0.406]
  std = [0.229, 0.224, 0.225]
  # Resize image
  if resize_im:
    cv2im = cv2.resize(cv2im, (224, 224))
  im_as_arr = np.float32(cv2im)
  im_as_arr = np.ascontiguousarray(im_as_arr[..., ::-1])
  im_as_arr = im_as_arr.transpose(2, 0, 1) # Convert array to D,W,H
  # Normalize the channels
  for channel, _ in enumerate(im_as_arr):
    im_as_arr[channel] /= 255
    im_as_arr[channel] -= mean[channel]
    im_as_arr[channel] /= std[channel]
  # Convert to float tensor
  im_as_ten = torch.from_numpy(im_as_arr).float()
  # Add one more channel to the beginning. Tensor shape = 1,3,224,224
  im_as_ten.unsqueeze_(0)
  # Convert to Pytorch variable
  im_as_var = Variable(im_as_ten, requires_grad=True)
  return im_as_var


class FeatureVisualization():
  def __init__(self,img_path,selected_layer):
    self.img_path=img_path
    self.selected_layer=selected_layer
    self.pretrained_model = models.vgg16(pretrained=True).features
    #print( self.pretrained_model)
  def process_image(self):
    img=cv2.imread(self.img_path)
    img=preprocess_image(img)
    return img

  def get_feature(self):
    # input = Variable(torch.randn(1, 3, 224, 224))
    input=self.process_image()
    print("input shape",input.shape)
    x=input
    for index,layer in enumerate(self.pretrained_model):
      #print(index)
      #print(layer)
      x=layer(x)
      if (index == self.selected_layer):
        return x

  def get_single_feature(self):
    features=self.get_feature()
    print("features.shape",features.shape)
    feature=features[:,0,:,:]
    print(feature.shape)
    feature=feature.view(feature.shape[1],feature.shape[2])
    print(feature.shape)
    return features

  def save_feature_to_img(self):
    #to numpy
    features=self.get_single_feature()
    for i in range(features.shape[1]):
      feature = features[:, i, :, :]
      feature = feature.view(feature.shape[1], feature.shape[2])
      feature = feature.data.numpy()
      # use sigmod to [0,1]
      feature = 1.0 / (1 + np.exp(-1 * feature))
      # to [0,255]
      feature = np.round(feature * 255)
      print(feature[0])
      mkdir('./feature/' + str(self.selected_layer))
      cv2.imwrite('./feature/'+ str( self.selected_layer)+'/' +str(i)+'.jpg', feature)
if __name__=='__main__':
  # get class
  for k in range(30):
    myClass=FeatureVisualization('/home/lqy/examples/TRP.PNG',k)
    print (myClass.pretrained_model)
    myClass.save_feature_to_img()

以上这篇使用pytorch实现可视化中间层的结果就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python网络编程学习笔记(一)
Jun 09 Python
Python字典操作简明总结
Apr 13 Python
python数据封装json格式数据
Mar 04 Python
python调用百度语音识别实现大音频文件语音识别功能
Aug 30 Python
python学习之hook钩子的原理和使用
Oct 25 Python
python输出电脑上所有的串口名的方法
Jul 02 Python
numpy.linalg.eig() 计算矩阵特征向量方式
Nov 29 Python
使用Pandas的Series方法绘制图像教程
Dec 04 Python
Django 设置多环境配置文件载入问题
Feb 25 Python
详解django使用include无法跳转的解决方法
Mar 19 Python
PyQt5爬取12306车票信息程序的实现
May 14 Python
Python中的pprint模块
Nov 27 Python
在Pytorch中计算自己模型的FLOPs方式
Dec 30 #Python
Pytorch之保存读取模型实例
Dec 30 #Python
Python爬虫解析网页的4种方式实例及原理解析
Dec 30 #Python
Python中如何将一个类方法变为多个方法
Dec 30 #Python
pytorch 实现打印模型的参数值
Dec 30 #Python
Python如何基于smtplib发不同格式的邮件
Dec 30 #Python
pytorch获取模型某一层参数名及参数值方式
Dec 30 #Python
You might like
PHP编码规范-php coding standard
2007/03/16 PHP
第五章 php数组操作
2011/12/30 PHP
php基于自定义函数记录log日志方法
2017/07/21 PHP
YII2框架中ActiveDataProvider与GridView的配合使用操作示例
2020/03/18 PHP
javascript 日期常用的方法
2009/11/11 Javascript
基于MooTools的很有创意的滚动条时钟动画
2010/11/14 Javascript
js 用CreateElement动态创建标签示例
2013/11/20 Javascript
指定区域的图片自动按比例缩小的js代码(防止页面被图片撑破)
2014/02/21 Javascript
javascript框架设计读书笔记之字符串的扩展和修复
2014/12/02 Javascript
jQuery大于号(>)选择器的作用解释
2015/01/13 Javascript
剖析Node.js异步编程中的回调与代码设计模式
2016/02/16 Javascript
Vue.set()实现数据动态响应的方法
2018/02/07 Javascript
Vue slot用法(小结)
2018/10/22 Javascript
Vue通过配置WebSocket并实现群聊功能
2019/12/31 Javascript
javascript设计模式 ? 访问者模式原理与用法实例分析
2020/04/26 Javascript
vue props default Array或是Object的正确写法说明
2020/07/30 Javascript
haskell实现多线程服务器实例代码
2013/11/26 Python
Python cookbook(字符串与文本)在字符串的开头或结尾处进行文本匹配操作
2018/04/20 Python
Pandas 同元素多列去重的实例
2018/07/03 Python
Python高级特性与几种函数的讲解
2019/03/08 Python
python 通过可变参数计算n个数的乘积方法
2019/06/13 Python
python如何统计代码运行的时长
2019/07/24 Python
python conda操作方法
2019/09/11 Python
django admin 添加自定义链接方式
2020/03/11 Python
jupyter notebook 恢复误删单元格或者历史代码的实现
2020/04/17 Python
Python getattr()函数使用方法代码实例
2020/08/10 Python
python 使用OpenCV进行简单的人像分割与合成
2021/02/02 Python
html5使用canvas实现弹幕功能示例
2017/09/11 HTML / CSS
NBA欧洲商店(英国):NBA Europe Store UK
2018/07/27 全球购物
Otticanet英国:最顶尖的世界名牌眼镜, 能得到打折季的价格
2019/02/10 全球购物
2019年c语言经典面试题目
2016/08/17 面试题
营业经理岗位职责
2013/11/10 职场文书
写给学生的新学期寄语
2014/01/18 职场文书
优秀共产党员先进事迹材料
2014/05/06 职场文书
购房意向书
2014/08/30 职场文书
残联2016年全国助残日活动总结
2016/04/01 职场文书