tensorflow实现简单的卷积神经网络


Posted in Python onMay 24, 2018

本文实例为大家分享了Android九宫格图片展示的具体代码,供大家参考,具体内容如下

一.知识点总结

1.  卷积神经网络出现的初衷是降低对图像的预处理,避免建立复杂的特征工程。因为卷积神经网络在训练的过程中,自己会提取特征。

2.   灵感来自于猫的视觉皮层研究,每一个视觉神经元只会处理一小块区域的视觉图像,即感知野。放到卷积神经网络里就是每一个隐含节点只与设定范围内的像素点相连(设定范围就是卷积核的尺寸),而全连接层是每个像素点与每个隐含节点相连。这种感知野也称之为局部感知。

例如,一张1000*1000的图片,如果隐含层有100*100个隐含节点全连接,则需要1000*1000*100*100+100*100个参数,而如果有10*10的范围局部感知,用同样多的隐含节点,只需要10*10*100*100+100*100个参数。

3.  把卷积的过程称作卷积滤波,除了上面的局部感知,卷积滤波还有一个化简操作——权值共享。即一个卷积滤波中的所有隐含节点与感知图像连接的权值是一样的,这样,上述例子的参数减少为10*10+100*100个了。W的数量等于感知范围的尺寸。

4.  为了抗变形和减小复杂度,卷积层同时还要做激活和池化。激活函数前一章已经弄明白了,池化,相当于降采样,将n*n的像素区域采样为m*m区域,m通常小于n。通常选择最大池化,即选择区域内的最大像素点。 

5.  总结来讲,卷积有三个要点:局部连接、权值共享、池化降采样。一个卷积过程包含三个步骤:卷积滤波、激活、池化。 

6.  卷积滤波中的卷积范围可以用一个词来代替——卷积核,卷积核等同于卷积滤波中的一个隐含节点感知范围。由于权值共享,相当于一个卷积核对整个图像做多次小范围滤波,每滤一次波生成一个小的特征图像,多次滤波后将所有小特征图像组合起来,生成了对整个图像的feature map。通常,一个卷积滤波过程有多个卷积核卷积,生成多张feature map。

所有的feature map都会被池化,然后输入下一层。 

7.  需要训练的权值(参数)的数量只和卷积核尺寸有关,隐含节点(即卷积核要卷积的次数)只和卷积的卷积步长、图像尺寸有关。

个人理解,一个卷积核对整个图像卷积的过程,就像是一个棋子,在整个棋盘上按照步长跳动,每跳动一次,对感知范围内的像素点做一次连接计算。 

8.  CNN在结构上和图像的结构更为接近,都是2D的,因此,早期用在图像上效果很好,但是最近,CNN用于NLP也很热门。

二.程序解析

# coding: utf-8 
 
# In[1]: 
 
from tensorflow.examples.tutorials.mnist import input_data 
import tensorflow as tf 
mnist = input_data.read_data_sets("MNSIT_data/", one_hot=True) 
sess = tf.InteractiveSession() 
 
 
# In[2]: 
#由于W和b在各层中均要用到,先定义乘函数。 
#tf.truncated_normal:截断正态分布,即限制范围的正态分布 
def weight_variable(shape): 
  initial = tf.truncated_normal(shape, stddev=0.1) 
  return tf.Variable(initial) 
 
 
# In[7]: 
#bias初始化值0.1. 
def bias_variable(shape): 
  initial = tf.constant(0.1, shape=shape) 
  return tf.Variable(initial) 
 
 
# In[12]: 
#tf.nn.conv2d:二维的卷积 
#conv2d(input, filter, strides, padding, use_cudnn_on_gpu=None,data_format=None, name=None) 
#filter:A 4-D tensor of shape 
#   `[filter_height, filter_width, in_channels, out_channels]` 
#strides:步长,都是1表示所有点都不会被遗漏。1-D 4值,表示每歌dim的移动步长。 
# padding:边界的处理方式,“SAME"、"VALID”可选 
def conv2d(x, W): 
  return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') 
 
#tf.nn.max_pool:最大值池化函数,即求2*2区域的最大值,保留最显著的特征。 
#max_pool(value, ksize, strides, padding, data_format="NHWC", name=None) 
#ksize:池化窗口的尺寸 
#strides:[1,2,2,1]表示横竖方向步长为2 
def max_pool_2x2(x): 
  return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides = [1, 2, 2, 1], padding='SAME') 
 
 
x = tf.placeholder(tf.float32, [None, 784]) 
y_ = tf.placeholder(tf.float32, [None, 10]) 
#tf.reshape:tensor的变形函数。 
#-1:样本数量不固定 
#28,28:新形状的shape 
#1:颜色通道数 
x_image = tf.reshape(x, [-1, 28, 28, 1]) 
 
 
#卷积层包含三部分:卷积计算、激活、池化 
#[5,5,1,32]表示卷积核的尺寸为5×5, 颜色通道为1, 有32个卷积核 
W_conv1 = weight_variable([5, 5, 1, 32]) 
b_conv1 = bias_variable([32]) 
h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1) 
h_pool1 = max_pool_2x2(h_conv1) 
 
 
W_conv2 = weight_variable([5, 5, 32, 64]) 
b_conv2 = bias_variable([64]) 
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2) 
h_pool2 = max_pool_2x2(h_conv2) 
 
 
#经过2次2×2的池化后,图像的尺寸变为7×7,第二个卷积层有64个卷积核,生成64类特征,因此,卷积最后输出为7×7×64. 
#tensor进入全连接层之前,先将64张二维图像变形为1维图像,便于计算。 
W_fc1 = weight_variable([7*7*64, 1024]) 
b_fc1 = bias_variable([1024]) 
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64]) 
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1) 
 
 
#对全连接层做dropot 
keep_prob = tf.placeholder(tf.float32) 
h_fc1_dropout = tf.nn.dropout(h_fc1, keep_prob) 
 
 
#又一个全连接后foftmax分类 
W_fc2 = weight_variable([1024, 10]) 
b_fc2 = bias_variable([10]) 
y_conv = tf.nn.softmax(tf.matmul(h_fc1_dropout, W_fc2) + b_fc2) 
 
 
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y_conv), reduction_indices=[1])) 
#AdamOptimizer:Adam优化函数 
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) 
 
 
 
