Tensorflow实现AlexNet卷积神经网络及运算时间评测


Posted in Python onMay 24, 2018

本文实例为大家分享了Tensorflow实现AlexNet卷积神经网络的具体实现代码,供大家参考,具体内容如下

之前已经介绍过了AlexNet的网络构建了,这次主要不是为了训练数据,而是为了对每个batch的前馈(Forward)和反馈(backward)的平均耗时进行计算。在设计网络的过程中,分类的结果很重要,但是运算速率也相当重要。尤其是在跟踪(Tracking)的任务中,如果使用的网络太深,那么也会导致实时性不好。

from datetime import datetime
import math
import time
import tensorflow as tf

batch_size = 32
num_batches = 100

def print_activations(t):
 print(t.op.name, '', t.get_shape().as_list())

def inference(images):
 parameters = []

 with tf.name_scope('conv1') as scope:
  kernel = tf.Variable(tf.truncated_normal([11, 11, 3, 64], dtype = tf.float32, stddev = 1e-1), name = 'weights')
  conv = tf.nn.conv2d(images, kernel, [1, 4, 4, 1], padding = 'SAME')
  biases = tf.Variable(tf.constant(0.0, shape = [64], dtype = tf.float32), trainable = True, name = 'biases')
  bias = tf.nn.bias_add(conv, biases)
  conv1 = tf.nn.relu(bias, name = scope)
  print_activations(conv1)
  parameters += [kernel, biases]

  lrn1 = tf.nn.lrn(conv1, 4, bias = 1.0, alpha = 0.001 / 9, beta = 0.75, name = 'lrn1')
  pool1 = tf.nn.max_pool(lrn1, ksize = [1, 3, 3, 1], strides = [1, 2, 2, 1], padding = 'VALID', name = 'pool1')
  print_activations(pool1)

 with tf.name_scope('conv2') as scope:
  kernel = tf.Variable(tf.truncated_normal([5, 5, 64, 192], dtype = tf.float32, stddev = 1e-1), name = 'weights')
  conv = tf.nn.conv2d(pool1, kernel, [1, 1, 1, 1], padding = 'SAME')
  biases = tf.Variable(tf.constant(0.0, shape = [192], dtype = tf.float32), trainable = True, name = 'biases')
  bias = tf.nn.bias_add(conv, biases)
  conv2 = tf.nn.relu(bias, name = scope)
  parameters += [kernel, biases]
  print_activations(conv2)

  lrn2 = tf.nn.lrn(conv2, 4, bias = 1.0, alpha = 0.001 / 9, beta = 0.75, name = 'lrn2')
  pool2 = tf.nn.max_pool(lrn2, ksize = [1, 3, 3, 1], strides = [1, 2, 2, 1], padding = 'VALID', name = 'pool2')
  print_activations(pool2)

 with tf.name_scope('conv3') as scope:
  kernel = tf.Variable(tf.truncated_normal([3, 3, 192, 384], dtype = tf.float32, stddev = 1e-1), name = 'weights')
  conv = tf.nn.conv2d(pool2, kernel, [1, 1, 1, 1], padding = 'SAME')
  biases = tf.Variable(tf.constant(0.0, shape = [384], dtype = tf.float32), trainable = True, name = 'biases')
  bias = tf.nn.bias_add(conv, biases)
  conv3 = tf.nn.relu(bias, name = scope)
  parameters += [kernel, biases]
  print_activations(conv3)

 with tf.name_scope('conv4') as scope:
  kernel = tf.Variable(tf.truncated_normal([3, 3, 384, 256], dtype = tf.float32, stddev = 1e-1), name = 'weights')
  conv = tf.nn.conv2d(conv3, kernel, [1, 1, 1, 1], padding = 'SAME')
  biases = tf.Variable(tf.constant(0.0, shape = [256], dtype = tf.float32), trainable = True, name = 'biases')
  bias = tf.nn.bias_add(conv, biases)
  conv4 = tf.nn.relu(bias, name = scope)
  parameters += [kernel, biases]
  print_activations(conv4)

 with tf.name_scope('conv5') as scope:
  kernel = tf.Variable(tf.truncated_normal([3, 3, 256, 256], dtype = tf.float32, stddev = 1e-1), name = 'weights')
  conv = tf.nn.conv2d(conv4, kernel, [1, 1, 1, 1], padding = 'SAME')
  biases = tf.Variable(tf.constant(0.0, shape = [256], dtype = tf.float32), trainable = True, name = 'biases')
  bias = tf.nn.bias_add(conv, biases)
  conv5 = tf.nn.relu(bias, name = scope)
  parameters += [kernel, biases]
  print_activations(conv5)

  pool5 = tf.nn.max_pool(conv5, ksize = [1, 3, 3, 1], strides = [1, 2, 2, 1], padding = 'VALID', name = 'pool5')
  print_activations(pool5)

  return pool5, parameters

