TensorFlow实现简单卷积神经网络


Posted in Python onMay 24, 2018

本文使用的数据集是MNIST,主要使用两个卷积层加一个全连接层构建的卷积神经网络。

先载入MNIST数据集(手写数字识别集),并创建默认的Interactive Session(在没有指定回话对象的情况下运行变量)

from tensorflow.examples.tutorials.mnist import input_data 
import tensorflow as tf 
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) 
sess = tf.InteractiveSession()

在定义一个初始化函数,因为卷积神经网络有很多权重和偏置需要创建。

def weight_variable(shape): 
 initial = tf.truncated_normal(shape, stddev=0.1)
#给权重制造一些随机的噪声来打破完全对称, 
 return tf.Variable(initial) 
#使用relu,给偏置增加一些小正值0.1,用来避免死亡节点 
def bias_variable(shape): 
 initial = tf.constant(0.1, shape=shape) 
 return tf.Variable(initial)

卷积移动步长都是1代表会不遗漏的划过图片的每一个点,padding代表边界处理方式,same表示给边界加上padding让卷积的输出和输入保持同样的尺寸。

def conv2d(x,W):#2维卷积函数,x输入,w是卷积的参数,strides代表卷积模板移动步长 
 return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') 
 
def max_pool_2x2(x): 
 return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], 
       padding='SAME')

在正式设计卷积神经网络结构前,先定义输入的placeholder(类似于c++的cin,要求用户运行时输入)。因为卷积神经网络会利用到空间结构信息,因此需要将一维的输入向量转换为二维的图片结构。同时因为只有一个颜色通道,所以最后尺寸为【-1, 28,28, 1],-1代表样本数量不固定,1代表颜色通道的数量。

这里的tf.reshape是tensor变形函数。

x = tf.placeholder(tf.float32, [None, 784])# x 时特征 
y_ = tf.placeholder(tf.float32, [None, 10])# y_时真实的label 
x_image = tf.reshape(x, [-1, 28, 28,1])

接下来定义第一个卷积层。

w_conv1 = weight_variable([5, 5, 1, 32])
#代表卷积核尺寸为5X5,1个颜色通道,32个不同的卷积核,使用conv2d函数进行卷积操作, 
b_conv1 = bias_variable([32]) 
h_conv1 = tf.nn.relu(conv2d(x_image, w_conv1) + b_conv1) 
h_pool1 = max_pool_2x2(h_conv1)

定义第二个卷积层,与第一个卷积层一样,只不过卷积核的数量变成了64,即这层卷积会提取64种特征

w_conv2 = weight_variable([5, 5, 32, 64])#这层提取64种特征 
b_conv2 = bias_variable([64]) 
h_conv2 = tf.nn.relu(conv2d(h_pool1, w_conv2) + b_conv2) 
h_pool2 = max_pool_2x2(h_conv2)

经过两次步长为2x2的最大池化,此时图片尺寸变成了7x7,在使用tf.reshape函数,对第二个卷积层的输出tensor进行变形,将其从二维转为一维向量,在连接一个全连接层(隐含节点为1024),使用relu激活函数。

w_fc1 = weight_variable([7*7*64, 1024]) 
b_fc1 = bias_variable([1024]) 
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64]) 
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, w_fc1) + b_fc1)

Dropout层:随机丢弃一部分节点的数据来减轻过拟合。这里是通过一个placeholder传入keep_prob比率来控制的。

#为了减轻过拟合,使用一个Dropout层 
keep_prob = tf.placeholder(tf.float32) 
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob) 
 
#dropout层的输出连接一个softmax层,得到最后的概率输出 
w_fc2 = weight_variable([1024, 10]) 
b_fc2 = bias_variable([10]) 
y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop, w_fc2) + b_fc2)

定义损失函数即评测准确率操作

#损失函数,并且定义优化器为Adam 
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y_conv), 
            reduction_indices=[1])) 
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) 
 
correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(y_, 1)) 
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

开始训练

#初始化所有参数 
tf.global_variables_initializer().run() 
for i in range (20000): 
 batch = mnist.train.next_batch(50) 
 if i%100 == 0: 
  train_accuracy = accuracy.eval(feed_dict={x:batch[0], y_: batch[1], 
             keep_prob: 1.0}) 
  print("step %d, training accuracy %g"%(i, train_accuracy)) 
 train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})

全部训练完成后,我们在最终的测试集上进行全面的测试,得到整体的分类准确率。

