Keras实现DenseNet结构操作


Posted in Python onJuly 06, 2020

DenseNet结构在16年由Huang Gao和Liu Zhuang等人提出,并且在CVRP2017中被评为最佳论文。网络的核心结构为如下所示的Dense块,在每一个Dense块中,存在多个Dense层,即下图所示的H1~H4。各Dense层之间彼此均相互连接,即H1的输入为x0,输出为x1,H2的输入即为[x0, x1],输出为x2,依次类推。最终Dense块的输出即为[x0, x1, x2, x3, x4]。这种结构个人感觉非常类似生物学里边的神经元连接方式,应该能够比较有效的提高了网络中特征信息的利用效率。

Keras实现DenseNet结构操作

DenseNet的其他结构就非常类似一般的卷积神经网络结构了,可以参考论文中提供的网路结构图(下图)。但是个人感觉,DenseNet的这种结构应该是存在进一步的优化方法的,比如可能不一定需要在Dense块中对每一个Dense层均直接进行相互连接,来缩小网络的结构;也可能可以在不相邻的Dense块之间通过简单的下采样操作进行连接,进一步提升网络对不同尺度的特征的利用效率。

Keras实现DenseNet结构操作

由于DenseNet的密集连接方式,在构建一个相同容量的网络时其所需的参数数量远小于其之前提出的如resnet等结构。进一步,个人感觉应该可以把Dense块看做对一个有较多参数的卷积层的高效替代。因此,其也可以结合U-Net等网络结构,来进一步优化网络性能,比如单纯的把U-net中的所有卷积层全部换成DenseNet的结构,就可以显著压缩网络大小。

下面基于Keras实现DenseNet-BC结构。首先定义Dense层,根据论文描述构建如下:

def DenseLayer(x, nb_filter, bn_size=4, alpha=0.0, drop_rate=0.2):
 
 # Bottleneck layers
 x = BatchNormalization(axis=3)(x)
 x = LeakyReLU(alpha=alpha)(x)
 x = Conv2D(bn_size*nb_filter, (1, 1), strides=(1,1), padding='same')(x)
 
 # Composite function
 x = BatchNormalization(axis=3)(x)
 x = LeakyReLU(alpha=alpha)(x)
 x = Conv2D(nb_filter, (3, 3), strides=(1,1), padding='same')(x)
 
 if drop_rate: x = Dropout(drop_rate)(x)
 
 return x

论文原文中提出使用1*1卷积核的卷积层作为bottleneck层来优化计算效率。原文中使用的激活函数全部为relu,但个人习惯是用leakyrelu进行构建,来方便调参。

之后是用Dense层搭建Dense块,如下:

def DenseBlock(x, nb_layers, growth_rate, drop_rate=0.2):
 
 for ii in range(nb_layers):
  conv = DenseLayer(x, nb_filter=growth_rate, drop_rate=drop_rate)
  x = concatenate([x, conv], axis=3)
 return x

如论文中所述,将每一个Dense层的输出与其输入融合之后作为下一Dense层的输入,来实现密集连接。

最后是各Dense块之间的过渡层,如下:

def TransitionLayer(x, compression=0.5, alpha=0.0, is_max=0):
 
 nb_filter = int(x.shape.as_list()[-1]*compression)
 x = BatchNormalization(axis=3)(x)
 x = LeakyReLU(alpha=alpha)(x)
 x = Conv2D(nb_filter, (1, 1), strides=(1,1), padding='same')(x)
 if is_max != 0: x = MaxPooling2D(pool_size=(2, 2), strides=2)(x)
 else: x = AveragePooling2D(pool_size=(2, 2), strides=2)(x)
 
 return x

论文中提出使用均值池化层来作下采样,不过在边缘特征提取方面,最大池化层效果应该更好,这里就加了相关接口。

将上述结构按照论文中提出的结构进行拼接,这里选择的参数是论文中提到的L=100,k=12,网络连接如下:

growth_rate = 12
inpt = Input(shape=(32,32,3))
 
x = Conv2D(growth_rate*2, (3, 3), strides=1, padding='same')(inpt)
x = BatchNormalization(axis=3)(x)
x = LeakyReLU(alpha=0.1)(x)
x = DenseBlock(x, 12, growth_rate, drop_rate=0.2)
x = TransitionLayer(x)
x = DenseBlock(x, 12, growth_rate, drop_rate=0.2)
x = TransitionLayer(x)
x = DenseBlock(x, 12, growth_rate, drop_rate=0.2)
x = BatchNormalization(axis=3)(x)
x = GlobalAveragePooling2D()(x)
x = Dense(10, activation='softmax')(x)
 
model = Model(inpt, x)
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
model.summary()

