Tensorflow实现AlexNet卷积神经网络及运算时间评测


Posted in Python onMay 24, 2018

本文实例为大家分享了Tensorflow实现AlexNet卷积神经网络的具体实现代码,供大家参考,具体内容如下

之前已经介绍过了AlexNet的网络构建了,这次主要不是为了训练数据,而是为了对每个batch的前馈(Forward)和反馈(backward)的平均耗时进行计算。在设计网络的过程中,分类的结果很重要,但是运算速率也相当重要。尤其是在跟踪(Tracking)的任务中,如果使用的网络太深,那么也会导致实时性不好。

from datetime import datetime
import math
import time
import tensorflow as tf

batch_size = 32
num_batches = 100

def print_activations(t):
 print(t.op.name, '', t.get_shape().as_list())

def inference(images):
 parameters = []

 with tf.name_scope('conv1') as scope:
  kernel = tf.Variable(tf.truncated_normal([11, 11, 3, 64], dtype = tf.float32, stddev = 1e-1), name = 'weights')
  conv = tf.nn.conv2d(images, kernel, [1, 4, 4, 1], padding = 'SAME')
  biases = tf.Variable(tf.constant(0.0, shape = [64], dtype = tf.float32), trainable = True, name = 'biases')
  bias = tf.nn.bias_add(conv, biases)
  conv1 = tf.nn.relu(bias, name = scope)
  print_activations(conv1)
  parameters += [kernel, biases]

  lrn1 = tf.nn.lrn(conv1, 4, bias = 1.0, alpha = 0.001 / 9, beta = 0.75, name = 'lrn1')
  pool1 = tf.nn.max_pool(lrn1, ksize = [1, 3, 3, 1], strides = [1, 2, 2, 1], padding = 'VALID', name = 'pool1')
  print_activations(pool1)

 with tf.name_scope('conv2') as scope:
  kernel = tf.Variable(tf.truncated_normal([5, 5, 64, 192], dtype = tf.float32, stddev = 1e-1), name = 'weights')
  conv = tf.nn.conv2d(pool1, kernel, [1, 1, 1, 1], padding = 'SAME')
  biases = tf.Variable(tf.constant(0.0, shape = [192], dtype = tf.float32), trainable = True, name = 'biases')
  bias = tf.nn.bias_add(conv, biases)
  conv2 = tf.nn.relu(bias, name = scope)
  parameters += [kernel, biases]
  print_activations(conv2)

  lrn2 = tf.nn.lrn(conv2, 4, bias = 1.0, alpha = 0.001 / 9, beta = 0.75, name = 'lrn2')
  pool2 = tf.nn.max_pool(lrn2, ksize = [1, 3, 3, 1], strides = [1, 2, 2, 1], padding = 'VALID', name = 'pool2')
  print_activations(pool2)

 with tf.name_scope('conv3') as scope:
  kernel = tf.Variable(tf.truncated_normal([3, 3, 192, 384], dtype = tf.float32, stddev = 1e-1), name = 'weights')
  conv = tf.nn.conv2d(pool2, kernel, [1, 1, 1, 1], padding = 'SAME')
  biases = tf.Variable(tf.constant(0.0, shape = [384], dtype = tf.float32), trainable = True, name = 'biases')
  bias = tf.nn.bias_add(conv, biases)
  conv3 = tf.nn.relu(bias, name = scope)
  parameters += [kernel, biases]
  print_activations(conv3)

 with tf.name_scope('conv4') as scope:
  kernel = tf.Variable(tf.truncated_normal([3, 3, 384, 256], dtype = tf.float32, stddev = 1e-1), name = 'weights')
  conv = tf.nn.conv2d(conv3, kernel, [1, 1, 1, 1], padding = 'SAME')
  biases = tf.Variable(tf.constant(0.0, shape = [256], dtype = tf.float32), trainable = True, name = 'biases')
  bias = tf.nn.bias_add(conv, biases)
  conv4 = tf.nn.relu(bias, name = scope)
  parameters += [kernel, biases]
  print_activations(conv4)

 with tf.name_scope('conv5') as scope:
  kernel = tf.Variable(tf.truncated_normal([3, 3, 256, 256], dtype = tf.float32, stddev = 1e-1), name = 'weights')
  conv = tf.nn.conv2d(conv4, kernel, [1, 1, 1, 1], padding = 'SAME')
  biases = tf.Variable(tf.constant(0.0, shape = [256], dtype = tf.float32), trainable = True, name = 'biases')
  bias = tf.nn.bias_add(conv, biases)
  conv5 = tf.nn.relu(bias, name = scope)
  parameters += [kernel, biases]
  print_activations(conv5)

  pool5 = tf.nn.max_pool(conv5, ksize = [1, 3, 3, 1], strides = [1, 2, 2, 1], padding = 'VALID', name = 'pool5')
  print_activations(pool5)

  return pool5, parameters

def time_tensorflow_run(session, target, info_string):
 num_steps_burn_in = 10
 total_duration = 0.0
 total_duration_squared = 0.0

 for i in range(num_batches + num_steps_burn_in):
  start_time = time.time()
  _ = session.run(target)
  duration = time.time() - start_time
  if i >= num_steps_burn_in:
   if not i % 10:
    print('%s: step %d, duration = %.3f' %(datetime.now(), i - num_steps_burn_in, duration))
   total_duration += duration
   total_duration_squared += duration * duration

 mn = total_duration / num_batches
 vr = total_duration_squared / num_batches - mn * mn
 sd = math.sqrt(vr)
 print('%s: %s across %d steps, %.3f +/- %.3f sec / batch' %(datetime.now(), info_string, num_batches, mn, sd))

