Keras 利用sklearn的ROC-AUC建立评价函数详解


Posted in Python onJune 15, 2020

我就废话不多说了,大家还是直接看代码吧!

# 利用sklearn自建评价函数
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score
from keras.callbacks import Callback

class RocAucEvaluation(Callback):
 def __init__(self, validation_data=(), interval=1):
 super(Callback, self).__init__()
 self.interval = interval
 self.x_val,self.y_val = validation_data
 def on_epoch_end(self, epoch, log={}):
 if epoch % self.interval == 0:
  y_pred = self.model.predict(self.x_val, verbose=0)
  score = roc_auc_score(self.y_val, y_pred)
  print('\n ROC_AUC - epoch:%d - score:%.6f \n' % (epoch+1, score))

x_train,y_train,x_label,y_label = train_test_split(train_feature, train_label, train_size=0.95, random_state=233)
RocAuc = RocAucEvaluation(validation_data=(y_train,y_label), interval=1)

hist = model.fit(x_train, x_label, batch_size=batch_size, epochs=epochs, validation_data=(y_train, y_label), callbacks=[RocAuc], verbose=2)

补充知识:keras用auc做metrics以及早停

我就废话不多说了,大家还是直接看代码吧!

import tensorflow as tf
from sklearn.metrics import roc_auc_score

def auroc(y_true, y_pred):
 return tf.py_func(roc_auc_score, (y_true, y_pred), tf.double)
# Build Model...
model.compile(loss='categorical_crossentropy', optimizer='adam',metrics=['accuracy', auroc])

完整例子:

def auc(y_true, y_pred):
 auc = tf.metrics.auc(y_true, y_pred)[1]
 K.get_session().run(tf.local_variables_initializer())
 return auc

def create_model_nn(in_dim,layer_size=200):
 model = Sequential()
 model.add(Dense(layer_size,input_dim=in_dim, kernel_initializer='normal'))
 model.add(BatchNormalization())
 model.add(Activation('relu'))
 model.add(Dropout(0.3))
 for i in range(2):
 model.add(Dense(layer_size))
 model.add(BatchNormalization())
 model.add(Activation('relu'))
 model.add(Dropout(0.3))
 model.add(Dense(1, activation='sigmoid'))
 adam = optimizers.Adam(lr=0.01)
 model.compile(optimizer=adam,loss='binary_crossentropy',metrics = [auc]) 
 return model
####cv train
folds = StratifiedKFold(n_splits=5, shuffle=False, random_state=15)
oof = np.zeros(len(df_train))
predictions = np.zeros(len(df_test))
for fold_, (trn_idx, val_idx) in enumerate(folds.split(df_train.values, target2.values)):
 print("fold n°{}".format(fold_))
 X_train = df_train.iloc[trn_idx][features]
 y_train = target2.iloc[trn_idx]
 X_valid = df_train.iloc[val_idx][features]
 y_valid = target2.iloc[val_idx]
 model_nn = create_model_nn(X_train.shape[1])
 callback = EarlyStopping(monitor="val_auc", patience=50, verbose=0, mode='max')
 history = model_nn.fit(X_train, y_train, validation_data = (X_valid ,y_valid),epochs=1000,batch_size=64,verbose=0,callbacks=[callback])
 print('\n Validation Max score : {}'.format(np.max(history.history['val_auc'])))
 predictions += model_nn.predict(df_test[features]).ravel()/folds.n_splits

以上这篇Keras 利用sklearn的ROC-AUC建立评价函数详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
wxpython 学习笔记 第一天
Feb 09 Python
Python实现网站文件的全备份和差异备份
Nov 30 Python
python daemon守护进程实现
Aug 27 Python
使用python实现生成用户信息
Mar 20 Python
python之拟合的实现
Jul 19 Python
详解PyTorch手写数字识别(MNIST数据集)
Aug 16 Python
如何使用Python多线程测试并发漏洞
Dec 18 Python
彻底搞懂 python 中文乱码问题(深入分析)
Feb 28 Python
pycharm解决关闭flask后依旧可以访问服务的问题
Apr 03 Python
django-orm F对象的使用 按照两个字段的和,乘积排序实例
May 18 Python
django注册用邮箱发送验证码的实现
Apr 18 Python
Python 解决空列表.append() 输出为None的问题
May 23 Python
Python如何在windows环境安装pip及rarfile
Jun 15 #Python
keras训练曲线,混淆矩阵,CNN层输出可视化实例
Jun 15 #Python
Python3 requests模块如何模仿浏览器及代理
Jun 15 #Python
keras读取训练好的模型参数并把参数赋值给其它模型详解
Jun 15 #Python
keras得到每层的系数方式
Jun 15 #Python
Python类及获取对象属性方法解析
Jun 15 #Python
在Keras中实现保存和加载权重及模型结构
Jun 15 #Python
You might like
解析Linux下Varnish缓存的配置优化
2013/06/20 PHP
Symfony2学习笔记之系统路由详解
2016/03/17 PHP
Ubuntu 16.04中Laravel5.4升级到5.6的步骤
2018/12/07 PHP
Javascript操纵Cookie实现购物车程序
2006/11/23 Javascript
document.body.scrollTop 值总为0的解决方法 比较常见的标准问题
2009/11/30 Javascript
JavaScript中setInterval的用法总结
2013/11/20 Javascript
jQuery动画效果实现图片无缝连续滚动
2016/01/12 Javascript
使用jquery获取url以及jquery获取url参数的实现方法
2016/05/25 Javascript
AngularJS通过$sce输出html的方法
2016/09/22 Javascript
js 判断附件后缀的简单实现方法
2016/10/11 Javascript
javascript判断firebug是否开启的方法
2016/11/23 Javascript
js时间控件只显示年月
2017/01/08 Javascript
jQuery实现移动端Tab选项卡效果
2017/03/15 Javascript
vue项目持久化存储数据的实现代码
2018/10/01 Javascript
总结Python中逻辑运算符的使用
2015/05/13 Python
Python实现的选择排序算法原理与用法实例分析
2017/11/22 Python
Python基础教程之利用期物处理并发
2018/03/29 Python
python 3.6.5 安装配置方法图文教程
2018/09/18 Python
python多个模块py文件的数据共享实例
2019/01/11 Python
对python 判断数字是否小于0的方法详解
2019/01/26 Python
python时间序列按频率生成日期的方法
2019/05/14 Python
Python matplotlib学习笔记之坐标轴范围
2019/06/28 Python
Python 绘制酷炫的三维图步骤详解
2019/07/12 Python
Python实现二叉搜索树BST的方法示例
2019/07/30 Python
python如何设置静态变量
2020/09/07 Python
Python classmethod装饰器原理及用法解析
2020/10/17 Python
关于HTML5语义标签的实践(blog页面)
2016/07/12 HTML / CSS
基于HTML5 WebGL的3D机房的示例
2018/03/16 HTML / CSS
购买美国制造的相框和画框架:Picture Frames
2018/08/14 全球购物
职务任命书范本
2014/06/05 职场文书
再婚婚前财产协议书范本
2014/10/19 职场文书
医院办公室主任岗位职责
2015/04/01 职场文书
创建文明城市倡议书
2015/04/28 职场文书
企业计划生育责任书
2015/05/09 职场文书
导游词之唐山景点
2019/12/18 职场文书
详解用Python把PDF转为Word方法总结
2021/04/27 Python