TensorFlow实现卷积神经网络


Posted in Python onMay 24, 2018

本文实例为大家分享了TensorFlow实现卷积神经网络的具体代码,供大家参考,具体内容如下

代码(源代码都有详细的注释)和数据集可以在github下载:

# -*- coding: utf-8 -*-
'''卷积神经网络测试MNIST数据'''

#########导入MNIST数据########
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
mnist = input_data.read_data_sets('MNIST_data/', one_hot=True)

# 创建默认InteractiveSession
sess = tf.InteractiveSession()


#########卷积网络会有很多的权重和偏置需要创建,先定义好初始化函数以便复用########
# 给权重制造一些随机噪声打破完全对称(比如截断的正态分布噪声,标准差设为0.1)
def weight_variable(shape):
 initial = tf.truncated_normal(shape, stddev=0.1)
 return tf.Variable(initial)
# 因为我们要使用ReLU,也给偏置增加一些小的正值(0.1)用来避免死亡节点(dead neurons)
def bias_variable(shape):
 initial = tf.constant(0.1, shape=shape)
 return tf.Variable(initial)


########卷积层、池化层接下来重复使用的,分别定义创建函数########
# tf.nn.conv2d是TensorFlow中的2维卷积函数
def conv2d(x, W):
 return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')
# 使用2*2的最大池化
def max_pool_2x2(x):
 return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],strides=[1, 2, 2, 1], padding='SAME')


########正式设计卷积神经网络之前先定义placeholder########
# x是特征,y_是真实label。将图片数据从1D转为2D。使用tensor的变形函数tf.reshape
x = tf.placeholder(tf.float32, shape=[None, 784])
y_ = tf.placeholder(tf.float32, shape=[None, 10])
x_image = tf.reshape(x,[-1,28,28,1])


########设计卷积神经网络########
# 第一层卷积
# 卷积核尺寸为5*5,1个颜色通道,32个不同的卷积核
W_conv1 = weight_variable([5, 5, 1, 32])
# 用conv2d函数进行卷积操作,加上偏置
b_conv1 = bias_variable([32])
# 把x_image和权值向量进行卷积,加上偏置项,然后应用ReLU激活函数,
h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
# 对卷积的输出结果进行池化操作
h_pool1 = max_pool_2x2(h_conv1)

# 第二层卷积(和第一层大致相同,卷积核为64,这一层卷积会提取64种特征)
W_conv2 = weight_variable([5, 5, 32, 64])
b_conv2 = bias_variable([64])
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
h_pool2 = max_pool_2x2(h_conv2)

# 全连接层。隐含节点数1024。使用ReLU激活函数
W_fc1 = weight_variable([7 * 7 * 64, 1024])
b_fc1 = bias_variable([1024])
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)

# 为了防止过拟合,在输出层之前加Dropout层
keep_prob = tf.placeholder(tf.float32)
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)

# 输出层。添加一个softmax层,就像softmax regression一样。得到概率输出。
W_fc2 = weight_variable([1024, 10])
b_fc2 = bias_variable([10])
y_conv=tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)


########模型训练设置########
# 定义loss function为cross entropy,优化器使用Adam,并给予一个比较小的学习速率1e-4
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y_conv),reduction_indices=[1]))
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)

# 定义评测准确率的操作
correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))


########开始训练过程########
# 初始化所有参数
tf.global_variables_initializer().run()

# 训练(设置训练时Dropout的kepp_prob比率为0.5。mini-batch为50,进行2000次迭代训练,参与训练样本5万)
# 其中每进行100次训练,对准确率进行一次评测keep_prob设置为1,用以实时监测模型的性能
for i in range(1000):
 batch = mnist.train.next_batch(50)
 if i%100 == 0:
  train_accuracy = accuracy.eval(feed_dict={x:batch[0], y_: batch[1], keep_prob: 1.0})
  print "-->step %d, training accuracy %.4f"%(i, train_accuracy)
 train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})
# 全部训练完成之后,在最终测试集上进行全面测试,得到整体的分类准确率
print "卷积神经网络在MNIST数据集正确率: %g"%accuracy.eval(feed_dict={
  x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0})

