Tensorflow实现AlexNet卷积神经网络及运算时间评测


Posted in Python onMay 24, 2018

本文实例为大家分享了Tensorflow实现AlexNet卷积神经网络的具体实现代码,供大家参考,具体内容如下

之前已经介绍过了AlexNet的网络构建了,这次主要不是为了训练数据,而是为了对每个batch的前馈(Forward)和反馈(backward)的平均耗时进行计算。在设计网络的过程中,分类的结果很重要,但是运算速率也相当重要。尤其是在跟踪(Tracking)的任务中,如果使用的网络太深,那么也会导致实时性不好。

from datetime import datetime
import math
import time
import tensorflow as tf

batch_size = 32
num_batches = 100

def print_activations(t):
 print(t.op.name, '', t.get_shape().as_list())

def inference(images):
 parameters = []

 with tf.name_scope('conv1') as scope:
  kernel = tf.Variable(tf.truncated_normal([11, 11, 3, 64], dtype = tf.float32, stddev = 1e-1), name = 'weights')
  conv = tf.nn.conv2d(images, kernel, [1, 4, 4, 1], padding = 'SAME')
  biases = tf.Variable(tf.constant(0.0, shape = [64], dtype = tf.float32), trainable = True, name = 'biases')
  bias = tf.nn.bias_add(conv, biases)
  conv1 = tf.nn.relu(bias, name = scope)
  print_activations(conv1)
  parameters += [kernel, biases]

  lrn1 = tf.nn.lrn(conv1, 4, bias = 1.0, alpha = 0.001 / 9, beta = 0.75, name = 'lrn1')
  pool1 = tf.nn.max_pool(lrn1, ksize = [1, 3, 3, 1], strides = [1, 2, 2, 1], padding = 'VALID', name = 'pool1')
  print_activations(pool1)

 with tf.name_scope('conv2') as scope:
  kernel = tf.Variable(tf.truncated_normal([5, 5, 64, 192], dtype = tf.float32, stddev = 1e-1), name = 'weights')
  conv = tf.nn.conv2d(pool1, kernel, [1, 1, 1, 1], padding = 'SAME')
  biases = tf.Variable(tf.constant(0.0, shape = [192], dtype = tf.float32), trainable = True, name = 'biases')
  bias = tf.nn.bias_add(conv, biases)
  conv2 = tf.nn.relu(bias, name = scope)
  parameters += [kernel, biases]
  print_activations(conv2)

  lrn2 = tf.nn.lrn(conv2, 4, bias = 1.0, alpha = 0.001 / 9, beta = 0.75, name = 'lrn2')
  pool2 = tf.nn.max_pool(lrn2, ksize = [1, 3, 3, 1], strides = [1, 2, 2, 1], padding = 'VALID', name = 'pool2')
  print_activations(pool2)

 with tf.name_scope('conv3') as scope:
  kernel = tf.Variable(tf.truncated_normal([3, 3, 192, 384], dtype = tf.float32, stddev = 1e-1), name = 'weights')
  conv = tf.nn.conv2d(pool2, kernel, [1, 1, 1, 1], padding = 'SAME')
  biases = tf.Variable(tf.constant(0.0, shape = [384], dtype = tf.float32), trainable = True, name = 'biases')
  bias = tf.nn.bias_add(conv, biases)
  conv3 = tf.nn.relu(bias, name = scope)
  parameters += [kernel, biases]
  print_activations(conv3)

 with tf.name_scope('conv4') as scope:
  kernel = tf.Variable(tf.truncated_normal([3, 3, 384, 256], dtype = tf.float32, stddev = 1e-1), name = 'weights')
  conv = tf.nn.conv2d(conv3, kernel, [1, 1, 1, 1], padding = 'SAME')
  biases = tf.Variable(tf.constant(0.0, shape = [256], dtype = tf.float32), trainable = True, name = 'biases')
  bias = tf.nn.bias_add(conv, biases)
  conv4 = tf.nn.relu(bias, name = scope)
  parameters += [kernel, biases]
  print_activations(conv4)

 with tf.name_scope('conv5') as scope:
  kernel = tf.Variable(tf.truncated_normal([3, 3, 256, 256], dtype = tf.float32, stddev = 1e-1), name = 'weights')
  conv = tf.nn.conv2d(conv4, kernel, [1, 1, 1, 1], padding = 'SAME')
  biases = tf.Variable(tf.constant(0.0, shape = [256], dtype = tf.float32), trainable = True, name = 'biases')
  bias = tf.nn.bias_add(conv, biases)
  conv5 = tf.nn.relu(bias, name = scope)
  parameters += [kernel, biases]
  print_activations(conv5)

  pool5 = tf.nn.max_pool(conv5, ksize = [1, 3, 3, 1], strides = [1, 2, 2, 1], padding = 'VALID', name = 'pool5')
  print_activations(pool5)

  return pool5, parameters

def time_tensorflow_run(session, target, info_string):
 num_steps_burn_in = 10
 total_duration = 0.0
 total_duration_squared = 0.0

 for i in range(num_batches + num_steps_burn_in):
  start_time = time.time()
  _ = session.run(target)
  duration = time.time() - start_time
  if i >= num_steps_burn_in:
   if not i % 10:
    print('%s: step %d, duration = %.3f' %(datetime.now(), i - num_steps_burn_in, duration))
   total_duration += duration
   total_duration_squared += duration * duration

 mn = total_duration / num_batches
 vr = total_duration_squared / num_batches - mn * mn
 sd = math.sqrt(vr)
 print('%s: %s across %d steps, %.3f +/- %.3f sec / batch' %(datetime.now(), info_string, num_batches, mn, sd))

