sklearn的predict_proba使用说明


Posted in Python onJune 28, 2020

发现个很有用的方法——predict_proba

今天在做数据预测的时候用到了,感觉很不错,所以记录分享一下,以后可能会经常用到。

我的理解:predict_proba不同于predict,它返回的预测值为,获得所有结果的概率。(有多少个分类结果,每行就有多少个概率,以至于它对每个结果都有一个可能,如0、1就有两个概率)

举例:

获取数据及预测代码:

from sklearn.linear_model import LogisticRegression
import numpy as np
 
train_X = np.array(np.random.randint(0,10,size=30).reshape(10,3))
train_y = np.array(np.random.randint(0,2,size=10))
test_X = np.array(np.random.randint(0,10,size=12).reshape(4,3))
 
model = LogisticRegression()
model.fit(train_X,train_y)
test_y = model.predict_proba(test_X)
 
print(train_X)
print(train_y)
print(test_y)

训练数据

[[2 9 8]
 [0 8 5]
 [7 1 2]
 [8 4 6]
 [8 8 3]
 [7 2 7]
 [6 4 3]
 [1 4 4]
 [1 9 3]
 [3 4 7]]

训练结果,与训练数据一一对应:

[1 1 1 0 1 1 0 0 0 1]

测试数据:

[[4 3 0]  #测试数据
 [3 0 4]
 [2 9 5]
 [2 8 5]]

测试结果,与测试数据一一对应:

[[0.48753831 0.51246169] 
 [0.58182694 0.41817306]
 [0.85361393 0.14638607]
 [0.57018655 0.42981345]]

可以看出,有四行两列,每行对应一条预测数据,两列分别对应 对于0、1的预测概率(左边概率大于0.5则为0,反之为1)

我们来看看使用predict方法获得的结果:

test_y = model.predict(test_X)
print(test_y)

输出结果:[1,0,0,0]

所以有的情况下predict_proba还是很有用的,它可以获得对每种可能结果的概率,使用predict则是直接获得唯一的预测结果,所以在使用的时候,应该灵活使用。

补充一个知识点:关于预测结果标签如何与原来标签相对应

predict_proba返回所有标签值可能性概率值,这些值是如何排序的呢?

返回模型中每个类的样本概率,其中类按类self.classes_进行排序。

其中关键的步骤为numpy的unique方法,即通过np.unique(Label)方法,对Label中的所有标签值进行从小到大的去重排序。得到一个从小到大唯一值的排序。这也就对应于predict_proba的行返回结果。

补充知识: python sklearn decision_function、predict_proba、predict

看代码~

import matplotlib.pyplot as plt
import numpy as np
from sklearn.svm import SVC
X = np.array([[-1,-1],[-2,-1],[1,1],[2,1],[-1,1],[-1,2],[1,-1],[1,-2]])
y = np.array([0,0,1,1,2,2,3,3])
# y=np.array([1,1,2,2,3,3,4,4])
# clf = SVC(decision_function_shape="ovr",probability=True)
clf = SVC(probability=True)
clf.fit(X, y)
print(clf.decision_function(X))
'''
对于n分类,会有n个分类器,然后,任意两个分类器都可以算出一个分类界面,这样,用decision_function()时,对于任意一个样例,就会有n*(n-1)/2个值。
任意两个分类器可以算出一个分类界面,然后这个值就是距离分类界面的距离。
我想,这个函数是为了统计画图,对于二分类时最明显,用来统计每个点离超平面有多远,为了在空间中直观的表示数据以及画超平面还有间隔平面等。
decision_function_shape="ovr"时是4个值,为ovo时是6个值。
'''
print(clf.predict(X))
clf.predict_proba(X) #这个是得分,每个分类器的得分,取最大得分对应的类。
#画图
plot_step=0.02
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, plot_step),
           np.arange(y_min, y_max, plot_step))
 
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()]) #对坐标风格上的点进行预测,来画分界面。其实最终看到的类的分界线就是分界面的边界线。
Z = Z.reshape(xx.shape)
cs = plt.contourf(xx, yy, Z, cmap=plt.cm.Paired)
plt.axis("tight")
 
class_names="ABCD"
plot_colors="rybg"
for i, n, c in zip(range(4), class_names, plot_colors):
  idx = np.where(y == i) #i为0或者1,两个类
  plt.scatter(X[idx, 0], X[idx, 1],
        c=c, cmap=plt.cm.Paired,
        label="Class %s" % n)
