|
@@ -5,6 +5,41 @@ from pyhive.exc import DatabaseError
|
|
|
from app.utils.get_kerberos import get_kerberos_to_local
|
|
|
from configs.logging import logger
|
|
|
from utils import flat_map
|
|
|
+import sasl
|
|
|
+from thrift_sasl import TSaslClientTransport
|
|
|
+from thrift.transport.TSocket import TSocket
|
|
|
+
|
|
|
+
|
|
|
+def create_hive_plain_transport(host, port, username, password, timeout=10):
|
|
|
+ socket = TSocket(host, port)
|
|
|
+ socket.setTimeout(timeout * 1000)
|
|
|
+
|
|
|
+ sasl_auth = 'PLAIN'
|
|
|
+
|
|
|
+ def sasl_factory():
|
|
|
+ sasl_client = sasl.Client()
|
|
|
+ sasl_client.setAttr('host', host)
|
|
|
+ sasl_client.setAttr('username', username)
|
|
|
+ sasl_client.setAttr('password', password)
|
|
|
+ sasl_client.init()
|
|
|
+ return sasl_client
|
|
|
+
|
|
|
+ return TSaslClientTransport(sasl_factory, sasl_auth, socket)
|
|
|
+
|
|
|
+def create_hive_kerberos_plain_transport(host, port, kerberos_service_name, timeout=10):
|
|
|
+ socket = TSocket(host, port)
|
|
|
+ socket.setTimeout(timeout * 1000)
|
|
|
+
|
|
|
+ sasl_auth = 'GSSAPI'
|
|
|
+
|
|
|
+ def sasl_factory():
|
|
|
+ sasl_client = sasl.Client()
|
|
|
+ sasl_client.setAttr('host', host)
|
|
|
+ sasl_client.setAttr('service', kerberos_service_name)
|
|
|
+ sasl_client.init()
|
|
|
+ return sasl_client
|
|
|
+
|
|
|
+ return TSaslClientTransport(sasl_factory, sasl_auth, socket)
|
|
|
|
|
|
class HiveDS(DataSourceBase):
|
|
|
type = 'hive'
|
|
@@ -44,7 +79,17 @@ class HiveDS(DataSourceBase):
|
|
|
res = []
|
|
|
try:
|
|
|
if self.kerberos == 0:
|
|
|
- conn = hive.Connection(host=self.host, port=self.port, username=self.username, database=self.database_name)
|
|
|
+ # conn = hive.Connection(host=self.host, port=self.port, username=self.username, database=self.database_name)
|
|
|
+ conn = hive.connect(
|
|
|
+ thrift_transport=create_hive_plain_transport(
|
|
|
+ host=self.host,
|
|
|
+ port=self.port,
|
|
|
+ username=self.username,
|
|
|
+ password=self.password,
|
|
|
+ timeout=10
|
|
|
+ ),
|
|
|
+ database=self.database_name
|
|
|
+ )
|
|
|
else:
|
|
|
file_name = ''
|
|
|
if self.path_type == 'minio':
|
|
@@ -52,8 +97,19 @@ class HiveDS(DataSourceBase):
|
|
|
file_name = './assets/kerberos/'+self.keytab.split("/")[-1]
|
|
|
else:
|
|
|
file_name = self.keytab
|
|
|
- os.system(f'kinit -kt {file_name} {self.principal}')
|
|
|
- conn = hive.Connection(host=self.host, database=self.database_name, port=self.port, auth="KERBEROS", kerberos_service_name=self.kerberos_service_name)
|
|
|
+ auth_res = os.system(f'kinit -kt {file_name} {self.principal}')
|
|
|
+ if auth_res != 0:
|
|
|
+ raise Exception('hive 连接失败')
|
|
|
+ # conn = hive.Connection(host=self.host, port=self.port, auth="KERBEROS", kerberos_service_name=self.kerberos_service_name, database=self.database_name)
|
|
|
+ conn = hive.connect(
|
|
|
+ thrift_transport=create_hive_kerberos_plain_transport(
|
|
|
+ host=self.host,
|
|
|
+ port=self.port,
|
|
|
+ kerberos_service_name=self.kerberos_service_name,
|
|
|
+ timeout=10
|
|
|
+ ),
|
|
|
+ database=self.database_name
|
|
|
+ )
|
|
|
|
|
|
|
|
|
cursor = conn.cursor()
|
|
@@ -63,7 +119,7 @@ class HiveDS(DataSourceBase):
|
|
|
# logger.info(res)
|
|
|
except Exception as e:
|
|
|
logger.error(e)
|
|
|
-
|
|
|
+ raise Exception('hive 连接失败')
|
|
|
finally:
|
|
|
if conn is not None:
|
|
|
conn.close()
|
|
@@ -80,13 +136,17 @@ class HiveDS(DataSourceBase):
|
|
|
|
|
|
|
|
|
def get_preview_data(self, table_name, limit=100, page = 0):
|
|
|
- sql1 = f'describe {self.database_name}.{table_name}'
|
|
|
+ table_schema = self.get_table_schema(table_name)
|
|
|
+ c_list = []
|
|
|
+ for col in table_schema:
|
|
|
+ c = col.split(':')
|
|
|
+ c_list.append(c)
|
|
|
sql2 = f"SELECT * FROM {table_name} LIMIT {page},{limit}"
|
|
|
- res = self._execute_sql([sql1, sql2])
|
|
|
+ res = self._execute_sql([ sql2])
|
|
|
logger.info(res)
|
|
|
return {
|
|
|
- 'header': flat_map(lambda x: [':'.join(x[:2])], res[0]),
|
|
|
- 'content': res[1]
|
|
|
+ 'header': flat_map(lambda x: [':'.join(x[1:3])], c_list),
|
|
|
+ 'content': res[0]
|
|
|
}
|
|
|
|
|
|
def get_data_num(self, table_name):
|
|
@@ -102,27 +162,43 @@ class HiveDS(DataSourceBase):
|
|
|
|
|
|
def get_table_schema(self, table_name):
|
|
|
logger.info(self.database_name)
|
|
|
- sql1 = f'show columns in {self.database_name}.{table_name}'
|
|
|
- res = self._execute_sql([sql1])
|
|
|
- if res:
|
|
|
- columns = list(map(lambda x: x[0],res[0]))
|
|
|
- # logger.info(columns)
|
|
|
- else:
|
|
|
- raise Exception(f'{table_name} no columns')
|
|
|
- ans = []
|
|
|
- for i, col in enumerate(columns):
|
|
|
- sql = f'describe {self.database_name}.{table_name} {col}'
|
|
|
- try:
|
|
|
- res = self._execute_sql([sql])
|
|
|
- if res:
|
|
|
- # print(res[0])
|
|
|
- res = [[str(i), *x] for x in filter(lambda x: x[0] != '', res[0])]
|
|
|
- ans.append(''.join(flat_map(lambda x: ':'.join(x[:3]), res)))
|
|
|
-
|
|
|
+ sql_test = f'desc {self.database_name}.{table_name}'
|
|
|
+ res_test = self._execute_sql([sql_test])
|
|
|
+ table_schema = []
|
|
|
+ if res_test and len(res_test) > 0:
|
|
|
+ index = 0
|
|
|
+ for col in res_test[0]:
|
|
|
+ col_name = col[0]
|
|
|
+ col_type = col[1]
|
|
|
+ if col_name != '' and col_name.find('#') < 0:
|
|
|
+ col_str = f'{index}:{col_name}:{col_type}'
|
|
|
+ table_schema.append(col_str)
|
|
|
+ index+=1
|
|
|
else:
|
|
|
- raise Exception('table not found')
|
|
|
- except Exception:
|
|
|
- return ans
|
|
|
-
|
|
|
- return ans
|
|
|
+ break
|
|
|
+ return table_schema
|
|
|
+
|
|
|
+ # sql1 = f'show columns in {self.database_name}.{table_name}'
|
|
|
+ # res = self._execute_sql([sql1])
|
|
|
+ # print("===",res)
|
|
|
+ # if res:
|
|
|
+ # columns = list(map(lambda x: x[0],res[0]))
|
|
|
+ # # logger.info(columns)
|
|
|
+ # else:
|
|
|
+ # raise Exception(f'{table_name} no columns')
|
|
|
+ # ans = []
|
|
|
+ # for i, col in enumerate(columns):
|
|
|
+ # sql = f'describe {self.database_name}.{table_name} {col}'
|
|
|
+ # try:
|
|
|
+ # res = self._execute_sql([sql])
|
|
|
+ # if res:
|
|
|
+ # res = [[str(i), *x] for x in filter(lambda x: x[0] != '', res[0])]
|
|
|
+ # ans.append(''.join(flat_map(lambda x: ':'.join(x[:3]), res)))
|
|
|
+
|
|
|
+ # else:
|
|
|
+ # raise Exception('table not found')
|
|
|
+ # except Exception:
|
|
|
+ # return ans
|
|
|
+
|
|
|
+ # return ans
|
|
|
|