用户定义的函数

译者:flink.sojb.cn

用户定义的函数是一个重要的特性,因为它们显着扩展了查询的表达能力。

注册用户定义的函数

在大多数情况下,必须先注册用户定义的函数,然后才能在查询中使用它。没有必要为Scala Table API注册函数。

TableEnvironment通过调用registerFunction()方法来注册函数。注册用户定义的函数时,会将其插入到函数目录中TableEnvironment,以便 Table API或SQL解析器可以识别并正确转换它。

请找到如何注册,如何调用每个类型的用户定义函数(详细的例子ScalarFunctionTableFunctionAggregateFunction下面的子会话)。

标量函数

如果内置函数中不包含必需的标量函数,则可以为 Table API和SQL定义自定义的,用户定义的标量函数。用户定义的标量函数将零个,一个或多个标量值映射到新的标量值。

为了定义一个标量函数之一具有以扩展的基类ScalarFunctionorg.apache.flink.table.functions和实现(一个或多个)的评价方法。标量函数的行为由评估方法确定。评估方法必须公开声明并命名eval。评估方法的参数类型和返回类型也确定标量函数的参数和返回类型。通过实现多个名为的方法,也可以重载评估方法eval。评估方法也可以支持变量参数,例如eval(String... strs)

以下示例显示如何定义自己的哈希代码函数,在TableEnvironment中注册它,并在查询中调用它。请注意,您可以在注册之前通过构造函数配置标量函数:

  1. public class HashCode extends ScalarFunction {
  2. private int factor = 12;
  3. public HashCode(int factor) {
  4. this.factor = factor;
  5. }
  6. public int eval(String s) {
  7. return s.hashCode() * factor;
  8. }
  9. }
  10. BatchTableEnvironment tableEnv = TableEnvironment.getTableEnvironment(env);
  11. // register the function
  12. tableEnv.registerFunction("hashCode", new HashCode(10));
  13. // use the function in Java Table API
  14. myTable.select("string, string.hashCode(), hashCode(string)");
  15. // use the function in SQL API
  16. tableEnv.sqlQuery("SELECT string, HASHCODE(string) FROM MyTable");
  1. // must be defined in static/object context class HashCode(factor: Int) extends ScalarFunction {
  2. def eval(s: String): Int = {
  3. s.hashCode() * factor
  4. }
  5. }
  6. val tableEnv = TableEnvironment.getTableEnvironment(env)
  7. // use the function in Scala Table API val hashCode = new HashCode(10)
  8. myTable.select('string, hashCode('string))
  9. // register and use the function in SQL tableEnv.registerFunction("hashCode", new HashCode(10))
  10. tableEnv.sqlQuery("SELECT string, HASHCODE(string) FROM MyTable")

默认情况下,评估方法的结果类型由Flink的类型提取工具确定。这对于基本类型或简单POJO就足够了,但对于更复杂,自定义或复合类型可能是错误的。在这些情况下TypeInformation,可以通过覆盖手动定义结果类型ScalarFunction#getResultType()

以下示例显示了一个高级示例,该示例采用内部时间戳表示形式,并将内部时间戳表示形式返回为long值。通过重写,ScalarFunction#getResultType()我们定义返回的long值应该Types.TIMESTAMP由代码生成解释为a 。

  1. public static class TimestampModifier extends ScalarFunction {
  2. public long eval(long t) {
  3. return t % 1000;
  4. }
  5. public TypeInformation<?> getResultType(signature: Class<?>[]) {
  6. return Types.TIMESTAMP;
  7. }
  8. }
  1. object TimestampModifier extends ScalarFunction {
  2. def eval(t: Long): Long = {
  3. t % 1000
  4. }
  5. override def getResultType(signature: Array[Class[_]]): TypeInformation[_] = {
  6. Types.TIMESTAMP
  7. }
  8. }

表函数

与用户定义的标量函数类似,用户定义的表函数将零个,一个或多个标量值作为输入参数。但是,与标量函数相比,它可以返回任意数量的行作为输出而不是单个值。返回的行可以包含一个或多个列。