plt.xlim(x_min, x_max)
plt.ylim(y_min, y_max)
plt.legend(loc='upper right')
plt.xlabel('x')
plt.ylabel('y')
plt.title('Decision Boundary')
plt.show()

sklearn的predict_proba使用说明

以上这篇sklearn的predict_proba使用说明就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python基于multiprocessing的多进程创建方法
Jun 04 Python
Python 绘图和可视化详细介绍
Feb 11 Python
python 读取txt中每行数据,并且保存到excel中的实例
Apr 29 Python
python递归函数绘制分形树的方法
Jun 22 Python
django中的图片验证码功能
Sep 18 Python
Python django搭建layui提交表单,表格,图标的实例
Nov 18 Python
Python注释、分支结构、循环结构、伪“选择结构”用法实例分析
Jan 09 Python
python 实现单例模式的5种方法
Sep 23 Python
Python关于拓扑排序知识点讲解
Jan 04 Python
TensorFlow低版本代码自动升级为1.0版本
Feb 20 Python
python 如何获取页面所有a标签下href的值
May 06 Python
聊一聊python常用的编程模块
May 14 Python
基于python实现ROC曲线绘制广场解析
Jun 28 #Python
Python sklearn中的.fit与.predict的用法说明
Jun 28 #Python
浅谈sklearn中predict与predict_proba区别
Jun 28 #Python
解决Pytorch自定义层出现多Variable共享内存错误问题
Jun 28 #Python
Pytorch学习之torch用法----比较操作(Comparison Ops)
Jun 28 #Python
PyTorch的torch.cat用法
Jun 28 #Python
使用pytorch 筛选出一定范围的值
Jun 28 #Python
You might like
PHP树的代码,可以嵌套任意层
2006/10/09 PHP
php实现文件下载功能的几个代码分享
2014/05/10 PHP
在WordPress的后台中添加顶级菜单和子菜单的函数详解
2016/01/11 PHP
详解PHP实现支付宝小程序用户授权的工具类
2018/12/25 PHP
php redis setnx分布式锁简单原理解析
2020/10/23 PHP
List the Stored Procedures in a SQL Server database
2007/06/20 Javascript
javascript removeChild 使用注意事项
2009/04/11 Javascript
ie下jquery.getJSON的缓存问题的处理方法
2013/03/29 Javascript
简单实现jQuery进度条轮播实例代码
2016/06/20 Javascript
Bootstrap的Refresh Icon也spin起来
2016/07/13 Javascript
老生常谈Javascript中的原型和this指针
2016/10/09 Javascript
jQuery旋转插件jqueryrotate用法详解
2016/10/13 Javascript
JavaScript中click和onclick本质区别与用法分析
2018/06/07 Javascript
layer.close()关闭进度条和Iframe窗的方法
2018/08/17 Javascript
浅析webpack-bundle-analyzer在vue-cli3中的使用
2019/10/23 Javascript
javascript使用正则表达式实现注册登入校验
2020/09/23 Javascript
关于Node.js中频繁修改代码重启服务器的问题
2020/10/15 Javascript
[03:14]2014DOTA2西雅图国际邀请赛 EG战队巡礼
2014/07/07 DOTA
python连接sql server乱码的解决方法
2013/01/28 Python
MySQL适配器PyMySQL详解
2017/09/20 Python
Python过滤txt文件内重复内容的方法
2018/10/21 Python
Python基于正则表达式实现计算器功能
2020/07/13 Python
Python爬虫之Selenium实现窗口截图
2020/12/04 Python
python 视频下载神器(you-get)的具体使用
2021/01/06 Python
Python图像处理之膨胀与腐蚀的操作
2021/02/07 Python
CSS3感应鼠标的背景闪烁和图片缩放动画效果
2014/05/14 HTML / CSS
6种非常炫酷的CSS3按钮边框动画特效
2016/03/16 HTML / CSS
Sneaker Studio法国:购买运动鞋
2018/06/08 全球购物
护理专科毕业生自荐书范文
2014/02/19 职场文书
保护环境建议书
2014/03/12 职场文书
超市开业庆典策划方案
2014/05/14 职场文书
高中运动会广播稿
2014/09/16 职场文书
党的群众路线教育实践活动个人对照检查材料(教师)
2014/11/04 职场文书
党员廉洁自律个人总结
2015/02/13 职场文书
2016年度师德标兵先进事迹材料
2016/02/26 职场文书
分析JVM源码之Thread.interrupt系统级别线程打断
2021/06/29 Java/Android