sklearn的predict_proba使用说明


Posted in Python onJune 28, 2020

发现个很有用的方法——predict_proba

今天在做数据预测的时候用到了,感觉很不错,所以记录分享一下,以后可能会经常用到。

我的理解:predict_proba不同于predict,它返回的预测值为,获得所有结果的概率。(有多少个分类结果,每行就有多少个概率,以至于它对每个结果都有一个可能,如0、1就有两个概率)

举例:

获取数据及预测代码:

from sklearn.linear_model import LogisticRegression
import numpy as np
 
train_X = np.array(np.random.randint(0,10,size=30).reshape(10,3))
train_y = np.array(np.random.randint(0,2,size=10))
test_X = np.array(np.random.randint(0,10,size=12).reshape(4,3))
 
model = LogisticRegression()
model.fit(train_X,train_y)
test_y = model.predict_proba(test_X)
 
print(train_X)
print(train_y)
print(test_y)

训练数据

[[2 9 8]
 [0 8 5]
 [7 1 2]
 [8 4 6]
 [8 8 3]
 [7 2 7]
 [6 4 3]
 [1 4 4]
 [1 9 3]
 [3 4 7]]

训练结果,与训练数据一一对应:

[1 1 1 0 1 1 0 0 0 1]

测试数据:

[[4 3 0]  #测试数据
 [3 0 4]
 [2 9 5]
 [2 8 5]]

测试结果,与测试数据一一对应:

[[0.48753831 0.51246169] 
 [0.58182694 0.41817306]
 [0.85361393 0.14638607]
 [0.57018655 0.42981345]]

可以看出,有四行两列,每行对应一条预测数据,两列分别对应 对于0、1的预测概率(左边概率大于0.5则为0,反之为1)

我们来看看使用predict方法获得的结果:

test_y = model.predict(test_X)
print(test_y)

输出结果:[1,0,0,0]

所以有的情况下predict_proba还是很有用的,它可以获得对每种可能结果的概率,使用predict则是直接获得唯一的预测结果,所以在使用的时候,应该灵活使用。

补充一个知识点:关于预测结果标签如何与原来标签相对应

predict_proba返回所有标签值可能性概率值,这些值是如何排序的呢?

返回模型中每个类的样本概率,其中类按类self.classes_进行排序。

其中关键的步骤为numpy的unique方法,即通过np.unique(Label)方法,对Label中的所有标签值进行从小到大的去重排序。得到一个从小到大唯一值的排序。这也就对应于predict_proba的行返回结果。

补充知识: python sklearn decision_function、predict_proba、predict

看代码~

import matplotlib.pyplot as plt
import numpy as np
from sklearn.svm import SVC
X = np.array([[-1,-1],[-2,-1],[1,1],[2,1],[-1,1],[-1,2],[1,-1],[1,-2]])
y = np.array([0,0,1,1,2,2,3,3])
# y=np.array([1,1,2,2,3,3,4,4])
# clf = SVC(decision_function_shape="ovr",probability=True)
clf = SVC(probability=True)
clf.fit(X, y)
print(clf.decision_function(X))
'''
对于n分类,会有n个分类器,然后,任意两个分类器都可以算出一个分类界面,这样,用decision_function()时,对于任意一个样例,就会有n*(n-1)/2个值。
任意两个分类器可以算出一个分类界面,然后这个值就是距离分类界面的距离。
我想,这个函数是为了统计画图,对于二分类时最明显,用来统计每个点离超平面有多远,为了在空间中直观的表示数据以及画超平面还有间隔平面等。
decision_function_shape="ovr"时是4个值,为ovo时是6个值。
'''
print(clf.predict(X))
clf.predict_proba(X) #这个是得分,每个分类器的得分,取最大得分对应的类。
#画图
plot_step=0.02
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, plot_step),
           np.arange(y_min, y_max, plot_step))
 
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()]) #对坐标风格上的点进行预测,来画分界面。其实最终看到的类的分界线就是分界面的边界线。
Z = Z.reshape(xx.shape)
cs = plt.contourf(xx, yy, Z, cmap=plt.cm.Paired)
plt.axis("tight")
 
class_names="ABCD"
plot_colors="rybg"
for i, n, c in zip(range(4), class_names, plot_colors):
  idx = np.where(y == i) #i为0或者1,两个类
  plt.scatter(X[idx, 0], X[idx, 1],
        c=c, cmap=plt.cm.Paired,
        label="Class %s" % n)
