123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256 |
- import os
- import random
- import re
- from app.core.datasource.datasource import DataSourceBase
- from pyhive import hive
- from pyhive.exc import DatabaseError
- from app.utils.get_kerberos import get_kerberos_to_local
- from configs.logging import logger
- from utils import flat_map
- import sasl
- from thrift_sasl import TSaslClientTransport
- from thrift.transport.TSocket import TSocket
- from kazoo.client import KazooClient
- def create_hive_plain_transport(host, port, username, password, timeout=10):
- socket = TSocket(host, port)
- socket.setTimeout(timeout * 1000)
- sasl_auth = 'PLAIN'
- def sasl_factory():
- sasl_client = sasl.Client()
- sasl_client.setAttr('host', host)
- sasl_client.setAttr('username', username)
- sasl_client.setAttr('password', password)
- sasl_client.init()
- return sasl_client
- return TSaslClientTransport(sasl_factory, sasl_auth, socket)
- def create_hive_kerberos_plain_transport(host, port, kerberos_service_name, timeout=10):
- socket = TSocket(host, port)
- socket.setTimeout(timeout * 1000)
- sasl_auth = 'GSSAPI'
- def sasl_factory():
- sasl_client = sasl.Client()
- sasl_client.setAttr('host', host)
- sasl_client.setAttr('service', kerberos_service_name)
- sasl_client.init()
- return sasl_client
- return TSaslClientTransport(sasl_factory, sasl_auth, socket)
- class HiveDS(DataSourceBase):
- type = 'hive'
- def __init__(self, host, port,database_name,\
- username=None, password=None, kerberos=0, \
- keytab=None, krb5config=None, kerberos_service_name=None, \
- principal=None, type='hive', path_type='minio', \
- zookeeper_enable=0, zookeeper_hosts=None, zookeeper_namespace=None):
- DataSourceBase.__init__(self, host, port, username, password, database_name, type)
- self.host = host
- self.port = port
- self.username = username
- self.password = password
- self.database_name = 'default' if not database_name else database_name
- self.kerberos = int(kerberos)
- self.keytab = keytab
- self.krb5config = krb5config
- self.kerberos_service_name = kerberos_service_name
- self.principal = principal
- self.path_type = path_type
- self.zookeeper_enable = zookeeper_enable
- self.zookeeper_hosts = zookeeper_hosts
- self.zookeeper_namespace = zookeeper_namespace
- @property
- def jdbc_url(self):
- return f'jdbc:hive2://{self.host}:{self.port}/{self.database_name}'
- @property
- def jdbc_driver_class(self):
- return 'org.apache.hive.jdbc.HiveDriver'
- @property
- def connection_str(self):
- pass
- def _connect(self, db_name: str = None):
- conn = None
- host_list = [f'{self.host}:{self.port}']
- if self.zookeeper_enable == 1:
- zk_client = KazooClient(hosts=self.zookeeper_hosts)
- zk_client.start()
- result = zk_client.get_children(self.zookeeper_namespace)
- zk_client.stop()
- host_list = []
- for host in result:
- if bool(re.search(r"(serverUri)",host)):
- host_list.append(host.split("=")[1].split(";")[0])
- host_count = len(host_list)
- while host_count > 0:
- host_count -= 1
- index = random.randint(0, host_count)
- host_str = host_list.pop(index).split(":")
- try:
- if self.kerberos == 0:
- conn = hive.connect(
- thrift_transport=create_hive_plain_transport(
- host=host_str[0],
- port=host_str[1],
- username=self.username,
- password=self.password,
- timeout=10
- ),
- database=db_name if db_name else self.database_name
- )
- else:
- file_name = ''
- if self.path_type == 'minio':
- get_kerberos_to_local(self.keytab)
- file_name = './assets/kerberos/'+self.keytab.split("/")[-1]
- else:
- file_name = self.keytab
- auth_res = os.system(f'kinit -kt {file_name} {self.principal}')
- if auth_res != 0:
- raise Exception('hive 连接失败')
- conn = hive.connect(
- thrift_transport=create_hive_kerberos_plain_transport(
- host=host_str[0],
- port=host_str[1],
- kerberos_service_name=self.kerberos_service_name,
- timeout=10
- ),
- database=db_name if db_name else self.database_name
- )
- cursor = conn.cursor()
- cursor.execute('select 1')
- result = cursor.fetchall()
- print('获取连接,通过连接查询数据测试===>',result)
- if result:
- return conn
- except Exception as e:
- logger.error(e)
- raise Exception('hive 连接失败')
- def _execute_sql(self, sqls, db_name: str = None):
- conn = None
- res = []
- try:
- conn = self._connect(db_name)
- cursor = conn.cursor()
- for sql in sqls:
- cursor.execute(sql)
- result = cursor.fetchall()
- # logger.info(res)
- res.append(result)
- except Exception as e:
- logger.error(e)
- raise Exception('hive 连接失败')
- finally:
- if conn is not None:
- conn.close()
- return res
- def _execute_create_sql(self, sqls):
- conn = None
- res = []
- try:
- conn = self._connect()
- cursor = conn.cursor()
- for sql in sqls:
- cursor.execute(sql)
- except Exception as e:
- logger.error(e)
- raise Exception('表创建失败,请检查字段填写是否有误')
- finally:
- if conn is not None:
- conn.close()
- return res
- def is_connect(self):
- sql = 'select 1'
- res = self._execute_sql([sql])
- logger.info(res)
- if res:
- return True
- else:
- return False
- def get_preview_data(self, table_name, size=100, start = 0, db_name: str = None):
- table_schema = self.get_table_schema(table_name, db_name)
- c_list = []
- for col in table_schema:
- c = col.split(':')
- c_list.append(c)
- sql2 = f"SELECT * FROM `{table_name}` LIMIT {start},{size}"
- res = self._execute_sql([ sql2], db_name)
- logger.info(res)
- return {
- 'header': flat_map(lambda x: [':'.join(x[1:3])], c_list),
- 'content': res[0]
- }
- def get_data_num(self, table_name, db_name: str = None):
- sql2 = f"SELECT 1 FROM `{table_name}`"
- res = self._execute_sql([sql2], db_name)
- return len(res[0])
- def list_tables(self, db_name: str = None):
- sql = f'show tables'
- res = self._execute_sql([sql], db_name)
- return flat_map(lambda x: x, res[0])
- def get_table_schema(self, table_name, db_name: str = None):
- logger.info(self.database_name)
- sql_test = f'desc `{self.database_name}`.`{table_name}`'
- res_test = self._execute_sql([sql_test], db_name)
- table_schema = []
- if res_test and len(res_test) > 0:
- index = 0
- for col in res_test[0]:
- col_name = col[0]
- col_type = col[1]
- if col_name != '' and col_name.find('#') < 0:
- col_str = f'{index}:{col_name}:{col_type}'
- table_schema.append(col_str)
- index+=1
- else:
- break
- return table_schema
- # sql1 = f'show columns in {self.database_name}.{table_name}'
- # res = self._execute_sql([sql1])
- # print("===",res)
- # if res:
- # columns = list(map(lambda x: x[0],res[0]))
- # # logger.info(columns)
- # else:
- # raise Exception(f'{table_name} no columns')
- # ans = []
- # for i, col in enumerate(columns):
- # sql = f'describe {self.database_name}.{table_name} {col}'
- # try:
- # res = self._execute_sql([sql])
- # if res:
- # res = [[str(i), *x] for x in filter(lambda x: x[0] != '', res[0])]
- # ans.append(''.join(flat_map(lambda x: ':'.join(x[:3]), res)))
- # else:
- # raise Exception('table not found')
- # except Exception:
- # return ans
- # return ans
|