keras K.function获取某层的输出操作


Posted in Python onJune 29, 2020

如下所示:

from keras import backend as K
from keras.models import load_model

models = load_model('models.hdf5')
image=r'image.png'
images=cv2.imread(r'image.png')
image_arr = process_image(image, (224, 224, 3))
image_arr = np.expand_dims(image_arr, axis=0)
layer_1 = K.function([base_model.get_input_at(0)], [base_model.get_layer('layer_name').output])
f1 = layer_1([image_arr])[0]

加载训练好并保存的网络模型

加载数据(图像),并将数据处理成array形式

指定输出层

将处理后的数据输入,然后获取输出

其中,K.function有两种不同的写法:

1. 获取名为layer_name的层的输出

layer_1 = K.function([base_model.get_input_at(0)], [base_model.get_layer('layer_name').output]) #指定输出层的名称

2. 获取第n层的输出

layer_1 = K.function([model.get_input_at(0)], [model.layers[5].output]) #指定输出层的序号(层号从0开始)

另外,需要注意的是,书写不规范会导致报错:

报错:

TypeError: inputs to a TensorFlow backend function should be a list or tuple

将该句:

f1 = layer_1(image_arr)[0]

修改为:

f1 = layer_1([image_arr])[0]

补充知识:keras.backend.function()

如下所示:

def function(inputs, outputs, updates=None, **kwargs):
 """Instantiates a Keras function.
 Arguments:
   inputs: List of placeholder tensors.
   outputs: List of output tensors.
   updates: List of update ops.
   **kwargs: Passed to `tf.Session.run`.
 Returns:
   Output values as Numpy arrays.
 Raises:
   ValueError: if invalid kwargs are passed in.
 """
 if kwargs:
  for key in kwargs:
   if (key not in tf_inspect.getargspec(session_module.Session.run)[0] and
     key not in tf_inspect.getargspec(Function.__init__)[0]):
    msg = ('Invalid argument "%s" passed to K.function with Tensorflow '
        'backend') % key
    raise ValueError(msg)
 return Function(inputs, outputs, updates=updates, **kwargs)

这是keras.backend.function()的源码。其中函数定义开头的注释就是官方文档对该函数的解释。

我们可以发现function()函数返回的是一个Function对象。下面是Function类的定义。

class Function(object):
 """Runs a computation graph.
 Arguments:
   inputs: Feed placeholders to the computation graph.
   outputs: Output tensors to fetch.
   updates: Additional update ops to be run at function call.
   name: a name to help users identify what this function does.
 """

 def __init__(self, inputs, outputs, updates=None, name=None,
        **session_kwargs):
  updates = updates or []
  if not isinstance(inputs, (list, tuple)):
   raise TypeError('`inputs` to a TensorFlow backend function '
           'should be a list or tuple.')
  if not isinstance(outputs, (list, tuple)):
   raise TypeError('`outputs` of a TensorFlow backend function '
           'should be a list or tuple.')
  if not isinstance(updates, (list, tuple)):
   raise TypeError('`updates` in a TensorFlow backend function '
           'should be a list or tuple.')
  self.inputs = list(inputs)
  self.outputs = list(outputs)
  with ops.control_dependencies(self.outputs):
   updates_ops = []
   for update in updates:
    if isinstance(update, tuple):
     p, new_p = update
     updates_ops.append(state_ops.assign(p, new_p))
    else:
     # assumed already an op
     updates_ops.append(update)
   self.updates_op = control_flow_ops.group(*updates_ops)
  self.name = name
  self.session_kwargs = session_kwargs

 def __call__(self, inputs):
  if not isinstance(inputs, (list, tuple)):
   raise TypeError('`inputs` should be a list or tuple.')
  feed_dict = {}
  for tensor, value in zip(self.inputs, inputs):
   if is_sparse(tensor):
    sparse_coo = value.tocoo()
    indices = np.concatenate((np.expand_dims(sparse_coo.row, 1),
                 np.expand_dims(sparse_coo.col, 1)), 1)
    value = (indices, sparse_coo.data, sparse_coo.shape)
   feed_dict[tensor] = value
  session = get_session()
  updated = session.run(
    self.outputs + [self.updates_op],
    feed_dict=feed_dict,
    **self.session_kwargs)
  return updated[:len(self.outputs)]

