迭代算法出现在数据分析的许多领域,例如机器学习或图形分析。Flink程序通过定义步进函数并将其嵌入到特殊的迭代运算符中来实现迭代算法。实际应用中有两种变体:Bulk Iterate和Delta Iterate。两个运算符都在当前迭代状态下重复调用step函数,直到达到某个终止条件为止。

Bulk Iterations Delta Iterations
迭代输入 Partial Solution WorksetSolution Set
步进功能 任意数据流
状态更新 下一个
Partial Solution

- 下一个Workset
- Solution Set的更改
迭代结果 最后的
Partial Solution
最后一次迭代后的
Solution Set
终止条件
- 最大迭代次数(默认)
- 自定义的收敛判断

- 最大迭代次数或空Workset(默认)
- 自定义的收敛判断

Bulk Iterations

简介

要创建BulkIteration调用,iterate(int)应从迭代开始的DataSet方法开始。然后返回IterativeDataSet,在迭代中可以使用各个常规的算子进行转换。当单次迭代结束后,调用closeWith(DataSet)方法来定义迭代的返回点。用户可以通过定义最大迭代次数来指定迭代的结束条件,也可以使用另一个方式来指定终止条件,即closeWith(DataSet, DataSet),如果第二个DataSet为空,迭代将结束。

demo

