解决Keras中循环使用K.ctc_decode内存不释放的问题


Posted in Python onJune 29, 2020

如下一段代码,在多次调用了K.ctc_decode时,会发现程序占用的内存会越来越高,执行速度越来越慢。

data = generator(...)
model = init_model(...)
for i in range(NUM):
  x, y = next(data)
  _y = model.predict(x)
  shape = _y.shape
  input_length = np.ones(shape[0]) * shape[1]
  ctc_decode = K.ctc_decode(_y, input_length)[0][0]
  out = K.get_value(ctc_decode)

原因

每次执行ctc_decode时都会向计算图中添加一个节点,这样会导致计算图逐渐变大,从而影响计算速度和内存。

PS:有资料说是由于get_value导致的,其中也给出了解决方案。

但是我将ctc_decode放在循环体之外就不再出现内存和速度问题,这是否说明get_value影响其实不大呢?

解决方案

通过K.function封装K.ctc_decode,只需初始化一次,只向计算图中添加一个计算节点,然后多次调用该节点(函数)

data = generator(...)
model = init_model(...)
x = model.output  # [batch_sizes, series_length, classes]
input_length = KL.Input(batch_shape=[None], dtype='int32')
ctc_decode = K.ctc_decode(x, input_length=input_length * K.shape(x)[1])
decode = K.function([model.input, input_length], [ctc_decode[0][0]])
for i in range(NUM):
  _x, _y = next(data)
  out = decode([_x, np.ones(1)])

补充知识:CTC_loss和CTC_decode的模型封装代码避免节点不断增加

该问题可以参考上面的描述,无论是CTC_decode还是CTC_loss,每次运行都会创建节点,避免的方法是将其封装到model中,这样就固定了计算节点。

测试方法: 在初始化节点后(注意是在运行fit/predict至少一次后,因为这些方法也会更改计算图状态),运行K.get_session().graph.finalize()锁定节点,此时如果图节点变了会报错并提示出错代码。

from keras import backend as K
from keras.layers import Lambda,Input
from keras import Model
from tensorflow.python.ops import ctc_ops as ctc
import tensorflow as tf
from keras.layers import Layer
class CTC_Batch_Cost():
  '''
  用于计算CTC loss
  '''
  def ctc_lambda_func(self,args):
    """Runs CTC loss algorithm on each batch element.

    # Arguments
      y_true: tensor `(samples, max_string_length)` 真实标签
      y_pred: tensor `(samples, time_steps, num_categories)` 预测前未经过softmax的向量
      input_length: tensor `(samples, 1)` 每一个y_pred的长度
      label_length: tensor `(samples, 1)` 每一个y_true的长度

      # Returns
        Tensor with shape (samples,1) 包含了每一个样本的ctc loss
      """
    y_true, y_pred, input_length, label_length = args

    # y_pred = y_pred[:, :, :]
    # y_pred = y_pred[:, 2:, :]
    return self.ctc_batch_cost(y_true, y_pred, input_length, label_length)

  def __call__(self, args):
    '''
    ctc_decode 每次创建会生成一个节点,这里参考了上面的内容
    将ctc封装成模型,是否会解决这个问题还没有测试过这种方法是否还会出现创建节点的问题
    '''
    y_true = Input(shape=(None,))
    y_pred = Input(shape=(None,None))
    input_length = Input(shape=(1,))
    label_length = Input(shape=(1,))

    lamd = Lambda(self.ctc_lambda_func, output_shape=(1,), name='ctc')([y_true,y_pred,input_length,label_length])
    model = Model([y_true,y_pred,input_length,label_length],[lamd],name="ctc")

    # return Lambda(self.ctc_lambda_func, output_shape=(1,), name='ctc')(args)
    return model(args)

  def ctc_batch_cost(self,y_true, y_pred, input_length, label_length):
    """Runs CTC loss algorithm on each batch element.

    # Arguments
      y_true: tensor `(samples, max_string_length)`
        containing the truth labels.
      y_pred: tensor `(samples, time_steps, num_categories)`
        containing the prediction, or output of the softmax.
      input_length: tensor `(samples, 1)` containing the sequence length for
        each batch item in `y_pred`.
      label_length: tensor `(samples, 1)` containing the sequence length for
        each batch item in `y_true`.

    # Returns
      Tensor with shape (samples,1) containing the
        CTC loss of each element.
    """
    label_length = tf.to_int32(tf.squeeze(label_length, axis=-1))
    input_length = tf.to_int32(tf.squeeze(input_length, axis=-1))
    sparse_labels = tf.to_int32(K.ctc_label_dense_to_sparse(y_true, label_length))

    y_pred = tf.log(tf.transpose(y_pred, perm=[1, 0, 2]) + 1e-7)

    # 注意这里的True是为了忽略解码失败的情况,此时loss会变成nan直到下一个个batch
    return tf.expand_dims(ctc.ctc_loss(inputs=y_pred,
                      labels=sparse_labels,
                      sequence_length=input_length,
                      ignore_longer_outputs_than_inputs=True), 1)

# 使用方法:(注意shape)
loss_out = CTC_Batch_Cost()([y_true, y_pred, audio_length, label_length])
from keras import backend as K
from keras.layers import Lambda,Input
from keras import Model
from tensorflow.python.ops import ctc_ops as ctc
import tensorflow as tf
from keras.layers import Layer

class CTCDecodeLayer(Layer):

  def __init__(self, **kwargs):
    super().__init__(**kwargs)

  def _ctc_decode(self,args):
    base_pred, in_len = args
    in_len = K.squeeze(in_len,axis=-1)

    r = K.ctc_decode(base_pred, in_len, greedy=True, beam_width=100, top_paths=1)
    r1 = r[0][0]
    prob = r[1][0]
    return [r1,prob]

  def call(self, inputs, **kwargs):
    return self._ctc_decode(inputs)

  def compute_output_shape(self, input_shape):
    return [(None,None),(1,)]

