keras训练曲线,混淆矩阵,CNN层输出可视化实例


Posted in Python onJune 15, 2020

训练曲线

def show_train_history(train_history, train_metrics, validation_metrics):
 plt.plot(train_history.history[train_metrics])
 plt.plot(train_history.history[validation_metrics])
 plt.title('Train History')
 plt.ylabel(train_metrics)
 plt.xlabel('Epoch')
 plt.legend(['train', 'validation'], loc='upper left')

# 显示训练过程
def plot(history):
 plt.figure(figsize=(12, 4))
 plt.subplot(1, 2, 1)
 show_train_history(history, 'acc', 'val_acc')
 plt.subplot(1, 2, 2)
 show_train_history(history, 'loss', 'val_loss')
 plt.show()

效果:

plot(history)

keras训练曲线,混淆矩阵,CNN层输出可视化实例

混淆矩阵

def plot_confusion_matrix(cm, classes,
    title='Confusion matrix',
    cmap=plt.cm.jet):
 cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
 plt.imshow(cm, interpolation='nearest', cmap=cmap)
 plt.title(title)
 plt.colorbar()
 tick_marks = np.arange(len(classes))
 plt.xticks(tick_marks, classes, rotation=45)
 plt.yticks(tick_marks, classes)
 thresh = cm.max() / 2.
 for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
 plt.text(j, i, '{:.2f}'.format(cm[i, j]), horizontalalignment="center",
   color="white" if cm[i, j] > thresh else "black")
 plt.tight_layout()
 plt.ylabel('True label')
 plt.xlabel('Predicted label')
 plt.show()

# 显示混淆矩阵
def plot_confuse(model, x_val, y_val):
 predictions = model.predict_classes(x_val)
 truelabel = y_val.argmax(axis=-1) # 将one-hot转化为label
 conf_mat = confusion_matrix(y_true=truelabel, y_pred=predictions)
 plt.figure()
 plot_confusion_matrix(conf_mat, range(np.max(truelabel)+1))

其中y_val以one-hot形式输入

效果:

x_val.shape # (25838, 48, 48, 1)
y_val.shape # (25838, 7)
plot_confuse(model, x_val, y_val)

keras训练曲线,混淆矩阵,CNN层输出可视化实例

CNN层输出可视化

# 卷积网络可视化
def visual(model, data, num_layer=1):
 # data:图像array数据
 # layer:第n层的输出
 data = np.expand_dims(data, axis=0) # 开头加一维
 layer = keras.backend.function([model.layers[0].input], [model.layers[num_layer].output])
 f1 = layer([data])[0]
 num = f1.shape[-1]
 plt.figure(figsize=(8, 8))
 for i in range(num):
 plt.subplot(np.ceil(np.sqrt(num)), np.ceil(np.sqrt(num)), i+1)
 plt.imshow(f1[0, :, :, i] * 255, cmap='gray')
 plt.axis('off')
 plt.show()

num_layer : 显示第n层的输出

效果

visual(model, data, 1) # 卷积层
visual(model, data, 2) # 激活层
visual(model, data, 3) # 规范化层
visual(model, data, 4) # 池化层

keras训练曲线,混淆矩阵,CNN层输出可视化实例

补充知识:Python sklearn.cross_validation.train_test_split及混淆矩阵实现

sklearn.cross_validation.train_test_split随机划分训练集和测试集

一般形式:

train_test_split是交叉验证中常用的函数,功能是从样本中随机的按比例选取train data和testdata,形式为:

X_train,X_test, y_train, y_test =
cross_validation.train_test_split(train_data,train_target,test_size=0.4, random_state=0)

参数解释:

train_data:所要划分的样本特征集

train_target:所要划分的样本结果

test_size:样本占比,如果是整数的话就是样本的数量

random_state:是随机数的种子。

随机数种子:其实就是该组随机数的编号,在需要重复试验的时候,保证得到一组一样的随机数。比如你每次都填1,其他参数一样的情况下你得到的随机数组是一样的。但填0或不填,每次都会不一样。随机数的产生取决于种子,随机数和种子之间的关系遵从以下两个规则:种子不同,产生不同的随机数;种子相同,即使实例不同也产生相同的随机数。

示例

fromsklearn.cross_validation import train_test_split
train= loan_data.iloc[0: 55596, :]
test= loan_data.iloc[55596:, :]
# 避免过拟合,采用交叉验证,验证集占训练集20%,固定随机种子(random_state)
train_X,test_X, train_y, test_y = train_test_split(train,
             target,
             test_size = 0.2,
             random_state = 0)
train_y= train_y['label']
test_y= test_y['label']

plot_confusion_matrix.py(混淆矩阵实现实例)

print(__doc__)

import numpy as np
import matplotlib.pyplot as plt

from sklearn import svm, datasets
from sklearn.cross_validation import train_test_split
from sklearn.metrics import confusion_matrix

# import some data to play with
iris = datasets.load_iris()
X = iris.data
y = iris.target

# Split the data into a training set and a test set
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

# Run classifier, using a model that is too regularized (C too low) to see
# the impact on the results
classifier = svm.SVC(kernel='linear', C=0.01)
y_pred = classifier.fit(X_train, y_train).predict(X_test)

