keras 简单 lstm实例(基于one-hot编码)


Posted in Python onJuly 02, 2020

简单的LSTM问题,能够预测一句话的下一个字词是什么

固定长度的句子,一个句子有3个词。

使用one-hot编码

各种引用

import keras
from keras.models import Sequential
from keras.layers import LSTM, Dense, Dropout
import numpy as np

数据预处理

data = 'abcdefghijklmnopqrstuvwxyz'
data_set = set(data)
 
word_2_int = {b:a for a,b in enumerate(data_set)}
int_2_word = {a:b for a,b in enumerate(data_set)}
 
word_len = len(data_set)
print(word_2_int)
print(int_2_word)

一些辅助函数

def words_2_ints(words):
 ints = []
 for itmp in words:
  ints.append(word_2_int[itmp])
 return ints
 
print(words_2_ints('ab'))
 
def words_2_one_hot(words, num_classes=word_len):
 return keras.utils.to_categorical(words_2_ints(words), num_classes=num_classes)
print(words_2_one_hot('a'))
def get_one_hot_max_idx(one_hot):
 idx_ = 0
 max_ = 0
 for i in range(len(one_hot)):
  if max_ < one_hot[i]:
   max_ = one_hot[i]
   idx_ = i
 return idx_
 
def one_hot_2_words(one_hot):
 tmp = []
 for itmp in one_hot:
  tmp.append(int_2_word[get_one_hot_max_idx(itmp)])
 return "".join(tmp)
 
print( one_hot_2_words(words_2_one_hot('adhjlkw')) )

构造样本

time_step = 3 #一个句子有3个词
 
def genarate_data(batch_size=5, genarate_num=100):
 #genarate_num = -1 表示一直循环下去,genarate_num=1表示生成一个batch的数据,以此类推
 #这里,我也不知道数据有多少,就这么循环的生成下去吧。
 #入参batch_size 控制一个batch 有多少数据,也就是一次要yield进多少个batch_size的数据
 '''
 例如,一个batch有batch_size=5个样本,那么对于这个例子,需要yield进的数据为:
 abc->d
 bcd->e
 cde->f
 def->g
 efg->h
 然后把这些数据都转换成one-hot形式,最终数据,输入x的形式为:
 
 [第1个batch]
 [第2个batch]
 ...
 [第genarate_num个batch]
 
 每个batch的形式为:
 
 [第1句话(如abc)]
 [第2句话(如bcd)]
 ...
 每一句话的形式为:
 
 [第1个词的one-hot表示]
 [第2个词的one-hot表示]
 ...
 '''
 cnt = 0
 batch_x = []
 batch_y = []
 sample_num = 0
 while(True):
  for i in range(len(data) - time_step):
   batch_x.append(words_2_one_hot(data[i : i+time_step]))
   batch_y.append(words_2_one_hot(data[i+time_step])[0]) #这里数据加[0],是为了符合keras的输出数据格式。 因为不加[0],表示是3维的数据。 你可以自己尝试不加0,看下面的test打印出来是什么
   sample_num += 1
   #print('sample num is :', sample_num)
   if len(batch_x) == batch_size:
    yield (np.array(batch_x), np.array(batch_y))
    batch_x = []
    batch_y = []
    if genarate_num != -1:
     cnt += 1
 
    if cnt == genarate_num:
     return
   
for test in genarate_data(batch_size=3, genarate_num=1):
 print('--------x:')
 print(test[0])
 print('--------y:')
 print(test[1])

搭建模型并训练

model = Sequential()
 
# LSTM输出维度为 128
# input_shape控制输入数据的形态
# time_stemp表示一句话有多少个单词
# word_len 表示一个单词用多少维度表示,这里是26维
 
model.add(LSTM(128, input_shape=(time_step, word_len)))
model.add(Dense(word_len, activation='softmax')) #输出用一个softmax,来分类,维度就是26,预测是哪一个字母
 
model.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'])
 
model.fit_generator(generator=genarate_data(batch_size=5, genarate_num=-1), epochs=50, steps_per_epoch=10)
#steps_per_epoch的意思是,一个epoch中,执行多少个batch
#batch_size是一个batch中,有多少个样本。
#所以,batch_size*steps_per_epoch就等于一个epoch中,训练的样本数量。(这个说法不对!再观察看看吧)
#可以将epochs设置成1,或者2,然后在genarate_data中打印样本序号,观察到样本总数。

使用训练后的模型进行预测:

result = model.predict(np.array([words_2_one_hot('bcd')]))

print(one_hot_2_words(result))

可以看到,预测结果为

e

补充知识:训练集产生的onehot编码特征如何在测试集、预测集复现

数据处理中有时要用到onehot编码,如果使用pandas自带的get_dummies方法,训练集产生的onehot编码特征会跟测试集、预测集不一样,正确的方式是使用sklearn自带的OneHotEncoder。