def run_benchmark():
 with tf.Graph().as_default():
  image_size = 224
  images = tf.Variable(tf.random_normal([batch_size, image_size, image_size, 3], dtype = tf.float32, stddev = 1e-1))
  pool5, parameters = inference(images)

  init = tf.global_variables_initializer()
  sess = tf.Session()
  sess.run(init)

  time_tensorflow_run(sess, pool5, "Forward")

  objective = tf.nn.l2_loss(pool5)
  grad = tf.gradients(objective, parameters)
  time_tensorflow_run(sess, grad, "Forward-backward")


run_benchmark()

这里的代码都是之前讲过的,只是加了一个计算时间和现实网络的卷积核的函数,应该很容易就看懂了,就不多赘述了。我在GTX TITAN X上前馈大概需要0.024s, 反馈大概需要0.079s。哈哈,自己动手试一试哦。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
用Python编写简单的定时器的方法
May 02 Python
python cx_Oracle的基础使用方法(连接和增删改查)
Nov 19 Python
机器学习之KNN算法原理及Python实现方法详解
Jul 09 Python
mac下如何将python2.7改为python3
Jul 13 Python
python实现求两个字符串的最长公共子串方法
Jul 20 Python
python之cv2与图像的载入、显示和保存实例
Dec 05 Python
Pycharm 2020最新永久激活码(附最新激活码和插件)
Sep 17 Python
PyCharm 2020.2 安装详细教程
Sep 25 Python
pycharm如何设置官方中文(如何汉化)
Dec 29 Python
如何在Python中创建二叉树
Mar 30 Python
Python办公自动化之教你用Python批量识别发票并录入到Excel表格中
Jun 26 Python
分享7个 Python 实战项目练习
Mar 03 Python
Tensorflow卷积神经网络实例进阶
May 24 #Python
Tensorflow卷积神经网络实例
May 24 #Python
使用pandas的DataFrame的plot方法绘制图像的实例
May 24 #Python
TensorFlow实现卷积神经网络
May 24 #Python
tensorflow实现简单的卷积神经网络
May 24 #Python
tensorflow实现简单的卷积网络
May 24 #Python
解决pandas 作图无法显示中文的问题
May 24 #Python
You might like
php 广告调用类代码(支持Flash调用)
2011/08/11 PHP
CI框架安全类Security.php源码分析
2014/11/04 PHP
php7安装mongoDB扩展的方法分析
2017/08/02 PHP
实例分析PHP中PHPMailer发邮件
2017/12/13 PHP
php从数据库中获取数据用ajax传送到前台的方法
2018/08/20 PHP
非常不错的一个javascript 类
2006/11/07 Javascript
JS实现根据当前文字选择返回被选中的文字
2014/05/21 Javascript
jquery禁止回车触发表单提交
2014/12/12 Javascript
简介JavaScript中的setTime()方法的使用
2015/06/11 Javascript
JavaScript浏览器对象之一Window对象详解
2016/06/03 Javascript
Vue.js第二天学习笔记(vue-router)
2016/12/01 Javascript
jQuery validate 验证radio实例
2017/03/01 Javascript
JSONP跨域请求
2017/03/02 Javascript
bootstrap3-dialog-master模态框使用详解
2017/08/22 Javascript
看看“疫苗查询”小程序有温度的代码
2018/07/31 Javascript
详解Node.js中path模块的resolve()和join()方法的区别
2018/10/29 Javascript
微信小程序rich-text富文本用法实例分析
2019/05/20 Javascript
js中let能否完全替代IIFE
2019/06/15 Javascript
云服务器部署Node.js项目的方法步骤(小白系列)
2020/03/23 Javascript
Vue中keep-alive组件的深入理解
2020/08/23 Javascript
python读取TXT到数组及列表去重后按原来顺序排序的方法
2015/06/26 Python
Python下调用Linux的Shell命令的方法
2018/06/12 Python
python实现单链表中删除倒数第K个节点的方法
2018/09/28 Python
django框架面向对象ORM模型继承用法实例分析
2019/07/29 Python
python多进程 主进程和子进程间共享和不共享全局变量实例
2020/04/25 Python
Python基于smtplib协议实现发送邮件
2020/06/03 Python
python进行OpenCV实战之画图(直线、矩形、圆形)
2020/08/27 Python
浅谈css3中的渐进增强和优雅降级
2017/12/01 HTML / CSS
浅析HTML5中的 History 模式
2017/06/22 HTML / CSS
Unineed旗下时尚轻奢网站:FABHunt
2019/05/13 全球购物
德国富尔达运动鞋店:43einhalb
2020/12/25 全球购物
违规违纪检讨书范文
2015/05/06 职场文书
礼仪培训心得体会
2016/01/22 职场文书
python glom模块的使用简介
2021/04/13 Python
vue完美实现el-table列宽自适应
2021/05/08 Vue.js
JavaScript实现九宫格拖拽效果
2022/06/28 Javascript