TensorFlow实现卷积神经网络


Posted in Python onMay 24, 2018

本文实例为大家分享了TensorFlow实现卷积神经网络的具体代码,供大家参考,具体内容如下

代码(源代码都有详细的注释)和数据集可以在github下载:

# -*- coding: utf-8 -*-
'''卷积神经网络测试MNIST数据'''

#########导入MNIST数据########
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
mnist = input_data.read_data_sets('MNIST_data/', one_hot=True)

# 创建默认InteractiveSession
sess = tf.InteractiveSession()


#########卷积网络会有很多的权重和偏置需要创建,先定义好初始化函数以便复用########
# 给权重制造一些随机噪声打破完全对称(比如截断的正态分布噪声,标准差设为0.1)
def weight_variable(shape):
 initial = tf.truncated_normal(shape, stddev=0.1)
 return tf.Variable(initial)
# 因为我们要使用ReLU,也给偏置增加一些小的正值(0.1)用来避免死亡节点(dead neurons)
def bias_variable(shape):
 initial = tf.constant(0.1, shape=shape)
 return tf.Variable(initial)


########卷积层、池化层接下来重复使用的,分别定义创建函数########
# tf.nn.conv2d是TensorFlow中的2维卷积函数
def conv2d(x, W):
 return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')
# 使用2*2的最大池化
def max_pool_2x2(x):
 return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],strides=[1, 2, 2, 1], padding='SAME')


########正式设计卷积神经网络之前先定义placeholder########
# x是特征,y_是真实label。将图片数据从1D转为2D。使用tensor的变形函数tf.reshape
x = tf.placeholder(tf.float32, shape=[None, 784])
y_ = tf.placeholder(tf.float32, shape=[None, 10])
x_image = tf.reshape(x,[-1,28,28,1])


########设计卷积神经网络########
# 第一层卷积
# 卷积核尺寸为5*5,1个颜色通道,32个不同的卷积核
W_conv1 = weight_variable([5, 5, 1, 32])
# 用conv2d函数进行卷积操作,加上偏置
b_conv1 = bias_variable([32])
# 把x_image和权值向量进行卷积,加上偏置项,然后应用ReLU激活函数,
h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
# 对卷积的输出结果进行池化操作
h_pool1 = max_pool_2x2(h_conv1)

# 第二层卷积(和第一层大致相同,卷积核为64,这一层卷积会提取64种特征)
W_conv2 = weight_variable([5, 5, 32, 64])
b_conv2 = bias_variable([64])
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
h_pool2 = max_pool_2x2(h_conv2)

# 全连接层。隐含节点数1024。使用ReLU激活函数
W_fc1 = weight_variable([7 * 7 * 64, 1024])
b_fc1 = bias_variable([1024])
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)

# 为了防止过拟合,在输出层之前加Dropout层
keep_prob = tf.placeholder(tf.float32)
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)

# 输出层。添加一个softmax层,就像softmax regression一样。得到概率输出。
W_fc2 = weight_variable([1024, 10])
b_fc2 = bias_variable([10])
y_conv=tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)


########模型训练设置########
# 定义loss function为cross entropy,优化器使用Adam,并给予一个比较小的学习速率1e-4
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y_conv),reduction_indices=[1]))
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)

# 定义评测准确率的操作
correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))


########开始训练过程########
# 初始化所有参数
tf.global_variables_initializer().run()

# 训练(设置训练时Dropout的kepp_prob比率为0.5。mini-batch为50,进行2000次迭代训练,参与训练样本5万)
# 其中每进行100次训练,对准确率进行一次评测keep_prob设置为1,用以实时监测模型的性能
for i in range(1000):
 batch = mnist.train.next_batch(50)
 if i%100 == 0:
  train_accuracy = accuracy.eval(feed_dict={x:batch[0], y_: batch[1], keep_prob: 1.0})
  print "-->step %d, training accuracy %.4f"%(i, train_accuracy)
 train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})
# 全部训练完成之后,在最终测试集上进行全面测试,得到整体的分类准确率
print "卷积神经网络在MNIST数据集正确率: %g"%accuracy.eval(feed_dict={
  x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0})

