Problems encountered in spark production process (accumulator related)

Project scenario:

Hint: Here is a brief background on the project:

Project scenario: spark upsert mysql data, and want to control the process through the result of foreachpartition execution.

Description of the problem

def insertOrUpdateDFtoDB(tableName: String, resultDateFrame: DataFrame, updateColumns: Array[String]): Boolean = {<!-- -->
    var count = 0
    var status = true
    val colNumbsers = resultDateFrame.columns.length
    println("colNumbsers length: " + resultDateFrame.columns.length.toString)
    println("updateColumns length: " + updateColumns.length.toString)
    var sql = getInsertOrUpdateSql(tableName, resultDateFrame.columns, updateColumns)
    val columnDataTypes = resultDateFrame.schema.fields.map(_.dataType)
    println(s"\\
$sql")
    resultDateFrame.foreachPartition((partitionRecords:Iterator[Row]) => {<!-- -->
      val conn = getOnlineConn()
      val prepareStatement = conn. prepareStatement(sql)
      try {<!-- -->
        conn.setAutoCommit(false)
        partitionRecords.foreach(record => {<!-- -->
          //Set the fields that need to be inserted
          for (i <- 1 to colNumbsers) {<!-- -->
            val value = record. get(i - 1)
            val dateType = columnDataTypes(i - 1)
            if (value != null) {<!-- -->
              prepareStatement.setString(i, value.toString)
              dateType match {<!-- -->
                case _: ByteType => prepareStatement.setInt(i, record.getAs[Int](i - 1))
                case _: ShortType => prepareStatement.setInt(i, record.getAs[Int](i - 1))
                case _: IntegerType => prepareStatement.setInt(i, record.getAs[Int](i - 1))
                case _: LongType => prepareStatement.setLong(i, record.getAs[Long](i - 1))
                case _: BooleanType => prepareStatement.setBoolean(i, record.getAs[Boolean](i - 1))
                case _: FloatType => prepareStatement.setFloat(i, record.getAs[Float](i - 1))
                case _: DoubleType => prepareStatement.setDouble(i, record.getAs[Double](i - 1))
                case _: StringType => prepareStatement.setString(i, record.getAs[String](i - 1))
                case _: TimestampType => prepareStatement.setTimestamp(i, record.getAs[Timestamp](i - 1))
                // case _: DateType => prepareStatement.setDate(i, record.getAs[Date](i - 1))
              }
            } else {<!-- -->
              println(i, updateColumns(i-1), record. getString(updateColumns. length-1))
            }
          }
          //Set the field value that needs to be updated
          for (i <- 1 to updateColumns. length) {<!-- -->
            val fieldIndex = record. fieldIndex(updateColumns(i - 1))
            val value = record. get(i-1)
            val dataType = columnDataTypes(fieldIndex)
            // println(s"\\
Update field value attribute index: $fieldIndex")//, attribute value: $value, attribute type: $dataType")
            if (value != null) {<!-- -->
              dataType match {<!-- -->
                case _: ByteType => prepareStatement.setInt(colNumbsers + i, record.getAs[Int](fieldIndex))
                case _: ShortType => prepareStatement.setInt(colNumbsers + i, record.getAs[Int](fieldIndex))
                case _: IntegerType => prepareStatement.setInt(colNumbsers + i, record.getAs[Int](fieldIndex))
                case _: LongType => prepareStatement.setLong(colNumbsers + i, record.getAs[Long](fieldIndex))
                case _: BooleanType => prepareStatement.setBoolean(colNumbsers + i, record.getAs[Boolean](fieldIndex))
                case _: FloatType => prepareStatement.setFloat(colNumbsers + i, record.getAs[Float](fieldIndex))
                case _: DoubleType => prepareStatement.setDouble(colNumbsers + i, record.getAs[Double](fieldIndex))
                case _: StringType => prepareStatement.setString(colNumbsers + i, record.getAs[String](fieldIndex))
                case _: TimestampType => prepareStatement.setTimestamp(colNumbsers + i, record.getAs[Timestamp](fieldIndex))
                // case _: DateType => prepareStatement.setDate(colNumbsers + i, record.getAs[Date](fieldIndex))
              }
            }
          }
          prepareStatement. addBatch()
          count + = 1
        })
        // batch size is 2000
        if (count % 2000 == 0) {<!-- -->
          prepareStatement. executeBatch()
          conn.commit()
        }
      } catch {<!-- -->
        case e: Exception =>
          status = false
          println(s"@@ ${<!-- -->e}")
      } finally {<!-- -->
        prepareStatement. executeBatch()
        conn.commit()
        prepareStatement. close()
        conn. close()
      }
    })
    status
  }

I want to get the result status of updating and inserting mysql through the last status, but the status is always true, even if the logic in the catch is gone, it is still true at the end, so there is no way to judge whether there is any data loss or failure caused by the wrong format in the process

Cause Analysis:

