UDF

  1. spark.udf.register("addName", (x:String)=> "Name:"+x)
  2. spark.sql("Select addName(name), age from people").show()

UDAF

弱类型

适用于DF的sql操作方式

        //创建自定义函数对象
    val myAvg = new MyAvg
    //注册自定义函数
    spark.udf.register("myAvg",myAvg)
    //创建临时视图
    df.createOrReplaceTempView("user")
    //使用聚合函数进行查询
    spark.sql("select myAvg(age) from user").show()

class MyAvg extends UserDefinedAggregateFunction{
  //聚合函数的输入数据的类型
  override def inputSchema: StructType = {
    StructType(Array(StructField("age",IntegerType)))
  }
  //缓存数据的类型
  override def bufferSchema: StructType = {
    StructType(Array(StructField("sum",LongType),StructField("count",LongType)))
  }
  //聚合函数返回的数据类型
  override def dataType: DataType = DoubleType
  //稳定性  默认不处理,直接返回true    相同输入是否会得到相同的输出
  override def deterministic: Boolean = true
  //初始化  缓存设置到初始状态
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    //让缓存中年龄总和归0
    buffer(0) = 0L
    //让缓存中总人数归0
    buffer(1) = 0L
  }
  //更新缓存数据
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    if(!buffer.isNullAt(0)){
      buffer(0) = buffer.getLong(0) + input.getInt(0)
      buffer(1) = buffer.getLong(1) + 1L
    }
  }
  //分区间的合并
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
    buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
  }
  //计算逻辑
  override def evaluate(buffer: Row): Any = {
    buffer.getLong(0).toDouble/buffer.getLong(1)
  }
}

强类型

适用于DS的DSL操作方式

        val myAvgNew = new MyAvgNew

    //将df转换为ds
    val ds: Dataset[User06] = df.as[User06]
    //将自定义函数对象转换为查询列
    val col: TypedColumn[User06, Double] = myAvgNew.toColumn

    //在进行查询的时候,会将查询出来的记录(User06类型)交给自定义的函数进行处理
    ds.select(col).show

//输入类型的样例类
case class User06(name:String,age:Long)
//缓存类型
case class AgeBuffer(var sum:Long,var count:Long)
class MyAvgNew extends Aggregator[User06,AgeBuffer,Double]{
  //对缓存数据进行初始化
  override def zero: AgeBuffer = {
    AgeBuffer(0L,0L)
  }
  //对当前分区内数据进行聚合
  override def reduce(b: AgeBuffer, a: User06): AgeBuffer = {
    b.sum += a.age
    b.count += 1
    b
  }
  //分区间合并
  override def merge(b1: AgeBuffer, b2: AgeBuffer): AgeBuffer = {
    b1.sum += b2.sum
    b1.count += b2.count
    b1
  }
  //返回计算结果
  override def finish(buffer: AgeBuffer): Double = {
    buffer.sum.toDouble/buffer.count
  }
  //DataSet的编码以及解码器  ,用于进行序列化,固定写法
  //用户自定义Ref类型  product       系统值类型,根据具体类型进行选择
  override def bufferEncoder: Encoder[AgeBuffer] = {
    Encoders.product
  }
  override def outputEncoder: Encoder[Double] = {
    Encoders.scalaDouble
  }
}

UDTF

输入一行,返回多行(hive)
SparkSQL中没有UDTF,spark中用flatMap即可实现该功能