利用Python实现kNN算法的代码


Posted in Python onAugust 16, 2019

邻近算法(k-NearestNeighbor) 是机器学习中的一种分类(classification)算法,也是机器学习中最简单的算法之一了。虽然很简单,但在解决特定问题时却能发挥很好的效果。因此,学习kNN算法是机器学习入门的一个很好的途径。

kNN算法的思想非常的朴素,它选取k个离测试点最近的样本点,输出在这k个样本点中数量最多的标签(label)。我们假设每一个样本有m个特征值(property),则一个样本的可以用一个m维向量表示: X =( x1,x2,... , xm ),  同样地,测试点的特征值也可表示成:Y =( y1,y2,... , ym )。那我们怎么定义这两者之间的“距离”呢?

在二维空间中,有:d2 = ( x1 - y1 )2 + ( x2 - y2 )2 ,  在三维空间中,两点的距离被定义为:d2 = ( x1 - y1 )2 + ( x2 - y2 )2  + ( x3 - y3 )2 。我们可以据此推广到m维空间中,定义m维空间的距离:d2 = ( x1 - y1 )2 + ( x2 - y2 )2  + ...... + ( xm - ym )2 。要实现kNN算法,我们只需要计算出每一个样本点与测试点的距离,选取距离最近的k个样本,获取他们的标签(label) ,然后找出k个样本中数量最多的标签,返回该标签。

在开始实现算法之前,我们要考虑一个问题,不同特征的特征值范围可能有很大的差别,例如,我们要分辨一个人的性别,一个女生的身高是1.70m,体重是60kg,一个男生的身高是1.80m,体重是70kg,而一个未知性别的人的身高是1.81m, 体重是64kg,这个人与女生数据点的“距离”的平方 d2 = ( 1.70 - 1.81 )2 + ( 60 - 64 )2 = 0.0121 + 16.0 = 16.0121,而与男生数据点的“距离”的平方d2 = ( 1.80 - 1.81 )2 + ( 70 - 64 )2 = 0.0001 + 36.0 = 36.0001 。可见,在这种情况下,身高差的平方相对于体重差的平方基本可以忽略不计,但是身高对于辨别性别来说是十分重要的。为了解决这个问题,就需要将数据标准化(normalize),把每一个特征值除以该特征的范围,保证标准化后每一个特征值都在0~1之间。我们写一个normData函数来执行标准化数据集的工作:

def normData(dataSet):
  maxVals = dataSet.max(axis=0)
  minVals = dataSet.min(axis=0)
  ranges = maxVals - minVals
  retData = (dataSet - minVals) / ranges
  return retData, ranges, minVals

 然后开始实现kNN算法:

def kNN(dataSet, labels, testData, k):
  distSquareMat = (dataSet - testData) ** 2 # 计算差值的平方
  distSquareSums = distSquareMat.sum(axis=1) # 求每一行的差值平方和
  distances = distSquareSums ** 0.5 # 开根号,得出每个样本到测试点的距离
  sortedIndices = distances.argsort() # 排序,得到排序后的下标
  indices = sortedIndices[:k] # 取最小的k个
  labelCount = {} # 存储每个label的出现次数
  for i in indices:
    label = labels[i]
    labelCount[label] = labelCount.get(label, 0) + 1 # 次数加一
  sortedCount = sorted(labelCount.items(), key=opt.itemgetter(1), reverse=True) 
  # 对label出现的次数从大到小进行排序
  return sortedCount[0][0] # 返回出现次数最大的label

注意,在testData作为参数传入kNN函数之前,需要经过标准化。

我们用几个小数据验证一下kNN函数是否能正常工作:

if __name__ == "__main__":
  dataSet = np.array([[2, 3], [6, 8]])
  normDataSet, ranges, minVals = normData(dataSet)
  labels = ['a', 'b']
  testData = np.array([3.9, 5.5])
  normTestData = (testData - minVals) / ranges
  result = kNN(normDataSet, labels, normTestData, 1)
  print(result)

结果输出 a ,与预期结果一致。

完整代码:

import numpy as np
from math import sqrt
import operator as opt

def normData(dataSet):
  maxVals = dataSet.max(axis=0)
  minVals = dataSet.min(axis=0)
  ranges = maxVals - minVals
  retData = (dataSet - minVals) / ranges
  return retData, ranges, minVals


def kNN(dataSet, labels, testData, k):
  distSquareMat = (dataSet - testData) ** 2 # 计算差值的平方
  distSquareSums = distSquareMat.sum(axis=1) # 求每一行的差值平方和
  distances = distSquareSums ** 0.5 # 开根号,得出每个样本到测试点的距离
  sortedIndices = distances.argsort() # 排序,得到排序后的下标
  indices = sortedIndices[:k] # 取最小的k个
  labelCount = {} # 存储每个label的出现次数
  for i in indices:
    label = labels[i]
    labelCount[label] = labelCount.get(label, 0) + 1 # 次数加一
  sortedCount = sorted(labelCount.items(), key=opt.itemgetter(1), reverse=True) # 对label出现的次数从大到小进行排序
  return sortedCount[0][0] # 返回出现次数最大的label



