tensorflow实现简单的卷积网络


Posted in Python onMay 24, 2018

使用tensorflow实现一个简单的卷积神经,使用的数据集是MNIST,本节将使用两个卷积层加一个全连接层,构建一个简单有代表性的卷积网络。

代码是按照书上的敲的,第一步就是导入数据库,设置节点的初始值,Tf.nn.conv2d是tensorflow中的2维卷积,参数x是输入,W是卷积的参数,比如【5,5,1,32】,前面两个数字代表卷积核的尺寸,第三个数字代表有几个通道,比如灰度图是1,彩色图是3.最后一个代表卷积的数量,总的实现代码如下:

from tensorflow.examples.tutorials.mnist import input_data 
import tensorflow as tf 
mnist = input_data.read_data_sets("MNSIT_data/", one_hot=True) 
sess = tf.InteractiveSession() 
 
 
# In[2]: 
#由于W和b在各层中均要用到,先定义乘函数。 
#tf.truncated_normal:截断正态分布,即限制范围的正态分布 
def weight_variable(shape): 
  initial = tf.truncated_normal(shape, stddev=0.1) 
  return tf.Variable(initial) 
 
 
# In[7]: 
#bias初始化值0.1. 
def bias_variable(shape): 
  initial = tf.constant(0.1, shape=shape) 
  return tf.Variable(initial) 
 
 
# In[12]: 
#tf.nn.conv2d:二维的卷积 
#conv2d(input, filter, strides, padding, use_cudnn_on_gpu=None,data_format=None, name=None) 
#filter:A 4-D tensor of shape 
#   `[filter_height, filter_width, in_channels, out_channels]` 
#strides:步长,都是1表示所有点都不会被遗漏。1-D 4值,表示每歌dim的移动步长。 
# padding:边界的处理方式,“SAME"、"VALID”可选 
def conv2d(x, W): 
  return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') 
 
#tf.nn.max_pool:最大值池化函数,即求2*2区域的最大值,保留最显著的特征。 
#max_pool(value, ksize, strides, padding, data_format="NHWC", name=None) 
#ksize:池化窗口的尺寸 
#strides:[1,2,2,1]表示横竖方向步长为2 
def max_pool_2x2(x): 
  return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides = [1, 2, 2, 1], padding='SAME') 
 
 
x = tf.placeholder(tf.float32, [None, 784]) 
y_ = tf.placeholder(tf.float32, [None, 10]) 
#tf.reshape:tensor的变形函数。 
#-1:样本数量不固定 
#28,28:新形状的shape 
#1:颜色通道数 
x_image = tf.reshape(x, [-1, 28, 28, 1]) 
 
 
#卷积层包含三部分:卷积计算、激活、池化 
#[5,5,1,32]表示卷积核的尺寸为5×5, 颜色通道为1, 有32个卷积核 
W_conv1 = weight_variable([5, 5, 1, 32]) 
b_conv1 = bias_variable([32]) 
h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1) 
h_pool1 = max_pool_2x2(h_conv1) 
 
 
W_conv2 = weight_variable([5, 5, 32, 64]) 
b_conv2 = bias_variable([64]) 
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2) 
h_pool2 = max_pool_2x2(h_conv2) 
 
 
#经过2次2×2的池化后,图像的尺寸变为7×7,第二个卷积层有64个卷积核,生成64类特征,因此,卷积最后输出为7×7×64. 
#tensor进入全连接层之前,先将64张二维图像变形为1维图像,便于计算。 
W_fc1 = weight_variable([7*7*64, 1024]) 
b_fc1 = bias_variable([1024]) 
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64]) 
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1) 
 
 
#对全连接层做dropot 
keep_prob = tf.placeholder(tf.float32) 
h_fc1_dropout = tf.nn.dropout(h_fc1, keep_prob) 
 
 
#又一个全连接后foftmax分类 
W_fc2 = weight_variable([1024, 10]) 
b_fc2 = bias_variable([10]) 
y_conv = tf.nn.softmax(tf.matmul(h_fc1_dropout, W_fc2) + b_fc2) 
 
 
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y_conv), reduction_indices=[1])) 
#AdamOptimizer:Adam优化函数 
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) 
 
 
 
correct_prediction = tf.equal(tf.argmax(y_, 1), tf.argmax(y_conv, 1)) 
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 
 
 
#训练,并且每100个batch计算一次精度 
tf.global_variables_initializer().run() 
for i in range(20000): 
  batch = mnist.train.next_batch(50) 
  if i%100 == 0: 
    train_accuracy = accuracy.eval(feed_dict={x:batch[0], y_:batch[1], keep_prob:1.0}) 
    print("step %d, training accuracy %g" %(i, train_accuracy)) 
  train_step.run(feed_dict={x:batch[0], y_:batch[1], keep_prob:0.5}) 
 
 