class CTCDecode():
  '''用与CTC 解码,得到真实语音序列
      2019年7月18日所写,对ctc_decode使用模型进行了封装,从而在初始化完成后不会再有新节点的产生
  '''
  def __init__(self):
    base_pred = Input(shape=[None,None],name="pred")
    feature_len = Input(shape=[1,],name="feature_len")
    r1, prob = CTCDecodeLayer()([base_pred,feature_len])
    self.model = Model([base_pred,feature_len],[r1,prob])
    pass

  def ctc_decode(self,base_pred,in_len,return_prob = False):
    '''
    :param base_pred:[sample,timestamp,vector]
    :param in_len: [sample,1]
    :return:
    '''
    result,prob = self.model.predict([base_pred,in_len])
    if return_prob:
      return result,prob
    return result

  def __call__(self,base_pred,in_len,return_prob = False):
    return self.ctc_decode(base_pred,in_len,return_prob)


# 使用方法:(注意shape,是batch级的输入)
ctc_decoder = CTCDecode()
ctc_decoder.ctc_decode(result,feature_len)

以上这篇解决Keras中循环使用K.ctc_decode内存不释放的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python将多个文本文件合并为一个文本的代码(便于搜索)
Mar 13 Python
Python实现拼接多张图片的方法
Dec 01 Python
python魔法方法-自定义序列详解
Jul 21 Python
pygame加载中文名mp3文件出现error
Mar 31 Python
Python3实现获取图片文字里中文的方法分析
Dec 13 Python
python安装pywin32clipboard的操作方法
Jan 24 Python
Django如何自定义model创建数据库索引的顺序
Jun 20 Python
Python使用Beautiful Soup爬取豆瓣音乐排行榜过程解析
Aug 15 Python
TensorFlow Saver:保存和读取模型参数.ckpt实例
Feb 10 Python
python 实现多维数组(array)排序
Feb 28 Python
Python+Selenium随机生成手机验证码并检查页面上是否弹出重复手机号码提示框
Sep 21 Python
python 基于opencv实现高斯平滑
Dec 18 Python
Python根据指定文件生成XML的方法
Jun 29 #Python
keras在构建LSTM模型时对变长序列的处理操作
Jun 29 #Python
Python爬虫爬取博客实现可视化过程解析
Jun 29 #Python
使用keras框架cnn+ctc_loss识别不定长字符图片操作
Jun 29 #Python
浅谈keras中的后端backend及其相关函数(K.prod,K.cast)
Jun 29 #Python
如何使用python记录室友的抖音在线时间
Jun 29 #Python
Python sublime安装及配置过程详解
Jun 29 #Python
You might like
计算2000年01月01日起到指定日的天数
2006/10/09 PHP
php小经验:解析preg_match与preg_match_all 函数
2013/06/29 PHP
php根据操作系统转换文件名大小写的方法
2014/02/24 PHP
关于php中的json_encode()和json_decode()函数的一些说明
2016/11/20 PHP
PHP laravel中的多对多关系实例详解
2017/06/07 PHP
PHP操作MongoDB实现增删改查功能【附php7操作MongoDB方法】
2018/04/24 PHP
Highslide.js是一款基于js实现的网页中图片展示插件
2020/03/30 Javascript
JQuery动态给table添加、删除行 改进版
2011/01/19 Javascript
读jQuery之一(对象的组成)
2011/06/11 Javascript
你未必知道的JavaScript和CSS交互的5种方法
2014/04/02 Javascript
jquery实现的代替传统checkbox样式插件
2015/06/19 Javascript
javascript RegExp 使用说明
2016/05/21 Javascript
jquery html动态添加的元素绑定事件详解
2016/05/24 Javascript
jQuery实现获取h1-h6标题元素值的方法
2017/03/06 Javascript
JavaScript 上传文件(psd,压缩包等),图片,视频的实现方法
2017/06/19 Javascript
JS 仿支付宝input文本输入框放大组件的实例
2017/11/14 Javascript
微信小程序滑动选择器的实现代码
2018/08/10 Javascript
小程序实现列表删除功能
2018/10/30 Javascript
小程序实现锚点滑动效果
2019/09/23 Javascript
javascript全局自定义鼠标右键菜单
2020/12/08 Javascript
[01:00:35]2018DOTA2亚洲邀请赛3月30日B组 EffcetVSMineski
2018/03/31 DOTA
python 测试实现方法
2008/12/24 Python
python 随机打乱 图片和对应的标签方法
2018/12/14 Python
Python3数字求和的实例
2019/02/19 Python
总结python中pass的作用
2019/02/27 Python
matplotlib 范围选区(SpanSelector)的使用
2021/02/24 Python
什么是CSS3 HSLA色彩模式?HSLA模拟渐变色条
2016/04/26 HTML / CSS
HTML5移动端开发中的Viewport标签及相关CSS用法解析
2016/04/15 HTML / CSS
英国最大的在线蜡烛商店:Candles Direct
2019/03/26 全球购物
承诺书范文
2014/06/03 职场文书
励志演讲稿3分钟
2014/08/21 职场文书
高一学年自我鉴定范文(3篇)
2014/09/26 职场文书
科长个人四风问题整改措施思想汇报
2014/10/13 职场文书
会议通知范文
2015/04/15 职场文书
活动宣传稿范文
2015/07/23 职场文书
html+css实现滚动到元素位置显示加载动画效果
2021/08/02 HTML / CSS