使用TensorFlow搭建一个全连接神经网络教程


Posted in Python onFebruary 06, 2020

说明

本例子利用TensorFlow搭建一个全连接神经网络,实现对MNIST手写数字的识别。

先上代码

from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf

# prepare data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

xs = tf.placeholder(tf.float32, [None, 784])
ys = tf.placeholder(tf.float32, [None, 10])

# the model of the fully-connected network
weights = tf.Variable(tf.random_normal([784, 10]))
biases = tf.Variable(tf.zeros([1, 10]) + 0.1)
outputs = tf.matmul(xs, weights) + biases
predictions = tf.nn.softmax(outputs)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(predictions),
            reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

# compute the accuracy
correct_predictions = tf.equal(tf.argmax(predictions, 1), tf.argmax(ys, 1))
accuracy = tf.reduce_mean(tf.cast(correct_predictions, tf.float32))

with tf.Session() as sess:
 init = tf.global_variables_initializer()
 sess.run(init)
 for i in range(1000):
  batch_xs, batch_ys = mnist.train.next_batch(100)
  sess.run(train_step, feed_dict={
   xs: batch_xs,
   ys: batch_ys
  })
  if i % 50 == 0:
   print(sess.run(accuracy, feed_dict={
    xs: mnist.test.images,
    ys: mnist.test.labels
   }))

代码解析

1. 读取MNIST数据

mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

2. 建立占位符

xs = tf.placeholder(tf.float32, [None, 784])
ys = tf.placeholder(tf.float32, [None, 10])

xs 代表图片像素数据, 每张图片(28×28)被展开成(1×784), 有多少图片还未定, 所以shape为None×784.

ys 代表图片标签数据, 0-9十个数字被表示成One-hot形式, 即只有对应bit为1, 其余为0.

3. 建立模型

weights = tf.Variable(tf.random_normal([784, 10]))


biases = tf.Variable(tf.zeros([1, 10]) + 0.1)
outputs = tf.matmul(xs, weights) + biases
predictions = tf.nn.softmax(outputs)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(predictions),
            reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

使用Softmax函数作为激活函数:

使用TensorFlow搭建一个全连接神经网络教程

4. 计算正确率

correct_predictions = tf.equal(tf.argmax(predictions, 1), tf.argmax(ys, 1))
accuracy = tf.reduce_mean(tf.cast(correct_predictions, tf.float32))

5. 使用模型

with tf.Session() as sess:
 init = tf.global_variables_initializer()
 sess.run(init)
 for i in range(1000):
  batch_xs, batch_ys = mnist.train.next_batch(100)
  sess.run(train_step, feed_dict={
   xs: batch_xs,
   ys: batch_ys
  })
  if i % 50 == 0:
   print(sess.run(accuracy, feed_dict={
    xs: mnist.test.images,
    ys: mnist.test.labels
   }))

运行结果

训练1000个循环, 准确率在87%左右.

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
0.1041
0.632
0.7357
0.7837
0.7971
0.8147
0.8283
0.8376
0.8423
0.8501
0.8501
0.8533
0.8567
0.8597
0.8552
0.8647
0.8654
0.8701
0.8712
0.8712

以上这篇使用TensorFlow搭建一个全连接神经网络教程就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用Python编写爬虫的基本模块及框架使用指南
Jan 20 Python
深入了解Python数据类型之列表
Jun 24 Python
Python 如何访问外围作用域中的变量
Sep 11 Python
详解MySQL数据类型int(M)中M的含义
Nov 20 Python
python django事务transaction源码分析详解
Mar 17 Python
python3操作mysql数据库的方法
Jun 23 Python
django2用iframe标签完成网页内嵌播放b站视频功能
Jun 20 Python
对Python多线程读写文件加锁的实例详解
Jan 14 Python
tensorflow:指定gpu 限制使用量百分比,设置最小使用量的实现
Feb 06 Python
Python爬虫工具requests-html使用解析
Apr 29 Python
在Keras中利用np.random.shuffle()打乱数据集实例
Jun 15 Python
使用Keras建立模型并训练等一系列操作方式
Jul 02 Python
详解python 降级到3.6终极解决方案
Feb 06 #Python
PyCharm如何导入python项目的方法
Feb 06 #Python
tensorflow 环境变量设置方式
Feb 06 #Python
快速查找Python安装路径方法
Feb 06 #Python
运行tensorflow python程序,限制对GPU和CPU的占用操作
Feb 06 #Python
如何在django中添加日志功能
Feb 06 #Python
keras tensorflow 实现在python下多进程运行
Feb 06 #Python
You might like
openPNE常用方法分享
2011/11/29 PHP
CI框架源码阅读,系统常量文件constants.php的配置
2013/02/28 PHP
Javascript 生成指定范围数值随机数
2009/01/09 Javascript
js计算字符串长度包含的中文是utf8格式
2013/10/15 Javascript
利用JavaScript检测CPU使用率自己写的
2014/03/22 Javascript
一个简单的实现下拉框多选的插件可移植性比较好
2014/05/05 Javascript
了不起的node.js读书笔记之node.js中的特性
2014/12/22 Javascript
JavaScript中判断变量是数组、函数或是对象类型的方法
2015/02/25 Javascript
WordPress 单页面上一页下一页的实现方法【附代码】
2016/03/10 Javascript
AngularJS ng-app 指令实例详解
2016/07/30 Javascript
AngularJS中过滤器的使用与自定义实例代码
2016/09/17 Javascript
移动端刮刮乐的实现方式(js+HTML5)
2017/03/23 Javascript
JavaScript判断输入是否为数字类型的方法总结
2017/09/28 Javascript
京东优选小程序的实现代码示例
2020/02/25 Javascript
[05:49]DOTA2-DPC中国联赛 正赛 Elephant vs LBZS 选手采访
2021/03/11 DOTA
Python中的复制操作及copy模块中的浅拷贝与深拷贝方法
2016/07/02 Python
Python中使用多进程来实现并行处理的方法小结
2017/08/09 Python
Python实现识别手写数字 简易图片存储管理系统
2018/01/29 Python
python安装模块如何通过setup.py安装(超简单)
2018/05/05 Python
python与caffe改变通道顺序的方法
2018/08/04 Python
Python之列表实现栈的工作功能
2019/01/28 Python
Python笔记之代理模式
2019/11/20 Python
PyCharm GUI界面开发和exe文件生成的实现
2020/03/04 Python
解决Tensorflow2.0 tf.keras.Model.load_weights() 报错处理问题
2020/06/12 Python
基于Tensorflow读取MNIST数据集时网络超时的解决方式
2020/06/22 Python
如何利用python读取micaps文件详解
2020/10/18 Python
css3之UI元素状态伪类选择器实例演示
2017/08/11 HTML / CSS
深入研究HTML5实现图片压缩上传功能
2016/03/25 HTML / CSS
平面设计自荐信
2013/10/07 职场文书
思想汇报范文
2013/11/04 职场文书
会计系个人求职信范文分享
2013/12/20 职场文书
2014年百日安全生产活动总结
2014/05/04 职场文书
人民调解员先进事迹材料
2014/05/08 职场文书
入党推优材料
2014/06/02 职场文书
2016教师年度考核评语大全
2015/12/01 职场文书
Redis 哨兵机制及配置实现
2022/03/25 Redis