python实现KNN近邻算法


Posted in Python onDecember 30, 2020

示例:《电影类型分类》

获取数据来源

电影名称 打斗次数 接吻次数 电影类型
California Man 3 104 Romance
He's Not Really into Dudes 8 95 Romance
Beautiful Woman 1 81 Romance
Kevin Longblade 111 15 Action
Roob Slayer 3000 99 2 Action
Amped II 88 10 Action
Unknown 18 90 unknown

数据显示:肉眼判断电影类型unknown是什么

from matplotlib import pyplot as plt
​
# 用来正常显示中文标签
plt.rcParams["font.sans-serif"] = ["SimHei"]
# 电影名称
names = ["California Man", "He's Not Really into Dudes", "Beautiful Woman",
   "Kevin Longblade", "Robo Slayer 3000", "Amped II", "Unknown"]
# 类型标签
labels = ["Romance", "Romance", "Romance", "Action", "Action", "Action", "Unknown"]
colors = ["darkblue", "red", "green"]
colorDict = {label: color for (label, color) in zip(set(labels), colors)}
print(colorDict)
# 打斗次数,接吻次数
X = [3, 8, 1, 111, 99, 88, 18]
Y = [104, 95, 81, 15, 2, 10, 88]
​
plt.title("通过打斗次数和接吻次数判断电影类型", fontsize=18)
plt.xlabel("电影中打斗镜头出现的次数", fontsize=16)
plt.ylabel("电影中接吻镜头出现的次数", fontsize=16)
​
# 绘制数据
for i in range(len(X)):
 # 散点图绘制
 plt.scatter(X[i], Y[i], color=colorDict[labels[i]])
​
# 每个点增加描述信息
for i in range(0, 7):
 plt.text(X[i]+2, Y[i]-1, names[i], fontsize=14)
​
plt.show()

问题分析:根据已知信息分析电影类型unknown是什么

核心思想:

未标记样本的类别由距离其最近的K个邻居的类别决定

距离度量:

一般距离计算使用欧式距离(用勾股定理计算距离),也可以采用曼哈顿距离(水平上和垂直上的距离之和)、余弦值和相似度(这是距离的另一种表达方式)。相比于上述距离,马氏距离更为精确,因为它能考虑很多因素,比如单位,由于在求协方差矩阵逆矩阵的过程中,可能不存在,而且若碰见3维及3维以上,求解过程中极其复杂,故可不使用马氏距离

知识扩展

  • 马氏距离概念:表示数据的协方差距离
  • 方差:数据集中各个点到均值点的距离的平方的平均值
  • 标准差:方差的开方
  • 协方差cov(x, y):E表示均值,D表示方差,x,y表示不同的数据集,xy表示数据集元素对应乘积组成数据集

cov(x, y) = E(xy) - E(x)*E(y)

cov(x, x) = D(x)

cov(x1+x2, y) = cov(x1, y) + cov(x2, y)

cov(ax, by) = abcov(x, y)

  • 协方差矩阵:根据维度组成的矩阵,假设有三个维度,a,b,c

∑ij = [cov(a, a) cov(a, b) cov(a, c) cov(b, a) cov(b,b) cov(b, c) cov(c, a) cov(c, b) cov(c, c)]

算法实现:欧氏距离

编码实现

# 自定义实现 mytest1.py
import numpy as np
​
# 创建数据集
def createDataSet():
 features = np.array([[3, 104], [8, 95], [1, 81], [111, 15],
       [99, 2], [88, 10]])
 labels = ["Romance", "Romance", "Romance", "Action", "Action", "Action"]
 return features, labels
