PYTHON InceptionV3模型的复现详解


Posted in Python onMay 06, 2022

学习前言

Inception系列的结构和其它的前向神经网络的结构不太一样,每一层的内容不是直直向下的,而是分了很多的块。

什么是InceptionV3模型

InceptionV3模型是谷歌Inception系列里面的第三代模型,其模型结构与InceptionV2模型放在了同一篇论文里,其实二者模型结构差距不大,相比于其它神经网络模型,Inception网络最大的特点在于将神经网络层与层之间的卷积运算进行了拓展。
如VGG,AlexNet网络,它就是一直卷积下来的,一层接着一层;
ResNet则是创新性的引入了残差网络的概念,使得靠前若干层的某一层数据输出直接跳过多层引入到后面数据层的输入部分,后面的特征层的内容会有一部分由其前面的某一层线性贡献。
而Inception网络则是采用不同大小的卷积核,使得存在不同大小的感受野,最后实现拼接达到不同尺度特征的融合。
对于InceptionV3而言,其网络中存在着如下的结构。
这个结构使用不同大小的卷积核对输入进行卷积(这个结构主要在代码中的block1使用)。
PYTHON InceptionV3模型的复现详解
还存在着这样的结构,利用1x7的卷积和7x1的卷积代替7x7的卷积,这样可以只使用约(1x7 + 7x1) / (7x7) = 28.6%的计算开销;利用1x3的卷积和3x1的卷积代替3x3的卷积,这样可以只使用约(1x3 + 3x1) / (3x3) = 67%的计算开销。
下图利用1x7的卷积和7x1的卷积代替7x7的卷积(这个结构主要在代码中的block2使用)。
PYTHON InceptionV3模型的复现详解
下图利用1x3的卷积和3x1的卷积代替3x3的卷积(这个结构主要在代码中的block3使用)。
PYTHON InceptionV3模型的复现详解

InceptionV3网络部分实现代码

我一共将InceptionV3划分为3个block,对应着35x35、17x17,8x8维度大小的图像。每个block中间有许多的part,对应着不同的特征层深度,用于特征提取。

#-------------------------------------------------------------#
#   InceptionV3的网络部分
#-------------------------------------------------------------#
from __future__ import print_function
from __future__ import absolute_import

import warnings
import numpy as np

from keras.models import Model
from keras import layers
from keras.layers import Activation,Dense,Input,BatchNormalization,Conv2D,MaxPooling2D,AveragePooling2D
from keras.layers import GlobalAveragePooling2D,GlobalMaxPooling2D
from keras.engine.topology import get_source_inputs
from keras.utils.layer_utils import convert_all_kernels_in_model
from keras.utils.data_utils import get_file
from keras import backend as K
from keras.applications.imagenet_utils import decode_predictions
from keras.preprocessing import image


def conv2d_bn(x,
              filters,
              num_row,
              num_col,
              padding='same',
              strides=(1, 1),
              name=None):
    if name is not None:
        bn_name = name + '_bn'
        conv_name = name + '_conv'
    else:
        bn_name = None
        conv_name = None
    x = Conv2D(
        filters, (num_row, num_col),
        strides=strides,
        padding=padding,
        use_bias=False,
        name=conv_name)(x)
    x = BatchNormalization(scale=False, name=bn_name)(x)
    x = Activation('relu', name=name)(x)
    return x


