import os from app.core.datasource.datasource import DataSourceBase from pyhive import hive from pyhive.exc import DatabaseError from app.utils.get_kerberos import get_kerberos_to_local from configs.logging import logger from utils import flat_map class HiveDS(DataSourceBase): type = 'hive' def __init__(self, host, port,database_name,\ username=None, password=None, kerberos=0, \ keytab=None, krb5config=None, kerberos_service_name=None, \ principal=None, type='hive', path_type='minio'): DataSourceBase.__init__(self, host, port, username, password, database_name, type) self.host = host self.port = port self.username = username self.password = password self.database_name = 'default' if not database_name else database_name self.kerberos = int(kerberos) self.keytab = keytab self.krb5config = krb5config self.kerberos_service_name = kerberos_service_name self.principal = principal self.path_type = path_type @property def jdbc_url(self): return f'jdbc:hive2://{self.host}:{self.port}/{self.database_name}' @property def jdbc_driver_class(self): return 'org.apache.hive.jdbc.HiveDriver' @property def connection_str(self): pass def _execute_sql(self, sqls): conn = None res = [] try: if self.kerberos == 0: conn = hive.Connection(host=self.host, port=self.port, username=self.username, database=self.database_name) else: file_name = '' if self.path_type == 'minio': get_kerberos_to_local(self.keytab) file_name = './assets/kerberos/'+self.keytab.split("/")[-1] else: file_name = self.keytab os.system(f'kinit -kt {file_name} {self.principal}') conn = hive.Connection(host=self.host, database=self.database_name, port=self.port, auth="KERBEROS", kerberos_service_name=self.kerberos_service_name) cursor = conn.cursor() for sql in sqls: cursor.execute(sql) res.append(cursor.fetchall()) # logger.info(res) except Exception as e: logger.error(e) finally: if conn is not None: conn.close() return res def is_connect(self): sql = 'select 1' res = self._execute_sql([sql]) logger.info(res) if res: return True else: return False def get_preview_data(self, table_name, limit=100, page = 0): table_schema = self.get_table_schema(table_name) c_list = [] for col in table_schema: c = col.split(':') c_list.append(c) sql2 = f"SELECT * FROM {table_name} LIMIT {page},{limit}" res = self._execute_sql([ sql2]) logger.info(res) return { 'header': flat_map(lambda x: [':'.join(x[1:3])], c_list), 'content': res[0] } def get_data_num(self, table_name): sql2 = f"SELECT 1 FROM {table_name}" res = self._execute_sql([sql2]) return len(res[0]) def list_tables(self): sql = f'show tables' res = self._execute_sql([sql]) return flat_map(lambda x: x, res[0]) def get_table_schema(self, table_name): logger.info(self.database_name) sql_test = f'desc {self.database_name}.{table_name}' res_test = self._execute_sql([sql_test]) table_schema = [] if res_test and len(res_test) > 0: index = 0 for col in res_test[0]: col_name = col[0] col_type = col[1] if col_name != '' and col_name.find('#') < 0: col_str = f'{index}:{col_name}:{col_type}' table_schema.append(col_str) index+=1 else: break return table_schema # sql1 = f'show columns in {self.database_name}.{table_name}' # res = self._execute_sql([sql1]) # print("===",res) # if res: # columns = list(map(lambda x: x[0],res[0])) # # logger.info(columns) # else: # raise Exception(f'{table_name} no columns') # ans = [] # for i, col in enumerate(columns): # sql = f'describe {self.database_name}.{table_name} {col}' # try: # res = self._execute_sql([sql]) # if res: # res = [[str(i), *x] for x in filter(lambda x: x[0] != '', res[0])] # ans.append(''.join(flat_map(lambda x: ':'.join(x[:3]), res))) # else: # raise Exception('table not found') # except Exception: # return ans # return ans