我们通过Flink中的KMeans实现来了解Bulk Interactions的基本代码结构。k均值聚类算法(k-means clustering algorithm)是一种迭代求解的聚类分析算法,其步骤是,预将数据分为K组,则随机选取K个对象作为初始的聚类中心,然后计算每个对象与各个种子聚类中心之间的距离,把每个对象分配给距离它最近的聚类中心。聚类中心以及分配给它们的对象就代表一个聚类。每分配一个样本,聚类的聚类中心会根据聚类中现有的对象被重新计算。这个过程将不断重复直到满足某个终止条件。在本例中,我们人为指定了初始的聚类中心,而且结束条件定义为达到最大迭代次数后结束。
image.png

  1. public class BulkIteration {
  2. public static void main(String[] args) throws Exception {
  3. final ParameterTool params = ParameterTool.fromArgs(args);
  4. final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
  5. env.getConfig().setGlobalJobParameters(params);
  6. // 读取数据点文件和中心点文件
  7. DataSet<Point> points = getPointDataSet(env);
  8. DataSet<Centroid> centroids = getCentroidDataSet(env);
  9. // 设置KMeans最大迭代次数
  10. IterativeDataSet<Centroid> loop = centroids.iterate(params.getInt("iterations", 10));
  11. DataSet<Centroid> newCentroids = points
  12. // 计算距离每个点最近的中心点
  13. .map(new SelectNearestCenter())
  14. // 每个分区都接受全量的聚类中心数据
  15. .withBroadcastSet(loop, "centroids")
  16. // 将数据点归到每个新的聚类中心
  17. .map(new CountAppender())
  18. // 按聚类中心分组
  19. .groupBy(0)
  20. .reduce(new CentroidAccumulator())
  21. // 计算新的聚类中心
  22. .map(new CentroidAverager());
  23. // 返回新的聚类中心到下一次迭代中
  24. DataSet<Centroid> finalCentroids = loop.closeWith(newCentroids);
  25. DataSet<Tuple2<Integer, Point>> clusteredPoints = points
  26. // 计算每个数据点的最终归属
  27. .map(new SelectNearestCenter())
  28. .withBroadcastSet(finalCentroids, "centroids");
  29. if (params.has("output")) {
  30. clusteredPoints.writeAsCsv(params.get("output"), "\n", ",", FileSystem.WriteMode.OVERWRITE);
  31. } else {
  32. System.out.println("Printing result to stdout. Use --output to specify output path.");
  33. clusteredPoints.printOnTaskManager("centroids:");
  34. }
  35. env.execute("KMeans Example");
  36. }
  37. private static DataSet<Centroid> getCentroidDataSet(ExecutionEnvironment env) throws IOException {
  38. DataSet<Centroid> centroids;
  39. URL fileUrl = BulkIteration.class.getClassLoader().getResource("centers");
  40. centroids = env.readCsvFile(fileUrl.getPath())
  41. .fieldDelimiter(" ")
  42. .pojoType(Centroid.class, "id", "x", "y");
  43. return centroids;
  44. }
  45. private static DataSet<Point> getPointDataSet(ExecutionEnvironment env) throws IOException {
  46. DataSet<Point> points;
  47. URL fileUrl = BulkIteration.class.getClassLoader().getResource("points");
  48. points = env.readCsvFile(fileUrl.getPath())
  49. .fieldDelimiter(" ")
  50. .pojoType(Point.class, "x", "y");
  51. return points;
  52. }
  53. public static class Point implements Serializable {
  54. public double x, y;
  55. public Point() {
  56. }
  57. public Point(double x, double y) {
  58. this.x = x;
  59. this.y = y;
  60. }
  61. public Point add(Point other) {
  62. x += other.x;
  63. y += other.y;
  64. return this;
  65. }
  66. public Point div(long val) {
  67. x /= val;
  68. y /= val;
  69. return this;
  70. }
  71. public double euclideanDistance(Point other) {
  72. return Math.sqrt((x - other.x) * (x - other.x) + (y - other.y) * (y - other.y));
  73. }
  74. public void clear() {
  75. x = y = 0.0;
  76. }
  77. @Override
  78. public String toString() {
  79. return x + " " + y;
  80. }
  81. }
  82. public static class Centroid extends Point {
  83. public int id;
  84. public Centroid() {
  85. }
  86. public Centroid(int id, double x, double y) {
  87. super(x, y);
  88. this.id = id;
  89. }
  90. public Centroid(int id, Point p) {
  91. super(p.x, p.y);
  92. this.id = id;
  93. }
  94. @Override
  95. public String toString() {
  96. return id + " " + super.toString();
  97. }
  98. }
  99. /**
  100. * 计算距离某个数据点最近的聚类中心
  101. */
  102. @FunctionAnnotation.ForwardedFields("*->1")
  103. public static final class SelectNearestCenter extends RichMapFunction<Point, Tuple2<Integer, Point>> {
  104. private Collection<Centroid> centroids;
  105. /**
  106. * 在各分区上获取广播中的全量聚类中心数据
  107. */
  108. @Override
  109. public void open(Configuration parameters) throws Exception {
  110. this.centroids = getRuntimeContext().getBroadcastVariable("centroids");
  111. }
  112. @Override
  113. public Tuple2<Integer, Point> map(Point p) throws Exception {
  114. double minDistance = Double.MAX_VALUE;
  115. int closestCentroidId = -1;
  116. for (Centroid centroid : centroids) {
  117. // 计算距离
  118. double distance = p.euclideanDistance(centroid);
  119. // 更新最近的距离
  120. if (distance < minDistance) {
  121. minDistance = distance;
  122. closestCentroidId = centroid.id;
  123. }
  124. }
  125. // 输出各个数据点以及其对应的最近聚类中心id
  126. return new Tuple2<>(closestCentroidId, p);
  127. }
  128. }
  129. /**
  130. * 附加新的数据点到聚类中心
  131. */
  132. @FunctionAnnotation.ForwardedFields("f0;f1")
  133. public static final class CountAppender implements MapFunction<Tuple2<Integer, Point>, Tuple3<Integer, Point, Long>> {
  134. @Override
  135. public Tuple3<Integer, Point, Long> map(Tuple2<Integer, Point> t) {
  136. return new Tuple3<>(t.f0, t.f1, 1L);
  137. }
  138. }
  139. /**
  140. * 累加每个聚类中心的数据点信息
  141. */
  142. @FunctionAnnotation.ForwardedFields("0")
  143. public static final class CentroidAccumulator implements ReduceFunction<Tuple3<Integer, Point, Long>> {
  144. @Override
  145. public Tuple3<Integer, Point, Long> reduce(Tuple3<Integer, Point, Long> val1, Tuple3<Integer, Point, Long> val2) {
  146. return new Tuple3<>(val1.f0, val1.f1.add(val2.f1), val1.f2 + val2.f2);
  147. }
  148. }
  149. /**
  150. * 计算新的聚类中心
  151. */
  152. @FunctionAnnotation.ForwardedFields("0->id")
  153. public static final class CentroidAverager implements MapFunction<Tuple3<Integer, Point, Long>, Centroid> {
  154. @Override
  155. public Centroid map(Tuple3<Integer, Point, Long> value) {
  156. return new Centroid(value.f0, value.f1.div(value.f2));
  157. }
  158. }
  159. }

