UDF

计算平均年龄

  1. {"username": "zs" , "age" : 18}
  2. {"username": "ls" , "age" : 19}
  1. case class Buff( var valCount : Long , var indexCount : Long)
  2. class MyAvgUDAF extends Aggregator[ Long, Buff , Long ] { // [输入值,缓冲区,输出值]
  3. // 缓冲区初始值
  4. override def zero: Buff = {
  5. Buff(0L,0L)
  6. }
  7. // 输入值 到 缓冲区聚合
  8. override def reduce(buff: Buff, a: Long): Buff = {
  9. buff.valCount += a
  10. buff.indexCount += 1
  11. buff
  12. }
  13. // 缓冲区 分区间 聚合
  14. override def merge(b1: Buff, b2: Buff): Buff = {
  15. b1.valCount += b2.valCount
  16. b1.indexCount += b2.indexCount
  17. b1
  18. }
  19. //最终输出
  20. override def finish(res: Buff): Long = {
  21. res.valCount / res.indexCount
  22. }
  23. override def bufferEncoder: Encoder[Buff] = Encoders.product // 如果是自定义类型 就必须这样
  24. override def outputEncoder: Encoder[Long] = Encoders.scalaLong // 如果不是 看类型
  25. }
  26. }
  1. def main(args: Array[String]): Unit = {
  2. //TODO 创建SparkSQL的运行环境
  3. val sparkConf = new SparkConf().setMaster("local[*]").setAppName("sparkSQL_UDAF1")
  4. val spark = SparkSession.builder().config(sparkConf).getOrCreate()
  5. //TODO 执行逻辑操作
  6. val df: DataFrame = spark.read.json("datas/sql/user.json")
  7. df.show()
  8. //用户自定义聚合函数
  9. spark.udf.register("userAvg", functions.udaf(new MyAvgUDAF))
  10. df.createOrReplaceTempView("user")
  11. spark.sql("select userAvg(age) from user").show()
  12. //TODO 关闭环境
  13. spark.close()
  14. }