UDF
计算平均年龄
{"username": "zs" , "age" : 18}
{"username": "ls" , "age" : 19}
case class Buff( var valCount : Long , var indexCount : Long)
class MyAvgUDAF extends Aggregator[ Long, Buff , Long ] { // [输入值,缓冲区,输出值]
// 缓冲区初始值
override def zero: Buff = {
Buff(0L,0L)
}
// 输入值 到 缓冲区聚合
override def reduce(buff: Buff, a: Long): Buff = {
buff.valCount += a
buff.indexCount += 1
buff
}
// 缓冲区 分区间 聚合
override def merge(b1: Buff, b2: Buff): Buff = {
b1.valCount += b2.valCount
b1.indexCount += b2.indexCount
b1
}
//最终输出
override def finish(res: Buff): Long = {
res.valCount / res.indexCount
}
override def bufferEncoder: Encoder[Buff] = Encoders.product // 如果是自定义类型 就必须这样
override def outputEncoder: Encoder[Long] = Encoders.scalaLong // 如果不是 看类型
}
}
def main(args: Array[String]): Unit = {
//TODO 创建SparkSQL的运行环境
val sparkConf = new SparkConf().setMaster("local[*]").setAppName("sparkSQL_UDAF1")
val spark = SparkSession.builder().config(sparkConf).getOrCreate()
//TODO 执行逻辑操作
val df: DataFrame = spark.read.json("datas/sql/user.json")
df.show()
//用户自定义聚合函数
spark.udf.register("userAvg", functions.udaf(new MyAvgUDAF))
df.createOrReplaceTempView("user")
spark.sql("select userAvg(age) from user").show()
//TODO 关闭环境
spark.close()
}