|
@@ -30,6 +30,7 @@ import threading
|
|
|
import tempfile
|
|
|
import shutil
|
|
|
import pickle
|
|
|
+import textwrap
|
|
|
|
|
|
if sys.version >= '3':
|
|
|
unicode = str
|
|
@@ -534,8 +535,21 @@ def main():
|
|
|
exec('from pyspark.sql import HiveContext', global_dict)
|
|
|
exec('from pyspark.streaming import StreamingContext', global_dict)
|
|
|
exec('import pyspark.cloudpickle as cloudpickle', global_dict)
|
|
|
+
|
|
|
if spark_major_version >= "2":
|
|
|
exec('from pyspark.shell import spark', global_dict)
|
|
|
+ else:
|
|
|
+ # LIVY-294, need to check whether HiveContext can work properly,
|
|
|
+ # fallback to SQLContext if HiveContext can not be initialized successfully.
|
|
|
+ # Only for spark-1.
|
|
|
+ code = textwrap.dedent("""
|
|
|
+ import py4j
|
|
|
+ from pyspark.sql import SQLContext
|
|
|
+ try:
|
|
|
+ sqlContext.tables()
|
|
|
+ except py4j.protocol.Py4JError:
|
|
|
+ sqlContext = SQLContext(sc)""")
|
|
|
+ exec(code, global_dict)
|
|
|
|
|
|
#Start py4j callback server
|
|
|
from py4j.protocol import ENTRY_POINT_OBJECT_ID
|