Spark [Spark SQL (4) UDF function and UDAF function]

UDF functions

UDF is a function that our users can customize. We call the register (name: String, func (A1, A2, A3…)) method of udf through the SparkSession object to register a custom function. Among them, name is our customized function name, and func is our customized function, which can have many parameters.

Through the UDF function, we can perform targeted processing on a certain column of data or a certain cell of data.

Case 1

Define a function that gives Andy’s name field value preceded by “Name: “.

def main(args: Array[String]): Unit = {

    val conf = new SparkConf()
    conf.setAppName("spark sql udf")
      .setMaster("local[*]")
    val spark = SparkSession.builder().config(conf).getOrCreate()
    import spark.implicits._

    val df = spark.read.json("data/sql/people.json")
    df.createOrReplaceTempView("people")

    spark.udf.register("prefixName",(name:String)=>{
      if (name.equals("Andy"))
        "Name: " + name
      else
        name
    })
    spark.sql("select prefixName(name) as name,age,sex from people").show()

    spark.stop()
  }

Here we define a custom UDF function: prefixName, which will determine whether the value of the name field is “Andy”. If so, it will + “Name: ” in front of her value.

Run results:

 + ---------- + --- + --- +
| name|age|sex|
 + ---------- + --- + --- +
| Michael| 30| Male|
|Name: Andy| 19| Female|
| Justin| 19| Male|
|Bernadette| 20| Female|
| Gretchen| 23| Female|
| David| 27| Male|
| Joseph| 33| Female|
| Trish| 27| Female|
| Alex| 33| Female|
| Ben| 25| Male|
 + ---------- + --- + --- + 

UDAF functions

Both the strongly typed DataSet and the weakly typed DataFrame provide related aggregate functions, such as count, countDistinct, avg, max, and min.

UDAF is our user-defined aggregate function. Aggregation functions are functions such as avg and sum. They need to put all the data together (buffer) and then process it uniformly.

Implementing the UDAF function requires a class with our own custom aggregate function (the main task is calculation). We can inherit UserDefinedAggregateFunction and implement the eight methods in it to implement weakly typed aggregate functions. (It is no longer recommended after Spark 3.0, and strongly typed aggregate functions are more recommended)

We can inherit Aggregator to implement strongly typed aggregate functions.

Case 1 – Average age

The case class can directly construct the object without new, because the sample class can automatically generate its companion object and apply method.

Weakly typed implementation

import org.apache.spark.SparkConf
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, IntegerType, LongType, StructField, StructType}

/**
 * Weak type
 */
object UDAFTest01 {
  def main(args: Array[String]): Unit = {
    val conf = new SparkConf()
    conf.setAppName("spark sql udaf")
      .setMaster("local[*]")
    val spark = SparkSession.builder().config(conf).getOrCreate()
    import spark.implicits._

    val df = spark.read.json("data/sql/people.json")
    df.createOrReplaceTempView("people")

    spark.udf.register("avgAge",new MyAvgUDAF())

    spark.sql("select avgAge(age) from people").show()

    spark.stop()
  }
}
class MyAvgUDAF extends UserDefinedAggregateFunction{

  //Structure of input data IN
  override def inputSchema: StructType = {
   StructType(
     Array(StructField("age",LongType))
   )}

  //Buffer data structure BUFFER
  override def bufferSchema: StructType = {
    StructType(
      Array(
        StructField("total",LongType),
        StructField("count",LongType)
      )
    )}

  // Data type of function calculation result OUT
  override def dataType: DataType = LongType

  // Stability of the function (whether the same parameters are passed in and the result is the same)
  override def deterministic: Boolean = true

  //Buffer initialization
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    //Both writing methods are the same
// buffer(0) = 0L
// buffer(1) = 0L
    //The second method
    buffer.update(0,0L) //total assigns 0L to the 0th data structure of the buffer -total-initialization
    buffer.update(1,0L) //count assigns 0L to the first data structure of the buffer -count-initialization
  }

  //How to update the buffer after the data comes in
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    // The first parameter represents the i-th data structure of the buffer. 0 represents total and 1 represents count.
    //The second parameter is to reassign the data structure of the first parameter
    // buffer.getLong(0) is to take out the 0th value of the buffer - that is, the value of total, and give it + the 0th value among the values input above (because the only one of our input structures is age:Long)
    buffer.update(0,buffer.getLong(0) + input.getLong(0))
    buffer.update(1,buffer.getLong(1) + 1) //count + 1 each time data comes in
  }

  // Combine data from multiple buffers
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1.update(0,buffer1.getLong(0) + buffer2.getLong(0))
    buffer1.update(1,buffer1.getLong(1) + buffer2.getLong(1))
  }

  // Calculate result operation
  override def evaluate(buffer: Row): Any = {
    buffer.getLong(0)/buffer.getLong(1)
  }
}

operation result:

 + ----------- +
|avgage(age)|
 + ----------- +
|25|
 + ---------- + 

Strong type implementation

