kaggle+mnist实现手写字体识别


Posted in Python onJuly 26, 2018

现在的许多手写字体识别代码都是基于已有的mnist手写字体数据集进行的,而kaggle需要用到网站上给出的数据集并生成测试集的输出用于提交。这里选择keras搭建卷积网络进行识别,可以直接生成测试集的结果,最终结果识别率大概97%左右的样子。

# -*- coding: utf-8 -*-
"""
Created on Tue Jun 6 19:07:10 2017

@author: Administrator
"""

from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, Flatten 
from keras.layers import Convolution2D, MaxPooling2D 
from keras.utils import np_utils
import os
import pandas as pd
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
from keras import backend as K
import tensorflow as tf

# 全局变量 
batch_size = 100 
nb_classes = 10 
epochs = 20
# input image dimensions 
img_rows, img_cols = 28, 28 
# number of convolutional filters to use 
nb_filters = 32 
# size of pooling area for max pooling 
pool_size = (2, 2) 
# convolution kernel size 
kernel_size = (3, 3) 

inputfile='F:/data/kaggle/mnist/train.csv'
inputfile2= 'F:/data/kaggle/mnist/test.csv'
outputfile= 'F:/data/kaggle/mnist/test_label.csv'


pwd = os.getcwd()
os.chdir(os.path.dirname(inputfile)) 
train= pd.read_csv(os.path.basename(inputfile)) #从训练数据文件读取数据
os.chdir(pwd)

pwd = os.getcwd()
os.chdir(os.path.dirname(inputfile)) 
test= pd.read_csv(os.path.basename(inputfile2)) #从测试数据文件读取数据
os.chdir(pwd)

x_train=train.iloc[:,1:785] #得到特征数据
y_train=train['label']
y_train = np_utils.to_categorical(y_train, 10)

mnist=input_data.read_data_sets("MNIST_data/",one_hot=True) #导入数据
x_test=mnist.test.images
y_test=mnist.test.labels
# 根据不同的backend定下不同的格式 
if K.image_dim_ordering() == 'th': 
 x_train=np.array(x_train)
 test=np.array(test)
 x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols) 
 x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols) 
 input_shape = (1, img_rows, img_cols) 
 test = test.reshape(test.shape[0], 1, img_rows, img_cols) 
else: 
 x_train=np.array(x_train)
 test=np.array(test)
 x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1) 
 X_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1) 
 test = test.reshape(test.shape[0], img_rows, img_cols, 1) 
 input_shape = (img_rows, img_cols, 1) 

x_train = x_train.astype('float32') 
x_test = X_test.astype('float32') 
test = test.astype('float32') 
x_train /= 255 
X_test /= 255
test/=255 
print('X_train shape:', x_train.shape) 
print(x_train.shape[0], 'train samples') 
print(x_test.shape[0], 'test samples') 
print(test.shape[0], 'testOuput samples') 

model=Sequential()#model initial
model.add(Convolution2D(nb_filters, (kernel_size[0], kernel_size[1]), 
      padding='same', 
      input_shape=input_shape)) # 卷积层1 
model.add(Activation('relu')) #激活层 
model.add(Convolution2D(nb_filters, (kernel_size[0], kernel_size[1]))) #卷积层2 
model.add(Activation('relu')) #激活层 
model.add(MaxPooling2D(pool_size=pool_size)) #池化层 
model.add(Dropout(0.25)) #神经元随机失活 
model.add(Flatten()) #拉成一维数据 
model.add(Dense(128)) #全连接层1 
model.add(Activation('relu')) #激活层 
model.add(Dropout(0.5)) #随机失活 
model.add(Dense(nb_classes)) #全连接层2 
model.add(Activation('softmax')) #Softmax评分 

#编译模型 
model.compile(loss='categorical_crossentropy', 
    optimizer='adadelta', 
    metrics=['accuracy']) 
#训练模型 

model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs,verbose=1) 
model.predict(x_test)
#评估模型 
score = model.evaluate(x_test, y_test, verbose=0) 
print('Test score:', score[0]) 
print('Test accuracy:', score[1]) 

y_test=model.predict(test)