TensorFlow实现卷积神经网络

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python urllib、urllib2、httplib抓取网页代码实例
May 09 Python
python获取当前时间对应unix时间戳的方法
May 15 Python
Python的Flask框架中SQLAlchemy使用时的乱码问题解决
Nov 07 Python
Django unittest 设置跳过某些case的方法
Dec 26 Python
python绘制直方图和密度图的实例
Jul 08 Python
浅谈python图片处理Image和skimage的区别
Aug 04 Python
pytorch多进程加速及代码优化方法
Aug 19 Python
使用Python paramiko模块利用多线程实现ssh并发执行操作
Dec 05 Python
pip安装tensorflow的坑的解决
Apr 19 Python
解决pip install psycopg2出错问题
Jul 09 Python
Python基于爬虫实现全网搜索并下载音乐
Feb 14 Python
python 爬取吉首大学网站成绩单
Jun 02 Python
tensorflow实现简单的卷积神经网络
May 24 #Python
tensorflow实现简单的卷积网络
May 24 #Python
解决pandas 作图无法显示中文的问题
May 24 #Python
TensorFlow实现简单卷积神经网络
May 24 #Python
解决matplotlib库show()方法不显示图片的问题
May 24 #Python
解决pandas无法在pycharm中使用plot()方法显示图像的问题
May 24 #Python
解决seaborn在pycharm中绘图不出图的问题
May 24 #Python
You might like
解析php中memcache的应用
2013/06/18 PHP
PHP中对缓冲区的控制实现代码
2013/09/29 PHP
php网站地图生成类示例
2014/01/13 PHP
PHP中常用的字符串格式化函数总结
2014/11/19 PHP
PHP+Jquery与ajax相结合实现下拉淡出瀑布流效果【无需插件】
2016/05/06 PHP
一款JavaScript压缩工具:X2JSCompactor
2007/06/13 Javascript
ie 7/8不支持trim的属性的解决方案
2014/05/23 Javascript
javascript实现回车键提交表单方法总结
2015/01/10 Javascript
在javascript中随机数 math random如何生成指定范围数值的随机数
2015/10/21 Javascript
JS实现模拟百度搜索“2012世界末日”网页地震撕裂效果代码
2015/10/31 Javascript
jQuery实现的可编辑表格完整实例
2016/06/20 Javascript
JS中对Cookie的操作详解
2016/08/05 Javascript
jQuery中animate的几种用法与注意事项
2016/12/12 Javascript
Vue项目中使用Vux的安装过程
2018/05/01 Javascript
Webpack之tree-starking 解析
2018/09/11 Javascript
深入理解与使用keep-alive(配合router-view缓存整个路由页面)
2018/09/25 Javascript
模块化react-router配置方法详解
2019/06/03 Javascript
JS 遍历 json 和 JQuery 遍历json操作完整示例
2019/11/11 jQuery
详解Vue之计算属性
2020/06/20 Javascript
vue 微信分享回调iOS和安卓回调出现错误的解决
2020/09/07 Javascript
基于Python实现文件大小输出
2016/01/11 Python
Linux 发邮件磁盘空间监控(python)
2016/04/23 Python
python paramiko利用sftp上传目录到远程的实例
2019/01/03 Python
python基于递归解决背包问题详解
2019/07/03 Python
Python如何计算语句执行时间
2019/11/22 Python
Python通过正则库爬取淘宝商品信息代码实例
2020/03/02 Python
Selenium结合BeautifulSoup4编写简单的python爬虫
2020/11/06 Python
详解CSS3中常用的样式【基本文本和字体样式】
2020/10/20 HTML / CSS
美国内衣第一品牌:Hanes(恒适)
2016/07/29 全球购物
美国领先的奢侈手表在线零售商:WatchMaxx
2017/12/17 全球购物
工商管理毕业生推荐信
2013/12/24 职场文书
教育科研先进个人材料
2014/01/26 职场文书
陈斌强事迹观后感
2015/06/17 职场文书
如何开发一个渐进式Web应用程序PWA
2021/05/10 Javascript
JavaGUI模仿QQ聊天功能完整版
2021/07/04 Java/Android
python数据可视化JupyterLab实用扩展程序Mito
2021/11/20 Python