Tensorflow卷积神经网络实例进阶


Posted in Python onMay 24, 2018

在Tensorflow卷积神经网络实例这篇博客中,我们实现了一个简单的卷积神经网络,没有复杂的Trick。接下来,我们将使用CIFAR-10数据集进行训练。

CIFAR-10是一个经典的数据集,包含60000张32*32的彩色图像,其中训练集50000张,测试集10000张。CIFAR-10如同其名字,一共标注为10类,每一类图片6000张。

本文实现了进阶的卷积神经网络来解决CIFAR-10分类问题,我们使用了一些新的技巧:

  1. 对weights进行了L2的正则化
  2. 对图片进行了翻转、随机剪切等数据增强,制造了更多样本
  3. 在每个卷积-最大池化层后面使用了LRN(局部响应归一化层),增强了模型的泛化能力

首先需要下载Tensorflow models Tensorflow models,以便使用其中的CIFAR-10数据的类.进入目录models/tutorials/image/cifar10目录,执行以下代码

import cifar10
import cifar10_input
import tensorflow as tf
import numpy as np
import time

# 定义batch_size, 训练轮数max_steps, 以及下载CIFAR-10数据的默认路径
max_steps = 3000
batch_size = 128
data_dir = 'E:\\tmp\cifar10_data\cifar-10-batches-bin'

# 定义初始化weight的函数,定义的同时,对weight加一个L2 loss,放在集'losses'中
def variable_with_weight_loss(shape, stddev, w1):
  var = tf.Variable(tf.truncated_normal(shape, stddev=stddev))
  if w1 is not None:
    weight_loss = tf.multiply(tf.nn.l2_loss(var), w1, name='weight_loss')
    tf.add_to_collection('losses', weight_loss)
  return var

# 使用cifar10类下载数据集,并解压、展开到其默认位置
#cifar10.maybe_download_and_extract()

# 在使用cifar10_input类中的distorted_inputs函数产生训练需要使用的数据。需要注意的是,返回的是已经封装好的tensor,
# 且对数据进行了Data Augmentation(水平翻转、随机剪切、设置随机亮度和对比度、对数据进行标准化)
images_train, labels_train = cifar10_input.distorted_inputs(data_dir=data_dir, batch_size=batch_size)

# 再使用cifar10_input.inputs函数生成测试数据,这里不需要进行太多处理
images_test, labels_test = cifar10_input.inputs(eval_data=True,
                        data_dir=data_dir,
                        batch_size=batch_size)

# 创建数据的placeholder
image_holder = tf.placeholder(tf.float32, [batch_size, 24, 24, 3])
label_holder = tf.placeholder(tf.int32, [batch_size])

# 创建第一个卷积层
weight1 = variable_with_weight_loss(shape=[5, 5, 3, 64], stddev=5e-2,
                  w1=0.0)
kernel1 = tf.nn.conv2d(image_holder, weight1, strides=[1, 1, 1, 1], padding='SAME')
bias1 = tf.Variable(tf.constant(0.0, shape=[64]))
conv1 = tf.nn.relu(tf.nn.bias_add(kernel1, bias1))
pool1 = tf.nn.max_pool(conv1, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1],
            padding='SAME')
# LRN层对ReLU会比较有用,但不适合Sigmoid这种有固定边界并且能抑制过大值的激活函数
norm1 = tf.nn.lrn(pool1, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75)

# 创建第二个卷积层
weight2 = variable_with_weight_loss(shape=[5, 5, 64, 64], stddev=5e-2,
                  w1=0.0)
kernel2 = tf.nn.conv2d(norm1, weight2, strides=[1, 1, 1, 1], padding='SAME')
bias2 = tf.Variable(tf.constant(0.1, shape=[64]))
conv2 = tf.nn.relu(tf.nn.bias_add(kernel2, bias2))
norm2 = tf.nn.lrn(conv2, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75)
pool2 = tf.nn.max_pool(norm2, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1],
            padding='SAME')

# 使用一个全连接层
reshape = tf.reshape(pool2, [batch_size, -1])
dim = reshape.get_shape()[1].value
weight3 = variable_with_weight_loss(shape=[dim, 384], stddev=0.04, w1=0.004)
bias3 = tf.Variable(tf.constant(0.1, shape=[384]))
local3 = tf.nn.relu(tf.matmul(reshape, weight3) + bias3)

# 再使用一个全连接层,隐含节点数下降了一半,只有192个,其他的超参数保持不变
weight4 = variable_with_weight_loss(shape=[384, 192], stddev=0.04, w1=0.004)
bias4 = tf.Variable(tf.constant(0.1, shape=[192]))
local4 = tf.nn.relu(tf.matmul(local3, weight4) + bias4)

# 最后一层,将softmax放在了计算loss部分
weight5 = variable_with_weight_loss(shape=[192, 10], stddev=1 / 192.0, w1=0.0)
bias5 = tf.Variable(tf.constant(0.0, shape=[10]))
logits = tf.add(tf.matmul(local4, weight5), bias5)

# 定义loss
def loss(logits, labels):
  labels = tf.cast(labels, tf.int64)
  cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels,
                                  name='cross_entropy_per_example')
  cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy')
  tf.add_to_collection('losses', cross_entropy_mean)
  return tf.add_n(tf.get_collection('losses'), name='total_loss')

# 获取最终的loss
loss = loss(logits, label_holder)

# 优化器
train_op = tf.train.AdamOptimizer(1e-3).minimize(loss)

# 使用tf.nn.in_top_k函数求输出结果中top k的准确率,默认使用top 1,也就是输出分数最高的那一类的准确率
top_k_op = tf.nn.in_top_k(logits, label_holder, 1)

