TensorFlow实现卷积神经网络CNN


Posted in Python onMarch 09, 2018

一、卷积神经网络CNN简介

卷积神经网络(ConvolutionalNeuralNetwork,CNN)最初是为解决图像识别等问题设计的,CNN现在的应用已经不限于图像和视频,也可用于时间序列信号,比如音频信号和文本数据等。CNN作为一个深度学习架构被提出的最初诉求是降低对图像数据预处理的要求,避免复杂的特征工程。在卷积神经网络中,第一个卷积层会直接接受图像像素级的输入,每一层卷积(滤波器)都会提取数据中最有效的特征,这种方法可以提取到图像中最基础的特征,而后再进行组合和抽象形成更高阶的特征,因此CNN在理论上具有对图像缩放、平移和旋转的不变性。

卷积神经网络CNN的要点就是局部连接(LocalConnection)、权值共享(WeightsSharing)和池化层(Pooling)中的降采样(Down-Sampling)。其中,局部连接和权值共享降低了参数量,使训练复杂度大大下降并减轻了过拟合。同时权值共享还赋予了卷积网络对平移的容忍性,池化层降采样则进一步降低了输出参数量并赋予模型对轻度形变的容忍性,提高了模型的泛化能力。可以把卷积层卷积操作理解为用少量参数在图像的多个位置上提取相似特征的过程。

更多请参见:深度学习之卷积神经网络CNN

二、TensorFlow代码实现

#!/usr/bin/env python2 
# -*- coding: utf-8 -*- 
""" 
Created on Thu Mar 9 22:01:46 2017 
 
@author: marsjhao 
""" 
 
import tensorflow as tf 
from tensorflow.examples.tutorials.mnist import input_data 
 
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) 
sess = tf.InteractiveSession() 
 
def weight_variable(shape): 
 initial = tf.truncated_normal(shape, stddev=0.1) #标准差为0.1的正态分布 
 return tf.Variable(initial) 
 
def bias_variable(shape): 
 initial = tf.constant(0.1, shape=shape) #偏差初始化为0.1 
 return tf.Variable(initial) 
 
def conv2d(x, W): 
 return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') 
 
def max_pool_2x2(x): 
 return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], 
       strides=[1, 2, 2, 1], padding='SAME') 
 
x = tf.placeholder(tf.float32, [None, 784]) 
y_ = tf.placeholder(tf.float32, [None, 10]) 
# -1代表先不考虑输入的图片例子多少这个维度,1是channel的数量 
x_image = tf.reshape(x, [-1, 28, 28, 1]) 
keep_prob = tf.placeholder(tf.float32) 
 
# 构建卷积层1 
W_conv1 = weight_variable([5, 5, 1, 32]) # 卷积核5*5,1个channel,32个卷积核,形成32个featuremap 
b_conv1 = bias_variable([32]) # 32个featuremap的偏置 
h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1) # 用relu非线性处理 
h_pool1 = max_pool_2x2(h_conv1) # pooling池化 
 
# 构建卷积层2 
W_conv2 = weight_variable([5, 5, 32, 64]) # 注意这里channel值是32 
b_conv2 = bias_variable([64]) 
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2) 
h_pool2 = max_pool_2x2(h_conv2) 
 
# 构建全连接层1 
W_fc1 = weight_variable([7*7*64, 1024]) 
b_fc1 = bias_variable([1024]) 
h_pool3 = tf.reshape(h_pool2, [-1, 7*7*64]) 
h_fc1 = tf.nn.relu(tf.matmul(h_pool3, W_fc1) + b_fc1) 
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob) 
 
# 构建全连接层2 
W_fc2 = weight_variable([1024, 10]) 
b_fc2 = bias_variable([10]) 
y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2) 
 
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y_conv), 
            reduction_indices=[1])) 
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) 
correct_prediction = tf.equal(tf.arg_max(y_conv, 1), tf.arg_max(y_, 1)) 
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 
 
tf.global_variables_initializer().run() 
 
