tensorflow使用神经网络实现mnist分类


Posted in Python onSeptember 08, 2018

本文实例为大家分享了tensorflow神经网络实现mnist分类的具体代码,供大家参考,具体内容如下

只有两层的神经网络,直接上代码

#引入包
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
#引入input_data文件
from tensorflow.examples.tutorials.mnist import input_data
#读取文件
mnist = input_data.read_data_sets('F:/mnist/data/',one_hot=True)

#定义第一个隐藏层和第二个隐藏层,输入层输出层
n_hidden_1 = 256
n_hidden_2 = 128
n_input = 784
n_classes = 10

#由于不知道输入图片个数,所以用placeholder
x = tf.placeholder("float",[None,n_input])
y = tf.placeholder("float",[None,n_classes])

stddev = 0.1

#定义权重
weights = {
    'w1':tf.Variable(tf.random_normal([n_input,n_hidden_1],stddev = stddev)),
    'w2':tf.Variable(tf.random_normal([n_hidden_1,n_hidden_2],stddev=stddev)),
    'out':tf.Variable(tf.random_normal([n_hidden_2,n_classes],stddev=stddev))    
    }

#定义偏置
biases = {
    'b1':tf.Variable(tf.random_normal([n_hidden_1])),
    'b2':tf.Variable(tf.random_normal([n_hidden_2])),
    'out':tf.Variable(tf.random_normal([n_classes])), 
    }
print("Network is Ready")


#前向传播
def multilayer_perceptrin(_X,_weights,_biases):
  layer1 = tf.nn.sigmoid(tf.add(tf.matmul(_X,_weights['w1']),_biases['b1']))
  layer2 = tf.nn.sigmoid(tf.add(tf.matmul(layer1,_weights['w2']),_biases['b2']))
  return (tf.matmul(layer2,_weights['out'])+_biases['out'])

#定义优化函数,精准度等
pred = multilayer_perceptrin(x,weights,biases)
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits = pred,labels=y))
optm = tf.train.GradientDescentOptimizer(learning_rate = 0.001).minimize(cost)
corr = tf.equal(tf.argmax(pred,1),tf.argmax(y,1))
accr = tf.reduce_mean(tf.cast(corr,"float"))
print("Functions is ready")

#定义超参数
training_epochs = 80
batch_size = 200
display_step = 4

#会话开始
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)

#优化
for epoch in range(training_epochs):
  avg_cost=0.
  total_batch = int(mnist.train.num_examples/batch_size)

  for i in range(total_batch):
    batch_xs,batch_ys = mnist.train.next_batch(batch_size)
    feeds = {x:batch_xs,y:batch_ys}
    sess.run(optm,feed_dict = feeds)
    avg_cost += sess.run(cost,feed_dict=feeds)
  avg_cost = avg_cost/total_batch

  if (epoch+1) % display_step ==0:
    print("Epoch:%03d/%03d cost:%.9f"%(epoch,training_epochs,avg_cost))
    feeds = {x:batch_xs,y:batch_ys}
    train_acc = sess.run(accr,feed_dict = feeds)
    print("Train accuracy:%.3f"%(train_acc))
    feeds = {x:mnist.test.images,y:mnist.test.labels}
    test_acc = sess.run(accr,feed_dict = feeds)
    print("Test accuracy:%.3f"%(test_acc))
print("Optimization Finished")

程序部分运行结果如下:

