Flink所有的内部函数都字义在类org.apache.flink.table.planner.functions.sql.FlinkSqlOperatorTable
中,其中有很多的函数是复用类org.apache.calcite.sql.fun.SqlStdOperatorTable
中的函数。
生成对应java代码在类逻辑org.apache.flink.table.planner.codegen.calls.FunctionGenerator
中
函数类型及实现
Aggregate Functions
函数的实现在org.apache.flink.table.planner.functions.aggfunctions
中,在类org.apache.flink.table.planner.plan.utils.AggFunctionFactory
中实现函数定义到函数实现逻辑的转化,具体代码如下
函数定义代码
// 类 org.apache.flink.table.planner.functions.sql.FlinkSqlOperatorTable
// AGGREGATE OPERATORS
public static final SqlAggFunction SUM = SqlStdOperatorTable.SUM;
public static final SqlAggFunction SUM0 = SqlStdOperatorTable.SUM0;
public static final SqlAggFunction COUNT = SqlStdOperatorTable.COUNT;
public static final SqlAggFunction COLLECT = SqlStdOperatorTable.COLLECT;
public static final SqlAggFunction MIN = SqlStdOperatorTable.MIN;
public static final SqlAggFunction MAX = SqlStdOperatorTable.MAX;
public static final SqlAggFunction AVG = SqlStdOperatorTable.AVG;
public static final SqlAggFunction STDDEV = SqlStdOperatorTable.STDDEV;
public static final SqlAggFunction STDDEV_POP = SqlStdOperatorTable.STDDEV_POP;
public static final SqlAggFunction STDDEV_SAMP = SqlStdOperatorTable.STDDEV_SAMP;
public static final SqlAggFunction VARIANCE = SqlStdOperatorTable.VARIANCE;
public static final SqlAggFunction VAR_POP = SqlStdOperatorTable.VAR_POP;
public static final SqlAggFunction VAR_SAMP = SqlStdOperatorTable.VAR_SAMP;
public static final SqlAggFunction SINGLE_VALUE = SqlStdOperatorTable.SINGLE_VALUE;
函数实现代码
// 类 org.apache.flink.table.planner.plan.utils.AggFunctionFactory
/**
* The entry point to create an aggregate function from the given AggregateCall
*/
def createAggFunction(call: AggregateCall, index: Int): UserDefinedFunction = {
val argTypes: Array[LogicalType] = call.getArgList
.map(inputType.getFieldList.get(_).getType)
.map(FlinkTypeFactory.toLogicalType)
.toArray
call.getAggregation match {
case a: SqlAvgAggFunction if a.kind == SqlKind.AVG => createAvgAggFunction(argTypes)
case _: SqlSumAggFunction => createSumAggFunction(argTypes, index)
case _: SqlSumEmptyIsZeroAggFunction => createSum0AggFunction(argTypes)
case a: SqlMinMaxAggFunction if a.getKind == SqlKind.MIN =>
createMinAggFunction(argTypes, index)
case a: SqlMinMaxAggFunction if a.getKind == SqlKind.MAX =>
createMaxAggFunction(argTypes, index)
case _: SqlCountAggFunction if call.getArgList.size() > 1 =>
throw new TableException("We now only support the count of one field.")
// TODO supports ApproximateCountDistinctAggFunction and CountDistinctAggFunction
case _: SqlCountAggFunction if call.getArgList.isEmpty => createCount1AggFunction(argTypes)
case _: SqlCountAggFunction => createCountAggFunction(argTypes)
case a: SqlRankFunction if a.getKind == SqlKind.ROW_NUMBER =>
createRowNumberAggFunction(argTypes)
case a: SqlRankFunction if a.getKind == SqlKind.RANK =>
createRankAggFunction(argTypes)
case a: SqlRankFunction if a.getKind == SqlKind.DENSE_RANK =>
createDenseRankAggFunction(argTypes)
case _: SqlLeadLagAggFunction =>
createLeadLagAggFunction(argTypes, index)
case _: SqlSingleValueAggFunction =>
createSingleValueAggFunction(argTypes)
case a: SqlFirstLastValueAggFunction if a.getKind == SqlKind.FIRST_VALUE =>
createFirstValueAggFunction(argTypes, index)
case a: SqlFirstLastValueAggFunction if a.getKind == SqlKind.LAST_VALUE =>
createLastValueAggFunction(argTypes, index)
case _: SqlListAggFunction if call.getArgList.size() == 1 =>
createListAggFunction(argTypes, index)
case _: SqlListAggFunction if call.getArgList.size() == 2 =>
createListAggWsFunction(argTypes, index)
// TODO supports SqlCardinalityCountAggFunction
case a: SqlAggFunction if a.getKind == SqlKind.COLLECT =>
createCollectAggFunction(argTypes)
case udagg: AggSqlFunction =>
// Can not touch the literals, Calcite make them in previous RelNode.
// In here, all inputs are input refs.
val constants = new util.ArrayList[AnyRef]()
argTypes.foreach(t => constants.add(null))
udagg.makeFunction(
constants.toArray,
argTypes)
case unSupported: SqlAggFunction =>
throw new TableException(s"Unsupported Function: '${unSupported.getName}'")
}
}
生成函数的code Generator的方法如下
// 类 org.apache.flink.table.planner.codegen.agg.AggsHandlerCodeGenerator
private def initialAggregateInformation(aggInfoList: AggregateInfoList): Unit = {
this.accTypeInfo = RowType.of(
aggInfoList.getAccTypes.map(fromDataTypeToLogicalType): _*)
this.aggBufferSize = accTypeInfo.getFieldCount
var aggBufferOffset: Int = 0
if (mergedAccExternalTypes == null) {
mergedAccExternalTypes = aggInfoList.getAccTypes
}
val aggCodeGens = aggInfoList.aggInfos.map { aggInfo =>
val filterExpr = createFilterExpression(
aggInfo.agg.filterArg,
aggInfo.aggIndex,
aggInfo.agg.name)
val codegen = aggInfo.function match {
case _: DeclarativeAggregateFunction =>
new DeclarativeAggCodeGen(
ctx,
aggInfo,
filterExpr,
mergedAccOffset,
aggBufferOffset,
aggBufferSize,
inputFieldTypes,
constants,
relBuilder)
case _: UserDefinedAggregateFunction[_, _] =>
new ImperativeAggCodeGen(
ctx,
aggInfo,
filterExpr,
mergedAccOffset,
aggBufferOffset,
aggBufferSize,
inputFieldTypes,
constantExprs,
relBuilder,
hasNamespace,
mergedAccOnHeap,
mergedAccExternalTypes(aggBufferOffset),
copyInputField)
}
aggBufferOffset = aggBufferOffset + aggInfo.externalAccTypes.length
codegen
}
val distinctCodeGens = aggInfoList.distinctInfos.zipWithIndex.map {
case (distinctInfo, index) =>
val innerCodeGens = distinctInfo.aggIndexes.map(aggCodeGens(_)).toArray
val distinctIndex = aggCodeGens.length + index
val filterExpr = distinctInfo.filterArgs.map(
createFilterExpression(_, distinctIndex, "distinct aggregate"))
val codegen = new DistinctAggCodeGen(
ctx,
distinctInfo,
index,
innerCodeGens,
filterExpr.toArray,
mergedAccOffset,
aggBufferOffset,
aggBufferSize,
hasNamespace,
isMergeNeeded,
mergedAccOnHeap,
distinctInfo.consumeRetraction,
copyInputField,
relBuilder)
// distinct agg buffer occupies only one field
aggBufferOffset += 1
codegen
}
val distinctAggIndexes = aggInfoList.distinctInfos.flatMap(_.aggIndexes)
val nonDistinctAggIndexes = aggCodeGens.indices.filter(!distinctAggIndexes.contains(_)).toArray
this.aggBufferCodeGens = aggCodeGens ++ distinctCodeGens
this.aggActionCodeGens = nonDistinctAggIndexes.map(aggCodeGens(_)) ++ distinctCodeGens
// when input contains retractions, we inserted a count1 agg in the agg list
// the count1 agg value shouldn't be in the aggregate result
if (aggInfoList.indexOfCountStar.nonEmpty && aggInfoList.countStarInserted) {
ignoreAggValues ++= Array(aggInfoList.indexOfCountStar.get)
}
// the distinct value shouldn't be in the aggregate result
if (aggInfoList.distinctInfos.nonEmpty) {
ignoreAggValues ++= distinctCodeGens.indices.map(_ + aggCodeGens.length)
}
}
其它function
代码生成在包org.apache.flink.table.planner.codegen.calls.FunctionGenerator
中