if __name__ == "__main__":
  dataSet = np.array([[2, 3], [6, 8]])
  normDataSet, ranges, minVals = normData(dataSet)
  labels = ['a', 'b']
  testData = np.array([3.9, 5.5])
  normTestData = (testData - minVals) / ranges
  result = kNN(normDataSet, labels, normTestData, 1)
  print(result)

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
布同 Python中文问题解决方法(总结了多位前人经验,初学者必看)
Mar 13 Python
python快速查找算法应用实例
Sep 26 Python
在Python中操作时间之tzset()方法的使用教程
May 22 Python
Python实现以时间换空间的缓存替换算法
Feb 19 Python
python得到单词模式的示例
Oct 15 Python
Python实现将Excel转换成为image的方法
Oct 23 Python
numpy.random模块用法总结
May 27 Python
浅析python redis的连接及相关操作
Nov 07 Python
python集合删除多种方法详解
Feb 10 Python
Python根据字典的值查询出对应的键的方法
Sep 30 Python
Python超简单容易上手的画图工具库推荐
May 10 Python
利用Matlab绘制各类特殊图形的实例代码
Jul 16 Python
python实现kNN算法识别手写体数字的示例代码
Aug 16 #Python
python爬虫 爬取超清壁纸代码实例
Aug 16 #Python
Python PO设计模式的具体使用
Aug 16 #Python
python使用sessions模拟登录淘宝的方式
Aug 16 #Python
Django错误:TypeError at / 'bool' object is not callable解决
Aug 16 #Python
Python facenet进行人脸识别测试过程解析
Aug 16 #Python
Python Web框架之Django框架Model基础详解
Aug 16 #Python
You might like
linux下使用crontab实现定时PHP计划任务失败的原因分析
2014/07/05 PHP
yii2 resetful 授权验证详解
2017/05/18 PHP
PHP基于SimpleXML生成和解析xml的方法示例
2017/07/17 PHP
PHP实现将多个文件压缩成zip格式并下载到本地的方法示例
2018/05/23 PHP
ThinkPHP5分页paginate代码实例解析
2020/11/10 PHP
用js实现的一个Flash滚动轮换显示图片代码生成器
2007/03/14 Javascript
HTML node相关的一些资料整理
2010/01/01 Javascript
Js+Flash实现访问剪切板操作
2012/11/20 Javascript
jquery+CSS实现的多级竖向展开树形TRee菜单效果
2015/08/24 Javascript
基于JS判断iframe是否加载成功的方法(多种浏览器)
2016/05/13 Javascript
Jquery Easyui日历组件Calender使用详解(23)
2016/12/18 Javascript
Boostrap栅格系统与自己额外定义的媒体查询的冲突问题
2017/02/19 Javascript
JS条形码(一维码)插件JsBarcode用法详解【编码类型、参数、属性】
2017/04/19 Javascript
解决jquery appaend元素中id绑定事件失效的问题
2017/09/12 jQuery
浅谈用Webpack路径压缩图片上传尺寸获取的问题
2018/02/22 Javascript
vue中,在本地缓存中读写数据的方法
2018/09/21 Javascript
React+Antd+Redux实现待办事件的方法
2019/03/14 Javascript
解决antd datepicker 获取时间默认少8个小时的问题
2020/10/29 Javascript
对于Python的Django框架使用的一些实用建议
2015/04/03 Python
python多进程实现进程间通信实例
2017/11/24 Python
深入解析HTML5 Canvas控制图形矩阵变换的方法
2016/03/24 HTML / CSS
使用iframe+postMessage实现页面跨域通信的示例代码
2020/01/14 HTML / CSS
SHEIN香港:价格实惠的女性时尚服装
2018/08/14 全球购物
德国婴儿服装和婴儿用品购买网站:Baby Sweets
2019/12/08 全球购物
介绍一下Java中的Class类
2015/04/10 面试题
毕业生自我鉴定
2013/11/05 职场文书
简短证婚人证婚词
2014/01/09 职场文书
公司司机岗位职责范本
2014/03/03 职场文书
洗车工岗位职责
2014/03/15 职场文书
2014年政教处工作总结
2014/12/20 职场文书
受资助学生感谢信
2015/01/21 职场文书
英文商务邀请函范文
2015/01/31 职场文书
领导干部学习三严三实心得体会
2016/01/05 职场文书
2016大学生暑期三下乡心得体会
2016/01/23 职场文书
漫改真人电影「萌系男友是燃燃的橘色」公开先导视觉图
2022/03/21 日漫
windows server2008 开启端口的实现方法
2022/06/25 Servers