使用keras实现densenet和Xception的模型融合


Posted in Python onMay 23, 2020

我正在参加天池上的一个竞赛,刚开始用的是DenseNet121但是效果没有达到预期,因此开始尝试使用模型融合,将Desenet和Xception融合起来共同提取特征。

代码如下:

def Multimodel(cnn_weights_path=None,all_weights_path=None,class_num=5,cnn_no_vary=False):
	'''
	获取densent121,xinception并联的网络
	此处的cnn_weights_path是个列表是densenet和xception的卷积部分的权值
	'''
	input_layer=Input(shape=(224,224,3))
	dense=DenseNet121(include_top=False,weights=None,input_shape=(224,224,3))
	xception=Xception(include_top=False,weights=None,input_shape=(224,224,3))
	#res=ResNet50(include_top=False,weights=None,input_shape=(224,224,3))

	if cnn_no_vary:
		for i,layer in enumerate(dense.layers):
			dense.layers[i].trainable=False
		for i,layer in enumerate(xception.layers):
			xception.layers[i].trainable=False
		#for i,layer in enumerate(res.layers):
		#	res.layers[i].trainable=False
 
	if cnn_weights_path!=None:
		dense.load_weights(cnn_weights_path[0])
		xception.load_weights(cnn_weights_path[1])
		#res.load_weights(cnn_weights_path[2])
	dense=dense(input_layer)
	xception=xception(input_layer)

	#对dense_121和xception进行全局最大池化
	top1_model=GlobalMaxPooling2D(data_format='channels_last')(dense)
	top2_model=GlobalMaxPooling2D(data_format='channels_last')(xception)
	#top3_model=GlobalMaxPool2D(input_shape=res.output_shape)(res.outputs[0])
	
	print(top1_model.shape,top2_model.shape)
	#把top1_model和top2_model连接起来
	t=keras.layers.Concatenate(axis=1)([top1_model,top2_model])
	#第一个全连接层
	top_model=Dense(units=512,activation="relu")(t)
	top_model=Dropout(rate=0.5)(top_model)
	top_model=Dense(units=class_num,activation="softmax")(top_model)
	
	model=Model(inputs=input_layer,outputs=top_model)
 
	#加载全部的参数
	if all_weights_path:
		model.load_weights(all_weights_path)
	return model

如下进行调用:

if __name__=="__main__":
 weights_path=["./densenet121_weights_tf_dim_ordering_tf_kernels_notop.h5",
 "xception_weights_tf_dim_ordering_tf_kernels_notop.h5"]
 model=Multimodel(cnn_weights_path=weights_path,class_num=6)
 plot_model(model,to_file="G:/model.png")

最后生成的模型图如下:有点长,可以不看

使用keras实现densenet和Xception的模型融合

需要注意的一点是,如果dense=dense(input_layer)这里报错的话,说明你用的是tensorflow1.4以下的版本,解决的方法就是

1、升级tensorflow到1.4以上

2、改代码:

def Multimodel(cnn_weights_path=None,all_weights_path=None,class_num=5,cnn_no_vary=False):
	'''
	获取densent121,xinception并联的网络
	此处的cnn_weights_path是个列表是densenet和xception的卷积部分的权值
	'''
	dir=os.getcwd()
	input_layer=Input(shape=(224,224,3))
	
	dense=DenseNet121(include_top=False,weights=None,input_tensor=input_layer,
		input_shape=(224,224,3))
	xception=Xception(include_top=False,weights=None,input_tensor=input_layer,
		input_shape=(224,224,3))
	#res=ResNet50(include_top=False,weights=None,input_shape=(224,224,3))
 
	if cnn_no_vary:
		for i,layer in enumerate(dense.layers):
			dense.layers[i].trainable=False
		for i,layer in enumerate(xception.layers):
			xception.layers[i].trainable=False
		#for i,layer in enumerate(res.layers):
		#	res.layers[i].trainable=False
	if cnn_weights_path!=None:
		dense.load_weights(cnn_weights_path[0])
		xception.load_weights(cnn_weights_path[1])
 
	#print(dense.shape,xception.shape)
	#对dense_121和xception进行全局最大池化
	top1_model=GlobalMaxPooling2D(input_shape=(7,7,1024),data_format='channels_last')(dense.output)
	top2_model=GlobalMaxPooling2D(input_shape=(7,7,1024),data_format='channels_last')(xception.output)
	#top3_model=GlobalMaxPool2D(input_shape=res.output_shape)(res.outputs[0])
	
	print(top1_model.shape,top2_model.shape)
	#把top1_model和top2_model连接起来
	t=keras.layers.Concatenate(axis=1)([top1_model,top2_model])
	#第一个全连接层
	top_model=Dense(units=512,activation="relu")(t)
	top_model=Dropout(rate=0.5)(top_model)
	top_model=Dense(units=class_num,activation="softmax")(top_model)
	
	model=Model(inputs=input_layer,outputs=top_model)
 
	#加载全部的参数
	if all_weights_path:
		model.load_weights(all_weights_path)
	return model

