hive.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. import os
  2. from app.core.datasource.datasource import DataSourceBase
  3. from pyhive import hive
  4. from pyhive.exc import DatabaseError
  5. from app.utils.get_kerberos import get_kerberos_to_local
  6. from configs.logging import logger
  7. from utils import flat_map
  8. class HiveDS(DataSourceBase):
  9. type = 'hive'
  10. def __init__(self, host, port,database_name,\
  11. username=None, password=None, kerberos=0, \
  12. keytab=None, krb5config=None, kerberos_service_name=None, \
  13. principal=None, type='hive', path_type='minio'):
  14. DataSourceBase.__init__(self, host, port, username, password, database_name, type)
  15. self.host = host
  16. self.port = port
  17. self.username = username
  18. self.password = password
  19. self.database_name = 'default' if not database_name else database_name
  20. self.kerberos = int(kerberos)
  21. self.keytab = keytab
  22. self.krb5config = krb5config
  23. self.kerberos_service_name = kerberos_service_name
  24. self.principal = principal
  25. self.path_type = path_type
  26. @property
  27. def jdbc_url(self):
  28. return f'jdbc:hive2://{self.host}:{self.port}/{self.database_name}'
  29. @property
  30. def jdbc_driver_class(self):
  31. return 'org.apache.hive.jdbc.HiveDriver'
  32. @property
  33. def connection_str(self):
  34. pass
  35. def _execute_sql(self, sqls):
  36. conn = None
  37. res = []
  38. try:
  39. if self.kerberos == 0:
  40. conn = hive.Connection(host=self.host, port=self.port, username=self.username, database=self.database_name)
  41. else:
  42. file_name = ''
  43. if self.path_type == 'minio':
  44. get_kerberos_to_local(self.keytab)
  45. file_name = './assets/kerberos/'+self.keytab.split("/")[-1]
  46. else:
  47. file_name = self.keytab
  48. os.system(f'kinit -kt {file_name} {self.principal}')
  49. conn = hive.Connection(host=self.host, database=self.database_name, port=self.port, auth="KERBEROS", kerberos_service_name=self.kerberos_service_name)
  50. cursor = conn.cursor()
  51. for sql in sqls:
  52. cursor.execute(sql)
  53. res.append(cursor.fetchall())
  54. # logger.info(res)
  55. except Exception as e:
  56. logger.error(e)
  57. finally:
  58. if conn is not None:
  59. conn.close()
  60. return res
  61. def is_connect(self):
  62. sql = 'select 1'
  63. res = self._execute_sql([sql])
  64. logger.info(res)
  65. if res:
  66. return True
  67. else:
  68. return False
  69. def get_preview_data(self, table_name, limit=100, page = 0):
  70. table_schema = self.get_table_schema(table_name)
  71. c_list = []
  72. for col in table_schema:
  73. c = col.split(':')
  74. c_list.append(c)
  75. sql2 = f"SELECT * FROM {table_name} LIMIT {page},{limit}"
  76. res = self._execute_sql([ sql2])
  77. logger.info(res)
  78. return {
  79. 'header': flat_map(lambda x: [':'.join(x[1:3])], c_list),
  80. 'content': res[0]
  81. }
  82. def get_data_num(self, table_name):
  83. sql2 = f"SELECT 1 FROM {table_name}"
  84. res = self._execute_sql([sql2])
  85. return len(res[0])
  86. def list_tables(self):
  87. sql = f'show tables'
  88. res = self._execute_sql([sql])
  89. return flat_map(lambda x: x, res[0])
  90. def get_table_schema(self, table_name):
  91. logger.info(self.database_name)
  92. sql_test = f'desc {self.database_name}.{table_name}'
  93. res_test = self._execute_sql([sql_test])
  94. table_schema = []
  95. if res_test and len(res_test) > 0:
  96. index = 0
  97. for col in res_test[0]:
  98. col_name = col[0]
  99. col_type = col[1]
  100. if col_name != '' and col_name.find('#') < 0:
  101. col_str = f'{index}:{col_name}:{col_type}'
  102. table_schema.append(col_str)
  103. index+=1
  104. else:
  105. break
  106. return table_schema
  107. # sql1 = f'show columns in {self.database_name}.{table_name}'
  108. # res = self._execute_sql([sql1])
  109. # print("===",res)
  110. # if res:
  111. # columns = list(map(lambda x: x[0],res[0]))
  112. # # logger.info(columns)
  113. # else:
  114. # raise Exception(f'{table_name} no columns')
  115. # ans = []
  116. # for i, col in enumerate(columns):
  117. # sql = f'describe {self.database_name}.{table_name} {col}'
  118. # try:
  119. # res = self._execute_sql([sql])
  120. # if res:
  121. # res = [[str(i), *x] for x in filter(lambda x: x[0] != '', res[0])]
  122. # ans.append(''.join(flat_map(lambda x: ':'.join(x[:3]), res)))
  123. # else:
  124. # raise Exception('table not found')
  125. # except Exception:
  126. # return ans
  127. # return ans