Train accuracy:0.605
Test accuracy:0.633
Epoch:071/080 cost:1.810029302
Train accuracy:0.600
Test accuracy:0.645
Epoch:075/080 cost:1.761531130
Train accuracy:0.690
Test accuracy:0.649
Epoch:079/080 cost:1.711757494
Train accuracy:0.640
Test accuracy:0.660
Optimization Finished

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现删除Android工程中的冗余字符串
Jan 19 Python
Python使用dis模块把Python反编译为字节码的用法详解
Jun 14 Python
python Flask实现restful api service
Dec 04 Python
python将.ppm格式图片转换成.jpg格式文件的方法
Oct 27 Python
关于Django ForeignKey 反向查询中filter和_set的效率对比详解
Dec 15 Python
Python基础之文件读取的讲解
Feb 16 Python
Python求均值,方差,标准差的实例
Jun 29 Python
Python标准库shutil模块使用方法解析
Mar 10 Python
Python BeautifulReport可视化报告代码实例
Apr 13 Python
keras训练浅层卷积网络并保存和加载模型实例
Jul 02 Python
Python 实现一个计时器
Jul 28 Python
Python urllib库如何添加headers过程解析
Oct 05 Python
Python unittest单元测试框架总结
Sep 08 #Python
tensorflow实现加载mnist数据集
Sep 08 #Python
使用tensorflow实现线性回归
Sep 08 #Python
Python  unittest单元测试框架的使用
Sep 08 #Python
tensorflow实现逻辑回归模型
Sep 08 #Python
Django实现表单验证
Sep 08 #Python
python实现排序算法解析
Sep 08 #Python
You might like
php UTF8 文件的签名问题
2009/10/30 PHP
PhpMyAdmin出现export.php Missing parameter: what /export_type错误解决方法
2012/08/09 PHP
php计算数组相同值出现次数的代码(array_count_values)
2015/01/20 PHP
php实现curl模拟ftp上传的方法
2015/07/29 PHP
一个符号插入器 中用到的js代码
2007/09/04 Javascript
contains和compareDocumentPosition 方法来确定是否HTML节点间的关系
2011/09/13 Javascript
THREE.JS入门教程(5)你应当知道的十件事
2013/01/24 Javascript
JS防止用户多次提交的简单代码
2013/08/01 Javascript
javascript显示用户停留时间的简单实例
2013/08/05 Javascript
使用jQuery插件创建常规模态窗口登陆效果
2013/08/23 Javascript
javascript简单事件处理和with用法介绍
2013/09/16 Javascript
jQuery基于当前元素进行下一步的遍历
2014/05/20 Javascript
node.js中的console.error方法使用说明
2014/12/10 Javascript
JavaScript对象创建模式实例汇总
2016/10/03 Javascript
微信小程序 textarea 详解及简单使用方法
2016/12/05 Javascript
jQuery实现的简单在线计算器功能
2017/05/11 jQuery
Angular实现搜索框及价格上下限功能
2018/01/19 Javascript
js实现删除li标签一行内容
2019/04/16 Javascript
学习LayUI时自研的表单参数校验框架案例分析
2019/07/29 Javascript
在vue项目中使用codemirror插件实现代码编辑器功能
2019/08/27 Javascript
iview实现图片上传功能
2020/06/29 Javascript
vue使用screenfull插件实现全屏功能
2020/09/17 Javascript
vue.js封装switch开关组件的操作
2020/10/26 Javascript
浅析vue中的nextTick
2020/12/28 Vue.js
Python实现读取目录所有文件的文件名并保存到txt文件代码
2014/11/22 Python
python 划分数据集为训练集和测试集的方法
2018/12/11 Python
Python版中国省市经纬度
2020/02/11 Python
python pandas利用fillna方法实现部分自动填充功能
2020/03/16 Python
HTML5 weui使用笔记
2019/11/21 HTML / CSS
埃弗顿足球俱乐部官方网上商店:Everton Direct
2018/01/13 全球购物
竞聘演讲稿精彩开头和结尾
2014/05/14 职场文书
学生会干部自我鉴定2014
2014/09/18 职场文书
中秋节寄语2015
2015/03/24 职场文书
2015年国庆节标语大全
2015/07/30 职场文书
解决Swagger2返回map复杂结构不能解析的问题
2021/07/02 Java/Android
在CSS中使用when/else的方法
2022/01/18 HTML / CSS