sess=tf.InteractiveSession()
y_test=sess.run(tf.arg_max(y_test,1))
y_test=pd.DataFrame(y_test)
y_test.to_csv(outputfile)

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python encode和decode的妙用
Sep 02 Python
centos下更新Python版本的步骤
Feb 12 Python
Python接收Gmail新邮件并发送到gtalk的方法
Mar 10 Python
Python批量修改文本文件内容的方法
Apr 29 Python
转换科学计数法的数值字符串为decimal类型的方法
Jul 16 Python
Django之模型层多表操作的实现
Jan 08 Python
Python3中exp()函数用法分析
Feb 19 Python
PyTorch 普通卷积和空洞卷积实例
Jan 07 Python
使用Puppeteer爬取微信文章的实现
Feb 11 Python
win10下python3.8的PIL库安装过程
Jun 08 Python
如何利用Python 进行边缘检测
Oct 14 Python
详细介绍python类及类的用法
May 31 Python
解决tensorflow模型参数保存和加载的问题
Jul 26 #Python
解决tensorflow1.x版本加载saver.restore目录报错的问题
Jul 26 #Python
Flask web开发处理POST请求实现(登录案例)
Jul 26 #Python
基于tensorflow加载部分层的方法
Jul 26 #Python
利用python画出折线图
Jul 26 #Python
浅谈flask源码之请求过程
Jul 26 #Python
python画折线图的程序
Jul 26 #Python
You might like
PHP 设计模式之观察者模式介绍
2012/02/22 PHP
ThinkPHP之用户注册登录留言完整实例
2014/07/22 PHP
使用ltrace工具跟踪PHP库函数调用的方法
2016/04/25 PHP
手把手编写PHP框架 深入了解MVC运行流程
2016/09/19 PHP
php empty 函数判断结果为空但实际值却为非空的原因解析
2018/05/28 PHP
PHP连接SQL Server的方法分析【基于thinkPHP5.1框架】
2019/05/06 PHP
thinkphp5框架实现数据库读取的数据转换成json格式示例
2019/10/10 PHP
Javascript isArray 数组类型检测函数
2009/10/08 Javascript
jQuery函数的第二个参数获取指定上下文中的DOM元素
2014/05/19 Javascript
jQuery中:only-child选择器用法实例
2015/01/03 Javascript
jQuery幻灯片带缩略图轮播效果代码分享
2015/08/17 Javascript
JS实用技巧小结(屏蔽错误、div滚动条设置、背景图片位置等)
2016/06/16 Javascript
利用vue-router实现二级菜单内容转换
2016/11/30 Javascript
javascript 中Cookie读、写与删除操作
2017/03/29 Javascript
JavaScript禁止微信浏览器下拉回弹效果
2017/05/16 Javascript
jQuery实现火车票买票城市选择切换功能
2017/09/15 jQuery
微信小程序实现打开内置地图功能【附源码下载】
2017/12/07 Javascript
解决使用Vue.js显示数据的时,页面闪现原始代码的问题
2018/02/11 Javascript
vuex与组件联合使用的方法
2018/05/10 Javascript
AngularJS发送异步Get/Post请求方法
2018/08/13 Javascript
ajax与jsonp的区别及用法
2018/10/16 Javascript
vue-cli3+typescript初体验小结
2019/02/28 Javascript
CentOS6.5设置Django开发环境
2016/10/13 Python
Python不同目录间进行模块调用的实现方法
2019/01/29 Python
python如何爬取网站数据并进行数据可视化
2019/07/08 Python
python安装cx_Oracle和wxPython的方法
2020/09/14 Python
python中not、and和or的优先级与详细用法介绍
2020/11/03 Python
HTML5中的网络存储实现方式
2020/04/28 HTML / CSS
怀旧收藏品和经典纪念品:Betty’s Attic
2018/08/29 全球购物
Vilebrequin美国官方网上商店:法国豪华泳装品牌
2020/02/22 全球购物
会计与审计毕业生自荐信范文
2013/12/30 职场文书
项目合作计划书
2014/01/09 职场文书
自愿离婚协议书2015
2015/01/26 职场文书
志愿者服务活动总结报告
2015/05/06 职场文书
《分数乘法》教学反思
2016/02/24 职场文书
python实现层次聚类的方法
2021/11/01 Python