json数据
[{"name":"张三" ,"age":18} ,{"name":"李四" ,"age":15}]
代码编写
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}import org.apache.spark.sql.types._import org.apache.spark.sql.{Row, SparkSession}import scala.collection.immutable.Nilobject UDAFDemo { def main(args: Array[String]): Unit = { // 在sql中, 聚合函数如何使用 val spark: SparkSession = SparkSession .builder() .master("local[*]") .appName("UDAFDemo") .getOrCreate() val df = spark.read.json("E:\\ZJJ_SparkSQL\\demo01\\src\\main\\resources\\users.json") df.createOrReplaceTempView("user") // 注册聚合函数 spark.udf.register("myAvg", new MyAvg) spark.sql("select myAvg(age) from user").show spark.close() }}/** * 求平均值 */class MyAvg extends UserDefinedAggregateFunction { // 输入的数据类型 10.1 12.2 100 override def inputSchema: StructType = StructType(StructField("ele", DoubleType) :: Nil) // 缓冲区的类型 // 求平均值需要两个值运算,一个是年龄的和,另外一个是多少个年龄参与运算, // 所以这就是求平均值了. override def bufferSchema: StructType = StructType(StructField("sum", DoubleType) :: StructField("count", LongType) :: Nil) // 最终聚合解结果的类型 override def dataType: DataType = DoubleType // 相同的输入是否返回相同的输出 override def deterministic: Boolean = true // 对缓冲区初始化 override def initialize(buffer: MutableAggregationBuffer): Unit = { // avg初始化是一个值,个数初始化也得是一个值. // 在缓冲集合中初始化和 buffer(0) = 0D // 等价于 buffer.update(0, 0D) buffer(1) = 0L //带个L,不然就存成int类型了. } // 分区内聚合 override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { input match { case Row(age: Double) => buffer(0) = buffer.getDouble(0) + age //年龄进行相加 buffer(1) = buffer.getLong(1) + 1L //个数累加碰到一个就加1 case _ => } /*// input是指的使用聚合函数的时候, 缓过来的参数封装到了Row if (!input.isNullAt(0)) { // 考虑到传字段可能是null val v = input.getAs[Double](0) // getDouble(0) buffer(0) = buffer.getDouble(0) + v buffer(1) = buffer.getLong(1) + 1L }*/ } // 分区间的聚合 override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { buffer2 match { case Row(sum: Double, count: Long) => // 缓冲区和要集合 buffer1(0) = buffer1.getDouble(0) + sum //个数也要聚合 buffer1(1) = buffer1.getLong(1) + count case _ => } // 把buffer1和buffer2 的缓冲弄聚合到一起, 然后再把值写回到buffer1 /*buffer1(0) = buffer1.getDouble(0) + buffer2.getDouble(0) buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)*/ } // 返回最终的输出值 // 就是累加的总数除以个数,就是平均值了. override def evaluate(buffer: Row): Any = buffer.getDouble(0) / buffer.getLong(1)}/** * 求和的函数 */class MySum extends UserDefinedAggregateFunction { // 输入的数据类型 10.1 12.2 100 override def inputSchema: StructType = StructType(StructField("ele", DoubleType) :: Nil) // 缓冲区的类型,求和的时候需要缓冲,计算聚合的是一定要有缓冲 override def bufferSchema: StructType = StructType(StructField("sum", DoubleType) :: Nil) // 最终聚合解结果的类型 //因为你是聚合,聚合的结果只有一个 override def dataType: DataType = DoubleType // 相同的输入是否返回相同的输出 //用的时候几乎永远都是true override def deterministic: Boolean = true // 对缓冲区初始化 //初始化的就是在缓冲里面去初始化一个值,用来计算聚合的值的. override def initialize(buffer: MutableAggregationBuffer): Unit = { // 在缓冲集合中初始化和 buffer(0) = 0D // 等价于 buffer.update(0, 0D) } // 分区内聚合 override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { // 需要先进行非空判断. // 如果不为空的话就取出来进行计算. input match { case Row(age: Double) => // 获取0位置的 buffer(0) = buffer.getDouble(0) + age case _ => } /*// input是指的使用聚合函数的时候, 缓过来的参数封装到了Row if (!input.isNullAt(0)) { // 考虑到传字段可能是null val v = input.getAs[Double](0) // getDouble(0) buffer(0) = buffer.getDouble(0) + v }*/ } // 分区间的聚合 override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { // 把buffer1和buffer2 的缓冲弄聚合到一起, 然后再把值写回到buffer1 //这里不需要判断非空,因为缓冲区初始化(initialize方法)是一定有值的. buffer1(0) = buffer1.getDouble(0) + buffer2.getDouble(0) } // 集合函数返回最终的输出值 override def evaluate(buffer: Row): Any = buffer.getDouble(0)}
+--------------------------+|myavg(CAST(age AS DOUBLE))|+--------------------------+| 16.5|+--------------------------+