def time_tensorflow_run(session, target, info_string):
 num_steps_burn_in = 10
 total_duration = 0.0
 total_duration_squared = 0.0

 for i in range(num_batches + num_steps_burn_in):
  start_time = time.time()
  _ = session.run(target)
  duration = time.time() - start_time
  if i >= num_steps_burn_in:
   if not i % 10:
    print('%s: step %d, duration = %.3f' %(datetime.now(), i - num_steps_burn_in, duration))
   total_duration += duration
   total_duration_squared += duration * duration

 mn = total_duration / num_batches
 vr = total_duration_squared / num_batches - mn * mn
 sd = math.sqrt(vr)
 print('%s: %s across %d steps, %.3f +/- %.3f sec / batch' %(datetime.now(), info_string, num_batches, mn, sd))

def run_benchmark():
 with tf.Graph().as_default():
  image_size = 224
  images = tf.Variable(tf.random_normal([batch_size, image_size, image_size, 3], dtype = tf.float32, stddev = 1e-1))
  pool5, parameters = inference(images)

  init = tf.global_variables_initializer()
  sess = tf.Session()
  sess.run(init)

  time_tensorflow_run(sess, pool5, "Forward")

  objective = tf.nn.l2_loss(pool5)
  grad = tf.gradients(objective, parameters)
  time_tensorflow_run(sess, grad, "Forward-backward")


run_benchmark()

这里的代码都是之前讲过的,只是加了一个计算时间和现实网络的卷积核的函数,应该很容易就看懂了,就不多赘述了。我在GTX TITAN X上前馈大概需要0.024s, 反馈大概需要0.079s。哈哈,自己动手试一试哦。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
详解python时间模块中的datetime模块
Jan 13 Python
Python模拟百度登录实例详解
Jan 20 Python
python实现聚类算法原理
Feb 12 Python
python K近邻算法的kd树实现
Sep 06 Python
对Python发送带header的http请求方法详解
Jan 02 Python
python 字典 setdefault()和get()方法比较详解
Aug 07 Python
详解django实现自定义manage命令的扩展
Aug 13 Python
PHP统计代码行数的小代码
Sep 19 Python
tensorflow保持每次训练结果一致的简单实现
Feb 17 Python
Python基于jieba, wordcloud库生成中文词云
May 13 Python
什么是python类属性
Jun 10 Python
python 装饰器的基本使用
Jan 13 Python
Tensorflow卷积神经网络实例进阶
May 24 #Python
Tensorflow卷积神经网络实例
May 24 #Python
使用pandas的DataFrame的plot方法绘制图像的实例
May 24 #Python
TensorFlow实现卷积神经网络
May 24 #Python
tensorflow实现简单的卷积神经网络
May 24 #Python
tensorflow实现简单的卷积网络
May 24 #Python
解决pandas 作图无法显示中文的问题
May 24 #Python
You might like
DC动画很好看?新作烂得令人发指,名叫《红色之子》
2020/04/09 欧美动漫
创建配置文件 用PHP写出自己的BLOG系统 2
2010/04/12 PHP
php获取后台Job管理的实现代码
2011/06/10 PHP
PHP之sprintf函数用法详解
2014/11/12 PHP
php里array_work用法实例分析
2015/07/13 PHP
cakephp2.X多表联合查询join及使用分页查询的方法
2017/02/23 PHP
JavaScript使用cookie
2007/02/02 Javascript
javascript的键盘控制事件说明
2008/04/15 Javascript
查看源码的工具 学习jQuery源码不错的工具
2011/12/26 Javascript
javascript图片滑动效果实现
2021/01/28 Javascript
ES6的新特性概览
2016/03/10 Javascript
早该知道的7个JavaScript技巧
2016/06/21 Javascript
js HTML5手机刮刮乐代码
2020/09/29 Javascript
javaScript语法总结
2016/11/25 Javascript
JavaScript 字符串常用操作小结(非常实用)
2016/11/30 Javascript
详解Immutable及 React 中实践
2018/03/01 Javascript
React中获取数据的3种方法及优缺点
2020/02/18 Javascript
python 正则式 概述及常用字符
2009/05/07 Python
python数据结构之二叉树的统计与转换实例
2014/04/29 Python
Python列表list数组array用法实例解析
2014/10/28 Python
PyQt 线程类 QThread使用详解
2017/07/16 Python
python3+PyQt5实现自定义流体混合窗口部件
2018/04/24 Python
python3.X 抓取火车票信息【修正版】
2018/06/19 Python
Python提取频域特征知识点浅析
2019/03/04 Python
Python 面向对象之封装、继承、多态操作实例分析
2019/11/21 Python
jupyter notebook 多环境conda kernel配置方式
2020/04/10 Python
使用pymysql查询数据库,把结果保存为列表并获取指定元素下标实例
2020/05/15 Python
世界上最大的巴士旅游观光公司:Big Bus Tours
2016/10/20 全球购物
美国时尚女装在线:Missguided
2016/12/03 全球购物
大学生旅游业创业计划书
2014/01/29 职场文书
拾金不昧锦旗标语
2014/06/27 职场文书
股指期货心得体会
2014/09/10 职场文书
思想作风整顿个人剖析材料
2014/10/06 职场文书
影视后期实训报告
2014/11/05 职场文书
统计工作个人总结
2015/03/03 职场文书
2015年银行个人工作总结
2015/05/14 职场文书