​
def knnClassify(testFeature, trainingSet, labels, k):
 """
 KNN算法实现,采用欧式距离
 :param testFeature: 测试数据集,ndarray类型,一维数组
 :param trainingSet: 训练数据集,ndarray类型,二维数组
 :param labels: 训练集对应标签,ndarray类型,一维数组
 :param k: k值,int类型
 :return: 预测结果,类型与标签中元素一致
 """
 dataSetsize = trainingSet.shape[0]
 """
 构建一个由dataSet[i] - testFeature的新的数据集diffMat
 diffMat中的每个元素都是dataSet中每个特征与testFeature的差值(欧式距离中差)
 """
 testFeatureArray = np.tile(testFeature, (dataSetsize, 1))
 diffMat = testFeatureArray - trainingSet
 # 对每个差值求平方
 sqDiffMat = diffMat ** 2
 # 计算dataSet中每个属性与testFeature的差的平方的和
 sqDistances = sqDiffMat.sum(axis=1)
 # 计算每个feature与testFeature之间的欧式距离
 distances = sqDistances ** 0.5
​
 """
 排序,按照从小到大的顺序记录distances中各个数据的位置
 如distance = [5, 9, 0, 2]
 则sortedStance = [2, 3, 0, 1]
 """
 sortedDistances = distances.argsort()
​
 # 选择距离最小的k个点
 classCount = {}
 for i in range(k):
  voteiLabel = labels[list(sortedDistances).index(i)]
  classCount[voteiLabel] = classCount.get(voteiLabel, 0) + 1
 # 对k个结果进行统计、排序,选取最终结果,将字典按照value值从大到小排序
 sortedclassCount = sorted(classCount.items(), key=lambda x: x[1], reverse=True)
 return sortedclassCount[0][0]
​
testFeature = np.array([100, 200])
features, labels = createDataSet()
res = knnClassify(testFeature, features, labels, 3)
print(res)
# 使用python包实现 mytest2.py
from sklearn.neighbors import KNeighborsClassifier
from .mytest1 import createDataSet
​
features, labels = createDataSet()
k = 5
clf = KNeighborsClassifier(k_neighbors=k)
clf.fit(features, labels)
​
# 样本值
my_sample = [[18, 90]]
res = clf.predict(my_sample)
print(res)

示例:《交友网站匹配效果预测》

数据来源:略

数据显示

import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
​
# 数据加载
def loadDatingData(file):
 datingData = pd.read_table(file, header=None)
 datingData.columns = ["FlightDistance", "PlaytimePreweek", "IcecreamCostPreweek", "label"]
 datingTrainData = np.array(datingData[["FlightDistance", "PlaytimePreweek", "IcecreamCostPreweek"]])
 datingTrainLabel = np.array(datingData["label"])
 return datingData, datingTrainData, datingTrainLabel
​
# 3D图显示数据
def dataView3D(datingTrainData, datingTrainLabel):
 plt.figure(1, figsize=(8, 3))
 plt.subplot(111, projection="3d")
 plt.scatter(np.array([datingTrainData[x][0]
       for x in range(len(datingTrainLabel))
       if datingTrainLabel[x] == "smallDoses"]),
    np.array([datingTrainData[x][1]
       for x in range(len(datingTrainLabel))
       if datingTrainLabel[x] == "smallDoses"]),
    np.array([datingTrainData[x][2]
       for x in range(len(datingTrainLabel))
       if datingTrainLabel[x] == "smallDoses"]), c="red")
 plt.scatter(np.array([datingTrainData[x][0]
       for x in range(len(datingTrainLabel))
       if datingTrainLabel[x] == "didntLike"]),
    np.array([datingTrainData[x][1]
       for x in range(len(datingTrainLabel))
       if datingTrainLabel[x] == "didntLike"]),
    np.array([datingTrainData[x][2]
       for x in range(len(datingTrainLabel))
       if datingTrainLabel[x] == "didntLike"]), c="green")
 plt.scatter(np.array([datingTrainData[x][0]
       for x in range(len(datingTrainLabel))
       if datingTrainLabel[x] == "largeDoses"]),
    np.array([datingTrainData[x][1]
       for x in range(len(datingTrainLabel))
       if datingTrainLabel[x] == "largeDoses"]),
    np.array([datingTrainData[x][2]
       for x in range(len(datingTrainLabel))
       if datingTrainLabel[x] == "largeDoses"]), c="blue")
 plt.xlabel("飞行里程数", fontsize=16)
 plt.ylabel("视频游戏耗时百分比", fontsize=16)
 plt.clabel("冰淇凌消耗", fontsize=16)
 plt.show()
 
