python神经网络 使用Keras构建RNN训练


Posted in Python onMay 04, 2022

Keras中构建RNN的重要函数

1、SimpleRNN

SimpleRNN用于在Keras中构建普通的简单RNN层,在使用前需要import。

from keras.layers import SimpleRNN

在实际使用时,需要用到几个参数。

model.add(
    SimpleRNN(
        batch_input_shape = (BATCH_SIZE,TIME_STEPS,INPUT_SIZE),
        output_dim = CELL_SIZE,
    )
)

其中,batch_input_shape代表RNN输入数据的shape,shape的内容分别是每一次训练使用的BATCH,TIME_STEPS表示这个RNN按顺序输入的时间点的数量,INPUT_SIZE表示每一个时间点的输入数据大小。
CELL_SIZE代表训练每一个时间点的神经元数量。

2、model.train_on_batch

与之前的训练CNN网络和普通分类网络不同,RNN网络在建立时就规定了batch_input_shape,所以训练的时候也需要一定量一定量的传入训练数据。
model.train_on_batch在使用前需要对数据进行处理。获取指定BATCH大小的训练集。

X_batch = X_train[index_start:index_start + BATCH_SIZE,:,:]
Y_batch = Y_train[index_start:index_start + BATCH_SIZE,:]
index_start += BATCH_SIZE

具体训练过程如下:

for i in range(500):
    X_batch = X_train[index_start:index_start + BATCH_SIZE,:,:]
    Y_batch = Y_train[index_start:index_start + BATCH_SIZE,:]
    index_start += BATCH_SIZE
    cost = model.train_on_batch(X_batch,Y_batch)
    if index_start >= X_train.shape[0]:
        index_start = 0
    if i%100 == 0:
        ## acc
        cost,accuracy = model.evaluate(X_test,Y_test,batch_size=50)
        ## W,b = model.layers[0].get_weights()
        print("accuracy:",accuracy)
        x = X_test[1].reshape(1,28,28)

全部代码

这是一个RNN神经网络的例子,用于识别手写体。

import numpy as np
from keras.models import Sequential
from keras.layers import SimpleRNN,Activation,Dense ## 全连接层
from keras.datasets import mnist
from keras.utils import np_utils
from keras.optimizers import Adam

TIME_STEPS = 28
INPUT_SIZE = 28
BATCH_SIZE = 50
index_start = 0
OUTPUT_SIZE = 10
CELL_SIZE = 75
LR = 1e-3

(X_train,Y_train),(X_test,Y_test) = mnist.load_data()
 
X_train = X_train.reshape(-1,28,28)/255
X_test = X_test.reshape(-1,28,28)/255

Y_train = np_utils.to_categorical(Y_train,num_classes= 10)
Y_test = np_utils.to_categorical(Y_test,num_classes= 10)

model = Sequential()

# conv1
model.add(
    SimpleRNN(
        batch_input_shape = (BATCH_SIZE,TIME_STEPS,INPUT_SIZE),
        output_dim = CELL_SIZE,
    )
)
model.add(Dense(OUTPUT_SIZE))
model.add(Activation("softmax"))
adam = Adam(LR)
## compile
model.compile(loss = 'categorical_crossentropy',optimizer = adam,metrics = ['accuracy'])

## tarin
for i in range(500):
    X_batch = X_train[index_start:index_start + BATCH_SIZE,:,:]
    Y_batch = Y_train[index_start:index_start + BATCH_SIZE,:]
    index_start += BATCH_SIZE
    cost = model.train_on_batch(X_batch,Y_batch)
    if index_start >= X_train.shape[0]:
        index_start = 0
    if i%100 == 0:
        ## acc
        cost,accuracy = model.evaluate(X_test,Y_test,batch_size=50)
        ## W,b = model.layers[0].get_weights()
        print("accuracy:",accuracy)

实验结果为:

10000/10000 [==============================] - 1s 147us/step
accuracy: 0.09329999938607215
…………………………
10000/10000 [==============================] - 1s 112us/step
accuracy: 0.9395000022649765
10000/10000 [==============================] - 1s 109us/step
accuracy: 0.9422999995946885
10000/10000 [==============================] - 1s 114us/step
accuracy: 0.9534000000357628
10000/10000 [==============================] - 1s 112us/step
accuracy: 0.9566000008583069
10000/10000 [==============================] - 1s 113us/step
accuracy: 0.950799999833107
10000/10000 [==============================] - 1s 116us/step
10000/10000 [==============================] - 1s 112us/step
accuracy: 0.9474999988079071
10000/10000 [==============================] - 1s 111us/step
accuracy: 0.9515000003576278
10000/10000 [==============================] - 1s 114us/step
accuracy: 0.9288999977707862
10000/10000 [==============================] - 1s 115us/step
accuracy: 0.9487999993562698

