Customized udf and udaf function statistics uv in SPARK SQL (using bitmap)

Customized udf and udaf function statistics uv in SPARK SQL (using bitmap)

When counting UVs in actual work, the count(distinct userId) method is generally used to count people, but this is not efficient. Suppose you are counting data in multiple dimensions. When one day you want to roll up the dimensions, then It is also necessary to start statistics from the original layer. If the amount of data is large, it will take a lot of time. At this time, the most fine-grained aggregation results can be used for roll-up statistics, that is, a custom aggregation function is needed for statistics, and the bitmap is serialized as A byte array.

1) One aggregation

package org.shydow.UDF

import org.apache.spark.sql.{Encoder, Encoders}
import org.apache.spark.sql.expressions.Aggregator
import org.roaringbitmap.RoaringBitmap
/**
 * @author shydow
 * @date 2021/12/13 22:55
 */

class BitmapGenUDAF extends Aggregator[Int, Array[Byte], Array[Byte]] {

  override def zero: Array[Byte] = {
    // Construct an empty bitmap
    val bm: RoaringBitmap = RoaringBitmap.bitmapOf()
    // Serialize bitmap into byte array
    BitmapUtil.serBitmap(bm)
  }

  override def reduce(b: Array[Byte], a: Int): Array[Byte] = {
    //Deserialize buff into bitmap
    val bitmap: RoaringBitmap = BitmapUtil.deSerBitmap(b)
    bitmap.add(a)
    BitmapUtil.serBitmap(bitmap)
  }

  override def merge(b1: Array[Byte], b2: Array[Byte]): Array[Byte] = {
    val bitmap1: RoaringBitmap = BitmapUtil.deSerBitmap(b1)
    val bitmap2: RoaringBitmap = BitmapUtil.deSerBitmap(b2)
    bitmap1.or(bitmap2)
    BitmapUtil.serBitmap(bitmap1)
  }

  override def finish(reduction: Array[Byte]): Array[Byte] = reduction

  override def bufferEncoder: Encoder[Array[Byte]] = Encoders.BINARY

  override def outputEncoder: Encoder[Array[Byte]] = Encoders.BINARY
}
package org.shydow.UDF

import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}

import org.roaringbitmap.RoaringBitmap

/**
 * @author shydow
 * @date 2021/12/13 22:45
 */
object BitmapUtil {

  /**
   * Serialize bitmap
   */
  def serBitmap(bm: RoaringBitmap): Array[Byte] = {
    val stream = new ByteArrayOutputStream()
    val dataOutput = new DataOutputStream(stream)
    bm.serialize(dataOutput)
    stream.toByteArray
  }

  /**
   * Reverse bitmap
   */
  def deSerBitmap(bytes: Array[Byte]): RoaringBitmap = {
    val bm: RoaringBitmap = RoaringBitmap.bitmapOf()
    val stream = new ByteArrayInputStream(bytes)
    val inputStream = new DataInputStream(stream)
    bm.deserialize(inputStream)
    bm
  }
}
package org.shydow.UDF

import org.apache.spark.sql.{DataFrame, SparkSession, TypedColumn}
import org.apache.spark.sql.types.{LongType, StringType, StructField, StructType}
import org.roaringbitmap.RoaringBitmap

/**
 * @author shydow
 * @date 2021/12/13 22:25
 */

object TestBehaviorAnalysis {
  def main(args: Array[String]): Unit = {

    val spark: SparkSession = SparkSession.builder()
      .appName("analysis")
      .master("local[*]")
      .getOrCreate()
    spark.sparkContext.setLogLevel("WARN")
    import spark.implicits._

    val schema = StructType(Seq(
      StructField("id", LongType),
      StructField("eventType", StringType),
      StructField("code", StringType),
      StructField("timestamp", LongType))
    )
    val frame: DataFrame = spark.read.schema(schema).csv("data/OrderLog.csv")
    frame.createOrReplaceTempView("order_log")

    /**
     * Use distinct count to calculate uv
     */
    spark.sql(
      s"""
         |select
         | eventType,
         | count(1) as pv,
         | count(distinct id) as uv
         |from order_log
         |group by eventType
         |""".stripMargin).show()

    /**
     * Customized UDAF calculation uv
     */
    import org.apache.spark.sql.functions.udaf
    spark.udf.register("gen_bitmap", udaf(new BitmapGenUDAF)) //The output of this function is a byte array. If you want to calculate the specific base, you must write a udf

    def card(byteArray: Array[Byte]): Int = {
      val bitmap: RoaringBitmap = BitmapUtil.deSerBitmap(byteArray)
      bitmap.getCardinality
    }
    spark.udf.register("get_card", card _)

    spark.sql(
      s"""
         |select
         | eventType,
         | count(1) as pv,
         | gen_bitmap(id) as uv_arr,
         | get_card(gen_bitmap(id)) as uv
         |from order_log
         |group by eventType
         |""".stripMargin).show()

    spark.close()
  }
}