虽然我们已经完成了网络的架设,网络本身的参数数量也仅有0.5M,但由于以这种方式实现的网络在Dense块中,每一次concat均需要开辟一组全新的内存空间,导致实际需要的内存空间非常大。作者在17年的时候,还专门写了相关的技术报告:https://arxiv.org/abs/1707.06990来说明怎么节省内存空间,不过单纯用keras实现起来是比较麻烦。下一篇博客中将以pytorch框架来对其进行实现。

最后放出网络完整代码:

import numpy as np
import keras
from keras.models import Model, save_model, load_model
from keras.layers import Input, Dense, Dropout, BatchNormalization, LeakyReLU, concatenate
from keras.layers import Conv2D, MaxPooling2D, AveragePooling2D, GlobalAveragePooling2D
 
## data
import pickle
 
data_batch_1 = pickle.load(open("cifar-10-batches-py/data_batch_1", 'rb'), encoding='bytes')
data_batch_2 = pickle.load(open("cifar-10-batches-py/data_batch_2", 'rb'), encoding='bytes')
data_batch_3 = pickle.load(open("cifar-10-batches-py/data_batch_3", 'rb'), encoding='bytes')
data_batch_4 = pickle.load(open("cifar-10-batches-py/data_batch_4", 'rb'), encoding='bytes')
data_batch_5 = pickle.load(open("cifar-10-batches-py/data_batch_5", 'rb'), encoding='bytes')
 
train_X_1 = data_batch_1[b'data']
train_X_1 = train_X_1.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).astype("float")
train_Y_1 = data_batch_1[b'labels']
 
train_X_2 = data_batch_2[b'data']
train_X_2 = train_X_2.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).astype("float")
train_Y_2 = data_batch_2[b'labels']
 
train_X_3 = data_batch_3[b'data']
train_X_3 = train_X_3.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).astype("float")
train_Y_3 = data_batch_3[b'labels']
 
train_X_4 = data_batch_4[b'data']
train_X_4 = train_X_4.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).astype("float")
train_Y_4 = data_batch_4[b'labels']
 
train_X_5 = data_batch_5[b'data']
train_X_5 = train_X_5.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).astype("float")
train_Y_5 = data_batch_5[b'labels']
 
train_X = np.row_stack((train_X_1, train_X_2))
train_X = np.row_stack((train_X, train_X_3))
train_X = np.row_stack((train_X, train_X_4))
train_X = np.row_stack((train_X, train_X_5))
 
train_Y = np.row_stack((train_Y_1, train_Y_2))
train_Y = np.row_stack((train_Y, train_Y_3))
train_Y = np.row_stack((train_Y, train_Y_4))
train_Y = np.row_stack((train_Y, train_Y_5))
train_Y = train_Y.reshape(50000, 1).transpose(0, 1).astype("int32")
train_Y = keras.utils.to_categorical(train_Y)
 
test_batch = pickle.load(open("cifar-10-batches-py/test_batch", 'rb'), encoding='bytes')
test_X = test_batch[b'data']
test_X = test_X.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).astype("float")
test_Y = test_batch[b'labels']
test_Y = keras.utils.to_categorical(test_Y)
 
train_X /= 255
test_X /= 255
 
# model
 
def DenseLayer(x, nb_filter, bn_size=4, alpha=0.0, drop_rate=0.2):
 
 # Bottleneck layers
 x = BatchNormalization(axis=3)(x)
 x = LeakyReLU(alpha=alpha)(x)
 x = Conv2D(bn_size*nb_filter, (1, 1), strides=(1,1), padding='same')(x)
 
 # Composite function
 x = BatchNormalization(axis=3)(x)
 x = LeakyReLU(alpha=alpha)(x)
 x = Conv2D(nb_filter, (3, 3), strides=(1,1), padding='same')(x)
 
 if drop_rate: x = Dropout(drop_rate)(x)
 
 return x
 
def DenseBlock(x, nb_layers, growth_rate, drop_rate=0.2):
 
 for ii in range(nb_layers):
  conv = DenseLayer(x, nb_filter=growth_rate, drop_rate=drop_rate)
  x = concatenate([x, conv], axis=3)
  
 return x
 
def TransitionLayer(x, compression=0.5, alpha=0.0, is_max=0):
 
 nb_filter = int(x.shape.as_list()[-1]*compression)
 x = BatchNormalization(axis=3)(x)
 x = LeakyReLU(alpha=alpha)(x)
 x = Conv2D(nb_filter, (1, 1), strides=(1,1), padding='same')(x)
 if is_max != 0: x = MaxPooling2D(pool_size=(2, 2), strides=2)(x)
 else: x = AveragePooling2D(pool_size=(2, 2), strides=2)(x)
 
 return x
 
growth_rate = 12
 
inpt = Input(shape=(32,32,3))
 