for i in range(20001): 
 batch = mnist.train.next_batch(50) 
 if i % 100 == 0: 
  train_accuracy = accuracy.eval(feed_dict={x:batch[0], y_:batch[1], 
             keep_prob: 1.0}) 
  print("step %d, training accuracy %g" %(i, train_accuracy)) 
 train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob:0.5}) 
print("test accuracy %g" %accuracy.eval(feed_dict={x: mnist.test.images, 
         y_: mnist.test.labels, keep_prob: 1.0}))

三、代码解读

该代码是用TensorFlow实现一个简单的卷积神经网络,在数据集MNIST上,预期可以实现99.2%左右的准确率。结构上使用两个卷积层和一个全连接层。

首先载入MNIST数据集,采用独热编码,并创建tf.InteractiveSession。然后为后续即将多次使用的部分代码创建函数,包括权重初始化weight_variable、偏置初始化bias_variable、卷积层conv2d、最大池化max_pool_2x2。其中权重初始化的时候要进行含有噪声的非对称初始化,打破完全对称。又由于我们要使用ReLU单元,也需要给偏置bias增加一些小的正值(0.1)用来避免死亡节点(dead neurons)。

构建卷积神经网络之前,先要定义输入的placeholder,特征x和真实标签y_,将1*784格式的特征x转换reshape为28*28的图片格式,又由于只有一个通道且不确定输入样本的数量,故最终尺寸为[-1, 28, 28, 1]。

接下来定义第一个卷积层,首先初始化weights和bias,然后使用conv2d进行卷积操作并加上偏置,随后使用ReLU激活函数进行非线性处理,最后使用最大池化函数对卷积的输出结果进行池化操作。

相同的步骤定义第二个卷积层,不同的地方是卷积核的数量为64,也就是说这一层的卷积会提取64种特征。经过两层不变尺寸的卷积和两次尺寸减半的池化,第二个卷积层后的输出尺寸为7*7*64。将其reshape为长度为7*7*64的1-D向量。经过ReLU后,为了减轻过拟合,使用一个Dropout层,在训练时随机丢弃部分节点的数据减轻过拟合,在预测的时候保留全部数据来追求最好的测试性能。

最后加一个Softmax层,得到最后的预测概率。随后的定义损失函数、优化器、评测准确率不再详细赘述。

训练过程首先进行初始化全部参数,训练时keep_prob比率设置为0.5,评测时设置为1。训练完成后,在最终的测试集上进行全面的测试,得到整体的分类准确率。

经过实验,这个CNN的模型可以得到99.2%的准确率,相比于MLP又有了较大幅度的提高。

四、其他解读补充

1. tf.nn.conv2d(x,W, strides=[1, 1, 1, 1], padding='SAME')

tf.nn.conv2d是TensorFlow的2维卷积函数,x和W都是4-D的tensors。x是输入input shape=[batch,in_height, in_width, in_channels],W是卷积的参数filter / kernel shape=[filter_height, filter_width, in_channels,out_channels]。strides参数是长度为4的1-D参数,代表了卷积核(滑动窗口)移动的步长,其中对于图片strides[0]和strides[3]必须是1,都是1表示不遗漏地划过图片的每一个点。padding参数中SAME代表给边界加上Padding让卷积的输出和输入保持相同的尺寸。

