使用keras实现densenet和Xception的模型融合


Posted in Python onMay 23, 2020

我正在参加天池上的一个竞赛,刚开始用的是DenseNet121但是效果没有达到预期,因此开始尝试使用模型融合,将Desenet和Xception融合起来共同提取特征。

代码如下:

def Multimodel(cnn_weights_path=None,all_weights_path=None,class_num=5,cnn_no_vary=False):
	'''
	获取densent121,xinception并联的网络
	此处的cnn_weights_path是个列表是densenet和xception的卷积部分的权值
	'''
	input_layer=Input(shape=(224,224,3))
	dense=DenseNet121(include_top=False,weights=None,input_shape=(224,224,3))
	xception=Xception(include_top=False,weights=None,input_shape=(224,224,3))
	#res=ResNet50(include_top=False,weights=None,input_shape=(224,224,3))

	if cnn_no_vary:
		for i,layer in enumerate(dense.layers):
			dense.layers[i].trainable=False
		for i,layer in enumerate(xception.layers):
			xception.layers[i].trainable=False
		#for i,layer in enumerate(res.layers):
		#	res.layers[i].trainable=False
 
	if cnn_weights_path!=None:
		dense.load_weights(cnn_weights_path[0])
		xception.load_weights(cnn_weights_path[1])
		#res.load_weights(cnn_weights_path[2])
	dense=dense(input_layer)
	xception=xception(input_layer)

	#对dense_121和xception进行全局最大池化
	top1_model=GlobalMaxPooling2D(data_format='channels_last')(dense)
	top2_model=GlobalMaxPooling2D(data_format='channels_last')(xception)
	#top3_model=GlobalMaxPool2D(input_shape=res.output_shape)(res.outputs[0])
	
	print(top1_model.shape,top2_model.shape)
	#把top1_model和top2_model连接起来
	t=keras.layers.Concatenate(axis=1)([top1_model,top2_model])
	#第一个全连接层
	top_model=Dense(units=512,activation="relu")(t)
	top_model=Dropout(rate=0.5)(top_model)
	top_model=Dense(units=class_num,activation="softmax")(top_model)
	
	model=Model(inputs=input_layer,outputs=top_model)
 
	#加载全部的参数
	if all_weights_path:
		model.load_weights(all_weights_path)
	return model

如下进行调用:

if __name__=="__main__":
 weights_path=["./densenet121_weights_tf_dim_ordering_tf_kernels_notop.h5",
 "xception_weights_tf_dim_ordering_tf_kernels_notop.h5"]
 model=Multimodel(cnn_weights_path=weights_path,class_num=6)
 plot_model(model,to_file="G:/model.png")

最后生成的模型图如下:有点长,可以不看

使用keras实现densenet和Xception的模型融合

需要注意的一点是,如果dense=dense(input_layer)这里报错的话,说明你用的是tensorflow1.4以下的版本,解决的方法就是

1、升级tensorflow到1.4以上

2、改代码:

def Multimodel(cnn_weights_path=None,all_weights_path=None,class_num=5,cnn_no_vary=False):
	'''
	获取densent121,xinception并联的网络
	此处的cnn_weights_path是个列表是densenet和xception的卷积部分的权值
	'''
	dir=os.getcwd()
	input_layer=Input(shape=(224,224,3))
	
	dense=DenseNet121(include_top=False,weights=None,input_tensor=input_layer,
		input_shape=(224,224,3))
	xception=Xception(include_top=False,weights=None,input_tensor=input_layer,
		input_shape=(224,224,3))
	#res=ResNet50(include_top=False,weights=None,input_shape=(224,224,3))
 
	if cnn_no_vary:
		for i,layer in enumerate(dense.layers):
			dense.layers[i].trainable=False
		for i,layer in enumerate(xception.layers):
			xception.layers[i].trainable=False
		#for i,layer in enumerate(res.layers):
		#	res.layers[i].trainable=False
	if cnn_weights_path!=None:
		dense.load_weights(cnn_weights_path[0])
		xception.load_weights(cnn_weights_path[1])
 
	#print(dense.shape,xception.shape)
	#对dense_121和xception进行全局最大池化
	top1_model=GlobalMaxPooling2D(input_shape=(7,7,1024),data_format='channels_last')(dense.output)
	top2_model=GlobalMaxPooling2D(input_shape=(7,7,1024),data_format='channels_last')(xception.output)
	#top3_model=GlobalMaxPool2D(input_shape=res.output_shape)(res.outputs[0])
	
	print(top1_model.shape,top2_model.shape)
	#把top1_model和top2_model连接起来
	t=keras.layers.Concatenate(axis=1)([top1_model,top2_model])
	#第一个全连接层
	top_model=Dense(units=512,activation="relu")(t)
	top_model=Dropout(rate=0.5)(top_model)
	top_model=Dense(units=class_num,activation="softmax")(top_model)
	
	model=Model(inputs=input_layer,outputs=top_model)
 
	#加载全部的参数
	if all_weights_path:
		model.load_weights(all_weights_path)
	return model

