kd树是一种对k维空间中的实例点进行存储以便对其进行快速检索的树形数据结构。
kd树是二叉树,表示对kd树 - 图1维空间的一个划分(partition)。构造kd树相当于不断地用垂直于坐标轴的超平面将kd树 - 图2维空间切分,构成一系列的k维超矩形区域。kd树的每个结点对应于一个kd树 - 图3维超矩形区域。

构造kd树的方法如下:
构造根结点,使根结点对应于kd树 - 图4维空间中包含所有实例点的超矩形区域;通过下面的递归方法,不断地对kd树 - 图5维空间进行切分,生成子结点。在超矩形区域(结点)上选择一个坐标轴和在此坐标轴上的一个切分点,确定一个超平面,这个超平面通过选定的切分点并垂直于选定的坐标轴,将当前超矩形区域切分为左右两个子区域 (子结点);这时,实例被分到两个子区域。这个过程直到子区域内没有实例时终止(终止时的结点为叶结点)。在此过程中,将实例保存在相应的结点上。
通常,依次选择坐标轴对空间切分,选择训练实例点在选定坐标轴上的中位数 (median)为切分点,这样得到的kd树是平衡的。注意,平衡的kd树搜索时的效率未必是最优的。

构造平衡kd树算法


输入:kd树 - 图6维空间数据集kd树 - 图7
其中kd树 - 图8kd树 - 图9
输出:kd树。

(1)开始:构造根结点,根结点对应于包含kd树 - 图10kd树 - 图11维空间的超矩形区域。
选择kd树 - 图12为坐标轴,以T中所有实例的kd树 - 图13坐标的中位数为切分点,将根结点对应的超矩形区域切分为两个子区域。切分由通过切分点并与坐标轴kd树 - 图14垂直的超平面实现。
由根结点生成深度为1的左、右子结点:左子结点对应坐标kd树 - 图15小于切分点的子区域, 右子结点对应于坐标kd树 - 图16大于切分点的子区域。
将落在切分超平面上的实例点保存在根结点。
(2)重复:对深度为kd树 - 图17的结点,选择kd树 - 图18为切分的坐标轴,kd树 - 图19,以该结点的区域中所有实例的kd树 - 图20坐标的中位数为切分点,将该结点对应的超矩形区域切分为两个子区域。切分由通过切分点并与坐标轴kd树 - 图21垂直的超平面实现。
由该结点生成深度为kd树 - 图22的左、右子结点:左子结点对应坐标kd树 - 图23小于切分点的子区域,右子结点对应坐标kd树 - 图24大于切分点的子区域。
将落在切分超平面上的实例点保存在该结点。
(3)直到两个子区域没有实例存在时停止。从而形成kd树的区域划分。

代码

  1. # kd-tree每个结点中主要包含的数据结构如下
  2. class KdNode(object):
  3. def __init__(self, dom_elt, split, left, right):
  4. self.dom_elt = dom_elt # k维向量节点(k维空间中的一个样本点)
  5. self.split = split # 整数(进行分割维度的序号)
  6. self.left = left # 该结点分割超平面左子空间构成的kd-tree
  7. self.right = right # 该结点分割超平面右子空间构成的kd-tree
  8. class KdTree(object):
  9. def __init__(self, data):
  10. k = len(data[0]) # 数据维度
  11. def CreateNode(split, data_set): # 按第split维划分数据集exset创建KdNode
  12. if not data_set: # 数据集为空
  13. return None
  14. # key参数的值为一个函数,此函数只有一个参数且返回一个值用来进行比较
  15. # operator模块提供的itemgetter函数用于获取对象的哪些维的数据,参数为需要获取的数据在对象中的序号
  16. #data_set.sort(key=itemgetter(split)) # 按要进行分割的那一维数据排序
  17. data_set.sort(key=lambda x: x[split])
  18. split_pos = len(data_set) // 2 # //为Python中的整数除法
  19. median = data_set[split_pos] # 中位数分割点
  20. split_next = (split + 1) % k # cycle coordinates
  21. # 递归的创建kd树
  22. return KdNode(
  23. median,
  24. split,
  25. CreateNode(split_next, data_set[:split_pos]), # 创建左子树
  26. CreateNode(split_next, data_set[split_pos + 1:])) # 创建右子树
  27. self.root = CreateNode(0, data) # 从第0维分量开始构建kd树,返回根节点
  28. # KDTree的前序遍历
  29. def preorder(root):
  30. print(root.dom_elt)
  31. if root.left: # 节点不为空
  32. preorder(root.left)
  33. if root.right:
  34. preorder(root.right)

