keras自定义损失函数并且模型加载的写法介绍


Posted in Python onJune 15, 2020

keras自定义函数时候,正常在模型里自己写好自定义的函数,然后在模型编译的那行代码里写上接口即可。如下所示,focal_loss和fbeta_score是我们自己定义的两个函数,在model.compile加入它们,metrics里‘accuracy'是keras自带的度量函数。

def focal_loss():
 ...
 return xx
def fbeta_score():
 ...
 return yy
model.compile(optimizer=Adam(lr=0.0001), loss=[focal_loss],metrics=['accuracy',fbeta_score] )

训练好之后,模型加载也需要再额外加一行,通过load_model里的custom_objects将我们定义的两个函数以字典的形式加入就能正常加载模型啦。

weight_path = './weights.h5'
model = load_model(weight_path,custom_objects={'focal_loss': focal_loss,'fbeta_score':fbeta_score})

补充知识:keras如何使用自定义的loss及评价函数进行训练及预测

1.有时候训练模型,现有的损失及评估函数并不足以科学的训练评估模型,这时候就需要自定义一些损失评估函数,比如focal loss损失函数及dice评价函数 for unet的训练。

2.在训练建模中导入自定义loss及评估函数。

#模型编译时加入自定义loss及评估函数
model.compile(optimizer = Adam(lr=1e-4), loss=[binary_focal_loss()],
    metrics=['accuracy',dice_coef])

#自定义loss及评估函数
def binary_focal_loss(gamma=2, alpha=0.25):
 """
 Binary form of focal loss.
 适用于二分类问题的focal loss
 focal_loss(p_t) = -alpha_t * (1 - p_t)**gamma * log(p_t)
  where p = sigmoid(x), p_t = p or 1 - p depending on if the label is 1 or 0, respectively.
 References:
  https://arxiv.org/pdf/1708.02002.pdf
 Usage:
  model.compile(loss=[binary_focal_loss(alpha=.25, gamma=2)], metrics=["accuracy"], optimizer=adam)
 """
 alpha = tf.constant(alpha, dtype=tf.float32)
 gamma = tf.constant(gamma, dtype=tf.float32)

 def binary_focal_loss_fixed(y_true, y_pred):
  """
  y_true shape need be (None,1)
  y_pred need be compute after sigmoid
  """
  y_true = tf.cast(y_true, tf.float32)
  alpha_t = y_true * alpha + (K.ones_like(y_true) - y_true) * (1 - alpha)

  p_t = y_true * y_pred + (K.ones_like(y_true) - y_true) * (K.ones_like(y_true) - y_pred) + K.epsilon()
  focal_loss = - alpha_t * K.pow((K.ones_like(y_true) - p_t), gamma) * K.log(p_t)
  return K.mean(focal_loss)

 return binary_focal_loss_fixed

#'''
#smooth 参数防止分母为0
def dice_coef(y_true, y_pred, smooth=1):
 intersection = K.sum(y_true * y_pred, axis=[1,2,3])
 union = K.sum(y_true, axis=[1,2,3]) + K.sum(y_pred, axis=[1,2,3])
 return K.mean( (2. * intersection + smooth) / (union + smooth), axis=0)

注意在模型保存时,记录的loss函数名称:你猜是哪个

a:binary_focal_loss()

b:binary_focal_loss_fixed

3.模型预测时,也要加载自定义loss及评估函数,不然会报错。

该告诉上面的答案了,保存在模型中loss的名称为:binary_focal_loss_fixed,在模型预测时,定义custom_objects字典,key一定要与保存在模型中的名称一致,不然会找不到loss function。所以自定义函数时,尽量避免使用我这种函数嵌套的方式,免得带来一些意想不到的烦恼。

model = load_model('./unet_' + label + '_20.h5',custom_objects={'binary_focal_loss_fixed': binary_focal_loss(),'dice_coef': dice_coef})

