UDF functions
UDF is a function that our users can customize. We call the register (name: String, func (A1, A2, A3…)) method of udf through the SparkSession object to register a custom function. Among them, name is our customized function name, and func is our customized function, which can have many parameters.
Through the UDF function, we can perform targeted processing on a certain column of data or a certain cell of data.
Case 1
Define a function that gives Andy’s name field value preceded by “Name: “.
def main(args: Array[String]): Unit = { val conf = new SparkConf() conf.setAppName("spark sql udf") .setMaster("local[*]") val spark = SparkSession.builder().config(conf).getOrCreate() import spark.implicits._ val df = spark.read.json("data/sql/people.json") df.createOrReplaceTempView("people") spark.udf.register("prefixName",(name:String)=>{ if (name.equals("Andy")) "Name: " + name else name }) spark.sql("select prefixName(name) as name,age,sex from people").show() spark.stop() }
Here we define a custom UDF function: prefixName, which will determine whether the value of the name field is “Andy”. If so, it will + “Name: ” in front of her value.
Run results:
+ ---------- + --- + --- + | name|age|sex| + ---------- + --- + --- + | Michael| 30| Male| |Name: Andy| 19| Female| | Justin| 19| Male| |Bernadette| 20| Female| | Gretchen| 23| Female| | David| 27| Male| | Joseph| 33| Female| | Trish| 27| Female| | Alex| 33| Female| | Ben| 25| Male| + ---------- + --- + --- +
UDAF functions
Both the strongly typed DataSet and the weakly typed DataFrame provide related aggregate functions, such as count, countDistinct, avg, max, and min.
UDAF is our user-defined aggregate function. Aggregation functions are functions such as avg and sum. They need to put all the data together (buffer) and then process it uniformly.
Implementing the UDAF function requires a class with our own custom aggregate function (the main task is calculation). We can inherit UserDefinedAggregateFunction and implement the eight methods in it to implement weakly typed aggregate functions. (It is no longer recommended after Spark 3.0, and strongly typed aggregate functions are more recommended)
We can inherit Aggregator to implement strongly typed aggregate functions.
Case 1 – Average age
The case class can directly construct the object without new, because the sample class can automatically generate its companion object and apply method.
Weakly typed implementation
import org.apache.spark.SparkConf import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.types.{DataType, IntegerType, LongType, StructField, StructType} /** * Weak type */ object UDAFTest01 { def main(args: Array[String]): Unit = { val conf = new SparkConf() conf.setAppName("spark sql udaf") .setMaster("local[*]") val spark = SparkSession.builder().config(conf).getOrCreate() import spark.implicits._ val df = spark.read.json("data/sql/people.json") df.createOrReplaceTempView("people") spark.udf.register("avgAge",new MyAvgUDAF()) spark.sql("select avgAge(age) from people").show() spark.stop() } } class MyAvgUDAF extends UserDefinedAggregateFunction{ //Structure of input data IN override def inputSchema: StructType = { StructType( Array(StructField("age",LongType)) )} //Buffer data structure BUFFER override def bufferSchema: StructType = { StructType( Array( StructField("total",LongType), StructField("count",LongType) ) )} // Data type of function calculation result OUT override def dataType: DataType = LongType // Stability of the function (whether the same parameters are passed in and the result is the same) override def deterministic: Boolean = true //Buffer initialization override def initialize(buffer: MutableAggregationBuffer): Unit = { //Both writing methods are the same // buffer(0) = 0L // buffer(1) = 0L //The second method buffer.update(0,0L) //total assigns 0L to the 0th data structure of the buffer -total-initialization buffer.update(1,0L) //count assigns 0L to the first data structure of the buffer -count-initialization } //How to update the buffer after the data comes in override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { // The first parameter represents the i-th data structure of the buffer. 0 represents total and 1 represents count. //The second parameter is to reassign the data structure of the first parameter // buffer.getLong(0) is to take out the 0th value of the buffer - that is, the value of total, and give it + the 0th value among the values input above (because the only one of our input structures is age:Long) buffer.update(0,buffer.getLong(0) + input.getLong(0)) buffer.update(1,buffer.getLong(1) + 1) //count + 1 each time data comes in } // Combine data from multiple buffers override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { buffer1.update(0,buffer1.getLong(0) + buffer2.getLong(0)) buffer1.update(1,buffer1.getLong(1) + buffer2.getLong(1)) } // Calculate result operation override def evaluate(buffer: Row): Any = { buffer.getLong(0)/buffer.getLong(1) } }
operation result:
+ ----------- + |avgage(age)| + ----------- + |25| + ---------- +
Strong type implementation
import org.apache.spark.SparkConf import org.apache.spark.sql.{Encoder, Encoders, Row, SparkSession, functions} import org.apache.spark.sql.expressions.Aggregator /** * Strong typing */ object UDAFTest02 { def main(args: Array[String]): Unit = { val conf = new SparkConf() conf.setAppName("spark sql udaf") .setMaster("local[*]") val spark = SparkSession.builder().config(conf).getOrCreate() import spark.implicits._ val df = spark.read.json("data/sql/people.json") df.createOrReplaceTempView("people") spark.udf.register("avgAge",functions.udaf(new MyAvg_UDAF())) spark.sql("select avgAge(age) from people").show() spark.stop() } } /** * Custom aggregate function class: * 1. Inherit org.apache.spark.sql.expressions.Aggregator and define generics: * IN: Input data type Long * BUF: buffer data type * OUT: Output data type Long * 2. Rewrite method */ //The parameter in the sample class defaults to val, so it must be specified as var here. case class Buff(var total: Long,var count: Long) class MyAvg_UDAF extends Aggregator[Long,Buff,Long]{ // zero: Buff zero means that this method is used for the initial value (0 value) // Buff is our case class, which means it is used to initialize the buffer. override def zero: Buff = { Buff(0L,0L) } //Update the buffer according to the input data and return -Buff upon request override def reduce(buff: Buff, in: Long): Buff = { buff.total + = in buff.count + = 1 buff } // Merge buffers also return buff1 override def merge(buff1: Buff, buff2: Buff): Buff = { buff1.total + = buff2.total buff1.count + = buff2.count buff1 } // Calculation results override def finish(buff: Buff): Long = { buff.total/buff.count } // Network transmission requires serialization buffer encoding operation - encoding override def bufferEncoder: Encoder[Buff] = Encoders.product // Output encoding operation - decoding override def outputEncoder: Encoder[Long] = Encoders.scalaLong }
operation result:
+ ----------- + |avgage(age)| + ----------- + |25| + ---------- +
Early UDAF strongly typed aggregate function
SQL: Structured data query & DSL: Object-oriented query (there are objects and methods, related to types, so they are used in combination with DSL statements)
Early UDAF strongly typed aggregate functions used DSL operations.
Define a case class corresponding to the data type, then convert the DataFrame to the DataSet type through the as[object] method, and then convert our UDAF aggregate class into a column object.
import org.apache.spark.SparkConf import org.apache.spark.sql.{Dataset, Encoder, Encoders, Row, SparkSession, TypedColumn, functions} import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.types.{DataType, IntegerType, LongType, StructField, StructType} /** * Early UDAF strongly typed aggregate functions used DSL operations */ object UDAFTest03 { def main(args: Array[String]): Unit = { val conf = new SparkConf() conf.setAppName("spark sql udaf") .setMaster("local[*]") val spark = SparkSession.builder().config(conf).getOrCreate() import spark.implicits._ val df = spark.read.json("data/sql/people.json") val ds: Dataset[User] = df.as[User] //Convert UDAF strongly typed aggregate function into query class object val udafCol: TypedColumn[User, Long] = new OldAvg_UDAF().toColumn ds.select(udafCol).show() spark.stop() } } /** * Custom aggregate function class: * 1. Inherit org.apache.spark.sql.expressions.Aggregator and define generics: * IN: Input data type User * BUF: buffer data type * OUT: Output data type Long * 2. Rewrite method */ //The parameter in the sample class defaults to val, so it must be specified as var here. case class User(name: String,age: Long,sex: String) case class Buff(var total: Long,var count: Long) class OldAvg_UDAF extends Aggregator[User,Buff,Long]{ // zero: Buff zero means that this method is used for the initial value (0 value) // Buff is our case class, which means it is used to initialize the buffer. override def zero: Buff = { Buff(0L,0L) } //Update the buffer according to the input data and return -Buff upon request override def reduce(buff: Buff, in: User): Buff = { buff.total + = in.age buff.count + = 1 buff } // Merge buffers also return buff1 override def merge(buff1: Buff, buff2: Buff): Buff = { buff1.total + = buff2.total buff1.count + = buff2.count buff1 } // Calculation results override def finish(buff: Buff): Long = { buff.total/buff.count } // Network transmission requires serialization buffer encoding operation - encoding override def bufferEncoder: Encoder[Buff] = Encoders.product // Output encoding operation - decoding override def outputEncoder: Encoder[Long] = Encoders.scalaLong }
operation result:
+ ------------------------------------------------ + |OldAvg_UDAF(com.study.spark.core.sql.User)| +------------------------------------------------+ |25| +------------------------------------------------+