python K近邻算法的kd树实现


Posted in Python onSeptember 06, 2018

k近邻算法的介绍

k近邻算法是一种基本的分类和回归方法,这里只实现分类的k近邻算法。

k近邻算法的输入为实例的特征向量,对应特征空间的点;输出为实例的类别,可以取多类。

k近邻算法不具有显式的学习过程,实际上k近邻算法是利用训练数据集对特征向量空间进行划分。将划分的空间模型作为其分类模型。

k近邻算法的三要素

  • k值的选择:即分类决策时选择k个最近邻实例;
  • 距离度量:即预测实例点和训练实例点间的距离,一般使用L2距离即欧氏距离;
  • 分类决策规则。

下面对三要素进行一下说明:

1.欧氏距离即欧几里得距离,高中数学中用来计算点和点间的距离公式;

2.k值选择:k值选择会对k近邻法结果产生重大影响,如果选择较小的k值,相当于在较小的邻域中训练实例进行预测,这样有点是“近似误差”会减小,即只与输入实例较近(相似)的训练实例才会起作用,缺点是“估计误差”会增大,即对近邻的实例点很敏感。而k值过大则相反。实际中取较小的k值通过交叉验证的方法取最优k值。

3.k近邻法的分类决策规则往往采用多数表决的方式,这等价于“经验风险最小化”。

k近邻算法的实现:kd树

实现k近邻法是要考虑的主要问题是如何退训练数据进行快速的k近邻搜索,当训练实例数很大是显然通过一般的线性搜索方式效率低下,因此为了提高搜索效率,需要构造特殊的数据结构对训练实例进行存储。kd树就是一种不错的数据结构,可以大大提高搜索效率。

本质商kd树是对k维空间的一个划分,构造kd树相当与使用垂直于坐标轴的超平面将k维空间进行切分,构造一系列的超矩形,kd树的每一个结点对应一个这样的超矩形。

kd树本质上是一棵二叉树,当通过一定规则构造是他是平衡的。

下面是过早kd树的算法:

  • 开始:构造根结点,根节点对应包含所有训练实例的k为空间。 选择第1维为坐标轴,以所有训练实例的第一维数据的中位数为切分点,将根结点对应的超矩形切分为两个子区域。由根结点生成深度为1的左右子结点,左结点对应第一维坐标小于切分点的子区域,右子结点对应第一位坐标大于切分点的子区域。
  • 重复:对深度为j的结点选择第l维为切分坐标轴,l=j(modk)+1,以该区域中所有训练实例的第l维的中位数为切分点,重复第一步。
  • 直到两个子区域没有实例存在时停止。形成kd树。

以下是kd树的python实现

准备工作

#读取数据准备
def file2matrix(filename):
  fr = open(filename)
  returnMat = []     #样本数据矩阵
  for line in fr.readlines():
    line = line.strip().split('\t')
    returnMat.append([float(line[0]),float(line[1]),float(line[2]),float(line[3])])
  return returnMat
  
#将数据归一化,避免数据各维度间的差异过大
def autoNorm(data):
  #将data数据和类别拆分
  data,label = np.split(data,[3],axis=1)
  minVals = data.min(0)   #data各列的最大值
  maxVals = data.max(0)    #data各列的最小值
  ranges = maxVals - minVals
  normDataSet = np.zeros(np.shape(data))
  m = data.shape[0]
  #tile函数将变量内容复制成输入矩阵同样大小的矩阵
  normDataSet = data - np.tile(minVals,(m,1))    
  normDataSet = normDataSet/np.tile(ranges,(m,1))
  #拼接
  normDataSet = np.hstack((normDataSet,label))
  return normDataSet
//数据实例
40920  8.326976  0.953952  3
14488  7.153469  1.673904  2
26052  1.441871  0.805124  1
75136  13.147394  0.428964  1
38344  1.669788  0.134296  1
72993  10.141740  1.032955  1
35948  6.830792  1.213192  3
42666  13.276369  0.543880  3
67497  8.631577  0.749278  1
35483  12.273169  1.508053  3
//每一行是一个数据实例,前三维是数据值,第四维是类别标记

树结构定义