# 使用tf.InteractiveSession创建默认的session,接着初始化全部模型参数
sess = tf.InteractiveSession()
tf.global_variables_initializer().run()

# 启动图片数据增强线程
tf.train.start_queue_runners()

# 正式开始训练
for step in range(max_steps):
  start_time = time.time()
  image_batch, label_batch = sess.run([images_train, labels_train])
  _, loss_value = sess.run([train_op, loss], feed_dict={image_holder: image_batch, label_holder: label_batch})
  duration = time.time() - start_time
  if step % 10 == 0:
    example_per_sec = batch_size / duration
    sec_per_batch = float(duration)
    format_str = 'step %d, loss=%.2f ,%.1f examples/sec, %.3f sec/batch'
    print(format_str % (step, loss_value, example_per_sec, sec_per_batch))

num_examples = 10000
import math
num_iter = int(math.ceil(num_examples / batch_size))
true_count = 0
total_sample_count = num_iter * batch_size
step = 0
while step < num_iter:
  image_batch, label_batch = sess.run([images_test, labels_test])
  predictions = sess.run([top_k_op], feed_dict={image_holder: image_batch, label_holder: label_holder})
  true_count += np.sum(predictions)
  step += 1

precision = true_count / total_sample_count
print('precision @ 1 = %.3f'%precision)

运行结果:

Tensorflow卷积神经网络实例进阶

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 正则式使用心得
May 07 Python
Python的Django框架中自定义模版标签的示例
Jul 20 Python
Python的Flask框架应用调用Redis队列数据的方法
Jun 06 Python
Python中的迭代器与生成器高级用法解析
Jun 28 Python
Python实现判断一个字符串是否包含子串的方法总结
Nov 21 Python
python3.x上post发送json数据
Mar 04 Python
使用Python制作自动推送微信消息提醒的备忘录功能
Sep 06 Python
VPS CENTOS 上配置python,mysql,nginx,uwsgi,django的方法详解
Jul 01 Python
Python实现word2Vec model过程解析
Dec 16 Python
Tensorflow轻松实现XOR运算的方式
Feb 03 Python
Python Tricks 使用 pywinrm 远程控制 Windows 主机的方法
Jul 21 Python
如何更换python默认编辑器的背景色
Aug 10 Python
Tensorflow卷积神经网络实例
May 24 #Python
使用pandas的DataFrame的plot方法绘制图像的实例
May 24 #Python
TensorFlow实现卷积神经网络
May 24 #Python
tensorflow实现简单的卷积神经网络
May 24 #Python
tensorflow实现简单的卷积网络
May 24 #Python
解决pandas 作图无法显示中文的问题
May 24 #Python
TensorFlow实现简单卷积神经网络
May 24 #Python
You might like
PHP的一个完整SMTP类(解决邮件服务器需要验证时的问题)
2006/10/09 PHP
解决控件遮挡问题:关于有窗口元素和无窗口元素
2007/01/28 PHP
用PHP实现Ftp用户的在线管理的代码
2007/03/06 PHP
PHPMYADMIN 简明安装教程 推荐
2010/03/07 PHP
调整优化您的LAMP应用程序的5种简单方法
2011/06/26 PHP
dhtmlxTree目录树增加右键菜单以及拖拽排序的实现方法
2013/04/26 PHP
PHP中UNIX时间戳和日期间的转换与计算实例
2014/11/19 PHP
完美解决thinkphp验证码出错无法显示的方法
2014/12/09 PHP
简单谈谈 php 文件锁
2017/02/19 PHP
PHP回调函数概念与用法实例分析
2017/11/03 PHP
php记录搜索引擎爬行记录的实现代码
2018/03/02 PHP
PHP预定义超全局数组变量小结
2018/08/20 PHP
javascript 打开页面window.location和window.open的区别
2010/03/17 Javascript
Jquery ui css framework
2010/06/28 Javascript
详解jQuery移动页面开发中的ui-grid网格布局使用
2015/12/03 Javascript
两种方法解决javascript url post 特殊字符转义 + &amp; #
2016/04/13 Javascript
微信小程序 122100版本更新问题解决方案
2016/12/22 Javascript
bootstrapValidator表单验证插件学习
2016/12/30 Javascript
JavaScript实现256色转灰度图
2017/02/22 Javascript
在 Angular-cli 中使用 simple-mock 实现前端开发 API Mock 接口数据模拟功能的方法
2018/11/28 Javascript
Vue触发隐藏input file的方法实例详解
2019/08/14 Javascript
Antd下拉选择,自动匹配功能的实现
2020/10/24 Javascript
浅谈es6中的元编程
2020/12/01 Javascript
[03:40]DOTA2亚洲邀请赛小组赛第二日 赛事回顾
2015/01/31 DOTA
[00:33]2018DOTA2亚洲邀请赛TNC出场
2018/04/04 DOTA
Python计算回文数的方法
2015/03/11 Python
python验证码识别的实例详解
2016/09/09 Python
Python实现将多个空格换为一个空格.md的方法
2018/12/20 Python
Django框架HttpResponse对象用法实例分析
2019/11/01 Python
pytorch 自定义卷积核进行卷积操作方式
2019/12/30 Python
pycharm工具连接mysql数据库失败问题
2020/04/01 Python
人力资源管理专业应届生求职信
2014/04/24 职场文书
道德演讲稿
2014/05/21 职场文书
写给同学的新学期寄语
2015/02/27 职场文书
2015年信访工作总结
2015/04/07 职场文书
Java基础——Map集合
2022/04/01 Java/Android