以上这篇keras自定义损失函数并且模型加载的写法介绍就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python日期操作学习笔记
Oct 07 Python
Python过滤函数filter()使用自定义函数过滤序列实例
Aug 26 Python
Python re模块介绍
Nov 30 Python
教你使用python实现微信每天给女朋友说晚安
Mar 23 Python
pandas系列之DataFrame 行列数据筛选实例
Apr 12 Python
python实现列表中由数值查到索引的方法
Jun 27 Python
用Python将结果保存为xlsx的方法
Jan 28 Python
python利用re,bs4,requests模块获取股票数据
Jul 29 Python
Python基于Dlib的人脸识别系统的实现
Feb 26 Python
解决python -m pip install --upgrade pip 升级不成功问题
Mar 05 Python
理解Django 中Call Stack机制的小Demo
Sep 01 Python
如何在Python3中使用telnetlib模块连接网络设备
Sep 21 Python
python语言是免费还是收费的?
Jun 15 #Python
DataFrame.groupby()所见的各种用法详解
Jun 14 #Python
详解pandas.DataFrame.plot() 画图函数
Jun 14 #Python
Pandas把dataframe或series转换成list的方法
Jun 14 #Python
详解pandas获取Dataframe元素值的几种方法
Jun 14 #Python
Pandas对DataFrame单列/多列进行运算(map, apply, transform, agg)
Jun 14 #Python
Python脚本破解压缩文件口令实例教程(zipfile)
Jun 14 #Python
You might like
PHP使用适合阅读的格式显示文件大小的方法
2015/03/05 PHP
Centos6.5和Centos7 php环境搭建方法
2016/05/27 PHP
Yii2――使用数据库操作汇总(增删查改、事务)
2016/12/19 PHP
laravel 中如何使用ajax和vue总结
2017/08/16 PHP
PHP children()函数讲解
2019/02/03 PHP
php array_map()函数实例用法
2021/03/03 PHP
javascript写的一个链表实现代码
2009/10/25 Javascript
Ext 今日学习总结
2010/09/19 Javascript
javascript验证上传文件的类型限制必须为某些格式
2013/11/14 Javascript
使用JavaScript的AngularJS库编写hello world的方法
2015/06/23 Javascript
使用requestAnimationFrame实现js动画性能好
2015/08/06 Javascript
Jquery实现select multiple左右添加和删除功能的简单实例
2016/05/26 Javascript
jQuery实现拖拽页面元素并将其保存到cookie的方法
2016/06/12 Javascript
动态生成的DOM不会触发onclick事件的原因及解决方法
2016/08/06 Javascript
jQuery插件echarts设置折线图中折线线条颜色和折线点颜色的方法
2017/03/03 Javascript
jQuery插件FusionCharts实现的MSBar2D图效果示例【附demo源码】
2017/03/24 jQuery
详解Angular-cli生成组件修改css成less或sass的实例
2017/07/27 Javascript
vue 通过下拉框组件学习vue中的父子通讯
2017/12/19 Javascript
vue利用axios来完成数据的交互
2018/03/23 Javascript
windows10系统中安装python3.x+scrapy教程
2016/11/08 Python
python 简单照相机调用系统摄像头实现方法 pygame
2018/08/03 Python
python ftp 按目录结构上传下载的实现代码
2018/09/12 Python
python对于requests的封装方法详解
2019/01/03 Python
Python读取指定日期邮件的实例
2019/02/01 Python
django settings.py 配置文件及介绍
2019/07/15 Python
python 获取计算机的网卡信息
2021/02/18 Python
AmazeUI 平滑滚动效果的示例代码
2020/08/20 HTML / CSS
兰芝美国网上商城:购买LANEIGE睡眠面膜等
2017/06/30 全球购物
英国著名药妆店:Superdrug
2021/02/13 全球购物
医学院校毕业生自荐信范文
2014/01/01 职场文书
电气自动化求职信
2014/06/24 职场文书
计算机专业自荐信范文
2015/03/26 职场文书
行政助理岗位职责范本
2015/04/11 职场文书
2015年领导干部廉洁自律工作总结
2015/05/26 职场文书
大学升旗仪式主持词
2015/07/04 职场文书
田径运动会通讯稿
2015/07/18 职场文书