使用TensorFlow搭建一个全连接神经网络教程


Posted in Python onFebruary 06, 2020

说明

本例子利用TensorFlow搭建一个全连接神经网络,实现对MNIST手写数字的识别。

先上代码

from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf

# prepare data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

xs = tf.placeholder(tf.float32, [None, 784])
ys = tf.placeholder(tf.float32, [None, 10])

# the model of the fully-connected network
weights = tf.Variable(tf.random_normal([784, 10]))
biases = tf.Variable(tf.zeros([1, 10]) + 0.1)
outputs = tf.matmul(xs, weights) + biases
predictions = tf.nn.softmax(outputs)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(predictions),
            reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

# compute the accuracy
correct_predictions = tf.equal(tf.argmax(predictions, 1), tf.argmax(ys, 1))
accuracy = tf.reduce_mean(tf.cast(correct_predictions, tf.float32))

with tf.Session() as sess:
 init = tf.global_variables_initializer()
 sess.run(init)
 for i in range(1000):
  batch_xs, batch_ys = mnist.train.next_batch(100)
  sess.run(train_step, feed_dict={
   xs: batch_xs,
   ys: batch_ys
  })
  if i % 50 == 0:
   print(sess.run(accuracy, feed_dict={
    xs: mnist.test.images,
    ys: mnist.test.labels
   }))

代码解析

1. 读取MNIST数据

mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

2. 建立占位符

xs = tf.placeholder(tf.float32, [None, 784])
ys = tf.placeholder(tf.float32, [None, 10])

xs 代表图片像素数据, 每张图片(28×28)被展开成(1×784), 有多少图片还未定, 所以shape为None×784.

ys 代表图片标签数据, 0-9十个数字被表示成One-hot形式, 即只有对应bit为1, 其余为0.

3. 建立模型

weights = tf.Variable(tf.random_normal([784, 10]))


biases = tf.Variable(tf.zeros([1, 10]) + 0.1)
outputs = tf.matmul(xs, weights) + biases
predictions = tf.nn.softmax(outputs)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(predictions),
            reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

使用Softmax函数作为激活函数:

使用TensorFlow搭建一个全连接神经网络教程

4. 计算正确率

correct_predictions = tf.equal(tf.argmax(predictions, 1), tf.argmax(ys, 1))
accuracy = tf.reduce_mean(tf.cast(correct_predictions, tf.float32))

5. 使用模型

with tf.Session() as sess:
 init = tf.global_variables_initializer()
 sess.run(init)
 for i in range(1000):
  batch_xs, batch_ys = mnist.train.next_batch(100)
  sess.run(train_step, feed_dict={
   xs: batch_xs,
   ys: batch_ys
  })
  if i % 50 == 0:
   print(sess.run(accuracy, feed_dict={
    xs: mnist.test.images,
    ys: mnist.test.labels
   }))

运行结果

训练1000个循环, 准确率在87%左右.

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
0.1041
0.632
0.7357
0.7837
0.7971
0.8147
0.8283
0.8376
0.8423
0.8501
0.8501
0.8533
0.8567
0.8597
0.8552
0.8647
0.8654
0.8701
0.8712
0.8712

以上这篇使用TensorFlow搭建一个全连接神经网络教程就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
跟老齐学Python之开始真正编程
Sep 12 Python
python自动翻译实现方法
May 28 Python
Python读写txt文本文件的操作方法全解析
Jun 26 Python
详谈Python3 操作系统与路径 模块(os / os.path / pathlib)
Apr 26 Python
pandas使用get_dummies进行one-hot编码的方法
Jul 10 Python
Python中应该使用%还是format来格式化字符串
Sep 25 Python
numpy向空的二维数组中添加元素的方法
Nov 01 Python
pandas.cut具体使用总结
Jun 24 Python
python常用库之NumPy和sklearn入门
Jul 11 Python
Python 实现平台类游戏添加跳跃功能
Mar 27 Python
python3.7 openpyxl 在excel单元格中写入数据实例
Sep 01 Python
通过Python pyecharts输出保存图片代码实例
Nov 25 Python
详解python 降级到3.6终极解决方案
Feb 06 #Python
PyCharm如何导入python项目的方法
Feb 06 #Python
tensorflow 环境变量设置方式
Feb 06 #Python
快速查找Python安装路径方法
Feb 06 #Python
运行tensorflow python程序,限制对GPU和CPU的占用操作
Feb 06 #Python
如何在django中添加日志功能
Feb 06 #Python
keras tensorflow 实现在python下多进程运行
Feb 06 #Python
You might like
php获取目录所有文件并将结果保存到数组(实例)
2013/10/25 PHP
phpstorm 配置xdebug的示例代码
2019/03/31 PHP
jQuery 方法大全方便学习参考
2010/02/25 Javascript
JAVASCRIPT实现的WEB页面跳转以及页面间传值方法
2010/05/13 Javascript
12款经典的白富美型—jquery图片轮播插件—前端开发必备
2013/01/08 Javascript
Javascript表格翻页效果的具体实现
2013/10/05 Javascript
jQuery创建DOM元素实例解析
2015/01/19 Javascript
深入理解JavaScript系列(48):对象创建模式(下篇)
2015/03/04 Javascript
javascript中this的四种用法
2015/05/11 Javascript
javascript实现在线客服效果
2015/07/15 Javascript
微信+angularJS的SPA应用中用router进行页面跳转,jssdk校验失败问题解决
2016/09/09 Javascript
JS实现密码框的显示密码和隐藏密码功能示例
2016/12/26 Javascript
jQuery插件zTree实现获取一级节点数据的方法
2017/03/08 Javascript
基于JavaScript实现评论框展开和隐藏功能
2017/08/25 Javascript
JavaScript实现职责链模式概述
2018/01/25 Javascript
layui问题之模拟select点击事件的实例讲解
2018/08/15 Javascript
Electron中实现大文件上传和断点续传功能
2018/10/28 Javascript
小程序指纹验证的实现代码
2018/12/04 Javascript
vue3 源码解读之 time slicing的使用方法
2019/10/31 Javascript
JSONP解决JS跨域问题的实现
2020/05/25 Javascript
Jquery cookie插件实现原理代码解析
2020/08/04 jQuery
python在windows和linux下获得本机本地ip地址方法小结
2015/03/20 Python
用PyInstaller把Python代码打包成单个独立的exe可执行文件
2018/05/26 Python
浅谈python中对于json写入txt文件的编码问题
2018/06/07 Python
Python基础学习之基本数据结构详解【数字、字符串、列表、元组、集合、字典】
2019/06/18 Python
Keras中的两种模型:Sequential和Model用法
2020/06/27 Python
HTML5中图片之间的缝隙完美解决方法
2017/07/07 HTML / CSS
Kenneth Cole官网:纽约时尚优雅品牌
2016/11/14 全球购物
喜诗官方在线巧克力店:See’s Candies
2017/01/01 全球购物
求∏的近似值,直到最后一项的绝对值小于指定的数
2016/02/12 面试题
创建卫生先进单位实施方案
2014/03/10 职场文书
学生未请假就回家检讨书
2014/09/22 职场文书
质监局领导班子对照检查材料思想汇报
2014/09/27 职场文书
反腐倡廉影片观后感
2015/06/08 职场文书
追悼会悼词大全
2015/06/23 职场文书
Python 装饰器(decorator)常用的创建方式及解析
2022/04/24 Python