Python机器学习应用之基于线性判别模型的分类篇详解


Posted in Python onJanuary 18, 2022

一、Introduction

线性判别模型(LDA)在模式识别领域(比如人脸识别等图形图像识别领域)中有非常广泛的应用。LDA是一种监督学习的降维技术,也就是说它的数据集的每个样本是有类别输出的。这点和PCA不同。PCA是不考虑样本类别输出的无监督降维技术。 LDA的思想可以用一句话概括,就是“投影后类内方差最小,类间方差最大”。我们要将数据在低维度上进行投影,投影后希望每一种类别数据的投影点尽可能的接近,而不同类别的数据的类别中心之间的距离尽可能的大。即:将数据投影到维度更低的空间中,使得投影后的点,会形成按类别区分,一簇一簇的情况,相同类别的点,将会在投影后的空间中更接近方法。

1 LDA的优点

  • 在降维过程中可以使用类别的先验知识经验,而像PCA这样的无监督学习则无法使用类别先验知识;
  • LDA在样本分类信息依赖均值而不是方差的时候,比PCA之类的算法较优

2 LDA的缺点

  • LDA不适合对非高斯分布样本进行降维,PCA也有这个问题
  • LDA降维最多降到类别数 k-1 的维数,如果我们降维的维度大于 k-1,则不能使用 LDA。当然目前有一些LDA的进化版算法可以绕过这个问题
  • LDA在样本分类信息依赖方差而不是均值的时候,降维效果不好
  • LDA可能过度拟合数据

3 LDA在模式识别领域与自然语言处理领域的区别

在自然语言处理领域,LDA是隐含狄利克雷分布,它是一种处理文档的主题模型。本文讨论的是线性判别分析 LDA除了可以用于降维以外,还可以用于分类。一个常见的LDA分类基本思想是假设各个类别的样本数据符合高斯分布,这样利用LDA进行投影后,可以利用极大似然估计计算各个类别投影数据的均值和方差,进而得到该类别高斯分布的概率密度函数。当一个新的样本到来后,我们可以将它投影,然后将投影后的样本特征分别带入各个类别的高斯分布概率密度函数,计算它属于这个类别的概率,最大的概率对应的类别即为预测类别

二、Demo

#%%导入基本库
# 基础数组运算库导入
import numpy as np 
# 画图库导入
import matplotlib.pyplot as plt 
# 导入三维显示工具
from mpl_toolkits.mplot3d import Axes3D
# 导入LDA模型
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
# 导入demo数据制作方法
from sklearn.datasets import make_classification
#%%模型训练
# 制作四个类别的数据,每个类别100个样本
X, y = make_classification(n_samples=1000, n_features=3, n_redundant=0,
                           n_classes=4, n_informative=2, n_clusters_per_class=1,
                           class_sep=3, random_state=10)
# 将四个类别的数据进行三维显示
fig = plt.figure()
ax = Axes3D(fig, rect=[0, 0, 1, 1], elev=20, azim=20)
ax.scatter(X[:, 0], X[:, 1], X[:, 2], marker='o', c=y)
plt.show()

Python机器学习应用之基于线性判别模型的分类篇详解

#%%建立 LDA 模型
lda = LinearDiscriminantAnalysis()
# 进行模型训练
lda.fit(X, y)
#%%查看lda的参数
print(lda.get_params())

Python机器学习应用之基于线性判别模型的分类篇详解

#%%数据可视化
#模型预测
X_new = lda.transform(X)
# 可视化预测数据
plt.scatter(X_new[:, 0], X_new[:, 1], marker='o', c=y)
plt.show()

Python机器学习应用之基于线性判别模型的分类篇详解

#%%使用新的数据进行测试
a = np.array([[-1, 0.1, 0.1]])
print(f"{a} 类别是: ", lda.predict(a))
print(f"{a} 类别概率分别是: ", lda.predict_proba(a))

a = np.array([[-12, -100, -91]])
print(f"{a} 类别是: ", lda.predict(a))
print(f"{a} 类别概率分别是: ", lda.predict_proba(a))

a = np.array([[-12, -0.1, -0.1]])
print(f"{a} 类别是: ", lda.predict(a))
print(f"{a} 类别概率分别是: ", lda.predict_proba(a))

a = np.array([[0.1, 90.1, 9.1]])
print(f"{a} 类别是: ", lda.predict(a))
print(f"{a} 类别概率分别是: ", lda.predict_proba(a))

Python机器学习应用之基于线性判别模型的分类篇详解

三、基于LDA 手写数字的分类

#%%导入库函数
# 导入手写数据集 MNIST
from sklearn.datasets import load_digits
# 导入训练集分割方法
from sklearn.model_selection import train_test_split
# 导入LDA模型
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
# 导入预测指标计算函数和混淆矩阵计算函数
from sklearn.metrics import classification_report, confusion_matrix
# 导入绘图包
import seaborn as sns
import matplotlib
import matplotlib.pyplot as plt
#%% 导入MNIST数据集
mnist = load_digits()
# 查看数据集信息
print('The Mnist dataeset:\n',mnist)

# 分割数据为训练集和测试集
x, test_x, y, test_y = train_test_split(mnist.data, mnist.target, test_size=0.1, random_state=2)

Python机器学习应用之基于线性判别模型的分类篇详解

#%%## 输出示例图像
images = range(0,9)

plt.figure(dpi=100)
for i in images:
    plt.subplot(330 + 1 + i)
    plt.imshow(x[i].reshape(8, 8), cmap = matplotlib.cm.binary,interpolation="nearest")
# show the plot
plt.show()

Python机器学习应用之基于线性判别模型的分类篇详解