print("test accuracy %g" %accuracy.eval(feed_dict={ 
 x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))

这个网络,参与训练的样本数量总共为100万,共进行20000次训练迭代,使用大小为50的mini_batch。

TensorFlow实现简单卷积神经网络

因为我安装的版本时CPU版的tensorflow,所以运行较慢,这个模型最终的准确性约为99.2%,基本可以满足对手写数字识别准确率的要求。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现批量转换文件编码(批转换编码示例)
Jan 23 Python
Python在线运行代码助手
Jul 15 Python
通过源码分析Python中的切片赋值
May 08 Python
python+django加载静态网页模板解析
Dec 12 Python
Python将多份excel表格整理成一份表格
Jan 03 Python
python OpenCV学习笔记实现二维直方图
Feb 08 Python
对Python正则匹配IP、Url、Mail的方法详解
Dec 25 Python
在python中利用try..except来代替if..else的用法
Dec 19 Python
Windows 下python3.8环境安装教程图文详解
Mar 11 Python
python爬虫学习笔记之Beautifulsoup模块用法详解
Apr 09 Python
python 发送邮件的示例代码(Python2/3都可以直接使用)
Dec 03 Python
Python爬虫爬取微博热搜保存为 Markdown 文件的源码
Feb 22 Python
解决matplotlib库show()方法不显示图片的问题
May 24 #Python
解决pandas无法在pycharm中使用plot()方法显示图像的问题
May 24 #Python
解决seaborn在pycharm中绘图不出图的问题
May 24 #Python
快速解决PyCharm无法引用matplotlib的问题
May 24 #Python
Django rest framework实现分页的示例
May 24 #Python
解决Matplotlib图表不能在Pycharm中显示的问题
May 24 #Python
Python系统监控模块psutil功能与经典用法分析
May 24 #Python
You might like
第五节--克隆
2006/11/16 PHP
php is_file 判断给定文件名是否为一个正常的文件
2010/05/10 PHP
ThinkPHP模板判断输出Empty标签用法详解
2014/06/30 PHP
浅谈PHP安全防护之Web攻击
2017/01/03 PHP
防止动态加载JavaScript引起的内存泄漏问题
2009/10/08 Javascript
让你的博客飘雪花超出屏幕依然看得见
2013/01/04 Javascript
javascript中自定义对象的属性方法分享
2013/07/12 Javascript
时间戳转换为时间 年月日时间的JS函数
2013/08/19 Javascript
javascript获取URL参数与参数值的示例代码
2013/12/20 Javascript
jquery fancybox ie6不显示关闭按钮的解决办法
2013/12/25 Javascript
使用jquery.validate自定义方法实现"手机号码或者固话至少填写一个"的逻辑验证
2014/09/01 Javascript
JavaScript中使用Callback控制流程介绍
2015/03/16 Javascript
基于Javascript实现弹出页面效果
2016/01/01 Javascript
浅谈jQuery中的$.extend方法来扩展JSON对象
2017/02/12 Javascript
jQuery自定义多选下拉框效果
2017/06/19 jQuery
Node.js中Bootstrap-table的两种分页的实现方法
2017/09/18 Javascript
详解vue.js之props传递参数
2017/12/12 Javascript
Vue实现美团app的影院推荐选座功能【推荐】
2018/08/29 Javascript
vue-cli 引入jQuery,Bootstrap,popper的方法
2018/09/03 jQuery
Node.js API详解之 console模块用法详解
2020/05/12 Javascript
js实现简单音乐播放器
2020/06/30 Javascript
[44:26]DOTA2上海特级锦标赛主赛事日 - 2 胜者组第一轮#4EG VS Fnatic第二局
2016/03/03 DOTA
Python 转义字符详细介绍
2017/03/21 Python
利用python批量修改word文件名的方法示例
2017/10/17 Python
python实现列表中由数值查到索引的方法
2018/06/27 Python
详解如何在css中引入自定义字体(font-face)
2018/05/17 HTML / CSS
会计学财务管理专业个人的自我评价
2013/10/19 职场文书
员工年终演讲稿
2014/01/03 职场文书
道路交通安全实施方案
2014/03/12 职场文书
学习型党组织心得体会
2014/09/12 职场文书
医药公司采购员岗位职责
2014/09/12 职场文书
入党积极分子十八届四中全会思想汇报
2014/10/23 职场文书
义诊活动总结
2015/02/04 职场文书
2016年社区“6.26”禁毒日宣传活动总结
2016/04/05 职场文书
html5移动端禁止长按图片保存的实现
2021/04/20 HTML / CSS
JavaScript阻止事件冒泡的方法
2021/12/06 Javascript