json数据

  1. [{"name":"张三" ,"age":18} ,{"name":"李四" ,"age":15}]

代码编写

  1. import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
  2. import org.apache.spark.sql.types._
  3. import org.apache.spark.sql.{Row, SparkSession}
  4. import scala.collection.immutable.Nil
  5. object UDAFDemo {
  6. def main(args: Array[String]): Unit = {
  7. // 在sql中, 聚合函数如何使用
  8. val spark: SparkSession = SparkSession
  9. .builder()
  10. .master("local[*]")
  11. .appName("UDAFDemo")
  12. .getOrCreate()
  13. val df = spark.read.json("E:\\ZJJ_SparkSQL\\demo01\\src\\main\\resources\\users.json")
  14. df.createOrReplaceTempView("user")
  15. // 注册聚合函数
  16. spark.udf.register("myAvg", new MyAvg)
  17. spark.sql("select myAvg(age) from user").show
  18. spark.close()
  19. }
  20. }
  21. /**
  22. * 求平均值
  23. */
  24. class MyAvg extends UserDefinedAggregateFunction {
  25. // 输入的数据类型 10.1 12.2 100
  26. override def inputSchema: StructType = StructType(StructField("ele", DoubleType) :: Nil)
  27. // 缓冲区的类型
  28. // 求平均值需要两个值运算,一个是年龄的和,另外一个是多少个年龄参与运算,
  29. // 所以这就是求平均值了.
  30. override def bufferSchema: StructType =
  31. StructType(StructField("sum", DoubleType) :: StructField("count", LongType) :: Nil)
  32. // 最终聚合解结果的类型
  33. override def dataType: DataType = DoubleType
  34. // 相同的输入是否返回相同的输出
  35. override def deterministic: Boolean = true
  36. // 对缓冲区初始化
  37. override def initialize(buffer: MutableAggregationBuffer): Unit = {
  38. // avg初始化是一个值,个数初始化也得是一个值.
  39. // 在缓冲集合中初始化和
  40. buffer(0) = 0D // 等价于 buffer.update(0, 0D)
  41. buffer(1) = 0L //带个L,不然就存成int类型了.
  42. }
  43. // 分区内聚合
  44. override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
  45. input match {
  46. case Row(age: Double) =>
  47. buffer(0) = buffer.getDouble(0) + age //年龄进行相加
  48. buffer(1) = buffer.getLong(1) + 1L //个数累加碰到一个就加1
  49. case _ =>
  50. }
  51. /*// input是指的使用聚合函数的时候, 缓过来的参数封装到了Row
  52. if (!input.isNullAt(0)) { // 考虑到传字段可能是null
  53. val v = input.getAs[Double](0) // getDouble(0)
  54. buffer(0) = buffer.getDouble(0) + v
  55. buffer(1) = buffer.getLong(1) + 1L
  56. }*/
  57. }
  58. // 分区间的聚合
  59. override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
  60. buffer2 match {
  61. case Row(sum: Double, count: Long) =>
  62. // 缓冲区和要集合
  63. buffer1(0) = buffer1.getDouble(0) + sum
  64. //个数也要聚合
  65. buffer1(1) = buffer1.getLong(1) + count
  66. case _ =>
  67. }
  68. // 把buffer1和buffer2 的缓冲弄聚合到一起, 然后再把值写回到buffer1
  69. /*buffer1(0) = buffer1.getDouble(0) + buffer2.getDouble(0)
  70. buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)*/
  71. }
  72. // 返回最终的输出值
  73. // 就是累加的总数除以个数,就是平均值了.
  74. override def evaluate(buffer: Row): Any = buffer.getDouble(0) / buffer.getLong(1)
  75. }
  76. /**
  77. * 求和的函数
  78. */
  79. class MySum extends UserDefinedAggregateFunction {
  80. // 输入的数据类型 10.1 12.2 100
  81. override def inputSchema: StructType = StructType(StructField("ele", DoubleType) :: Nil)
  82. // 缓冲区的类型,求和的时候需要缓冲,计算聚合的是一定要有缓冲
  83. override def bufferSchema: StructType = StructType(StructField("sum", DoubleType) :: Nil)
  84. // 最终聚合解结果的类型
  85. //因为你是聚合,聚合的结果只有一个
  86. override def dataType: DataType = DoubleType
  87. // 相同的输入是否返回相同的输出
  88. //用的时候几乎永远都是true
  89. override def deterministic: Boolean = true
  90. // 对缓冲区初始化
  91. //初始化的就是在缓冲里面去初始化一个值,用来计算聚合的值的.
  92. override def initialize(buffer: MutableAggregationBuffer): Unit = {
  93. // 在缓冲集合中初始化和
  94. buffer(0) = 0D // 等价于 buffer.update(0, 0D)
  95. }
  96. // 分区内聚合
  97. override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
  98. // 需要先进行非空判断.
  99. // 如果不为空的话就取出来进行计算.
  100. input match {
  101. case Row(age: Double) =>
  102. // 获取0位置的
  103. buffer(0) = buffer.getDouble(0) + age
  104. case _ =>
  105. }
  106. /*// input是指的使用聚合函数的时候, 缓过来的参数封装到了Row
  107. if (!input.isNullAt(0)) { // 考虑到传字段可能是null
  108. val v = input.getAs[Double](0) // getDouble(0)
  109. buffer(0) = buffer.getDouble(0) + v
  110. }*/
  111. }
  112. // 分区间的聚合
  113. override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
  114. // 把buffer1和buffer2 的缓冲弄聚合到一起, 然后再把值写回到buffer1
  115. //这里不需要判断非空,因为缓冲区初始化(initialize方法)是一定有值的.
  116. buffer1(0) = buffer1.getDouble(0) + buffer2.getDouble(0)
  117. }
  118. // 集合函数返回最终的输出值
  119. override def evaluate(buffer: Row): Any = buffer.getDouble(0)
  120. }
  1. +--------------------------+
  2. |myavg(CAST(age AS DOUBLE))|
  3. +--------------------------+
  4. | 16.5|
  5. +--------------------------+