使用tensorflow实现AlexNet


Posted in Python onNovember 20, 2017

AlexNet是2012年ImageNet比赛的冠军,虽然过去了很长时间,但是作为深度学习中的经典模型,AlexNet不但有助于我们理解其中所使用的很多技巧,而且非常有助于提升我们使用深度学习工具箱的熟练度。尤其是我刚入门深度学习,迫切需要一个能让自己熟悉tensorflow的小练习,于是就有了这个小玩意儿......

先放上我的代码:https://github.com/hjptriplebee/AlexNet_with_tensorflow

如果想运行代码,详细的配置要求都在上面链接的readme文件中了。本文建立在一定的tensorflow基础上,不会对太细的点进行说明。

模型结构

使用tensorflow实现AlexNet

关于模型结构网上的文献很多,我这里不赘述,一会儿都在代码里解释。

有一点需要注意,AlexNet将网络分成了上下两个部分,在论文中两部分结构完全相同,唯一不同的是他们放在不同GPU上训练,因为每一层的feature map之间都是独立的(除了全连接层),所以这相当于是提升训练速度的一种方法。很多AlexNet的复现都将上下两部分合并了,因为他们都是在单个GPU上运行的。虽然我也是在单个GPU上运行,但是我还是很想将最原始的网络结构还原出来,所以我的代码里也是分开的。

模型定义

def maxPoolLayer(x, kHeight, kWidth, strideX, strideY, name, padding = "SAME"): 
  """max-pooling""" 
  return tf.nn.max_pool(x, ksize = [1, kHeight, kWidth, 1], 
             strides = [1, strideX, strideY, 1], padding = padding, name = name) 
 
def dropout(x, keepPro, name = None): 
  """dropout""" 
  return tf.nn.dropout(x, keepPro, name) 
 
def LRN(x, R, alpha, beta, name = None, bias = 1.0): 
  """LRN""" 
  return tf.nn.local_response_normalization(x, depth_radius = R, alpha = alpha, 
                       beta = beta, bias = bias, name = name) 
 
def fcLayer(x, inputD, outputD, reluFlag, name): 
  """fully-connect""" 
  with tf.variable_scope(name) as scope: 
    w = tf.get_variable("w", shape = [inputD, outputD], dtype = "float") 
    b = tf.get_variable("b", [outputD], dtype = "float") 
    out = tf.nn.xw_plus_b(x, w, b, name = scope.name) 
    if reluFlag: 
      return tf.nn.relu(out) 
    else: 
      return out 
 
def convLayer(x, kHeight, kWidth, strideX, strideY, 
       featureNum, name, padding = "SAME", groups = 1):#group为2时等于AlexNet中分上下两部分 
  """convlutional""" 
  channel = int(x.get_shape()[-1])#获取channel 
  conv = lambda a, b: tf.nn.conv2d(a, b, strides = [1, strideY, strideX, 1], padding = padding)#定义卷积的匿名函数 
  with tf.variable_scope(name) as scope: 
    w = tf.get_variable("w", shape = [kHeight, kWidth, channel/groups, featureNum]) 
    b = tf.get_variable("b", shape = [featureNum]) 
 
    xNew = tf.split(value = x, num_or_size_splits = groups, axis = 3)#划分后的输入和权重 
    wNew = tf.split(value = w, num_or_size_splits = groups, axis = 3) 
 
    featureMap = [conv(t1, t2) for t1, t2 in zip(xNew, wNew)] #分别提取feature map 
    mergeFeatureMap = tf.concat(axis = 3, values = featureMap) #feature map整合 
    # print mergeFeatureMap.shape 
    out = tf.nn.bias_add(mergeFeatureMap, b) 
    return tf.nn.relu(tf.reshape(out, mergeFeatureMap.get_shape().as_list()), name = scope.name) #relu后的结果

定义了卷积、pooling、LRN、dropout、全连接五个模块,其中卷积模块因为将网络的上下两部分分开了,所以比较复杂。接下来定义AlexNet。