#构建kdTree将特征空间划分
class kd_tree:
  """
  定义结点
  value:节点值
  dimension:当前划分的维数
  left:左子树
  right:右子树
  """
  def __init__(self, value):
    self.value = value
    self.dimension = None    #记录划分的维数
    self.left = None
    self.right = None
  
  def setValue(self, value):
    self.value = value
  
  #类似Java的toString()方法
  def __str__(self):
    return str(self.value)

kd树构造

def creat_kdTree(dataIn, k, root, deep):
  """
  data:要划分的特征空间(即数据集)
  k:表示要选择k个近邻
  root:树的根结点
  deep:结点的深度
  """
  #选择x(l)(即为第l个特征)为坐标轴进行划分,找到x(l)的中位数进行划分
#   x_L = data[:,deep%k]    #这里选取第L个特征的所有数据组成一个列表
  #获取特征值中位数,这里是难点如果numpy没有提供的话
  
  if(dataIn.shape[0]>0):   #如果该区域还有实例数据就继续
    dataIn = dataIn[dataIn[:,int(deep%k)].argsort()]    #numpy的array按照某列进行排序
    data1 = None; data2 = None
    #拿取根据xL排序的中位数的数据作为该子树根结点的value
    if(dataIn.shape[0]%2 == 0):   #该数据集有偶数个数据
      mid = int(dataIn.shape[0]/2)
      root = kd_tree(dataIn[mid,:])
      root.dimension = deep%k
      dataIn = np.delete(dataIn,mid, axis = 0)
      data1,data2 = np.split(dataIn,[mid], axis=0) 
      #mid行元素分到data2中,删除放到根结点中
    elif(dataIn.shape[0]%2 == 1):
      mid = int((dataIn.shape[0]+1)/2 - 1)  #这里出现递归溢出,当shape为(1,4)时出现,原因是np.delete时没有赋值给dataIn
      root = kd_tree(dataIn[mid,:])
      root.dimension = deep%k
      dataIn = np.delete(dataIn,mid, axis = 0)
      data1,data2 = np.split(dataIn,[mid], axis=0) #mid行元素分到data1中,删除放到根结点中
    #深度加一
    deep+=1
    #递归构造子树
    #这里犯了严重错误,递归调用是将root传递进去,造成程序混乱,应该给None
    root.left = creat_kdTree(data1, k, None, deep)
    root.right = creat_kdTree(data2, k, None, deep)
  return root

前序遍历测试

#前序遍历kd树
def preorder(kd_tree,i):
  print(str(kd_tree.value)+" :"+str(kd_tree.dimension)+":"+str(i))
  if kd_tree.left != None:
    preorder(kd_tree.left,i+1)
  if kd_tree.right != None:
    preorder(kd_tree.right,i+1)

kd树的最近邻搜索

最近邻搜索算法,k近邻搜索在此基础上实现

原理:首先找到包含目标点的叶节点;然后从该也结点出发,一次退回到父节点,不断查找与目标点最近的结点,当确定不可能存在更近的结点是停止。

def findClosest(kdNode,closestPoint,x,minDis,i=0):
  """
  这里存在一个问题,当传递普通的不可变对象minDis时,递归退回第一次找到
  最端距离前,minDis改变,最后结果混乱,这里传递一个可变对象进来。
  kdNode:是构造好的kd树。
  closestPoint:是存储最近点的可变对象,这里是array
  x:是要预测的实例
  minDis:是当前最近距离。
  """
  if kdNode == None:
    return
  #计算欧氏距离
  curDis = (sum((kdNode.value[0:3]-x[0:3])**2))**0.5
  if minDis[0] < 0 or curDis < minDis[0] :
    i+=1
    minDis[0] = curDis 
    closestPoint[0] = kdNode.value[0]
    closestPoint[1] = kdNode.value[1]
    closestPoint[2] = kdNode.value[2]
    closestPoint[3] = kdNode.value[3]
    print(str(closestPoint)+" : "+str(i)+" : "+str(minDis))
  #递归查找叶节点
  if kdNode.value[kdNode.dimension] >= x[kdNode.dimension]:
    findClosest(kdNode.left,closestPoint,x,minDis,i)
  else:
    findClosest(kdNode.right, closestPoint, x, minDis,i) 
  #计算测试点和分隔超平面的距离,如果相交进入另一个叶节点重复
  rang = abs(x[kdNode.dimension] - kdNode.value[kdNode.dimension])
  if rang > minDis[0] :
    return
  if kdNode.value[kdNode.dimension] >= x[kdNode.dimension]:
    findClosest(kdNode.right,closestPoint,x,minDis,i)
  else:
    findClosest(kdNode.left, closestPoint, x, minDis,i)

