keras在构建LSTM模型时对变长序列的处理操作


Posted in Python onJune 29, 2020

我就废话不多说了,大家还是直接看代码吧~

print(np.shape(X))#(1920, 45, 20)
X=sequence.pad_sequences(X, maxlen=100, padding='post')
print(np.shape(X))#(1920, 100, 20)

model = Sequential()
model.add(Masking(mask_value=0,input_shape=(100,20)))
model.add(LSTM(128,dropout_W=0.5,dropout_U=0.5))
model.add(Dense(13,activation='softmax'))
model.compile(loss='categorical_crossentropy',
       optimizer='adam',
       metrics=['accuracy'])

# 用于保存验证集误差最小的参数,当验证集误差减少时,保存下来
checkpointer = ModelCheckpoint(filepath="keras_rnn.hdf5", verbose=1, save_best_only=True, )
history = LossHistory()
result = model.fit(X, Y, batch_size=10,
          nb_epoch=500, verbose=1, validation_data=(testX, testY),
          callbacks=[checkpointer, history])

model.save('keras_rnn_epochend.hdf5')

补充知识:RNN(LSTM)数据形式及Padding操作处理变长时序序列dynamic_rnn

Summary

RNN

样本一样,计算的状态值和输出结构一致,也即是说只要当前时刻的输入值也前一状态值一样,那么其当前状态值和当前输出结果一致,因为在当前这一轮训练中权重参数和偏置均未更新

RNN的最终状态值与最后一个时刻的输出值一致

输入数据要求格式为,shape=(batch_size, step_time_size, input_size),那么,state的shape=(batch_size, state_size);output的shape=(batch_size, step_time_size, state_size),并且最后一个有效输出(有效序列长度,不包括padding的部分)与状态值会一样

LSTM

LSTM与RNN基本一致,不同在于其状态有两个c_state和h_state,它们的shape一样,输出值output的最后一个有效输出与h_state一致

用变长RNN训练,要求其输入格式仍然要求为shape=(batch_size, step_time_size, input_size),但可指定每一个批次中各个样本的有效序列长度,这样在有效长度内其状态值和输出值原理不变,但超过有效长度的部分的状态值将不会发生改变,而输出值都将是shape=(state_size,)的零向量(注:RNN也是这个原理)

需要说明的是,不是因为无效序列长度部分全padding为0而引起输出全为0,状态不变,因为输出值和状态值得计算不仅依赖当前时刻的输入值,也依赖于上一时刻的状态值。其内部原理是利用一个mask matrix矩阵标记有效部分和无效部分,这样在无效部分就不用计算了,也就是说,这一部分不会造成反向传播时对参数的更新。当然,如果padding不是零,那么padding的这部分输出和状态同样与padding为零的结果是一样的

'''
#样本数据为(batch_size,time_step_size, input_size[embedding_size])的形式,其中samples=4,timesteps=3,features=3,其中第二个、第四个样本是只有一个时间步长和二个时间步长的,这里自动补零
'''
import pandas as pd
import numpy as np
import tensorflow as tf

train_X = np.array([[[0, 1, 2], [9, 8, 7], [3,6,8]], 
          [[3, 4, 5], [0, 10, 110], [0,0,0]], 
          [[6, 7, 8], [6, 5, 4], [1,7,4]], 
          [[9, 0, 1], [3, 7, 4], [0,0,0]],
          [[9, 0, 1], [3, 3, 4], [0,0,0]]
          ])
          
sequence_length = [3, 1, 3, 2, 2]

train_X.shape, train_X[:,2:3,:].reshape(5, 3)
tf.reset_default_graph()

x = tf.placeholder(tf.float32, shape=(None, 3, 3)) # 输入数据只需能够迭代并符合要求shape即可,list也行,shape不指定表示没有shape约束,任意shape均可
rnn_cell = tf.nn.rnn_cell.BasicRNNCell(num_units=6) # state_size[hidden_size]
lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(num_units=6) # state_size[hidden_size]
outputs1, state1 = tf.nn.dynamic_rnn(rnn_cell, x, dtype=tf.float32, sequence_length=sequence_length)
outputs2, state2 = tf.nn.dynamic_rnn(lstm_cell, x, dtype=tf.float32, sequence_length=sequence_length)

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer()) # 初始化rnn_cell中参数变量
  outputs1, state1 = sess.run((outputs1, state1), feed_dict={x: train_X})
  outputs2, state2 = sess.run([outputs2, state2], feed_dict={x: train_X})
  print(outputs1.shape, state1.shape) # (4, 3, 5)->(batch_size, time_step_size, state_size), (4, 5)->(batch_size, state_size)
  print(outputs2.shape) # state2为LSTMStateTuple(c_state, h_state)
  print("---------output1<rnn>state1-----------")
  print(outputs1) # 可以看出output1的最后一个时刻的输出即为state1, 即output1[:,-1,:]与state1相等
  print(state1)
  print(np.all(outputs1[:,-1,:] == state1))
  print("---------output2<lstm>state2-----------")
  print(outputs2) # 可以看出output2的最后一个时刻的输出即为LSTMStateTuple中的h
  print(state2)
  print(np.all(outputs2[:,-1,:] == state2[1]))

再来怼怼dynamic_rnn中数据序列长度tricks

keras在构建LSTM模型时对变长序列的处理操作

思路样例代码

from collections import Counter
import numpy as np

origin_data = np.array([[1, 2, 3],
            [3, 0, 2],
            [1, 1, 4],
            [2, 1, 2],
            [0, 1, 1],
            [2, 0, 3]
            ])
