UDF
spark内置的函数: [http://spark.apache.org/docs/latest/sql-ref-functions-builtin.html](http://spark.apache.org/docs/latest/sql-ref-functions-builtin.html)
允许自定义UDF: <br />①编写函数<br /> ②在spark中注册函数<br /> ③使用函数
@Test
def testCustomUDF():Unit={
//①编写函数
def addAge(age:Int) :Int={
age + 10
}
// ②在spark中注册函数 addAge _ 获取当前函数的引用,
其中,"addAge10"随便写,addAge _ 根据编写的函数写
//session.udf.register("addAge10",addAge _)
//使用匿名函数
session.udf.register("addAge10",(age:Int) => age + 10)
val dataFrame1: DataFrame = session.read.json("input/people.json")
dataFrame1.createTempView("emps")
//帮我将age + 10
session.sql("select name,addAge10(age) from emps").show()
}
UDAF
org.apache.spark.sql.expressions.Aggregator[-IN, BUF, OUT]<br /> UDAF 输入: N行N列,输出 1行1列<br /> IN: 输入的类型<br /> BUF: 缓冲区的类型<br /> OUT :函数输出的类型<br /> IN : age:Int<br /> BUF: 1)使用元祖 ( sum:Double , count:Int ) <br /> 2)使用样例类封装需要保存的字段:MyBuf(sum:Double , count:Int)<br /> OUT: avgAge:double
Aggregator[IN, BUF, OUT] should now be registered as a UDF" + " via the functions.udaf(agg) method
定义聚合函数类继承 Aggregator[IN, BUF, OUT]
package com.tcode
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.{Encoder, Encoders}
/**
* Created by Smexy on 2021/6/7
*/
class MyAvg extends Aggregator[Int,MyBuf,Double]{
// 初始化(构造)缓冲区
override def zero: MyBuf = MyBuf(0.0 , 0)
// 在Spark RDD的一个分区中,对这个分区要计算的累,进行累加,累加到缓冲区
override def reduce(buffer: MyBuf, age: Int): MyBuf = {
buffer.sum += age
buffer.count += 1
buffer
}
// 在Spark 不同分区的缓冲区进行合并,得到最终的缓冲区
override def merge(b1: MyBuf, b2: MyBuf): MyBuf = {
b1.sum += b2.sum
b1.count += b2.count
b1
}
// 返回最后的结果
override def finish(reduction: MyBuf): Double = {
reduction.sum / reduction.count
}
//提供缓冲区的Encoder
override def bufferEncoder: Encoder[MyBuf] = Encoders.product[MyBuf]
// //提供返回值类型的Encoder
override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}
case class MyBuf(var sum:Double , var count:Int)
调用UDAF
@Test
def testUDAF():Unit={
//①定义函数
val myAvg = new MyAvg()
// ②注册函数
session.udf.register("myavg",functions.udaf(myAvg))
// ③使用函数
val dataFrame1: DataFrame = session.read.json("input/people.json")
dataFrame1.createTempView("emps")
session.sql("select myavg(age) from emps").show()
}