Tensorflow卷积神经网络实例进阶


Posted in Python onMay 24, 2018

在Tensorflow卷积神经网络实例这篇博客中,我们实现了一个简单的卷积神经网络,没有复杂的Trick。接下来,我们将使用CIFAR-10数据集进行训练。

CIFAR-10是一个经典的数据集,包含60000张32*32的彩色图像,其中训练集50000张,测试集10000张。CIFAR-10如同其名字,一共标注为10类,每一类图片6000张。

本文实现了进阶的卷积神经网络来解决CIFAR-10分类问题,我们使用了一些新的技巧:

  1. 对weights进行了L2的正则化
  2. 对图片进行了翻转、随机剪切等数据增强,制造了更多样本
  3. 在每个卷积-最大池化层后面使用了LRN(局部响应归一化层),增强了模型的泛化能力

首先需要下载Tensorflow models Tensorflow models,以便使用其中的CIFAR-10数据的类.进入目录models/tutorials/image/cifar10目录,执行以下代码

import cifar10
import cifar10_input
import tensorflow as tf
import numpy as np
import time

# 定义batch_size, 训练轮数max_steps, 以及下载CIFAR-10数据的默认路径
max_steps = 3000
batch_size = 128
data_dir = 'E:\\tmp\cifar10_data\cifar-10-batches-bin'

# 定义初始化weight的函数,定义的同时,对weight加一个L2 loss,放在集'losses'中
def variable_with_weight_loss(shape, stddev, w1):
  var = tf.Variable(tf.truncated_normal(shape, stddev=stddev))
  if w1 is not None:
    weight_loss = tf.multiply(tf.nn.l2_loss(var), w1, name='weight_loss')
    tf.add_to_collection('losses', weight_loss)
  return var

# 使用cifar10类下载数据集,并解压、展开到其默认位置
#cifar10.maybe_download_and_extract()

# 在使用cifar10_input类中的distorted_inputs函数产生训练需要使用的数据。需要注意的是,返回的是已经封装好的tensor,
# 且对数据进行了Data Augmentation(水平翻转、随机剪切、设置随机亮度和对比度、对数据进行标准化)
images_train, labels_train = cifar10_input.distorted_inputs(data_dir=data_dir, batch_size=batch_size)

# 再使用cifar10_input.inputs函数生成测试数据,这里不需要进行太多处理
images_test, labels_test = cifar10_input.inputs(eval_data=True,
                        data_dir=data_dir,
                        batch_size=batch_size)

# 创建数据的placeholder
image_holder = tf.placeholder(tf.float32, [batch_size, 24, 24, 3])
label_holder = tf.placeholder(tf.int32, [batch_size])

# 创建第一个卷积层
weight1 = variable_with_weight_loss(shape=[5, 5, 3, 64], stddev=5e-2,
                  w1=0.0)
kernel1 = tf.nn.conv2d(image_holder, weight1, strides=[1, 1, 1, 1], padding='SAME')
bias1 = tf.Variable(tf.constant(0.0, shape=[64]))
conv1 = tf.nn.relu(tf.nn.bias_add(kernel1, bias1))
pool1 = tf.nn.max_pool(conv1, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1],
            padding='SAME')
# LRN层对ReLU会比较有用,但不适合Sigmoid这种有固定边界并且能抑制过大值的激活函数
norm1 = tf.nn.lrn(pool1, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75)

# 创建第二个卷积层
weight2 = variable_with_weight_loss(shape=[5, 5, 64, 64], stddev=5e-2,
                  w1=0.0)
kernel2 = tf.nn.conv2d(norm1, weight2, strides=[1, 1, 1, 1], padding='SAME')
bias2 = tf.Variable(tf.constant(0.1, shape=[64]))
conv2 = tf.nn.relu(tf.nn.bias_add(kernel2, bias2))
norm2 = tf.nn.lrn(conv2, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75)
pool2 = tf.nn.max_pool(norm2, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1],
            padding='SAME')

# 使用一个全连接层
reshape = tf.reshape(pool2, [batch_size, -1])
dim = reshape.get_shape()[1].value
weight3 = variable_with_weight_loss(shape=[dim, 384], stddev=0.04, w1=0.004)
bias3 = tf.Variable(tf.constant(0.1, shape=[384]))
local3 = tf.nn.relu(tf.matmul(reshape, weight3) + bias3)

# 再使用一个全连接层,隐含节点数下降了一半,只有192个,其他的超参数保持不变
weight4 = variable_with_weight_loss(shape=[384, 192], stddev=0.04, w1=0.004)
bias4 = tf.Variable(tf.constant(0.1, shape=[192]))
local4 = tf.nn.relu(tf.matmul(local3, weight4) + bias4)

# 最后一层,将softmax放在了计算loss部分
weight5 = variable_with_weight_loss(shape=[192, 10], stddev=1 / 192.0, w1=0.0)
bias5 = tf.Variable(tf.constant(0.0, shape=[10]))
logits = tf.add(tf.matmul(local4, weight5), bias5)

# 定义loss
def loss(logits, labels):
  labels = tf.cast(labels, tf.int64)
  cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels,
                                  name='cross_entropy_per_example')
  cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy')
  tf.add_to_collection('losses', cross_entropy_mean)
  return tf.add_n(tf.get_collection('losses'), name='total_loss')

# 获取最终的loss
loss = loss(logits, label_holder)

# 优化器
train_op = tf.train.AdamOptimizer(1e-3).minimize(loss)

