tensorflow实现简单的卷积网络


Posted in Python onMay 24, 2018

使用tensorflow实现一个简单的卷积神经,使用的数据集是MNIST,本节将使用两个卷积层加一个全连接层,构建一个简单有代表性的卷积网络。

代码是按照书上的敲的,第一步就是导入数据库,设置节点的初始值,Tf.nn.conv2d是tensorflow中的2维卷积,参数x是输入,W是卷积的参数,比如【5,5,1,32】,前面两个数字代表卷积核的尺寸,第三个数字代表有几个通道,比如灰度图是1,彩色图是3.最后一个代表卷积的数量,总的实现代码如下:

from tensorflow.examples.tutorials.mnist import input_data 
import tensorflow as tf 
mnist = input_data.read_data_sets("MNSIT_data/", one_hot=True) 
sess = tf.InteractiveSession() 
 
 
# In[2]: 
#由于W和b在各层中均要用到,先定义乘函数。 
#tf.truncated_normal:截断正态分布,即限制范围的正态分布 
def weight_variable(shape): 
  initial = tf.truncated_normal(shape, stddev=0.1) 
  return tf.Variable(initial) 
 
 
# In[7]: 
#bias初始化值0.1. 
def bias_variable(shape): 
  initial = tf.constant(0.1, shape=shape) 
  return tf.Variable(initial) 
 
 
# In[12]: 
#tf.nn.conv2d:二维的卷积 
#conv2d(input, filter, strides, padding, use_cudnn_on_gpu=None,data_format=None, name=None) 
#filter:A 4-D tensor of shape 
#   `[filter_height, filter_width, in_channels, out_channels]` 
#strides:步长,都是1表示所有点都不会被遗漏。1-D 4值,表示每歌dim的移动步长。 
# padding:边界的处理方式,“SAME"、"VALID”可选 
def conv2d(x, W): 
  return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') 
 
#tf.nn.max_pool:最大值池化函数,即求2*2区域的最大值,保留最显著的特征。 
#max_pool(value, ksize, strides, padding, data_format="NHWC", name=None) 
#ksize:池化窗口的尺寸 
#strides:[1,2,2,1]表示横竖方向步长为2 
def max_pool_2x2(x): 
  return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides = [1, 2, 2, 1], padding='SAME') 
 
 
x = tf.placeholder(tf.float32, [None, 784]) 
y_ = tf.placeholder(tf.float32, [None, 10]) 
#tf.reshape:tensor的变形函数。 
#-1:样本数量不固定 
#28,28:新形状的shape 
#1:颜色通道数 
x_image = tf.reshape(x, [-1, 28, 28, 1]) 
 
 
#卷积层包含三部分:卷积计算、激活、池化 
#[5,5,1,32]表示卷积核的尺寸为5×5, 颜色通道为1, 有32个卷积核 
W_conv1 = weight_variable([5, 5, 1, 32]) 
b_conv1 = bias_variable([32]) 
h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1) 
h_pool1 = max_pool_2x2(h_conv1) 
 
 
W_conv2 = weight_variable([5, 5, 32, 64]) 
b_conv2 = bias_variable([64]) 
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2) 
h_pool2 = max_pool_2x2(h_conv2) 
 
 
#经过2次2×2的池化后,图像的尺寸变为7×7,第二个卷积层有64个卷积核,生成64类特征,因此,卷积最后输出为7×7×64. 
#tensor进入全连接层之前,先将64张二维图像变形为1维图像,便于计算。 
W_fc1 = weight_variable([7*7*64, 1024]) 
b_fc1 = bias_variable([1024]) 
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64]) 
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1) 
 
 
#对全连接层做dropot 
keep_prob = tf.placeholder(tf.float32) 
h_fc1_dropout = tf.nn.dropout(h_fc1, keep_prob) 
 
 
#又一个全连接后foftmax分类 
W_fc2 = weight_variable([1024, 10]) 
b_fc2 = bias_variable([10]) 
y_conv = tf.nn.softmax(tf.matmul(h_fc1_dropout, W_fc2) + b_fc2) 
 
 
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y_conv), reduction_indices=[1])) 
#AdamOptimizer:Adam优化函数 
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) 
 
 
 
correct_prediction = tf.equal(tf.argmax(y_, 1), tf.argmax(y_conv, 1)) 
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 
 
 
#训练,并且每100个batch计算一次精度 
tf.global_variables_initializer().run() 
for i in range(20000): 
  batch = mnist.train.next_batch(50) 
  if i%100 == 0: 
    train_accuracy = accuracy.eval(feed_dict={x:batch[0], y_:batch[1], keep_prob:1.0}) 
    print("step %d, training accuracy %g" %(i, train_accuracy)) 
  train_step.run(feed_dict={x:batch[0], y_:batch[1], keep_prob:0.5}) 
 
 
