Parcourir la source

[LIVY-455][REPL] Fix json4s doesn't support java.math.BigDecimal issue

## What changes were proposed in this pull request?

Livy's SQLInterpreter will throw exception when rows contain java.math.BigDecimal data. This is because current version of json4s doesn't treat java.math.BigDecimal type as primitive type. On the contrary, json4s supports Scala BigDecimal as primitive type. So the fix is to convert java BigDecimal to Scala BigDecimal.

## How was this patch tested?

Unit test is added.

Author: jerryshao <sshao@hortonworks.com>

Closes #85 from jerryshao/LIVY-455.
jerryshao il y a 7 ans
Parent
commit
7e4bb3bd68

+ 9 - 1
repl/src/main/scala/org/apache/livy/repl/SQLInterpreter.scala

@@ -97,7 +97,15 @@ class SQLInterpreter(
       val rows = result.getClass.getMethod("take", classOf[Int])
         .invoke(result, maxResult: java.lang.Integer)
         .asInstanceOf[Array[Row]]
-          .map(_.toSeq)
+        .map {
+          _.toSeq.map {
+            // Convert java BigDecimal type to Scala BigDecimal, because current version of
+            // Json4s doesn't support java BigDecimal as a primitive type (LIVY-455).
+            case i: java.math.BigDecimal => BigDecimal(i)
+            case e => e
+          }
+        }
+
       val jRows = Extraction.decompose(rows)
 
       Interpreter.ExecuteSuccess(

+ 36 - 0
repl/src/test/scala/org/apache/livy/repl/SQLInterpreterSpec.scala

@@ -95,6 +95,42 @@ class SQLInterpreterSpec extends BaseInterpreterSpec {
     ))
   }
 
+  it should "handle java BigDecimal" in withInterpreter { interpreter =>
+    val rdd = sparkEntries.sc().parallelize(Seq(
+      ("1", new java.math.BigDecimal(1.0)),
+      ("2", new java.math.BigDecimal(2.0))))
+    val df = sparkEntries.sqlctx().createDataFrame(rdd).selectExpr("_1 as col1", "_2 as col2")
+    df.registerTempTable("test")
+
+    val resp1 = interpreter.execute(
+      """
+        |SELECT * FROM test
+      """.stripMargin)
+
+    val expectedResult = (nullable: Boolean) => {
+      Interpreter.ExecuteSuccess(
+        APPLICATION_JSON -> (("schema" ->
+          (("type" -> "struct") ~
+            ("fields" -> List(
+              ("name" -> "col1") ~ ("type" -> "string") ~ ("nullable" -> true) ~
+                ("metadata" -> List()),
+              ("name" -> "col2") ~ ("type" -> "decimal(38,18)") ~ ("nullable" -> nullable) ~
+                ("metadata" -> List())
+            )))) ~
+          ("data" -> List(
+            List[JValue]("1", 1.0d),
+            List[JValue]("2", 2.0d)
+          )))
+      )
+    }
+
+    val result = Try { resp1 should equal(expectedResult(false))}
+      .orElse(Try { resp1 should equal(expectedResult(true)) })
+    if (result.isFailure) {
+      fail(s"$resp1 doesn't equal to expected result")
+    }
+  }
+
   it should "throw exception for illegal query" in withInterpreter { interpreter =>
     val resp = interpreter.execute(
       """