利用Python实现kNN算法的代码


Posted in Python onAugust 16, 2019

邻近算法(k-NearestNeighbor) 是机器学习中的一种分类(classification)算法,也是机器学习中最简单的算法之一了。虽然很简单,但在解决特定问题时却能发挥很好的效果。因此,学习kNN算法是机器学习入门的一个很好的途径。

kNN算法的思想非常的朴素,它选取k个离测试点最近的样本点,输出在这k个样本点中数量最多的标签(label)。我们假设每一个样本有m个特征值(property),则一个样本的可以用一个m维向量表示: X =( x1,x2,... , xm ),  同样地,测试点的特征值也可表示成:Y =( y1,y2,... , ym )。那我们怎么定义这两者之间的“距离”呢?

在二维空间中,有:d2 = ( x1 - y1 )2 + ( x2 - y2 )2 ,  在三维空间中,两点的距离被定义为:d2 = ( x1 - y1 )2 + ( x2 - y2 )2  + ( x3 - y3 )2 。我们可以据此推广到m维空间中,定义m维空间的距离:d2 = ( x1 - y1 )2 + ( x2 - y2 )2  + ...... + ( xm - ym )2 。要实现kNN算法,我们只需要计算出每一个样本点与测试点的距离,选取距离最近的k个样本,获取他们的标签(label) ,然后找出k个样本中数量最多的标签,返回该标签。

在开始实现算法之前,我们要考虑一个问题,不同特征的特征值范围可能有很大的差别,例如,我们要分辨一个人的性别,一个女生的身高是1.70m,体重是60kg,一个男生的身高是1.80m,体重是70kg,而一个未知性别的人的身高是1.81m, 体重是64kg,这个人与女生数据点的“距离”的平方 d2 = ( 1.70 - 1.81 )2 + ( 60 - 64 )2 = 0.0121 + 16.0 = 16.0121,而与男生数据点的“距离”的平方d2 = ( 1.80 - 1.81 )2 + ( 70 - 64 )2 = 0.0001 + 36.0 = 36.0001 。可见,在这种情况下,身高差的平方相对于体重差的平方基本可以忽略不计,但是身高对于辨别性别来说是十分重要的。为了解决这个问题,就需要将数据标准化(normalize),把每一个特征值除以该特征的范围,保证标准化后每一个特征值都在0~1之间。我们写一个normData函数来执行标准化数据集的工作:

def normData(dataSet):
  maxVals = dataSet.max(axis=0)
  minVals = dataSet.min(axis=0)
  ranges = maxVals - minVals
  retData = (dataSet - minVals) / ranges
  return retData, ranges, minVals

 然后开始实现kNN算法:

def kNN(dataSet, labels, testData, k):
  distSquareMat = (dataSet - testData) ** 2 # 计算差值的平方
  distSquareSums = distSquareMat.sum(axis=1) # 求每一行的差值平方和
  distances = distSquareSums ** 0.5 # 开根号,得出每个样本到测试点的距离
  sortedIndices = distances.argsort() # 排序,得到排序后的下标
  indices = sortedIndices[:k] # 取最小的k个
  labelCount = {} # 存储每个label的出现次数
  for i in indices:
    label = labels[i]
    labelCount[label] = labelCount.get(label, 0) + 1 # 次数加一
  sortedCount = sorted(labelCount.items(), key=opt.itemgetter(1), reverse=True) 
  # 对label出现的次数从大到小进行排序
  return sortedCount[0][0] # 返回出现次数最大的label

注意,在testData作为参数传入kNN函数之前,需要经过标准化。

我们用几个小数据验证一下kNN函数是否能正常工作:

if __name__ == "__main__":
  dataSet = np.array([[2, 3], [6, 8]])
  normDataSet, ranges, minVals = normData(dataSet)
  labels = ['a', 'b']
  testData = np.array([3.9, 5.5])
  normTestData = (testData - minVals) / ranges
  result = kNN(normDataSet, labels, normTestData, 1)
  print(result)

结果输出 a ,与预期结果一致。

完整代码:

import numpy as np
from math import sqrt
import operator as opt

def normData(dataSet):
  maxVals = dataSet.max(axis=0)
  minVals = dataSet.min(axis=0)
  ranges = maxVals - minVals
  retData = (dataSet - minVals) / ranges
  return retData, ranges, minVals


def kNN(dataSet, labels, testData, k):
  distSquareMat = (dataSet - testData) ** 2 # 计算差值的平方
  distSquareSums = distSquareMat.sum(axis=1) # 求每一行的差值平方和
  distances = distSquareSums ** 0.5 # 开根号,得出每个样本到测试点的距离
  sortedIndices = distances.argsort() # 排序,得到排序后的下标
  indices = sortedIndices[:k] # 取最小的k个
  labelCount = {} # 存储每个label的出现次数
  for i in indices:
    label = labels[i]
    labelCount[label] = labelCount.get(label, 0) + 1 # 次数加一
  sortedCount = sorted(labelCount.items(), key=opt.itemgetter(1), reverse=True) # 对label出现的次数从大到小进行排序
  return sortedCount[0][0] # 返回出现次数最大的label