plt.xlim(x_min, x_max)
plt.ylim(y_min, y_max)
plt.legend(loc='upper right')
plt.xlabel('x')
plt.ylabel('y')
plt.title('Decision Boundary')
plt.show()

sklearn的predict_proba使用说明

以上这篇sklearn的predict_proba使用说明就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python ljust rjust center输出
Sep 06 Python
python处理文本文件并生成指定格式的文件
Jul 31 Python
Python 实现文件的全备份和差异备份详解
Dec 27 Python
Python实现的密码强度检测器示例
Aug 23 Python
Python使用numpy实现BP神经网络
Mar 10 Python
浅谈python写入大量文件的问题
Nov 09 Python
python TF-IDF算法实现文本关键词提取
May 29 Python
Python 字符串类型列表转换成真正列表类型过程解析
Aug 26 Python
详解python中eval函数的作用
Oct 22 Python
windows下Pycharm安装opencv的多种方法
Mar 05 Python
pyinstaller打包找不到文件的问题解决
Apr 15 Python
Python中常见的导入方式总结
May 06 Python
基于python实现ROC曲线绘制广场解析
Jun 28 #Python
Python sklearn中的.fit与.predict的用法说明
Jun 28 #Python
浅谈sklearn中predict与predict_proba区别
Jun 28 #Python
解决Pytorch自定义层出现多Variable共享内存错误问题
Jun 28 #Python
Pytorch学习之torch用法----比较操作(Comparison Ops)
Jun 28 #Python
PyTorch的torch.cat用法
Jun 28 #Python
使用pytorch 筛选出一定范围的值
Jun 28 #Python
You might like
PHP 编写的 25个游戏脚本
2009/05/11 PHP
smarty中post用法实例
2014/11/28 PHP
Codeigniter校验ip地址的方法
2015/03/21 PHP
php结合ACCESS的跨库查询功能
2015/06/12 PHP
调试WordPress中定时任务的相关PHP脚本示例
2015/12/10 PHP
Yii2框架中日志的使用方法分析
2017/05/22 PHP
CI框架(CodeIgniter)实现的数据库增删改查操作总结
2018/05/23 PHP
php 使用mpdf实现指定字段配置字体样式的方法
2019/07/29 PHP
JS基础之undefined与null的区别分析
2011/08/08 Javascript
javascript验证只能输入数字和一个小数点示例
2013/10/21 Javascript
jQuery实现企业网站横幅焦点图切换功能实例
2015/04/30 Javascript
极易被忽视的javascript面试题七问七答
2016/02/15 Javascript
js智能获取浏览器版本UA信息的方法
2016/08/08 Javascript
深入理解requestAnimationFrame的动画循环
2016/09/20 Javascript
信息滚动效果的实例讲解
2017/09/18 Javascript
vue2.0 循环遍历加载不同图片的方法
2018/03/06 Javascript
vue系列之requireJs中引入vue-router的方法
2018/07/18 Javascript
JavaScript利用键盘码控制div移动
2020/03/19 Javascript
VueCli生产环境打包部署跨域失败的解决
2020/11/13 Javascript
JavaScript实现alert弹框效果
2020/11/19 Javascript
Django认证系统实现的web页面实现代码
2019/08/12 Python
python3中利用filter函数输出小于某个数的所有回文数实例
2019/11/24 Python
python实现磁盘日志清理的示例
2020/11/05 Python
python SOCKET编程基础入门
2021/02/27 Python
澳大利亚最好的厨具店:Kitchen Warehouse
2018/03/13 全球购物
哄娃神器4moms商店:美国婴童用品品牌
2019/03/07 全球购物
小学少先队活动方案
2014/02/18 职场文书
党员干部一句话承诺
2014/05/30 职场文书
大学军训的体会
2014/11/08 职场文书
2015年依法行政工作总结
2015/04/29 职场文书
小学工作总结2015
2015/05/04 职场文书
煤矿安全生产工作总结
2015/08/13 职场文书
2016元旦文艺汇演主持词(开场白+结束语)
2015/12/03 职场文书
JavaScript实现简单计时器
2021/06/22 Javascript
vue项目如何打包之项目打包优化(让打包的js文件变小)
2022/04/30 Vue.js
Python+pyaudio实现音频控制示例详解
2022/07/23 Python