keras读取训练好的模型参数并把参数赋值给其它模型详解


Posted in Python onJune 15, 2020

介绍

本博文中的代码,实现的是加载训练好的模型model_halcon_resenet.h5,并把该模型的参数赋值给两个不同的新的model。

函数式模型

官网上给出的调用一个训练好模型,并输出任意层的feature。

model = Model(inputs=base_model.input, outputs=base_model.get_layer(‘block4_pool').output)

但是这有一个问题,就是新的model,如果输入inputs和训练好的model的inputs大小不同呢?比如我想建立一个输入是600x600x3的新model,但是训练好的model输入是200x200x3,而这时我又想调用训练好模型的卷积核参数,这时该怎么办呢?

其实想一下,用训练好的模型参数,即使输入的尺寸不同,但是这些模型参数仍然可以处理计算,只是输出的feature map大小不同。那到底怎么赋值呢?其实很简单

在定义新的model时,新的model层在定义时,需要加上名字,而这个名字就是训练好的模型的每层名字。如下代码所示:

inputs=Input(shape=(400,500,3))
X=Conv2D(32, (3, 3),name=“conv2d_1”)(inputs)
X=BatchNormalization(name=“batch_normalization_1”)(X)
X=Activation(‘relu',name=“activation_1”)(X)

最后通过以下代码即可建立一个新的模型并拥有训练好模型的参数:

model=Model(inputs=inputs, outputs=X)
model.load_weights(‘model_halcon_resenet.h5', by_name=True)

源代码

from keras.models import load_model
from keras.preprocessing import image
from keras.applications.vgg19 import preprocess_input
from keras.models import Model
import numpy as np
from keras.layers import Conv2D, MaxPooling2D,merge
from keras.layers import BatchNormalization,Activation
from keras.layers import Input, Dense
from PIL import Image
import numpy as np
import keras
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten,Input
from keras.layers import Conv2D, MaxPooling2D,merge,AveragePooling2D,GlobalAveragePooling2D
from keras.layers import BatchNormalization,Activation
from sklearn.model_selection import train_test_split
from keras.applications.densenet import DenseNet169, DenseNet121
from keras.applications.inception_resnet_v2 import InceptionResNetV2
from keras.applications.inception_v3 import InceptionV3
from keras.optimizers import SGD
from keras import regularizers
from keras.models import Model
import tensorflow as tf
from PIL import Image
from keras.callbacks import TensorBoard
import os
import cv2
from keras import backend as K
from model import focal_loss
import keras.losses

#ReadMe 该代码是参考fast rcnn系列,先对整幅图像提取特征feature map,然后从原图对应位置上映射到feature map,并对feature map进行
# 切片,从而提取对应某个位置上的特征,并把该特征送进后面的识别网络进行分类识别。
keras.losses.focal_loss = focal_loss#这句代码是为了引入定义的loss
base_model=load_model('model_halcon_resenet.h5')
base_model.summary()

inputs=Input(shape=(400,500,3))
X=Conv2D(32, (3, 3),name="conv2d_1")(inputs)
X=BatchNormalization(name="batch_normalization_1")(X)
X=Activation('relu',name="activation_1")(X)
#第一个残差模块
X_1=Conv2D(32, (3, 3),padding='same',name="conv2d_2")(X)
X_1=BatchNormalization(name="batch_normalization_2")(X_1)
X_1= Activation('relu',name="activation_2")(X_1)
X_1 = Conv2D(32, (3, 3),padding='same',name="conv2d_3")(X_1)
X_1 = BatchNormalization(name="batch_normalization_3")(X_1)
merge_data = merge([X_1, X], mode='sum',name="merge_1")
X = Activation('relu',name="activation_3")(merge_data)
#第一个残差模块结束
X=MaxPooling2D(pool_size=(2, 2),strides=(2,2),name="max_pooling2d_1")(X)
X=Conv2D(64, (3, 3),kernel_regularizer=regularizers.l2(0.01),name="conv2d_4")(X)
X=BatchNormalization(name="batch_normalization_4")(X)
X=Activation('relu',name="activation_4")(X)
#第二个残差模块
X_2=Conv2D(64, (3, 3),padding='same',name="conv2d_5")(X)
X_2=BatchNormalization(name="batch_normalization_5")(X_2)
X_2= Activation('relu',name="activation_5")(X_2)
X_2 = Conv2D(64, (3, 3),padding='same',name="conv2d_6")(X_2)
X_2 = BatchNormalization(name="batch_normalization_6")(X_2)
merge_data = merge([X_2, X], mode='sum',name="merge_2")
X = Activation('relu',name="activation_6")(merge_data)
#第二个残差模块结束
X = MaxPooling2D(pool_size=(2, 2), strides=(2, 2),name="max_pooling2d_2")(X)
X=Conv2D(64, (3, 3),name="conv2d_7")(X)
X=BatchNormalization(name="batch_normalization_7")(X)
X=Activation('relu',name="activation_7")(X)
X=MaxPooling2D(pool_size=(2, 2),strides=(2,2),name="max_pooling2d_3")(X)
#第三个残差模块开始
X_3=Conv2D(64, (3, 3),padding='same',name="conv2d_8")(X)
X_3=BatchNormalization(name="batch_normalization_8")(X_3)
X_3= Activation('relu',name="activation_8")(X_3)
X_3 = Conv2D(64, (3, 3),padding='same',name="conv2d_9")(X_3)
X_3 = BatchNormalization(name="batch_normalization_9")(X_3)
merge_data = merge([X_3, X], mode='sum',name="merge_3")
X = Activation('relu',name="activation_9")(merge_data)
#第三个残差模块结束
X=Conv2D(32, (3, 3),kernel_regularizer=regularizers.l2(0.01),name="conv2d_10")(X)
X=BatchNormalization(name="batch_normalization_10")(X)
X=Activation('relu',name="activation_10")(X)
#第四个残差模块开始
X_4=Conv2D(32, (3, 3),padding='same',name="conv2d_11")(X)
X_4=BatchNormalization(name="batch_normalization_11")(X_4)
X_4= Activation('relu',name="activation_11")(X_4)
X_4 = Conv2D(32, (3, 3),padding='same',name="conv2d_12")(X_4)
X_4 = BatchNormalization(name="batch_normalization_12")(X_4)
merge_data = merge([X_4, X], mode='sum',name="merge_4")
X = Activation('relu',name="activation_12")(merge_data)
#第四个残差模块结束
X = MaxPooling2D(pool_size=(2, 2), strides=(2, 2),name="max_pooling2d_4")(X)
X = Conv2D(64, (3, 3),name="conv2d_13")(X)
X = BatchNormalization(name="batch_normalization_13")(X)
X = Activation('relu',name="activation_13")(X)
#第五个残差模块开始
X_5=Conv2D(64, (3, 3),padding='same',name="conv2d_14")(X)
X_5=BatchNormalization(name="batch_normalization_14")(X_5)
X_5= Activation('relu',name="activation_14")(X_5)
X_5 = Conv2D(64, (3, 3),padding='same',name="conv2d_15")(X_5)
X_5 = BatchNormalization(name="batch_normalization_15")(X_5)
merge_data = merge([X_5, X], mode='sum',name="merge_5")
X = Activation('relu',name="activation_15")(merge_data)
#第五个残差模块结束
model=Model(inputs=inputs, outputs=X)
model.load_weights('model_halcon_resenet.h5', by_name=True)
#读取指定图像数据
image_dir='C:/Users/18301/Desktop/blister/new/blister_mixed_11.png'
img = image.load_img(image_dir, target_size=(400, 500))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)
#利用第一个模型预测出特征数据,并对特征数据进行切片
feature_map=model.predict(x)
T=np.array(feature_map)
f_1=T[:,16:21,0:10,:]
print(f_1.shape)
print(feature_map.shape)
#第一个模型没有问题
#定义第二个模型
inputs_sec=Input(shape=(1,5,10,64))
X_= Flatten(name="flatten_1")(inputs_sec)
X_ = Dense(256, activation='relu',name="dense_1")(X_)
X_ = Dropout(0.5,name="dropout_1")(X_)
predictions = Dense(6, activation='softmax',name="dense_2")(X_)
model_sec=Model(inputs=inputs_sec, outputs=predictions)
model_sec.load_weights('model_halcon_resenet.h5', by_name=True)
#第二个模型定义结束
model_sec.summary()
#开始对整幅图像进行切片,并记录坐标位置
pic=cv2.imread(image_dir)
cor_list=[]
name_list=['blank','green_blank','red_blank','yellow','yellow_balnk','yellow_blue']
font = cv2.FONT_HERSHEY_SIMPLEX
for i in range(3):
 for j in range(5):
 if(i==2):
  cut_feature = T[:, 4 * j:4 * j + 5, 17:27, :]
  data = np.expand_dims(cut_feature, axis=0)
  result = model_sec.predict(data)
  print(result)
  result_data=result[0].tolist()
  #如果置信度过低,则舍弃
  # if(max(result_data)<=0.7):
  # continue
  index_num = result_data.index(max(result_data))
  name=name_list[index_num]
  cor_list = [i * 160 + 6, j * 80] # 每个切片数据,映射到原图上,检测框对应的左上角坐标
  x=cor_list[0]
  y=cor_list[1]
  cv2.rectangle(pic, (160 * i + 6, 80 * j), ((i + 1) * 160 + 6, 80 * (j+ 1)), (0, 255, 0), 2)
  cv2.putText(pic, name, (x + 40, y + 40), font, 0.5, (0, 0, 255), 1)
 else:
  cut_feature = T[:, 4 * j:4 * j + 5, 9 * i:9 * i + 10, :]
  data = np.expand_dims(cut_feature, axis=0)
  result = model_sec.predict(data)
  print(result)
  result_data = result[0].tolist()
  #如果置信度过低,则舍弃
  # if (max(result_data) <= 0.7):
  # continue
  index_num = result_data.index(max(result_data))
  name = name_list[index_num]
  cor_list = [i * 160 + 6, j * 80] # 每个切片数据,映射到原图上,检测框对应的左上角坐标
  x = cor_list[0]
  y = cor_list[1]
  cv2.rectangle(pic, (160 * i + 6, 80 * j), ((i + 1) * 160 + 6, 80 * (j + 1)), (0, 255, 0), 2)
  cv2.putText(pic, name, (x + 40, y + 40), font, 0.5, (0, 0, 255), 1)

cv2.imshow('pic',pic)
cv2.waitKey(0)
cv2.destroyAllWindows()
# data= np.expand_dims(f_1, axis=0)
# result=model_sec.predict(data)
# print(result)
#第二个模型可以完全预测,没有问题

补充知识:加载训练好的模型参数,但是权重一直变化

keras读取训练好的模型参数并把参数赋值给其它模型详解

变量初始化会导致权重发生变化,去掉就好了。

以上这篇keras读取训练好的模型参数并把参数赋值给其它模型详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python创建关联数组(字典)的方法
May 04 Python
详解常用查找数据结构及算法(Python实现)
Dec 09 Python
Python 实现数据库更新脚本的生成方法
Jul 09 Python
python实现人脸识别代码
Nov 08 Python
python 获取字符串MD5值方法
May 29 Python
解决Python pandas df 写入excel 出现的问题
Jul 04 Python
基于Python的微信机器人开发 微信登录和获取好友列表实现解析
Aug 21 Python
详解Pycharm出现out of memory的终极解决方法
Mar 03 Python
PyCharm永久激活方式(推荐)
Sep 22 Python
Python通过socketserver处理多个链接
Mar 18 Python
Python如何使用队列方式实现多线程爬虫
May 12 Python
Python安装使用Scrapy框架
Apr 12 Python
keras得到每层的系数方式
Jun 15 #Python
Python类及获取对象属性方法解析
Jun 15 #Python
在Keras中实现保存和加载权重及模型结构
Jun 15 #Python
简单了解Python多态与属性运行原理
Jun 15 #Python
Python类super()及私有属性原理解析
Jun 15 #Python
Keras 实现加载预训练模型并冻结网络的层
Jun 15 #Python
Python StringIO及BytesIO包使用方法解析
Jun 15 #Python
You might like
PHP之图片上传类实例代码(加了缩略图)
2016/06/30 PHP
PHP编程计算日期间隔天数的方法
2017/04/26 PHP
js 通过html()及text()方法获取并设置p标签的显示值
2014/05/14 Javascript
JS或jQuery获取ASP.NET服务器控件ID的方法
2015/06/08 Javascript
基于JS实现新闻列表无缝向上滚动实例代码
2016/01/22 Javascript
nodejs加密Crypto的实例代码
2016/07/07 NodeJs
浅谈Angular路由复用策略
2017/10/04 Javascript
详解vuex的简单使用
2018/03/12 Javascript
vue 开发一个按钮组件的示例代码
2018/03/27 Javascript
安装vue-cli的简易过程
2018/05/22 Javascript
vue写h5页面的方法总结
2019/02/12 Javascript
详解VS Code使用之Vue工程配置format代码格式化
2019/03/20 Javascript
Python中字典(dict)和列表(list)的排序方法实例
2014/06/16 Python
简单介绍Python中的readline()方法的使用
2015/05/24 Python
一波神奇的Python语句、函数与方法的使用技巧总结
2015/12/08 Python
python中实现精确的浮点数运算详解
2017/11/02 Python
python实现抖音视频批量下载
2018/06/20 Python
浅析python的优势和不足之处
2018/11/20 Python
33个Python爬虫项目实战(推荐)
2019/07/08 Python
python中如何实现将数据分成训练集与测试集的方法
2019/09/13 Python
python next()和iter()函数原理解析
2020/02/07 Python
Python virtualenv虚拟环境实现过程解析
2020/04/18 Python
Python3使用 GitLab API 进行批量合并分支
2020/10/15 Python
python 下划线的不同用法
2020/10/24 Python
css3学习之2D转换功能详解
2016/12/23 HTML / CSS
欧洲著名的珠宝和手表网上商城:uhrcenter
2017/04/10 全球购物
乌克兰电子和家用电器商店:Foxtrot
2019/07/23 全球购物
思想政治自我鉴定
2013/10/06 职场文书
开展党的群众路线教育实践活动方案
2014/02/05 职场文书
运动会稿件50字
2014/02/17 职场文书
地理科学专业自荐信
2014/09/01 职场文书
公司收款委托书范本
2014/09/20 职场文书
周末问候语大全
2015/11/10 职场文书
怎样做好公众演讲能力?
2019/08/28 职场文书
python批量创建变量并赋值操作
2021/06/03 Python
Win11使用CAD卡顿或者致命错误怎么办?Win11无法正常使用CAD的解决方法
2022/07/23 数码科技