keras训练曲线,混淆矩阵,CNN层输出可视化实例


Posted in Python onJune 15, 2020

训练曲线

def show_train_history(train_history, train_metrics, validation_metrics):
 plt.plot(train_history.history[train_metrics])
 plt.plot(train_history.history[validation_metrics])
 plt.title('Train History')
 plt.ylabel(train_metrics)
 plt.xlabel('Epoch')
 plt.legend(['train', 'validation'], loc='upper left')

# 显示训练过程
def plot(history):
 plt.figure(figsize=(12, 4))
 plt.subplot(1, 2, 1)
 show_train_history(history, 'acc', 'val_acc')
 plt.subplot(1, 2, 2)
 show_train_history(history, 'loss', 'val_loss')
 plt.show()

效果:

plot(history)

keras训练曲线,混淆矩阵,CNN层输出可视化实例

混淆矩阵

def plot_confusion_matrix(cm, classes,
    title='Confusion matrix',
    cmap=plt.cm.jet):
 cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
 plt.imshow(cm, interpolation='nearest', cmap=cmap)
 plt.title(title)
 plt.colorbar()
 tick_marks = np.arange(len(classes))
 plt.xticks(tick_marks, classes, rotation=45)
 plt.yticks(tick_marks, classes)
 thresh = cm.max() / 2.
 for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
 plt.text(j, i, '{:.2f}'.format(cm[i, j]), horizontalalignment="center",
   color="white" if cm[i, j] > thresh else "black")
 plt.tight_layout()
 plt.ylabel('True label')
 plt.xlabel('Predicted label')
 plt.show()

# 显示混淆矩阵
def plot_confuse(model, x_val, y_val):
 predictions = model.predict_classes(x_val)
 truelabel = y_val.argmax(axis=-1) # 将one-hot转化为label
 conf_mat = confusion_matrix(y_true=truelabel, y_pred=predictions)
 plt.figure()
 plot_confusion_matrix(conf_mat, range(np.max(truelabel)+1))

其中y_val以one-hot形式输入

效果:

x_val.shape # (25838, 48, 48, 1)
y_val.shape # (25838, 7)
plot_confuse(model, x_val, y_val)

keras训练曲线,混淆矩阵,CNN层输出可视化实例

CNN层输出可视化

# 卷积网络可视化
def visual(model, data, num_layer=1):
 # data:图像array数据
 # layer:第n层的输出
 data = np.expand_dims(data, axis=0) # 开头加一维
 layer = keras.backend.function([model.layers[0].input], [model.layers[num_layer].output])
 f1 = layer([data])[0]
 num = f1.shape[-1]
 plt.figure(figsize=(8, 8))
 for i in range(num):
 plt.subplot(np.ceil(np.sqrt(num)), np.ceil(np.sqrt(num)), i+1)
 plt.imshow(f1[0, :, :, i] * 255, cmap='gray')
 plt.axis('off')
 plt.show()

num_layer : 显示第n层的输出

效果

visual(model, data, 1) # 卷积层
visual(model, data, 2) # 激活层
visual(model, data, 3) # 规范化层
visual(model, data, 4) # 池化层

keras训练曲线,混淆矩阵,CNN层输出可视化实例

补充知识:Python sklearn.cross_validation.train_test_split及混淆矩阵实现

sklearn.cross_validation.train_test_split随机划分训练集和测试集

一般形式:

train_test_split是交叉验证中常用的函数,功能是从样本中随机的按比例选取train data和testdata,形式为:

X_train,X_test, y_train, y_test =
cross_validation.train_test_split(train_data,train_target,test_size=0.4, random_state=0)

参数解释:

train_data:所要划分的样本特征集

train_target:所要划分的样本结果

test_size:样本占比,如果是整数的话就是样本的数量

random_state:是随机数的种子。

随机数种子:其实就是该组随机数的编号,在需要重复试验的时候,保证得到一组一样的随机数。比如你每次都填1,其他参数一样的情况下你得到的随机数组是一样的。但填0或不填,每次都会不一样。随机数的产生取决于种子,随机数和种子之间的关系遵从以下两个规则:种子不同,产生不同的随机数;种子相同,即使实例不同也产生相同的随机数。

示例

fromsklearn.cross_validation import train_test_split
train= loan_data.iloc[0: 55596, :]
test= loan_data.iloc[55596:, :]
# 避免过拟合,采用交叉验证,验证集占训练集20%,固定随机种子(random_state)
train_X,test_X, train_y, test_y = train_test_split(train,
             target,
             test_size = 0.2,
             random_state = 0)
train_y= train_y['label']
test_y= test_y['label']

plot_confusion_matrix.py(混淆矩阵实现实例)

print(__doc__)

import numpy as np
import matplotlib.pyplot as plt

from sklearn import svm, datasets
from sklearn.cross_validation import train_test_split
from sklearn.metrics import confusion_matrix

# import some data to play with
iris = datasets.load_iris()
X = iris.data
y = iris.target

# Split the data into a training set and a test set
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

# Run classifier, using a model that is too regularized (C too low) to see
# the impact on the results
classifier = svm.SVC(kernel='linear', C=0.01)
y_pred = classifier.fit(X_train, y_train).predict(X_test)

