Flink所有的内部函数都字义在类org.apache.flink.table.planner.functions.sql.FlinkSqlOperatorTable中,其中有很多的函数是复用类org.apache.calcite.sql.fun.SqlStdOperatorTable中的函数。
生成对应java代码在类逻辑org.apache.flink.table.planner.codegen.calls.FunctionGenerator

函数类型及实现

Aggregate Functions

函数的实现在org.apache.flink.table.planner.functions.aggfunctions中,在类org.apache.flink.table.planner.plan.utils.AggFunctionFactory中实现函数定义到函数实现逻辑的转化,具体代码如下
函数定义代码

  1. // 类 org.apache.flink.table.planner.functions.sql.FlinkSqlOperatorTable
  2. // AGGREGATE OPERATORS
  3. public static final SqlAggFunction SUM = SqlStdOperatorTable.SUM;
  4. public static final SqlAggFunction SUM0 = SqlStdOperatorTable.SUM0;
  5. public static final SqlAggFunction COUNT = SqlStdOperatorTable.COUNT;
  6. public static final SqlAggFunction COLLECT = SqlStdOperatorTable.COLLECT;
  7. public static final SqlAggFunction MIN = SqlStdOperatorTable.MIN;
  8. public static final SqlAggFunction MAX = SqlStdOperatorTable.MAX;
  9. public static final SqlAggFunction AVG = SqlStdOperatorTable.AVG;
  10. public static final SqlAggFunction STDDEV = SqlStdOperatorTable.STDDEV;
  11. public static final SqlAggFunction STDDEV_POP = SqlStdOperatorTable.STDDEV_POP;
  12. public static final SqlAggFunction STDDEV_SAMP = SqlStdOperatorTable.STDDEV_SAMP;
  13. public static final SqlAggFunction VARIANCE = SqlStdOperatorTable.VARIANCE;
  14. public static final SqlAggFunction VAR_POP = SqlStdOperatorTable.VAR_POP;
  15. public static final SqlAggFunction VAR_SAMP = SqlStdOperatorTable.VAR_SAMP;
  16. public static final SqlAggFunction SINGLE_VALUE = SqlStdOperatorTable.SINGLE_VALUE;

函数实现代码

  1. // 类 org.apache.flink.table.planner.plan.utils.AggFunctionFactory
  2. /**
  3. * The entry point to create an aggregate function from the given AggregateCall
  4. */
  5. def createAggFunction(call: AggregateCall, index: Int): UserDefinedFunction = {
  6. val argTypes: Array[LogicalType] = call.getArgList
  7. .map(inputType.getFieldList.get(_).getType)
  8. .map(FlinkTypeFactory.toLogicalType)
  9. .toArray
  10. call.getAggregation match {
  11. case a: SqlAvgAggFunction if a.kind == SqlKind.AVG => createAvgAggFunction(argTypes)
  12. case _: SqlSumAggFunction => createSumAggFunction(argTypes, index)
  13. case _: SqlSumEmptyIsZeroAggFunction => createSum0AggFunction(argTypes)
  14. case a: SqlMinMaxAggFunction if a.getKind == SqlKind.MIN =>
  15. createMinAggFunction(argTypes, index)
  16. case a: SqlMinMaxAggFunction if a.getKind == SqlKind.MAX =>
  17. createMaxAggFunction(argTypes, index)
  18. case _: SqlCountAggFunction if call.getArgList.size() > 1 =>
  19. throw new TableException("We now only support the count of one field.")
  20. // TODO supports ApproximateCountDistinctAggFunction and CountDistinctAggFunction
  21. case _: SqlCountAggFunction if call.getArgList.isEmpty => createCount1AggFunction(argTypes)
  22. case _: SqlCountAggFunction => createCountAggFunction(argTypes)
  23. case a: SqlRankFunction if a.getKind == SqlKind.ROW_NUMBER =>
  24. createRowNumberAggFunction(argTypes)
  25. case a: SqlRankFunction if a.getKind == SqlKind.RANK =>
  26. createRankAggFunction(argTypes)
  27. case a: SqlRankFunction if a.getKind == SqlKind.DENSE_RANK =>
  28. createDenseRankAggFunction(argTypes)
  29. case _: SqlLeadLagAggFunction =>
  30. createLeadLagAggFunction(argTypes, index)
  31. case _: SqlSingleValueAggFunction =>
  32. createSingleValueAggFunction(argTypes)
  33. case a: SqlFirstLastValueAggFunction if a.getKind == SqlKind.FIRST_VALUE =>
  34. createFirstValueAggFunction(argTypes, index)
  35. case a: SqlFirstLastValueAggFunction if a.getKind == SqlKind.LAST_VALUE =>
  36. createLastValueAggFunction(argTypes, index)
  37. case _: SqlListAggFunction if call.getArgList.size() == 1 =>
  38. createListAggFunction(argTypes, index)
  39. case _: SqlListAggFunction if call.getArgList.size() == 2 =>
  40. createListAggWsFunction(argTypes, index)
  41. // TODO supports SqlCardinalityCountAggFunction
  42. case a: SqlAggFunction if a.getKind == SqlKind.COLLECT =>
  43. createCollectAggFunction(argTypes)
  44. case udagg: AggSqlFunction =>
  45. // Can not touch the literals, Calcite make them in previous RelNode.
  46. // In here, all inputs are input refs.
  47. val constants = new util.ArrayList[AnyRef]()
  48. argTypes.foreach(t => constants.add(null))
  49. udagg.makeFunction(
  50. constants.toArray,
  51. argTypes)
  52. case unSupported: SqlAggFunction =>
  53. throw new TableException(s"Unsupported Function: '${unSupported.getName}'")
  54. }
  55. }