def run_benchmark():
 with tf.Graph().as_default():
  image_size = 224
  images = tf.Variable(tf.random_normal([batch_size, image_size, image_size, 3], dtype = tf.float32, stddev = 1e-1))
  pool5, parameters = inference(images)

  init = tf.global_variables_initializer()
  sess = tf.Session()
  sess.run(init)

  time_tensorflow_run(sess, pool5, "Forward")

  objective = tf.nn.l2_loss(pool5)
  grad = tf.gradients(objective, parameters)
  time_tensorflow_run(sess, grad, "Forward-backward")


run_benchmark()

这里的代码都是之前讲过的,只是加了一个计算时间和现实网络的卷积核的函数,应该很容易就看懂了,就不多赘述了。我在GTX TITAN X上前馈大概需要0.024s, 反馈大概需要0.079s。哈哈,自己动手试一试哦。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python列出目录下指定文件与子目录的方法
Jul 03 Python
详谈python3 numpy-loadtxt的编码问题
Apr 29 Python
pandas.DataFrame.to_json按行转json的方法
Jun 05 Python
使用pandas实现csv/excel sheet互相转换的方法
Dec 10 Python
python正则表达式去除两个特殊字符间的内容方法
Dec 24 Python
Python3安装Pillow与PIL的方法
Apr 03 Python
Python实现钉钉订阅消息功能
Jan 14 Python
Windows+Anaconda3+PyTorch+PyCharm的安装教程图文详解
Apr 03 Python
tensorflow 20:搭网络,导出模型,运行模型的实例
May 26 Python
Python-opencv实现红绿两色识别操作
Jun 04 Python
如何对python的字典进行排序
Jun 19 Python
Python编写nmap扫描工具
Jul 21 Python
Tensorflow卷积神经网络实例进阶
May 24 #Python
Tensorflow卷积神经网络实例
May 24 #Python
使用pandas的DataFrame的plot方法绘制图像的实例
May 24 #Python
TensorFlow实现卷积神经网络
May 24 #Python
tensorflow实现简单的卷积神经网络
May 24 #Python
tensorflow实现简单的卷积网络
May 24 #Python
解决pandas 作图无法显示中文的问题
May 24 #Python
You might like
基于qmail的完整WEBMAIL解决方案安装详解
2006/10/09 PHP
php记录日志的实现代码
2011/08/08 PHP
匹配csdn用户数据库与官方用户的重合度并将重叠部分的用户筛选出来
2011/12/25 PHP
使用JavaScript创建新样式表和新样式规则
2016/06/14 PHP
PHP数组生成XML格式数据的封装类实例
2016/11/10 PHP
PHP实现蛇形矩阵,回环矩阵及数字螺旋矩阵的方法分析
2017/05/29 PHP
PHP进阶学习之垃圾回收机制详解
2019/06/18 PHP
JavaScript 组件之旅(二)编码实现和算法
2009/10/28 Javascript
JQuery的ready函数与JS的onload的区别详解
2013/11/21 Javascript
一段非常简单的js判断浏览器的内核
2014/08/17 Javascript
完美兼容各大浏览器的jQuery插件实现图片切换特效
2014/12/12 Javascript
angularJS 中input示例分享
2015/02/09 Javascript
js实现鼠标经过时图片滚动停止的方法
2015/02/16 Javascript
深入理解jQuery事件绑定
2016/06/02 Javascript
js实现常见的工具条效果
2017/03/02 Javascript
js实现简单的获取验证码按钮效果
2017/03/03 Javascript
Node.js设置CORS跨域请求中多域名白名单的方法
2017/03/28 Javascript
使用vue中的v-for遍历二维数组的方法
2018/03/07 Javascript
javascript匿名函数中的'return function()'作用
2018/10/15 Javascript
JS浅拷贝和深拷贝原理与实现方法分析
2019/02/28 Javascript
js变量值传到php过程详解 将php解析成数据
2019/06/26 Javascript
微信小程序:报错(in promise) MiniProgramError
2020/10/30 Javascript
Centos5.x下升级python到python2.7版本教程
2015/02/14 Python
pygame播放音乐的方法
2015/05/19 Python
Python yield 使用浅析
2015/05/28 Python
python处理xml文件的方法小结
2017/05/02 Python
python3+dlib实现人脸识别和情绪分析
2018/04/21 Python
python替换字符串中的子串图文步骤
2019/06/19 Python
对Django 中request.get和request.post的区别详解
2019/08/12 Python
python实现截取屏幕保存文件,删除N天前截图的例子
2019/08/27 Python
成教自我鉴定
2013/10/27 职场文书
工作表扬信的范文
2014/01/10 职场文书
新闻报道策划方案
2014/06/11 职场文书
项目投资意向书范本
2015/05/09 职场文书
React Native项目框架搭建的一些心得体会
2021/05/28 Javascript
MySQL中in和exists区别详解
2021/06/03 MySQL