解决Keras中循环使用K.ctc_decode内存不释放的问题


Posted in Python onJune 29, 2020

如下一段代码,在多次调用了K.ctc_decode时,会发现程序占用的内存会越来越高,执行速度越来越慢。

data = generator(...)
model = init_model(...)
for i in range(NUM):
  x, y = next(data)
  _y = model.predict(x)
  shape = _y.shape
  input_length = np.ones(shape[0]) * shape[1]
  ctc_decode = K.ctc_decode(_y, input_length)[0][0]
  out = K.get_value(ctc_decode)

原因

每次执行ctc_decode时都会向计算图中添加一个节点,这样会导致计算图逐渐变大,从而影响计算速度和内存。

PS:有资料说是由于get_value导致的,其中也给出了解决方案。

但是我将ctc_decode放在循环体之外就不再出现内存和速度问题,这是否说明get_value影响其实不大呢?

解决方案

通过K.function封装K.ctc_decode,只需初始化一次,只向计算图中添加一个计算节点,然后多次调用该节点(函数)

data = generator(...)
model = init_model(...)
x = model.output  # [batch_sizes, series_length, classes]
input_length = KL.Input(batch_shape=[None], dtype='int32')
ctc_decode = K.ctc_decode(x, input_length=input_length * K.shape(x)[1])
decode = K.function([model.input, input_length], [ctc_decode[0][0]])
for i in range(NUM):
  _x, _y = next(data)
  out = decode([_x, np.ones(1)])

补充知识:CTC_loss和CTC_decode的模型封装代码避免节点不断增加

该问题可以参考上面的描述,无论是CTC_decode还是CTC_loss,每次运行都会创建节点,避免的方法是将其封装到model中,这样就固定了计算节点。

测试方法: 在初始化节点后(注意是在运行fit/predict至少一次后,因为这些方法也会更改计算图状态),运行K.get_session().graph.finalize()锁定节点,此时如果图节点变了会报错并提示出错代码。

from keras import backend as K
from keras.layers import Lambda,Input
from keras import Model
from tensorflow.python.ops import ctc_ops as ctc
import tensorflow as tf
from keras.layers import Layer
class CTC_Batch_Cost():
  '''
  用于计算CTC loss
  '''
  def ctc_lambda_func(self,args):
    """Runs CTC loss algorithm on each batch element.

    # Arguments
      y_true: tensor `(samples, max_string_length)` 真实标签
      y_pred: tensor `(samples, time_steps, num_categories)` 预测前未经过softmax的向量
      input_length: tensor `(samples, 1)` 每一个y_pred的长度
      label_length: tensor `(samples, 1)` 每一个y_true的长度

      # Returns
        Tensor with shape (samples,1) 包含了每一个样本的ctc loss
      """
    y_true, y_pred, input_length, label_length = args

    # y_pred = y_pred[:, :, :]
    # y_pred = y_pred[:, 2:, :]
    return self.ctc_batch_cost(y_true, y_pred, input_length, label_length)

  def __call__(self, args):
    '''
    ctc_decode 每次创建会生成一个节点,这里参考了上面的内容
    将ctc封装成模型,是否会解决这个问题还没有测试过这种方法是否还会出现创建节点的问题
    '''
    y_true = Input(shape=(None,))
    y_pred = Input(shape=(None,None))
    input_length = Input(shape=(1,))
    label_length = Input(shape=(1,))

    lamd = Lambda(self.ctc_lambda_func, output_shape=(1,), name='ctc')([y_true,y_pred,input_length,label_length])
    model = Model([y_true,y_pred,input_length,label_length],[lamd],name="ctc")

    # return Lambda(self.ctc_lambda_func, output_shape=(1,), name='ctc')(args)
    return model(args)

  def ctc_batch_cost(self,y_true, y_pred, input_length, label_length):
    """Runs CTC loss algorithm on each batch element.

    # Arguments
      y_true: tensor `(samples, max_string_length)`
        containing the truth labels.
      y_pred: tensor `(samples, time_steps, num_categories)`
        containing the prediction, or output of the softmax.
      input_length: tensor `(samples, 1)` containing the sequence length for
        each batch item in `y_pred`.
      label_length: tensor `(samples, 1)` containing the sequence length for
        each batch item in `y_true`.

    # Returns
      Tensor with shape (samples,1) containing the
        CTC loss of each element.
    """
    label_length = tf.to_int32(tf.squeeze(label_length, axis=-1))
    input_length = tf.to_int32(tf.squeeze(input_length, axis=-1))
    sparse_labels = tf.to_int32(K.ctc_label_dense_to_sparse(y_true, label_length))

    y_pred = tf.log(tf.transpose(y_pred, perm=[1, 0, 2]) + 1e-7)

    # 注意这里的True是为了忽略解码失败的情况,此时loss会变成nan直到下一个个batch
    return tf.expand_dims(ctc.ctc_loss(inputs=y_pred,
                      labels=sparse_labels,
                      sequence_length=input_length,
                      ignore_longer_outputs_than_inputs=True), 1)

# 使用方法:(注意shape)
loss_out = CTC_Batch_Cost()([y_true, y_pred, audio_length, label_length])
from keras import backend as K
from keras.layers import Lambda,Input
from keras import Model
from tensorflow.python.ops import ctc_ops as ctc
import tensorflow as tf
from keras.layers import Layer

