json数据
[{"name":"张三" ,"age":18} ,{"name":"李四" ,"age":15}]
代码编写
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Row, SparkSession}
import scala.collection.immutable.Nil
object UDAFDemo {
def main(args: Array[String]): Unit = {
// 在sql中, 聚合函数如何使用
val spark: SparkSession = SparkSession
.builder()
.master("local[*]")
.appName("UDAFDemo")
.getOrCreate()
val df = spark.read.json("E:\\ZJJ_SparkSQL\\demo01\\src\\main\\resources\\users.json")
df.createOrReplaceTempView("user")
// 注册聚合函数
spark.udf.register("myAvg", new MyAvg)
spark.sql("select myAvg(age) from user").show
spark.close()
}
}
/**
* 求平均值
*/
class MyAvg extends UserDefinedAggregateFunction {
// 输入的数据类型 10.1 12.2 100
override def inputSchema: StructType = StructType(StructField("ele", DoubleType) :: Nil)
// 缓冲区的类型
// 求平均值需要两个值运算,一个是年龄的和,另外一个是多少个年龄参与运算,
// 所以这就是求平均值了.
override def bufferSchema: StructType =
StructType(StructField("sum", DoubleType) :: StructField("count", LongType) :: Nil)
// 最终聚合解结果的类型
override def dataType: DataType = DoubleType
// 相同的输入是否返回相同的输出
override def deterministic: Boolean = true
// 对缓冲区初始化
override def initialize(buffer: MutableAggregationBuffer): Unit = {
// avg初始化是一个值,个数初始化也得是一个值.
// 在缓冲集合中初始化和
buffer(0) = 0D // 等价于 buffer.update(0, 0D)
buffer(1) = 0L //带个L,不然就存成int类型了.
}
// 分区内聚合
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
input match {
case Row(age: Double) =>
buffer(0) = buffer.getDouble(0) + age //年龄进行相加
buffer(1) = buffer.getLong(1) + 1L //个数累加碰到一个就加1
case _ =>
}
/*// input是指的使用聚合函数的时候, 缓过来的参数封装到了Row
if (!input.isNullAt(0)) { // 考虑到传字段可能是null
val v = input.getAs[Double](0) // getDouble(0)
buffer(0) = buffer.getDouble(0) + v
buffer(1) = buffer.getLong(1) + 1L
}*/
}
// 分区间的聚合
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer2 match {
case Row(sum: Double, count: Long) =>
// 缓冲区和要集合
buffer1(0) = buffer1.getDouble(0) + sum
//个数也要聚合
buffer1(1) = buffer1.getLong(1) + count
case _ =>
}
// 把buffer1和buffer2 的缓冲弄聚合到一起, 然后再把值写回到buffer1
/*buffer1(0) = buffer1.getDouble(0) + buffer2.getDouble(0)
buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)*/
}
// 返回最终的输出值
// 就是累加的总数除以个数,就是平均值了.
override def evaluate(buffer: Row): Any = buffer.getDouble(0) / buffer.getLong(1)
}
/**
* 求和的函数
*/
class MySum extends UserDefinedAggregateFunction {
// 输入的数据类型 10.1 12.2 100
override def inputSchema: StructType = StructType(StructField("ele", DoubleType) :: Nil)
// 缓冲区的类型,求和的时候需要缓冲,计算聚合的是一定要有缓冲
override def bufferSchema: StructType = StructType(StructField("sum", DoubleType) :: Nil)
// 最终聚合解结果的类型
//因为你是聚合,聚合的结果只有一个
override def dataType: DataType = DoubleType
// 相同的输入是否返回相同的输出
//用的时候几乎永远都是true
override def deterministic: Boolean = true
// 对缓冲区初始化
//初始化的就是在缓冲里面去初始化一个值,用来计算聚合的值的.
override def initialize(buffer: MutableAggregationBuffer): Unit = {
// 在缓冲集合中初始化和
buffer(0) = 0D // 等价于 buffer.update(0, 0D)
}
// 分区内聚合
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
// 需要先进行非空判断.
// 如果不为空的话就取出来进行计算.
input match {
case Row(age: Double) =>
// 获取0位置的
buffer(0) = buffer.getDouble(0) + age
case _ =>
}
/*// input是指的使用聚合函数的时候, 缓过来的参数封装到了Row
if (!input.isNullAt(0)) { // 考虑到传字段可能是null
val v = input.getAs[Double](0) // getDouble(0)
buffer(0) = buffer.getDouble(0) + v
}*/
}
// 分区间的聚合
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
// 把buffer1和buffer2 的缓冲弄聚合到一起, 然后再把值写回到buffer1
//这里不需要判断非空,因为缓冲区初始化(initialize方法)是一定有值的.
buffer1(0) = buffer1.getDouble(0) + buffer2.getDouble(0)
}
// 集合函数返回最终的输出值
override def evaluate(buffer: Row): Any = buffer.getDouble(0)
}
+--------------------------+
|myavg(CAST(age AS DOUBLE))|
+--------------------------+
| 16.5|
+--------------------------+