python神经网络 使用Keras构建RNN训练


Posted in Python onMay 04, 2022

Keras中构建RNN的重要函数

1、SimpleRNN

SimpleRNN用于在Keras中构建普通的简单RNN层,在使用前需要import。

from keras.layers import SimpleRNN

在实际使用时,需要用到几个参数。

model.add(
    SimpleRNN(
        batch_input_shape = (BATCH_SIZE,TIME_STEPS,INPUT_SIZE),
        output_dim = CELL_SIZE,
    )
)

其中,batch_input_shape代表RNN输入数据的shape,shape的内容分别是每一次训练使用的BATCH,TIME_STEPS表示这个RNN按顺序输入的时间点的数量,INPUT_SIZE表示每一个时间点的输入数据大小。
CELL_SIZE代表训练每一个时间点的神经元数量。

2、model.train_on_batch

与之前的训练CNN网络和普通分类网络不同,RNN网络在建立时就规定了batch_input_shape,所以训练的时候也需要一定量一定量的传入训练数据。
model.train_on_batch在使用前需要对数据进行处理。获取指定BATCH大小的训练集。

X_batch = X_train[index_start:index_start + BATCH_SIZE,:,:]
Y_batch = Y_train[index_start:index_start + BATCH_SIZE,:]
index_start += BATCH_SIZE

具体训练过程如下:

for i in range(500):
    X_batch = X_train[index_start:index_start + BATCH_SIZE,:,:]
    Y_batch = Y_train[index_start:index_start + BATCH_SIZE,:]
    index_start += BATCH_SIZE
    cost = model.train_on_batch(X_batch,Y_batch)
    if index_start >= X_train.shape[0]:
        index_start = 0
    if i%100 == 0:
        ## acc
        cost,accuracy = model.evaluate(X_test,Y_test,batch_size=50)
        ## W,b = model.layers[0].get_weights()
        print("accuracy:",accuracy)
        x = X_test[1].reshape(1,28,28)

全部代码

这是一个RNN神经网络的例子,用于识别手写体。

import numpy as np
from keras.models import Sequential
from keras.layers import SimpleRNN,Activation,Dense ## 全连接层
from keras.datasets import mnist
from keras.utils import np_utils
from keras.optimizers import Adam

TIME_STEPS = 28
INPUT_SIZE = 28
BATCH_SIZE = 50
index_start = 0
OUTPUT_SIZE = 10
CELL_SIZE = 75
LR = 1e-3

(X_train,Y_train),(X_test,Y_test) = mnist.load_data()
 
X_train = X_train.reshape(-1,28,28)/255
X_test = X_test.reshape(-1,28,28)/255

Y_train = np_utils.to_categorical(Y_train,num_classes= 10)
Y_test = np_utils.to_categorical(Y_test,num_classes= 10)

model = Sequential()

# conv1
model.add(
    SimpleRNN(
        batch_input_shape = (BATCH_SIZE,TIME_STEPS,INPUT_SIZE),
        output_dim = CELL_SIZE,
    )
)
model.add(Dense(OUTPUT_SIZE))
model.add(Activation("softmax"))
adam = Adam(LR)
## compile
model.compile(loss = 'categorical_crossentropy',optimizer = adam,metrics = ['accuracy'])

## tarin
for i in range(500):
    X_batch = X_train[index_start:index_start + BATCH_SIZE,:,:]
    Y_batch = Y_train[index_start:index_start + BATCH_SIZE,:]
    index_start += BATCH_SIZE
    cost = model.train_on_batch(X_batch,Y_batch)
    if index_start >= X_train.shape[0]:
        index_start = 0
    if i%100 == 0:
        ## acc
        cost,accuracy = model.evaluate(X_test,Y_test,batch_size=50)
        ## W,b = model.layers[0].get_weights()
        print("accuracy:",accuracy)

实验结果为:

10000/10000 [==============================] - 1s 147us/step
accuracy: 0.09329999938607215
…………………………
10000/10000 [==============================] - 1s 112us/step
accuracy: 0.9395000022649765
10000/10000 [==============================] - 1s 109us/step
accuracy: 0.9422999995946885
10000/10000 [==============================] - 1s 114us/step
accuracy: 0.9534000000357628
10000/10000 [==============================] - 1s 112us/step
accuracy: 0.9566000008583069
10000/10000 [==============================] - 1s 113us/step
accuracy: 0.950799999833107
10000/10000 [==============================] - 1s 116us/step
10000/10000 [==============================] - 1s 112us/step
accuracy: 0.9474999988079071
10000/10000 [==============================] - 1s 111us/step
accuracy: 0.9515000003576278
10000/10000 [==============================] - 1s 114us/step
accuracy: 0.9288999977707862
10000/10000 [==============================] - 1s 115us/step
accuracy: 0.9487999993562698