#%%利用LDA对手写数字进行训练与预测
m_lda = LinearDiscriminantAnalysis()# 建立 LDA 模型
# 进行模型训练
m_lda.fit(x, y)
# 进行模型预测
x_new = m_lda.transform(x)
# 可视化预测数据
plt.scatter(x_new[:, 0], x_new[:, 1], marker='o', c=y)
plt.title('MNIST with LDA Model')
plt.show()

Python机器学习应用之基于线性判别模型的分类篇详解

#%% 进行测试集数据的类别预测
y_test_pred = m_lda.predict(test_x)
print("测试集的真实标签:\n", test_y)
print("测试集的预测标签:\n", y_test_pred)
#%% 进行预测结果指标统计 统计每一类别的预测准确率、召回率、F1分数
print(classification_report(test_y, y_test_pred))
# 计算混淆矩阵
C2 = confusion_matrix(test_y, y_test_pred)
# 打混淆矩阵
print(C2)

# 将混淆矩阵以热力图的防线显示
sns.set()
f, ax = plt.subplots()
# 画热力图
sns.heatmap(C2, cmap="YlGnBu_r", annot=True, ax=ax)  
# 标题 
ax.set_title('confusion matrix')
# x轴为预测类别
ax.set_xlabel('predict')  
# y轴实际类别
ax.set_ylabel('true')  
plt.show()

Python机器学习应用之基于线性判别模型的分类篇详解

Python机器学习应用之基于线性判别模型的分类篇详解

Python机器学习应用之基于线性判别模型的分类篇详解

四、小结

LDA适用于线性可分数据,在非线性数据上要谨慎使用。 886~~~

到此这篇关于Python机器学习应用之基于线性判别模型的分类篇详解的文章就介绍到这了,更多相关Python 线性判别模型的分类内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
python基础入门详解(文件输入/输出 内建类型 字典操作使用方法)
Dec 08 Python
pyqt4教程之widget使用示例分享
Mar 07 Python
Python常用列表数据结构小结
Aug 06 Python
Python中分数的相关使用教程
Mar 30 Python
利用python将图片版PDF转文字版PDF
May 03 Python
Django 权限认证(根据不同的用户,设置不同的显示和访问权限)
Jul 24 Python
Python3 文章标题关键字提取的例子
Aug 26 Python
Python continue语句实例用法
Feb 06 Python
后端开发使用pycharm的技巧(推荐)
Mar 27 Python
Python实时监控网站浏览记录实现过程详解
Jul 14 Python
python实现数字炸弹游戏程序
Jul 17 Python
python中remove函数的踩坑记录
Jan 04 Python
68行Python代码实现带难度升级的贪吃蛇
Jan 18 #Python
如何利用Python实现n*n螺旋矩阵
Jan 18 #Python
聊聊Python String型列表求最值的问题
Jan 18 #Python
Python的三个重要函数详解
Jan 18 #Python
python多线程方法详解
Jan 18 #Python
用Python生成会跳舞的美女
基于Pygame实现简单的贪吃蛇游戏
Dec 06 #Python
You might like
PHP简单选择排序算法实例
2015/01/26 PHP
php选择排序法实现数组排序实例分析
2015/02/16 PHP
yii2中的rules 自定义验证规则详解
2016/04/19 PHP
深入理解PHP的远程多会话调试
2017/09/21 PHP
漂亮的仿flash菜单,来自蓝色经典
2006/06/26 Javascript
jquery ajax abort()的使用方法
2010/10/28 Javascript
解决jQuery插件tipswindown与hintbox冲突
2010/11/05 Javascript
javascript获取选中的文本的方法代码
2013/10/30 Javascript
轻松创建nodejs服务器(2):nodejs服务器的构成分析
2014/12/18 NodeJs
JavaScript的类型、值和变量小结
2015/07/09 Javascript
简单实现jQuery多选框功能
2017/01/09 Javascript
vue2.0设置proxyTable使用axios进行跨域请求的方法
2017/10/19 Javascript
jquery+css实现Tab栏切换的代码实例
2019/05/14 jQuery
JavaScript实现随机五位数验证码
2019/09/27 Javascript
JavaScript实现网页动态生成表格
2020/11/25 Javascript
python基础教程之获取本机ip数据包示例
2014/02/10 Python
python cx_Oracle模块的安装和使用详细介绍
2017/02/13 Python
python将字典内容存入mysql实例代码
2018/01/18 Python
Python实现的自定义多线程多进程类示例
2018/03/23 Python
python 利用栈和队列模拟递归的过程
2018/05/29 Python
浅谈Pycharm中的Python Console与Terminal
2019/01/17 Python
python按照多个条件排序的方法
2019/02/08 Python
python实现维吉尼亚加密法
2019/03/20 Python
Jupyter Notebook的连接密码 token查询方式
2020/04/21 Python
Python eval函数原理及用法解析
2020/11/14 Python
python中绕过反爬虫的方法总结
2020/11/25 Python
使用Python制作一盏 3D 花灯喜迎元宵佳节
2021/02/26 Python
Pytorch - TORCH.NN.INIT 参数初始化的操作
2021/02/27 Python
从一次项目重构说起CSS3自定义变量在项目的使用方法
2021/03/01 HTML / CSS
超30万乐谱下载:Musicnotes.com
2016/09/24 全球购物
《花瓣飘香》教学反思
2014/04/15 职场文书
交通事故赔偿协议书怎么写
2014/10/04 职场文书
2014年团支部年度工作总结
2014/12/24 职场文书
go mod 安装依赖 unkown revision问题的解决方案
2021/05/06 Golang
Java面试题冲刺第十五天--设计模式
2021/08/07 面试题
中国古风插画师排行榜:夏达第一,第三是阴阳师姑获鸟皮肤创作者
2022/03/18 国漫