|
@@ -24,6 +24,8 @@ import java.util.concurrent.{CountDownLatch, Semaphore, TimeUnit}
|
|
|
import scala.annotation.tailrec
|
|
|
import scala.collection.JavaConverters._
|
|
|
import scala.reflect.runtime.universe
|
|
|
+import scala.util.Try
|
|
|
+import scala.util.control.NonFatal
|
|
|
|
|
|
import org.apache.commons.codec.binary.Base64
|
|
|
import org.apache.commons.lang.StringEscapeUtils
|
|
@@ -33,13 +35,14 @@ import org.apache.spark.sql.SQLContext
|
|
|
import org.json4s._
|
|
|
import org.json4s.JsonDSL._
|
|
|
|
|
|
+import org.apache.livy.Logging
|
|
|
import org.apache.livy.client.common.ClientConf
|
|
|
import org.apache.livy.rsc.driver.SparkEntries
|
|
|
|
|
|
private case class RequestResponse(content: String, error: Boolean)
|
|
|
|
|
|
// scalastyle:off println
|
|
|
-object SparkRInterpreter {
|
|
|
+object SparkRInterpreter extends Logging {
|
|
|
private val LIVY_END_MARKER = "----LIVY_END_OF_COMMAND----"
|
|
|
private val LIVY_ERROR_MARKER = "----LIVY_END_OF_ERROR----"
|
|
|
private val PRINT_MARKER = f"""print("$LIVY_END_MARKER")"""
|
|
@@ -76,12 +79,25 @@ object SparkRInterpreter {
|
|
|
val backendInstance = sparkRBackendClass.getDeclaredConstructor().newInstance()
|
|
|
|
|
|
var sparkRBackendPort = 0
|
|
|
+ var sparkRBackendSecret: String = null
|
|
|
val initialized = new Semaphore(0)
|
|
|
// Launch a SparkR backend server for the R process to connect to
|
|
|
val backendThread = new Thread("SparkR backend") {
|
|
|
override def run(): Unit = {
|
|
|
- sparkRBackendPort = sparkRBackendClass.getMethod("init").invoke(backendInstance)
|
|
|
- .asInstanceOf[Int]
|
|
|
+ try {
|
|
|
+ sparkRBackendPort = sparkRBackendClass.getMethod("init").invoke(backendInstance)
|
|
|
+ .asInstanceOf[Int]
|
|
|
+ } catch {
|
|
|
+ case NonFatal(e) =>
|
|
|
+ warn("Fail to init Spark RBackend, using different method signature", e)
|
|
|
+ val retTuple = sparkRBackendClass.getMethod("init").invoke(backendInstance)
|
|
|
+ .asInstanceOf[(Int, Object)]
|
|
|
+ sparkRBackendPort = retTuple._1
|
|
|
+ sparkRBackendSecret = Try {
|
|
|
+ val rAuthHelper = retTuple._2
|
|
|
+ rAuthHelper.getClass.getMethod("secret").invoke(rAuthHelper).asInstanceOf[String]
|
|
|
+ }.getOrElse(null)
|
|
|
+ }
|
|
|
|
|
|
initialized.release()
|
|
|
sparkRBackendClass.getMethod("run").invoke(backendInstance)
|
|
@@ -116,6 +132,9 @@ object SparkRInterpreter {
|
|
|
val env = builder.environment()
|
|
|
env.put("SPARK_HOME", sys.env.getOrElse("SPARK_HOME", "."))
|
|
|
env.put("EXISTING_SPARKR_BACKEND_PORT", sparkRBackendPort.toString)
|
|
|
+ if (sparkRBackendSecret != null) {
|
|
|
+ env.put("SPARKR_BACKEND_AUTH_SECRET", sparkRBackendSecret)
|
|
|
+ }
|
|
|
env.put("SPARKR_PACKAGE_DIR", packageDir)
|
|
|
env.put("R_PROFILE_USER",
|
|
|
Seq(packageDir, "SparkR", "profile", "general.R").mkString(File.separator))
|
|
@@ -123,7 +142,7 @@ object SparkRInterpreter {
|
|
|
builder.redirectErrorStream(true)
|
|
|
val process = builder.start()
|
|
|
new SparkRInterpreter(process, backendInstance, backendThread,
|
|
|
- conf.getInt("spark.livy.spark_major_version", 1))
|
|
|
+ conf.getInt("spark.livy.spark_major_version", 1), sparkRBackendSecret != null)
|
|
|
} catch {
|
|
|
case e: Exception =>
|
|
|
if (backendThread != null) {
|
|
@@ -149,10 +168,12 @@ object SparkRInterpreter {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-class SparkRInterpreter(process: Process,
|
|
|
+class SparkRInterpreter(
|
|
|
+ process: Process,
|
|
|
backendInstance: Any,
|
|
|
backendThread: Thread,
|
|
|
- val sparkMajorVersion: Int)
|
|
|
+ val sparkMajorVersion: Int,
|
|
|
+ authProvided: Boolean)
|
|
|
extends ProcessInterpreter(process) {
|
|
|
import SparkRInterpreter._
|
|
|
|
|
@@ -169,7 +190,12 @@ class SparkRInterpreter(process: Process,
|
|
|
// scalastyle:off line.size.limit
|
|
|
sendRequest("library(SparkR)")
|
|
|
sendRequest("""port <- Sys.getenv("EXISTING_SPARKR_BACKEND_PORT", "")""")
|
|
|
- sendRequest("""SparkR:::connectBackend("localhost", port, 6000)""")
|
|
|
+ if (authProvided) {
|
|
|
+ sendRequest("""authSecret <- Sys.getenv("SPARKR_BACKEND_AUTH_SECRET", "")""")
|
|
|
+ sendRequest("""SparkR:::connectBackend("localhost", port, 6000, authSecret)""")
|
|
|
+ } else {
|
|
|
+ sendRequest("""SparkR:::connectBackend("localhost", port, 6000)""")
|
|
|
+ }
|
|
|
sendRequest("""assign(".scStartTime", as.integer(Sys.time()), envir = SparkR:::.sparkREnv)""")
|
|
|
|
|
|
sendRequest("""assign(".sc", SparkR:::callJStatic("org.apache.livy.repl.SparkRInterpreter", "getSparkContext"), envir = SparkR:::.sparkREnv)""")
|