使用TensorFlow搭建一个全连接神经网络教程


Posted in Python onFebruary 06, 2020

说明

本例子利用TensorFlow搭建一个全连接神经网络,实现对MNIST手写数字的识别。

先上代码

from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf

# prepare data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

xs = tf.placeholder(tf.float32, [None, 784])
ys = tf.placeholder(tf.float32, [None, 10])

# the model of the fully-connected network
weights = tf.Variable(tf.random_normal([784, 10]))
biases = tf.Variable(tf.zeros([1, 10]) + 0.1)
outputs = tf.matmul(xs, weights) + biases
predictions = tf.nn.softmax(outputs)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(predictions),
            reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

# compute the accuracy
correct_predictions = tf.equal(tf.argmax(predictions, 1), tf.argmax(ys, 1))
accuracy = tf.reduce_mean(tf.cast(correct_predictions, tf.float32))

with tf.Session() as sess:
 init = tf.global_variables_initializer()
 sess.run(init)
 for i in range(1000):
  batch_xs, batch_ys = mnist.train.next_batch(100)
  sess.run(train_step, feed_dict={
   xs: batch_xs,
   ys: batch_ys
  })
  if i % 50 == 0:
   print(sess.run(accuracy, feed_dict={
    xs: mnist.test.images,
    ys: mnist.test.labels
   }))

代码解析

1. 读取MNIST数据

mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

2. 建立占位符

xs = tf.placeholder(tf.float32, [None, 784])
ys = tf.placeholder(tf.float32, [None, 10])

xs 代表图片像素数据, 每张图片(28×28)被展开成(1×784), 有多少图片还未定, 所以shape为None×784.

ys 代表图片标签数据, 0-9十个数字被表示成One-hot形式, 即只有对应bit为1, 其余为0.

3. 建立模型

weights = tf.Variable(tf.random_normal([784, 10]))


biases = tf.Variable(tf.zeros([1, 10]) + 0.1)
outputs = tf.matmul(xs, weights) + biases
predictions = tf.nn.softmax(outputs)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(predictions),
            reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

使用Softmax函数作为激活函数:

使用TensorFlow搭建一个全连接神经网络教程

4. 计算正确率

correct_predictions = tf.equal(tf.argmax(predictions, 1), tf.argmax(ys, 1))
accuracy = tf.reduce_mean(tf.cast(correct_predictions, tf.float32))

5. 使用模型

with tf.Session() as sess:
 init = tf.global_variables_initializer()
 sess.run(init)
 for i in range(1000):
  batch_xs, batch_ys = mnist.train.next_batch(100)
  sess.run(train_step, feed_dict={
   xs: batch_xs,
   ys: batch_ys
  })
  if i % 50 == 0:
   print(sess.run(accuracy, feed_dict={
    xs: mnist.test.images,
    ys: mnist.test.labels
   }))

运行结果

训练1000个循环, 准确率在87%左右.

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
0.1041
0.632
0.7357
0.7837
0.7971
0.8147
0.8283
0.8376
0.8423
0.8501
0.8501
0.8533
0.8567
0.8597
0.8552
0.8647
0.8654
0.8701
0.8712
0.8712

以上这篇使用TensorFlow搭建一个全连接神经网络教程就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
布同 统计英文单词的个数的python代码
Mar 13 Python
详细解读Python中的__init__()方法
May 02 Python
Python语言描述KNN算法与Kd树
Dec 13 Python
tensorflow实现测试时读取任意指定的check point的网络参数
Jan 21 Python
python 实现两个线程交替执行
May 02 Python
Python函数参数定义及传递方式解析
Jun 10 Python
经验丰富程序员才知道的8种高级Python技巧
Jul 27 Python
python中字符串的编码与解码详析
Dec 03 Python
使用Python爬取Json数据的示例代码
Dec 07 Python
python的scipy.stats模块中正态分布常用函数总结
Feb 19 Python
使用Python快速打开一个百万行级别的超大Excel文件的方法
Mar 02 Python
numpy数据类型dtype转换实现
Apr 24 Python
详解python 降级到3.6终极解决方案
Feb 06 #Python
PyCharm如何导入python项目的方法
Feb 06 #Python
tensorflow 环境变量设置方式
Feb 06 #Python
快速查找Python安装路径方法
Feb 06 #Python
运行tensorflow python程序,限制对GPU和CPU的占用操作
Feb 06 #Python
如何在django中添加日志功能
Feb 06 #Python
keras tensorflow 实现在python下多进程运行
Feb 06 #Python
You might like
espresso double下 咖啡粉超细时 饼压力对咖啡的影响
2021/03/03 冲泡冲煮
php xml留言板 xml存储数据的简单例子
2009/08/24 PHP
PHP实现今天是星期几的几种写法
2013/09/26 PHP
PHP jQuery表单,带验证具体实现方法
2014/02/15 PHP
Laravel中错误与异常处理的用法示例
2018/09/16 PHP
IE7提供XMLHttpRequest对象为兼容
2007/03/08 Javascript
表单项的name命名为submit、reset引起的问题
2007/12/22 Javascript
简单的js分页脚本
2009/05/21 Javascript
js apply/call/caller/callee/bind使用方法与区别分析
2009/10/28 Javascript
javascript判断ie浏览器6/7版本加载不同样式表的实现代码
2011/12/26 Javascript
jQuery:delegate中select()不起作用的解决方法(实例讲解)
2014/01/26 Javascript
jQuery判断元素上是否绑定了指定事件的方法
2015/03/17 Javascript
JavaScript将数字转换成大写中文的方法
2015/03/23 Javascript
javascript为按钮注册回车事件(设置默认按钮)的方法
2015/05/09 Javascript
jQuery插件Easyui设置datagrid的pageNumber导致两次请求问题的解决方法
2016/08/06 Javascript
JS输出空格的简单实现方法
2016/09/08 Javascript
JS中BOM相关知识点总结(必看篇)
2016/11/22 Javascript
jquery实现点击页面回到顶部
2016/11/23 Javascript
Angularjs之filter过滤器(推荐)
2016/11/27 Javascript
详解js前端代码异常监控
2017/01/11 Javascript
微信小程序 request接口的封装实例代码
2017/04/26 Javascript
老生常谈Bootstrap媒体对象
2017/07/06 Javascript
echarts实现地图定时切换散点与多图表级联联动详解
2018/08/07 Javascript
每周一练 之 数据结构与算法(Stack)
2019/04/16 Javascript
JavaScript实现的弹出遮罩层特效经典示例【基于jQuery】
2019/07/10 jQuery
ant-design-vue 实现表格内部字段验证功能
2019/12/16 Javascript
Python contextlib模块使用示例
2015/02/18 Python
Python聚类算法之基本K均值实例详解
2015/11/20 Python
python决策树之CART分类回归树详解
2017/12/20 Python
Python统计python文件中代码,注释及空白对应的行数示例【测试可用】
2018/07/25 Python
Apache部署Django项目图文详解
2019/07/30 Python
应用心理学个人的求职信
2013/12/08 职场文书
军训感想500字
2014/02/20 职场文书
新年抽奖获奖感言
2014/03/02 职场文书
2014年小学辅导员工作总结
2014/12/23 职场文书
2015年学校保卫部工作总结
2015/05/11 职场文书