使用pytorch实现可视化中间层的结果


Posted in Python onDecember 30, 2019

摘要

一直比较想知道图片经过卷积之后中间层的结果,于是使用pytorch写了一个脚本查看,先看效果

这是原图,随便从网上下载的一张大概224*224大小的图片,如下

使用pytorch实现可视化中间层的结果

网络介绍

我们使用的VGG16,包含RULE层总共有30层可以可视化的结果,我们把这30层分别保存在30个文件夹中,每个文件中根据特征的大小保存了64~128张图片

结果如下:

原图大小为224224,经过第一层后大小为64224*224,下面是第一层可视化的结果,总共有64张这样的图片:

使用pytorch实现可视化中间层的结果

下面看看第六层的结果

这层的输出大小是 1128112*112,总共有128张这样的图片

使用pytorch实现可视化中间层的结果

下面是完整的代码

import cv2
import numpy as np
import torch
from torch.autograd import Variable
from torchvision import models

#创建30个文件夹
def mkdir(path): # 判断是否存在指定文件夹,不存在则创建
  # 引入模块
  import os

  # 去除首位空格
  path = path.strip()
  # 去除尾部 \ 符号
  path = path.rstrip("\\")

  # 判断路径是否存在
  # 存在   True
  # 不存在  False
  isExists = os.path.exists(path)

  # 判断结果
  if not isExists:
    # 如果不存在则创建目录
    # 创建目录操作函数
    os.makedirs(path)
    return True
  else:

    return False


def preprocess_image(cv2im, resize_im=True):
  """
    Processes image for CNNs

  Args:
    PIL_img (PIL_img): Image to process
    resize_im (bool): Resize to 224 or not
  returns:
    im_as_var (Pytorch variable): Variable that contains processed float tensor
  """
  # mean and std list for channels (Imagenet)
  mean = [0.485, 0.456, 0.406]
  std = [0.229, 0.224, 0.225]
  # Resize image
  if resize_im:
    cv2im = cv2.resize(cv2im, (224, 224))
  im_as_arr = np.float32(cv2im)
  im_as_arr = np.ascontiguousarray(im_as_arr[..., ::-1])
  im_as_arr = im_as_arr.transpose(2, 0, 1) # Convert array to D,W,H
  # Normalize the channels
  for channel, _ in enumerate(im_as_arr):
    im_as_arr[channel] /= 255
    im_as_arr[channel] -= mean[channel]
    im_as_arr[channel] /= std[channel]
  # Convert to float tensor
  im_as_ten = torch.from_numpy(im_as_arr).float()
  # Add one more channel to the beginning. Tensor shape = 1,3,224,224
  im_as_ten.unsqueeze_(0)
  # Convert to Pytorch variable
  im_as_var = Variable(im_as_ten, requires_grad=True)
  return im_as_var


class FeatureVisualization():
  def __init__(self,img_path,selected_layer):
    self.img_path=img_path
    self.selected_layer=selected_layer
    self.pretrained_model = models.vgg16(pretrained=True).features
    #print( self.pretrained_model)
  def process_image(self):
    img=cv2.imread(self.img_path)
    img=preprocess_image(img)
    return img

  def get_feature(self):
    # input = Variable(torch.randn(1, 3, 224, 224))
    input=self.process_image()
    print("input shape",input.shape)
    x=input
    for index,layer in enumerate(self.pretrained_model):
      #print(index)
      #print(layer)
      x=layer(x)
      if (index == self.selected_layer):
        return x

  def get_single_feature(self):
    features=self.get_feature()
    print("features.shape",features.shape)
    feature=features[:,0,:,:]
    print(feature.shape)
    feature=feature.view(feature.shape[1],feature.shape[2])
    print(feature.shape)
    return features

  def save_feature_to_img(self):
    #to numpy
    features=self.get_single_feature()
    for i in range(features.shape[1]):
      feature = features[:, i, :, :]
      feature = feature.view(feature.shape[1], feature.shape[2])
      feature = feature.data.numpy()
      # use sigmod to [0,1]
      feature = 1.0 / (1 + np.exp(-1 * feature))
      # to [0,255]
      feature = np.round(feature * 255)
      print(feature[0])
      mkdir('./feature/' + str(self.selected_layer))
      cv2.imwrite('./feature/'+ str( self.selected_layer)+'/' +str(i)+'.jpg', feature)