以上就是python神经网络使用Keras构建RNN训练的详细内容!


Tags in this post...

Python 相关文章推荐
Python 备份程序代码实现
Mar 06 Python
Python数据结构与算法之完全树与最小堆实例
Dec 13 Python
python tensorflow基于cnn实现手写数字识别
Jan 01 Python
python爬取网页转换为PDF文件
Jun 07 Python
TensorFlow实现Logistic回归
Sep 07 Python
详解Django的model查询操作与查询性能优化
Oct 16 Python
python微信好友数据分析详解
Nov 19 Python
python 限制函数执行时间,自己实现timeout的实例
Jan 12 Python
Python利用字典破解WIFI密码的方法
Feb 27 Python
python实现两个dict合并与计算操作示例
Jul 01 Python
Python 离线工作环境搭建的方法步骤
Jul 29 Python
Python调用.NET库的方法步骤
Dec 27 Python
python神经网络学习 使用Keras进行回归运算
May 04 #Python
python神经网络学习 使用Keras进行简单分类
May 04 #Python
python神经网络 tf.name_scope 和 tf.variable_scope 的区别
May 04 #Python
Python3使用Qt5来实现简易的五子棋小游戏
May 02 #Python
python开发制作好看的时钟效果
关于的python五子棋的算法
python开发人人对战的五子棋小游戏
You might like
php的list()的一步操作给一组变量进行赋值的使用
2011/05/18 PHP
PHP 以POST方式提交XML、获取XML,解析XML详解及实例
2016/10/26 PHP
浅析Node.js中使用依赖注入的相关问题及解决方法
2015/06/24 Javascript
jQuery实例—选项卡的简单实现(js源码和jQuery)
2016/06/14 Javascript
Bootstrap复选框和单选按钮美化插件(推荐)
2016/11/23 Javascript
js a标签点击事件
2017/03/30 Javascript
Angular2使用Angular CLI快速搭建工程(一)
2017/05/21 Javascript
JS 中LocalStorage和SessionStorage的使用
2017/08/17 Javascript
微信小程序用户自定义模版用法实例分析
2017/11/28 Javascript
Vue-cli中为单独页面设置背景色的实现方法
2018/02/11 Javascript
JS实现前端页面的搜索功能
2018/06/12 Javascript
快速解决vue动态绑定多个class的官方实例语法无效的问题
2018/09/05 Javascript
浅谈React Native 传参的几种方式(小结)
2019/05/21 Javascript
Vue实现todo应用的示例
2021/02/20 Vue.js
用Python一键搭建Http服务器的方法
2018/06/01 Python
Python 实现子类获取父类的类成员方法
2019/01/11 Python
python使用 request 发送表单数据操作示例
2019/09/25 Python
如何使用Python发送HTML格式的邮件
2020/02/11 Python
详解用Python进行时间序列预测的7种方法
2020/03/13 Python
django model的update时auto_now不被更新的原因及解决方式
2020/04/01 Python
python3 简单实现组合设计模式
2020/07/02 Python
安装pyinstaller遇到的各种问题(小结)
2020/11/20 Python
html5画布旋转效果示例
2014/01/27 HTML / CSS
Html5元素及基本语法详解
2016/08/02 HTML / CSS
ECHT官方网站:男女健身服
2020/02/14 全球购物
介绍一下linux文件系统分配策略
2013/02/25 面试题
会计与审计专业大专生求职信
2013/10/03 职场文书
优秀实习自我鉴定
2013/12/04 职场文书
国土资源局开展党的群众路线教育实践活动整改措施
2014/09/26 职场文书
2016年幼儿园庆六一开幕词
2016/03/04 职场文书
一文搞懂Python Sklearn库使用
2021/08/23 Python
Nginx中使用Lua脚本与图片的缩略图处理的实现
2022/03/18 Servers
Python实现自动玩连连看的脚本分享
2022/04/04 Python
排查Tomcat进程假死的问题
2022/05/06 Servers
Python OpenGL基本配置方式
2022/05/20 Python
基于redis+lua进行限流的方法
2022/07/23 Redis