深度学习tensorflow基础mnist


Posted in Python onApril 14, 2021

软件架构

mnist数据集的识别使用了两个非常小的网络来实现,第一个是最简单的全连接网络,第二个是卷积网络,mnist数据集是入门数据集,所以不需要进行图像增强,或者用生成器读入内存,直接使用简单的fit()命令就可以一次性训练

安装教程

  1. 使用到的主要第三方库有tensorflow1.x,基于TensorFlow的Keras,基础的库包括numpy,matplotlib
  2. 安装方式也很简答,例如:pip install numpy -i https://pypi.tuna.tsinghua.edu.cn/simple
  3. 注意tensorflow版本不能是2.x

使用说明

  1. 首先,我们预览数据集,运行mnistplt.py,绘制了4张训练用到的图像
  2. 训练全连接网络则运行Densemnist.py,得到权重Dense.h5,加载模型并预测运行Denseload.py
  3. 训练卷积网络则运行CNNmnist.py,得到权重CNN.h5,加载模型并预测运行CNNload.py

结果图

深度学习tensorflow基础mnist

深度学习tensorflow基础mnist

训练过程注释

全连接网络训练:

"""多层感知机训练"""
from tensorflow.examples.tutorials.mnist import input_data
from keras.models import  Sequential
from keras.layers import Dense
#模拟原始灰度数据读入
img_size=28
num=10
mnist=input_data.read_data_sets("./data",one_hot=True)
X_train,y_train,X_test,y_test=mnist.train.images,mnist.train.labels,mnist.test.images,mnist.test.labels
X_train=X_train.reshape(-1,img_size,img_size)
X_test=X_test.reshape(-1,img_size,img_size)
X_train=X_train*255
X_test=X_test*255
y_train=y_train.reshape(-1,num)
y_test=y_test.reshape(-1,num)
print(X_train.shape)
print(y_train.shape)
#全连接层只能输入一维
num_pixels = X_train.shape[1] * X_train.shape[2]
X_train = X_train.reshape(X_train.shape[0],num_pixels).astype('float32')
X_test = X_test.reshape(X_test.shape[0],num_pixels).astype('float32')
#归一化
X_train=X_train/255
X_test=X_test/255
# one hot编码,这里编好了,省略
#y_train = np_utils.to_categorical(y_train)
#y_test = np_utils.to_categorical(y_test)
#搭建网络
def baseline():
    """
    optimizer:优化器,如Adam
    loss:计算损失,当使用categorical_crossentropy损失函数时,标签应为多类模式,例如如果你有10个类别,
    每一个样本的标签应该是一个10维的向量,该向量在对应有值的索引位置为1其余为0
    metrics: 列表,包含评估模型在训练和测试时的性能的指标
    """
    model=Sequential()
    #第一步是确定输入层的数目:在创建模型时用input_dim参数确定,例如,有784个个输入变量,就设成num_pixels。
    #全连接层用Dense类定义:第一个参数是本层神经元个数,然后是初始化方式和激活函数,初始化方法有0到0.05的连续型均匀分布(uniform
    #Keras的默认方法也是这个,也可以用高斯分布进行初始化normal,初始化实际就是该层连接上权重与偏置的初始化
    model.add(Dense(num_pixels,input_dim=num_pixels,kernel_initializer='normal',activation='relu'))
    #softmax是一种用到该层所有神经元的激活函数
    model.add(Dense(num,kernel_initializer='normal',activation='softmax'))
    #categorical_crossentropy适用于多分类问题,并使用softmax作为输出层的激活函数的情况
    model.compile(loss='categorical_crossentropy',optimizer='adam',metrics=['accuracy'])
    return model
