Tensorflow卷积神经网络实例进阶


Posted in Python onMay 24, 2018

在Tensorflow卷积神经网络实例这篇博客中,我们实现了一个简单的卷积神经网络,没有复杂的Trick。接下来,我们将使用CIFAR-10数据集进行训练。

CIFAR-10是一个经典的数据集,包含60000张32*32的彩色图像,其中训练集50000张,测试集10000张。CIFAR-10如同其名字,一共标注为10类,每一类图片6000张。

本文实现了进阶的卷积神经网络来解决CIFAR-10分类问题,我们使用了一些新的技巧:

  1. 对weights进行了L2的正则化
  2. 对图片进行了翻转、随机剪切等数据增强,制造了更多样本
  3. 在每个卷积-最大池化层后面使用了LRN(局部响应归一化层),增强了模型的泛化能力

首先需要下载Tensorflow models Tensorflow models,以便使用其中的CIFAR-10数据的类.进入目录models/tutorials/image/cifar10目录,执行以下代码

import cifar10
import cifar10_input
import tensorflow as tf
import numpy as np
import time

# 定义batch_size, 训练轮数max_steps, 以及下载CIFAR-10数据的默认路径
max_steps = 3000
batch_size = 128
data_dir = 'E:\\tmp\cifar10_data\cifar-10-batches-bin'

# 定义初始化weight的函数,定义的同时,对weight加一个L2 loss,放在集'losses'中
def variable_with_weight_loss(shape, stddev, w1):
  var = tf.Variable(tf.truncated_normal(shape, stddev=stddev))
  if w1 is not None:
    weight_loss = tf.multiply(tf.nn.l2_loss(var), w1, name='weight_loss')
    tf.add_to_collection('losses', weight_loss)
  return var

# 使用cifar10类下载数据集,并解压、展开到其默认位置
#cifar10.maybe_download_and_extract()

# 在使用cifar10_input类中的distorted_inputs函数产生训练需要使用的数据。需要注意的是,返回的是已经封装好的tensor,
# 且对数据进行了Data Augmentation(水平翻转、随机剪切、设置随机亮度和对比度、对数据进行标准化)
images_train, labels_train = cifar10_input.distorted_inputs(data_dir=data_dir, batch_size=batch_size)

# 再使用cifar10_input.inputs函数生成测试数据,这里不需要进行太多处理
images_test, labels_test = cifar10_input.inputs(eval_data=True,
                        data_dir=data_dir,
                        batch_size=batch_size)

# 创建数据的placeholder
image_holder = tf.placeholder(tf.float32, [batch_size, 24, 24, 3])
label_holder = tf.placeholder(tf.int32, [batch_size])

# 创建第一个卷积层
weight1 = variable_with_weight_loss(shape=[5, 5, 3, 64], stddev=5e-2,
                  w1=0.0)
kernel1 = tf.nn.conv2d(image_holder, weight1, strides=[1, 1, 1, 1], padding='SAME')
bias1 = tf.Variable(tf.constant(0.0, shape=[64]))
conv1 = tf.nn.relu(tf.nn.bias_add(kernel1, bias1))
pool1 = tf.nn.max_pool(conv1, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1],
            padding='SAME')
# LRN层对ReLU会比较有用,但不适合Sigmoid这种有固定边界并且能抑制过大值的激活函数
norm1 = tf.nn.lrn(pool1, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75)

# 创建第二个卷积层
weight2 = variable_with_weight_loss(shape=[5, 5, 64, 64], stddev=5e-2,
                  w1=0.0)
kernel2 = tf.nn.conv2d(norm1, weight2, strides=[1, 1, 1, 1], padding='SAME')
bias2 = tf.Variable(tf.constant(0.1, shape=[64]))
conv2 = tf.nn.relu(tf.nn.bias_add(kernel2, bias2))
norm2 = tf.nn.lrn(conv2, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75)
pool2 = tf.nn.max_pool(norm2, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1],
            padding='SAME')

# 使用一个全连接层
reshape = tf.reshape(pool2, [batch_size, -1])
dim = reshape.get_shape()[1].value
weight3 = variable_with_weight_loss(shape=[dim, 384], stddev=0.04, w1=0.004)
bias3 = tf.Variable(tf.constant(0.1, shape=[384]))
local3 = tf.nn.relu(tf.matmul(reshape, weight3) + bias3)

# 再使用一个全连接层,隐含节点数下降了一半,只有192个,其他的超参数保持不变
weight4 = variable_with_weight_loss(shape=[384, 192], stddev=0.04, w1=0.004)
bias4 = tf.Variable(tf.constant(0.1, shape=[192]))
local4 = tf.nn.relu(tf.matmul(local3, weight4) + bias4)

# 最后一层,将softmax放在了计算loss部分
weight5 = variable_with_weight_loss(shape=[192, 10], stddev=1 / 192.0, w1=0.0)
bias5 = tf.Variable(tf.constant(0.0, shape=[10]))
logits = tf.add(tf.matmul(local4, weight5), bias5)

# 定义loss
def loss(logits, labels):
  labels = tf.cast(labels, tf.int64)
  cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels,
                                  name='cross_entropy_per_example')
  cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy')
  tf.add_to_collection('losses', cross_entropy_mean)
  return tf.add_n(tf.get_collection('losses'), name='total_loss')

# 获取最终的loss
loss = loss(logits, label_holder)

# 优化器
train_op = tf.train.AdamOptimizer(1e-3).minimize(loss)

