keras自定义损失函数并且模型加载的写法介绍


Posted in Python onJune 15, 2020

keras自定义函数时候,正常在模型里自己写好自定义的函数,然后在模型编译的那行代码里写上接口即可。如下所示,focal_loss和fbeta_score是我们自己定义的两个函数,在model.compile加入它们,metrics里‘accuracy'是keras自带的度量函数。

def focal_loss():
 ...
 return xx
def fbeta_score():
 ...
 return yy
model.compile(optimizer=Adam(lr=0.0001), loss=[focal_loss],metrics=['accuracy',fbeta_score] )

训练好之后,模型加载也需要再额外加一行,通过load_model里的custom_objects将我们定义的两个函数以字典的形式加入就能正常加载模型啦。

weight_path = './weights.h5'
model = load_model(weight_path,custom_objects={'focal_loss': focal_loss,'fbeta_score':fbeta_score})

补充知识:keras如何使用自定义的loss及评价函数进行训练及预测

1.有时候训练模型,现有的损失及评估函数并不足以科学的训练评估模型,这时候就需要自定义一些损失评估函数,比如focal loss损失函数及dice评价函数 for unet的训练。

2.在训练建模中导入自定义loss及评估函数。

#模型编译时加入自定义loss及评估函数
model.compile(optimizer = Adam(lr=1e-4), loss=[binary_focal_loss()],
    metrics=['accuracy',dice_coef])

#自定义loss及评估函数
def binary_focal_loss(gamma=2, alpha=0.25):
 """
 Binary form of focal loss.
 适用于二分类问题的focal loss
 focal_loss(p_t) = -alpha_t * (1 - p_t)**gamma * log(p_t)
  where p = sigmoid(x), p_t = p or 1 - p depending on if the label is 1 or 0, respectively.
 References:
  https://arxiv.org/pdf/1708.02002.pdf
 Usage:
  model.compile(loss=[binary_focal_loss(alpha=.25, gamma=2)], metrics=["accuracy"], optimizer=adam)
 """
 alpha = tf.constant(alpha, dtype=tf.float32)
 gamma = tf.constant(gamma, dtype=tf.float32)

 def binary_focal_loss_fixed(y_true, y_pred):
  """
  y_true shape need be (None,1)
  y_pred need be compute after sigmoid
  """
  y_true = tf.cast(y_true, tf.float32)
  alpha_t = y_true * alpha + (K.ones_like(y_true) - y_true) * (1 - alpha)

  p_t = y_true * y_pred + (K.ones_like(y_true) - y_true) * (K.ones_like(y_true) - y_pred) + K.epsilon()
  focal_loss = - alpha_t * K.pow((K.ones_like(y_true) - p_t), gamma) * K.log(p_t)
  return K.mean(focal_loss)

 return binary_focal_loss_fixed

#'''
#smooth 参数防止分母为0
def dice_coef(y_true, y_pred, smooth=1):
 intersection = K.sum(y_true * y_pred, axis=[1,2,3])
 union = K.sum(y_true, axis=[1,2,3]) + K.sum(y_pred, axis=[1,2,3])
 return K.mean( (2. * intersection + smooth) / (union + smooth), axis=0)

注意在模型保存时,记录的loss函数名称:你猜是哪个

a:binary_focal_loss()

b:binary_focal_loss_fixed

3.模型预测时,也要加载自定义loss及评估函数,不然会报错。

该告诉上面的答案了,保存在模型中loss的名称为:binary_focal_loss_fixed,在模型预测时,定义custom_objects字典,key一定要与保存在模型中的名称一致,不然会找不到loss function。所以自定义函数时,尽量避免使用我这种函数嵌套的方式,免得带来一些意想不到的烦恼。

model = load_model('./unet_' + label + '_20.h5',custom_objects={'binary_focal_loss_fixed': binary_focal_loss(),'dice_coef': dice_coef})

