keras小技巧——获取某一个网络层的输出方式


Posted in Python onMay 23, 2020

前言:

keras默认提供了如何获取某一个层的某一个节点的输出,但是没有提供如何获取某一个层的输出的接口,所以有时候我们需要获取某一个层的输出,则需要自己编写代码,但是鉴于keras高层封装的特性,编写起来实际上很简单,本文提供两种常见的方法来实现,基于上一篇文章的模型和代码: keras自定义回调函数查看训练的loss和accuracy

一、模型加载以及各个层的信息查看

从前面的定义可知,参见上一篇文章,一共定义了8个网络层,定义如下:

model.add(Convolution2D(filters=6, kernel_size=(5, 5), padding='valid', input_shape=(img_rows, img_cols, 1), activation='tanh'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Convolution2D(filters=16, kernel_size=(5, 5), padding='valid', activation='tanh'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dense(120, activation='tanh'))
model.add(Dense(84, activation='tanh'))
model.add(Dense(n_classes, activation='softmax'))

这里每一个层都没有起名字,实际上最好给每一个层取一个名字,所以这里就使用索引来访问层,如下:

for index in range(8):
 layer=model.get_layer(index=index)
 # layer=model.layers[index] # 这样获取每一个层也是一样的
 print(model)
 
'''运行结果如下:
<keras.engine.sequential.Sequential object at 0x0000012A4F232E10>
<keras.engine.sequential.Sequential object at 0x0000012A4F232E10>
<keras.engine.sequential.Sequential object at 0x0000012A4F232E10>
<keras.engine.sequential.Sequential object at 0x0000012A4F232E10>
<keras.engine.sequential.Sequential object at 0x0000012A4F232E10>
<keras.engine.sequential.Sequential object at 0x0000012A4F232E10>
<keras.engine.sequential.Sequential object at 0x0000012A4F232E10>
<keras.engine.sequential.Sequential object at 0x0000012A4F232E10>
'''

当然由于 model.laters是一个列表,所以可以一次性打印出所有的层信息,即

print(model.layers) # 打印出所有的层

二、模型的加载

准备测试数据

# 训练参数
learning_rate = 0.001
epochs = 10
batch_size = 128
n_classes = 10
 
# 定义图像维度reshape
img_rows, img_cols = 28, 28
 
# 加载keras中的mnist数据集 分为60,000个训练集,10,000个测试集
(x_train, y_train), (x_test, y_test) = mnist.load_data()
 
# 将图片转化为(samples,width,height,channels)的格式
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
 
# 将X_train, X_test的数据格式转为float32
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
# 将X_train, X_test归一化0-1
x_train /= 255
x_test /= 255
 
# 输出0-9转换为ont-hot形式
y_train = np_utils.to_categorical(y_train, n_classes)
y_test = np_utils.to_categorical(y_test, n_classes)

模型的加载

model=keras.models.load_model('./models/lenet5_weight.h5')

注意事项:

keras的每一个层有一个input和output属性,但是它是只针对单节点的层而言的哦,否则就不需要我们再自己编写输出函数了,

如果一个层具有单个节点 (i.e. 如果它不是共享层), 你可以得到它的输入张量、输出张量、输入尺寸和输出尺寸:

layer.input
layer.output
layer.input_shape
layer.output_shape

如果层有多个节点 (参见: 层节点和共享层的概念), 您可以使用以下函数:

layer.get_input_at(node_index)
layer.get_output_at(node_index)
layer.get_input_shape_at(node_index)
layer.get_output_shape_at(node_index)

三、获取某一个层的输出的方法定义

3.1 第一种实现方法

def get_output_function(model,output_layer_index):
 '''
 model: 要保存的模型
 output_layer_index:要获取的那一个层的索引
 '''
 vector_funcrion=K.function([model.layers[0].input],[model.layers[output_layer_index].output])
 def inner(input_data):
  vector=vector_funcrion([input_data])[0]
  return vector
 
 return inner
 
# 现在仅仅测试一张图片
#选择一张图片,选择第一张
x= np.expand_dims(x_test[1],axis=0) #[1,28,28,1] 的形状
 
get_feature=get_output_function(model,6) # 该函数的返回值依然是一个函数哦,获取第6层输出
 
feature=get_feature(x) # 相当于调用 定义在里面的inner函数
print(feature)
'''运行结果为
[[-0.99986297 -0.9988328 -0.9273474 0.9101525 -0.9054705 -0.95798373
 0.9911243 0.78576803 0.99676156 0.39356467 -0.9724135 -0.74534595
 0.8527011 -0.9968267 -0.9420816 -0.32765102 -0.41667578 0.99942905
 0.92333794 0.7565034 -0.38416263 -0.994241 0.3781617 0.9621943
 0.9443946 0.9671554 -0.01000021 -0.9984282 -0.96650964 -0.9925837
 -0.48193568 -0.9749565 -0.79769516 0.9651831 0.9678705 -0.9444472
 0.9405674 0.97538495 -0.12366439 -0.9973782 0.05803521 0.9159217
 -0.9627071 0.99898154 0.99429387 -0.985909 0.5787794 -0.9789403
 -0.94316894 0.9999644 0.9156823 0.46314353 -0.01582102 0.98359734
 0.5586145 -0.97360635 0.99058044 0.9995654 -0.9800733 0.99942625
 0.8786553 -0.9992093 0.99916387 -0.5141877 0.99970615 0.28427476
 0.86589384 0.7649907 -0.9986046 0.9999706 -0.9892468 0.99854743
 -0.86872625 -0.9997323 0.98981035 -0.87805724 -0.9999373 -0.7842255
 -0.97456616 -0.97237325 -0.729563 0.98718935 0.9992022 -0.5294769 ]]
'''

但是上面的实现方法似乎不是很简单,还有更加简单的方法,思想来源与keras中,可以将整个模型model也当成是层layer来处理,实现如下面。

3.2 第二种实现方法

import keras
import numpy as np
from keras.datasets import mnist
from keras.models import Model
 
model=keras.models.load_model('./models/lenet5_weight.h5')
 
#选择一张图片,选择第一张
x= np.expand_dims(x_test[1],axis=0) #[1,28,28,1] 的形状
 
# 将模型作为一个层,输出第7层的输出
layer_model = Model(inputs=model.input, outputs=model.layers[6].output)
 
feature=layer_model.predict(x)
 
print(feature)
'''运行结果为:
[[-0.99986297 -0.9988328 -0.9273474 0.9101525 -0.9054705 -0.95798373
 0.9911243 0.78576803 0.99676156 0.39356467 -0.9724135 -0.74534595
 0.8527011 -0.9968267 -0.9420816 -0.32765102 -0.41667578 0.99942905
 0.92333794 0.7565034 -0.38416263 -0.994241 0.3781617 0.9621943
 0.9443946 0.9671554 -0.01000021 -0.9984282 -0.96650964 -0.9925837
 -0.48193568 -0.9749565 -0.79769516 0.9651831 0.9678705 -0.9444472
 0.9405674 0.97538495 -0.12366439 -0.9973782 0.05803521 0.9159217
 -0.9627071 0.99898154 0.99429387 -0.985909 0.5787794 -0.9789403
 -0.94316894 0.9999644 0.9156823 0.46314353 -0.01582102 0.98359734
 0.5586145 -0.97360635 0.99058044 0.9995654 -0.9800733 0.99942625
 0.8786553 -0.9992093 0.99916387 -0.5141877 0.99970615 0.28427476
 0.86589384 0.7649907 -0.9986046 0.9999706 -0.9892468 0.99854743
 -0.86872625 -0.9997323 0.98981035 -0.87805724 -0.9999373 -0.7842255
 -0.97456616 -0.97237325 -0.729563 0.98718935 0.9992022 -0.5294769 ]]
'''

可见和上面的结果是一样的,

总结:

由于keras的层与模型之间实际上的转化关系,所以提供了非常灵活的输出方法,推荐使用第二种方法获得某一个层的输出。总结为以下几个主要的步骤(四步走):

import keras
import numpy as np
from keras.datasets import mnist
from keras.models import Model
 
# 第一步:准备输入数据
x= np.expand_dims(x_test[1],axis=0) #[1,28,28,1] 的形状
 
# 第二步:加载已经训练的模型
model=keras.models.load_model('./models/lenet5_weight.h5')
 
# 第三步:将模型作为一个层,输出第7层的输出
layer_model = Model(inputs=model.input, outputs=model.layers[6].output)
 
# 第四步:调用新建的“曾模型”的predict方法,得到模型的输出
feature=layer_model.predict(x)
 
print(feature)

以上这篇keras小技巧——获取某一个网络层的输出方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
下载糗事百科的内容_python版
Dec 07 Python
python里对list中的整数求平均并排序
Sep 12 Python
python开发之字符串string操作方法实例详解
Nov 12 Python
一些常用的Python爬虫技巧汇总
Sep 28 Python
Django文件存储 自己定制存储系统解析
Aug 02 Python
Python Web框架之Django框架Model基础详解
Aug 16 Python
如何基于Python创建目录文件夹
Dec 31 Python
tensorflow模型继续训练 fineturn实例
Jan 21 Python
利用Tensorboard绘制网络识别准确率和loss曲线实例
Feb 15 Python
Python关于__name__属性的含义和作用详解
Feb 19 Python
使用keras实现孪生网络中的权值共享教程
Jun 11 Python
python安装cx_Oracle和wxPython的方法
Sep 14 Python
keras自定义回调函数查看训练的loss和accuracy方式
May 23 #Python
Keras设定GPU使用内存大小方式(Tensorflow backend)
May 22 #Python
tensorflow使用L2 regularization正则化修正overfitting过拟合方式
May 22 #Python
Softmax函数原理及Python实现过程解析
May 22 #Python
Python接口测试文件上传实例解析
May 22 #Python
计算Python Numpy向量之间的欧氏距离实例
May 22 #Python
python numpy矩阵信息说明,shape,size,dtype
May 22 #Python
You might like
Smarty结合Ajax实现无刷新留言本实例
2007/01/02 PHP
Yii调试SQL的常用方法
2014/07/09 PHP
制作个性化的WordPress登陆界面的实例教程
2016/05/21 PHP
javascript中字符串替换函数replace()方法与c# 、vb 替换有一点不同
2010/06/25 Javascript
基于JQuery的浮动DIV显示提示信息并自动隐藏
2011/02/11 Javascript
javascript创建createXmlHttpRequest对象示例代码
2014/02/10 Javascript
JavaScript获取网页、浏览器、屏幕高度和宽度汇总
2014/12/18 Javascript
JavaScript中setUTCMilliseconds()方法的使用详解
2015/06/12 Javascript
javascript获取系统当前时间的方法
2015/11/19 Javascript
JavaScript继承学习笔记【新手必看】
2016/05/10 Javascript
javascript中对象的定义、使用以及对象和原型链操作小结
2016/12/14 Javascript
基于js 各种排序方法和sort方法的区别(详解)
2018/01/03 Javascript
React+Redux实现简单的待办事项列表ToDoList
2019/09/29 Javascript
Openlayers测量距离与面积的实现方法
2020/09/25 Javascript
[02:22]《新闻直播间》2017年08月14日
2017/08/15 DOTA
用Python程序抓取网页的HTML信息的一个小实例
2015/05/02 Python
python中threading超线程用法实例分析
2015/05/16 Python
Python设计模式中单例模式的实现及在Tornado中的应用
2016/03/02 Python
redis之django-redis的简单缓存使用
2018/06/07 Python
django框架之cookie/session的使用示例(小结)
2018/10/15 Python
浅析python的Lambda表达式
2019/02/27 Python
详解Python数据分析--Pandas知识点
2019/03/23 Python
PyCharm搭建Spark开发环境的实现步骤
2019/09/05 Python
pandas read_excel()和to_excel()函数解析
2019/09/19 Python
OpenCV图片漫画效果的实现示例
2020/08/18 Python
Martinelli官方商店:西班牙皮鞋和高跟鞋品牌
2019/07/30 全球购物
英国最大的独立玩具专卖店:The Entertainer
2019/09/06 全球购物
常见的软件开发流程有哪些
2015/11/14 面试题
艺术设计专业个人求职信
2014/04/10 职场文书
教师节演讲稿
2014/05/06 职场文书
个人委托书
2014/07/31 职场文书
离婚财产分隔协议书
2014/10/23 职场文书
圣诞晚会主持词
2015/07/01 职场文书
班级管理经验交流材料
2015/11/02 职场文书
导游词之上饶龟峰
2019/10/25 职场文书
Linux中Nginx的防盗链和优化的实现代码
2021/06/20 Servers