datingData, datingTrainData, datingTrainLabel = loadDatingData(FILEPATH1)
datingView3D(datingTrainData, datingTrainLabel)

问题分析:抽取数据集的前10%在数据集的后90%进行测试

编码实现

# 自定义方法实现
import pandas as pd
import numpy as np
​
# 数据加载
def loadDatingData(file):
 datingData = pd.read_table(file, header=None)
 datingData.columns = ["FlightDistance", "PlaytimePreweek", "IcecreamCostPreweek", "label"]
 datingTrainData = np.array(datingData[["FlightDistance", "PlaytimePreweek", "IcecreamCostPreweek"]])
 datingTrainLabel = np.array(datingData["label"])
 return datingData, datingTrainData, datingTrainLabel
​
# 数据归一化
def autoNorm(datingTrainData):
 # 获取数据集每一列的最值
 minValues, maxValues = datingTrainData.min(0), datingTrainData.max(0)
 diffValues = maxValues - minValues
 
 # 定义形状和datingTrainData相似的最小值矩阵和差值矩阵
 m = datingTrainData.shape(0)
 minValuesData = np.tile(minValues, (m, 1))
 diffValuesData = np.tile(diffValues, (m, 1))
 normValuesData = (datingTrainData-minValuesData)/diffValuesData
 return normValuesData
​
# 核心算法实现
def KNNClassifier(testData, trainData, trainLabel, k):
 m = trainData.shape(0)
 testDataArray = np.tile(testData, (m, 1))
 diffDataArray = (testDataArray - trainData) ** 2
 sumDataArray = diffDataArray.sum(axis=1) ** 0.5
 # 对结果进行排序
 sumDataSortedArray = sumDataArray.argsort()
 
 classCount = {}
 for i in range(k):
  labelName = trainLabel[list(sumDataSortedArray).index(i)]
  classCount[labelName] = classCount.get(labelName, 0)+1
 classCount = sorted(classCount.items(), key=lambda x: x[1], reversed=True)
 return classCount[0][0]
 
​
# 数据测试
def datingTest(file):
 datingData, datingTrainData, datingTrainLabel = loadDatingData(file)
 normValuesData = autoNorm(datingTrainData)
 
 
 errorCount = 0
 ratio = 0.10
 total = datingTrainData.shape(0)
 numberTest = int(total * ratio)
 for i in range(numberTest):
  res = KNNClassifier(normValuesData[i], normValuesData[numberTest:m], datingTrainLabel, 5)
  if res != datingTrainLabel[i]:
   errorCount += 1
 print("The total error rate is : {}\n".format(error/float(numberTest)))
​
if __name__ == "__main__":
 FILEPATH = "./datingTestSet1.txt"
 datingTest(FILEPATH)
# python 第三方包实现
import pandas as pd
import numpy as np
from sklearn.neighbors import KNeighborsClassifier
​
if __name__ == "__main__":
 FILEPATH = "./datingTestSet1.txt"
 datingData, datingTrainData, datingTrainLabel = loadDatingData(FILEPATH)
 normValuesData = autoNorm(datingTrainData)
 errorCount = 0
 ratio = 0.10
 total = normValuesData.shape[0]
 numberTest = int(total * ratio)
 
 k = 5
 clf = KNeighborsClassifier(n_neighbors=k)
 clf.fit(normValuesData[numberTest:total], datingTrainLabel[numberTest:total])
 
 for i in range(numberTest):
  res = clf.predict(normValuesData[i].reshape(1, -1))
  if res != datingTrainLabel[i]:
   errorCount += 1
 print("The total error rate is : {}\n".format(errorCount/float(numberTest)))

以上就是python实现KNN近邻算法的详细内容,更多关于python实现KNN近邻算法的资料请关注三水点靠木其它相关文章!

