Sfoglia il codice sorgente

LIVY-244. Replace livy-server repl state polling thread with RPC. (#231)

- Remove repl state polling thread in InteractiveSession.
- Modified Rpc protocol between livy-server and livy-repl to push repl state changes using a Rpc call from repl to server.
- Changed InteractiveSession's state to reflect repl state returned from Rpc protocol.
Alex Man 8 anni fa
parent
commit
fe9286825c

+ 16 - 0
core/src/main/scala/com/cloudera/livy/sessions/SessionState.scala

@@ -25,6 +25,22 @@ sealed trait SessionState {
 
 object SessionState {
 
+  def apply(s: String): SessionState = {
+    s match {
+      case "not_started" => NotStarted()
+      case "starting" => Starting()
+      case "recovering" => Recovering()
+      case "idle" => Idle()
+      case "running" => Running()
+      case "busy" => Busy()
+      case "shutting_down" => ShuttingDown()
+      case "error" => Error()
+      case "dead" => Dead()
+      case "success" => Success()
+      case _ => throw new IllegalArgumentException(s"Illegal session state: $s")
+    }
+  }
+
   case class NotStarted() extends SessionState {
     override def isActive: Boolean = true
 

+ 20 - 13
repl/src/main/scala/com/cloudera/livy/repl/ReplDriver.scala

@@ -26,8 +26,10 @@ import org.apache.spark.SparkConf
 import org.apache.spark.api.java.JavaSparkContext
 
 import com.cloudera.livy.Logging
-import com.cloudera.livy.rsc.{BaseProtocol, ReplJobResults, RSCConf}
+import com.cloudera.livy.rsc.BaseProtocol.ReplState
+import com.cloudera.livy.rsc.{BaseProtocol, RSCConf, ReplJobResults}
 import com.cloudera.livy.rsc.driver._
+import com.cloudera.livy.rsc.rpc.Rpc
 import com.cloudera.livy.sessions._
 
 class ReplDriver(conf: SparkConf, livyConf: RSCConf)
@@ -47,7 +49,8 @@ class ReplDriver(conf: SparkConf, livyConf: RSCConf)
       case Spark() => new SparkInterpreter(conf)
       case SparkR() => SparkRInterpreter(conf)
     }
-    session = new Session(interpreter)
+    session = new Session(interpreter, { s => broadcast(new ReplState(s.toString)) })
+
     Option(Await.result(session.start(), Duration.Inf))
       .map(new JavaSparkContext(_))
       .orNull
@@ -70,22 +73,20 @@ class ReplDriver(conf: SparkConf, livyConf: RSCConf)
   /**
    * Return statement results. Results are sorted by statement id.
    */
-  def handle(ctx: ChannelHandlerContext, msg: BaseProtocol.GetReplJobResults): ReplJobResults =
-    session.synchronized {
-      val stmts = if (msg.allResults) {
-        session.statements.values.toArray
+  def handle(ctx: ChannelHandlerContext, msg: BaseProtocol.GetReplJobResults): ReplJobResults = {
+    val statements = if (msg.allResults) {
+      session.statements.values.toArray
+    } else {
+      assert(msg.from != null)
+      assert(msg.size != null)
+      if (msg.size == 1) {
+        session.statements.get(msg.from).toArray
       } else {
-        assert(msg.from != null)
-        assert(msg.size != null)
         val until = msg.from + msg.size
         session.statements.filterKeys(id => id >= msg.from && id < until).values.toArray
       }
-      val state = session.state.toString
-      new ReplJobResults(stmts.sortBy(_.id), state)
     }
-
-  def handle(ctx: ChannelHandlerContext, msg: BaseProtocol.GetReplState): String = {
-    session.state.toString
+    new ReplJobResults(statements.sortBy(_.id))
   }
 
   override protected def createWrapper(msg: BaseProtocol.BypassJobRequest): BypassJobWrapper = {
@@ -108,4 +109,10 @@ class ReplDriver(conf: SparkConf, livyConf: RSCConf)
       case _ => super.addJarOrPyFile(path)
     }
   }
+
+  override protected def onClientAuthenticated(client: Rpc): Unit = {
+    if (session != null) {
+      client.call(new ReplState(session.state.toString))
+    }
+  }
 }

+ 32 - 24
repl/src/main/scala/com/cloudera/livy/repl/Session.scala

@@ -21,7 +21,7 @@ package com.cloudera.livy.repl
 import java.util.concurrent.Executors
 import java.util.concurrent.atomic.AtomicInteger
 
-import scala.collection.mutable
+import scala.collection.concurrent.TrieMap
 import scala.concurrent.{ExecutionContext, Future}
 
 import org.apache.spark.SparkContext
@@ -44,7 +44,7 @@ object Session {
   val TRACEBACK = "traceback"
 }
 
-class Session(interpreter: Interpreter)
+class Session(interpreter: Interpreter, stateChangedCallback: SessionState => Unit = { _ => } )
   extends Logging
 {
   import Session._
@@ -54,19 +54,21 @@ class Session(interpreter: Interpreter)
   private implicit val formats = DefaultFormats
 
   private var _state: SessionState = SessionState.NotStarted()
-  private val _statements = mutable.Map[Int, Statement]()
+  private val _statements = TrieMap[Int, Statement]()
 
   private val newStatementId = new AtomicInteger(0)
 
+  stateChangedCallback(_state)
+
   def start(): Future[SparkContext] = {
     val future = Future {
-      _state = SessionState.Starting()
+      changeState(SessionState.Starting())
       val sc = interpreter.start()
-      _state = SessionState.Idle()
+      changeState(SessionState.Idle())
       sc
     }
     future.onFailure { case _ =>
-      _state = SessionState.Error(System.currentTimeMillis())
+      changeState(SessionState.Error())
     }
     future
   }
@@ -75,24 +77,16 @@ class Session(interpreter: Interpreter)
 
   def state: SessionState = _state
 
-  def statements: mutable.Map[Int, Statement] = _statements
+  def statements: collection.Map[Int, Statement] = _statements.readOnlySnapshot()
 
   def execute(code: String): Int = {
     val statementId = newStatementId.getAndIncrement()
-    synchronized {
-      _statements(statementId) = new Statement(statementId, StatementState.Waiting, null)
-    }
+    _statements(statementId) = new Statement(statementId, StatementState.Waiting, null)
     Future {
-      synchronized {
-        _statements(statementId) = new Statement(statementId, StatementState.Running, null)
-      }
+      _statements(statementId) = new Statement(statementId, StatementState.Running, null)
 
-      val statement =
+      _statements(statementId) =
         new Statement(statementId, StatementState.Available, executeCode(statementId, code))
-
-      synchronized {
-        _statements(statementId) = statement
-      }
     }
     statementId
   }
@@ -106,20 +100,34 @@ class Session(interpreter: Interpreter)
     _statements.clear()
   }
 
+  private def changeState(newState: SessionState): Unit = {
+    synchronized {
+      _state = newState
+    }
+    stateChangedCallback(newState)
+  }
+
   private def executeCode(executionCount: Int, code: String): String = synchronized {
-    _state = SessionState.Busy()
+    changeState(SessionState.Busy())
+
+    def transitToIdle() = {
+      val executingLastStatement = executionCount == newStatementId.intValue() - 1
+      if (_statements.isEmpty || executingLastStatement) {
+        changeState(SessionState.Idle())
+      }
+    }
 
     val resultInJson = try {
       interpreter.execute(code) match {
         case Interpreter.ExecuteSuccess(data) =>
-          _state = SessionState.Idle()
+          transitToIdle()
 
           (STATUS -> OK) ~
           (EXECUTION_COUNT -> executionCount) ~
           (DATA -> data)
 
         case Interpreter.ExecuteIncomplete() =>
-          _state = SessionState.Idle()
+          transitToIdle()
 
           (STATUS -> ERROR) ~
           (EXECUTION_COUNT -> executionCount) ~
@@ -128,7 +136,7 @@ class Session(interpreter: Interpreter)
           (TRACEBACK -> List())
 
         case Interpreter.ExecuteError(ename, evalue, traceback) =>
-          _state = SessionState.Idle()
+          transitToIdle()
 
           (STATUS -> ERROR) ~
           (EXECUTION_COUNT -> executionCount) ~
@@ -137,7 +145,7 @@ class Session(interpreter: Interpreter)
           (TRACEBACK -> traceback)
 
         case Interpreter.ExecuteAborted(message) =>
-          _state = SessionState.Error(System.nanoTime())
+          changeState(SessionState.Error())
 
           (STATUS -> ERROR) ~
           (EXECUTION_COUNT -> executionCount) ~
@@ -149,7 +157,7 @@ class Session(interpreter: Interpreter)
       case e: Throwable =>
         error("Exception when executing code", e)
 
-        _state = SessionState.Idle()
+        transitToIdle()
 
         (STATUS -> ERROR) ~
         (EXECUTION_COUNT -> executionCount) ~

+ 8 - 1
repl/src/test/scala/com/cloudera/livy/repl/BaseSessionSpec.scala

@@ -18,6 +18,8 @@
 
 package com.cloudera.livy.repl
 
+import java.util.concurrent.atomic.AtomicInteger
+
 import scala.concurrent.Await
 import scala.concurrent.duration._
 import scala.language.postfixOps
@@ -44,10 +46,15 @@ abstract class BaseSessionSpec extends FlatSpec with Matchers with LivyBaseUnitT
   }
 
   protected def withSession(testCode: Session => Any): Unit = {
-    val session = new Session(createInterpreter())
+    val stateChangedCalled = new AtomicInteger()
+    val session = new Session(createInterpreter(), { _ => stateChangedCalled.incrementAndGet() })
     try {
+      // Session's constructor should fire an initial state change event.
+      stateChangedCalled.intValue() shouldBe 1
       Await.ready(session.start(), 30 seconds)
       assert(session.state === SessionState.Idle())
+      // There should be at least 1 state change event fired when session transits to idle.
+      stateChangedCalled.intValue() should (be > 1)
       testCode(session)
     } finally {
       session.close()

+ 1 - 3
repl/src/test/scala/com/cloudera/livy/repl/ReplDriverSuite.scala

@@ -52,9 +52,7 @@ class ReplDriverSuite extends FunSuite with LivyBaseUnitTestSuite {
 
     try {
       // This is sort of what InteractiveSession.scala does to detect an idle session.
-      val handle = client.submit(new PingJob()).get(60, TimeUnit.SECONDS)
-
-      assert(client.getReplState().get(10, TimeUnit.SECONDS) === "idle")
+      client.submit(new PingJob()).get(60, TimeUnit.SECONDS)
 
       val statementId = client.submitReplCode("1 + 1").get
       eventually(timeout(30 seconds), interval(100 millis)) {

+ 81 - 0
repl/src/test/scala/com/cloudera/livy/repl/SessionSpec.scala

@@ -0,0 +1,81 @@
+/*
+ * Licensed to Cloudera, Inc. under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  Cloudera, Inc. licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.cloudera.livy.repl
+
+import java.util.concurrent.{ConcurrentLinkedQueue, CountDownLatch, TimeUnit}
+
+
+import org.mockito.Mockito.when
+import org.mockito.invocation.InvocationOnMock
+import org.mockito.stubbing.Answer
+import org.scalatest.FunSpec
+import org.scalatest.Matchers._
+import org.scalatest.concurrent.Eventually
+import org.scalatest.mock.MockitoSugar.mock
+import org.scalatest.time._
+
+import com.cloudera.livy.LivyBaseUnitTestSuite
+import com.cloudera.livy.repl.Interpreter.ExecuteResponse
+
+class SessionSpec extends FunSpec with Eventually with LivyBaseUnitTestSuite {
+  override implicit val patienceConfig =
+    PatienceConfig(timeout = scaled(Span(10, Seconds)), interval = scaled(Span(100, Millis)))
+
+  describe("Session") {
+    it("should call state changed callbacks in happy path") {
+      val expectedStateTransitions = Array("not_started", "starting", "idle", "busy", "idle")
+      val actualStateTransitions = new ConcurrentLinkedQueue[String]()
+
+      val interpreter = mock[Interpreter]
+      val session = new Session(interpreter, { s => actualStateTransitions.add(s.toString) })
+
+      session.start()
+
+      session.execute("")
+
+      eventually {
+        actualStateTransitions.toArray shouldBe expectedStateTransitions
+      }
+    }
+
+    it("should not transit to idle if there're any pending statements.") {
+      val expectedStateTransitions = Array("not_started", "busy", "busy", "idle")
+      val actualStateTransitions = new ConcurrentLinkedQueue[String]()
+
+      val interpreter = mock[Interpreter]
+      val blockFirstExecuteCall = new CountDownLatch(1)
+      when(interpreter.execute("")).thenAnswer(new Answer[Interpreter.ExecuteResponse] {
+        override def answer(invocation: InvocationOnMock): ExecuteResponse = {
+          blockFirstExecuteCall.await(10, TimeUnit.SECONDS)
+          null
+        }
+      })
+      val session = new Session(interpreter, { s => actualStateTransitions.add(s.toString) })
+
+      for (_ <- 1 to 2) {
+        session.execute("")
+      }
+
+      blockFirstExecuteCall.countDown()
+      eventually {
+        actualStateTransitions.toArray shouldBe expectedStateTransitions
+      }
+    }
+  }
+}

+ 11 - 1
rsc/src/main/java/com/cloudera/livy/rsc/BaseProtocol.java

@@ -199,7 +199,17 @@ public abstract class BaseProtocol extends RpcDispatcher {
     }
   }
 
-  public static class GetReplState {
+  protected static class ReplState {
+
+    public final String state;
+
+    public ReplState(String state) {
+      this.state = state;
+    }
+
+    public ReplState() {
+      this(null);
+    }
 
   }
 

+ 4 - 0
rsc/src/main/java/com/cloudera/livy/rsc/ContextLauncher.java

@@ -342,6 +342,10 @@ class ContextLauncher {
       return this;
     }
 
+    @Override
+    public void onSaslComplete(Rpc client) {
+    }
+
     void dispose() {
       if (client != null) {
         client.close();

+ 11 - 2
rsc/src/main/java/com/cloudera/livy/rsc/RSCClient.java

@@ -61,6 +61,7 @@ public class RSCClient implements LivyClient {
 
   private ContextInfo contextInfo;
   private volatile boolean isAlive;
+  private volatile String replState;
 
   RSCClient(RSCConf conf, Promise<ContextInfo> ctx) throws IOException {
     this.conf = conf;
@@ -286,8 +287,11 @@ public class RSCClient implements LivyClient {
     return deferredCall(new BaseProtocol.GetReplJobResults(), ReplJobResults.class);
   }
 
-  public Future<String> getReplState() {
-    return deferredCall(new BaseProtocol.GetReplState(), String.class);
+  /**
+   * @return Return the repl state. If this's not connected to a repl session, it will return null.
+   */
+  public String getReplState() {
+    return replState;
   }
 
   private class ClientProtocol extends BaseProtocol {
@@ -382,5 +386,10 @@ public class RSCClient implements LivyClient {
         LOG.warn("Received event for unknown job {}", msg.id);
       }
     }
+
+    private void handle(ChannelHandlerContext ctx, ReplState msg) {
+      LOG.trace("Received repl state for {}", msg.state);
+      replState = msg.state;
+    }
   }
 }

+ 2 - 4
rsc/src/main/java/com/cloudera/livy/rsc/ReplJobResults.java

@@ -20,14 +20,12 @@ import com.cloudera.livy.rsc.driver.Statement;
 
 public class ReplJobResults {
   public final Statement[] statements;
-  public final String replState;
 
-  public ReplJobResults(Statement[] statements, String replState) {
+  public ReplJobResults(Statement[] statements) {
     this.statements = statements;
-    this.replState = replState;
   }
 
   public ReplJobResults() {
-    this(null, null);
+    this(null);
   }
 }

+ 19 - 10
rsc/src/main/java/com/cloudera/livy/rsc/driver/RSCDriver.java

@@ -172,6 +172,11 @@ public class RSCDriver extends BaseProtocol {
         registerClient(client);
         return RSCDriver.this;
       }
+
+      @Override
+      public void onSaslComplete(Rpc client) {
+        onClientAuthenticated(client);
+      }
     });
 
     // The RPC library takes care of timing out this.
@@ -241,6 +246,16 @@ public class RSCDriver extends BaseProtocol {
     }
   }
 
+  protected void broadcast(Object msg) {
+    for (Rpc client : clients) {
+      try {
+        client.call(msg);
+      } catch (Exception e) {
+        LOG.warn("Failed to send message to client " + client, e);
+      }
+    }
+  }
+
   /**
    * Initializes the SparkContext used by this driver. This implementation creates a
    * context with the provided configuration. Subclasses can override this behavior,
@@ -256,6 +271,10 @@ public class RSCDriver extends BaseProtocol {
     return sc;
   }
 
+  protected void onClientAuthenticated(final Rpc client) {
+
+  }
+
   /**
    * Called to shut down the driver; any initialization done by initializeContext() should
    * be undone here. This is guaranteed to be called only once.
@@ -281,16 +300,6 @@ public class RSCDriver extends BaseProtocol {
     }
   }
 
-  private void broadcast(Object msg) {
-    for (Rpc client : clients) {
-      try {
-        client.call(msg);
-      } catch (Exception e) {
-        LOG.warn("Failed to send message to client " + client, e);
-      }
-    }
-  }
-
   void run() throws Exception {
     this.running = true;
 

+ 9 - 6
rsc/src/main/java/com/cloudera/livy/rsc/rpc/RpcServer.java

@@ -21,12 +21,9 @@ import java.io.Closeable;
 import java.io.IOException;
 import java.net.InetSocketAddress;
 import java.security.SecureRandom;
-import java.util.Map;
-import java.util.Properties;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ConcurrentMap;
 import java.util.concurrent.TimeUnit;
-import java.util.concurrent.TimeoutException;
 import javax.security.auth.callback.Callback;
 import javax.security.auth.callback.CallbackHandler;
 import javax.security.auth.callback.NameCallback;
@@ -45,9 +42,6 @@ import io.netty.channel.EventLoopGroup;
 import io.netty.channel.nio.NioEventLoopGroup;
 import io.netty.channel.socket.SocketChannel;
 import io.netty.channel.socket.nio.NioServerSocketChannel;
-import io.netty.util.concurrent.Future;
-import io.netty.util.concurrent.GenericFutureListener;
-import io.netty.util.concurrent.Promise;
 import io.netty.util.concurrent.ScheduledFuture;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -194,6 +188,13 @@ public class RpcServer implements Closeable {
      */
     RpcDispatcher onNewClient(Rpc client);
 
+
+    /**
+     * Called when a new client successfully completed SASL authentication.
+     *
+     * @param client The RPC instance for the new client.
+     */
+    void onSaslComplete(Rpc client);
   }
 
   private class SaslServerHandler extends SaslHandler implements CallbackHandler {
@@ -266,6 +267,8 @@ public class RpcServer implements Closeable {
       if (dispatcher != null) {
         rpc.setDispatcher(dispatcher);
       }
+
+      client.callback.onSaslComplete(rpc);
     }
 
     @Override

+ 15 - 8
rsc/src/test/java/com/cloudera/livy/rsc/rpc/TestRpc.java

@@ -20,8 +20,10 @@ package com.cloudera.livy.rsc.rpc;
 import java.io.Closeable;
 import java.util.ArrayList;
 import java.util.Collection;
+import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicInteger;
 import javax.security.sasl.SaslException;
 
@@ -223,9 +225,10 @@ public class TestRpc {
     Future<Rpc> clientRpcFuture = Rpc.createClient(clientConf, server.getEventLoopGroup(),
         "localhost", server.getPort(), "client", secret, new TestDispatcher());
 
-    synchronized (callback) {
-      callback.wait(TimeUnit.SECONDS.toMillis(10));
-    }
+    assertTrue("onNewClient() wasn't called.",
+      callback.onNewClientCalled.await(10, TimeUnit.SECONDS));
+    assertTrue("onSaslComplete() wasn't called.",
+      callback.onSaslCompleteCalled.await(10, TimeUnit.SECONDS));
     assertNotNull(callback.client);
     Rpc serverRpc = autoClose(callback.client);
     Rpc clientRpc = autoClose(clientRpcFuture.get(10, TimeUnit.SECONDS));
@@ -233,18 +236,22 @@ public class TestRpc {
   }
 
   private static class ServerRpcCallback implements RpcServer.ClientCallback {
-
+    final CountDownLatch onNewClientCalled = new CountDownLatch(1);
+    final CountDownLatch onSaslCompleteCalled = new CountDownLatch(1);
     Rpc client;
 
     @Override
     public RpcDispatcher onNewClient(Rpc client) {
-      synchronized (this) {
-        this.client = client;
-        notifyAll();
-      }
+      this.client = client;
+      onNewClientCalled.countDown();
       return new TestDispatcher();
     }
 
+    @Override
+    public void onSaslComplete(Rpc client) {
+      onSaslCompleteCalled.countDown();
+    }
+
   }
 
   private static class TestMessage {

+ 22 - 53
server/src/main/scala/com/cloudera/livy/server/interactive/InteractiveSession.scala

@@ -337,7 +337,7 @@ class InteractiveSession(
 
   import InteractiveSession._
 
-  private var _state: SessionState = initialState
+  private var serverSideState: SessionState = initialState
 
   private val operations = mutable.Map[Long, String]()
   private val operationCounter = new AtomicLong(0)
@@ -348,21 +348,6 @@ class InteractiveSession(
   _appId = appIdHint
   sessionStore.save(RECOVERY_SESSION_TYPE, recoveryMetadata)
 
-  // TODO Replace this with a Rpc call from repl to server.
-  private val stateThread = new Thread(new Runnable {
-    override def run(): Unit = {
-      try {
-        while (_state.isActive) {
-          // State is also updated when we get statement results from repl, not just here.
-          setSessionStateFromReplState(client.map(_.getReplState.get()))
-          Thread.sleep(30000)
-        }
-      } catch {
-        case _: InterruptedException =>
-      }
-    }
-  })
-
   private val app = mockApp.orElse {
     if (livyConf.isRunningOnYarn()) {
       // When Livy is running with YARN, SparkYarnApp can provide better YARN integration.
@@ -402,16 +387,14 @@ class InteractiveSession(
       override def onJobFailed(job: JobHandle[Void], cause: Throwable): Unit = errorOut()
 
       override def onJobSucceeded(job: JobHandle[Void], result: Void): Unit = {
-        transition(SessionState.Idle())
-        stateThread.setDaemon(true)
-        stateThread.start()
+        transition(SessionState.Running())
       }
 
       private def errorOut(): Unit = {
         // Other code might call stop() to close the RPC channel. When RPC channel is closing,
         // this callback might be triggered. Check and don't call stop() to avoid nested called
         // if the session is already shutting down.
-        if (_state != SessionState.ShuttingDown()) {
+        if (serverSideState != SessionState.ShuttingDown()) {
           transition(SessionState.Error())
           stop()
         }
@@ -424,15 +407,21 @@ class InteractiveSession(
   override def recoveryMetadata: RecoveryMetadata =
     InteractiveRecoveryMetadata(id, appId, appTag, kind, owner, proxyUser, rscDriverUri)
 
-  override def state: SessionState = _state
+  override def state: SessionState = {
+    if (serverSideState.isInstanceOf[SessionState.Running]) {
+      // If session is in running state, return the repl state from RSCClient.
+      client
+        .flatMap(s => Option(s.getReplState))
+        .map(SessionState(_))
+        .getOrElse(SessionState.Busy()) // If repl state is unknown, assume repl is busy.
+    } else {
+      serverSideState
+    }
+  }
 
   override def stopSession(): Unit = {
     try {
       transition(SessionState.ShuttingDown())
-      if (stateThread.isAlive) {
-        stateThread.interrupt()
-        stateThread.join()
-      }
       sessionStore.remove(RECOVERY_SESSION_TYPE, id)
       client.foreach { _.stop(true) }
     } catch {
@@ -449,16 +438,12 @@ class InteractiveSession(
   def statements: IndexedSeq[Statement] = {
     ensureActive()
     val r = client.get.getReplJobResults().get()
-
-    setSessionStateFromReplState(Option(r.replState))
     r.statements.toIndexedSeq
   }
 
   def getStatement(stmtId: Int): Option[Statement] = {
     ensureActive()
     val r = client.get.getReplJobResults(stmtId, 1).get()
-
-    setSessionStateFromReplState(Option(r.replState))
     if (r.statements.length < 1) {
       None
     } else {
@@ -472,7 +457,6 @@ class InteractiveSession(
 
   def executeStatement(content: ExecuteRequest): Statement = {
     ensureRunning()
-    setSessionStateFromReplState(client.map(_.getReplState.get()))
     recordActivity()
 
     val id = client.get.submitReplCode(content.code).get
@@ -527,24 +511,24 @@ class InteractiveSession(
     // If the session crashed because of the error, the session should instead go to dead state.
     // Since these 2 transitions are triggered by different threads, there's a race condition.
     // Make sure we won't transit from dead to error state.
-    val areSameStates = _state.getClass() == newState.getClass()
-    val transitFromInactiveToActive = !_state.isActive && newState.isActive
+    val areSameStates = serverSideState.getClass() == newState.getClass()
+    val transitFromInactiveToActive = !serverSideState.isActive && newState.isActive
     if (!areSameStates && !transitFromInactiveToActive) {
-      debug(s"$this session state change from ${_state} to $newState")
-      _state = newState
+      debug(s"$this session state change from ${serverSideState} to $newState")
+      serverSideState = newState
     }
   }
 
   private def ensureActive(): Unit = synchronized {
-    require(_state.isActive, "Session isn't active.")
+    require(serverSideState.isActive, "Session isn't active.")
     require(client.isDefined, "Session is active but client hasn't been created.")
   }
 
   private def ensureRunning(): Unit = synchronized {
-    _state match {
-      case SessionState.Idle() | SessionState.Busy() =>
+    serverSideState match {
+      case SessionState.Running() =>
       case _ =>
-        throw new IllegalStateException("Session is in state %s" format _state)
+        throw new IllegalStateException("Session is in state %s" format serverSideState)
     }
   }
 
@@ -557,21 +541,6 @@ class InteractiveSession(
     opId
    }
 
-  private def setSessionStateFromReplState(newStateStr: Option[String]): Unit = {
-    val newState = newStateStr match {
-      case Some("starting") => SessionState.Starting()
-      case Some("idle") => SessionState.Idle()
-      case Some("busy") => SessionState.Busy()
-      case Some("error") => SessionState.Error()
-      case Some(s) => // Should not happen.
-        warn(s"Unexpected repl state $s")
-        SessionState.Error()
-      case None =>
-        SessionState.Dead()
-    }
-    transition(newState)
-  }
-
   override def appIdKnown(appId: String): Unit = {
     _appId = Option(appId)
     sessionSaveLock.synchronized {