if __name__ == "__main__":
  dataSet = np.array([[2, 3], [6, 8]])
  normDataSet, ranges, minVals = normData(dataSet)
  labels = ['a', 'b']
  testData = np.array([3.9, 5.5])
  normTestData = (testData - minVals) / ranges
  result = kNN(normDataSet, labels, normTestData, 1)
  print(result)

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
详解设计模式中的工厂方法模式在Python程序中的运用
Mar 02 Python
在Python中使用AOP实现Redis缓存示例
Jul 11 Python
Python实现列表删除重复元素的三种常用方法分析
Nov 24 Python
python实现用户管理系统
Jan 10 Python
python使用turtle绘制国际象棋棋盘
May 23 Python
python 求定积分和不定积分示例
Nov 20 Python
Python如何使用BeautifulSoup爬取网页信息
Nov 26 Python
pycharm 2019 最新激活方式(pycharm破解、激活)
Sep 22 Python
详解字符串在Python内部是如何省内存的
Feb 03 Python
解决启动django,浏览器显示“服务器拒绝访问”的问题
May 13 Python
Python如何合并多个字典或映射
Jul 24 Python
python中实现词云图的示例
Dec 19 Python
python实现kNN算法识别手写体数字的示例代码
Aug 16 #Python
python爬虫 爬取超清壁纸代码实例
Aug 16 #Python
Python PO设计模式的具体使用
Aug 16 #Python
python使用sessions模拟登录淘宝的方式
Aug 16 #Python
Django错误:TypeError at / 'bool' object is not callable解决
Aug 16 #Python
Python facenet进行人脸识别测试过程解析
Aug 16 #Python
Python Web框架之Django框架Model基础详解
Aug 16 #Python
You might like
星际争霸兵种名称对照表
2020/03/04 星际争霸
php学习笔记 面向对象的构造与析构方法
2011/06/13 PHP
一个基于PDO的数据库操作类(新) 一个PDO事务实例
2011/07/03 PHP
PHP两种去掉数组重复值的方法比较
2014/06/19 PHP
PHP查询快递信息的方法
2015/03/07 PHP
PHP模拟asp.net的StringBuilder类实现方法
2015/08/08 PHP
PHP 实现文件压缩解压操作的方法
2019/06/14 PHP
JS的数组的扩展实例代码
2008/07/09 Javascript
预加载css或javascript的js代码
2010/04/23 Javascript
jQuery 菜单随滚条改为以定位方式(固定要浏览器顶部)
2012/05/24 Javascript
跟我学习javascript的异步脚本加载
2015/11/20 Javascript
微信小程序之picker日期和时间选择器
2017/02/09 Javascript
浅谈webpack下的AOP式无侵入注入
2017/11/12 Javascript
jQuery实现的点击标题文字切换字体效果示例【测试可用】
2018/04/26 jQuery
JavaScript数组,JSON对象实现动态添加、修改、删除功能示例
2018/05/26 Javascript
原生JS封装_new函数实现new关键字的功能
2018/08/12 Javascript
JavaScript onclick事件使用方法详解
2020/05/15 Javascript
python实现pdf转换成word/txt纯文本文件
2018/06/07 Python
Python切片操作去除字符串首尾的空格
2019/04/22 Python
python实现大文件分割与合并
2019/07/22 Python
Python图像处理之图片文字识别功能(OCR)
2019/07/30 Python
django写用户登录判定并跳转制定页面的实例
2019/08/21 Python
python 使用raw socket进行TCP SYN扫描实例
2020/05/05 Python
X/HTML5 和 XHTML2
2008/10/17 HTML / CSS
使用纯HTML5编写一款网页上的时钟的代码分享
2015/11/16 HTML / CSS
美国彩妆品牌:Coastal Scents
2017/04/01 全球购物
Erwin Müller穆勒家居瑞士官网:您整个家庭的邮购公司
2019/12/28 全球购物
戴森比利时官方网站:Dyson BE
2020/10/03 全球购物
同事打架检讨书
2014/02/04 职场文书
管理失职检讨书
2014/02/12 职场文书
《春雨》教学反思
2014/04/24 职场文书
音乐学专业求职信
2014/07/22 职场文书
毕业生自荐材料范文
2014/12/30 职场文书
2015年党员干部承诺书
2015/01/21 职场文书
行政处罚告知书
2015/07/01 职场文书
浅谈MySQL之select优化方案
2021/08/07 MySQL