代码

import pandas as pd
from sklearn.preprocessing import OneHotEncoder
ohe = OneHotEncoder(handle_unknown='ignore')
data_train=pd.DataFrame({'职业':['数据挖掘工程师','数据库开发工程师','数据分析师','数据分析师'],
     '籍贯':['福州','厦门','泉州','龙岩']})
ohe.fit(data_train)#训练规则
feature_names=ohe.get_feature_names(data_train.columns)#获取编码后的特征名
data_train_onehot=pd.DataFrame(ohe.transform(data_train).toarray(),columns=feature_names)#应用规则在训练集上
 
data_new=pd.DataFrame({'职业':['数据挖掘工程师','jave工程师'],
     '籍贯':['福州','莆田']})
data_new_onehot=pd.DataFrame(ohe.transform(data_new).toarray(),columns=feature_names)#应用规则在预测集上

以上这篇keras 简单 lstm实例(基于one-hot编码)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中readline判断文件读取结束的方法
Nov 08 Python
python实现根据用户输入从电影网站获取影片信息的方法
Apr 07 Python
python去掉行尾的换行符方法
Jan 04 Python
深入理解python中的select模块
Apr 23 Python
Python使用正则表达式过滤或替换HTML标签的方法详解
Sep 25 Python
Python2和Python3中print的用法示例总结
Oct 25 Python
python爬虫爬取网页表格数据
Mar 07 Python
python获取当前目录路径和上级路径的实例
Apr 26 Python
Pytest参数化parametrize使用代码实例
Feb 22 Python
Django models filter筛选条件详解
Mar 16 Python
python中for in的用法详解
Apr 17 Python
Python中使用aiohttp模拟服务器出现错误问题及解决方法
Oct 31 Python
Python装饰器结合递归原理解析
Jul 02 #Python
Python OpenCV读取中文路径图像的方法
Jul 02 #Python
keras.utils.to_categorical和one hot格式解析
Jul 02 #Python
python 使用多线程创建一个Buffer缓存器的实现思路
Jul 02 #Python
浅谈keras中的keras.utils.to_categorical用法
Jul 02 #Python
Python使用OpenPyXL处理Excel表格
Jul 02 #Python
解决keras GAN训练是loss不发生变化,accuracy一直为0.5的问题
Jul 02 #Python
You might like
php图片上传存储源码并且可以预览
2011/08/26 PHP
JSON在PHP中的应用介绍
2012/09/08 PHP
详解在PHP的Yii框架中使用行为Behaviors的方法
2016/03/18 PHP
JavaScript 组件之旅(三):用 Ant 构建组件
2009/10/28 Javascript
node.js中的fs.statSync方法使用说明
2014/12/16 Javascript
JS实现让网页背景图片斜向移动的方法
2015/02/25 Javascript
AngularJs concepts详解及示例代码
2016/09/01 Javascript
jQuery Easy UI中根据第一个下拉框选中的值设置第二个下拉框是否可以编辑
2016/11/29 Javascript
基于input动态模糊查询的实现方法
2017/12/12 Javascript
React 路由懒加载的几种实现方案
2018/10/23 Javascript
Nginx设置为Node.js的前端服务器方法总结
2019/03/27 Javascript
vue解决花括号数据绑定不成功的问题
2019/10/30 Javascript
Openlayers显示瓦片网格信息的方法
2020/09/28 Javascript
Python中自定义函数的教程
2015/04/27 Python
详解Django中Request对象的相关用法
2015/07/17 Python
Python中的字符串操作和编码Unicode详解
2017/01/18 Python
Python之str操作方法(详解)
2017/06/19 Python
详解Python核心编程中的浅拷贝与深拷贝
2018/01/07 Python
python中退出多层循环的方法
2018/11/27 Python
在Pycharm中自动添加时间日期作者等信息的方法
2019/01/16 Python
Python选择网卡发包及接收数据包
2019/04/04 Python
python实践项目之监控当前联网状态详情
2019/05/23 Python
python flask解析json数据不完整的解决方法
2019/05/26 Python
用django-allauth实现第三方登录的示例代码
2019/06/24 Python
python3获取当前目录的实现方法
2019/07/29 Python
Django models文件模型变更错误解决
2020/05/11 Python
Python读写锁实现实现代码解析
2020/11/28 Python
几个SQL的面试题
2014/03/08 面试题
竟聘演讲稿范文
2013/12/31 职场文书
校运会入场式解说词
2014/02/10 职场文书
精彩广告词大全
2014/03/19 职场文书
餐饮周年庆活动方案
2014/08/14 职场文书
招商引资工作汇报
2014/10/28 职场文书
单方投资意向书
2015/05/11 职场文书
刑事撤诉申请书
2015/05/18 职场文书
vue-router中hash模式与history模式的区别
2021/06/23 Vue.js