TensorFlow实现简单卷积神经网络


Posted in Python onMay 24, 2018

本文使用的数据集是MNIST,主要使用两个卷积层加一个全连接层构建的卷积神经网络。

先载入MNIST数据集(手写数字识别集),并创建默认的Interactive Session(在没有指定回话对象的情况下运行变量)

from tensorflow.examples.tutorials.mnist import input_data 
import tensorflow as tf 
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) 
sess = tf.InteractiveSession()

在定义一个初始化函数,因为卷积神经网络有很多权重和偏置需要创建。

def weight_variable(shape): 
 initial = tf.truncated_normal(shape, stddev=0.1)
#给权重制造一些随机的噪声来打破完全对称, 
 return tf.Variable(initial) 
#使用relu,给偏置增加一些小正值0.1,用来避免死亡节点 
def bias_variable(shape): 
 initial = tf.constant(0.1, shape=shape) 
 return tf.Variable(initial)

卷积移动步长都是1代表会不遗漏的划过图片的每一个点,padding代表边界处理方式,same表示给边界加上padding让卷积的输出和输入保持同样的尺寸。

def conv2d(x,W):#2维卷积函数,x输入,w是卷积的参数,strides代表卷积模板移动步长 
 return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') 
 
def max_pool_2x2(x): 
 return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], 
       padding='SAME')

在正式设计卷积神经网络结构前,先定义输入的placeholder(类似于c++的cin,要求用户运行时输入)。因为卷积神经网络会利用到空间结构信息,因此需要将一维的输入向量转换为二维的图片结构。同时因为只有一个颜色通道,所以最后尺寸为【-1, 28,28, 1],-1代表样本数量不固定,1代表颜色通道的数量。

这里的tf.reshape是tensor变形函数。

x = tf.placeholder(tf.float32, [None, 784])# x 时特征 
y_ = tf.placeholder(tf.float32, [None, 10])# y_时真实的label 
x_image = tf.reshape(x, [-1, 28, 28,1])

接下来定义第一个卷积层。

w_conv1 = weight_variable([5, 5, 1, 32])
#代表卷积核尺寸为5X5,1个颜色通道,32个不同的卷积核,使用conv2d函数进行卷积操作, 
b_conv1 = bias_variable([32]) 
h_conv1 = tf.nn.relu(conv2d(x_image, w_conv1) + b_conv1) 
h_pool1 = max_pool_2x2(h_conv1)

定义第二个卷积层,与第一个卷积层一样,只不过卷积核的数量变成了64,即这层卷积会提取64种特征

w_conv2 = weight_variable([5, 5, 32, 64])#这层提取64种特征 
b_conv2 = bias_variable([64]) 
h_conv2 = tf.nn.relu(conv2d(h_pool1, w_conv2) + b_conv2) 
h_pool2 = max_pool_2x2(h_conv2)

经过两次步长为2x2的最大池化,此时图片尺寸变成了7x7,在使用tf.reshape函数,对第二个卷积层的输出tensor进行变形,将其从二维转为一维向量,在连接一个全连接层(隐含节点为1024),使用relu激活函数。

w_fc1 = weight_variable([7*7*64, 1024]) 
b_fc1 = bias_variable([1024]) 
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64]) 
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, w_fc1) + b_fc1)

Dropout层:随机丢弃一部分节点的数据来减轻过拟合。这里是通过一个placeholder传入keep_prob比率来控制的。

#为了减轻过拟合,使用一个Dropout层 
keep_prob = tf.placeholder(tf.float32) 
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob) 
 
#dropout层的输出连接一个softmax层,得到最后的概率输出 
w_fc2 = weight_variable([1024, 10]) 
b_fc2 = bias_variable([10]) 
y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop, w_fc2) + b_fc2)

定义损失函数即评测准确率操作

#损失函数,并且定义优化器为Adam 
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y_conv), 
            reduction_indices=[1])) 
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) 
 
correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(y_, 1)) 
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

开始训练

#初始化所有参数 
tf.global_variables_initializer().run() 
for i in range (20000): 
 batch = mnist.train.next_batch(50) 
 if i%100 == 0: 
  train_accuracy = accuracy.eval(feed_dict={x:batch[0], y_: batch[1], 
             keep_prob: 1.0}) 
  print("step %d, training accuracy %g"%(i, train_accuracy)) 
 train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})

全部训练完成后,我们在最终的测试集上进行全面的测试,得到整体的分类准确率。