这个bug我也是在服务器上跑的时候才出现的,找了半天,而实验室的cuda和cudnn又改不了,tensorflow无法升级,因此只能改代码了。

如下所示,是最后画出的模型图:(很长,底下没内容了)

使用keras实现densenet和Xception的模型融合

以上这篇使用keras实现densenet和Xception的模型融合就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在类Unix系统上开始Python3编程入门
Aug 20 Python
VSCode下好用的Python插件及配置
Apr 06 Python
Python lxml解析HTML并用xpath获取元素的方法
Jan 02 Python
python变量的存储原理详解
Jul 10 Python
python的一些加密方法及python 加密模块
Jul 11 Python
python3 深浅copy对比详解
Aug 12 Python
python实现的发邮件功能示例
Sep 11 Python
python实现把二维列表变为一维列表的方法分析
Oct 08 Python
pytorch掉坑记录:model.eval的作用说明
Jun 23 Python
Python如何输出警告信息
Jul 30 Python
python各种excel写入方式的速度对比
Nov 10 Python
Python ellipsis 的用法详解
Nov 20 Python
在keras下实现多个模型的融合方式
May 23 #Python
Keras使用ImageNet上预训练的模型方式
May 23 #Python
使用Keras预训练模型ResNet50进行图像分类方式
May 23 #Python
基于Python中random.sample()的替代方案
May 23 #Python
keras 自定义loss损失函数,sample在loss上的加权和metric详解
May 23 #Python
keras中模型训练class_weight,sample_weight区别说明
May 23 #Python
浅谈keras中的Merge层(实现层的相加、相减、相乘实例)
May 23 #Python
You might like
PHP网页游戏学习之Xnova(ogame)源码解读(十)
2014/06/24 PHP
PHP curl CURLOPT_RETURNTRANSFER参数的作用使用实例
2015/02/07 PHP
PHP abstract 抽象类定义与用法示例
2018/05/29 PHP
php 调用百度sms来发送短信的实现示例
2018/11/02 PHP
js 调用本地exe的例子(支持IE内核的浏览器)
2012/12/26 Javascript
前台js对象在后台转化java对象的问题探讨
2013/12/20 Javascript
SeaJS入门教程系列之完整示例(三)
2014/03/03 Javascript
Javascript获取CSS伪元素属性的实现代码
2014/09/28 Javascript
JavaScript给按钮绑定点击事件(onclick)的方法
2015/04/07 Javascript
详解JavaScript操作HTML DOM的基本方式
2015/10/21 Javascript
javascript获取select标签选中的值
2016/06/04 Javascript
jQuery实现的浮动层div浏览器居中显示效果
2017/02/03 Javascript
详谈表单重复提交的三种情况及解决方法
2017/08/16 Javascript
详解Chai.js断言库API中文文档
2018/01/31 Javascript
JS实现判断图片是否加载完成的方法分析
2018/07/31 Javascript
Vue 实现展开折叠效果的示例代码
2018/08/27 Javascript
Node.js 进程平滑离场剖析小结
2019/01/24 Javascript
JavaScript遍历数组的三种方法map、forEach与filter实例详解
2019/02/27 Javascript
利用JS响应式修改vue实现页面的input值
2019/09/02 Javascript
ES2020 新特性(种草)
2020/01/12 Javascript
python计算最小优先级队列代码分享
2013/12/18 Python
Python Trie树实现字典排序
2014/03/28 Python
Python标准库笔记struct模块的使用
2018/02/22 Python
python将秒数转化为时间格式的实例
2018/09/16 Python
解决python中无法自动补全代码的问题
2018/12/04 Python
在python中pandas读文件,有中文字符的方法
2018/12/12 Python
python3 pathlib库Path类方法总结
2019/12/26 Python
pandas抽取行列数据的几种方法
2020/12/13 Python
python openpyxl模块的使用详解
2021/02/25 Python
CSS3文本换行word-wrap解决英文文本超过固定宽度不换行
2013/10/10 HTML / CSS
美国皮靴公司自1863年:The Frye Company
2016/11/30 全球购物
司机岗位职责
2013/11/15 职场文书
社团活动总结书
2014/06/27 职场文书
2016年“9.22”世界无车日活动小结
2016/04/05 职场文书
基于PyTorch实现一个简单的CNN图像分类器
2021/05/29 Python
Python中的 No Module named ***问题及解决
2022/07/23 Python