# 使用tf.nn.in_top_k函数求输出结果中top k的准确率,默认使用top 1,也就是输出分数最高的那一类的准确率
top_k_op = tf.nn.in_top_k(logits, label_holder, 1)

# 使用tf.InteractiveSession创建默认的session,接着初始化全部模型参数
sess = tf.InteractiveSession()
tf.global_variables_initializer().run()

# 启动图片数据增强线程
tf.train.start_queue_runners()

# 正式开始训练
for step in range(max_steps):
  start_time = time.time()
  image_batch, label_batch = sess.run([images_train, labels_train])
  _, loss_value = sess.run([train_op, loss], feed_dict={image_holder: image_batch, label_holder: label_batch})
  duration = time.time() - start_time
  if step % 10 == 0:
    example_per_sec = batch_size / duration
    sec_per_batch = float(duration)
    format_str = 'step %d, loss=%.2f ,%.1f examples/sec, %.3f sec/batch'
    print(format_str % (step, loss_value, example_per_sec, sec_per_batch))

num_examples = 10000
import math
num_iter = int(math.ceil(num_examples / batch_size))
true_count = 0
total_sample_count = num_iter * batch_size
step = 0
while step < num_iter:
  image_batch, label_batch = sess.run([images_test, labels_test])
  predictions = sess.run([top_k_op], feed_dict={image_holder: image_batch, label_holder: label_holder})
  true_count += np.sum(predictions)
  step += 1

precision = true_count / total_sample_count
print('precision @ 1 = %.3f'%precision)

运行结果:

Tensorflow卷积神经网络实例进阶

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python写的服务监控程序实例
Jan 31 Python
Python通过PIL获取图片主要颜色并和颜色库进行对比的方法
Mar 19 Python
初步解析Python中的yield函数的用法
Apr 03 Python
Python3使用requests包抓取并保存网页源码的方法
Mar 15 Python
一个基于flask的web应用诞生(1)
Apr 11 Python
Python列表删除的三种方法代码分享
Oct 31 Python
解决Mac下首次安装pycharm无project interpreter的问题
Oct 29 Python
pycharm远程开发项目的实现步骤
Jan 20 Python
python之信息加密题目详解
Jun 26 Python
Django rstful登陆认证并检查session是否过期代码实例
Aug 13 Python
python调用API接口实现登陆短信验证
May 10 Python
Python-OpenCV实现图像缺陷检测的实例
Jun 11 Python
Tensorflow卷积神经网络实例
May 24 #Python
使用pandas的DataFrame的plot方法绘制图像的实例
May 24 #Python
TensorFlow实现卷积神经网络
May 24 #Python
tensorflow实现简单的卷积神经网络
May 24 #Python
tensorflow实现简单的卷积网络
May 24 #Python
解决pandas 作图无法显示中文的问题
May 24 #Python
TensorFlow实现简单卷积神经网络
May 24 #Python
You might like
PHP number_format() 函数定义和用法
2012/06/01 PHP
关于PHP实现异步操作的研究
2013/02/03 PHP
浅析PHP中call user func()函数及如何使用call user func调用自定义函数
2015/11/05 PHP
基于php实现七牛抓取远程图片
2015/12/01 PHP
如何批量清理系统临时文件(语言:C#、 C/C++、 php 、python 、java )
2016/02/01 PHP
Javascript学习笔记 delete运算符
2011/09/13 Javascript
Mac/Windows下如何安装Node.js
2013/11/22 Javascript
jquery 删除字符串最后一个字符的方法解析
2014/02/11 Javascript
jQuery中on()方法用法实例详解
2015/02/06 Javascript
jquery.mobile 共同布局遇到的问题小结
2015/02/10 Javascript
解决ajax不能访问本地文件问题(利用js跨域原理)
2017/01/24 Javascript
Ionic项目中Native Camera的使用方法
2017/06/07 Javascript
javascript+css3开发打气球小游戏完整代码
2017/11/28 Javascript
vue.js的computed,filter,get,set的用法及区别详解
2018/03/08 Javascript
微信小程序 冒泡事件原理解析
2019/09/27 Javascript
使用Promise封装小程序wx.request的实现方法
2019/11/13 Javascript
vue实现一个获取按键展示快捷键效果的Input组件
2021/01/13 Vue.js
[00:36]DOTA2上海特级锦标赛 LGD战队宣传片
2016/03/04 DOTA
[41:41]TFT vs Secret Supermajor小组赛C组 BO3 第一场 6.3
2018/06/04 DOTA
详细介绍Python语言中的按位运算符
2013/11/26 Python
python多重继承新算法C3介绍
2014/09/28 Python
用PyInstaller把Python代码打包成单个独立的exe可执行文件
2018/05/26 Python
关于python中remove的一些坑小结
2021/01/04 Python
ORACLE十问
2015/04/20 面试题
Hashtable 添加内容的方式有哪几种,有什么区别?
2012/04/08 面试题
计算机专业毕业生推荐信
2013/11/25 职场文书
文秘人员工作职责
2014/01/31 职场文书
聘任书的写作格式及范文
2014/03/29 职场文书
公司委托书范本
2014/04/04 职场文书
领导干部整治奢华浪费之风思想汇报
2014/10/07 职场文书
2014流动人口计划生育工作总结
2014/12/20 职场文书
三好学生个人总结
2015/02/15 职场文书
销售辞职信范文
2015/03/02 职场文书
详解Python 3.10 中的新功能和变化
2021/04/28 Python
Python中常见的反爬机制及其破解方法总结
2021/06/10 Python
【海涛DOTA】D-cup邀请赛NV.cn vs DT.Love
2022/04/01 DOTA