correct_prediction = tf.equal(tf.argmax(y_, 1), tf.argmax(y_conv, 1)) 
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 
 
 
#训练,并且每100个batch计算一次精度 
tf.global_variables_initializer().run() 
for i in range(20000): 
  batch = mnist.train.next_batch(50) 
  if i%100 == 0: 
    train_accuracy = accuracy.eval(feed_dict={x:batch[0], y_:batch[1], keep_prob:1.0}) 
    print("step %d, training accuracy %g" %(i, train_accuracy)) 
  train_step.run(feed_dict={x:batch[0], y_:batch[1], keep_prob:0.5}) 
 
 
#在测试集上测试 
print("test accuracy %g"%accuracy.eval(feed_dict={x:mnist.test.images, y_:mnist.test.labels, keep_prob:1.0}))

补充一下目前三个网络在mnist上的精度分别为:

无隐含层的softmax:91.5%

加入一个全连接隐含层的感知机:98.1%

此cnn:99.07%

和作者的训练结果有细微的差异,可能设备不同吧。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python fabric实现远程操作和部署示例
Mar 25 Python
python简单读取大文件的方法
Jul 01 Python
用python脚本24小时刷浏览器的访问量方法
Dec 07 Python
python爬虫超时的处理的实例
Dec 19 Python
python opencv 二值化 计算白色像素点的实例
Jul 03 Python
Tensorflow 自定义loss的情况下初始化部分变量方式
Jan 06 Python
TensorFlow实现保存训练模型为pd文件并恢复
Feb 06 Python
python实现最速下降法
Mar 24 Python
Django实现将一个字典传到前端显示出来
Apr 03 Python
TensorFlow tf.nn.conv2d_transpose是怎样实现反卷积的
Apr 20 Python
使用python-cv2实现视频的分解与合成的示例代码
Oct 26 Python
python实现发送邮件
Mar 02 Python
tensorflow实现简单的卷积网络
May 24 #Python
解决pandas 作图无法显示中文的问题
May 24 #Python
TensorFlow实现简单卷积神经网络
May 24 #Python
解决matplotlib库show()方法不显示图片的问题
May 24 #Python
解决pandas无法在pycharm中使用plot()方法显示图像的问题
May 24 #Python
解决seaborn在pycharm中绘图不出图的问题
May 24 #Python
快速解决PyCharm无法引用matplotlib的问题
May 24 #Python
You might like
在数据量大(超过10万)的情况下
2007/01/15 PHP
php环境无法上传文件的解决方法
2014/04/30 PHP
PHP登录环节防止sql注入的方法浅析
2014/06/30 PHP
PHP AjaxForm提交图片上传并显示图片源码
2016/11/29 PHP
php-app开发接口加密详解
2018/04/18 PHP
PHP按一定比例压缩图片的方法
2018/10/12 PHP
jQuery Ajax文件上传(php)
2009/06/16 Javascript
jQuery焦点图切换特效插件封装实例
2013/08/18 Javascript
LABjs、RequireJS、SeaJS的区别
2014/03/04 Javascript
一张表格告诉你windows.onload()与$(document).ready()的区别
2014/05/16 Javascript
JavaScript事件类型中焦点、鼠标和滚轮事件详解
2016/01/25 Javascript
easyui combobox开启搜索自动完成功能的实例代码
2016/11/08 Javascript
详解有关easyUI的拖动操作中droppable,draggable用法例子
2017/06/03 Javascript
js使用generator函数同步执行ajax任务
2017/09/05 Javascript
微信小程序云开发如何实现数据库自动备份实现
2019/08/16 Javascript
JavaScript this使用方法图解
2020/02/04 Javascript
Vue 实现v-for循环的时候更改 class的样式名称
2020/07/17 Javascript
基于JQuery和DWR实现异步数据传递
2020/10/16 jQuery
Python去除列表中重复元素的方法
2015/03/20 Python
举例讲解Python中装饰器的用法
2015/04/27 Python
在java中如何定义一个抽象属性示例详解
2017/08/18 Python
Django ORM框架的定时任务如何使用详解
2017/10/19 Python
python+matplotlib演示电偶极子实例代码
2018/01/12 Python
Python元组及文件核心对象类型详解
2018/02/11 Python
python: 自动安装缺失库文件的方法
2018/10/22 Python
Python实现的微信支付方式总结【三种方式】
2019/04/13 Python
python 计算概率密度、累计分布、逆函数的例子
2020/02/25 Python
Python视频编辑库MoviePy的使用
2020/04/01 Python
python如何删除文件、目录
2020/06/23 Python
介绍Java的内部类
2012/10/27 面试题
小学生自我评价范例
2013/09/24 职场文书
宾馆仓管员岗位职责
2014/07/27 职场文书
社保转移委托书范本
2014/10/08 职场文书
具结保证书
2015/01/17 职场文书
如何用JS实现网页瀑布流布局
2021/04/24 Javascript
win server2012 r2服务器共享文件夹如何设置
2022/06/21 Servers