UDF
spark内置的函数: [http://spark.apache.org/docs/latest/sql-ref-functions-builtin.html](http://spark.apache.org/docs/latest/sql-ref-functions-builtin.html) 允许自定义UDF: <br />①编写函数<br /> ②在spark中注册函数<br /> ③使用函数
@Test def testCustomUDF():Unit={ //①编写函数 def addAge(age:Int) :Int={ age + 10 } // ②在spark中注册函数 addAge _ 获取当前函数的引用, 其中,"addAge10"随便写,addAge _ 根据编写的函数写 //session.udf.register("addAge10",addAge _) //使用匿名函数 session.udf.register("addAge10",(age:Int) => age + 10) val dataFrame1: DataFrame = session.read.json("input/people.json") dataFrame1.createTempView("emps") //帮我将age + 10 session.sql("select name,addAge10(age) from emps").show() }
UDAF
org.apache.spark.sql.expressions.Aggregator[-IN, BUF, OUT]<br /> UDAF 输入: N行N列,输出 1行1列<br /> IN: 输入的类型<br /> BUF: 缓冲区的类型<br /> OUT :函数输出的类型<br /> IN : age:Int<br /> BUF: 1)使用元祖 ( sum:Double , count:Int ) <br /> 2)使用样例类封装需要保存的字段:MyBuf(sum:Double , count:Int)<br /> OUT: avgAge:double Aggregator[IN, BUF, OUT] should now be registered as a UDF" + " via the functions.udaf(agg) method
定义聚合函数类继承 Aggregator[IN, BUF, OUT]
package com.tcodeimport org.apache.spark.sql.expressions.Aggregatorimport org.apache.spark.sql.{Encoder, Encoders}/** * Created by Smexy on 2021/6/7 */class MyAvg extends Aggregator[Int,MyBuf,Double]{ // 初始化(构造)缓冲区 override def zero: MyBuf = MyBuf(0.0 , 0) // 在Spark RDD的一个分区中,对这个分区要计算的累,进行累加,累加到缓冲区 override def reduce(buffer: MyBuf, age: Int): MyBuf = { buffer.sum += age buffer.count += 1 buffer } // 在Spark 不同分区的缓冲区进行合并,得到最终的缓冲区 override def merge(b1: MyBuf, b2: MyBuf): MyBuf = { b1.sum += b2.sum b1.count += b2.count b1 } // 返回最后的结果 override def finish(reduction: MyBuf): Double = { reduction.sum / reduction.count } //提供缓冲区的Encoder override def bufferEncoder: Encoder[MyBuf] = Encoders.product[MyBuf] // //提供返回值类型的Encoder override def outputEncoder: Encoder[Double] = Encoders.scalaDouble}case class MyBuf(var sum:Double , var count:Int)
调用UDAF
@Test def testUDAF():Unit={ //①定义函数 val myAvg = new MyAvg() // ②注册函数 session.udf.register("myavg",functions.udaf(myAvg)) // ③使用函数 val dataFrame1: DataFrame = session.read.json("input/people.json") dataFrame1.createTempView("emps") session.sql("select myavg(age) from emps").show() }