Python 相关文章推荐
python生成随机验证码(中文验证码)示例
Apr 03 Python
在服务器端实现无间断部署Python应用的教程
Apr 16 Python
Python脚本实现自动发带图的微博
Apr 27 Python
Python中基础的socket编程实战攻略
Jun 01 Python
python实现QQ邮箱/163邮箱的邮件发送
Jan 22 Python
Python 调用 Outlook 发送邮件过程解析
Aug 08 Python
python科学计算之scipy——optimize用法
Nov 25 Python
python加载自定义词典实例
Dec 06 Python
深入浅析Python 命令行模块 Click
Mar 11 Python
django 外键创建注意事项说明
May 20 Python
PyQt中使用QtSql连接MySql数据库的方法
Jul 28 Python
python 使用三引号时容易犯的小错误
Oct 21 Python
python 实现逻辑回归
Dec 30 #Python
Python 随机按键模拟2小时
Dec 30 #Python
Python的scikit-image模块实例讲解
Dec 30 #Python
用Python实现职工信息管理系统
Dec 30 #Python
python实现双人五子棋(终端版)
Dec 30 #Python
pandas 数据类型转换的实现
Dec 29 #Python
python中xlutils库用法浅析
Dec 29 #Python
You might like
php中一个有意思的日期逻辑处理
2012/03/25 PHP
Nginx下配置codeigniter框架方法
2015/04/07 PHP
php序列化函数serialize() 和 unserialize() 与原生函数对比
2015/05/08 PHP
使用Composer安装Yii框架的方法
2016/03/15 PHP
php版微信开发Token验证失败或请求URL超时问题的解决方法
2016/09/23 PHP
php连接MSsql server的五种方法总结
2018/03/04 PHP
漂亮的提示信息(带箭头)
2007/03/21 Javascript
基于promise.js实现nodejs的promises库
2014/07/06 NodeJs
通过JS动态创建一个html DOM元素并显示
2014/10/15 Javascript
Js实现自定义右键行为
2015/03/26 Javascript
JavaScript中this详解
2015/09/01 Javascript
浅析JavaScript访问对象属性和方法及区别
2015/11/16 Javascript
Javascript中arguments对象的详解与使用方法
2016/10/04 Javascript
详解js的事件代理(委托)
2016/12/22 Javascript
详解JavaScript中js对象与JSON格式字符串的相互转换
2017/02/14 Javascript
jQuery加密密码到cookie的实现代码
2017/04/18 jQuery
详解利用 Vue.js 实现前后端分离的RBAC角色权限管理
2017/09/15 Javascript
vue + axios get下载文件功能
2019/09/25 Javascript
nuxt.js 在middleware(中间件)中实现路由鉴权操作
2020/11/06 Javascript
python对html过滤处理的方法
2018/10/21 Python
详解Python基础random模块随机数的生成
2019/03/23 Python
Python线程threading模块用法详解
2020/02/26 Python
python爬取”顶点小说网“《纯阳剑尊》的示例代码
2020/10/16 Python
python 利用openpyxl读取Excel表格中指定的行或列教程
2021/02/06 Python
使用CSS3的appearance属性改变元素的外观的方法
2015/12/12 HTML / CSS
深入剖析webstorage[html5的本地数据处理]
2016/07/11 HTML / CSS
Linux的主要特性
2014/10/06 面试题
常用UNIX 命令(Linux的常用命令)
2015/12/26 面试题
工程专业应届生求职信
2014/02/19 职场文书
《李时珍夜宿古寺》教学反思
2014/04/09 职场文书
2014年高中班主任工作总结
2014/11/08 职场文书
民事诉讼答辩状范文
2015/05/21 职场文书
基于Redis过期事件实现订单超时取消
2021/05/08 Redis
Python趣味挑战之给幼儿园弟弟生成1000道算术题
2021/05/28 Python
MySQL数据库实验实现简单数据库应用系统设计
2022/06/21 MySQL
MySQL下载安装配置详细教程 附下载资源
2022/09/23 MySQL