所以,function函数利用我们之前已经创建好的comuptation graph。遵循计算图,从输入到定义的输出。这也是为什么该函数经常用于提取中间层结果。

以上这篇keras K.function获取某层的输出操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 自动提交和抓取网页
Jul 13 Python
python使用正则表达式分析网页中的图片并进行替换的方法
Mar 26 Python
Python Paramiko模块的安装与使用详解
Nov 18 Python
Python基础教程之浅拷贝和深拷贝实例详解
Jul 15 Python
分享给Python新手们的几道简单练习题
Sep 21 Python
Python实现的字典值比较功能示例
Jan 08 Python
django认证系统 Authentication使用详解
Jul 22 Python
Django 重写用户模型的实现
Jul 29 Python
Python configparser模块配置文件过程解析
Mar 03 Python
使用keras实现BiLSTM+CNN+CRF文字标记NER
Jun 29 Python
python实现简单区块链结构
Apr 25 Python
Python虚拟环境virtualenv是如何使用的
Jun 20 Python
Python pytesseract验证码识别库用法解析
Jun 29 #Python
用Python开发app后端有优势吗
Jun 29 #Python
在keras里实现自定义上采样层
Jun 28 #Python
Python如何对XML 解析
Jun 28 #Python
keras 自定义loss层+接受输入实例
Jun 28 #Python
python批量处理多DNS多域名的nslookup解析实现
Jun 28 #Python
解决Keras TensorFlow 混编中 trainable=False设置无效问题
Jun 28 #Python
You might like
PHP aes (ecb)解密后乱码问题
2015/06/22 PHP
详解Yii2 定制表单输入字段的标签和样式
2017/01/04 PHP
利用laravel+ajax实现文件上传功能方法示例
2017/08/13 PHP
ThinkPHP5框架缓存查询操作分析
2018/05/30 PHP
phpstudy隐藏index.php的方法
2020/09/21 PHP
JQuery在页面中添加和除移DOM示例代码
2013/06/24 Javascript
动态创建script在IE中缓存js文件时导致编码的解决方法
2014/05/04 Javascript
jQuery实现跨域iframe接口方法调用
2015/03/14 Javascript
JavaScript数据结构学习之数组、栈与队列
2017/05/02 Javascript
Node.js服务器开启Gzip压缩教程
2017/08/11 Javascript
分享5个好用的javascript文件上传插件
2018/09/16 Javascript
jQuery事件多次绑定与解绑问题实例分析
2019/02/19 jQuery
vue项目在线上服务器访问失败原因分析
2020/08/14 Javascript
[00:55]2015国际邀请赛中国区预选赛5月23日——28日约战上海
2015/05/25 DOTA
[01:06:32]DOTA2上海特级锦标赛D组资格赛#1 EG VS VP第一局
2016/02/28 DOTA
Python实现迪杰斯特拉算法过程解析
2020/09/18 Python
乌克兰时尚鞋子和衣服购物网站:Born2be
2018/05/24 全球购物
Topshop美国官网:英国快速时尚品牌
2019/05/16 全球购物
说出ArrayList,Vector, LinkedList的存储性能和特性
2015/01/04 面试题
如何开发安全的AJAX应用
2014/03/26 面试题
遇到的Mysql的面试题
2014/06/29 面试题
编程输出如下图形
2013/11/24 面试题
幼儿园元旦家长感言
2014/02/27 职场文书
社团活动总结
2014/04/28 职场文书
大学生新学期计划书
2014/04/28 职场文书
申论倡议书范文
2014/05/13 职场文书
普通党员个人剖析材料
2014/10/08 职场文书
2015年建筑工程工作总结
2015/05/13 职场文书
反腐倡廉观后感
2015/06/08 职场文书
入党宣誓大会后的感想
2015/08/10 职场文书
2015年物业管理员工工作总结
2015/10/15 职场文书
2017大学生寒假社会实践心得体会
2016/01/14 职场文书
导游词之河姆渡遗址博物馆
2019/10/10 职场文书
关于PHP数组迭代器的使用方法实例
2021/11/17 PHP
什么是动态刷新率DRR? Windows11动态刷新率功能介绍
2021/11/21 数码科技
前端JS获取URL参数的4种方法总结
2022/04/05 Javascript