为了定义表函数之一具有以扩展的基类TableFunctionorg.apache.flink.table.functions和实现(一个或多个)的评价方法。表函数的行为由其评估方法确定。必须声明public和命名评估方法eval。该TableFunction可以通过实施名为多种方法被重载eval。评估方法的参数类型确定表函数的所有有效参数。评估方法也可以支持变量参数,例如eval(String... strs)。返回表的类型由泛型类型确定TableFunction。评估方法使用受保护的方法发出输出行collect(T)

在该 Table API,表格函数用于.join(Expression).leftOuterJoin(Expression)Scala用户和.join(String).leftOuterJoin(String)针对Java用户。的join 算子(交叉)关联其中,从外部表与由表值函数(其是在 算子操作者的右侧)所产生的所有行的每行(表上的 算子操作者的左侧)。的leftOuterJoin 算子连接从外部表(在 算子左侧表)与由表值函数(其是在 算子操作者的右侧)所产生的所有行的每一行,并保存的量,表函数返回一个外部行空表。在SQL中使用LATERAL TABLE(&lt;TableFunction&gt;)CROSS JOIN和LEFT JOIN以及ON TRUE连接条件(参见下面的示例)。

以下示例显示如何定义表值函数,在TableEnvironment中注册它,并在查询中调用它。请注意,您可以在注册之前通过构造函数配置表函数:

  1. // The generic type "Tuple2<String, Integer>" determines the schema of the returned table as (String, Integer).
  2. public class Split extends TableFunction<Tuple2<String, Integer>> {
  3. private String separator = " ";
  4. public Split(String separator) {
  5. this.separator = separator;
  6. }
  7. public void eval(String str) {
  8. for (String s : str.split(separator)) {
  9. // use collect(...) to emit a row
  10. collect(new Tuple2<String, Integer>(s, s.length()));
  11. }
  12. }
  13. }
  14. BatchTableEnvironment tableEnv = TableEnvironment.getTableEnvironment(env);
  15. Table myTable = ... // table schema: [a: String]
  16. // Register the function.
  17. tableEnv.registerFunction("split", new Split("#"));
  18. // Use the table function in the Java Table API. "as" specifies the field names of the table.
  19. myTable.join("split(a) as (word, length)").select("a, word, length");
  20. myTable.leftOuterJoin("split(a) as (word, length)").select("a, word, length");
  21. // Use the table function in SQL with LATERAL and TABLE keywords.
  22. // CROSS JOIN a table function (equivalent to "join" in Table API).
  23. tableEnv.sqlQuery("SELECT a, word, length FROM MyTable, LATERAL TABLE(split(a)) as T(word, length)");
  24. // LEFT JOIN a table function (equivalent to "leftOuterJoin" in Table API).
  25. tableEnv.sqlQuery("SELECT a, word, length FROM MyTable LEFT JOIN LATERAL TABLE(split(a)) as T(word, length) ON TRUE");
  1. // The generic type "(String, Int)" determines the schema of the returned table as (String, Integer). class Split(separator: String) extends TableFunction[(String, Int)] {
  2. def eval(str: String): Unit = {
  3. // use collect(...) to emit a row.
  4. str.split(separator).foreach(x -> collect((x, x.length))
  5. }
  6. }
  7. val tableEnv = TableEnvironment.getTableEnvironment(env)
  8. val myTable = ... // table schema: [a: String]
  9. // Use the table function in the Scala Table API (Note: No registration required in Scala Table API). val split = new Split("#")
  10. // "as" specifies the field names of the generated table. myTable.join(split('a) as ('word, 'length)).select('a, 'word, 'length)
  11. myTable.leftOuterJoin(split('a) as ('word, 'length)).select('a, 'word, 'length)
  12. // Register the table function to use it in SQL queries. tableEnv.registerFunction("split", new Split("#"))
  13. // Use the table function in SQL with LATERAL and TABLE keywords.
  14. // CROSS JOIN a table function (equivalent to "join" in Table API) tableEnv.sqlQuery("SELECT a, word, length FROM MyTable, LATERAL TABLE(split(a)) as T(word, length)")
  15. // LEFT JOIN a table function (equivalent to "leftOuterJoin" in Table API) tableEnv.sqlQuery("SELECT a, word, length FROM MyTable LEFT JOIN TABLE(split(a)) as T(word, length) ON TRUE")

IMPORTANT: Do not implement TableFunction as a Scala object. Scala object is a singleton and will cause concurrency issues.

请注意,POJO类型没有确定性字段顺序。因此,您无法重命名由表函数返回的POJO字段AS

默认情况下,a的结果类型TableFunction由Flink的自动类型提取工具确定。这适用于基本类型和简单POJO,但对于更复杂,自定义或复合类型可能是错误的。在这种情况下,结果的类型可以通过覆盖TableFunction#getResultType()返回它来手动指定TypeInformation

以下示例显示了一个TableFunction返回Row需要显式类型信息的类型的示例。我们定义返回的表类型应该RowTypeInfo(String, Integer)通过重写TableFunction#getResultType()

  1. public class CustomTypeSplit extends TableFunction<Row> {
  2. public void eval(String str) {
  3. for (String s : str.split(" ")) {
  4. Row row = new Row(2);
  5. row.setField(0, s);
  6. row.setField(1, s.length);
  7. collect(row);
  8. }
  9. }
  10. @Override
  11. public TypeInformation<Row> getResultType() {
  12. return Types.ROW(Types.STRING(), Types.INT());
  13. }
  14. }
  1. class CustomTypeSplit extends TableFunction[Row] {
  2. def eval(str: String): Unit = {
  3. str.split(" ").foreach({ s =>
  4. val row = new Row(2)
  5. row.setField(0, s)
  6. row.setField(1, s.length)
  7. collect(row)
  8. })
  9. }
  10. override def getResultType: TypeInformation[Row] = {
  11. Types.ROW(Types.STRING, Types.INT)
  12. }
  13. }

聚合函数

用户定义的聚合函数(UDAGG)将一个表(一个或多个具有一个或多个属性的行)聚合到标量值。

UDAGG机制

上图显示了聚合的示例。假设您有一个包含饮料数据的表格。该表由三列的idnameprice5行。想象一下,您需要找到表中所有饮料的最高价格,即执行max()聚合。您需要检查5行中的每一行,结果将是单个数值。

用户定义的聚合函数通过扩展AggregateFunction类来实现。一个AggregateFunction作品如下。首先,它需要一个accumulator,它是保存聚合的中间结果的数据结构。通过调用createAccumulator()方法创建一个空累加器AggregateFunction。随后,accumulate()为每个输入行调用函数的方法以更新累加器。处理完所有行后,将getValue()调用该函数的方法来计算并返回最终结果。

每种方法都必须使用以下方法AggregateFunction

  • createAccumulator()
  • accumulate()
  • getValue()

Flink的类型提取工具无法识别复杂的数据类型,例如,如果它们不是基本类型或简单的POJO。类似于ScalarFunctionTableFunctionAggregateFunction提供了指定TypeInformation结果类型(通过 AggregateFunction#getResultType())和累加器类型(通过AggregateFunction#getAccumulatorType())的方法。

除了上述方法之外,还有一些可以选择性实施的简约方法。虽然其中一些方法允许系统更有效地执行查询,但其他方法对于某些用例是强制性的。例如,merge()如果聚合函数应该应用于会话组窗口的上下文中,则该方法是必需的(当观察到“连接”它们的行时,需要连接两个会话窗口的累加器)。

AggregateFunction根据用例,Required以下方法:

  • retract()有界OVER窗口上的聚合需要。
  • merge() 是许多批量聚合和会话窗口聚合所必需的。
  • resetAccumulator() 是许多批量聚合所必需的。

所有方法AggregateFunction必须声明为public,而不是static完全按照上面提到的名称命名。该方法createAccumulatorgetValuegetResultType,和getAccumulatorType在定义的AggregateFunction抽象类,而另一些则收缩的方法。为了定义聚合函数,必须扩展基类org.apache.flink.table.functions.AggregateFunction并实现一个(或多个)accumulate方法。该方法accumulate可以使用不同的参数类型重载,并支持可变参数。

AggregateFunction下面给出了所有方法的详细文档。

  1. /**
  2. * Base class for aggregation functions.
  3. *
  4. * @param <T> the type of the aggregation result
  5. * @param <ACC> the type of the aggregation accumulator. The accumulator is used to keep the
  6. * aggregated values which are needed to compute an aggregation result.
  7. * AggregateFunction represents its state using accumulator, thereby the state of the
  8. * AggregateFunction must be put into the accumulator.
  9. */
  10. public abstract class AggregateFunction<T, ACC> extends UserDefinedFunction {
  11. /**
  12. * Creates and init the Accumulator for this [[AggregateFunction]].
  13. *
  14. * @return the accumulator with the initial value
  15. */
  16. public ACC createAccumulator(); // MANDATORY
  17. /** Processes the input values and update the provided accumulator instance. The method
  18. * accumulate can be overloaded with different custom types and arguments. An AggregateFunction
  19. * requires at least one accumulate() method.
  20. *
  21. * @param accumulator the accumulator which contains the current aggregated results
  22. * @param [user defined inputs] the input value (usually obtained from a new arrived data).
  23. */
  24. public void accumulate(ACC accumulator, [user defined inputs]); // MANDATORY
  25. /**
  26. * Retracts the input values from the accumulator instance. The current design assumes the
  27. * inputs are the values that have been previously accumulated. The method retract can be
  28. * overloaded with different custom types and arguments. This function must be implemented for
  29. * datastream bounded over aggregate.
  30. *
  31. * @param accumulator the accumulator which contains the current aggregated results
  32. * @param [user defined inputs] the input value (usually obtained from a new arrived data).
  33. */
  34. public void retract(ACC accumulator, [user defined inputs]); // OPTIONAL
  35. /**
  36. * Merges a group of accumulator instances into one accumulator instance. This function must be
  37. * implemented for datastream session window grouping aggregate and dataset grouping aggregate.
  38. *
  39. * @param accumulator the accumulator which will keep the merged aggregate results. It should
  40. * be noted that the accumulator may contain the previous aggregated
  41. * results. Therefore user should not replace or clean this instance in the
  42. * custom merge method.
  43. * @param its an [[java.lang.Iterable]] pointed to a group of accumulators that will be
  44. * merged.
  45. */
  46. public void merge(ACC accumulator, java.lang.Iterable<ACC> its); // OPTIONAL
  47. /**
  48. * Called every time when an aggregation result should be materialized.
  49. * The returned value could be either an early and incomplete result
  50. * (periodically emitted as data arrive) or the final result of the
  51. * aggregation.
  52. *
  53. * @param accumulator the accumulator which contains the current
  54. * aggregated results
  55. * @return the aggregation result
  56. */
  57. public T getValue(ACC accumulator); // MANDATORY
  58. /**
  59. * Resets the accumulator for this [[AggregateFunction]]. This function must be implemented for
  60. * dataset grouping aggregate.
  61. *
  62. * @param accumulator the accumulator which needs to be reset
  63. */
  64. public void resetAccumulator(ACC accumulator); // OPTIONAL
  65. /**
  66. * Returns true if this AggregateFunction can only be applied in an OVER window.
  67. *
  68. * @return true if the AggregateFunction requires an OVER window, false otherwise.
  69. */
  70. public Boolean requiresOver = false; // PRE-DEFINED
  71. /**
  72. * Returns the TypeInformation of the AggregateFunction's result.
  73. *
  74. * @return The TypeInformation of the AggregateFunction's result or null if the result type
  75. * should be automatically inferred.
  76. */
  77. public TypeInformation<T> getResultType = null; // PRE-DEFINED
  78. /**
  79. * Returns the TypeInformation of the AggregateFunction's accumulator.
  80. *
  81. * @return The TypeInformation of the AggregateFunction's accumulator or null if the
  82. * accumulator type should be automatically inferred.
  83. */
  84. public TypeInformation<T> getAccumulatorType = null; // PRE-DEFINED
  85. }
  1. /**
  2. * Base class for aggregation functions.
  3. *
  4. * @tparam T the type of the aggregation result
  5. * @tparam ACC the type of the aggregation accumulator. The accumulator is used to keep the
  6. * aggregated values which are needed to compute an aggregation result.
  7. * AggregateFunction represents its state using accumulator, thereby the state of the
  8. * AggregateFunction must be put into the accumulator.
  9. */
  10. abstract class AggregateFunction[T, ACC] extends UserDefinedFunction {
  11. /**
  12. * Creates and init the Accumulator for this [[AggregateFunction]].
  13. *
  14. * @return the accumulator with the initial value
  15. */
  16. def createAccumulator(): ACC // MANDATORY
  17. /**
  18. * Processes the input values and update the provided accumulator instance. The method
  19. * accumulate can be overloaded with different custom types and arguments. An AggregateFunction
  20. * requires at least one accumulate() method.
  21. *
  22. * @param accumulator the accumulator which contains the current aggregated results
  23. * @param [user defined inputs] the input value (usually obtained from a new arrived data).
  24. */
  25. def accumulate(accumulator: ACC, [user defined inputs]): Unit // MANDATORY
  26. /**
  27. * Retracts the input values from the accumulator instance. The current design assumes the
  28. * inputs are the values that have been previously accumulated. The method retract can be
  29. * overloaded with different custom types and arguments. This function must be implemented for
  30. * datastream bounded over aggregate.
  31. *
  32. * @param accumulator the accumulator which contains the current aggregated results
  33. * @param [user defined inputs] the input value (usually obtained from a new arrived data).
  34. */
  35. def retract(accumulator: ACC, [user defined inputs]): Unit // OPTIONAL
  36. /**
  37. * Merges a group of accumulator instances into one accumulator instance. This function must be
  38. * implemented for datastream session window grouping aggregate and dataset grouping aggregate.
  39. *
  40. * @param accumulator the accumulator which will keep the merged aggregate results. It should
  41. * be noted that the accumulator may contain the previous aggregated
  42. * results. Therefore user should not replace or clean this instance in the
  43. * custom merge method.
  44. * @param its an [[java.lang.Iterable]] pointed to a group of accumulators that will be
  45. * merged.
  46. */
  47. def merge(accumulator: ACC, its: java.lang.Iterable[ACC]): Unit // OPTIONAL
  48. /**
  49. * Called every time when an aggregation result should be materialized.
  50. * The returned value could be either an early and incomplete result
  51. * (periodically emitted as data arrive) or the final result of the
  52. * aggregation.
  53. *
  54. * @param accumulator the accumulator which contains the current
  55. * aggregated results
  56. * @return the aggregation result
  57. */
  58. def getValue(accumulator: ACC): T // MANDATORY
  59. h/**
  60. * Resets the accumulator for this [[AggregateFunction]]. This function must be implemented for
  61. * dataset grouping aggregate.
  62. *
  63. * @param accumulator the accumulator which needs to be reset
  64. */
  65. def resetAccumulator(accumulator: ACC): Unit // OPTIONAL
  66. /**
  67. * Returns true if this AggregateFunction can only be applied in an OVER window.
  68. *
  69. * @return true if the AggregateFunction requires an OVER window, false otherwise.
  70. */
  71. def requiresOver: Boolean = false // PRE-DEFINED
  72. /**
  73. * Returns the TypeInformation of the AggregateFunction's result.
  74. *
  75. * @return The TypeInformation of the AggregateFunction's result or null if the result type
  76. * should be automatically inferred.
  77. */
  78. def getResultType: TypeInformation[T] = null // PRE-DEFINED
  79. /**
  80. * Returns the TypeInformation of the AggregateFunction's accumulator.
  81. *
  82. * @return The TypeInformation of the AggregateFunction's accumulator or null if the
  83. * accumulator type should be automatically inferred.
  84. */
  85. def getAccumulatorType: TypeInformation[ACC] = null // PRE-DEFINED }

以下示例说明如何 算子操作

  • 定义一个AggregateFunction计算给定列的加权平均值,
  • TableEnvironment和中注册函数
  • 在查询中使用该函数。

为了计算加权平均值,累加器需要存储已累积的所有数据的加权和和计数。在我们的示例中,我们将一个类定义为WeightedAvgAccum累加器。Flink的检查点机制自动备份累加器,并在无法确保一次性语义的情况下进行恢复。

accumulate()我们的方法WeightedAvg AggregateFunction有三个输入。第一个是WeightedAvgAccum累加器,另外两个是用户定义的输入:输入值ivalue和输入的权重iweight。虽然retract()merge()resetAccumulator()方法不是强制性的最聚集的类型,我们提供以下举例它们。请注意,我们使用Java基本类型和定义getResultType(),并getAccumulatorType()在Scala例如方法,因为Flink类型提取不Scala类型的工作非常好。

  1. /**
  2. * Accumulator for WeightedAvg.
  3. */
  4. public static class WeightedAvgAccum {
  5. public long sum = 0;
  6. public int count = 0;
  7. }
  8. /**
  9. * Weighted Average user-defined aggregate function.
  10. */
  11. public static class WeightedAvg extends AggregateFunction<Long, WeightedAvgAccum> {
  12. @Override
  13. public WeightedAvgAccum createAccumulator() {
  14. return new WeightedAvgAccum();
  15. }
  16. @Override
  17. public Long getValue(WeightedAvgAccum acc) {
  18. if (acc.count == 0) {
  19. return null;
  20. } else {
  21. return acc.sum / acc.count;
  22. }
  23. }
  24. public void accumulate(WeightedAvgAccum acc, long iValue, int iWeight) {
  25. acc.sum += iValue * iWeight;
  26. acc.count += iWeight;
  27. }
  28. public void retract(WeightedAvgAccum acc, long iValue, int iWeight) {
  29. acc.sum -= iValue * iWeight;
  30. acc.count -= iWeight;
  31. }
  32. public void merge(WeightedAvgAccum acc, Iterable<WeightedAvgAccum> it) {
  33. Iterator<WeightedAvgAccum> iter = it.iterator();
  34. while (iter.hasNext()) {
  35. WeightedAvgAccum a = iter.next();
  36. acc.count += a.count;
  37. acc.sum += a.sum;
  38. }
  39. }
  40. public void resetAccumulator(WeightedAvgAccum acc) {
  41. acc.count = 0;
  42. acc.sum = 0L;
  43. }
  44. }
  45. // register function
  46. StreamTableEnvironment tEnv = ...
  47. tEnv.registerFunction("wAvg", new WeightedAvg());
  48. // use function
  49. tEnv.sqlQuery("SELECT user, wAvg(points, level) AS avgPoints FROM userScores GROUP BY user");
  1. import java.lang.{Long => JLong, Integer => JInteger}
  2. import org.apache.flink.api.java.tuple.{Tuple1 => JTuple1}
  3. import org.apache.flink.api.java.typeutils.TupleTypeInfo
  4. import org.apache.flink.table.api.Types
  5. import org.apache.flink.table.functions.AggregateFunction
  6. /**
  7. * Accumulator for WeightedAvg.
  8. */
  9. class WeightedAvgAccum extends JTuple1[JLong, JInteger] {
  10. sum = 0L
  11. count = 0
  12. }
  13. /**
  14. * Weighted Average user-defined aggregate function.
  15. */
  16. class WeightedAvg extends AggregateFunction[JLong, CountAccumulator] {
  17. override def createAccumulator(): WeightedAvgAccum = {
  18. new WeightedAvgAccum
  19. }
  20. override def getValue(acc: WeightedAvgAccum): JLong = {
  21. if (acc.count == 0) {
  22. null
  23. } else {
  24. acc.sum / acc.count
  25. }
  26. }
  27. def accumulate(acc: WeightedAvgAccum, iValue: JLong, iWeight: JInteger): Unit = {
  28. acc.sum += iValue * iWeight
  29. acc.count += iWeight
  30. }
  31. def retract(acc: WeightedAvgAccum, iValue: JLong, iWeight: JInteger): Unit = {
  32. acc.sum -= iValue * iWeight
  33. acc.count -= iWeight
  34. }
  35. def merge(acc: WeightedAvgAccum, it: java.lang.Iterable[WeightedAvgAccum]): Unit = {
  36. val iter = it.iterator()
  37. while (iter.hasNext) {
  38. val a = iter.next()
  39. acc.count += a.count
  40. acc.sum += a.sum
  41. }
  42. }
  43. def resetAccumulator(acc: WeightedAvgAccum): Unit = {
  44. acc.count = 0
  45. acc.sum = 0L
  46. }
  47. override def getAccumulatorType: TypeInformation[WeightedAvgAccum] = {
  48. new TupleTypeInfo(classOf[WeightedAvgAccum], Types.LONG, Types.INT)
  49. }
  50. override def getResultType: TypeInformation[JLong] = Types.LONG
  51. }
  52. // register function val tEnv: StreamTableEnvironment = ???
  53. tEnv.registerFunction("wAvg", new WeightedAvg())
  54. // use function tEnv.sqlQuery("SELECT user, wAvg(points, level) AS avgPoints FROM userScores GROUP BY user")

实施UDF的最佳实践

Table API和SQL代码生成在内部尝试尽可能地使用原始值。用户定义的函数可以通过对象创建,转换和(un)装箱引入很多开销。因此,强烈建议将参数和结果类型声明为基本类型而不是它们的盒装类。Types.DATE并且Types.TIME也可以表示为intTypes.TIMESTAMP可以表示为long

我们建议用户定义的函数应该由Java而不是Scala编写,因为Scala类型对Flink的类型提取器构成了挑战。

将UDF与运行时集成

有时,用户定义的函数可能需要获取全局运行时信息,或者在实际工作之前进行一些设置/清理工作。用户定义的函数提供open()close()方法可以被覆盖,并提供与RichFunctionDataSet或DataStream API中的方法类似的函数。

open()在评估方法之前调用该方法一次。在close()该评价方法最后一次通话之后方法。

open()方法提供FunctionContext包含关于执行用户定义的函数的上下文的信息,例如度量标准组,分布式缓存文件或全局作业参数。

通过调用以下相应的方法可以获得以下信息FunctionContext

方法 描述
getMetricGroup() 此并行子任务的度量标准组。
getCachedFile(name) 分布式缓存文件的本地临时文件副本。
getJobParameter(name, defaultValue) 与给定键关联的全局作业参数值。

以下示例代码段显示了如何FunctionContext在标量函数中使用以访问全局作业参数:

  1. public class HashCode extends ScalarFunction {
  2. private int factor = 0;
  3. @Override
  4. public void open(FunctionContext context) throws Exception {
  5. // access "hashcode_factor" parameter
  6. // "12" would be the default value if parameter does not exist
  7. factor = Integer.valueOf(context.getJobParameter("hashcode_factor", "12"));
  8. }
  9. public int eval(String s) {
  10. return s.hashCode() * factor;
  11. }
  12. }
  13. ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
  14. BatchTableEnvironment tableEnv = TableEnvironment.getTableEnvironment(env);
  15. // set job parameter
  16. Configuration conf = new Configuration();
  17. conf.setString("hashcode_factor", "31");
  18. env.getConfig().setGlobalJobParameters(conf);
  19. // register the function
  20. tableEnv.registerFunction("hashCode", new HashCode());
  21. // use the function in Java Table API
  22. myTable.select("string, string.hashCode(), hashCode(string)");
  23. // use the function in SQL
  24. tableEnv.sqlQuery("SELECT string, HASHCODE(string) FROM MyTable");
  1. object hashCode extends ScalarFunction {
  2. var hashcode_factor = 12
  3. override def open(context: FunctionContext): Unit = {
  4. // access "hashcode_factor" parameter
  5. // "12" would be the default value if parameter does not exist
  6. hashcode_factor = context.getJobParameter("hashcode_factor", "12").toInt
  7. }
  8. def eval(s: String): Int = {
  9. s.hashCode() * hashcode_factor
  10. }
  11. }
  12. val tableEnv = TableEnvironment.getTableEnvironment(env)
  13. // use the function in Scala Table API myTable.select('string, hashCode('string))
  14. // register and use the function in SQL tableEnv.registerFunction("hashCode", hashCode)
  15. tableEnv.sqlQuery("SELECT string, HASHCODE(string) FROM MyTable")