TensorFlow实现卷积神经网络

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中as用法实例分析
Apr 30 Python
python链接Oracle数据库的方法
Jun 28 Python
python中字符串类型json操作的注意事项
May 02 Python
Python程序运行原理图文解析
Feb 10 Python
Python用于学习重要算法的模块pygorithm实例浅析
Aug 16 Python
用python一行代码得到数组中某个元素的个数方法
Jan 28 Python
使用coverage统计python web项目代码覆盖率的方法详解
Aug 05 Python
pytorch梯度剪裁方式
Feb 04 Python
Python使用type动态创建类操作示例
Feb 29 Python
解决Python数据可视化中文部分显示方块问题
May 16 Python
Manjaro、pip、conda更换国内源的方法
Nov 17 Python
Python OpenCV之常用滤波器使用详解
Apr 07 Python
tensorflow实现简单的卷积神经网络
May 24 #Python
tensorflow实现简单的卷积网络
May 24 #Python
解决pandas 作图无法显示中文的问题
May 24 #Python
TensorFlow实现简单卷积神经网络
May 24 #Python
解决matplotlib库show()方法不显示图片的问题
May 24 #Python
解决pandas无法在pycharm中使用plot()方法显示图像的问题
May 24 #Python
解决seaborn在pycharm中绘图不出图的问题
May 24 #Python
You might like
建立文件交换功能的脚本(三)
2006/10/09 PHP
PHP清除字符串中所有无用标签的方法
2014/12/01 PHP
laravel 5 实现模板主题功能(续)
2015/03/02 PHP
zend框架实现支持sql server的操作方法
2016/12/08 PHP
PHP封装的多文件上传类实例与用法详解
2017/02/07 PHP
PHP 断点续传实例详解
2017/11/11 PHP
php图像生成函数之间的区别分析
2012/12/06 Javascript
javascript-简单的日历实现及Date对象语法介绍(附图)
2013/05/30 Javascript
JQuery+Ajax无刷新分页的实例代码
2014/02/08 Javascript
JS实现队列与堆栈的方法
2016/04/21 Javascript
jQuery 自定义下拉框(DropDown)附源码下载
2016/07/22 Javascript
canvas绘制环形进度条
2017/02/23 Javascript
Bootstrap 网格系统布局详解
2017/03/19 Javascript
前端axios下载excel文件(二进制)的处理方法
2018/07/31 Javascript
JavaScript实现图片放大镜效果
2019/06/27 Javascript
layui关闭弹窗后刷新主页面和当前更改项的例子
2019/09/06 Javascript
微信小程序防止多次点击跳转和防止表单组件输入内容多次验证功能(函数防抖)
2019/09/19 Javascript
nodejs实现的http、https 请求封装操作示例
2020/02/06 NodeJs
JavaScript Event Loop相关原理解析
2020/06/10 Javascript
[03:04]2018年国际邀请赛典藏宝瓶&莱恩声望物品展示 片尾有彩蛋
2018/06/04 DOTA
Python增量循环删除MySQL表数据的方法
2016/09/23 Python
详解Python中的相对导入和绝对导入
2017/01/06 Python
python爬虫之验证码篇3-滑动验证码识别技术
2019/04/11 Python
详解numpy矩阵的创建与数据类型
2019/10/18 Python
python对Excel按条件进行内容补充(推荐)
2019/11/24 Python
html5 localStorage本地存储_动力节点Java学院整理
2017/07/06 HTML / CSS
试解释COMMIT操作和ROLLBACK操作的语义
2014/07/25 面试题
学术会议欢迎词
2014/01/09 职场文书
大学生职业生涯规划范文
2014/01/22 职场文书
陈胜吴广起义口号
2014/06/20 职场文书
小学阳光体育活动总结
2014/07/05 职场文书
2014年公务员工作总结
2014/11/18 职场文书
运动会开幕词
2015/01/28 职场文书
通知怎么写?
2019/04/17 职场文书
MySql统计函数COUNT的具体使用详解
2022/08/14 MySQL
CSS list-style-type属性使用方法
2023/05/21 HTML / CSS