123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120 |
- import os
- from app.core.datasource.datasource import DataSourceBase
- from pyhive import hive
- from pyhive.exc import DatabaseError
- from configs.logging import logger
- from utils import flat_map
- class HiveDS(DataSourceBase):
- type = 'hive'
- def __init__(self, host, port,database_name,\
- username=None, password=None, kerberos=0, \
- keytab=None, krb5config=None, kerberos_service_name=None, \
- principal=None, type='hive'):
- DataSourceBase.__init__(self, host, port, username, password, database_name, type)
- self.host = host
- self.port = port
- self.username = username
- self.password = password
- self.database_name = 'default' if not database_name else database_name
- self.kerberos = int(kerberos)
- self.keytab = keytab
- self.krb5config = krb5config
- self.kerberos_service_name = kerberos_service_name
- self.principal = principal
- @property
- def jdbc_url(self):
- return f'jdbc:hive2://{self.host}:{self.port}/{self.database_name}'
- @property
- def jdbc_driver_class(self):
- return 'org.apache.hive.jdbc.HiveDriver'
- @property
- def connection_str(self):
- pass
- def _execute_sql(self, sqls):
- conn = None
- res = []
- try:
- if self.kerberos == 0:
- conn = hive.Connection(host=self.host, port=self.port, username=self.username, database=self.database_name)
- else:
- os.system(f'kiinit -kt {self.keytab} {self.principal}')
- conn = hive.Connection(host=self.host, database=self.database_name, port=self.port, auth="KERBEROS", kerberos_service_name=self.kerberos_service_name)
- cursor = conn.cursor()
- for sql in sqls:
- cursor.execute(sql)
- res.append(cursor.fetchall())
- # logger.info(res)
- except Exception as e:
- logger.error(e)
- finally:
- if conn is not None:
- conn.close()
- return res
- def is_connect(self):
- sql = 'select 1'
- res = self._execute_sql([sql])
- logger.info(res)
- if res:
- return True
- else:
- return False
- def get_preview_data(self, table_name, limit=100, page = 0):
- sql1 = f'describe {self.database_name}.{table_name}'
- sql2 = f"SELECT * FROM {table_name} LIMIT {page},{limit}"
- res = self._execute_sql([sql1, sql2])
- logger.info(res)
- return {
- 'header': flat_map(lambda x: [':'.join(x[:2])], res[0]),
- 'content': res[1]
- }
- def get_data_num(self, table_name):
- sql2 = f"SELECT 1 FROM {table_name}"
- res = self._execute_sql([sql2])
- return len(res[0])
- def list_tables(self):
- sql = f'show tables'
- res = self._execute_sql([sql])
- return flat_map(lambda x: x, res[0])
- def get_table_schema(self, table_name):
- logger.info(self.database_name)
- sql1 = f'show columns in {self.database_name}.{table_name}'
- res = self._execute_sql([sql1])
- if res:
- columns = list(map(lambda x: x[0],res[0]))
- # logger.info(columns)
- else:
- raise Exception(f'{table_name} no columns')
- ans = []
- for i, col in enumerate(columns):
- sql = f'describe {self.database_name}.{table_name} {col}'
- try:
- res = self._execute_sql([sql])
- if res:
- # print(res[0])
- res = [[str(i), *x] for x in filter(lambda x: x[0] != '', res[0])]
- ans.append(''.join(flat_map(lambda x: ':'.join(x[:3]), res)))
- else:
- raise Exception('table not found')
- except Exception:
- return ans
- return ans
|