\newcommand{\R}{\mathbb{R}} \newcommand{\E}{\mathbb{E}} \newcommand{\x}{\mathbf{x}} \newcommand{\y}{\mathbf{y}} \newcommand{\wv}{\mathbf{w}} \newcommand{\av}{\mathbf{\alpha}} \newcommand{\bv}{\mathbf{b}} \newcommand{\N}{\mathbb{N}} \newcommand{\id}{\mathbf{I}} \newcommand{\ind}{\mathbf{1}} \newcommand{\0}{\mathbf{0}} \newcommand{\unit}{\mathbf{e}} \newcommand{\one}{\mathbf{1}} \newcommand{\zero}{\mathbf{0}} \newcommand\rfrac[2]{^{#1}!/_{#2}} \newcommand{\norm}[1]{\left\lVert#1\right\rVert}

k-Nearest Neighbors Join

描述

实现一个精确的k-nearest neighbors join算法。给定一个训练集$A$和一个测试集$B$,算法返回:

蛮力法是计算每个训练点和测试点之间的距离。为了简化蛮力计算,利用四叉树计算每个训练点之间的距离。四叉树在训练点的数量上具有良好的扩展能力,但在空间维度上的扩展能力较差。该算法将自动选择是否使用四叉树,尽管用户可以通过设置一个参数强制使用或不使用四叉树来覆盖该决定。

操作

KNN是一个预测因子。因此,它支持“拟合”和“预测”操作。

拟合

KNN由给定的一组“Vector”训练:

  • fit[T <: Vector]: DataSet[T] => Unit

预测

KNN预测FlinkML的’Vector’所有子类型对应的类标签:

  • predict[T <: Vector]: DataSet[T] => DataSet[(T, Array[Vector])], where the (T, Array[Vector]) tuple corresponds to (test point, K-nearest training points)

参数

KNN的实现可以有下列参数进行控制:

参数 描述
K 定义要搜索的最近邻居的数量。也就是说,对于每个测试点,算法都在训练集中找到k个最近邻居 (默认值: 5)
DistanceMetric 设置我们用来计算两点之间距离的距离度量。如果没有指定度量,那么[[org.apache.flink.ml.metrics.distances.EuclideanDistanceMetric]]将默认使用. (默认值: EuclideanDistanceMetric)
Blocks 设置将输入数据切分到的块的数量。这个数字至少应该设置为并行度。如果没有指定值,则使用输入[[DataSet]]的并行度作为块的数量 (默认值: None)
UseQuadTree 一个布尔变量,是否使用四叉树对训练集进行分区,以潜在地简化KNN搜索。如果没有指定值,代码将自动决定是否使用四叉树。使用四叉树可以很好地扩展训练点和测试点的数量,但是在维数方面却很差。 (默认值: None)
SizeHint 指定训练集或测试集是否小,以优化KNN搜索所需的叉乘操作。如果训练集很小,这个应该设置为’CrossHint.FIRST_IS_SMALL’,当测试集很小的时候,应该设置为’CrossHint.SECOND_IS_SMALL’ (默认值: None)

Examples

  1. import org.apache.flink.api.common.operators.base.CrossOperatorBase.CrossHint
  2. import org.apache.flink.api.scala._
  3. import org.apache.flink.ml.nn.KNN
  4. import org.apache.flink.ml.math.Vector
  5. import org.apache.flink.ml.metrics.distances.SquaredEuclideanDistanceMetric
  6. val env = ExecutionEnvironment.getExecutionEnvironment
  7. // prepare data val trainingSet: DataSet[Vector] = ...
  8. val testingSet: DataSet[Vector] = ...
  9. val knn = KNN()
  10. .setK(3)
  11. .setBlocks(10)
  12. .setDistanceMetric(SquaredEuclideanDistanceMetric())
  13. .setUseQuadTree(false)
  14. .setSizeHint(CrossHint.SECOND_IS_SMALL)
  15. // run knn join knn.fit(trainingSet)
  16. val result = knn.predict(testingSet).collect()

更多关于计算KNN带不带四叉树的细节,这里有一个演示:http://danielblazevski.github.io/