tensorflow使用神经网络实现mnist分类


Posted in Python onSeptember 08, 2018

本文实例为大家分享了tensorflow神经网络实现mnist分类的具体代码,供大家参考,具体内容如下

只有两层的神经网络,直接上代码

#引入包
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
#引入input_data文件
from tensorflow.examples.tutorials.mnist import input_data
#读取文件
mnist = input_data.read_data_sets('F:/mnist/data/',one_hot=True)

#定义第一个隐藏层和第二个隐藏层,输入层输出层
n_hidden_1 = 256
n_hidden_2 = 128
n_input = 784
n_classes = 10

#由于不知道输入图片个数,所以用placeholder
x = tf.placeholder("float",[None,n_input])
y = tf.placeholder("float",[None,n_classes])

stddev = 0.1

#定义权重
weights = {
    'w1':tf.Variable(tf.random_normal([n_input,n_hidden_1],stddev = stddev)),
    'w2':tf.Variable(tf.random_normal([n_hidden_1,n_hidden_2],stddev=stddev)),
    'out':tf.Variable(tf.random_normal([n_hidden_2,n_classes],stddev=stddev))    
    }

#定义偏置
biases = {
    'b1':tf.Variable(tf.random_normal([n_hidden_1])),
    'b2':tf.Variable(tf.random_normal([n_hidden_2])),
    'out':tf.Variable(tf.random_normal([n_classes])), 
    }
print("Network is Ready")


#前向传播
def multilayer_perceptrin(_X,_weights,_biases):
  layer1 = tf.nn.sigmoid(tf.add(tf.matmul(_X,_weights['w1']),_biases['b1']))
  layer2 = tf.nn.sigmoid(tf.add(tf.matmul(layer1,_weights['w2']),_biases['b2']))
  return (tf.matmul(layer2,_weights['out'])+_biases['out'])

#定义优化函数,精准度等
pred = multilayer_perceptrin(x,weights,biases)
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits = pred,labels=y))
optm = tf.train.GradientDescentOptimizer(learning_rate = 0.001).minimize(cost)
corr = tf.equal(tf.argmax(pred,1),tf.argmax(y,1))
accr = tf.reduce_mean(tf.cast(corr,"float"))
print("Functions is ready")

#定义超参数
training_epochs = 80
batch_size = 200
display_step = 4

#会话开始
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)

#优化
for epoch in range(training_epochs):
  avg_cost=0.
  total_batch = int(mnist.train.num_examples/batch_size)

  for i in range(total_batch):
    batch_xs,batch_ys = mnist.train.next_batch(batch_size)
    feeds = {x:batch_xs,y:batch_ys}
    sess.run(optm,feed_dict = feeds)
    avg_cost += sess.run(cost,feed_dict=feeds)
  avg_cost = avg_cost/total_batch

  if (epoch+1) % display_step ==0:
    print("Epoch:%03d/%03d cost:%.9f"%(epoch,training_epochs,avg_cost))
    feeds = {x:batch_xs,y:batch_ys}
    train_acc = sess.run(accr,feed_dict = feeds)
    print("Train accuracy:%.3f"%(train_acc))
    feeds = {x:mnist.test.images,y:mnist.test.labels}
    test_acc = sess.run(accr,feed_dict = feeds)
    print("Test accuracy:%.3f"%(test_acc))
print("Optimization Finished")

程序部分运行结果如下:

Train accuracy:0.605
Test accuracy:0.633
Epoch:071/080 cost:1.810029302
Train accuracy:0.600
Test accuracy:0.645
Epoch:075/080 cost:1.761531130
Train accuracy:0.690
Test accuracy:0.649
Epoch:079/080 cost:1.711757494
Train accuracy:0.640
Test accuracy:0.660
Optimization Finished

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中死锁的形成示例及死锁情况的防止
Jun 14 Python
python实现随机漫步方法和原理
Jun 10 Python
python读取Excel表格文件的方法
Sep 02 Python
python中提高pip install速度
Feb 14 Python
django模型动态修改参数,增加 filter 字段的方式
Mar 16 Python
Python基于pyecharts实现关联图绘制
Mar 27 Python
Python闭包及装饰器运行原理解析
Jun 17 Python
如何实现一个python函数装饰器(Decorator)
Oct 12 Python
Django url 路由匹配过程详解
Jan 22 Python
python 列表推导和生成器表达式的使用
Feb 01 Python
Python之Sklearn使用入门教程
Feb 19 Python
Python中常见的导入方式总结
May 06 Python
Python unittest单元测试框架总结
Sep 08 #Python
tensorflow实现加载mnist数据集
Sep 08 #Python
使用tensorflow实现线性回归
Sep 08 #Python
Python  unittest单元测试框架的使用
Sep 08 #Python
tensorflow实现逻辑回归模型
Sep 08 #Python
Django实现表单验证
Sep 08 #Python
python实现排序算法解析
Sep 08 #Python
You might like
thinkPHP5.0框架简单配置作用域的方法
2017/03/17 PHP
php判断str字符串是否是xml格式数据的方法示例
2017/07/26 PHP
Laravel中获取路由参数Route Parameters的五种方法示例
2017/09/29 PHP
PHP实现图的邻接矩阵表示及几种简单遍历算法分析
2017/11/24 PHP
PHP实现的最大正向匹配算法示例
2017/12/19 PHP
ThinkPHP中图片按比例切割的代码实例
2019/03/08 PHP
Yii框架分页技术实例分析
2019/08/30 PHP
js播放wav文件(源码)
2013/04/22 Javascript
js获取某月的最后一天日期的简单实例
2013/06/22 Javascript
JavaScript电子时钟倒计时第二款
2016/01/10 Javascript
一步一步封装自己的HtmlHelper组件BootstrapHelper(三)
2016/09/14 Javascript
微信小程序实现图片预加载组件
2017/01/18 Javascript
EasyUI为Numberbox添加blur事件的方法
2017/03/05 Javascript
JS实现table表格固定表头且表头随横向滚动而滚动
2017/10/26 Javascript
ES6使用Set数据结构实现数组的交集、并集、差集功能示例
2017/10/31 Javascript
基于vue实现分页效果
2017/11/06 Javascript
js使用ajax传值给后台,后台返回字符串处理方法
2018/08/08 Javascript
详解关于Angular4 ng-zorro使用过程中遇到的问题
2018/12/05 Javascript
详解React项目中碰到的IE问题
2019/03/14 Javascript
javascript实现弹出层效果
2019/12/10 Javascript
JavaScript实现指定数量的并发限制的示例代码
2020/03/10 Javascript
Node.js API详解之 querystring用法实例分析
2020/04/29 Javascript
深入了解JS之作用域和闭包
2020/06/16 Javascript
python复制文件代码实现
2013/12/23 Python
Python创建xml文件示例
2017/03/22 Python
详解django中自定义标签和过滤器
2017/07/03 Python
Python中的探索性数据分析(功能式)
2017/12/22 Python
对numpy和pandas中数组的合并和拆分详解
2018/04/11 Python
详解使用python绘制混淆矩阵(confusion_matrix)
2019/07/14 Python
Django处理Ajax发送的Get请求代码详解
2019/07/29 Python
英国工具中心:UK Tool Centre
2017/07/10 全球购物
Groupon法国官方网站:特卖和网上购物高达-70%
2019/09/02 全球购物
Alexandre Birman美国官网:亚历山大·伯曼
2019/10/30 全球购物
2014年评职称工作总结
2014/11/20 职场文书
工伤私了协议书范本
2014/11/24 职场文书
中学教代会开幕词
2016/03/04 职场文书