x = Conv2D(growth_rate*2, (3, 3), strides=1, padding='same')(inpt)
x = BatchNormalization(axis=3)(x)
x = LeakyReLU(alpha=0.1)(x)
x = DenseBlock(x, 12, growth_rate, drop_rate=0.2)
x = TransitionLayer(x)
x = DenseBlock(x, 12, growth_rate, drop_rate=0.2)
x = TransitionLayer(x)
x = DenseBlock(x, 12, growth_rate, drop_rate=0.2)
x = BatchNormalization(axis=3)(x)
x = GlobalAveragePooling2D()(x)
x = Dense(10, activation='softmax')(x)
 
model = Model(inpt, x)
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
 
model.summary()
 
for ii in range(10):
 print("Epoch:", ii+1)
 model.fit(train_X, train_Y, batch_size=100, epochs=1, verbose=1)
 score = model.evaluate(test_X, test_Y, verbose=1)
 print('Test loss =', score[0])
 print('Test accuracy =', score[1])
 
save_model(model, 'DenseNet.h5')
model = load_model('DenseNet.h5')
 
pred_Y = model.predict(test_X)
score = model.evaluate(test_X, test_Y, verbose=0)
print('Test loss =', score[0])
print('Test accuracy =', score[1])

以上这篇Keras实现DenseNet结构操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现文件名批量替换和内容替换
Mar 20 Python
python TKinter获取文本框内容的方法
Oct 11 Python
在python中利用GDAL对tif文件进行读写的方法
Nov 29 Python
python实现弹跳小球
May 13 Python
python字典嵌套字典的情况下找到某个key的value详解
Jul 10 Python
kali中python版本的切换方法
Jul 11 Python
django-filter和普通查询的例子
Aug 12 Python
tensorflow 模型权重导出实例
Jan 24 Python
Tensorflow不支持AVX2指令集的解决方法
Feb 03 Python
python实现ftp文件传输系统(案例分析)
Mar 20 Python
Python异常原理及异常捕捉实现过程解析
Mar 25 Python
在pycharm中使用matplotlib.pyplot 绘图时报错的解决
Jun 01 Python
基于Python和C++实现删除链表的节点
Jul 06 #Python
基于Python 的语音重采样函数解析
Jul 06 #Python
python interpolate插值实例
Jul 06 #Python
基于Python实现2种反转链表方法代码实例
Jul 06 #Python
简单了解Django项目应用创建过程
Jul 06 #Python
如何在mac下配置python虚拟环境
Jul 06 #Python
Python优秀开源项目Rich源码解析的流程分析
Jul 06 #Python
You might like
星际争霸秘籍
2020/03/04 星际争霸
融入意大利的咖啡文化
2021/03/03 咖啡文化
php ctype函数中文翻译和示例
2014/03/21 PHP
php自动给网址加上链接的方法
2015/06/02 PHP
070823更新的一个[消息提示框]组件 兼容ie7
2007/08/29 Javascript
论坛里点击别人帖子下面的回复,回复标题变成“回复 24# 的帖子”
2009/06/14 Javascript
基于jQuery的淡入淡出可自动切换的幻灯插件打包下载
2010/09/15 Javascript
打开新窗口关闭当前页面不弹出关闭提示js代码
2013/03/18 Javascript
js实现对table动态添加、删除和更新的方法
2015/02/10 Javascript
php结合imgareaselect实现图片裁剪
2015/07/05 Javascript
学JavaScript七大注意事项【必看】
2016/05/04 Javascript
ajax的分页查询示例(不刷新页面)
2017/01/11 Javascript
JS实现直接运行html代码的方法
2017/03/13 Javascript
CSS3+JavaScript实现翻页幻灯片效果
2017/06/28 Javascript
2种简单的js倒计时方式
2017/10/20 Javascript
解决node修改后需频繁手动重启的问题
2018/05/13 Javascript
使用koa-log4管理nodeJs日志笔记的使用方法
2018/11/30 NodeJs
js实现盒子滚动动画效果
2020/08/09 Javascript
Python网络爬虫出现乱码问题的解决方法
2017/01/05 Python
Python冲顶大会 快来答题!
2018/01/17 Python
Python实现二维数组输出为图片
2018/04/03 Python
Python简单爬虫导出CSV文件的实例讲解
2018/07/06 Python
Python 给定的经纬度标注在地图上的实现方法
2019/07/05 Python
PyCharm导入python项目并配置虚拟环境的教程详解
2019/10/13 Python
Python类的动态绑定实现原理
2020/03/21 Python
白色公司:The White Company
2017/10/11 全球购物
品管员岗位职责
2013/11/10 职场文书
质量保证书范本
2014/04/29 职场文书
优秀求职信
2014/05/29 职场文书
2014教师个人自我评价范文
2014/09/13 职场文书
2014年“向国旗敬礼”网上签名寄语活动方案
2014/09/27 职场文书
2015年全国科普日活动总结
2015/03/23 职场文书
2015年党小组工作总结
2015/05/26 职场文书
Go标准容器之Ring的使用说明
2021/05/05 Golang
深入解读Java三大集合之map list set的用法
2021/11/11 Java/Android
聊聊redis-dump工具安装问题
2022/01/18 Redis