print("test accuracy %g" %accuracy.eval(feed_dict={ 
 x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))

这个网络,参与训练的样本数量总共为100万,共进行20000次训练迭代,使用大小为50的mini_batch。

TensorFlow实现简单卷积神经网络

因为我安装的版本时CPU版的tensorflow,所以运行较慢,这个模型最终的准确性约为99.2%,基本可以满足对手写数字识别准确率的要求。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
总结网络IO模型与select模型的Python实例讲解
Jun 27 Python
教你学会使用Python正则表达式
Sep 07 Python
python读取文本绘制动态速度曲线
Jun 21 Python
python生成器与迭代器详解
Jan 01 Python
对Python中的条件判断、循环以及循环的终止方法详解
Feb 08 Python
python实现给微信指定好友定时发送消息
Apr 29 Python
pyqt5 禁止窗口最大化和禁止窗口拉伸的方法
Jun 18 Python
Python代理IP爬虫的新手使用教程
Sep 05 Python
使用python快速实现不同机器间文件夹共享方式
Dec 22 Python
Nginx+Uwsgi+Django 项目部署到服务器的思路详解
May 08 Python
Python机器学习算法之决策树算法的实现与优缺点
May 13 Python
Python数据类型最全知识总结
May 31 Python
解决matplotlib库show()方法不显示图片的问题
May 24 #Python
解决pandas无法在pycharm中使用plot()方法显示图像的问题
May 24 #Python
解决seaborn在pycharm中绘图不出图的问题
May 24 #Python
快速解决PyCharm无法引用matplotlib的问题
May 24 #Python
Django rest framework实现分页的示例
May 24 #Python
解决Matplotlib图表不能在Pycharm中显示的问题
May 24 #Python
Python系统监控模块psutil功能与经典用法分析
May 24 #Python
You might like
Codeigniter生成Excel文档的简单方法
2014/06/12 PHP
PHP常见漏洞攻击分析
2016/02/21 PHP
Yii2中多表关联查询hasOne hasMany的方法
2017/02/15 PHP
关于IE浏览器以及Firefox下的javascript冒泡事件的响应层级
2010/10/14 Javascript
jQuery ajax serialize()方法的使用以及常见问题解决
2013/01/27 Javascript
js获取html页面节点方法(递归方式)
2013/12/13 Javascript
js防止页面被iframe调用的方法
2014/10/30 Javascript
JQuery中DOM实现事件移除的方法
2015/06/13 Javascript
巧方法 JavaScript获取超链接的绝对URL地址
2016/06/14 Javascript
AngularJS 过滤器的简单实例
2016/07/27 Javascript
AngularJS 整理一些优化的小技巧
2016/08/18 Javascript
Javascript实现找不同色块的游戏
2017/07/17 Javascript
angular中的cookie读写方法
2017/08/02 Javascript
Angular中封装fancyBox(图片预览)遇到问题小结
2017/09/01 Javascript
vue+vuecli+webpack中使用mockjs模拟后端数据的示例
2017/10/24 Javascript
简易Vue评论框架的实现(父组件的实现)
2018/01/08 Javascript
Angular2.0实现modal对话框的方法示例
2018/02/18 Javascript
Vue2.0 给Tab标签页和页面切换过渡添加样式的方法
2018/03/13 Javascript
解决Layui当中的导航条动态添加后渲染失败的问题
2019/09/25 Javascript
Ant Design moment对象和字符串之间的相互转化教程
2020/10/27 Javascript
[50:34]VGJ.T vs Fnatic 2018国际邀请赛小组赛BO2 第二场 8.16
2018/08/17 DOTA
python dict remove数组删除(del,pop)
2013/03/24 Python
python文件比较示例分享
2014/01/10 Python
安装python3的时候就是输入python3死活没有反应的解决方法
2018/01/24 Python
python 生成图形验证码的方法示例
2018/11/11 Python
基于Python实现迪杰斯特拉和弗洛伊德算法
2020/05/27 Python
django使用haystack调用Elasticsearch实现索引搜索
2019/07/24 Python
Python类中的魔法方法之 __slots__原理解析
2019/08/26 Python
Python变量作用域LEGB用法解析
2020/02/04 Python
深入了解Python enumerate和zip
2020/07/16 Python
HTML5标签大全
2016/11/23 HTML / CSS
详解canvas绘制网络字体几种方法
2019/08/27 HTML / CSS
沃达丰英国有限公司:Vodafone英国
2019/04/16 全球购物
公司员工奖惩制度
2015/08/04 职场文书
详解Vue中$props、$attrs和$listeners的使用方法
2022/02/18 Vue.js