【问题标题】:SPARK SQL - update MySql table using DataFrames and JDBCSPARK SQL - 使用 DataFrames 和 JDBC 更新 MySql 表
【发布时间】:2016-06-08 16:09:18
【问题描述】:

我正在尝试使用 Spark SQL DataFrames 和 JDBC 连接在 MySql 上插入和更新一些数据。

我已成功使用 SaveMode.Append 插入新数据。有没有办法从 Spark SQL 更新 MySql Table 中已经存在的数据?

我要插入的代码是:

myDataFrame.write.mode(SaveMode.Append).jdbc(JDBCurl,mySqlTable,connectionProperties)

如果我更改为 SaveMode.Overwrite,它会删除整个表并创建一个新表,我正在寻找类似 MySql 中可用的“ON DUPLICATE KEY UPDATE”之类的东西

【问题讨论】:

    标签: jdbc apache-spark apache-spark-sql


    【解决方案1】:

    这是不可能的。至于现在(Spark 1.6.0 / 2.2.0 SNAPSHOT)Spark DataFrameWriter 只支持四种写入模式:

    • SaveMode.Overwrite:覆盖现有数据。
    • SaveMode.Append:追加数据。
    • SaveMode.Ignore:忽略操作(即无操作)。
    • SaveMode.ErrorIfExists:默认选项,运行时抛出异常。

    您可以手动插入,例如使用 mapPartitions(因为您希望 UPSERT 操作应该是幂等的并且易于实现)、写入临时表并手动执行 upsert,或使用触发器。

    一般来说,为批处理操作实现 upsert 行为并保持良好的性能绝非易事。您必须记住,通常情况下会有多个并发事务(每个分区一个),因此您必须确保不会出现写入冲突(通常通过使用特定于应用程序的分区)或提供适当的恢复过程。在实践中,执行和批量写入临时表并直接在数据库中解析 upsert 部分可能会更好。

    【讨论】:

      【解决方案2】:

      zero323 的回答是对的,我只是想补充一点,您可以使用 JayDeBeApi 包来解决此问题: https://pypi.python.org/pypi/JayDeBeApi/

      更新 mysql 表中的数据。由于您已经安装了 mysql jdbc 驱动程序,因此这可能是一个容易实现的目标。

      JayDeBeApi 模块允许您从 Python 代码连接到 使用 Java JDBC 的数据库。它为此提供了一个 Python DB-API v2.0 数据库。

      我们使用 Python 的 Anaconda 发行版,JayDeBeApi python 包是标准的。

      请参阅上面该链接中的示例。

      【讨论】:

        【解决方案3】:

        遗憾的是,Spark 中没有 SaveMode.Upsert 模式,用于像 upserting 这样非常常见的情况。

        zero322 总的来说是对的,但我认为应该可以(在性能上有所妥协)提供这种替换功能。

        我还想为此案例提供一些 java 代码。 当然,它的性能不如 spark 的内置产品 - 但它应该是满足您要求的良好基础。只需根据您的需要进行修改:

        myDF.repartition(20); //one connection per partition, see below
        
        myDF.foreachPartition((Iterator<Row> t) -> {
                    Connection conn = DriverManager.getConnection(
                            Constants.DB_JDBC_CONN,
                            Constants.DB_JDBC_USER,
                            Constants.DB_JDBC_PASS);
        
                    conn.setAutoCommit(true);
                    Statement statement = conn.createStatement();
        
                    final int batchSize = 100000;
                    int i = 0;
                    while (t.hasNext()) {
                        Row row = t.next();
                        try {
                            // better than REPLACE INTO, less cycles
                            statement.addBatch(("INSERT INTO mytable " + "VALUES ("
                                    + "'" + row.getAs("_id") + "', 
                                    + "'" + row.getStruct(1).get(0) + "'
                                    + "')  ON DUPLICATE KEY UPDATE _id='" + row.getAs("_id") + "';"));
                            //conn.commit();
        
                            if (++i % batchSize == 0) {
                                statement.executeBatch();
                            }
                        } catch (SQLIntegrityConstraintViolationException e) {
                            //should not occur, nevertheless
                            //conn.commit();
                        } catch (SQLException e) {
                            e.printStackTrace();
                        } finally {
                            //conn.commit();
                            statement.executeBatch();
                        }
                    }
                    int[] ret = statement.executeBatch();
        
                    System.out.println("Ret val: " + Arrays.toString(ret));
                    System.out.println("Update count: " + statement.getUpdateCount());
                    //conn.commit();
        
                    statement.close();
                    conn.close();
        

        【讨论】:

        • 这对我来说效果很好。我必须做的一个小修正是在statement.close(); 之前注释掉conn.commit(); 行。否则,它会抛出这个错误java-sql-sqlexception-cant-call-commit-when-autocommit-true
        【解决方案4】:

        org.apache.spark.sql.execution.datasources.jdbcJdbcUtils.scalainsert into覆盖为replace into

        import java.sql.{Connection, Driver, DriverManager, PreparedStatement, ResultSet, SQLException}
        
        import scala.collection.JavaConverters._
        import scala.util.control.NonFatal
        import com.typesafe.scalalogging.Logger
        import org.apache.spark.sql.catalyst.InternalRow
        import org.apache.spark.sql.execution.datasources.jdbc.{DriverRegistry, DriverWrapper, JDBCOptions}
        import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects, JdbcType}
        import org.apache.spark.sql.types._
        import org.apache.spark.sql.{DataFrame, Row}
        
        /**
          * Util functions for JDBC tables.
          */
        object UpdateJdbcUtils {
        
          val logger = Logger(this.getClass)
        
          /**
            * Returns a factory for creating connections to the given JDBC URL.
            *
            * @param options - JDBC options that contains url, table and other information.
            */
          def createConnectionFactory(options: JDBCOptions): () => Connection = {
            val driverClass: String = options.driverClass
            () => {
              DriverRegistry.register(driverClass)
              val driver: Driver = DriverManager.getDrivers.asScala.collectFirst {
                case d: DriverWrapper if d.wrapped.getClass.getCanonicalName == driverClass => d
                case d if d.getClass.getCanonicalName == driverClass => d
              }.getOrElse {
                throw new IllegalStateException(
                  s"Did not find registered driver with class $driverClass")
              }
              driver.connect(options.url, options.asConnectionProperties)
            }
          }
        
          /**
            * Returns a PreparedStatement that inserts a row into table via conn.
            */
          def insertStatement(conn: Connection, table: String, rddSchema: StructType, dialect: JdbcDialect)
          : PreparedStatement = {
            val columns = rddSchema.fields.map(x => dialect.quoteIdentifier(x.name)).mkString(",")
            val placeholders = rddSchema.fields.map(_ => "?").mkString(",")
            val sql = s"REPLACE INTO $table ($columns) VALUES ($placeholders)"
            conn.prepareStatement(sql)
          }
        
          /**
            * Retrieve standard jdbc types.
            *
            * @param dt The datatype (e.g. [[org.apache.spark.sql.types.StringType]])
            * @return The default JdbcType for this DataType
            */
          def getCommonJDBCType(dt: DataType): Option[JdbcType] = {
            dt match {
              case IntegerType => Option(JdbcType("INTEGER", java.sql.Types.INTEGER))
              case LongType => Option(JdbcType("BIGINT", java.sql.Types.BIGINT))
              case DoubleType => Option(JdbcType("DOUBLE PRECISION", java.sql.Types.DOUBLE))
              case FloatType => Option(JdbcType("REAL", java.sql.Types.FLOAT))
              case ShortType => Option(JdbcType("INTEGER", java.sql.Types.SMALLINT))
              case ByteType => Option(JdbcType("BYTE", java.sql.Types.TINYINT))
              case BooleanType => Option(JdbcType("BIT(1)", java.sql.Types.BIT))
              case StringType => Option(JdbcType("TEXT", java.sql.Types.CLOB))
              case BinaryType => Option(JdbcType("BLOB", java.sql.Types.BLOB))
              case TimestampType => Option(JdbcType("TIMESTAMP", java.sql.Types.TIMESTAMP))
              case DateType => Option(JdbcType("DATE", java.sql.Types.DATE))
              case t: DecimalType => Option(
                JdbcType(s"DECIMAL(${t.precision},${t.scale})", java.sql.Types.DECIMAL))
              case _ => None
            }
          }
        
          private def getJdbcType(dt: DataType, dialect: JdbcDialect): JdbcType = {
            dialect.getJDBCType(dt).orElse(getCommonJDBCType(dt)).getOrElse(
              throw new IllegalArgumentException(s"Can't get JDBC type for ${dt.simpleString}"))
          }
        
          // A `JDBCValueGetter` is responsible for getting a value from `ResultSet` into a field
          // for `MutableRow`. The last argument `Int` means the index for the value to be set in
          // the row and also used for the value in `ResultSet`.
          private type JDBCValueGetter = (ResultSet, InternalRow, Int) => Unit
        
          // A `JDBCValueSetter` is responsible for setting a value from `Row` into a field for
          // `PreparedStatement`. The last argument `Int` means the index for the value to be set
          // in the SQL statement and also used for the value in `Row`.
          private type JDBCValueSetter = (PreparedStatement, Row, Int) => Unit
        
          /**
            * Saves a partition of a DataFrame to the JDBC database.  This is done in
            * a single database transaction (unless isolation level is "NONE")
            * in order to avoid repeatedly inserting data as much as possible.
            *
            * It is still theoretically possible for rows in a DataFrame to be
            * inserted into the database more than once if a stage somehow fails after
            * the commit occurs but before the stage can return successfully.
            *
            * This is not a closure inside saveTable() because apparently cosmetic
            * implementation changes elsewhere might easily render such a closure
            * non-Serializable.  Instead, we explicitly close over all variables that
            * are used.
            */
          def savePartition(
                             getConnection: () => Connection,
                             table: String,
                             iterator: Iterator[Row],
                             rddSchema: StructType,
                             nullTypes: Array[Int],
                             batchSize: Int,
                             dialect: JdbcDialect,
                             isolationLevel: Int): Iterator[Byte] = {
            val conn = getConnection()
            var committed = false
        
            var finalIsolationLevel = Connection.TRANSACTION_NONE
            if (isolationLevel != Connection.TRANSACTION_NONE) {
              try {
                val metadata = conn.getMetaData
                if (metadata.supportsTransactions()) {
                  // Update to at least use the default isolation, if any transaction level
                  // has been chosen and transactions are supported
                  val defaultIsolation = metadata.getDefaultTransactionIsolation
                  finalIsolationLevel = defaultIsolation
                  if (metadata.supportsTransactionIsolationLevel(isolationLevel)) {
                    // Finally update to actually requested level if possible
                    finalIsolationLevel = isolationLevel
                  } else {
                    logger.warn(s"Requested isolation level $isolationLevel is not supported; " +
                      s"falling back to default isolation level $defaultIsolation")
                  }
                } else {
                  logger.warn(s"Requested isolation level $isolationLevel, but transactions are unsupported")
                }
              } catch {
                case NonFatal(e) => logger.warn("Exception while detecting transaction support", e)
              }
            }
            val supportsTransactions = finalIsolationLevel != Connection.TRANSACTION_NONE
        
            try {
              if (supportsTransactions) {
                conn.setAutoCommit(false) // Everything in the same db transaction.
                conn.setTransactionIsolation(finalIsolationLevel)
              }
              val stmt = insertStatement(conn, table, rddSchema, dialect)
              val setters: Array[JDBCValueSetter] = rddSchema.fields.map(_.dataType)
                .map(makeSetter(conn, dialect, _))
              val numFields = rddSchema.fields.length
        
              try {
                var rowCount = 0
                while (iterator.hasNext) {
                  val row = iterator.next()
                  var i = 0
                  while (i < numFields) {
                    if (row.isNullAt(i)) {
                      stmt.setNull(i + 1, nullTypes(i))
                    } else {
                      setters(i).apply(stmt, row, i)
                    }
                    i = i + 1
                  }
                  stmt.addBatch()
                  rowCount += 1
                  if (rowCount % batchSize == 0) {
                    stmt.executeBatch()
                    rowCount = 0
                  }
                }
                if (rowCount > 0) {
                  stmt.executeBatch()
                }
              } finally {
                stmt.close()
              }
              if (supportsTransactions) {
                conn.commit()
              }
              committed = true
              Iterator.empty
            } catch {
              case e: SQLException =>
                val cause = e.getNextException
                if (cause != null && e.getCause != cause) {
                  if (e.getCause == null) {
                    e.initCause(cause)
                  } else {
                    e.addSuppressed(cause)
                  }
                }
                throw e
            } finally {
              if (!committed) {
                // The stage must fail.  We got here through an exception path, so
                // let the exception through unless rollback() or close() want to
                // tell the user about another problem.
                if (supportsTransactions) {
                  conn.rollback()
                }
                conn.close()
              } else {
                // The stage must succeed.  We cannot propagate any exception close() might throw.
                try {
                  conn.close()
                } catch {
                  case e: Exception => logger.warn("Transaction succeeded, but closing failed", e)
                }
              }
            }
          }
        
          /**
            * Saves the RDD to the database in a single transaction.
            */
          def saveTable(
                         df: DataFrame,
                         url: String,
                         table: String,
                         options: JDBCOptions) {
            val dialect = JdbcDialects.get(url)
            val nullTypes: Array[Int] = df.schema.fields.map { field =>
              getJdbcType(field.dataType, dialect).jdbcNullType
            }
        
            val rddSchema = df.schema
            val getConnection: () => Connection = createConnectionFactory(options)
            val batchSize = options.batchSize
            val isolationLevel = options.isolationLevel
            df.foreachPartition(iterator => savePartition(
              getConnection, table, iterator, rddSchema, nullTypes, batchSize, dialect, isolationLevel)
            )
          }
        
          private def makeSetter(
                                  conn: Connection,
                                  dialect: JdbcDialect,
                                  dataType: DataType): JDBCValueSetter = dataType match {
            case IntegerType =>
              (stmt: PreparedStatement, row: Row, pos: Int) =>
                stmt.setInt(pos + 1, row.getInt(pos))
        
            case LongType =>
              (stmt: PreparedStatement, row: Row, pos: Int) =>
                stmt.setLong(pos + 1, row.getLong(pos))
        
            case DoubleType =>
              (stmt: PreparedStatement, row: Row, pos: Int) =>
                stmt.setDouble(pos + 1, row.getDouble(pos))
        
            case FloatType =>
              (stmt: PreparedStatement, row: Row, pos: Int) =>
                stmt.setFloat(pos + 1, row.getFloat(pos))
        
            case ShortType =>
              (stmt: PreparedStatement, row: Row, pos: Int) =>
                stmt.setInt(pos + 1, row.getShort(pos))
        
            case ByteType =>
              (stmt: PreparedStatement, row: Row, pos: Int) =>
                stmt.setInt(pos + 1, row.getByte(pos))
        
            case BooleanType =>
              (stmt: PreparedStatement, row: Row, pos: Int) =>
                stmt.setBoolean(pos + 1, row.getBoolean(pos))
        
            case StringType =>
              (stmt: PreparedStatement, row: Row, pos: Int) =>
                stmt.setString(pos + 1, row.getString(pos))
        
            case BinaryType =>
              (stmt: PreparedStatement, row: Row, pos: Int) =>
                stmt.setBytes(pos + 1, row.getAs[Array[Byte]](pos))
        
            case TimestampType =>
              (stmt: PreparedStatement, row: Row, pos: Int) =>
                stmt.setTimestamp(pos + 1, row.getAs[java.sql.Timestamp](pos))
        
            case DateType =>
              (stmt: PreparedStatement, row: Row, pos: Int) =>
                stmt.setDate(pos + 1, row.getAs[java.sql.Date](pos))
        
            case t: DecimalType =>
              (stmt: PreparedStatement, row: Row, pos: Int) =>
                stmt.setBigDecimal(pos + 1, row.getDecimal(pos))
        
            case ArrayType(et, _) =>
              // remove type length parameters from end of type name
              val typeName = getJdbcType(et, dialect).databaseTypeDefinition
                .toLowerCase.split("\\(")(0)
              (stmt: PreparedStatement, row: Row, pos: Int) =>
                val array = conn.createArrayOf(
                  typeName,
                  row.getSeq[AnyRef](pos).toArray)
                stmt.setArray(pos + 1, array)
        
            case _ =>
              (_: PreparedStatement, _: Row, pos: Int) =>
                throw new IllegalArgumentException(
                  s"Can't translate non-null value for field $pos")
          }
        }
        

        用法:

        val url = s"jdbc:mysql://$host/$database?useUnicode=true&characterEncoding=UTF-8"
        
        val parameters: Map[String, String] = Map(
          "url" -> url,
          "dbtable" -> table,
          "driver" -> "com.mysql.jdbc.Driver",
          "numPartitions" -> numPartitions.toString,
          "user" -> user,
          "password" -> password
        )
        val options = new JDBCOptions(parameters)
        
        for (d <- data) {
          UpdateJdbcUtils.saveTable(d, url, table, options)
        }
        

        ps:注意死锁,不要频繁更新数据,只是在紧急情况下重新运行时使用,我想这就是为什么spark不支持这个官方的原因。

        【讨论】:

        • 尝试运行代码时出现此错误:Caused by: java.io.NotSerializableException: UpdateJdbcUtils$ Serialization stack: - object not serializable (class: UpdateJdbcUtils$, value: UpdateJdbcUtils$@4f87e8f9) - field (class: UpdateJdbcUtils$$anonfun$saveTable$1, name: $outer, type: class UpdateJdbcUtils$) - object (class UpdateJdbcUtils$$anonfun$saveTable$1, &lt;function1&gt;) at org.apache.spark.serializer.SerializationDebugger$.improveException(SerializationDebugger.scala:40)
        【解决方案5】:

        在 PYSPARK 中我无法做到这一点,所以我决定使用 odbc。

        url = "jdbc:sqlserver://xxx:1433;databaseName=xxx;user=xxx;password=xxx"
        df.write.jdbc(url=url, table="__TableInsert", mode='overwrite')
        cnxn  = pyodbc.connect('Driver={ODBC Driver 17 for SQL Server};Server=xxx;Database=xxx;Uid=xxx;Pwd=xxx;', autocommit=False) 
        try:
            crsr = cnxn.cursor()
            # DO UPSERTS OR WHATEVER YOU WANT
            crsr.execute("DELETE FROM Table")
            crsr.execute("INSERT INTO Table (Field) SELECT Field FROM __TableInsert")
            cnxn.commit()
        except:
            cnxn.rollback()
        cnxn.close()
        

        【讨论】:

          【解决方案6】:

          如果您的表很小,那么您可以读取 sql 数据并在 spark dataframe 中执行 upsertion。并覆盖已有的sql表。

          【讨论】:

            猜你喜欢
            • 2020-10-22
            • 1970-01-01
            • 2018-08-20
            • 1970-01-01
            • 1970-01-01
            • 2018-02-22
            • 2016-10-14
            • 1970-01-01
            相关资源
            最近更新 更多