UDF
spark.udf.register("addName", (x:String)=> "Name:"+x)spark.sql("Select addName(name), age from people").show()
UDAF
弱类型
适用于DF的sql操作方式
//创建自定义函数对象
val myAvg = new MyAvg
//注册自定义函数
spark.udf.register("myAvg",myAvg)
//创建临时视图
df.createOrReplaceTempView("user")
//使用聚合函数进行查询
spark.sql("select myAvg(age) from user").show()
class MyAvg extends UserDefinedAggregateFunction{
//聚合函数的输入数据的类型
override def inputSchema: StructType = {
StructType(Array(StructField("age",IntegerType)))
}
//缓存数据的类型
override def bufferSchema: StructType = {
StructType(Array(StructField("sum",LongType),StructField("count",LongType)))
}
//聚合函数返回的数据类型
override def dataType: DataType = DoubleType
//稳定性 默认不处理,直接返回true 相同输入是否会得到相同的输出
override def deterministic: Boolean = true
//初始化 缓存设置到初始状态
override def initialize(buffer: MutableAggregationBuffer): Unit = {
//让缓存中年龄总和归0
buffer(0) = 0L
//让缓存中总人数归0
buffer(1) = 0L
}
//更新缓存数据
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
if(!buffer.isNullAt(0)){
buffer(0) = buffer.getLong(0) + input.getInt(0)
buffer(1) = buffer.getLong(1) + 1L
}
}
//分区间的合并
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
}
//计算逻辑
override def evaluate(buffer: Row): Any = {
buffer.getLong(0).toDouble/buffer.getLong(1)
}
}
强类型
适用于DS的DSL操作方式
val myAvgNew = new MyAvgNew
//将df转换为ds
val ds: Dataset[User06] = df.as[User06]
//将自定义函数对象转换为查询列
val col: TypedColumn[User06, Double] = myAvgNew.toColumn
//在进行查询的时候,会将查询出来的记录(User06类型)交给自定义的函数进行处理
ds.select(col).show
//输入类型的样例类
case class User06(name:String,age:Long)
//缓存类型
case class AgeBuffer(var sum:Long,var count:Long)
class MyAvgNew extends Aggregator[User06,AgeBuffer,Double]{
//对缓存数据进行初始化
override def zero: AgeBuffer = {
AgeBuffer(0L,0L)
}
//对当前分区内数据进行聚合
override def reduce(b: AgeBuffer, a: User06): AgeBuffer = {
b.sum += a.age
b.count += 1
b
}
//分区间合并
override def merge(b1: AgeBuffer, b2: AgeBuffer): AgeBuffer = {
b1.sum += b2.sum
b1.count += b2.count
b1
}
//返回计算结果
override def finish(buffer: AgeBuffer): Double = {
buffer.sum.toDouble/buffer.count
}
//DataSet的编码以及解码器 ,用于进行序列化,固定写法
//用户自定义Ref类型 product 系统值类型,根据具体类型进行选择
override def bufferEncoder: Encoder[AgeBuffer] = {
Encoders.product
}
override def outputEncoder: Encoder[Double] = {
Encoders.scalaDouble
}
}
UDTF
输入一行,返回多行(hive)
SparkSQL中没有UDTF,spark中用flatMap即可实现该功能