#在测试集上测试 
print("test accuracy %g"%accuracy.eval(feed_dict={x:mnist.test.images, y_:mnist.test.labels, keep_prob:1.0}))

 注意的是书上开始运行的代码是tf.global_variables_initializer().run(),但是在敲到代码中就会报错,也不知道为什么,可能是因为版本的问题吧,上网搜了一下,改为sess.run(tf.initialiaze_all_variables)即可。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
web.py中调用文件夹内模板的方法
Aug 26 Python
python利用beautifulSoup实现爬虫
Sep 29 Python
CentOS中升级Python版本的方法详解
Jul 10 Python
python学习笔记之列表(list)与元组(tuple)详解
Nov 23 Python
使用Flask-Cache缓存实现给Flask提速的方法详解
Jun 11 Python
python命令行参数用法实例分析
Jun 25 Python
Windows 下更改 jupyterlab 默认启动位置的教程详解
May 18 Python
Python常用数据分析模块原理解析
Jul 20 Python
matplotlib基础绘图命令之errorbar的使用
Aug 13 Python
解决Python安装cryptography报错问题
Sep 03 Python
Python通过递归函数输出嵌套列表元素
Oct 15 Python
简单介绍Python的第三方库yaml
Jun 18 Python
解决pandas 作图无法显示中文的问题
May 24 #Python
TensorFlow实现简单卷积神经网络
May 24 #Python
解决matplotlib库show()方法不显示图片的问题
May 24 #Python
解决pandas无法在pycharm中使用plot()方法显示图像的问题
May 24 #Python
解决seaborn在pycharm中绘图不出图的问题
May 24 #Python
快速解决PyCharm无法引用matplotlib的问题
May 24 #Python
Django rest framework实现分页的示例
May 24 #Python
You might like
最省空间的计数器
2006/10/09 PHP
php四种基础算法代码实例
2013/10/29 PHP
php用ini_get获取php.ini里变量值的方法
2015/03/04 PHP
ThinkPHP的常用配置选项汇总
2016/03/24 PHP
Laravel 实现数据软删除功能
2019/08/21 PHP
JavaScript 学习笔记(十三)Dom创建表格
2010/01/21 Javascript
javascript encodeURI和encodeURIComponent的比较
2010/04/03 Javascript
jquery 弹出层注册页面等(asp.net后台)
2010/06/17 Javascript
JS验证IP,子网掩码,网关和MAC的方法
2015/07/02 Javascript
JavaScript中判断两个字符串是否相等的方法
2015/07/07 Javascript
React创建组件的三种方式及其区别
2017/01/12 Javascript
bootstrap table插件的分页与checkbox使用详解
2017/07/23 Javascript
详解React 在服务端渲染的实现
2017/11/16 Javascript
AngularJS实现的鼠标拖动画矩形框示例【可兼容IE8】
2019/05/17 Javascript
json字符串对象转换代码实例
2019/09/28 Javascript
微信小程序去除左上角返回键的实现方法
2020/03/06 Javascript
React实现类似淘宝tab居中切换效果的示例代码
2020/06/02 Javascript
flask使用session保存登录状态及拦截未登录请求代码
2018/01/19 Python
查看TensorFlow checkpoint文件中的变量名和对应值方法
2018/06/14 Python
使用Selenium破解新浪微博的四宫格验证码
2018/10/19 Python
解决pycharm 误删掉项目文件的处理方法
2018/10/22 Python
Python获取数据库数据并保存在excel表格中的方法
2019/06/12 Python
Python qqbot 实现qq机器人的示例代码
2019/07/11 Python
python GUI库图形界面开发之PyQt5窗口背景与不规则窗口实例
2020/02/25 Python
浅谈Python线程的同步互斥与死锁
2020/03/22 Python
简述网络文件系统NFS,并说明其作用
2016/10/19 面试题
体育教育毕业生自荐信
2013/11/21 职场文书
说明书格式及范文
2014/05/07 职场文书
促销活动总结模板
2014/07/01 职场文书
2014最新房贷收入证明范本
2014/09/12 职场文书
2014年个人债务授权委托书范本
2014/09/22 职场文书
公司领导九九重阳节发言稿2014
2014/09/25 职场文书
公司员工离职证明书
2014/10/04 职场文书
2016年先进班集体事迹材料
2016/02/26 职场文书
简述python四种分词工具,盘点哪个更好用?
2021/04/13 Python
CSS使用Flex和Grid布局实现3D骰子
2022/08/05 HTML / CSS