使用TensorFlow搭建一个全连接神经网络教程


Posted in Python onFebruary 06, 2020

说明

本例子利用TensorFlow搭建一个全连接神经网络,实现对MNIST手写数字的识别。

先上代码

from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf

# prepare data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

xs = tf.placeholder(tf.float32, [None, 784])
ys = tf.placeholder(tf.float32, [None, 10])

# the model of the fully-connected network
weights = tf.Variable(tf.random_normal([784, 10]))
biases = tf.Variable(tf.zeros([1, 10]) + 0.1)
outputs = tf.matmul(xs, weights) + biases
predictions = tf.nn.softmax(outputs)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(predictions),
            reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

# compute the accuracy
correct_predictions = tf.equal(tf.argmax(predictions, 1), tf.argmax(ys, 1))
accuracy = tf.reduce_mean(tf.cast(correct_predictions, tf.float32))

with tf.Session() as sess:
 init = tf.global_variables_initializer()
 sess.run(init)
 for i in range(1000):
  batch_xs, batch_ys = mnist.train.next_batch(100)
  sess.run(train_step, feed_dict={
   xs: batch_xs,
   ys: batch_ys
  })
  if i % 50 == 0:
   print(sess.run(accuracy, feed_dict={
    xs: mnist.test.images,
    ys: mnist.test.labels
   }))

代码解析

1. 读取MNIST数据

mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

2. 建立占位符

xs = tf.placeholder(tf.float32, [None, 784])
ys = tf.placeholder(tf.float32, [None, 10])

xs 代表图片像素数据, 每张图片(28×28)被展开成(1×784), 有多少图片还未定, 所以shape为None×784.

ys 代表图片标签数据, 0-9十个数字被表示成One-hot形式, 即只有对应bit为1, 其余为0.

3. 建立模型

weights = tf.Variable(tf.random_normal([784, 10]))


biases = tf.Variable(tf.zeros([1, 10]) + 0.1)
outputs = tf.matmul(xs, weights) + biases
predictions = tf.nn.softmax(outputs)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(predictions),
            reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

使用Softmax函数作为激活函数:

使用TensorFlow搭建一个全连接神经网络教程

4. 计算正确率

correct_predictions = tf.equal(tf.argmax(predictions, 1), tf.argmax(ys, 1))
accuracy = tf.reduce_mean(tf.cast(correct_predictions, tf.float32))

5. 使用模型

with tf.Session() as sess:
 init = tf.global_variables_initializer()
 sess.run(init)
 for i in range(1000):
  batch_xs, batch_ys = mnist.train.next_batch(100)
  sess.run(train_step, feed_dict={
   xs: batch_xs,
   ys: batch_ys
  })
  if i % 50 == 0:
   print(sess.run(accuracy, feed_dict={
    xs: mnist.test.images,
    ys: mnist.test.labels
   }))

运行结果

训练1000个循环, 准确率在87%左右.

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
0.1041
0.632
0.7357
0.7837
0.7971
0.8147
0.8283
0.8376
0.8423
0.8501
0.8501
0.8533
0.8567
0.8597
0.8552
0.8647
0.8654
0.8701
0.8712
0.8712

以上这篇使用TensorFlow搭建一个全连接神经网络教程就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用Python生成url短链接的方法
May 04 Python
使用PyV8在Python爬虫中执行js代码
Feb 16 Python
python字典快速保存于读取的方法
Mar 23 Python
基于python中theano库的线性回归
Aug 31 Python
Python关于excel和shp的使用在matplotlib
Jan 03 Python
利用python和ffmpeg 批量将其他图片转换为.yuv格式的方法
Jan 08 Python
对PyQt5中树结构的实现方法详解
Jun 17 Python
python正则爬取某段子网站前20页段子(request库)过程解析
Aug 10 Python
将matplotlib绘图嵌入pyqt的方法示例
Jan 08 Python
python将unicode和str互相转化的实现
May 11 Python
Python实现发票自动校核微信机器人的方法
May 22 Python
python openCV实现摄像头获取人脸图片
Aug 20 Python
详解python 降级到3.6终极解决方案
Feb 06 #Python
PyCharm如何导入python项目的方法
Feb 06 #Python
tensorflow 环境变量设置方式
Feb 06 #Python
快速查找Python安装路径方法
Feb 06 #Python
运行tensorflow python程序,限制对GPU和CPU的占用操作
Feb 06 #Python
如何在django中添加日志功能
Feb 06 #Python
keras tensorflow 实现在python下多进程运行
Feb 06 #Python
You might like
php 来访国内外IP判断代码并实现页面跳转
2009/12/18 PHP
php对关联数组循环遍历的实现方法
2015/03/13 PHP
Yii2实现log输出到file及database的方法
2016/11/12 PHP
分析php://output和php://stdout的区别
2018/05/06 PHP
ThinkPHP 3使用OSS的方法
2018/07/19 PHP
JavaScript的面向对象(一)
2006/11/09 Javascript
artDialog双击会关闭对话框的修改过程分享
2013/08/05 Javascript
JQuery中阻止事件冒泡几种方式及其区别介绍
2014/01/15 Javascript
通过Jquery的Ajax方法读取将table转换为Json
2014/05/31 Javascript
js/jquery判断浏览器的方法小结
2014/09/02 Javascript
jQuery中使用each处理json数据
2015/04/23 Javascript
详解AngularJS中$http缓存以及处理多个$http请求的方法
2016/02/06 Javascript
微信小程序-图片、录音、音频播放、音乐播放、视频、文件代码实例
2016/11/22 Javascript
jQuery实现淡入淡出的模态框
2017/02/09 Javascript
React diff算法的实现示例
2018/04/20 Javascript
react-native动态切换tab组件的方法
2018/07/07 Javascript
Vue实现将数据库中带html标签的内容输出(原始HTML(Raw HTML))
2019/10/28 Javascript
ant design实现圈选功能
2019/12/17 Javascript
[00:37]2016完美“圣”典风云人物:AMS宣传片
2016/12/06 DOTA
[41:56]Spirit vs Liquid Supermajor小组赛A组 BO3 第一场 6.2
2018/06/03 DOTA
python解析html开发库pyquery使用方法
2014/02/07 Python
Python命令行参数解析模块optparse使用实例
2015/04/13 Python
Python编程之多态用法实例详解
2015/05/19 Python
Python中map和列表推导效率比较实例分析
2015/06/17 Python
Python实现Sqlite将字段当做索引进行查询的方法
2016/07/21 Python
基于Python_脚本CGI、特点、应用、开发环境(详解)
2017/05/23 Python
Python中摘要算法MD5,SHA1简介及应用实例代码
2018/01/09 Python
全网首秀之Pycharm十大实用技巧(推荐)
2020/04/27 Python
碧欧泉法国官网:Biotherm法国
2019/10/23 全球购物
员工培训邀请函
2014/02/02 职场文书
军训教官感言
2014/03/02 职场文书
yy婚礼司仪主持词
2014/03/14 职场文书
超市创业计划书
2014/09/15 职场文书
2014年班干部工作总结
2014/11/25 职场文书
用 Python 元类的特性实现 ORM 框架
2021/05/19 Python
源码安装apache脚本部署过程详解
2022/09/23 Servers