UDF
计算平均年龄
{"username": "zs" , "age" : 18}{"username": "ls" , "age" : 19}
case class Buff( var valCount : Long , var indexCount : Long) class MyAvgUDAF extends Aggregator[ Long, Buff , Long ] { // [输入值,缓冲区,输出值] // 缓冲区初始值 override def zero: Buff = { Buff(0L,0L) } // 输入值 到 缓冲区聚合 override def reduce(buff: Buff, a: Long): Buff = { buff.valCount += a buff.indexCount += 1 buff } // 缓冲区 分区间 聚合 override def merge(b1: Buff, b2: Buff): Buff = { b1.valCount += b2.valCount b1.indexCount += b2.indexCount b1 } //最终输出 override def finish(res: Buff): Long = { res.valCount / res.indexCount } override def bufferEncoder: Encoder[Buff] = Encoders.product // 如果是自定义类型 就必须这样 override def outputEncoder: Encoder[Long] = Encoders.scalaLong // 如果不是 看类型 }}
def main(args: Array[String]): Unit = { //TODO 创建SparkSQL的运行环境 val sparkConf = new SparkConf().setMaster("local[*]").setAppName("sparkSQL_UDAF1") val spark = SparkSession.builder().config(sparkConf).getOrCreate() //TODO 执行逻辑操作 val df: DataFrame = spark.read.json("datas/sql/user.json") df.show() //用户自定义聚合函数 spark.udf.register("userAvg", functions.udaf(new MyAvgUDAF)) df.createOrReplaceTempView("user") spark.sql("select userAvg(age) from user").show() //TODO 关闭环境 spark.close() }