#训练模型
model = baseline()
"""
batch_size
整数
每次梯度更新的样本数。
未指定,默认为32
epochs
整数
训练模型迭代次数
verbose
日志展示,整数
0:为不在标准输出流输出日志信息
1:显示进度条
2:每个epoch输出一行记录
对于一个有 2000 个训练样本的数据集,将 2000 个样本分成大小为 500 的 batch,那么完成一个 epoch 需要 4 个 iteration
"""
model.fit(X_train,y_train,validation_data=(X_test,y_test),epochs=10,batch_size=200,verbose=2)
#模型概括打印
model.summary()
#model.evaluate()返回的是 损失值和你选定的指标值(例如,精度accuracy)
"""
verbose:控制日志显示的方式
verbose = 0  不在标准输出流输出日志信息
verbose = 1  输出进度条记录
"""
scores = model.evaluate(X_test,y_test,verbose=0)
print(scores)
#模型保存
model_dir="./Dense.h5"
model.save(model_dir)

CNN训练:

"""
模型构建与训练
Sequential 模型结构: 层(layers)的线性堆栈,它是一个简单的线性结构,没有多余分支,是多个网络层的堆叠
多少个滤波器就输出多少个特征图,即卷积核(滤波器)的深度
3通道RGB图片,一个滤波器有3个通道的小卷积核,但还是只算1个滤波器
"""
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Dropout
#Flatten层用来将输入“压平”,即把多维的输入一维化,
#常用在从卷积层到全连接层的过渡
from keras.layers import Flatten
from keras.layers.convolutional import Conv2D
from keras.layers.convolutional import MaxPooling2D
#模拟原始灰度数据读入
img_size=28
num=10
mnist=input_data.read_data_sets("./data",one_hot=True)
X_train,y_train,X_test,y_test=mnist.train.images,mnist.train.labels,mnist.test.images,mnist.test.labels
X_train=X_train.reshape(-1,img_size,img_size)
X_test=X_test.reshape(-1,img_size,img_size)
X_train=X_train*255
X_test=X_test*255
y_train=y_train.reshape(-1,num)
y_test=y_test.reshape(-1,num)
print(X_train.shape) #(55000, 28, 28)
print(y_train.shape) #(55000, 10)
#此处卷积输入的形状要与模型中的input_shape匹配
X_train = X_train.reshape(X_train.shape[0],28,28,1).astype('float32')
X_test = X_test.reshape(X_test.shape[0],28,28,1).astype('float32')
print(X_train.shape)#(55000,28,28,1)
#归一化
X_train=X_train/255
X_test=X_test/255
# one hot编码,这里编好了,省略
#y_train = np_utils.to_categorical(y_train)
#y_test = np_utils.to_categorical(y_test)
#搭建CNN网络
def CNN():
    """
    第一层是卷积层。该层有32个feature map,作为模型的输入层,接受[pixels][width][height]大小的输入数据。feature map的大小是1*5*5,其输出接一个‘relu'激活函数
    下一层是pooling层,使用了MaxPooling,大小为2*2
    Flatten压缩一维后作为全连接层的输入层
    接下来是全连接层,有128个神经元,激活函数采用‘relu'
    最后一层是输出层,有10个神经元,每个神经元对应一个类别,输出值表示样本属于该类别的概率大小
    """
    model = Sequential()
    model.add(Conv2D(32, (5, 5), input_shape=(img_size,img_size,1), activation='relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Flatten())
    model.add(Dense(128, activation='relu'))
    model.add(Dense(num, activation='softmax'))
    #编译
    model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
    return model
#模型训练
model=CNN()
model.fit(X_train, y_train, validation_data=(X_test, y_test), epochs=5, batch_size=200, verbose=1)
model.summary()
scores = model.evaluate(X_test,y_test,verbose=1)
print(scores)
#模型保存
model_dir="./CNN.h5"
model.save(model_dir)

到此这篇关于mnist的文章就介绍到这了,希望可以帮到你们,更多相关深度学习内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章,希望大家以后多多支持三水点靠木!

Python 相关文章推荐
使用Python进行稳定可靠的文件操作详解
Dec 31 Python
Python用list或dict字段模式读取文件的方法
Jan 10 Python
pyenv命令管理多个Python版本
Mar 26 Python
Python中用post、get方式提交数据的方法示例
Sep 22 Python
Python中的id()函数指的什么
Oct 17 Python
python之django母板页面的使用
Jul 03 Python
Anaconda之conda常用命令介绍(安装、更新、删除)
Oct 06 Python
python实现的读取网页并分词功能示例
Oct 29 Python
使用Bazel编译TensorBoard教程
Feb 15 Python
python实现图像拼接功能
Mar 23 Python
Python 如何安装Selenium
May 06 Python
python数据处理之Pandas类型转换
Apr 28 Python
Python 多线程之threading 模块的使用
Apr 14 #Python
教你如何用python开发一款数字推盘小游戏
深度学习详解之初试机器学习
正确的理解和使用Django信号(Signals)
Apr 14 #Python
编写python程序的90条建议
Apr 14 #Python
Python基础知识之变量的详解
理解深度学习之深度学习简介
Apr 14 #Python
You might like
学习discuz php 引入文件的方法DISCUZ_ROOT
2009/06/21 PHP
PHP开发规范手册之PHP代码规范详解
2011/01/13 PHP
php获取用户IPv4或IPv6地址的代码
2012/11/15 PHP
laravel框架 api自定义全局异常处理方法
2019/10/11 PHP
TP3.2框架分页相关实现方法分析
2020/06/03 PHP
JavaScript 直接操作本地文件的实现代码
2009/12/01 Javascript
使用js dom和jquery分别实现简单增删改
2014/09/11 Javascript
JQuery显示隐藏DIV的方法及代码实例
2015/04/16 Javascript
jquery可定制的在线UEditor编辑器
2015/11/17 Javascript
Jquery1.9.1源码分析系列(六)延时对象应用之jQuery.ready
2015/11/24 Javascript
JavaScript手机振动API
2016/06/11 Javascript
原生js仿浏览器滚动条效果
2017/03/02 Javascript
jQuery插件HighCharts实现2D柱状图、折线图的组合多轴图效果示例【附demo源码下载】
2017/03/09 Javascript
详解Angular 4.x 动态创建组件
2017/04/25 Javascript
微信小程序之页面跳转和参数传递的实现
2017/09/29 Javascript
浅谈Vue的加载顺序探讨
2017/10/25 Javascript
Three.js 再探 - 写一个微信跳一跳极简版游戏
2018/01/04 Javascript
vue2.0 父组件给子组件传递数据的方法
2018/01/15 Javascript
JS将网址url转化为JSON格式的方法
2018/07/02 Javascript
JS实现图片幻灯片效果代码实例
2020/05/21 Javascript
使用nodejs实现JSON文件自动转Excel的工具(推荐)
2020/06/24 NodeJs
基于openlayers实现角度测量功能
2020/09/28 Javascript
[03:26]回顾2015国际邀请赛中国区预选赛
2015/06/09 DOTA
使用python的pandas库读取csv文件保存至mysql数据库
2018/08/20 Python
Python基于codecs模块实现文件读写案例解析
2020/05/11 Python
python 使用多线程创建一个Buffer缓存器的实现思路
2020/07/02 Python
详解HTML5 Canvas绘制不规则图形时的非零环绕原则
2016/03/21 HTML / CSS
Puma印度官网:德国运动品牌
2019/10/06 全球购物
Python文件操作的面试题
2013/06/22 面试题
英语生日邀请函
2014/01/23 职场文书
《鸟岛》教学反思
2014/04/26 职场文书
安全宣传标语口号
2014/06/06 职场文书
企业开业庆典答谢词
2015/01/20 职场文书
收银员岗位职责范本
2015/04/07 职场文书
2015年化工厂工作总结
2015/05/04 职场文书
mysql联合索引的使用规则
2021/06/23 MySQL