def InceptionV3(input_shape=[299,299,3],
                classes=1000):


    img_input = Input(shape=input_shape)

    x = conv2d_bn(img_input, 32, 3, 3, strides=(2, 2), padding='valid')
    x = conv2d_bn(x, 32, 3, 3, padding='valid')
    x = conv2d_bn(x, 64, 3, 3)
    x = MaxPooling2D((3, 3), strides=(2, 2))(x)

    x = conv2d_bn(x, 80, 1, 1, padding='valid')
    x = conv2d_bn(x, 192, 3, 3, padding='valid')
    x = MaxPooling2D((3, 3), strides=(2, 2))(x)

    #--------------------------------#
    #   Block1 35x35
    #--------------------------------#
    # Block1 part1
    # 35 x 35 x 192 -> 35 x 35 x 256
    branch1x1 = conv2d_bn(x, 64, 1, 1)

    branch5x5 = conv2d_bn(x, 48, 1, 1)
    branch5x5 = conv2d_bn(branch5x5, 64, 5, 5)

    branch3x3dbl = conv2d_bn(x, 64, 1, 1)
    branch3x3dbl = conv2d_bn(branch3x3dbl, 96, 3, 3)
    branch3x3dbl = conv2d_bn(branch3x3dbl, 96, 3, 3)

    branch_pool = AveragePooling2D((3, 3), strides=(1, 1), padding='same')(x)
    branch_pool = conv2d_bn(branch_pool, 32, 1, 1)
    x = layers.concatenate(
        [branch1x1, branch5x5, branch3x3dbl, branch_pool],
        axis=3,
        name='mixed0')

    # Block1 part2
    # 35 x 35 x 256 -> 35 x 35 x 288
    branch1x1 = conv2d_bn(x, 64, 1, 1)

    branch5x5 = conv2d_bn(x, 48, 1, 1)
    branch5x5 = conv2d_bn(branch5x5, 64, 5, 5)

    branch3x3dbl = conv2d_bn(x, 64, 1, 1)
    branch3x3dbl = conv2d_bn(branch3x3dbl, 96, 3, 3)
    branch3x3dbl = conv2d_bn(branch3x3dbl, 96, 3, 3)

    branch_pool = AveragePooling2D((3, 3), strides=(1, 1), padding='same')(x)
    branch_pool = conv2d_bn(branch_pool, 64, 1, 1)
    x = layers.concatenate(
        [branch1x1, branch5x5, branch3x3dbl, branch_pool],
        axis=3,
        name='mixed1')

    # Block1 part3
    # 35 x 35 x 288 -> 35 x 35 x 288
    branch1x1 = conv2d_bn(x, 64, 1, 1)

    branch5x5 = conv2d_bn(x, 48, 1, 1)
    branch5x5 = conv2d_bn(branch5x5, 64, 5, 5)

    branch3x3dbl = conv2d_bn(x, 64, 1, 1)
    branch3x3dbl = conv2d_bn(branch3x3dbl, 96, 3, 3)
    branch3x3dbl = conv2d_bn(branch3x3dbl, 96, 3, 3)

    branch_pool = AveragePooling2D((3, 3), strides=(1, 1), padding='same')(x)
    branch_pool = conv2d_bn(branch_pool, 64, 1, 1)
    x = layers.concatenate(
        [branch1x1, branch5x5, branch3x3dbl, branch_pool],
        axis=3,
        name='mixed2')

    #--------------------------------#
    #   Block2 17x17
    #--------------------------------#
    # Block2 part1
    # 35 x 35 x 288 -> 17 x 17 x 768
    branch3x3 = conv2d_bn(x, 384, 3, 3, strides=(2, 2), padding='valid')

    branch3x3dbl = conv2d_bn(x, 64, 1, 1)
    branch3x3dbl = conv2d_bn(branch3x3dbl, 96, 3, 3)
    branch3x3dbl = conv2d_bn(
        branch3x3dbl, 96, 3, 3, strides=(2, 2), padding='valid')

    branch_pool = MaxPooling2D((3, 3), strides=(2, 2))(x)
    x = layers.concatenate(
        [branch3x3, branch3x3dbl, branch_pool], axis=3, name='mixed3')

    # Block2 part2
    # 17 x 17 x 768 -> 17 x 17 x 768
    branch1x1 = conv2d_bn(x, 192, 1, 1)

    branch7x7 = conv2d_bn(x, 128, 1, 1)
    branch7x7 = conv2d_bn(branch7x7, 128, 1, 7)
    branch7x7 = conv2d_bn(branch7x7, 192, 7, 1)

    branch7x7dbl = conv2d_bn(x, 128, 1, 1)
    branch7x7dbl = conv2d_bn(branch7x7dbl, 128, 7, 1)
    branch7x7dbl = conv2d_bn(branch7x7dbl, 128, 1, 7)
    branch7x7dbl = conv2d_bn(branch7x7dbl, 128, 7, 1)
    branch7x7dbl = conv2d_bn(branch7x7dbl, 192, 1, 7)

    branch_pool = AveragePooling2D((3, 3), strides=(1, 1), padding='same')(x)
    branch_pool = conv2d_bn(branch_pool, 192, 1, 1)
    x = layers.concatenate(
        [branch1x1, branch7x7, branch7x7dbl, branch_pool],
        axis=3,
        name='mixed4')

    # Block2 part3 and part4
    # 17 x 17 x 768 -> 17 x 17 x 768 -> 17 x 17 x 768
    for i in range(2):
        branch1x1 = conv2d_bn(x, 192, 1, 1)

        branch7x7 = conv2d_bn(x, 160, 1, 1)
        branch7x7 = conv2d_bn(branch7x7, 160, 1, 7)
        branch7x7 = conv2d_bn(branch7x7, 192, 7, 1)

        branch7x7dbl = conv2d_bn(x, 160, 1, 1)
        branch7x7dbl = conv2d_bn(branch7x7dbl, 160, 7, 1)
        branch7x7dbl = conv2d_bn(branch7x7dbl, 160, 1, 7)
        branch7x7dbl = conv2d_bn(branch7x7dbl, 160, 7, 1)
        branch7x7dbl = conv2d_bn(branch7x7dbl, 192, 1, 7)

        branch_pool = AveragePooling2D(
            (3, 3), strides=(1, 1), padding='same')(x)
        branch_pool = conv2d_bn(branch_pool, 192, 1, 1)
        x = layers.concatenate(
            [branch1x1, branch7x7, branch7x7dbl, branch_pool],
            axis=3,
            name='mixed' + str(5 + i))

    # Block2 part5
    # 17 x 17 x 768 -> 17 x 17 x 768
    branch1x1 = conv2d_bn(x, 192, 1, 1)

    branch7x7 = conv2d_bn(x, 192, 1, 1)
    branch7x7 = conv2d_bn(branch7x7, 192, 1, 7)
    branch7x7 = conv2d_bn(branch7x7, 192, 7, 1)

    branch7x7dbl = conv2d_bn(x, 192, 1, 1)
    branch7x7dbl = conv2d_bn(branch7x7dbl, 192, 7, 1)
    branch7x7dbl = conv2d_bn(branch7x7dbl, 192, 1, 7)
    branch7x7dbl = conv2d_bn(branch7x7dbl, 192, 7, 1)
    branch7x7dbl = conv2d_bn(branch7x7dbl, 192, 1, 7)

    branch_pool = AveragePooling2D((3, 3), strides=(1, 1), padding='same')(x)
    branch_pool = conv2d_bn(branch_pool, 192, 1, 1)
    x = layers.concatenate(
        [branch1x1, branch7x7, branch7x7dbl, branch_pool],
        axis=3,
        name='mixed7')

    #--------------------------------#
    #   Block3 8x8
    #--------------------------------#
    # Block3 part1
    # 17 x 17 x 768 -> 8 x 8 x 1280
    branch3x3 = conv2d_bn(x, 192, 1, 1)
    branch3x3 = conv2d_bn(branch3x3, 320, 3, 3,
                          strides=(2, 2), padding='valid')

    branch7x7x3 = conv2d_bn(x, 192, 1, 1)
    branch7x7x3 = conv2d_bn(branch7x7x3, 192, 1, 7)
    branch7x7x3 = conv2d_bn(branch7x7x3, 192, 7, 1)
    branch7x7x3 = conv2d_bn(
        branch7x7x3, 192, 3, 3, strides=(2, 2), padding='valid')

    branch_pool = MaxPooling2D((3, 3), strides=(2, 2))(x)
    x = layers.concatenate(
        [branch3x3, branch7x7x3, branch_pool], axis=3, name='mixed8')

    # Block3 part2 part3
    # 8 x 8 x 1280 -> 8 x 8 x 2048 -> 8 x 8 x 2048
    for i in range(2):
        branch1x1 = conv2d_bn(x, 320, 1, 1)

        branch3x3 = conv2d_bn(x, 384, 1, 1)
        branch3x3_1 = conv2d_bn(branch3x3, 384, 1, 3)
        branch3x3_2 = conv2d_bn(branch3x3, 384, 3, 1)
        branch3x3 = layers.concatenate(
            [branch3x3_1, branch3x3_2], axis=3, name='mixed9_' + str(i))

        branch3x3dbl = conv2d_bn(x, 448, 1, 1)
        branch3x3dbl = conv2d_bn(branch3x3dbl, 384, 3, 3)
        branch3x3dbl_1 = conv2d_bn(branch3x3dbl, 384, 1, 3)
        branch3x3dbl_2 = conv2d_bn(branch3x3dbl, 384, 3, 1)
        branch3x3dbl = layers.concatenate(
            [branch3x3dbl_1, branch3x3dbl_2], axis=3)

        branch_pool = AveragePooling2D(
            (3, 3), strides=(1, 1), padding='same')(x)
        branch_pool = conv2d_bn(branch_pool, 192, 1, 1)
        x = layers.concatenate(
            [branch1x1, branch3x3, branch3x3dbl, branch_pool],
            axis=3,
            name='mixed' + str(9 + i))
    # 平均池化后全连接。
    x = GlobalAveragePooling2D(name='avg_pool')(x)
    x = Dense(classes, activation='softmax', name='predictions')(x)


    inputs = img_input

    model = Model(inputs, x, name='inception_v3')

    return model