def plot_confusion_matrix(cm, title='Confusion matrix', cmap=plt.cm.Blues):
 plt.imshow(cm, interpolation='nearest', cmap=cmap)
 plt.title(title)
 plt.colorbar()
 tick_marks = np.arange(len(iris.target_names))
 plt.xticks(tick_marks, iris.target_names, rotation=45)
 plt.yticks(tick_marks, iris.target_names)
 plt.tight_layout()
 plt.ylabel('True label')
 plt.xlabel('Predicted label')

# Compute confusion matrix
cm = confusion_matrix(y_test, y_pred)
np.set_printoptions(precision=2)
print('Confusion matrix, without normalization')
print(cm)
plt.figure()
plot_confusion_matrix(cm)

# Normalize the confusion matrix by row (i.e by the number of samples
# in each class)
cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
print('Normalized confusion matrix')
print(cm_normalized)
plt.figure()
plot_confusion_matrix(cm_normalized, title='Normalized confusion matrix')

plt.show()

以上这篇keras训练曲线,混淆矩阵,CNN层输出可视化实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用python开发vim插件及心得分享
Nov 04 Python
将字典转换为DataFrame并进行频次统计的方法
Apr 08 Python
Python中偏函数用法示例
Jun 07 Python
python判断完全平方数的方法
Nov 13 Python
对Python3+gdal 读取tiff格式数据的实例讲解
Dec 04 Python
python得到一个excel的全部sheet标签值方法
Dec 10 Python
Python 控制终端输出文字的实例
Jul 12 Python
Django基础三之视图函数的使用方法
Jul 18 Python
python爬虫 批量下载zabbix文档代码实例
Aug 21 Python
python计算二维矩形IOU实例
Jan 18 Python
TensorFlow保存TensorBoard图像操作
Jun 23 Python
Django跨域请求原理及实现代码
Nov 14 Python
Python3 requests模块如何模仿浏览器及代理
Jun 15 #Python
keras读取训练好的模型参数并把参数赋值给其它模型详解
Jun 15 #Python
keras得到每层的系数方式
Jun 15 #Python
Python类及获取对象属性方法解析
Jun 15 #Python
在Keras中实现保存和加载权重及模型结构
Jun 15 #Python
简单了解Python多态与属性运行原理
Jun 15 #Python
Python类super()及私有属性原理解析
Jun 15 #Python
You might like
超神学院:天使彦公认最美的三个视角,网友:我的天使快下凡吧!
2020/03/02 国漫
用PHP制作静态网站的模板框架(二)
2006/10/09 PHP
php 运行效率总结(提示程序速度)
2009/11/26 PHP
php获取通过http协议post提交过来xml数据及解析xml
2012/12/16 PHP
PHP容易忘记的知识点分享
2013/04/30 PHP
php中file_exists函数使用详解
2015/05/08 PHP
YII Framework框架教程之国际化实现方法
2016/03/14 PHP
php实现文件管理与基础功能操作
2017/03/21 PHP
Laravel接收前端ajax传来的数据的实例代码
2017/07/20 PHP
PHP使用SMTP邮件服务器发送邮件示例
2018/08/28 PHP
php实现单笔转账到支付宝功能
2018/10/09 PHP
Some tips of wmi scripting in jscript (1)
2007/04/03 Javascript
模拟select的代码
2011/10/19 Javascript
IE与FireFox中的childNodes区别
2011/10/20 Javascript
JavaScript实现QueryString获取GET参数的方法
2013/07/02 Javascript
JS获取鼠标坐标的实例方法
2013/07/18 Javascript
js 三级关联菜单效果实例
2013/08/13 Javascript
判断文档离浏览器顶部的距离的方法
2014/01/08 Javascript
用jquery仿做发微博功能示例
2014/04/18 Javascript
javascript实现复制与粘贴操作实例
2014/10/16 Javascript
Eclipse引入jquery报错如何解决
2015/12/01 Javascript
Extjs4.0 ComboBox如何实现三级联动
2016/05/11 Javascript
使用vue.js2.0 + ElementUI开发后台管理系统详细教程(二)
2017/01/21 Javascript
node.js爬虫爬取拉勾网职位信息
2017/03/14 Javascript
基于vue的换肤功能的示例代码
2017/10/10 Javascript
原生JS封装_new函数实现new关键字的功能
2018/08/12 Javascript
浅谈python为什么不需要三目运算符和switch
2016/06/17 Python
django 通过ajax完成邮箱用户注册、激活账号的方法
2018/04/17 Python
Python for循环与range函数的使用详解
2019/03/23 Python
详解从Django Allauth中进行登录改造小结
2019/12/18 Python
python 多进程队列数据处理详解
2019/12/23 Python
python 实现倒计时功能(gui界面)
2020/11/11 Python
跑步、骑行和铁人三项的高性能眼镜和服装:ROKA
2018/07/06 全球购物
2014小学语文教学工作总结
2014/12/17 职场文书
golang 在windows中设置环境变量的操作
2021/04/29 Golang
Go语言编译原理之源码调试
2022/08/05 Golang