import os import random import re from app.core.datasource.datasource import DataSourceBase from pyhive import hive from pyhive.exc import DatabaseError from app.utils.get_kerberos import get_kerberos_to_local from configs.logging import logger from utils import flat_map import sasl from thrift_sasl import TSaslClientTransport from thrift.transport.TSocket import TSocket from kazoo.client import KazooClient def create_hive_plain_transport(host, port, username, password, timeout=10): socket = TSocket(host, port) socket.setTimeout(timeout * 1000) sasl_auth = 'PLAIN' def sasl_factory(): sasl_client = sasl.Client() sasl_client.setAttr('host', host) sasl_client.setAttr('username', username) sasl_client.setAttr('password', password) sasl_client.init() return sasl_client return TSaslClientTransport(sasl_factory, sasl_auth, socket) def create_hive_kerberos_plain_transport(host, port, kerberos_service_name, timeout=10): socket = TSocket(host, port) socket.setTimeout(timeout * 1000) sasl_auth = 'GSSAPI' def sasl_factory(): sasl_client = sasl.Client() sasl_client.setAttr('host', host) sasl_client.setAttr('service', kerberos_service_name) sasl_client.init() return sasl_client return TSaslClientTransport(sasl_factory, sasl_auth, socket) class HiveDS(DataSourceBase): type = 'hive' def __init__(self, host, port,database_name,\ username=None, password=None, kerberos=0, \ keytab=None, krb5config=None, kerberos_service_name=None, \ principal=None, type='hive', path_type='minio', \ zookeeper_enable=0, zookeeper_hosts=None, zookeeper_namespace=None): DataSourceBase.__init__(self, host, port, username, password, database_name, type) self.host = host self.port = port self.username = username self.password = password self.database_name = 'default' if not database_name else database_name self.kerberos = int(kerberos) self.keytab = keytab self.krb5config = krb5config self.kerberos_service_name = kerberos_service_name self.principal = principal self.path_type = path_type self.zookeeper_enable = zookeeper_enable self.zookeeper_hosts = zookeeper_hosts self.zookeeper_namespace = zookeeper_namespace @property def jdbc_url(self): return f'jdbc:hive2://{self.host}:{self.port}/{self.database_name}' @property def jdbc_driver_class(self): return 'org.apache.hive.jdbc.HiveDriver' @property def connection_str(self): pass def _connect(self, db_name: str = None): conn = None host_list = [f'{self.host}:{self.port}'] if self.zookeeper_enable == 1: zk_client = KazooClient(hosts=self.zookeeper_hosts) zk_client.start() result = zk_client.get_children(self.zookeeper_namespace) zk_client.stop() host_list = [] for host in result: if bool(re.search(r"(serverUri)",host)): host_list.append(host.split("=")[1].split(";")[0]) host_count = len(host_list) while host_count > 0: host_count -= 1 index = random.randint(0, host_count) host_str = host_list.pop(index).split(":") try: if self.kerberos == 0: conn = hive.connect( thrift_transport=create_hive_plain_transport( host=host_str[0], port=host_str[1], username=self.username, password=self.password, timeout=10 ), database=db_name if db_name else self.database_name ) else: file_name = '' if self.path_type == 'minio': get_kerberos_to_local(self.keytab) file_name = './assets/kerberos/'+self.keytab.split("/")[-1] else: file_name = self.keytab auth_res = os.system(f'kinit -kt {file_name} {self.principal}') if auth_res != 0: raise Exception('hive 连接失败') conn = hive.connect( thrift_transport=create_hive_kerberos_plain_transport( host=host_str[0], port=host_str[1], kerberos_service_name=self.kerberos_service_name, timeout=10 ), database=db_name if db_name else self.database_name ) cursor = conn.cursor() cursor.execute('select 1') result = cursor.fetchall() print('获取连接,通过连接查询数据测试===>',result) if result: return conn except Exception as e: logger.error(e) raise Exception('hive 连接失败') def _execute_sql(self, sqls, db_name: str = None): conn = None res = [] try: conn = self._connect(db_name) cursor = conn.cursor() for sql in sqls: cursor.execute(sql) result = cursor.fetchall() # logger.info(res) res.append(result) except Exception as e: logger.error(e) raise Exception('hive 连接失败') finally: if conn is not None: conn.close() return res def _execute_create_sql(self, sqls): conn = None res = [] try: conn = self._connect() cursor = conn.cursor() for sql in sqls: cursor.execute(sql) except Exception as e: logger.error(e) raise Exception('表创建失败,请检查字段填写是否有误') finally: if conn is not None: conn.close() return res def is_connect(self): sql = 'select 1' res = self._execute_sql([sql]) logger.info(res) if res: return True else: return False def get_preview_data(self, table_name, size=100, start = 0, db_name: str = None): table_schema = self.get_table_schema(table_name, db_name) c_list = [] for col in table_schema: c = col.split(':') c_list.append(c) sql2 = f"SELECT * FROM `{table_name}` LIMIT {start},{size}" res = self._execute_sql([ sql2], db_name) logger.info(res) return { 'header': flat_map(lambda x: [':'.join(x[1:3])], c_list), 'content': res[0] } def get_data_num(self, table_name, db_name: str = None): sql2 = f"SELECT 1 FROM `{table_name}`" res = self._execute_sql([sql2], db_name) return len(res[0]) def list_tables(self, db_name: str = None): sql = f'show tables' res = self._execute_sql([sql], db_name) return flat_map(lambda x: x, res[0]) def get_table_schema(self, table_name, db_name: str = None): logger.info(self.database_name) sql_test = f'desc `{self.database_name}`.`{table_name}`' res_test = self._execute_sql([sql_test], db_name) table_schema = [] if res_test and len(res_test) > 0: index = 0 for col in res_test[0]: col_name = col[0] col_type = col[1] if col_name != '' and col_name.find('#') < 0: col_str = f'{index}:{col_name}:{col_type}' table_schema.append(col_str) index+=1 else: break return table_schema # sql1 = f'show columns in {self.database_name}.{table_name}' # res = self._execute_sql([sql1]) # print("===",res) # if res: # columns = list(map(lambda x: x[0],res[0])) # # logger.info(columns) # else: # raise Exception(f'{table_name} no columns') # ans = [] # for i, col in enumerate(columns): # sql = f'describe {self.database_name}.{table_name} {col}' # try: # res = self._execute_sql([sql]) # if res: # res = [[str(i), *x] for x in filter(lambda x: x[0] != '', res[0])] # ans.append(''.join(flat_map(lambda x: ':'.join(x[:3]), res))) # else: # raise Exception('table not found') # except Exception: # return ans # return ans