class alexNet(object): 
  """alexNet model""" 
  def __init__(self, x, keepPro, classNum, skip, modelPath = "bvlc_alexnet.npy"): 
    self.X = x 
    self.KEEPPRO = keepPro 
    self.CLASSNUM = classNum 
    self.SKIP = skip 
    self.MODELPATH = modelPath 
    #build CNN 
    self.buildCNN() 
 
  def buildCNN(self): 
    """build model""" 
    conv1 = convLayer(self.X, 11, 11, 4, 4, 96, "conv1", "VALID") 
    pool1 = maxPoolLayer(conv1, 3, 3, 2, 2, "pool1", "VALID") 
    lrn1 = LRN(pool1, 2, 2e-05, 0.75, "norm1") 
 
    conv2 = convLayer(lrn1, 5, 5, 1, 1, 256, "conv2", groups = 2) 
    pool2 = maxPoolLayer(conv2, 3, 3, 2, 2, "pool2", "VALID") 
    lrn2 = LRN(pool2, 2, 2e-05, 0.75, "lrn2") 
 
    conv3 = convLayer(lrn2, 3, 3, 1, 1, 384, "conv3") 
 
    conv4 = convLayer(conv3, 3, 3, 1, 1, 384, "conv4", groups = 2) 
 
    conv5 = convLayer(conv4, 3, 3, 1, 1, 256, "conv5", groups = 2) 
    pool5 = maxPoolLayer(conv5, 3, 3, 2, 2, "pool5", "VALID") 
 
    fcIn = tf.reshape(pool5, [-1, 256 * 6 * 6]) 
    fc1 = fcLayer(fcIn, 256 * 6 * 6, 4096, True, "fc6") 
    dropout1 = dropout(fc1, self.KEEPPRO) 
 
    fc2 = fcLayer(dropout1, 4096, 4096, True, "fc7") 
    dropout2 = dropout(fc2, self.KEEPPRO) 
 
    self.fc3 = fcLayer(dropout2, 4096, self.CLASSNUM, True, "fc8") 
 
  def loadModel(self, sess): 
    """load model""" 
    wDict = np.load(self.MODELPATH, encoding = "bytes").item() 
    #for layers in model 
    for name in wDict: 
      if name not in self.SKIP: 
        with tf.variable_scope(name, reuse = True): 
          for p in wDict[name]: 
            if len(p.shape) == 1:  
              #bias 只有一维 
              sess.run(tf.get_variable('b', trainable = False).assign(p)) 
            else: 
              #weights  
              sess.run(tf.get_variable('w', trainable = False).assign(p))

buildCNN函数完全按照alexnet的结构搭建网络。
loadModel函数从模型文件中读取参数,采用的模型文件见github上的readme说明。
至此,我们定义了完整的模型,下面开始测试模型。

模型测试

ImageNet训练的AlexNet有很多类,几乎包含所有常见的物体,因此我们随便从网上找几张图片测试。比如我直接用了之前做项目的渣土车图片:

使用tensorflow实现AlexNet

然后编写测试代码:

#some params 
dropoutPro = 1 
classNum = 1000 
skip = [] 
#get testImage 
testPath = "testModel" 
testImg = [] 
for f in os.listdir(testPath): 
  testImg.append(cv2.imread(testPath + "/" + f)) 
 
imgMean = np.array([104, 117, 124], np.float) 
x = tf.placeholder("float", [1, 227, 227, 3]) 
 
model = alexnet.alexNet(x, dropoutPro, classNum, skip) 
score = model.fc3 
softmax = tf.nn.softmax(score) 
 
with tf.Session() as sess: 
  sess.run(tf.global_variables_initializer()) 
  model.loadModel(sess) #加载模型 
 
  for i, img in enumerate(testImg): 
    #img preprocess 
    test = cv2.resize(img.astype(np.float), (227, 227)) #resize成网络输入大小 
    test -= imgMean #去均值 
    test = test.reshape((1, 227, 227, 3)) #拉成tensor 
    maxx = np.argmax(sess.run(softmax, feed_dict = {x: test})) 
    res = caffe_classes.class_names[maxx] #取概率最大类的下标 
    #print(res) 
    font = cv2.FONT_HERSHEY_SIMPLEX 
    cv2.putText(img, res, (int(img.shape[0]/3), int(img.shape[1]/3)), font, 1, (0, 255, 0), 2)#绘制类的名字 
    cv2.imshow("demo", img)  
    cv2.waitKey(5000) #显示5秒