这个bug我也是在服务器上跑的时候才出现的,找了半天,而实验室的cuda和cudnn又改不了,tensorflow无法升级,因此只能改代码了。

如下所示,是最后画出的模型图:(很长,底下没内容了)

使用keras实现densenet和Xception的模型融合

以上这篇使用keras实现densenet和Xception的模型融合就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
分享Python开发中要注意的十个小贴士
Aug 30 Python
浅析Python中return和finally共同挖的坑
Aug 18 Python
Python实现的十进制小数与二进制小数相互转换功能
Oct 12 Python
django实现登录时候输入密码错误5次锁定用户十分钟
Nov 05 Python
Django入门使用示例
Dec 12 Python
Python SQLite3简介
Feb 22 Python
Django uwsgi Nginx 的生产环境部署详解
Feb 02 Python
基于torch.where和布尔索引的速度比较
Jan 02 Python
Pycharm如何导入python文件及解决报错问题
May 10 Python
基于python实现获取网页图片过程解析
May 11 Python
JAVA及PYTHON质数计算代码对比解析
Jun 10 Python
python virtualenv虚拟环境配置与使用教程详解
Jul 13 Python
在keras下实现多个模型的融合方式
May 23 #Python
Keras使用ImageNet上预训练的模型方式
May 23 #Python
使用Keras预训练模型ResNet50进行图像分类方式
May 23 #Python
基于Python中random.sample()的替代方案
May 23 #Python
keras 自定义loss损失函数,sample在loss上的加权和metric详解
May 23 #Python
keras中模型训练class_weight,sample_weight区别说明
May 23 #Python
浅谈keras中的Merge层(实现层的相加、相减、相乘实例)
May 23 #Python
You might like
文章推荐系统(二)
2006/10/09 PHP
社区(php&&mysql)六
2006/10/09 PHP
杏林同学录(三)
2006/10/09 PHP
PHP防CC攻击实现代码
2011/12/29 PHP
ThinkPHP惯例配置文件详解
2014/07/14 PHP
php自动识别文字编码并转换为目标编码的方法
2015/08/08 PHP
针对thinkPHP5框架存储过程bug重写的存储过程扩展类完整实例
2018/06/16 PHP
PHP PDOStatement::rowCount讲解
2019/02/01 PHP
javascript 跳转代码集合
2009/12/03 Javascript
javascript:void(0)使用探讨
2013/08/27 Javascript
在页面上用action传递参数到后台出现乱码的解决方法
2013/12/31 Javascript
jQuery之ajax删除详解
2014/02/27 Javascript
Node.js利用Net模块实现多人命令行聊天室的方法
2016/12/23 Javascript
基于BootStrap的前端分页带省略号和上下页效果
2017/05/18 Javascript
SpringBoot+Vue前后端分离,使用SpringSecurity完美处理权限问题的解决方法
2018/01/09 Javascript
微信小程序开发实现消息推送
2020/11/18 Javascript
Vue+axios+WebApi+NPOI导出Excel文件实例方法
2019/06/05 Javascript
JavaScript装箱及拆箱boxing及unBoxing用法解析
2020/06/15 Javascript
JS sort方法基于数组对象属性值排序
2020/07/10 Javascript
[42:20]2014 DOTA2华西杯精英邀请赛5 24 DK VS NewBee
2014/05/25 DOTA
[52:37]完美世界DOTA2联赛循环赛 Forest vs DM BO2第一场 10.29
2020/10/29 DOTA
python每隔N秒运行指定函数的方法
2015/03/16 Python
Python中urllib+urllib2+cookielib模块编写爬虫实战
2016/01/20 Python
基于python OpenCV实现动态人脸检测
2018/05/25 Python
python将回车作为输入内容的实例
2018/06/23 Python
flask框架自定义url转换器操作详解
2020/01/25 Python
Python xlrd模块导入过程及常用操作
2020/06/10 Python
python使用多线程查询数据库的实现示例
2020/08/17 Python
欧洲领先的电子和电信零售商和服务提供商:Currys PC World Business
2017/12/05 全球购物
Envie de Fraise意大利:法国网上推出的孕妇装品牌
2020/10/18 全球购物
师范大学毕业自我鉴定
2013/11/21 职场文书
七年级政治教学反思
2014/02/03 职场文书
项目施工员岗位职责
2014/03/09 职场文书
电教室标语
2014/06/20 职场文书
婚礼答谢礼品
2015/01/20 职场文书
vue实现拖拽交换位置
2022/04/07 Vue.js