# 使用tf.nn.in_top_k函数求输出结果中top k的准确率,默认使用top 1,也就是输出分数最高的那一类的准确率
top_k_op = tf.nn.in_top_k(logits, label_holder, 1)

# 使用tf.InteractiveSession创建默认的session,接着初始化全部模型参数
sess = tf.InteractiveSession()
tf.global_variables_initializer().run()

# 启动图片数据增强线程
tf.train.start_queue_runners()

# 正式开始训练
for step in range(max_steps):
  start_time = time.time()
  image_batch, label_batch = sess.run([images_train, labels_train])
  _, loss_value = sess.run([train_op, loss], feed_dict={image_holder: image_batch, label_holder: label_batch})
  duration = time.time() - start_time
  if step % 10 == 0:
    example_per_sec = batch_size / duration
    sec_per_batch = float(duration)
    format_str = 'step %d, loss=%.2f ,%.1f examples/sec, %.3f sec/batch'
    print(format_str % (step, loss_value, example_per_sec, sec_per_batch))

num_examples = 10000
import math
num_iter = int(math.ceil(num_examples / batch_size))
true_count = 0
total_sample_count = num_iter * batch_size
step = 0
while step < num_iter:
  image_batch, label_batch = sess.run([images_test, labels_test])
  predictions = sess.run([top_k_op], feed_dict={image_holder: image_batch, label_holder: label_holder})
  true_count += np.sum(predictions)
  step += 1

precision = true_count / total_sample_count
print('precision @ 1 = %.3f'%precision)

运行结果:

Tensorflow卷积神经网络实例进阶

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python 随机生成中文验证码的实例代码
Mar 20 Python
python自带的http模块详解
Nov 06 Python
Python 数据结构之旋转链表
Feb 25 Python
Django视图和URL配置详解
Jan 31 Python
对Python中的条件判断、循环以及循环的终止方法详解
Feb 08 Python
更新pip3与pyttsx3文字语音转换的实现方法
Aug 08 Python
通过python检测字符串的字母
Feb 18 Python
利用python在excel中画图的实现方法
Mar 17 Python
python要安装在哪个盘
Jun 15 Python
Django url 路由匹配过程详解
Jan 22 Python
python 算法题——快乐数的多种解法
May 27 Python
Python集合set()使用的方法详解
Mar 18 Python
Tensorflow卷积神经网络实例
May 24 #Python
使用pandas的DataFrame的plot方法绘制图像的实例
May 24 #Python
TensorFlow实现卷积神经网络
May 24 #Python
tensorflow实现简单的卷积神经网络
May 24 #Python
tensorflow实现简单的卷积网络
May 24 #Python
解决pandas 作图无法显示中文的问题
May 24 #Python
TensorFlow实现简单卷积神经网络
May 24 #Python
You might like
php操作csv文件代码实例汇总
2014/09/22 PHP
PHP中实现crontab代码分享
2015/03/26 PHP
PHP判断来访是搜索引擎蜘蛛还是普通用户的代码小结
2015/09/14 PHP
PHP实现二叉树的深度优先与广度优先遍历方法
2015/09/28 PHP
学习ExtJS fit布局使用说明
2009/10/08 Javascript
基于Jquery插件开发之图片放大镜效果(仿淘宝)
2011/11/19 Javascript
JavaScript中日期函数的相关操作知识
2016/08/03 Javascript
BootStrap使用file-input插件上传图片的方法
2016/09/05 Javascript
Angular的$http与$location
2016/12/26 Javascript
详解vue2.0 使用动态组件实现 Tab 标签页切换效果(vue-cli)
2017/08/30 Javascript
使用JS动态显示文本
2017/09/09 Javascript
Node.js 使用流实现读写同步边读边写功能
2017/09/11 Javascript
360提示[高危]使用存在漏洞的JQuery版本的解决方法
2017/10/27 jQuery
webpack实现一个行内样式px转vw的loader示例
2018/09/13 Javascript
微信小程序实现的日期午别医生排班表功能示例
2019/01/09 Javascript
详解TypeScript+Vue 插件 vue-class-component的使用总结
2019/02/18 Javascript
简单了解JavaScript sort方法
2019/11/25 Javascript
vue 百度地图(vue-baidu-map)绘制方向箭头折线实例代码详解
2020/04/28 Javascript
JS实现小米轮播图
2020/09/21 Javascript
[01:20]2018DOTA2亚洲邀请赛总决赛战队LGD晋级之路
2018/04/07 DOTA
python翻译软件实现代码(使用google api完成)
2013/11/26 Python
Python爬虫实现爬取京东手机页面的图片(实例代码)
2017/11/30 Python
python生成器,可迭代对象,迭代器区别和联系
2018/02/04 Python
Django项目中用JS实现加载子页面并传值的方法
2018/05/28 Python
python函数enumerate,operator和Counter使用技巧实例小结
2020/02/22 Python
基于python实现监听Rabbitmq系统日志代码示例
2020/11/28 Python
深入探究HTML5的History API
2015/07/09 HTML / CSS
美特斯邦威官方商城:邦购网
2016/10/13 全球购物
致800米运动员广播稿
2014/02/16 职场文书
教师师德考核自我评价
2014/09/13 职场文书
保证书格式
2015/01/16 职场文书
休假证明书
2015/06/24 职场文书
2019公司管理制度
2019/04/19 职场文书
创业计划书之酒吧
2019/12/02 职场文书
python ConfigParser库的使用及遇到的坑
2022/02/12 Python
pandas时间序列之pd.to_datetime()的实现
2022/06/16 Python