python神经网络 使用Keras构建RNN训练


Posted in Python onMay 04, 2022

Keras中构建RNN的重要函数

1、SimpleRNN

SimpleRNN用于在Keras中构建普通的简单RNN层,在使用前需要import。

from keras.layers import SimpleRNN

在实际使用时,需要用到几个参数。

model.add(
    SimpleRNN(
        batch_input_shape = (BATCH_SIZE,TIME_STEPS,INPUT_SIZE),
        output_dim = CELL_SIZE,
    )
)

其中,batch_input_shape代表RNN输入数据的shape,shape的内容分别是每一次训练使用的BATCH,TIME_STEPS表示这个RNN按顺序输入的时间点的数量,INPUT_SIZE表示每一个时间点的输入数据大小。
CELL_SIZE代表训练每一个时间点的神经元数量。

2、model.train_on_batch

与之前的训练CNN网络和普通分类网络不同,RNN网络在建立时就规定了batch_input_shape,所以训练的时候也需要一定量一定量的传入训练数据。
model.train_on_batch在使用前需要对数据进行处理。获取指定BATCH大小的训练集。

X_batch = X_train[index_start:index_start + BATCH_SIZE,:,:]
Y_batch = Y_train[index_start:index_start + BATCH_SIZE,:]
index_start += BATCH_SIZE

具体训练过程如下:

for i in range(500):
    X_batch = X_train[index_start:index_start + BATCH_SIZE,:,:]
    Y_batch = Y_train[index_start:index_start + BATCH_SIZE,:]
    index_start += BATCH_SIZE
    cost = model.train_on_batch(X_batch,Y_batch)
    if index_start >= X_train.shape[0]:
        index_start = 0
    if i%100 == 0:
        ## acc
        cost,accuracy = model.evaluate(X_test,Y_test,batch_size=50)
        ## W,b = model.layers[0].get_weights()
        print("accuracy:",accuracy)
        x = X_test[1].reshape(1,28,28)

全部代码

这是一个RNN神经网络的例子,用于识别手写体。

import numpy as np
from keras.models import Sequential
from keras.layers import SimpleRNN,Activation,Dense ## 全连接层
from keras.datasets import mnist
from keras.utils import np_utils
from keras.optimizers import Adam

TIME_STEPS = 28
INPUT_SIZE = 28
BATCH_SIZE = 50
index_start = 0
OUTPUT_SIZE = 10
CELL_SIZE = 75
LR = 1e-3

(X_train,Y_train),(X_test,Y_test) = mnist.load_data()
 
X_train = X_train.reshape(-1,28,28)/255
X_test = X_test.reshape(-1,28,28)/255

Y_train = np_utils.to_categorical(Y_train,num_classes= 10)
Y_test = np_utils.to_categorical(Y_test,num_classes= 10)

model = Sequential()

# conv1
model.add(
    SimpleRNN(
        batch_input_shape = (BATCH_SIZE,TIME_STEPS,INPUT_SIZE),
        output_dim = CELL_SIZE,
    )
)
model.add(Dense(OUTPUT_SIZE))
model.add(Activation("softmax"))
adam = Adam(LR)
## compile
model.compile(loss = 'categorical_crossentropy',optimizer = adam,metrics = ['accuracy'])

## tarin
for i in range(500):
    X_batch = X_train[index_start:index_start + BATCH_SIZE,:,:]
    Y_batch = Y_train[index_start:index_start + BATCH_SIZE,:]
    index_start += BATCH_SIZE
    cost = model.train_on_batch(X_batch,Y_batch)
    if index_start >= X_train.shape[0]:
        index_start = 0
    if i%100 == 0:
        ## acc
        cost,accuracy = model.evaluate(X_test,Y_test,batch_size=50)
        ## W,b = model.layers[0].get_weights()
        print("accuracy:",accuracy)

实验结果为:

10000/10000 [==============================] - 1s 147us/step
accuracy: 0.09329999938607215
…………………………
10000/10000 [==============================] - 1s 112us/step
accuracy: 0.9395000022649765
10000/10000 [==============================] - 1s 109us/step
accuracy: 0.9422999995946885
10000/10000 [==============================] - 1s 114us/step
accuracy: 0.9534000000357628
10000/10000 [==============================] - 1s 112us/step
accuracy: 0.9566000008583069
10000/10000 [==============================] - 1s 113us/step
accuracy: 0.950799999833107
10000/10000 [==============================] - 1s 116us/step
10000/10000 [==============================] - 1s 112us/step
accuracy: 0.9474999988079071
10000/10000 [==============================] - 1s 111us/step
accuracy: 0.9515000003576278
10000/10000 [==============================] - 1s 114us/step
accuracy: 0.9288999977707862
10000/10000 [==============================] - 1s 115us/step
accuracy: 0.9487999993562698