def plot_confusion_matrix(cm, title='Confusion matrix', cmap=plt.cm.Blues):
 plt.imshow(cm, interpolation='nearest', cmap=cmap)
 plt.title(title)
 plt.colorbar()
 tick_marks = np.arange(len(iris.target_names))
 plt.xticks(tick_marks, iris.target_names, rotation=45)
 plt.yticks(tick_marks, iris.target_names)
 plt.tight_layout()
 plt.ylabel('True label')
 plt.xlabel('Predicted label')

# Compute confusion matrix
cm = confusion_matrix(y_test, y_pred)
np.set_printoptions(precision=2)
print('Confusion matrix, without normalization')
print(cm)
plt.figure()
plot_confusion_matrix(cm)

# Normalize the confusion matrix by row (i.e by the number of samples
# in each class)
cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
print('Normalized confusion matrix')
print(cm_normalized)
plt.figure()
plot_confusion_matrix(cm_normalized, title='Normalized confusion matrix')

plt.show()

以上这篇keras训练曲线,混淆矩阵,CNN层输出可视化实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python数据类型详解(一)字符串
May 08 Python
Python可变参数用法实例分析
Apr 02 Python
python如何使用正则表达式的前向、后向搜索及前向搜索否定模式详解
Nov 08 Python
Python之Scrapy爬虫框架安装及使用详解
Nov 16 Python
浅谈Tensorflow模型的保存与恢复加载
Apr 26 Python
python matplotlib 在指定的两个点之间连线方法
May 25 Python
python仿抖音表白神器
Apr 08 Python
python队列Queue的详解
May 10 Python
Django工程的分层结构详解
Jul 18 Python
Python使用sqlite3模块内置数据库
May 07 Python
Python爬虫爬取新闻资讯案例详解
Jul 14 Python
Numpy中np.random.rand()和np.random.randn() 用法和区别详解
Oct 23 Python
Python3 requests模块如何模仿浏览器及代理
Jun 15 #Python
keras读取训练好的模型参数并把参数赋值给其它模型详解
Jun 15 #Python
keras得到每层的系数方式
Jun 15 #Python
Python类及获取对象属性方法解析
Jun 15 #Python
在Keras中实现保存和加载权重及模型结构
Jun 15 #Python
简单了解Python多态与属性运行原理
Jun 15 #Python
Python类super()及私有属性原理解析
Jun 15 #Python
You might like
php更改目录及子目录下所有的文件后缀的代码
2010/09/24 PHP
php中定时计划任务的实现原理
2013/01/08 PHP
PHP实现将科学计数法转换为原始数字字符串的方法
2014/12/16 PHP
PHP中empty和isset对于参数结构的判断及empty()和isset()的区别
2015/11/15 PHP
php、java、android、ios通用的3des方法(推荐)
2016/09/09 PHP
php从数据库中获取数据用ajax传送到前台的方法
2018/08/20 PHP
javascript的键盘控制事件说明
2008/04/15 Javascript
克隆javascript对象的三个方法小结
2011/01/12 Javascript
JQuery扩展插件Validate 5添加自定义验证方法
2011/09/05 Javascript
jquery创建并行对象或者合并对象的实现代码
2012/10/10 Javascript
一个简单的弹性返回顶部JS代码实现介绍
2013/06/09 Javascript
点击页面其它地方隐藏该div的两种思路
2013/11/18 Javascript
Javascript之String对象详解
2016/06/08 Javascript
jQuery+ajax实现实用的点赞插件代码
2016/07/06 Javascript
JavaScript重定向URL参数的两种方法小结
2016/10/19 Javascript
通过jsonp获取json数据实现AJAX跨域请求
2017/01/22 Javascript
微信小程序 仿猫眼实现实例代码
2017/03/14 Javascript
vue2.0实现倒计时的插件(时间戳 刷新 跳转 都不影响)
2017/03/30 Javascript
jquery.uploadView 实现图片预览上传功能
2017/08/10 jQuery
详解angular路由高亮之RouterLinkActive
2018/04/28 Javascript
vue 实现cli3.0中使用proxy进行代理转发
2019/10/30 Javascript
用Python的Django框架来制作一个RSS阅读器
2015/07/22 Python
总结Python编程中函数的使用要点
2016/03/20 Python
Python中模块与包有相同名字的处理方法
2017/05/05 Python
python实现发送邮件及附件功能
2021/03/02 Python
python 自定义异常和异常捕捉的方法
2018/10/18 Python
python 叠加等边三角形的绘制的实现
2019/08/14 Python
keras的三种模型实现与区别说明
2020/07/03 Python
美国摄影爱好者购物网站:Focus Camera
2016/10/21 全球购物
ghd法国官方网站:英国最受欢迎的美发工具品牌
2019/04/18 全球购物
W Hamond官网:始于1979年的钻石专家
2020/07/20 全球购物
在子网210.27.48.21/30种有多少个可用地址?分别是什么?
2014/07/27 面试题
计算机专业应届毕业生自荐信
2013/09/26 职场文书
文明倡议书范文
2014/04/15 职场文书
主持人演讲稿
2014/05/13 职场文书
班子四风对照检查材料
2014/08/21 职场文书