在您的代码中,decimalType 实际上不是 Scala 类型标识符 - 它是 DecimalType 类的值。所以,你不能在编译器需要类型标识符的地方使用它。
为了编写 UDF,您可以只使用 java.math.BigDecimal 作为参数类型。无需指定精度和比例。但是,如果您确实需要为 UDF 中的计算设置这些值,您可以尝试在 MathContext 中指定它们。
package HelloSpec.parser
import com.holdenkarau.spark.testing.{DataFrameSuiteBase, SharedSparkContext}
import org.apache.spark.sql.types.{DecimalType, StructField, StructType}
import org.scalatest.FlatSpec
case class SKU(price_usd: BigDecimal)
object Fields {
val PRICE_USD = "price_usd"
}
class TestSo extends FlatSpec with DataFrameSuiteBase with SharedSparkContext {
import Fields._
it should "not fail" in {
import spark.implicits._
val df = Seq(
SKU(BigDecimal("1.12")),
SKU(BigDecimal("1234567890123456.12")),
SKU(BigDecimal("1234567890123456.123")),
SKU(BigDecimal("12345678901234567.12"))
).toDF
df.printSchema()
df.show(truncate = false)
assert(
df.schema ==
StructType(Seq(StructField(name = PRICE_USD, dataType = DecimalType(38, 18))))
)
val castedTo18_2 = df.withColumn(PRICE_USD, df(PRICE_USD).cast(DecimalType(18, 2)))
castedTo18_2.printSchema()
castedTo18_2.show(truncate = false)
assert(
castedTo18_2.schema ==
StructType(Seq(StructField(name = PRICE_USD, dataType = DecimalType(18, 2))))
)
assert {
castedTo18_2.as[Option[BigDecimal]].collect.toSeq.sorted == Seq(
// this was 12345678901234567.12 before the cast,
// but the number with 17 digits before the decimal point exceeded the 18-2=16 allowed digits
None,
Some(BigDecimal("1.12")),
Some(BigDecimal("1234567890123456.12")),
// note, that 1234567890123456.123 was rounded to 1234567890123456.12
Some(BigDecimal("1234567890123456.12"))
)
}
import org.apache.spark.sql.functions.{udf, col}
val processBigDecimal = udf(
// The argument type has to be java.math.BigDecimal, not scala.math.BigDecimal, which is imported by default
(bd: java.math.BigDecimal) => {
if (bd == null) {
null
} else {
s"${bd.getClass} with precision ${bd.precision}, scale ${bd.scale} and value $bd"
}
}
)
val withUdfApplied = castedTo18_2.
withColumn("udf_result", processBigDecimal(col(PRICE_USD)))
withUdfApplied.printSchema()
withUdfApplied.show(truncate = false)
assert(
withUdfApplied.as[(Option[BigDecimal], String)].collect.toSeq.sorted == Seq(
None -> null,
Some(BigDecimal("1.12")) -> "class java.math.BigDecimal with precision 19, scale 18 and value 1.120000000000000000",
Some(BigDecimal("1234567890123456.12")) -> "class java.math.BigDecimal with precision 34, scale 18 and value 1234567890123456.120000000000000000",
Some(BigDecimal("1234567890123456.12")) -> "class java.math.BigDecimal with precision 34, scale 18 and value 1234567890123456.120000000000000000"
)
)
}
}