测试:

data = file2matrix("datingTestSet2.txt")
data = np.array(data)
normDataSet = autoNorm(data)
sys.setrecursionlimit(10000)      #设置递归深度为10000
trainSet,testSet = np.split(normDataSet,[900],axis=0) 
kdTree = creat_kdTree(trainSet, 3, None, 0)
newData = testSet[1,0:3]
closestPoint = np.zeros(4)
minDis = np.array([-1.0])
findClosest(kdTree, closestPoint, newData, minDis)
print(closestPoint)
print(testSet[1,:])
print(minDis)

测试结果

[0.35118819 0.43961918 0.67110669 3.        ] : 1 : [0.40348346]
[0.11482037 0.13448927 0.48293309 2.        ] : 2 : [0.30404792]
[0.12227055 0.07902201 0.57826697 2.        ] : 3 : [0.22272422]
[0.0645755  0.10845299 0.83274698 2.        ] : 4 : [0.07066192]
[0.10020488 0.15196271 0.76225551 2.        ] : 5 : [0.02546591]
[0.10020488 0.15196271 0.76225551 2.        ]
[0.08959933 0.15442555 0.78527657 2.        ]
[0.02546591]

k近邻搜索实现

在最近邻的基础上进行改进得到:

这里的closestPoint和minDis合并,一同处理

#k近邻搜索
def findKNode(kdNode, closestPoints, x, k):
  """
  k近邻搜索,kdNode是要搜索的kd树
  closestPoints:是要搜索的k近邻点集合,将minDis放入closestPoints最后一列合并
  x:预测实例
  minDis:是最近距离
  k:是选择k个近邻
  """
  if kdNode == None:
    return
  #计算欧式距离
  curDis = (sum((kdNode.value[0:3]-x[0:3])**2))**0.5
  #将closestPoints按照minDis列排序,这里存在一个问题,排序后返回一个新对象
  #不能将其直接赋值给closestPoints
  tempPoints = closestPoints[closestPoints[:,4].argsort()]
  for i in range(k):
    closestPoints[i] = tempPoints[i]
  #每次取最后一行元素操作
  if closestPoints[k-1][4] >=10000 or closestPoints[k-1][4] > curDis:
    closestPoints[k-1][4] = curDis
    closestPoints[k-1,0:4] = kdNode.value 
    
  #递归搜索叶结点
  if kdNode.value[kdNode.dimension] >= x[kdNode.dimension]:
    findKNode(kdNode.left, closestPoints, x, k)
  else:
    findKNode(kdNode.right, closestPoints, x, k)
  #计算测试点和分隔超平面的距离,如果相交进入另一个叶节点重复
  rang = abs(x[kdNode.dimension] - kdNode.value[kdNode.dimension])
  if rang > closestPoints[k-1][4]:
    return
  if kdNode.value[kdNode.dimension] >= x[kdNode.dimension]:
    findKNode(kdNode.right, closestPoints, x, k)
  else:
    findKNode(kdNode.left, closestPoints, x, k)

测试

data = file2matrix("datingTestSet2.txt")
data = np.array(data)
normDataSet = autoNorm(data)
sys.setrecursionlimit(10000)      #设置递归深度为10000
trainSet,testSet = np.split(normDataSet,[900],axis=0) 
kdTree = creat_kdTree(trainSet, 3, None, 0)
newData = testSet[1,0:3]
print("预测实例点:"+str(newData))
closestPoints = np.zeros((3,5))     #初始化参数
closestPoints[:,4] = 10000.0      #给minDis列赋值
findKNode(kdTree, closestPoints, newData, 3)
print("k近邻结果:"+str(closestPoints))

测试结果

预测实例点:[0.08959933 0.15442555 0.78527657]

