Browse Source

[Security] Update to support pyspark and sparkr changes in Spark 2.3.1

jerryshao 6 năm trước cách đây
mục cha
commit
2196302731

+ 9 - 2
core/src/main/scala/org/apache/livy/Utils.scala

@@ -17,10 +17,11 @@
 
 package org.apache.livy
 
-import java.io.{Closeable, File, FileInputStream, InputStreamReader}
+import java.io.{Closeable, File, InputStreamReader}
 import java.net.URL
 import java.nio.charset.StandardCharsets.UTF_8
-import java.util.Properties
+import java.security.SecureRandom
+import java.util.{Base64, Properties}
 
 import scala.annotation.tailrec
 import scala.collection.JavaConverters._
@@ -106,4 +107,10 @@ object Utils {
     }
   }
 
+  def createSecret(secretBitLength: Int): String = {
+    val rnd = new SecureRandom()
+    val secretBytes = new Array[Byte](secretBitLength / java.lang.Byte.SIZE)
+    rnd.nextBytes(secretBytes)
+    Base64.getEncoder.encodeToString(secretBytes)
+  }
 }

+ 4 - 4
pom.xml

@@ -1196,13 +1196,13 @@
         </property>
       </activation>
       <properties>
-        <spark.scala-2.11.version>2.3.0</spark.scala-2.11.version>
+        <spark.scala-2.11.version>2.3.1</spark.scala-2.11.version>
         <spark.scala-2.10.version>2.2.0</spark.scala-2.10.version>
         <spark.version>${spark.scala-2.11.version}</spark.version>
         <netty.spark-2.11.version>4.1.17.Final</netty.spark-2.11.version>
         <netty.spark-2.10.version>4.0.37.Final</netty.spark-2.10.version>
         <java.version>1.8</java.version>
-        <py4j.version>0.10.4</py4j.version>
+        <py4j.version>0.10.7</py4j.version>
         <json4s.version>3.2.11</json4s.version>
       </properties>
     </profile>
@@ -1216,9 +1216,9 @@
       </activation>
       <properties>
         <spark.bin.download.url>
-          http://apache.mirrors.ionfish.org/spark/spark-2.3.0/spark-2.3.0-bin-hadoop2.7.tgz
+          http://mirrors.advancedhosters.com/apache/spark/spark-2.3.1/spark-2.3.1-bin-hadoop2.7.tgz
         </spark.bin.download.url>
-        <spark.bin.name>spark-2.3.0-bin-hadoop2.7</spark.bin.name>
+        <spark.bin.name>spark-2.3.1-bin-hadoop2.7</spark.bin.name>
       </properties>
     </profile>
 

+ 17 - 6
repl/src/main/resources/fake_shell.py

@@ -569,7 +569,13 @@ def main():
             from pyspark.sql import SQLContext, HiveContext, Row
             # Connect to the gateway
             gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"])
-            gateway = JavaGateway(GatewayClient(port=gateway_port), auto_convert=True)
+            try:
+                from py4j.java_gateway import GatewayParameters
+                gateway_secret = os.environ["PYSPARK_GATEWAY_SECRET"]
+                gateway = JavaGateway(gateway_parameters=GatewayParameters(
+                    port=gateway_port, auth_token=gateway_secret, auto_convert=True))
+            except:
+                gateway = JavaGateway(GatewayClient(port=gateway_port), auto_convert=True)
 
             # Import the classes used by PySpark
             java_import(gateway.jvm, "org.apache.spark.SparkConf")
@@ -613,12 +619,17 @@ def main():
 
             #Start py4j callback server
             from py4j.protocol import ENTRY_POINT_OBJECT_ID
-            from py4j.java_gateway import JavaGateway, GatewayClient, CallbackServerParameters
+            from py4j.java_gateway import CallbackServerParameters
+
+            try:
+                gateway_secret = os.environ["PYSPARK_GATEWAY_SECRET"]
+                gateway.start_callback_server(
+                    callback_server_parameters=CallbackServerParameters(
+                        port=0, auth_token=gateway_secret))
+            except:
+                gateway.start_callback_server(
+                    callback_server_parameters=CallbackServerParameters(port=0))
 
-            gateway_client_port = int(os.environ.get("PYSPARK_GATEWAY_PORT"))
-            gateway = JavaGateway(GatewayClient(port=gateway_client_port))
-            gateway.start_callback_server(
-                callback_server_parameters=CallbackServerParameters(port=0))
             socket_info = gateway._callback_server.server_socket.getsockname()
             listening_port = socket_info[1]
             pyspark_job_processor = PySparkJobProcessorImpl()