如上代码所示,首先需要设置一些参数,然后读取指定路径下的测试图像,再对模型做一个初始化,最后是真正测试代码。测试结果如下:

使用tensorflow实现AlexNet

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python基础教程之循环介绍
Aug 29 Python
Linux下使用python调用top命令获得CPU利用率
Mar 10 Python
在Linux系统上安装Python的Scrapy框架的教程
Jun 11 Python
python实现微信远程控制电脑
Feb 22 Python
python 创建一个空dataframe 然后添加行数据的实例
Jun 07 Python
bluepy 一款python封装的BLE利器简单介绍
Jun 25 Python
简单易懂Pytorch实战实例VGG深度网络
Aug 27 Python
pytorch方法测试——激活函数(ReLU)详解
Jan 15 Python
django 外键创建注意事项说明
May 20 Python
pytorch 多分类问题,计算百分比操作
Jul 09 Python
Python实现区域填充的示例代码
Feb 03 Python
python requests模块的使用示例
Apr 07 Python
Django在win10下的安装并创建工程
Nov 20 #Python
Python2与python3中 for 循环语句基础与实例分析
Nov 20 #Python
Python3中类、模块、错误与异常、文件的简易教程
Nov 20 #Python
Python实现将HTML转换成doc格式文件的方法示例
Nov 20 #Python
python中学习K-Means和图片压缩
Nov 20 #Python
深入理解Python中的super()方法
Nov 20 #Python
python实现读取excel写入mysql的小工具详解
Nov 20 #Python
You might like
PHP新手上路(十四)
2006/10/09 PHP
安装APACHE
2007/01/15 PHP
php下几个常用的去空、分组、调试数组函数
2009/02/22 PHP
简单的PHP多图上传小程序代码
2011/07/17 PHP
跟我学Laravel之视图 & Response
2014/10/15 PHP
关于JavaScript定义类和对象的几种方式
2010/11/09 Javascript
js querySelector和getElementById通过id获取元素的区别
2012/04/20 Javascript
JS获取月的最后一天与JS得到一个月份最大天数的实例代码
2013/12/16 Javascript
js 弹出新页面避免被浏览器、ad拦截的一种新方法
2014/04/30 Javascript
javascript进行数组追加方法小结
2014/06/16 Javascript
jQuery+PHP打造滑动开关效果
2014/12/16 Javascript
jquery实现在光标位置插入内容的方法
2015/02/05 Javascript
C#中使用迭代器处理等待任务
2015/07/13 Javascript
javascript运算符语法全面概述
2016/07/14 Javascript
jQuery插件实现可输入和自动匹配的下拉框
2016/10/24 Javascript
基于bootstrap风格的弹框插件
2016/12/28 Javascript
js eval函数使用,js对象和字符串互转实例
2017/03/06 Javascript
详解Angular 4.x NgIf 的用法
2017/05/22 Javascript
js禁止浏览器页面后退功能的实例(推荐)
2017/09/01 Javascript
ES6 系列之 WeakMap的使用示例
2018/08/06 Javascript
JS判断用户用的哪个浏览器实例详解
2018/10/09 Javascript
JavaScript中reduce()的5个基本用法示例
2020/07/19 Javascript
[50:05]VGJ.S vs OG 2018国际邀请赛淘汰赛BO3 第二场 8.22
2018/08/23 DOTA
Python多进程编程技术实例分析
2014/09/16 Python
实践Python的爬虫框架Scrapy来抓取豆瓣电影TOP250
2016/01/20 Python
Python操作列表常用方法实例小结【创建、遍历、统计、切片等】
2019/10/25 Python
css3实现椭圆轨迹旋转的示例代码
2018/10/29 HTML / CSS
皮尔·卡丹巴西官方商店:Pierre Cardin
2017/07/21 全球购物
应届毕业生通用的自荐书范文
2014/02/07 职场文书
小学生竞选班长演讲稿
2014/04/24 职场文书
设备收款委托书范本
2014/10/02 职场文书
2015年七一建党节活动总结
2015/03/20 职场文书
厉行节约工作总结
2015/08/12 职场文书
感恩教师节主题班会
2015/08/12 职场文书
mysql timestamp比较查询遇到的坑及解决
2021/11/27 MySQL
Python 中面向接口编程
2022/05/20 Python