K近邻算法(KNN)

描述

实现一个精密K最近邻居连接(exact k-nearest neighbors join)算法。假设有一个训练集 $A$ 和一个测试集 $B$,该算法返回

$$ KNNJ(A, B, k) = { \left( b, KNN(b, A, k) \right) \text{ where } b \in B \text{ and } KNN(b, A, k) \text{ are the k-nearest points to }b\text{ in }A } $$

该暴力方法的目标是计算每个训练点和测试点之间的距离。为了使计算每个训练点之间距离的暴力计算过程更加简化和平滑,本方法使用一个四叉树。该四叉树在训练点的数量上有很好的扩展性,但是在空间维度上的扩展性表现不佳。本算法会自动选择是否采用该四叉树,用户也可以通过设置一个参数来覆盖算法的决定,强制指定是否使用该四叉树。

操作

KNN 是一个 Predictor。 正如所示, 它支持 fitpredict 操作。

拟合

KNN 通过一个给定的 Vector 集来训练:

  • fit[T <: Vector]: DataSet[T] => Unit

预测

KNN 为所有的FlinkML的 Vector 的子类预测对应的类别标签:

  • predict[T <: Vector]: DataSet[T] => DataSet[(T, Array[Vector])], 这里 (T, Array[Vector]) 元组对应 (test point, K-nearest training points)

参数

KNN的实现可以由以下参数控制:

参数 描述
K

定义要搜索的最近邻居数量。也就是说,对于每一个测试点,该算法会从训练集中找到K个最近邻居.(默认值: 5)

DistanceMetric

设置用来计算两点之间距离的度量标准。如果没有指定度量标准,则[[org.apache.flink.ml.metrics.distances.EuclideanDistanceMetric]] 被使用.(默认值: EuclideanDistanceMetric)

Blocks

设置输入数据将会被切分的块数。该数目至少应该被设置成与并行度相等。如果没有指定块数,则使用作为输入的 [[DataSet]] 的平行度作为块数.(默认值: None)

UseQuadTree

一个布尔参数,该参数用来指定是否使用能够对训练集进行分区,并且有可能简化平滑KNN搜索的四叉树。如果该值没有指定,则代码会自动决定是否使用一个四叉树。四叉树的使用在训练点和测试点的数量上有很好的扩展性,但在维度上的扩展性表现不佳.(默认值: None)

SizeHint

指定训练集或测试集是否小到能优化KNN搜索所需的向量乘操作。如果训练集小,该值应该是 CrossHint.FIRST_IS_SMALL,如果测试集小,则设置成 CrossHint.SECOND_IS_SMALL.(默认值: None)

示例

  1. import org.apache.flink.api.common.operators.base.CrossOperatorBase.CrossHint
  2. import org.apache.flink.api.scala._
  3. import org.apache.flink.ml.nn.KNN
  4. import org.apache.flink.ml.math.Vector
  5. import org.apache.flink.ml.metrics.distances.SquaredEuclideanDistanceMetric
  6. val env = ExecutionEnvironment.getExecutionEnvironment
  7. // 准备数据
  8. val trainingSet: DataSet[Vector] = ...
  9. val testingSet: DataSet[Vector] = ...
  10. val knn = KNN()
  11. .setK(3)
  12. .setBlocks(10)
  13. .setDistanceMetric(SquaredEuclideanDistanceMetric())
  14. .setUseQuadTree(false)
  15. .setSizeHint(CrossHint.SECOND_IS_SMALL)
  16. // 运行 knn join
  17. knn.fit(trainingSet)
  18. val result = knn.predict(testingSet).collect()

关于使用和不使用四叉树计算KNN的更多细节,参照该介绍: http://danielblazevski.github.io/