以上就是python神经网络使用Keras构建RNN训练的详细内容!


Tags in this post...

Python 相关文章推荐
Python中使用copy模块实现列表(list)拷贝
Apr 14 Python
Python的Flask框架中实现登录用户的个人资料和头像的教程
Apr 20 Python
Python类的用法实例浅析
May 27 Python
Python遍历numpy数组的实例
Apr 04 Python
Python实现的查询mysql数据库并通过邮件发送信息功能
May 17 Python
对Pycharm创建py文件时自定义头部模板的方法详解
Feb 12 Python
django框架创建应用操作示例
Sep 26 Python
Python实现打印实心和空心菱形
Nov 23 Python
查看已安装tensorflow版本的方法示例
Apr 19 Python
Python基于QQ邮箱实现SSL发送
Apr 26 Python
sublime3之内网安装python插件Anaconda的流程
Nov 10 Python
pycharm debug 断点调试心得分享
Apr 16 Python
python神经网络学习 使用Keras进行回归运算
May 04 #Python
python神经网络学习 使用Keras进行简单分类
May 04 #Python
python神经网络 tf.name_scope 和 tf.variable_scope 的区别
May 04 #Python
Python3使用Qt5来实现简易的五子棋小游戏
May 02 #Python
python开发制作好看的时钟效果
关于的python五子棋的算法
python开发人人对战的五子棋小游戏
You might like
第十三节--对象串行化
2006/11/16 PHP
php面象对象数据库操作类实例
2014/12/02 PHP
Laravel 5 框架入门(一)
2015/04/09 PHP
PHP递归遍历指定文件夹内的文件实现方法
2016/11/15 PHP
PHP实现随机生成水印图片功能
2017/03/22 PHP
laravel 解决Validator使用中出现的问题
2019/10/25 PHP
ajax的hide隐藏问题解决方法
2012/12/11 Javascript
零基础搭建Node.js、Express、Ejs、Mongodb服务器及应用开发入门
2014/12/20 Javascript
常用的jQuery前端技巧收集
2014/12/24 Javascript
window.setInterval()方法的定义和用法及offsetLeft与style.left的区别
2015/11/11 Javascript
js输出数据精确到小数点后n位代码
2016/07/02 Javascript
js H5 canvas投篮小游戏
2016/08/18 Javascript
AngularJs directive详解及示例代码
2016/09/01 Javascript
微信小程序 location API接口详解及实例代码
2016/10/12 Javascript
JS实现点击网页判断是否安装app并打开否则跳转app store
2016/11/18 Javascript
Vue Transition实现类原生组件跳转过渡动画的示例
2017/08/19 Javascript
新手vue构建单页面应用实例代码
2017/09/18 Javascript
angular的输入和输出的使用方法
2018/09/22 Javascript
vue实现select下拉显示隐藏功能
2019/09/30 Javascript
jstree中的checkbox默认选中和隐藏示例代码
2019/12/29 Javascript
使用Vue实现简单计算器
2020/02/25 Javascript
vue Element左侧无限级菜单实现
2020/06/10 Javascript
在vue中封装的弹窗组件使用队列模式实现方法
2020/07/23 Javascript
pandas.DataFrame选取/排除特定行的方法
2018/07/03 Python
Opencv+Python 色彩通道拆分及合并的示例
2018/12/08 Python
python实现石头剪刀布程序
2021/01/20 Python
解决Keyerror ''acc'' KeyError: ''val_acc''问题
2020/06/18 Python
Application Cache未缓存文件无法访问无法加载问题
2014/05/31 HTML / CSS
WEB控件可以激发服务端事件,请谈谈服务端事件是怎么发生并解释其原理?自动传回是什么?为什么要使用自动传回?
2012/02/21 面试题
五分钟演讲稿
2014/04/30 职场文书
社会实践的活动方案
2014/08/22 职场文书
青年文明号汇报材料
2014/12/23 职场文书
2015年个人自我剖析材料
2014/12/29 职场文书
团支部组织委员竞选稿
2015/11/21 职场文书
MySQL 8.0 之不可见列的基本操作
2021/05/20 MySQL
Python 第三方库 openpyxl 的安装过程
2022/12/24 Python