class CTCDecodeLayer(Layer):

  def __init__(self, **kwargs):
    super().__init__(**kwargs)

  def _ctc_decode(self,args):
    base_pred, in_len = args
    in_len = K.squeeze(in_len,axis=-1)

    r = K.ctc_decode(base_pred, in_len, greedy=True, beam_width=100, top_paths=1)
    r1 = r[0][0]
    prob = r[1][0]
    return [r1,prob]

  def call(self, inputs, **kwargs):
    return self._ctc_decode(inputs)

  def compute_output_shape(self, input_shape):
    return [(None,None),(1,)]

class CTCDecode():
  '''用与CTC 解码,得到真实语音序列
      2019年7月18日所写,对ctc_decode使用模型进行了封装,从而在初始化完成后不会再有新节点的产生
  '''
  def __init__(self):
    base_pred = Input(shape=[None,None],name="pred")
    feature_len = Input(shape=[1,],name="feature_len")
    r1, prob = CTCDecodeLayer()([base_pred,feature_len])
    self.model = Model([base_pred,feature_len],[r1,prob])
    pass

  def ctc_decode(self,base_pred,in_len,return_prob = False):
    '''
    :param base_pred:[sample,timestamp,vector]
    :param in_len: [sample,1]
    :return:
    '''
    result,prob = self.model.predict([base_pred,in_len])
    if return_prob:
      return result,prob
    return result

  def __call__(self,base_pred,in_len,return_prob = False):
    return self.ctc_decode(base_pred,in_len,return_prob)


# 使用方法:(注意shape,是batch级的输入)
ctc_decoder = CTCDecode()
ctc_decoder.ctc_decode(result,feature_len)

以上这篇解决Keras中循环使用K.ctc_decode内存不释放的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中反射用法实例
Mar 27 Python
python BeautifulSoup设置页面编码的方法
Apr 03 Python
Python多进程并发与多线程并发编程实例总结
Feb 08 Python
使用Python对微信好友进行数据分析
Jun 27 Python
numpy中loadtxt 的用法详解
Aug 03 Python
python创建文件时去掉非法字符的方法
Oct 31 Python
Python numpy.zero() 初始化矩阵实例
Nov 27 Python
pytorch:实现简单的GAN示例(MNIST数据集)
Jan 10 Python
Python类的绑定方法和非绑定方法实例解析
Mar 04 Python
Python 如何批量更新已安装的库
May 26 Python
python 图像插值 最近邻、双线性、双三次实例
Jul 05 Python
详解Python流程控制语句
Oct 28 Python
Python根据指定文件生成XML的方法
Jun 29 #Python
keras在构建LSTM模型时对变长序列的处理操作
Jun 29 #Python
Python爬虫爬取博客实现可视化过程解析
Jun 29 #Python
使用keras框架cnn+ctc_loss识别不定长字符图片操作
Jun 29 #Python
浅谈keras中的后端backend及其相关函数(K.prod,K.cast)
Jun 29 #Python
如何使用python记录室友的抖音在线时间
Jun 29 #Python
Python sublime安装及配置过程详解
Jun 29 #Python
You might like
php单例模式实现方法分析
2015/03/14 PHP
PHP判断字符串长度的两种方法很实用
2015/09/22 PHP
谈谈php对接芝麻信用踩的坑
2016/12/01 PHP
php判断str字符串是否是xml格式数据的方法示例
2017/07/26 PHP
PHP中十六进制颜色与RGB颜色值互转的方法
2019/03/18 PHP
PHP实现通过二维数组键值获取一维键名操作示例
2019/10/11 PHP
javascript css在IE和Firefox中区别分析
2009/02/18 Javascript
EasyUi tabs的高度与宽度根据IE窗口的变化自适应代码
2010/10/26 Javascript
js用闭包遍历树状数组的方法
2014/03/19 Javascript
3种Jquery限制文本框只能输入数字字母的方法
2014/12/03 Javascript
javascript中Math.random()使用详解
2015/04/15 Javascript
javascript实现倒计时并弹窗提示特效
2015/06/05 Javascript
JS冒泡事件与事件捕获实例详解
2016/11/25 Javascript
BootStrapValidator校验方式
2016/12/19 Javascript
浅谈jQuery操作类数组的工具方法
2016/12/23 Javascript
js如何判断是否在iframe中及防止网页被别站用iframe嵌套
2017/01/11 Javascript
vue element-ui 绑定@keyup事件无效的解决方法
2018/03/09 Javascript
vue项目国际化vue-i18n的安装使用教程
2018/03/14 Javascript
C#程序员入门学习微信小程序的笔记
2019/03/05 Javascript
JQueryDOM之样式操作
2019/03/27 jQuery
VUE解决微信签名及SPA微信invalid signature问题(完美处理)
2019/03/29 Javascript
20个必会的JavaScript面试题(小结)
2019/07/02 Javascript
如何在 Vue 表单中处理图片
2021/01/26 Vue.js
python 实现堆排序算法代码
2012/06/05 Python
详细介绍Python语言中的按位运算符
2013/11/26 Python
python实现查找excel里某一列重复数据并且剔除后打印的方法
2015/05/26 Python
Python寻找两个有序数组的中位数实例详解
2018/12/05 Python
Sentry错误日志监控使用方法解析
2020/11/12 Python
selenium框架中driver.close()和driver.quit()关闭浏览器
2020/12/08 Python
CSS3教程(8):CSS3透明度指南
2009/04/02 HTML / CSS
FC-Moto瑞典:欧洲最大的摩托车服装和头盔商店之一
2018/11/27 全球购物
计算机专业推荐信范文
2013/11/20 职场文书
蜜蜂引路教学反思
2014/02/04 职场文书
2014年高校辅导员工作总结
2014/12/09 职场文书
2015年个人现实表现材料
2014/12/10 职场文书
小程序与后端Java接口交互实现HelloWorld入门
2021/07/09 Java/Android