In Spark, the processing of each partition is parallel, so the modification of the shared variable may cause a race condition to cause the variable to not change.

Solution:

To avoid this, an accumulator can be used to collect the results within a partition and get a state after the partition ends.

def insertOrUpdateDFtoDB(sc:SparkContext,tableName: String, resultDateFrame: DataFrame, updateColumns: Array[String]): Boolean = {<!-- -->
    var count = 0
    var status = true
    val colNumbsers = resultDateFrame.columns.length
    println("colNumbsers length: " + resultDateFrame.columns.length.toString)
    println("updateColumns length: " + updateColumns.length.toString)
    var sql = getInsertOrUpdateSql(tableName, resultDateFrame.columns, updateColumns)
    val columnDataTypes = resultDateFrame.schema.fields.map(_.dataType)
    println(s"\\
$sql")

    // Declare the accumulator
    val acc = sc. longAccumulator("acc_count")
    resultDateFrame.foreachPartition((partitionRecords:Iterator[Row]) => {<!-- -->
      val conn = getOnlineConn()
      var acc_count = 1
      val prepareStatement = conn. prepareStatement(sql)
      try {<!-- -->
        conn.setAutoCommit(false)
        partitionRecords.foreach(record => {<!-- -->
          //Set the fields that need to be inserted
          for (i <- 1 to colNumbsers) {<!-- -->
            val value = record. get(i - 1)
            val dateType = columnDataTypes(i - 1)
            if (value != null) {<!-- -->
              prepareStatement.setString(i, value.toString)
              dateType match {<!-- -->
                case _: ByteType => prepareStatement.setInt(i, record.getAs[Int](i - 1))
                case _: ShortType => prepareStatement.setInt(i, record.getAs[Int](i - 1))
                case _: IntegerType => prepareStatement.setInt(i, record.getAs[Int](i - 1))
                case _: LongType => prepareStatement.setLong(i, record.getAs[Long](i - 1))
                case _: BooleanType => prepareStatement.setBoolean(i, record.getAs[Boolean](i - 1))
                case _: FloatType => prepareStatement.setFloat(i, record.getAs[Float](i - 1))
                case _: DoubleType => prepareStatement.setDouble(i, record.getAs[Double](i - 1))
                case _: StringType => prepareStatement.setString(i, record.getAs[String](i - 1))
                case _: TimestampType => prepareStatement.setTimestamp(i, record.getAs[Timestamp](i - 1))
                // case _: DateType => prepareStatement.setDate(i, record.getAs[Date](i - 1))
              }
            }
          }
          //Set the field value that needs to be updated
          for (i <- 1 to updateColumns. length) {<!-- -->
            val fieldIndex = record. fieldIndex(updateColumns(i - 1))
            val value = record. get(i-1)
            val dataType = columnDataTypes(fieldIndex)
            // println(s"\\
Update field value attribute index: $fieldIndex")//, attribute value: $value, attribute type: $dataType")
            if (value != null) {<!-- -->
              dataType match {<!-- -->
                case _: ByteType => prepareStatement.setInt(colNumbsers + i, record.getAs[Int](fieldIndex))
                case _: ShortType => prepareStatement.setInt(colNumbsers + i, record.getAs[Int](fieldIndex))
                case _: IntegerType => prepareStatement.setInt(colNumbsers + i, record.getAs[Int](fieldIndex))
                case _: LongType => prepareStatement.setLong(colNumbsers + i, record.getAs[Long](fieldIndex))
                case _: BooleanType => prepareStatement.setBoolean(colNumbsers + i, record.getAs[Boolean](fieldIndex))
                case _: FloatType => prepareStatement.setFloat(colNumbsers + i, record.getAs[Float](fieldIndex))
                case _: DoubleType => prepareStatement.setDouble(colNumbsers + i, record.getAs[Double](fieldIndex))
                case _: StringType => prepareStatement.setString(colNumbsers + i, record.getAs[String](fieldIndex))
                case _: TimestampType => prepareStatement.setTimestamp(colNumbsers + i, record.getAs[Timestamp](fieldIndex))
                // case _: DateType => prepareStatement.setDate(colNumbsers + i, record.getAs[Date](fieldIndex))
              }
            }
          }
          prepareStatement. addBatch()
          count + = 1
        })
        // batch size is 2000
        if (count % 2000 == 0) {<!-- -->
          prepareStatement. executeBatch()
          conn.commit()
        }
        acc.add(acc_count) // accumulator + 1
      } catch {<!-- -->
        case e: Exception =>
          println(s"@@ ${<!-- -->e}")
      } finally {<!-- -->
        prepareStatement. executeBatch()
        conn.commit()
        prepareStatement. close()
        conn. close()
      }
    })
    

val result = acc. value
    if(result==5){<!-- --> //5 is the number of partitions, depending on how many partitions there are foreachpartition
      status=true
    }else{<!-- -->
      status=false
    }

    status
  }