以上这篇keras自定义损失函数并且模型加载的写法介绍就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python使用新浪微博api上传图片到微博示例
Jan 10 Python
Python 3.6 读取并操作文件内容的实例
Apr 23 Python
对python内置map和six.moves.map的区别详解
Dec 19 Python
Django学习笔记之为Model添加Action
Apr 30 Python
python使用threading.Condition交替打印两个字符
May 07 Python
python傅里叶变换FFT绘制频谱图
Jul 19 Python
python 视频逐帧保存为图片的完整实例
Dec 10 Python
python print 格式化输出,动态指定长度的实现
Apr 12 Python
Python selenium如何打包静态网页并下载
Aug 12 Python
scrapy利用selenium爬取豆瓣阅读的全步骤
Sep 20 Python
python中实现栈的三种方法
Dec 19 Python
自己搭建resnet18网络并加载torchvision自带权重的操作
May 13 Python
python语言是免费还是收费的?
Jun 15 #Python
DataFrame.groupby()所见的各种用法详解
Jun 14 #Python
详解pandas.DataFrame.plot() 画图函数
Jun 14 #Python
Pandas把dataframe或series转换成list的方法
Jun 14 #Python
详解pandas获取Dataframe元素值的几种方法
Jun 14 #Python
Pandas对DataFrame单列/多列进行运算(map, apply, transform, agg)
Jun 14 #Python
Python脚本破解压缩文件口令实例教程(zipfile)
Jun 14 #Python
You might like
PHP+DBM的同学录程序(2)
2006/10/09 PHP
php使用异或实现的加密解密实例
2013/09/04 PHP
codeigniter教程之上传视频并使用ffmpeg转flv示例
2014/02/13 PHP
Php中使用Select 查询语句的实例
2014/02/19 PHP
10个实用的PHP正则表达式汇总
2014/10/23 PHP
PHP中isset与array_key_exists的区别实例分析
2015/06/02 PHP
ThinkPHP实现递归无级分类――代码少
2015/07/29 PHP
yii2中dropDownList实现二级和三级联动写法
2017/04/26 PHP
PHP数据对象映射模式实例分析
2019/03/29 PHP
JS 页面自动加载函数(兼容多浏览器)
2009/05/18 Javascript
jQuery 页面 Mask实现代码
2010/01/09 Javascript
juqery 学习之四 筛选查找
2010/11/30 Javascript
JS获取DropDownList的value值与text值的示例代码
2014/01/07 Javascript
JQuery判断radio是否选中并获取选中值的示例代码
2014/10/17 Javascript
AngularJS Module方法详解
2015/12/08 Javascript
Vue多种方法实现表头和首列固定的示例代码
2018/02/02 Javascript
webpack4简单入门实例
2018/09/06 Javascript
nodejs中函数的调用实例详解
2018/10/31 NodeJs
vue中实现弹出层动画效果的示例代码
2020/09/25 Javascript
原生小程序封装跑马灯效果
2020/10/21 Javascript
[02:43]2014DOTA2国际邀请赛 官方Alliance战队纪录片
2014/07/14 DOTA
Python高效编程技巧
2013/01/07 Python
使用Python+Splinter自动刷新抢12306火车票
2018/01/03 Python
对tensorflow 中tile函数的使用详解
2020/02/07 Python
Python将二维列表list的数据输出(TXT,Excel)
2020/04/23 Python
Anaconda3中的Jupyter notebook添加目录插件的实现
2020/05/18 Python
Pytorch转keras的有效方法,以FlowNet为例讲解
2020/05/26 Python
南非领先的在线旅行社:Travelstart南非
2016/09/04 全球购物
现代家居用品及礼品:LBC Modern
2018/06/24 全球购物
美国打印机墨水和碳粉购物网站:QuikShip Toner
2018/08/29 全球购物
药物学专业学生的自我评价
2013/10/27 职场文书
单位消防安全制度
2014/01/12 职场文书
2014村书记党建工作汇报材料
2014/11/02 职场文书
小学生光盘行动倡议书
2015/04/28 职场文书
Go语言基础map用法及示例详解
2021/11/17 Golang
Python PIL按比例裁剪图片
2022/05/11 Python