+ 28 - 2
repl/src/main/scala/org/apache/livy/repl/PythonInterpreter.scala

@@ -18,8 +18,10 @@
 package org.apache.livy.repl
 
 import java.io._
+import java.lang.{Integer => JInteger}
 import java.lang.ProcessBuilder.Redirect
 import java.lang.reflect.Proxy
+import java.net.InetAddress
 import java.nio.file.{Files, Paths}
 
 import scala.annotation.tailrec
@@ -35,7 +37,7 @@ import org.json4s.jackson.Serialization.write
 import py4j._
 import py4j.reflection.PythonProxyHandler
 
-import org.apache.livy.Logging
+import org.apache.livy.{Logging, Utils}
 import org.apache.livy.client.common.ClientConf
 import org.apache.livy.rsc.driver.SparkEntries
 import org.apache.livy.sessions._
@@ -49,7 +51,8 @@ object PythonInterpreter extends Logging {
       .orElse(sys.props.get("pyspark.python")) // This java property is only used for internal UT.
       .getOrElse("python")
 
-    val gatewayServer = new GatewayServer(sparkEntries, 0)
+    val secretKey = Utils.createSecret(256)
+    val gatewayServer = createGatewayServer(sparkEntries, secretKey)
     gatewayServer.start()
 
     val builder = new ProcessBuilder(Seq(pythonExec, createFakeShell().toString).asJava)
@@ -65,6 +68,7 @@ object PythonInterpreter extends Logging {
     env.put("PYTHONPATH", pythonPath.mkString(File.pathSeparator))
     env.put("PYTHONUNBUFFERED", "YES")
     env.put("PYSPARK_GATEWAY_PORT", "" + gatewayServer.getListeningPort)
+    env.put("PYSPARK_GATEWAY_SECRET", secretKey)
     env.put("SPARK_HOME", sys.env.getOrElse("SPARK_HOME", "."))
     env.put("LIVY_SPARK_MAJOR_VERSION", conf.get("spark.livy.spark_major_version", "1"))
     builder.redirectError(Redirect.PIPE)
@@ -131,6 +135,28 @@ object PythonInterpreter extends Logging {
     file
   }
 
+  private def createGatewayServer(sparkEntries: SparkEntries, secretKey: String): GatewayServer = {
+    try {
+      val clz = Class.forName("py4j.GatewayServer$GatewayServerBuilder", true,
+        Thread.currentThread().getContextClassLoader)
+      val builder = clz.getConstructor(classOf[Object])
+        .newInstance(sparkEntries)
+
+      val localhost = InetAddress.getLoopbackAddress()
+      builder.getClass.getMethod("authToken", classOf[String]).invoke(builder, secretKey)
+      builder.getClass.getMethod("javaPort", classOf[Int]).invoke(builder, 0: JInteger)
+      builder.getClass.getMethod("javaAddress", classOf[InetAddress]).invoke(builder, localhost)
+      builder.getClass
+        .getMethod("callbackClient", classOf[Int], classOf[InetAddress], classOf[String])
+        .invoke(builder, GatewayServer.DEFAULT_PYTHON_PORT: JInteger, localhost, secretKey)
+      builder.getClass.getMethod("build").invoke(builder).asInstanceOf[GatewayServer]
+    } catch {
+      case NonFatal(e) =>
+        warn("Fail to create GatewayServer with auth parameter, downgrade to old constructor", e)
+        new GatewayServer(sparkEntries, 0)
+    }
+  }
+
   private def initiatePy4jCallbackGateway(server: GatewayServer): PySparkJobProcessor = {
     val f = server.getClass.getDeclaredField("gateway")
     f.setAccessible(true)

+ 33 - 7
repl/src/main/scala/org/apache/livy/repl/SparkRInterpreter.scala

@@ -24,6 +24,8 @@ import java.util.concurrent.{CountDownLatch, Semaphore, TimeUnit}
 import scala.annotation.tailrec
 import scala.collection.JavaConverters._
 import scala.reflect.runtime.universe
+import scala.util.Try
+import scala.util.control.NonFatal
 
 import org.apache.commons.codec.binary.Base64
 import org.apache.commons.lang.StringEscapeUtils
@@ -33,13 +35,14 @@ import org.apache.spark.sql.SQLContext
 import org.json4s._
 import org.json4s.JsonDSL._
 
+import org.apache.livy.Logging
 import org.apache.livy.client.common.ClientConf
 import org.apache.livy.rsc.driver.SparkEntries
 
 private case class RequestResponse(content: String, error: Boolean)
 
 // scalastyle:off println
-object SparkRInterpreter {
+object SparkRInterpreter extends Logging {
   private val LIVY_END_MARKER = "----LIVY_END_OF_COMMAND----"
   private val LIVY_ERROR_MARKER = "----LIVY_END_OF_ERROR----"
   private val PRINT_MARKER = f"""print("$LIVY_END_MARKER")"""
@@ -76,12 +79,25 @@ object SparkRInterpreter {
     val backendInstance = sparkRBackendClass.getDeclaredConstructor().newInstance()
 
     var sparkRBackendPort = 0
+    var sparkRBackendSecret: String = null
     val initialized = new Semaphore(0)
     // Launch a SparkR backend server for the R process to connect to
     val backendThread = new Thread("SparkR backend") {
       override def run(): Unit = {
-        sparkRBackendPort = sparkRBackendClass.getMethod("init").invoke(backendInstance)
-          .asInstanceOf[Int]
+        try {
+          sparkRBackendPort = sparkRBackendClass.getMethod("init").invoke(backendInstance)
+            .asInstanceOf[Int]
+        } catch {
+          case NonFatal(e) =>
+            warn("Fail to init Spark RBackend, using different method signature", e)
+            val retTuple = sparkRBackendClass.getMethod("init").invoke(backendInstance)
+              .asInstanceOf[(Int, Object)]
+            sparkRBackendPort = retTuple._1
+            sparkRBackendSecret = Try {
+              val rAuthHelper = retTuple._2
+              rAuthHelper.getClass.getMethod("secret").invoke(rAuthHelper).asInstanceOf[String]
+            }.getOrElse(null)
+        }
 
         initialized.release()
         sparkRBackendClass.getMethod("run").invoke(backendInstance)
@@ -116,6 +132,9 @@ object SparkRInterpreter {
       val env = builder.environment()
       env.put("SPARK_HOME", sys.env.getOrElse("SPARK_HOME", "."))
       env.put("EXISTING_SPARKR_BACKEND_PORT", sparkRBackendPort.toString)
+      if (sparkRBackendSecret != null) {
+        env.put("SPARKR_BACKEND_AUTH_SECRET", sparkRBackendSecret)
+      }
       env.put("SPARKR_PACKAGE_DIR", packageDir)
       env.put("R_PROFILE_USER",
         Seq(packageDir, "SparkR", "profile", "general.R").mkString(File.separator))
@@ -123,7 +142,7 @@ object SparkRInterpreter {
       builder.redirectErrorStream(true)
       val process = builder.start()
       new SparkRInterpreter(process, backendInstance, backendThread,
-        conf.getInt("spark.livy.spark_major_version", 1))
+        conf.getInt("spark.livy.spark_major_version", 1), sparkRBackendSecret != null)
     } catch {
       case e: Exception =>
         if (backendThread != null) {
@@ -149,10 +168,12 @@ object SparkRInterpreter {
   }
 }
 
-class SparkRInterpreter(process: Process,
+class SparkRInterpreter(
+    process: Process,
     backendInstance: Any,
     backendThread: Thread,
-    val sparkMajorVersion: Int)
+    val sparkMajorVersion: Int,
+    authProvided: Boolean)
   extends ProcessInterpreter(process) {
   import SparkRInterpreter._
 
@@ -169,7 +190,12 @@ class SparkRInterpreter(process: Process,
       // scalastyle:off line.size.limit
       sendRequest("library(SparkR)")
       sendRequest("""port <- Sys.getenv("EXISTING_SPARKR_BACKEND_PORT", "")""")
-      sendRequest("""SparkR:::connectBackend("localhost", port, 6000)""")
+      if (authProvided) {
+        sendRequest("""authSecret <- Sys.getenv("SPARKR_BACKEND_AUTH_SECRET", "")""")
+        sendRequest("""SparkR:::connectBackend("localhost", port, 6000, authSecret)""")
+      } else {
+        sendRequest("""SparkR:::connectBackend("localhost", port, 6000)""")
+      }
       sendRequest("""assign(".scStartTime", as.integer(Sys.time()), envir = SparkR:::.sparkREnv)""")
 
       sendRequest("""assign(".sc", SparkR:::callJStatic("org.apache.livy.repl.SparkRInterpreter", "getSparkContext"), envir = SparkR:::.sparkREnv)""")