2. tf.nn.max_pool(x,ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')

tf.nn.max_pool是TensorFlow中的最大池化函数,x是4-D的输入tensor shape=[batch, height, width, channels],ksize参数表示池化窗口的大小,取一个4维向量,一般是[1, height, width, 1],因为我们不想在batch和channels上做池化,所以这两个维度设为了1,strides与tf.nn.conv2d相同,strides=[1, 2, 2, 1]可以缩小图片尺寸。padding参数也参见tf.nn.conv2d。

Python 相关文章推荐
python获取Linux下文件版本信息、公司名和产品名的方法
Oct 05 Python
Python计算已经过去多少个周末的方法
Jul 25 Python
使用Python对Access读写操作
Mar 30 Python
python3获取当前文件的上一级目录实例
Apr 26 Python
Tesserocr库的正确安装方式
Oct 19 Python
python 判断字符串中是否含有汉字或非汉字的实例
Jul 15 Python
Python如何筛选序列中的元素的方法实现
Jul 15 Python
Python高级编程之消息队列(Queue)与进程池(Pool)实例详解
Nov 01 Python
python 协程 gevent原理与用法分析
Nov 22 Python
python def 定义函数,调用函数方式
Jun 02 Python
使用SQLAlchemy操作数据库表过程解析
Jun 10 Python
用python进行视频剪辑
Nov 02 Python
新手常见6种的python报错及解决方法
Mar 09 #Python
Python 函数基础知识汇总
Mar 09 #Python
Python 使用with上下文实现计时功能
Mar 09 #Python
TensorFlow搭建神经网络最佳实践
Mar 09 #Python
TensorFlow实现Batch Normalization
Mar 08 #Python
用Django实现一个可运行的区块链应用
Mar 08 #Python
Python pyinotify日志监控系统处理日志的方法
Mar 08 #Python
You might like
如何使用PHP获取指定日期所在月的开始日期与结束日期
2013/08/01 PHP
php 浮点数比较方法详解
2017/05/05 PHP
PHP实现的注册,登录及查询用户资料功能API接口示例
2017/06/06 PHP
实用javaScript技术-屏蔽类
2006/08/15 Javascript
一个js封装的不错的选项卡效果代码
2008/02/15 Javascript
JS 密码强度验证(兼容IE,火狐,谷歌)
2010/03/15 Javascript
jQuery ajax BUG:object doesn't support this property or method
2010/07/06 Javascript
javascript中的对象创建 实例附注释
2011/02/08 Javascript
js 代码优化点滴记录
2012/02/19 Javascript
第一次接触神奇的Bootstrap网格系统
2016/07/27 Javascript
JS实现获取当前URL和来源URL的方法
2016/08/24 Javascript
etmvc+jQuery EasyUI+combobox多值操作实现角色授权实例
2016/11/09 Javascript
AngularJS入门教程之Helloworld示例
2016/12/25 Javascript
微信小程序 中wx.chooseAddress(OBJECT)实例详解
2017/03/31 Javascript
AngularJS遍历获取数组元素的方法示例
2017/11/11 Javascript
js正则表达式校验指定字符串的方法
2018/07/23 Javascript
小程序中this.setData的使用和注意事项
2019/08/28 Javascript
[42:32]完美世界DOTA2联赛PWL S2 LBZS vs FTD.C 第二场 11.27
2020/12/01 DOTA
简单介绍Python中的len()函数的使用
2015/04/07 Python
在Python的Flask框架中使用日期和时间的教程
2015/04/21 Python
Python实现将Excel转换成xml的方法示例
2018/08/25 Python
python集合比较(交集,并集,差集)方法详解
2018/09/13 Python
CentOS7下安装python3.6.8的教程详解
2020/01/03 Python
Html5 FileReader实现即时上传图片功能实例代码
2014/09/01 HTML / CSS
德国化妆品和天然化妆品网上商店:kosmetikfuchs.de
2017/06/09 全球购物
JD Sports瑞典:英国领先的运动时尚商店
2018/01/28 全球购物
亚洲最大的运动鞋寄售店:KicksCrew
2020/11/26 全球购物
Nobody Denim官网:购买高级女士牛仔裤
2021/03/15 全球购物
广播电视新闻学专业应届生求职信
2013/10/08 职场文书
函授教育个人学习的自我评价
2013/12/31 职场文书
项目计划书范文
2014/01/09 职场文书
停车位租赁协议书
2014/09/24 职场文书
村主任“四风”问题个人对照检查材料思想汇报
2014/10/02 职场文书
党员带头倡议书
2015/04/29 职场文书
培训班开班主持词
2015/07/02 职场文书
分析Python感知线程状态的解决方案之Event与信号量
2021/06/16 Python