import org.apache.spark.SparkConf
import org.apache.spark.sql.{Encoder, Encoders, Row, SparkSession, functions}
import org.apache.spark.sql.expressions.Aggregator

/**
 * Strong typing
 */
object UDAFTest02 {
  def main(args: Array[String]): Unit = {
    val conf = new SparkConf()
    conf.setAppName("spark sql udaf")
      .setMaster("local[*]")
    val spark = SparkSession.builder().config(conf).getOrCreate()
    import spark.implicits._

    val df = spark.read.json("data/sql/people.json")
    df.createOrReplaceTempView("people")

    spark.udf.register("avgAge",functions.udaf(new MyAvg_UDAF()))

    spark.sql("select avgAge(age) from people").show()

    spark.stop()
  }
}

/**
 * Custom aggregate function class:
 * 1. Inherit org.apache.spark.sql.expressions.Aggregator and define generics:
 * IN: Input data type Long
 * BUF: buffer data type
 * OUT: Output data type Long
 * 2. Rewrite method
 */
//The parameter in the sample class defaults to val, so it must be specified as var here.
case class Buff(var total: Long,var count: Long)
class MyAvg_UDAF extends Aggregator[Long,Buff,Long]{

  // zero: Buff zero means that this method is used for the initial value (0 value)
  // Buff is our case class, which means it is used to initialize the buffer.
  override def zero: Buff = {
    Buff(0L,0L)
  }

  //Update the buffer according to the input data and return -Buff upon request
  override def reduce(buff: Buff, in: Long): Buff = {
    buff.total + = in
    buff.count + = 1
    buff
  }

  // Merge buffers also return buff1
  override def merge(buff1: Buff, buff2: Buff): Buff = {
    buff1.total + = buff2.total
    buff1.count + = buff2.count
    buff1
  }

  // Calculation results
  override def finish(buff: Buff): Long = {
    buff.total/buff.count
  }

  // Network transmission requires serialization buffer encoding operation - encoding
  override def bufferEncoder: Encoder[Buff] = Encoders.product

  // Output encoding operation - decoding
  override def outputEncoder: Encoder[Long] = Encoders.scalaLong
}

operation result:

 + ----------- +
|avgage(age)|
 + ----------- +
|25|
 + ---------- + 

Early UDAF strongly typed aggregate function

SQL: Structured data query & DSL: Object-oriented query (there are objects and methods, related to types, so they are used in combination with DSL statements)

Early UDAF strongly typed aggregate functions used DSL operations.

Define a case class corresponding to the data type, then convert the DataFrame to the DataSet type through the as[object] method, and then convert our UDAF aggregate class into a column object.

import org.apache.spark.SparkConf
import org.apache.spark.sql.{Dataset, Encoder, Encoders, Row, SparkSession, TypedColumn, functions}
import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, IntegerType, LongType, StructField, StructType}

/**
 * Early UDAF strongly typed aggregate functions used DSL operations
 */
object UDAFTest03 {

  def main(args: Array[String]): Unit = {
    val conf = new SparkConf()
    conf.setAppName("spark sql udaf")
      .setMaster("local[*]")
    val spark = SparkSession.builder().config(conf).getOrCreate()
    import spark.implicits._

    val df = spark.read.json("data/sql/people.json")

    val ds: Dataset[User] = df.as[User]

    //Convert UDAF strongly typed aggregate function into query class object
    val udafCol: TypedColumn[User, Long] = new OldAvg_UDAF().toColumn
    ds.select(udafCol).show()

    spark.stop()
  }
}

/**
 * Custom aggregate function class:
 * 1. Inherit org.apache.spark.sql.expressions.Aggregator and define generics:
 * IN: Input data type User
 * BUF: buffer data type
 * OUT: Output data type Long
 * 2. Rewrite method
 */
//The parameter in the sample class defaults to val, so it must be specified as var here.
case class User(name: String,age: Long,sex: String)
case class Buff(var total: Long,var count: Long)
class OldAvg_UDAF extends Aggregator[User,Buff,Long]{

  // zero: Buff zero means that this method is used for the initial value (0 value)
  // Buff is our case class, which means it is used to initialize the buffer.
  override def zero: Buff = {
    Buff(0L,0L)
  }

  //Update the buffer according to the input data and return -Buff upon request
  override def reduce(buff: Buff, in: User): Buff = {
    buff.total + = in.age
    buff.count + = 1
    buff
  }

  // Merge buffers also return buff1
  override def merge(buff1: Buff, buff2: Buff): Buff = {
    buff1.total + = buff2.total
    buff1.count + = buff2.count
    buff1
  }

  // Calculation results
  override def finish(buff: Buff): Long = {
    buff.total/buff.count
  }

  // Network transmission requires serialization buffer encoding operation - encoding
  override def bufferEncoder: Encoder[Buff] = Encoders.product

  // Output encoding operation - decoding
  override def outputEncoder: Encoder[Long] = Encoders.scalaLong
}

operation result:

 + ------------------------------------------------ +
|OldAvg_UDAF(com.study.spark.core.sql.User)|
 +------------------------------------------------+
|25|
 +------------------------------------------------+