keras读取训练好的模型参数并把参数赋值给其它模型详解


Posted in Python onJune 15, 2020

介绍

本博文中的代码,实现的是加载训练好的模型model_halcon_resenet.h5,并把该模型的参数赋值给两个不同的新的model。

函数式模型

官网上给出的调用一个训练好模型,并输出任意层的feature。

model = Model(inputs=base_model.input, outputs=base_model.get_layer(‘block4_pool').output)

但是这有一个问题,就是新的model,如果输入inputs和训练好的model的inputs大小不同呢?比如我想建立一个输入是600x600x3的新model,但是训练好的model输入是200x200x3,而这时我又想调用训练好模型的卷积核参数,这时该怎么办呢?

其实想一下,用训练好的模型参数,即使输入的尺寸不同,但是这些模型参数仍然可以处理计算,只是输出的feature map大小不同。那到底怎么赋值呢?其实很简单

在定义新的model时,新的model层在定义时,需要加上名字,而这个名字就是训练好的模型的每层名字。如下代码所示:

inputs=Input(shape=(400,500,3))
X=Conv2D(32, (3, 3),name=“conv2d_1”)(inputs)
X=BatchNormalization(name=“batch_normalization_1”)(X)
X=Activation(‘relu',name=“activation_1”)(X)

最后通过以下代码即可建立一个新的模型并拥有训练好模型的参数:

model=Model(inputs=inputs, outputs=X)
model.load_weights(‘model_halcon_resenet.h5', by_name=True)

源代码

from keras.models import load_model
from keras.preprocessing import image
from keras.applications.vgg19 import preprocess_input
from keras.models import Model
import numpy as np
from keras.layers import Conv2D, MaxPooling2D,merge
from keras.layers import BatchNormalization,Activation
from keras.layers import Input, Dense
from PIL import Image
import numpy as np
import keras
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten,Input
from keras.layers import Conv2D, MaxPooling2D,merge,AveragePooling2D,GlobalAveragePooling2D
from keras.layers import BatchNormalization,Activation
from sklearn.model_selection import train_test_split
from keras.applications.densenet import DenseNet169, DenseNet121
from keras.applications.inception_resnet_v2 import InceptionResNetV2
from keras.applications.inception_v3 import InceptionV3
from keras.optimizers import SGD
from keras import regularizers
from keras.models import Model
import tensorflow as tf
from PIL import Image
from keras.callbacks import TensorBoard
import os
import cv2
from keras import backend as K
from model import focal_loss
import keras.losses

#ReadMe 该代码是参考fast rcnn系列,先对整幅图像提取特征feature map,然后从原图对应位置上映射到feature map,并对feature map进行
# 切片,从而提取对应某个位置上的特征,并把该特征送进后面的识别网络进行分类识别。
keras.losses.focal_loss = focal_loss#这句代码是为了引入定义的loss
base_model=load_model('model_halcon_resenet.h5')
base_model.summary()

inputs=Input(shape=(400,500,3))
X=Conv2D(32, (3, 3),name="conv2d_1")(inputs)
X=BatchNormalization(name="batch_normalization_1")(X)
X=Activation('relu',name="activation_1")(X)
#第一个残差模块
X_1=Conv2D(32, (3, 3),padding='same',name="conv2d_2")(X)
X_1=BatchNormalization(name="batch_normalization_2")(X_1)
X_1= Activation('relu',name="activation_2")(X_1)
X_1 = Conv2D(32, (3, 3),padding='same',name="conv2d_3")(X_1)
X_1 = BatchNormalization(name="batch_normalization_3")(X_1)
merge_data = merge([X_1, X], mode='sum',name="merge_1")
X = Activation('relu',name="activation_3")(merge_data)
#第一个残差模块结束
X=MaxPooling2D(pool_size=(2, 2),strides=(2,2),name="max_pooling2d_1")(X)
X=Conv2D(64, (3, 3),kernel_regularizer=regularizers.l2(0.01),name="conv2d_4")(X)
X=BatchNormalization(name="batch_normalization_4")(X)
X=Activation('relu',name="activation_4")(X)
#第二个残差模块
X_2=Conv2D(64, (3, 3),padding='same',name="conv2d_5")(X)
X_2=BatchNormalization(name="batch_normalization_5")(X_2)
X_2= Activation('relu',name="activation_5")(X_2)
X_2 = Conv2D(64, (3, 3),padding='same',name="conv2d_6")(X_2)
X_2 = BatchNormalization(name="batch_normalization_6")(X_2)
merge_data = merge([X_2, X], mode='sum',name="merge_2")
X = Activation('relu',name="activation_6")(merge_data)
#第二个残差模块结束
X = MaxPooling2D(pool_size=(2, 2), strides=(2, 2),name="max_pooling2d_2")(X)
X=Conv2D(64, (3, 3),name="conv2d_7")(X)
X=BatchNormalization(name="batch_normalization_7")(X)
X=Activation('relu',name="activation_7")(X)
X=MaxPooling2D(pool_size=(2, 2),strides=(2,2),name="max_pooling2d_3")(X)
#第三个残差模块开始
X_3=Conv2D(64, (3, 3),padding='same',name="conv2d_8")(X)
X_3=BatchNormalization(name="batch_normalization_8")(X_3)
X_3= Activation('relu',name="activation_8")(X_3)
X_3 = Conv2D(64, (3, 3),padding='same',name="conv2d_9")(X_3)
X_3 = BatchNormalization(name="batch_normalization_9")(X_3)
merge_data = merge([X_3, X], mode='sum',name="merge_3")
X = Activation('relu',name="activation_9")(merge_data)
#第三个残差模块结束
X=Conv2D(32, (3, 3),kernel_regularizer=regularizers.l2(0.01),name="conv2d_10")(X)
X=BatchNormalization(name="batch_normalization_10")(X)
X=Activation('relu',name="activation_10")(X)
#第四个残差模块开始
X_4=Conv2D(32, (3, 3),padding='same',name="conv2d_11")(X)
X_4=BatchNormalization(name="batch_normalization_11")(X_4)
X_4= Activation('relu',name="activation_11")(X_4)
X_4 = Conv2D(32, (3, 3),padding='same',name="conv2d_12")(X_4)
X_4 = BatchNormalization(name="batch_normalization_12")(X_4)
merge_data = merge([X_4, X], mode='sum',name="merge_4")
X = Activation('relu',name="activation_12")(merge_data)
#第四个残差模块结束
X = MaxPooling2D(pool_size=(2, 2), strides=(2, 2),name="max_pooling2d_4")(X)
X = Conv2D(64, (3, 3),name="conv2d_13")(X)
X = BatchNormalization(name="batch_normalization_13")(X)
X = Activation('relu',name="activation_13")(X)
#第五个残差模块开始
X_5=Conv2D(64, (3, 3),padding='same',name="conv2d_14")(X)
X_5=BatchNormalization(name="batch_normalization_14")(X_5)
X_5= Activation('relu',name="activation_14")(X_5)
X_5 = Conv2D(64, (3, 3),padding='same',name="conv2d_15")(X_5)
X_5 = BatchNormalization(name="batch_normalization_15")(X_5)
merge_data = merge([X_5, X], mode='sum',name="merge_5")
X = Activation('relu',name="activation_15")(merge_data)
#第五个残差模块结束
model=Model(inputs=inputs, outputs=X)
model.load_weights('model_halcon_resenet.h5', by_name=True)
#读取指定图像数据
image_dir='C:/Users/18301/Desktop/blister/new/blister_mixed_11.png'
img = image.load_img(image_dir, target_size=(400, 500))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)
#利用第一个模型预测出特征数据,并对特征数据进行切片
feature_map=model.predict(x)
T=np.array(feature_map)
f_1=T[:,16:21,0:10,:]
print(f_1.shape)
print(feature_map.shape)
#第一个模型没有问题
#定义第二个模型
inputs_sec=Input(shape=(1,5,10,64))
X_= Flatten(name="flatten_1")(inputs_sec)
X_ = Dense(256, activation='relu',name="dense_1")(X_)
X_ = Dropout(0.5,name="dropout_1")(X_)
predictions = Dense(6, activation='softmax',name="dense_2")(X_)
model_sec=Model(inputs=inputs_sec, outputs=predictions)
model_sec.load_weights('model_halcon_resenet.h5', by_name=True)
#第二个模型定义结束
model_sec.summary()
#开始对整幅图像进行切片,并记录坐标位置
pic=cv2.imread(image_dir)
cor_list=[]
name_list=['blank','green_blank','red_blank','yellow','yellow_balnk','yellow_blue']
font = cv2.FONT_HERSHEY_SIMPLEX
for i in range(3):
 for j in range(5):
 if(i==2):
  cut_feature = T[:, 4 * j:4 * j + 5, 17:27, :]
  data = np.expand_dims(cut_feature, axis=0)
  result = model_sec.predict(data)
  print(result)
  result_data=result[0].tolist()
  #如果置信度过低,则舍弃
  # if(max(result_data)<=0.7):
  # continue
  index_num = result_data.index(max(result_data))
  name=name_list[index_num]
  cor_list = [i * 160 + 6, j * 80] # 每个切片数据,映射到原图上,检测框对应的左上角坐标
  x=cor_list[0]
  y=cor_list[1]
  cv2.rectangle(pic, (160 * i + 6, 80 * j), ((i + 1) * 160 + 6, 80 * (j+ 1)), (0, 255, 0), 2)
  cv2.putText(pic, name, (x + 40, y + 40), font, 0.5, (0, 0, 255), 1)
 else:
  cut_feature = T[:, 4 * j:4 * j + 5, 9 * i:9 * i + 10, :]
  data = np.expand_dims(cut_feature, axis=0)
  result = model_sec.predict(data)
  print(result)
  result_data = result[0].tolist()
  #如果置信度过低,则舍弃
  # if (max(result_data) <= 0.7):
  # continue
  index_num = result_data.index(max(result_data))
  name = name_list[index_num]
  cor_list = [i * 160 + 6, j * 80] # 每个切片数据,映射到原图上,检测框对应的左上角坐标
  x = cor_list[0]
  y = cor_list[1]
  cv2.rectangle(pic, (160 * i + 6, 80 * j), ((i + 1) * 160 + 6, 80 * (j + 1)), (0, 255, 0), 2)
  cv2.putText(pic, name, (x + 40, y + 40), font, 0.5, (0, 0, 255), 1)

cv2.imshow('pic',pic)
cv2.waitKey(0)
cv2.destroyAllWindows()
# data= np.expand_dims(f_1, axis=0)
# result=model_sec.predict(data)
# print(result)
#第二个模型可以完全预测,没有问题

补充知识:加载训练好的模型参数,但是权重一直变化

keras读取训练好的模型参数并把参数赋值给其它模型详解

变量初始化会导致权重发生变化,去掉就好了。

以上这篇keras读取训练好的模型参数并把参数赋值给其它模型详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
跟老齐学Python之Import 模块
Oct 13 Python
python3生成随机数实例
Oct 20 Python
Python中List.count()方法的使用教程
May 20 Python
Python内置函数 next的具体使用方法
Nov 24 Python
分析Python读取文件时的路径问题
Feb 11 Python
python实现超市扫码仪计费
May 30 Python
详解django.contirb.auth-认证
Jul 16 Python
删除DataFrame中值全为NaN或者包含有NaN的列或行方法
Nov 06 Python
浅谈python的dataframe与series的创建方法
Nov 12 Python
python gdal安装与简单使用
Aug 01 Python
python+openCV调用摄像头拍摄和处理图片的实现
Aug 06 Python
django formset实现数据表的批量操作的示例代码
Dec 06 Python
keras得到每层的系数方式
Jun 15 #Python
Python类及获取对象属性方法解析
Jun 15 #Python
在Keras中实现保存和加载权重及模型结构
Jun 15 #Python
简单了解Python多态与属性运行原理
Jun 15 #Python
Python类super()及私有属性原理解析
Jun 15 #Python
Keras 实现加载预训练模型并冻结网络的层
Jun 15 #Python
Python StringIO及BytesIO包使用方法解析
Jun 15 #Python
You might like
OAuth认证协议中的HMACSHA1加密算法(实例)
2017/10/25 PHP
PHP Class SoapClient not found解决方法
2018/01/20 PHP
Centos7.7 64位利用本地完整安装包安装lnmp/lamp套件教程
2021/03/09 Servers
js 手机号码合法性验证代码集合
2012/09/29 Javascript
jquery实现炫酷的叠加层自动切换特效
2015/02/01 Javascript
javascript实现图片跟随鼠标移动效果的方法
2015/05/13 Javascript
js基于cookie记录来宾姓名的方法
2016/07/19 Javascript
jQuery插件Validation快速完成表单验证的方式
2016/07/28 Javascript
正则表达式,替换所有HTML标签的简单实例
2016/11/28 Javascript
详解Angular 自定义结构指令
2017/06/21 Javascript
详解.vue文件解析的实现
2018/06/11 Javascript
原生javascript制作贪吃蛇小游戏的方法分析
2020/02/26 Javascript
使用vue构建多页面应用的示例
2020/10/22 Javascript
python实现在pickling的时候压缩的方法
2014/09/25 Python
Python入门篇之数字
2014/10/20 Python
python爬虫之百度API调用方法
2017/06/11 Python
python web基础之加载静态文件实例
2018/03/20 Python
Python面向对象类的继承实例详解
2018/06/27 Python
Django中日期处理注意事项与自定义时间格式转换详解
2018/08/06 Python
python实现顺序表的简单代码
2018/09/28 Python
为什么从Python 3.6开始字典有序并效率更高
2019/07/15 Python
python实现全排列代码(回溯、深度优先搜索)
2020/02/26 Python
scrapy中如何设置应用cookies的方法(3种)
2020/09/22 Python
CSS3制作翻转效果_动力节点Java学院整理
2017/07/11 HTML / CSS
阿里健康大药房:阿里自营网上药店
2017/08/01 全球购物
Linux操作面试题
2012/05/16 面试题
销售总监工作职责
2013/11/21 职场文书
电台编导求职信
2014/05/06 职场文书
应届毕业生自荐信
2014/05/28 职场文书
2014年小学数学教师工作总结
2014/12/03 职场文书
2015年毕业实习工作总结
2014/12/12 职场文书
2014年便民服务中心工作总结
2014/12/20 职场文书
经费申请报告范文
2015/05/18 职场文书
毕业生登记表班级意见
2015/06/05 职场文书
2016情人节宣传语
2015/07/14 职场文书
Redisson实现Redis分布式锁的几种方式
2021/08/07 Redis