以上就是python神经网络使用Keras构建RNN训练的详细内容!


Tags in this post...

Python 相关文章推荐
python中bisect模块用法实例
Sep 25 Python
Python实现的文本简单可逆加密算法示例
May 18 Python
Python数据抓取爬虫代理防封IP方法
Dec 23 Python
Django+Xadmin构建项目的方法步骤
Mar 06 Python
基于Python打造账号共享浏览器功能
May 30 Python
python五子棋游戏的设计与实现
Jun 18 Python
Python内置类型性能分析过程实例
Jan 29 Python
Python实现AI自动抠图实例解析
Mar 05 Python
pycharm下pyqt4安装及环境配置的教程
Apr 24 Python
Django实现后台上传并显示图片功能
May 29 Python
使用Python制作一个数据预处理小工具(多种操作一键完成)
Feb 07 Python
python小程序之飘落的银杏
Apr 17 Python
python神经网络学习 使用Keras进行回归运算
May 04 #Python
python神经网络学习 使用Keras进行简单分类
May 04 #Python
python神经网络 tf.name_scope 和 tf.variable_scope 的区别
May 04 #Python
Python3使用Qt5来实现简易的五子棋小游戏
May 02 #Python
python开发制作好看的时钟效果
关于的python五子棋的算法
python开发人人对战的五子棋小游戏
You might like
点评山进PR-D3L三波段收音机
2021/03/02 无线电
基于mysql的论坛(6)
2006/10/09 PHP
destoon首页调用求购供应信息的地区名称的方法
2014/08/21 PHP
php实现随机显示图片方法汇总
2015/05/21 PHP
PHP + plupload.js实现多图上传并显示进度条加删除实例代码
2017/03/06 PHP
PHP高精确度运算BC函数库实例详解
2017/08/15 PHP
javascript第一课
2007/02/27 Javascript
js 图片轮播(5张图片)
2008/12/30 Javascript
JavaScript 封装Ajax传递的数据代码
2009/06/05 Javascript
(jQuery,mootools,dojo)使用适合自己的编程别名命名
2010/09/14 Javascript
js+css实现增加表单可用性之提示文字
2013/06/03 Javascript
js导出table数据到excel即导出为EXCEL文档的方法
2013/10/10 Javascript
JQuery的ready函数与JS的onload的区别详解
2013/11/21 Javascript
jQuery定义插件的方法
2015/12/18 Javascript
JavaScript+CSS实现的可折叠二级菜单实例
2016/02/29 Javascript
js无法获取到html标签的属性的解决方法
2016/07/26 Javascript
BACKBONE.JS 简单入门范例
2017/10/17 Javascript
vue和better-scroll实现列表左右联动效果详解
2019/04/29 Javascript
原生JS与CSS实现软件卸载对话框功能
2019/12/05 Javascript
详解Python3.1版本带来的核心变化
2015/04/07 Python
Python2.7读取PDF文件的方法示例
2017/07/13 Python
基于循环神经网络(RNN)实现影评情感分类
2018/03/26 Python
python用match()函数爬数据方法详解
2019/07/23 Python
解决Python pip 自动更新升级失败的问题
2020/02/21 Python
python3光学字符识别模块tesserocr与pytesseract的使用详解
2020/02/26 Python
解决Pymongo insert时会自动添加_id的问题
2020/12/05 Python
Fanatics法国官网:美国体育电商
2019/08/27 全球购物
平面设计师工作职责范文
2013/12/03 职场文书
《理想的风筝》教学反思
2014/04/11 职场文书
交通违章检讨书
2014/09/21 职场文书
教师师德师风自我剖析材料
2014/09/29 职场文书
公安领导班子四风问题个人整改措施思想汇报
2014/10/09 职场文书
判缓刑人员个人思想汇报
2014/10/10 职场文书
2014年教师业务工作总结
2014/12/19 职场文书
推销搭讪开场白
2015/05/28 职场文书
解决Python字典查找报Keyerror的问题
2021/05/26 Python