#在测试集上测试 
print("test accuracy %g"%accuracy.eval(feed_dict={x:mnist.test.images, y_:mnist.test.labels, keep_prob:1.0}))

 注意的是书上开始运行的代码是tf.global_variables_initializer().run(),但是在敲到代码中就会报错,也不知道为什么,可能是因为版本的问题吧,上网搜了一下,改为sess.run(tf.initialiaze_all_variables)即可。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中replace方法实例分析
Aug 20 Python
python实现决策树
Dec 21 Python
Python tkinter实现的图片移动碰撞动画效果【附源码下载】
Jan 04 Python
Django添加sitemap的方法示例
Aug 06 Python
在Mac下使用python实现简单的目录树展示方法
Nov 01 Python
Python 窗体(tkinter)按钮 位置实例
Jun 13 Python
Ubuntu18.04下python版本完美切换的解决方法
Jun 14 Python
PyQt5响应回车事件的方法
Jun 25 Python
python实现H2O中的随机森林算法介绍及其项目实战
Aug 29 Python
GDAL 矢量属性数据修改方式(python)
Mar 10 Python
快速解决pymongo操作mongodb的时区问题
Dec 05 Python
python文件与路径操作神器 pathlib
Apr 01 Python
解决pandas 作图无法显示中文的问题
May 24 #Python
TensorFlow实现简单卷积神经网络
May 24 #Python
解决matplotlib库show()方法不显示图片的问题
May 24 #Python
解决pandas无法在pycharm中使用plot()方法显示图像的问题
May 24 #Python
解决seaborn在pycharm中绘图不出图的问题
May 24 #Python
快速解决PyCharm无法引用matplotlib的问题
May 24 #Python
Django rest framework实现分页的示例
May 24 #Python
You might like
浅析php设计模式之数据对象映射模式
2016/03/03 PHP
CI(CodeIgniter)框架实现图片上传的方法
2017/03/24 PHP
[原创]PHP实现字节数Byte转换为KB、MB、GB、TB的方法
2017/08/31 PHP
[原创]静态页面也可以实现预览 列表不同的显示方式
2006/10/14 Javascript
表单JS弹出填写提示效果代码
2011/04/16 Javascript
js控制的遮罩层实例介绍
2013/05/29 Javascript
js控制页面控件隐藏显示的两种方法介绍
2013/10/09 Javascript
JS制作手机端自适应缩放显示
2015/06/11 Javascript
jquery实现表单验证简单实例演示
2015/11/23 Javascript
基于jQuery仿淘宝产品图片放大镜特效
2020/10/19 Javascript
理解Javascript文件动态加载
2016/01/29 Javascript
Node.js上传文件功能之服务端如何获取文件上传进度
2018/02/05 Javascript
vue 中 beforeRouteEnter 死循环的问题
2019/04/23 Javascript
Element-ui DatePicker显示周数的方法示例
2019/07/19 Javascript
[01:00:25]NB vs Secret 2018国际邀请赛小组赛BO1 B组加赛 8.19
2018/08/21 DOTA
[54:30]Liquid vs Newbee 2019国际邀请赛小组赛 BO2 第二场 8.15
2019/08/16 DOTA
[01:32:50]DOTA2-DPC中国联赛 正赛 DLG vs XG BO3 第一场 1月25日
2021/03/11 DOTA
Python的词法分析与语法分析
2013/05/18 Python
Python中使用dom模块生成XML文件示例
2015/04/05 Python
Python MySQLdb Linux下安装笔记
2015/05/09 Python
Python中pygame的mouse鼠标事件用法实例
2015/11/11 Python
十个Python程序员易犯的错误
2015/12/15 Python
浅谈Python爬取网页的编码处理
2016/11/04 Python
Python将DataFrame的某一列作为index的方法
2018/04/08 Python
基于DataFrame改变列类型的方法
2018/07/25 Python
python使用pygame框架实现推箱子游戏
2018/11/20 Python
Python实现的大数据分析操作系统日志功能示例
2019/02/11 Python
解决python3.5 正常安装 却不能直接使用Tkinter包的问题
2019/02/22 Python
python——全排列数的生成方式
2020/02/26 Python
类的返射机制中的包及核心类
2016/09/12 面试题
李开复演讲稿
2014/05/24 职场文书
2014幼儿教师个人工作总结
2014/12/03 职场文书
暂住证证明
2015/06/19 职场文书
新学期主题班会
2015/08/17 职场文书
CSS 实现多彩、智能的阴影效果
2021/05/12 HTML / CSS
pytorch中Schedule与warmup_steps的用法说明
2021/05/24 Python