kd树是一种对k维空间中的实例点进行存储以便对其进行快速检索的树形数据结构。
kd树是二叉树,表示对维空间的一个划分(partition)。构造kd树相当于不断地用垂直于坐标轴的超平面将
维空间切分,构成一系列的k维超矩形区域。kd树的每个结点对应于一个
维超矩形区域。
构造kd树的方法如下:
构造根结点,使根结点对应于维空间中包含所有实例点的超矩形区域;通过下面的递归方法,不断地对
维空间进行切分,生成子结点。在超矩形区域(结点)上选择一个坐标轴和在此坐标轴上的一个切分点,确定一个超平面,这个超平面通过选定的切分点并垂直于选定的坐标轴,将当前超矩形区域切分为左右两个子区域 (子结点);这时,实例被分到两个子区域。这个过程直到子区域内没有实例时终止(终止时的结点为叶结点)。在此过程中,将实例保存在相应的结点上。
通常,依次选择坐标轴对空间切分,选择训练实例点在选定坐标轴上的中位数 (median)为切分点,这样得到的kd树是平衡的。注意,平衡的kd树搜索时的效率未必是最优的。
构造平衡kd树算法
输入:维空间数据集
,
其中 ,
;
输出:kd树。
(1)开始:构造根结点,根结点对应于包含的
维空间的超矩形区域。
选择为坐标轴,以T中所有实例的
坐标的中位数为切分点,将根结点对应的超矩形区域切分为两个子区域。切分由通过切分点并与坐标轴
垂直的超平面实现。
由根结点生成深度为1的左、右子结点:左子结点对应坐标小于切分点的子区域, 右子结点对应于坐标
大于切分点的子区域。
将落在切分超平面上的实例点保存在根结点。
(2)重复:对深度为的结点,选择
为切分的坐标轴,
,以该结点的区域中所有实例的
坐标的中位数为切分点,将该结点对应的超矩形区域切分为两个子区域。切分由通过切分点并与坐标轴
垂直的超平面实现。
由该结点生成深度为的左、右子结点:左子结点对应坐标
小于切分点的子区域,右子结点对应坐标
大于切分点的子区域。
将落在切分超平面上的实例点保存在该结点。
(3)直到两个子区域没有实例存在时停止。从而形成kd树的区域划分。
代码
# kd-tree每个结点中主要包含的数据结构如下class KdNode(object):def __init__(self, dom_elt, split, left, right):self.dom_elt = dom_elt # k维向量节点(k维空间中的一个样本点)self.split = split # 整数(进行分割维度的序号)self.left = left # 该结点分割超平面左子空间构成的kd-treeself.right = right # 该结点分割超平面右子空间构成的kd-treeclass KdTree(object):def __init__(self, data):k = len(data[0]) # 数据维度def CreateNode(split, data_set): # 按第split维划分数据集exset创建KdNodeif not data_set: # 数据集为空return None# key参数的值为一个函数,此函数只有一个参数且返回一个值用来进行比较# operator模块提供的itemgetter函数用于获取对象的哪些维的数据,参数为需要获取的数据在对象中的序号#data_set.sort(key=itemgetter(split)) # 按要进行分割的那一维数据排序data_set.sort(key=lambda x: x[split])split_pos = len(data_set) // 2 # //为Python中的整数除法median = data_set[split_pos] # 中位数分割点split_next = (split + 1) % k # cycle coordinates# 递归的创建kd树return KdNode(median,split,CreateNode(split_next, data_set[:split_pos]), # 创建左子树CreateNode(split_next, data_set[split_pos + 1:])) # 创建右子树self.root = CreateNode(0, data) # 从第0维分量开始构建kd树,返回根节点# KDTree的前序遍历def preorder(root):print(root.dom_elt)if root.left: # 节点不为空preorder(root.left)if root.right:preorder(root.right)
_
# 对构建好的kd树进行搜索,寻找与目标点最近的样本点:from math import sqrtfrom collections import namedtuple# 定义一个namedtuple,分别存放最近坐标点、最近距离和访问过的节点数result = namedtuple("Result_tuple", "nearest_point nearest_dist nodes_visited")def find_nearest(tree, point):k = len(point) # 数据维度def travel(kd_node, target, max_dist):if kd_node is None:return result([0] * k, float("inf"),0) # python中用float("inf")和float("-inf")表示正负无穷nodes_visited = 1s = kd_node.split # 进行分割的维度pivot = kd_node.dom_elt # 进行分割的“轴”if target[s] <= pivot[s]: # 如果目标点第s维小于分割轴的对应值(目标离左子树更近)nearer_node = kd_node.left # 下一个访问节点为左子树根节点further_node = kd_node.right # 同时记录下右子树else: # 目标离右子树更近nearer_node = kd_node.right # 下一个访问节点为右子树根节点further_node = kd_node.lefttemp1 = travel(nearer_node, target, max_dist) # 进行遍历找到包含目标点的区域nearest = temp1.nearest_point # 以此叶结点作为“当前最近点”dist = temp1.nearest_dist # 更新最近距离nodes_visited += temp1.nodes_visitedif dist < max_dist:max_dist = dist # 最近点将在以目标点为球心,max_dist为半径的超球体内temp_dist = abs(pivot[s] - target[s]) # 第s维上目标点与分割超平面的距离if max_dist < temp_dist: # 判断超球体是否与超平面相交return result(nearest, dist, nodes_visited) # 不相交则可以直接返回,不用继续判断#----------------------------------------------------------------------# 计算目标点与分割点的欧氏距离temp_dist = sqrt(sum((p1 - p2)**2 for p1, p2 in zip(pivot, target)))if temp_dist < dist: # 如果“更近”nearest = pivot # 更新最近点dist = temp_dist # 更新最近距离max_dist = dist # 更新超球体半径# 检查另一个子结点对应的区域是否有更近的点temp2 = travel(further_node, target, max_dist)nodes_visited += temp2.nodes_visitedif temp2.nearest_dist < dist: # 如果另一个子结点内存在更近距离nearest = temp2.nearest_point # 更新最近点dist = temp2.nearest_dist # 更新最近距离return result(nearest, dist, nodes_visited)return travel(tree.root, point, float("inf")) # 从根节点开始递归
例题3.2

data = [[2,3],[5,4],[9,6],[4,7],[8,1],[7,2]]kd = KdTree(data)preorder(kd.root)[7, 2][5, 4][2, 3][4, 7][9, 6][8, 1]# ----------------------------------------------------------ret = find_nearest(kd, [3,4.5])print (ret)Result_tuple(nearest_point=[2, 3], nearest_dist=1.8027756377319946, nodes_visited=4)
随机情况
from time import clockfrom random import random# 产生一个k维随机向量,每维分量值在0~1之间def random_point(k):return [random() for _ in range(k)]# 产生n个k维随机向量def random_points(k, n):return [random_point(k) for _ in range(n)]N = 400000t0 = clock()kd2 = KdTree(random_points(3, N)) # 构建包含四十万个3维空间样本点的kd树ret2 = find_nearest(kd2, [0.1,0.5,0.8]) # 四十万个样本点中寻找离目标最近的点t1 = clock()print ("time: ",t1-t0, "s")print (ret2)time: 5.4623788 sResult_tuple(nearest_point=[0.09929288205798159, 0.4954936771850429, 0.8005722800665575], nearest_dist=0.004597223680778027, nodes_visited=42)
