|
@@ -18,10 +18,13 @@
|
|
|
|
|
|
package com.cloudera.livy.test
|
|
|
|
|
|
+import java.util.concurrent.atomic.AtomicInteger
|
|
|
+
|
|
|
import scala.concurrent.duration._
|
|
|
|
|
|
import org.scalatest.BeforeAndAfter
|
|
|
|
|
|
+import com.cloudera.livy.sessions._
|
|
|
import com.cloudera.livy.test.framework.BaseIntegrationTestSuite
|
|
|
|
|
|
private case class TestStatement(
|
|
@@ -35,23 +38,56 @@ class InteractiveIT extends BaseIntegrationTestSuite with BeforeAndAfter {
|
|
|
|
|
|
after {
|
|
|
livyClient.stopSession(sessionId)
|
|
|
+ sessionId = -1
|
|
|
}
|
|
|
|
|
|
test("basic interactive session") {
|
|
|
- sessionId = livyClient.startSession()
|
|
|
+ sessionId = livyClient.startSession(Spark())
|
|
|
|
|
|
val testStmts = List(
|
|
|
new TestStatement("1+1", Some("res0: Int = 2")),
|
|
|
- new TestStatement("val hiveContext = new org.apache.spark.sql.hive.HiveContext(sc)",
|
|
|
- Some("hiveContext: org.apache.spark.sql.hive.HiveContext = " +
|
|
|
- "org.apache.spark.sql.hive.HiveContext")))
|
|
|
+ new TestStatement("val sqlContext = new org.apache.spark.sql.SQLContext(sc)",
|
|
|
+ Some("sqlContext: org.apache.spark.sql.SQLContext = " +
|
|
|
+ "org.apache.spark.sql.SQLContext")))
|
|
|
+ runAndValidateStatements(testStmts)
|
|
|
+ }
|
|
|
|
|
|
- waitTillSessionIdle(sessionId)
|
|
|
+ pytest("pyspark interactive session") {
|
|
|
+ sessionId = livyClient.startSession(PySpark())
|
|
|
|
|
|
- // Run the statements
|
|
|
- testStmts.foreach {
|
|
|
- runAndValidateStatement(_)
|
|
|
- }
|
|
|
+ val testStmts = List(
|
|
|
+ new TestStatement("1+1", Some("2")),
|
|
|
+ new TestStatement(
|
|
|
+ "sc.parallelize(range(100)).map(lambda x: x * 2).reduce(lambda x, y: x + y)",
|
|
|
+ Some("9900")))
|
|
|
+ runAndValidateStatements(testStmts)
|
|
|
+ }
|
|
|
+
|
|
|
+ rtest("R interactive session") {
|
|
|
+ sessionId = livyClient.startSession(SparkR())
|
|
|
+
|
|
|
+ // R's output sometimes includes the count of statements, which makes it annoying to test
|
|
|
+ // things. This helps a bit.
|
|
|
+ val curr = new AtomicInteger()
|
|
|
+ def count: Int = curr.incrementAndGet()
|
|
|
+
|
|
|
+ val testStmts = List(
|
|
|
+ new TestStatement("1+1", Some(s"[$count] 2")),
|
|
|
+ new TestStatement("sqlContext <- sparkRSQL.init(sc)", None),
|
|
|
+ new TestStatement(
|
|
|
+ """localDF <- data.frame(name=c("John", "Smith", "Sarah"), age=c(19, 23, 18))""", None),
|
|
|
+ new TestStatement("df <- createDataFrame(sqlContext, localDF)", None),
|
|
|
+ new TestStatement("printSchema(df)", Some(
|
|
|
+ """|root
|
|
|
+ | |-- name: string (nullable = true)
|
|
|
+ | |-- age: double (nullable = true)""".stripMargin))
|
|
|
+ )
|
|
|
+ runAndValidateStatements(testStmts)
|
|
|
+ }
|
|
|
+
|
|
|
+ private def runAndValidateStatements(statements: Seq[TestStatement]) = {
|
|
|
+ waitTillSessionIdle(sessionId)
|
|
|
+ statements.foreach(runAndValidateStatement)
|
|
|
}
|
|
|
|
|
|
private def runAndValidateStatement(testStmt: TestStatement) = {
|
|
@@ -61,8 +97,7 @@ class InteractiveIT extends BaseIntegrationTestSuite with BeforeAndAfter {
|
|
|
|
|
|
testStmt.expectedResult.map { s =>
|
|
|
val result = livyClient.getStatementResult(sessionId, testStmt.stmtId)
|
|
|
- assert(result.indexOf(s) >= 0,
|
|
|
- s"Statement result doesn't match. Expected: $s. Actual: $result")
|
|
|
+ assert(result.contains(s))
|
|
|
}
|
|
|
|
|
|
}
|