2) Roll-up aggregation

package org.shydow.UDF

import org.apache.spark.sql.{Encoder, Encoders}
import org.apache.spark.sql.expressions.Aggregator
import org.roaringbitmap.RoaringBitmap

/**
 * @author shydow
 * @date 2021/12/14 8:36
 */

class BitmapOrMergeUDAF extends Aggregator[Array[Byte], Array[Byte], Array[Byte]]{
  override def zero: Array[Byte] = {
    val bitmap: RoaringBitmap = RoaringBitmap.bitmapOf()
    BitmapUtil.serBitmap(bitmap)
  }

  override def reduce(b: Array[Byte], a: Array[Byte]): Array[Byte] = {
    val bitmap1: RoaringBitmap = BitmapUtil.deSerBitmap(b)
    val bitmap2: RoaringBitmap = BitmapUtil.deSerBitmap(a)
    bitmap1.or(bitmap2)
    BitmapUtil.serBitmap(bitmap1)
  }

  override def merge(b1: Array[Byte], b2: Array[Byte]): Array[Byte] = {
    val bitmap1: RoaringBitmap = BitmapUtil.deSerBitmap(b1)
    val bitmap2: RoaringBitmap = BitmapUtil.deSerBitmap(b2)
    bitmap1.or(bitmap2)
    BitmapUtil.serBitmap(bitmap1)
  }

  override def finish(reduction: Array[Byte]): Array[Byte] = reduction

  override def bufferEncoder: Encoder[Array[Byte]] = Encoders.BINARY

  override def outputEncoder: Encoder[Array[Byte]] = Encoders.BINARY
}
package org.shydow.UDF

import org.apache.spark.sql.{DataFrame, SparkSession, TypedColumn}
import org.apache.spark.sql.types.{LongType, StringType, StructField, StructType}
import org.roaringbitmap.RoaringBitmap

/**
 * @author shydow
 * @date 2021/12/13 22:25
 */

object TestBehaviorAnalysis {
  def main(args: Array[String]): Unit = {

    val spark: SparkSession = SparkSession.builder()
      .appName("analysis")
      .master("local[*]")
      .getOrCreate()
    spark.sparkContext.setLogLevel("WARN")
    import spark.implicits._

    val schema = StructType(Seq(
      StructField("id", LongType),
      StructField("eventType", StringType),
      StructField("code", StringType),
      StructField("timestamp", LongType))
    )
    val frame: DataFrame = spark.read.schema(schema).csv("data/OrderLog.csv")
    frame.createOrReplaceTempView("order_log")

    /**
     * Use distinct count to calculate uv
     */
    spark.sql(
      s"""
         |select
         | eventType,
         | code,
         | count(1) as pv,
         | count(distinct id) as uv
         |from order_log
         |where code is not null
         |group by eventType, code
         |""".stripMargin).show()

    /**
     * Customized UDAF calculation uv
     */
    import org.apache.spark.sql.functions.udaf
    spark.udf.register("gen_bitmap", udaf(new BitmapGenUDAF)) //The output of this function is a byte array. If you want to calculate the specific base, you must write a udf

    def card(byteArray: Array[Byte]): Int = {
      val bitmap: RoaringBitmap = BitmapUtil.deSerBitmap(byteArray)
      bitmap.getCardinality
    }
    spark.udf.register("get_card", card _)

    val res: DataFrame = spark.sql(
      s"""
         |select
         | eventType,
         | code,
         | count(1) as pv,
         | gen_bitmap(id) as uv_arr,
         | get_card(gen_bitmap(id)) as uv
         |from order_log
         |where code is not null
         |group by eventType, code
         |""".stripMargin)
    res.createTempView("dws_stat")

    spark.udf.register("bitmapOr", udaf(new BitmapOrMergeUDAF))
    spark.sql(
      s"""
        |select
        | eventType,
        | sum(pv) as total_pv,
        | bitmapOr(uv_arr),
        | get_card(bitmapOr(uv_arr)) as total_uv
        |from dws_stat
        |group by eventType
        |""".stripMargin).show()


    spark.close()
  }
}