图片预测

建立网络后,可以用以下的代码进行预测。

def preprocess_input(x):
    x /= 255.
    x -= 0.5
    x *= 2.
    return x


if __name__ == '__main__':
    model = InceptionV3()

    model.load_weights("inception_v3_weights_tf_dim_ordering_tf_kernels.h5")
    
    img_path = 'elephant.jpg'
    img = image.load_img(img_path, target_size=(299, 299))
    x = image.img_to_array(img)
    x = np.expand_dims(x, axis=0)

    x = preprocess_input(x)

    preds = model.predict(x)
    print('Predicted:', decode_predictions(preds))

预测所需的已经训练好的InceptionV3模型可以在https://github.com/fchollet/deep-learning-models/releases下载。非常方便。
预测结果为:

Predicted: [[('n02504458', 'African_elephant', 0.50874853), ('n01871265', 'tusker', 0.19524273), ('n02504013', 'Indian_elephant', 0.1566972), ('n01917289', 'brain_coral', 0.0008956835), ('n01695060', 'Komodo_dragon', 0.0008260256)]]

这里我推荐一个很不错的blog讲InceptionV3的结构的深度神经网络Google Inception Net-V3结构图里面有每一层的结构图,非常清晰。


Tags in this post...

Python 相关文章推荐
Python中多线程thread与threading的实现方法
Aug 18 Python
Python实现控制台中的进度条功能代码
Dec 22 Python
tensorflow中next_batch的具体使用
Feb 02 Python
Python实现ping指定IP的示例
Jun 04 Python
解决pip install xxx报错SyntaxError: invalid syntax的问题
Nov 30 Python
Python函数装饰器实现方法详解
Dec 22 Python
django框架model orM使用字典作为参数,保存数据的方法分析
Jun 24 Python
Django 中自定义 Admin 样式与功能的实现方法
Jul 04 Python
使用TFRecord存取多个数据案例
Feb 17 Python
Python json解析库jsonpath原理及使用示例
Nov 25 Python
使用python操作lmdb对数据读取的实例
Dec 11 Python
python实现图片批量压缩
Apr 24 Python
代码复现python目标检测yolo3详解预测
讲解Python实例练习逆序输出字符串
May 06 #Python
python turtle绘图
May 04 #Python
python blinker 信号库
May 04 #Python
python三子棋游戏
May 04 #Python
python神经网络 使用Keras构建RNN训练
May 04 #Python
python神经网络学习 使用Keras进行回归运算
May 04 #Python
You might like
PHP在线生成二维码代码(google api)
2013/06/03 PHP
php通过baihui网API实现读取word文档并展示
2015/06/22 PHP
php实现mysql数据库连接操作及用户管理
2015/11/08 PHP
PHP addAttribute()函数讲解
2019/02/03 PHP
让checkbox不选中即将选中的checkbox不选中
2014/07/11 Javascript
jQuery操作表单常用控件方法小结
2015/03/23 Javascript
详解Document.Cookie
2015/12/25 Javascript
js仿百度登录页实现拖动窗口效果
2016/03/11 Javascript
用js动态添加html元素,以及属性的简单实例
2016/07/19 Javascript
JavaScript中英文字符长度统计方法示例【按照中文占2个字符】
2017/01/17 Javascript
js获取当前页的URL与window.location.href简单方法
2017/02/13 Javascript
vue如何获取点击事件源的方法
2017/08/10 Javascript
详解jquery插件jquery.viewport.js学习使用方法
2017/09/08 jQuery
vue2.0设置proxyTable使用axios进行跨域请求的方法
2017/10/19 Javascript
使用mpvue搭建一个初始小程序及项目配置方法
2018/12/03 Javascript
vue+elementUI组件table实现前端分页功能
2020/11/15 Javascript
vue-cli创建的项目中的gitHooks原理解析
2020/02/14 Javascript
[08:08]DOTA2-DPC中国联赛2月28日Recap集锦
2021/03/11 DOTA
[05:09]DOTA2-DPC中国联赛2月22日Recap集锦
2021/03/11 DOTA
python爬虫入门教程之糗百图片爬虫代码分享
2014/09/02 Python
用python实现简单EXCEL数据统计的实例
2017/01/24 Python
Python之自动获取公网IP的实例讲解
2017/10/01 Python
详解python上传文件和字符到PHP服务器
2017/11/24 Python
python模块smtplib实现纯文本邮件发送功能
2018/05/22 Python
django创建简单的页面响应实例教程
2019/09/06 Python
python return逻辑判断表达式实现解析
2019/12/02 Python
HTML5的Video标签有部分MP4无法播放的问题解析(多图)
2017/08/18 HTML / CSS
Nike英国官网:Nike.com (UK)
2017/02/13 全球购物
Skip Hop官网:好莱坞宝宝挚爱品牌
2018/06/17 全球购物
JDBC操作数据库的基本流程是什么
2014/10/28 面试题
如何安装ruby on rails
2014/02/09 面试题
销售类求职信
2014/06/13 职场文书
个人四风问题整改措施
2014/10/24 职场文书
三好学生评选事迹材料(2016精选版)
2016/02/25 职场文书
Python基础详解之邮件处理
2021/04/28 Python
总结python多进程multiprocessing的相关知识
2021/06/29 Python