if __name__=='__main__':
  # get class
  for k in range(30):
    myClass=FeatureVisualization('/home/lqy/examples/TRP.PNG',k)
    print (myClass.pretrained_model)
    myClass.save_feature_to_img()

以上这篇使用pytorch实现可视化中间层的结果就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python字符转换
Sep 06 Python
python实现得到一个给定类的虚函数
Sep 28 Python
Python中尝试多线程编程的一个简明例子
Apr 07 Python
介绍Python中的文档测试模块
Apr 28 Python
python 检查是否为中文字符串的方法
Dec 28 Python
使用python实现简单五子棋游戏
Jun 18 Python
python 实现12bit灰度图像映射到8bit显示的方法
Jul 08 Python
Python使用scipy模块实现一维卷积运算示例
Sep 05 Python
Python使用selenium + headless chrome获取网页内容的方法示例
Oct 16 Python
jupyter notebook实现显示行号
Apr 13 Python
pytorch 实现在测试的时候启用dropout
May 27 Python
FP-growth算法发现频繁项集——发现频繁项集
Jun 24 Python
在Pytorch中计算自己模型的FLOPs方式
Dec 30 #Python
Pytorch之保存读取模型实例
Dec 30 #Python
Python爬虫解析网页的4种方式实例及原理解析
Dec 30 #Python
Python中如何将一个类方法变为多个方法
Dec 30 #Python
pytorch 实现打印模型的参数值
Dec 30 #Python
Python如何基于smtplib发不同格式的邮件
Dec 30 #Python
pytorch获取模型某一层参数名及参数值方式
Dec 30 #Python
You might like
小偷PHP+Html+缓存
2006/11/25 PHP
CURL的学习和应用(附多线程实现)
2013/06/03 PHP
php二维数组排序详解
2013/11/06 PHP
php生成QRcode实例
2014/09/22 PHP
Use Word to Search for Files
2007/06/15 Javascript
JavaScript Event学习补遗 addEventSimple
2010/02/11 Javascript
js操作ajax返回的json的注意问题!
2010/02/23 Javascript
JS图片无缝、平滑滚动代码
2014/03/11 Javascript
js子页面获取父页面数据示例
2014/05/15 Javascript
Node.js开发之访问Redis数据库教程
2015/01/14 Javascript
JavaScript使用Max函数返回两个数字中较大数的方法
2015/04/06 Javascript
jsMind通过鼠标拖拽的方式调整节点位置
2015/04/13 Javascript
JavaScript 数据类型详解
2017/03/13 Javascript
Angular2开发——组件规划篇
2017/03/28 Javascript
Node.js dgram模块实现UDP通信示例代码
2017/09/26 Javascript
浅谈Webpack下多环境配置的思路
2018/06/27 Javascript
解决vue 项目引入字体图标报错、不显示等问题
2018/09/01 Javascript
vue+django实现一对一聊天功能的实例代码
2019/07/17 Javascript
python 图片验证码代码分享
2012/07/04 Python
在Python中使用M2Crypto模块实现AES加密的教程
2015/04/08 Python
Python实现查找系统盘中需要找的字符
2015/07/14 Python
django 信号调度机制详解
2019/07/19 Python
django实现支付宝支付实例讲解
2019/10/17 Python
浅谈pytorch池化maxpool2D注意事项
2020/02/18 Python
Python实现动态给类和对象添加属性和方法操作示例
2020/02/29 Python
Volcom法国官网:美国冲浪滑板品牌
2017/05/25 全球购物
英国比较机场停车场网站:Airport Parking Essentials
2019/12/01 全球购物
中学生校园广播稿
2014/01/16 职场文书
事业单位辞职信范文
2014/01/19 职场文书
毕业生学校推荐信范文
2014/05/21 职场文书
法人任命书范本
2014/06/04 职场文书
优秀党员事迹材料
2014/12/18 职场文书
检察院起诉意见书
2015/05/20 职场文书
一个独生女的故事观后感
2015/06/04 职场文书
Redis中缓存穿透/击穿/雪崩问题和解决方法
2021/12/04 Redis
springboot新建项目pom.xml文件第一行报错的解决
2022/01/18 Java/Android