Delta Iterations

简介

在Bulk Interactions中,每次迭代中,所有的输入数据都会重新参与计算,直至形成新的输出结构。但在某些算法不会在每次迭代中更改解决输入数据集中的每个数据点,Delta Interactions正是适用于这类算法。Delta Iterations在迭代中有两个数据集,一个称为WorkSet,另一个称为SolutionSet。在每次迭代后,将返回之前WorkSet的部分数据,不再参与计算的将不返回,同时返回更新后的SolutionSet。迭代计算的结果是最后一次迭代后的SolutionSet。要创建DeltaIteration,请分别调用iterateDelta(DataSet, int, int)(或iterateDelta(DataSet, int, int[]))。在输入数据集上调用上述方法,参数分别是初始数据集,最大迭代次数和键位置,返回的 DeltaIteration对象,可以通过访问iteration.getWorkset()和iteration.getSolutionSet()来获取所需的集合并附加新的算子。

demo

本例利用Delta Iterations来实现在连通图中传播最小值的目的。连通图中每个顶点都有一个ID和一个值,该值最终应等于此点与其周围临近点中ID最小的点的值。每个顶点会将其顶点ID传播到相邻的顶点。该算法目标是将最小ID分配给子图的每个顶点。如果接收到的ID小于当前ID,则它将更改为具有接收ID的顶点值。每个顶点的初始值默认设为自己的ID,通过多次迭代,将每个顶点的值调整为最终的临近最小ID。在迭代过程中,如果某个顶点的值没有发生变化,则它将不再参加下一次迭代中,这就是Delta Iterations的优势,它可以不断缩减WorkSet的大小,直至为空,最后退出迭代。而SolutionSet在此过程中不断更新,成为最终的目标结果。
image.png

  1. @SuppressWarnings("serial")
  2. public class DeltaIteration {
  3. public static void main(String[] args) throws Exception {
  4. final ParameterTool params = ParameterTool.fromArgs(args);
  5. final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
  6. final int maxIterations = params.getInt("iterations", 10);
  7. env.getConfig().setGlobalJobParameters(params);
  8. DataSet<Long> vertices = getVertexDataSet(env, params);
  9. // 将初始数据集中的单向边补全对应反向的边,形成完整的无向图边集合
  10. DataSet<Tuple2<Long, Long>> edges = getEdgeDataSet(env, params).flatMap(new UndirectEdge());
  11. // 初始分配各个顶点的临近最小值为自身
  12. DataSet<Tuple2<Long, Long>> verticesWithInitialId = vertices.map(new DuplicateValue<Long>());
  13. // 开启增量迭代
  14. org.apache.flink.api.java.operators.DeltaIteration<Tuple2<Long, Long>, Tuple2<Long, Long>> iteration = verticesWithInitialId
  15. .iterateDelta(verticesWithInitialId, maxIterations, 0);
  16. // 选取附近的最小值更新到自己的值
  17. DataSet<Tuple2<Long, Long>> changes = iteration
  18. .getWorkset()
  19. .join(edges)
  20. .where(0)
  21. .equalTo(0)
  22. // 将当前数据点的值传播到所有临近点
  23. .with(new NeighborWithComponentIDJoin())
  24. .groupBy(0)
  25. // 在所有临近点的值中选取最小值
  26. .aggregate(Aggregations.MIN, 1)
  27. .join(iteration.getSolutionSet())
  28. .where(0)
  29. .equalTo(0)
  30. // 如果临近点的最小值小于当前数据点的目标值,则更新
  31. .with(new ComponentIdFilter());
  32. // 单次迭代闭环
  33. DataSet<Tuple2<Long, Long>> result = iteration.closeWith(changes, changes);
  34. if (params.has("output")) {
  35. result.writeAsCsv(params.get("output"), "\n", " ");
  36. env.execute("Connected Components Example");
  37. } else {
  38. System.out.println("Printing result to stdout. Use --output to specify output path.");
  39. result.print();
  40. }
  41. }
  42. /**
  43. * 初始化(点-值)元组
  44. */
  45. @FunctionAnnotation.ForwardedFields("*->f0")
  46. public static final class DuplicateValue<T> implements MapFunction<T, Tuple2<T, T>> {
  47. @Override
  48. public Tuple2<T, T> map(T vertex) {
  49. return new Tuple2<T, T>(vertex, vertex);
  50. }
  51. }
  52. /**
  53. * 将单向边映射为反向边,即将有向图变成无向图
  54. */
  55. public static final class UndirectEdge implements FlatMapFunction<Tuple2<Long, Long>, Tuple2<Long, Long>> {
  56. Tuple2<Long, Long> invertedEdge = new Tuple2<Long, Long>();
  57. @Override
  58. public void flatMap(Tuple2<Long, Long> edge, Collector<Tuple2<Long, Long>> out) {
  59. invertedEdge.f0 = edge.f1;
  60. invertedEdge.f1 = edge.f0;
  61. out.collect(edge);
  62. out.collect(invertedEdge);
  63. }
  64. }
  65. /**
  66. * 读取所有临近点的值
  67. */
  68. @FunctionAnnotation.ForwardedFieldsFirst("f1->f1")
  69. @FunctionAnnotation.ForwardedFieldsSecond("f1->f0")
  70. public static final class NeighborWithComponentIDJoin implements JoinFunction<Tuple2<Long, Long>, Tuple2<Long, Long>, Tuple2<Long, Long>> {
  71. @Override
  72. public Tuple2<Long, Long> join(Tuple2<Long, Long> vertexWithComponent, Tuple2<Long, Long> edge) {
  73. return new Tuple2<Long, Long>(edge.f1, vertexWithComponent.f1);
  74. }
  75. }
  76. /**
  77. * 如果候选值小于当前数据点的目标值,则输出新的(点-值)元组。
  78. */
  79. @FunctionAnnotation.ForwardedFieldsFirst("*")
  80. public static final class ComponentIdFilter implements FlatJoinFunction<Tuple2<Long, Long>, Tuple2<Long, Long>, Tuple2<Long, Long>> {
  81. @Override
  82. public void join(Tuple2<Long, Long> candidate, Tuple2<Long, Long> old, Collector<Tuple2<Long, Long>> out) {
  83. if (candidate.f1 < old.f1) {
  84. out.collect(candidate);
  85. }
  86. }
  87. }
  88. private static DataSet<Long> getVertexDataSet(ExecutionEnvironment env, ParameterTool params) {
  89. if (params.has("vertices")) {
  90. return env.readCsvFile(params.get("vertices")).types(Long.class).map(
  91. new MapFunction<Tuple1<Long>, Long>() {
  92. @Override
  93. public Long map(Tuple1<Long> value) {
  94. return value.f0;
  95. }
  96. });
  97. } else {
  98. System.out.println("Executing Connected Components example with default vertices data set.");
  99. System.out.println("Use --vertices to specify file input.");
  100. return ConnectedComponentsData.getDefaultVertexDataSet(env);
  101. }
  102. }
  103. private static DataSet<Tuple2<Long, Long>> getEdgeDataSet(ExecutionEnvironment env, ParameterTool params) {
  104. if (params.has("edges")) {
  105. return env.readCsvFile(params.get("edges")).fieldDelimiter(" ").types(Long.class, Long.class);
  106. } else {
  107. System.out.println("Executing Connected Components example with default edges data set.");
  108. System.out.println("Use --edges to specify file input.");
  109. return ConnectedComponentsData.getDefaultEdgeDataSet(env);
  110. }
  111. }
  112. }