_

  1. # 对构建好的kd树进行搜索,寻找与目标点最近的样本点:
  2. from math import sqrt
  3. from collections import namedtuple
  4. # 定义一个namedtuple,分别存放最近坐标点、最近距离和访问过的节点数
  5. result = namedtuple("Result_tuple", "nearest_point nearest_dist nodes_visited")
  6. def find_nearest(tree, point):
  7. k = len(point) # 数据维度
  8. def travel(kd_node, target, max_dist):
  9. if kd_node is None:
  10. return result([0] * k, float("inf"),
  11. 0) # python中用float("inf")和float("-inf")表示正负无穷
  12. nodes_visited = 1
  13. s = kd_node.split # 进行分割的维度
  14. pivot = kd_node.dom_elt # 进行分割的“轴”
  15. if target[s] <= pivot[s]: # 如果目标点第s维小于分割轴的对应值(目标离左子树更近)
  16. nearer_node = kd_node.left # 下一个访问节点为左子树根节点
  17. further_node = kd_node.right # 同时记录下右子树
  18. else: # 目标离右子树更近
  19. nearer_node = kd_node.right # 下一个访问节点为右子树根节点
  20. further_node = kd_node.left
  21. temp1 = travel(nearer_node, target, max_dist) # 进行遍历找到包含目标点的区域
  22. nearest = temp1.nearest_point # 以此叶结点作为“当前最近点”
  23. dist = temp1.nearest_dist # 更新最近距离
  24. nodes_visited += temp1.nodes_visited
  25. if dist < max_dist:
  26. max_dist = dist # 最近点将在以目标点为球心,max_dist为半径的超球体内
  27. temp_dist = abs(pivot[s] - target[s]) # 第s维上目标点与分割超平面的距离
  28. if max_dist < temp_dist: # 判断超球体是否与超平面相交
  29. return result(nearest, dist, nodes_visited) # 不相交则可以直接返回,不用继续判断
  30. #----------------------------------------------------------------------
  31. # 计算目标点与分割点的欧氏距离
  32. temp_dist = sqrt(sum((p1 - p2)**2 for p1, p2 in zip(pivot, target)))
  33. if temp_dist < dist: # 如果“更近”
  34. nearest = pivot # 更新最近点
  35. dist = temp_dist # 更新最近距离
  36. max_dist = dist # 更新超球体半径
  37. # 检查另一个子结点对应的区域是否有更近的点
  38. temp2 = travel(further_node, target, max_dist)
  39. nodes_visited += temp2.nodes_visited
  40. if temp2.nearest_dist < dist: # 如果另一个子结点内存在更近距离
  41. nearest = temp2.nearest_point # 更新最近点
  42. dist = temp2.nearest_dist # 更新最近距离
  43. return result(nearest, dist, nodes_visited)
  44. return travel(tree.root, point, float("inf")) # 从根节点开始递归

例题3.2

image.png

  1. data = [[2,3],[5,4],[9,6],[4,7],[8,1],[7,2]]
  2. kd = KdTree(data)
  3. preorder(kd.root)
  4. [7, 2]
  5. [5, 4]
  6. [2, 3]
  7. [4, 7]
  8. [9, 6]
  9. [8, 1]
  10. # ----------------------------------------------------------
  11. ret = find_nearest(kd, [3,4.5])
  12. print (ret)
  13. Result_tuple(nearest_point=[2, 3], nearest_dist=1.8027756377319946, nodes_visited=4)

随机情况

  1. from time import clock
  2. from random import random
  3. # 产生一个k维随机向量,每维分量值在0~1之间
  4. def random_point(k):
  5. return [random() for _ in range(k)]
  6. # 产生n个k维随机向量
  7. def random_points(k, n):
  8. return [random_point(k) for _ in range(n)]
  9. N = 400000
  10. t0 = clock()
  11. kd2 = KdTree(random_points(3, N)) # 构建包含四十万个3维空间样本点的kd树
  12. ret2 = find_nearest(kd2, [0.1,0.5,0.8]) # 四十万个样本点中寻找离目标最近的点
  13. t1 = clock()
  14. print ("time: ",t1-t0, "s")
  15. print (ret2)
  16. time: 5.4623788 s
  17. Result_tuple(nearest_point=[0.09929288205798159, 0.4954936771850429, 0.8005722800665575], nearest_dist=0.004597223680778027, nodes_visited=42)