# 按照指定列索引进行分组(看作RNN中一个样本序列),如下为按照第二列分组的结果
# [[[1, 2, 3], [0, 0, 0], [0, 0, 0]],
# [[3, 0, 2], [2, 0, 3], [0, 0, 0]],
# [[1, 1, 4], [2, 1, 2], [0, 1, 1]]]

# 第一步,将原始数据按照某列序列化使之成为一个序列数据
def groupby(a, col_index): # 未加入索引越界判断
  max_len = max(Counter(a[:, col_index]).values())
  for i in set(a[:, col_index]):
    d[i] = []
  for sample in a:
    d[sample[col_index]].append(list(sample))
#   for key in d:
#     d[key].extend([[0]*a.shape[1] for _ in range(max_len-len(d[key]))])
  return list(d.values()), [len(_) for _ in d.values()]

samples, sizes = groupby(origin_data, 2)
# 第二步,根据当前这一批次的中最大序列长度max(sizes)作为padding标准(不同批次的样本序列长度可以不一样,但同一批次要求一样(包括padding的部分)),当然也可以一次性将所有样本(不按照批量)按照最大序列长度padding也行,可能空间浪费
paddig_samples = np.zeros([len(samples), max(sizes), 3])
for seq_index, seq in enumerate(samples):
  paddig_samples[seq_index, :len(seq), :] = seq
paddig_samples

以上这篇keras在构建LSTM模型时对变长序列的处理操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中使用Boolean操作符做真值测试实例
Jan 30 Python
Python中使用pprint函数进行格式化输出的教程
Apr 07 Python
简析Python的闭包和装饰器
Feb 26 Python
python pandas修改列属性的方法详解
Jun 09 Python
Python如何获得百度统计API的数据并发送邮件示例代码
Jan 27 Python
详解python列表生成式和列表生成式器区别
Mar 27 Python
Python文件操作中进行字符串替换的方法(保存到新文件/当前文件)
Jun 28 Python
Python调用百度根据经纬度查询地址的示例代码
Jul 07 Python
PyQt5基本控件使用详解:单选按钮、复选框、下拉框
Aug 05 Python
详解python 利用echarts画地图(热力图)(世界地图,省市地图,区县地图)
Aug 06 Python
python实现UDP协议下的文件传输
Mar 20 Python
Python虚拟环境virtualenv是如何使用的
Jun 20 Python
Python爬虫爬取博客实现可视化过程解析
Jun 29 #Python
使用keras框架cnn+ctc_loss识别不定长字符图片操作
Jun 29 #Python
浅谈keras中的后端backend及其相关函数(K.prod,K.cast)
Jun 29 #Python
如何使用python记录室友的抖音在线时间
Jun 29 #Python
Python sublime安装及配置过程详解
Jun 29 #Python
keras K.function获取某层的输出操作
Jun 29 #Python
Python pytesseract验证码识别库用法解析
Jun 29 #Python
You might like
php面向对象全攻略 (十七) 自动加载类
2009/09/30 PHP
php+js实现异步图片上传实例分享
2014/06/02 PHP
thinkphp的静态缓存用法分析
2014/11/29 PHP
PHP的Yii框架使用中的一些错误解决方法与建议
2015/08/21 PHP
总结PHP代码规范、流程规范、git规范
2018/06/18 PHP
js的with语句使用方法
2007/09/21 Javascript
从父页面读取和操作iframe中内容方法
2009/07/25 Javascript
javascript textarea光标定位方法(兼容IE和FF)
2011/03/12 Javascript
JS cookie中文乱码解决方法
2014/01/28 Javascript
js漂浮广告实现代码
2015/08/15 Javascript
jQuery实现点击行选中或取消CheckBox的方法
2016/08/01 Javascript
微信小程序 loading 详解及实例代码
2016/11/09 Javascript
JS实现加载和读取XML文件的方法详解
2017/04/24 Javascript
js实现移动端编辑添加地址【模仿京东】
2017/04/28 Javascript
AngularJS实现自定义指令与控制器数据交互的方法示例
2017/06/19 Javascript
Bootstrap Table从零开始
2017/06/30 Javascript
关于react-router/react-router-dom v4 history不能访问问题的解决
2018/01/08 Javascript
python利用beautifulSoup实现爬虫
2014/09/29 Python
Python使用Flask框架同时上传多个文件的方法
2015/03/21 Python
简单介绍Python的Django框架的dj-scaffold项目
2015/05/30 Python
python 为什么说eval要慎用
2019/03/26 Python
python3发送邮件需要经过代理服务器的示例代码
2019/07/25 Python
浅谈Python访问MySQL的正确姿势
2020/01/07 Python
解决Tensorflow sess.run导致的内存溢出问题
2020/02/05 Python
python使用pandas抽样训练数据中某个类别实例
2020/02/28 Python
python爬虫开发之urllib模块详细使用方法与实例全解
2020/03/09 Python
PyCharm中如何直接使用Anaconda已安装的库
2020/05/28 Python
python基于pygame实现飞机大作战小游戏
2020/11/19 Python
HTML5 和小程序实现拍照图片旋转、压缩和上传功能
2018/10/08 HTML / CSS
园林毕业生自我鉴定范文
2013/12/29 职场文书
幼儿园门卫岗位职责
2014/02/14 职场文书
煤矿班组长竞聘书
2014/03/31 职场文书
经营理念口号
2014/06/21 职场文书
元旦趣味活动方案
2014/08/22 职场文书
Golang 入门 之url 包
2022/05/04 Golang
MySQL使用IF语句及用case语句对条件并结果进行判断 
2022/09/23 MySQL