k近邻结果:[[0.10020488 0.15196271 0.76225551 2.         0.02546591]
 [0.10664709 0.13172159 0.83777837 2.         0.05968697]
 [0.09616206 0.20475001 0.75047289 2.         0.06153793]]

 以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python装饰器使用方法实例
Nov 21 Python
python通过urllib2爬网页上种子下载示例
Feb 24 Python
批处理与python代码混合编程的方法
May 19 Python
使用python生成目录树
Mar 29 Python
python使用zip将list转为json的方法
Dec 31 Python
python交易记录整合交易类详解
Jul 03 Python
python3反转字符串的3种方法(小结)
Nov 07 Python
python调用接口的4种方式代码实例
Nov 19 Python
Python BeautifulReport可视化报告代码实例
Apr 13 Python
python+selenium 简易地疫情信息自动打卡签到功能的实现代码
Aug 22 Python
Python return语句如何实现结果返回调用
Oct 15 Python
使用Python爬取Json数据的示例代码
Dec 07 Python
pyqt5的QComboBox 使用模板的具体方法
Sep 06 #Python
Python多线程编程之多线程加锁操作示例
Sep 06 #Python
python中将\\uxxxx转换为Unicode字符串的方法
Sep 06 #Python
Python json模块dumps、loads操作示例
Sep 06 #Python
Python 字符串换行的多种方式
Sep 06 #Python
Python使用logging模块实现打印log到指定文件的方法
Sep 05 #Python
Python使用try except处理程序异常的三种常用方法分析
Sep 05 #Python
You might like
PHP数据库操作面向对象的优点
2006/10/09 PHP
php防攻击代码升级版
2010/12/29 PHP
Youku 视频绝对地址获取的方法详解
2013/06/26 PHP
PHP使用CURL实现对带有验证码的网站进行模拟登录的方法
2014/07/23 PHP
PHP解析RSS的方法
2015/03/05 PHP
检测是否已安装 .NET Framework 3.5的js脚本
2009/02/14 Javascript
JS中的this变量的使用介绍
2013/10/21 Javascript
用unescape反编码得出汉字示例
2014/04/24 Javascript
jQuery中大家不太了解的几个方法
2015/03/04 Javascript
JavaScript实现动画打开半透明提示层的方法
2015/04/21 Javascript
javascript实现数字倒计时特效
2016/03/30 Javascript
EasyUI在表单提交之前进行验证的实例代码
2016/06/24 Javascript
js获取一组日期中最近连续的天数
2017/05/25 Javascript
详解vue组件通信的三种方式
2017/06/30 Javascript
vue 2.0项目中如何引入element-ui详解
2017/09/06 Javascript
详解vue中组件参数
2018/07/09 Javascript
js获取 gif 的帧数的代码实例
2019/09/10 Javascript
js实现全选和全不选功能
2020/07/28 Javascript
[03:19]2016国际邀请赛中国区预选赛第四日TOP10镜头集锦
2016/07/01 DOTA
详解Python中的__init__和__new__
2014/03/12 Python
Python 实现简单的电话本功能
2015/08/09 Python
Window 64位下python3.6.2环境搭建图文教程
2018/09/19 Python
使用python对文件中的单词进行提取的方法示例
2018/12/21 Python
pyqt5利用pyqtDesigner实现登录界面
2019/03/28 Python
python 处理微信对账单数据的实例代码
2019/07/19 Python
解决jupyter notebook import error但是命令提示符import正常的问题
2020/04/15 Python
英国领先的办公用品供应商:Viking
2016/08/01 全球购物
局部内部类是否可以访问非final变量?
2013/04/20 面试题
英文版销售经理个人求职信
2013/11/20 职场文书
2014年十一国庆向国旗敬礼寄语
2014/04/11 职场文书
和谐拯救危机观后感
2015/06/15 职场文书
消夏晚会主持词
2015/06/30 职场文书
班主任工作经验交流会总结
2015/11/02 职场文书
详解Redis实现限流的三种方式
2021/04/27 Redis
修改并编译golang源码的操作步骤
2021/07/25 Golang
企业开发CSS命名BEM代码规范实践
2022/02/12 HTML / CSS