keras用auc做metrics以及早停实例


Posted in Python onJuly 02, 2020

我就废话不多说了,大家还是直接看代码吧~

import tensorflow as tf
from sklearn.metrics import roc_auc_score

def auroc(y_true, y_pred):
 return tf.py_func(roc_auc_score, (y_true, y_pred), tf.double)
# Build Model...

model.compile(loss='categorical_crossentropy', optimizer='adam',metrics=['accuracy', auroc])

完整例子:

def auc(y_true, y_pred):
 auc = tf.metrics.auc(y_true, y_pred)[1]
 K.get_session().run(tf.local_variables_initializer())
 return auc

def create_model_nn(in_dim,layer_size=200):
 model = Sequential()
 model.add(Dense(layer_size,input_dim=in_dim, kernel_initializer='normal'))
 model.add(BatchNormalization())
 model.add(Activation('relu'))
 model.add(Dropout(0.3))
 for i in range(2):
  model.add(Dense(layer_size))
  model.add(BatchNormalization())
  model.add(Activation('relu'))
  model.add(Dropout(0.3))
 model.add(Dense(1, activation='sigmoid'))
 adam = optimizers.Adam(lr=0.01)
 model.compile(optimizer=adam,loss='binary_crossentropy',metrics = [auc]) 
 return model
####cv train
folds = StratifiedKFold(n_splits=5, shuffle=False, random_state=15)
oof = np.zeros(len(df_train))
predictions = np.zeros(len(df_test))
for fold_, (trn_idx, val_idx) in enumerate(folds.split(df_train.values, target2.values)):
 print("fold n°{}".format(fold_))
 X_train = df_train.iloc[trn_idx][features]
 y_train = target2.iloc[trn_idx]
 X_valid = df_train.iloc[val_idx][features]
 y_valid = target2.iloc[val_idx]
 model_nn = create_model_nn(X_train.shape[1])
 callback = EarlyStopping(monitor="val_auc", patience=50, verbose=0, mode='max')
 history = model_nn.fit(X_train, y_train, validation_data = (X_valid ,y_valid),epochs=1000,batch_size=64,verbose=0,callbacks=[callback])
 print('\n Validation Max score : {}'.format(np.max(history.history['val_auc'])))
 predictions += model_nn.predict(df_test[features]).ravel()/folds.n_splits

补充知识:Keras可使用的评价函数

1:binary_accuracy(对二分类问题,计算在所有预测值上的平均正确率)

binary_accuracy(y_true, y_pred)

2:categorical_accuracy(对多分类问题,计算在所有预测值上的平均正确率)

categorical_accuracy(y_true, y_pred)

3:sparse_categorical_accuracy(与categorical_accuracy相同,在对稀疏的目标值预测时有用 )

sparse_categorical_accuracy(y_true, y_pred)

4:top_k_categorical_accuracy(计算top-k正确率,当预测值的前k个值中存在目标类别即认为预测正确 )

top_k_categorical_accuracy(y_true, y_pred, k=5)

5:sparse_top_k_categorical_accuracy(与top_k_categorical_accracy作用相同,但适用于稀疏情况)

sparse_top_k_categorical_accuracy(y_true, y_pred, k=5)

以上这篇keras用auc做metrics以及早停实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
django+tornado实现实时查看远程日志的方法
Aug 12 Python
python实现实时视频流播放代码实例
Jan 11 Python
Python使用turtle库绘制小猪佩奇(实例代码)
Jan 16 Python
PyTorch中Tensor的数据统计示例
Feb 17 Python
python GUI库图形界面开发之PyQt5表单布局控件QFormLayout详细使用方法与实例
Mar 06 Python
Python 实现使用空值进行赋值 None
Mar 12 Python
解决在keras中使用model.save()函数保存模型失败的问题
May 21 Python
python中数字是否为可变类型
Jul 08 Python
详解Python3 定义一个跨越多行的字符串的多种方法
Sep 06 Python
Pycharm新手使用教程(图文详解)
Sep 17 Python
如何Tkinter模块编写Python图形界面
Oct 14 Python
python开发一个解析protobuf文件的简单编译器
Nov 17 Python
keras 简单 lstm实例(基于one-hot编码)
Jul 02 #Python
Python装饰器结合递归原理解析
Jul 02 #Python
Python OpenCV读取中文路径图像的方法
Jul 02 #Python
keras.utils.to_categorical和one hot格式解析
Jul 02 #Python
python 使用多线程创建一个Buffer缓存器的实现思路
Jul 02 #Python
浅谈keras中的keras.utils.to_categorical用法
Jul 02 #Python
Python使用OpenPyXL处理Excel表格
Jul 02 #Python
You might like
Netflix将与CLAMP、乙一以及冲方丁等6名知名制作人合伙展开原创动画计划!
2020/03/06 日漫
PHP调用三种数据库的方法(2)
2006/10/09 PHP
php日历[测试通过]
2008/03/27 PHP
php处理斐波那契数列非递归方法
2012/02/04 PHP
PHP时间戳格式全部汇总 (获取时间、时间戳)
2016/06/13 PHP
php实现xml转换数组的方法示例
2017/02/03 PHP
PHP _construct()函数讲解
2019/02/03 PHP
YII2框架中使用RBAC对模块,控制器,方法的权限控制及规则的使用示例
2020/03/18 PHP
Mootools 1.2教程 Fx.Morph、Fx选项和Fx事件
2009/09/15 Javascript
常见效果实现之返回顶部(结合淡入、淡出、减速滚动)
2012/01/04 Javascript
40个新鲜出炉的jQuery 插件和免费教程[上]
2012/07/24 Javascript
jquery mobile动态添加元素之后不能正确渲染解决方法说明
2014/03/05 Javascript
jQuery鼠标事件汇总
2015/08/30 Javascript
jquery实现简单的banner轮播效果【实例】
2016/03/30 Javascript
Vue.js 表单校验插件
2016/08/14 Javascript
react配合antd组件实现的管理系统示例代码
2018/04/24 Javascript
layer关闭弹出窗口触发表单提交问题的处理方法
2019/09/25 Javascript
vue使用screenfull插件实现全屏功能
2020/09/17 Javascript
[41:08]TNC vs VG 2018国际邀请赛小组赛BO2 第一场 8.16
2018/08/17 DOTA
python轻松实现代码编码格式转换
2015/03/26 Python
Python中urllib+urllib2+cookielib模块编写爬虫实战
2016/01/20 Python
简单实现python收发邮件功能
2018/01/05 Python
python文件拆分与重组实例
2018/12/10 Python
Python数据可视化之画图
2019/01/15 Python
对pandas通过索引提取dataframe的行方法详解
2019/02/01 Python
python设计tcp数据包协议类的例子
2019/07/23 Python
Python自省及反射原理实例详解
2020/07/06 Python
CSS3美化表单控件全集
2016/06/29 HTML / CSS
使用Html5中的cavas画一面国旗
2019/09/25 HTML / CSS
详解HTML5 Canvas标签及基本使用
2020/01/10 HTML / CSS
韩国演唱会订票网站:StubHub韩国
2019/01/17 全球购物
Quiksilver荷兰官方网站:冲浪和滑雪板
2019/11/16 全球购物
Perfume’s Club中文官网:西班牙美妆在线零售品牌
2020/08/24 全球购物
什么是静态路由?什么是动态路由?各自的特点是什么?
2015/09/16 面试题
MySQL官方导出工具mysqlpump的使用
2021/05/21 MySQL
浅谈如何保证Mysql主从一致
2022/03/13 MySQL