生成函数的code Generator的方法如下

  1. // 类 org.apache.flink.table.planner.codegen.agg.AggsHandlerCodeGenerator
  2. private def initialAggregateInformation(aggInfoList: AggregateInfoList): Unit = {
  3. this.accTypeInfo = RowType.of(
  4. aggInfoList.getAccTypes.map(fromDataTypeToLogicalType): _*)
  5. this.aggBufferSize = accTypeInfo.getFieldCount
  6. var aggBufferOffset: Int = 0
  7. if (mergedAccExternalTypes == null) {
  8. mergedAccExternalTypes = aggInfoList.getAccTypes
  9. }
  10. val aggCodeGens = aggInfoList.aggInfos.map { aggInfo =>
  11. val filterExpr = createFilterExpression(
  12. aggInfo.agg.filterArg,
  13. aggInfo.aggIndex,
  14. aggInfo.agg.name)
  15. val codegen = aggInfo.function match {
  16. case _: DeclarativeAggregateFunction =>
  17. new DeclarativeAggCodeGen(
  18. ctx,
  19. aggInfo,
  20. filterExpr,
  21. mergedAccOffset,
  22. aggBufferOffset,
  23. aggBufferSize,
  24. inputFieldTypes,
  25. constants,
  26. relBuilder)
  27. case _: UserDefinedAggregateFunction[_, _] =>
  28. new ImperativeAggCodeGen(
  29. ctx,
  30. aggInfo,
  31. filterExpr,
  32. mergedAccOffset,
  33. aggBufferOffset,
  34. aggBufferSize,
  35. inputFieldTypes,
  36. constantExprs,
  37. relBuilder,
  38. hasNamespace,
  39. mergedAccOnHeap,
  40. mergedAccExternalTypes(aggBufferOffset),
  41. copyInputField)
  42. }
  43. aggBufferOffset = aggBufferOffset + aggInfo.externalAccTypes.length
  44. codegen
  45. }
  46. val distinctCodeGens = aggInfoList.distinctInfos.zipWithIndex.map {
  47. case (distinctInfo, index) =>
  48. val innerCodeGens = distinctInfo.aggIndexes.map(aggCodeGens(_)).toArray
  49. val distinctIndex = aggCodeGens.length + index
  50. val filterExpr = distinctInfo.filterArgs.map(
  51. createFilterExpression(_, distinctIndex, "distinct aggregate"))
  52. val codegen = new DistinctAggCodeGen(
  53. ctx,
  54. distinctInfo,
  55. index,
  56. innerCodeGens,
  57. filterExpr.toArray,
  58. mergedAccOffset,
  59. aggBufferOffset,
  60. aggBufferSize,
  61. hasNamespace,
  62. isMergeNeeded,
  63. mergedAccOnHeap,
  64. distinctInfo.consumeRetraction,
  65. copyInputField,
  66. relBuilder)
  67. // distinct agg buffer occupies only one field
  68. aggBufferOffset += 1
  69. codegen
  70. }
  71. val distinctAggIndexes = aggInfoList.distinctInfos.flatMap(_.aggIndexes)
  72. val nonDistinctAggIndexes = aggCodeGens.indices.filter(!distinctAggIndexes.contains(_)).toArray
  73. this.aggBufferCodeGens = aggCodeGens ++ distinctCodeGens
  74. this.aggActionCodeGens = nonDistinctAggIndexes.map(aggCodeGens(_)) ++ distinctCodeGens
  75. // when input contains retractions, we inserted a count1 agg in the agg list
  76. // the count1 agg value shouldn't be in the aggregate result
  77. if (aggInfoList.indexOfCountStar.nonEmpty && aggInfoList.countStarInserted) {
  78. ignoreAggValues ++= Array(aggInfoList.indexOfCountStar.get)
  79. }
  80. // the distinct value shouldn't be in the aggregate result
  81. if (aggInfoList.distinctInfos.nonEmpty) {
  82. ignoreAggValues ++= distinctCodeGens.indices.map(_ + aggCodeGens.length)
  83. }
  84. }

其它function

代码生成在包org.apache.flink.table.planner.codegen.calls.FunctionGenerator