hive.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. import os
  2. from app.core.datasource.datasource import DataSourceBase
  3. from pyhive import hive
  4. from pyhive.exc import DatabaseError
  5. from app.utils.get_kerberos import get_kerberos_to_local
  6. from configs.logging import logger
  7. from utils import flat_map
  8. import sasl
  9. from thrift_sasl import TSaslClientTransport
  10. from thrift.transport.TSocket import TSocket
  11. def create_hive_plain_transport(host, port, username, password, timeout=10):
  12. socket = TSocket(host, port)
  13. socket.setTimeout(timeout * 1000)
  14. sasl_auth = 'PLAIN'
  15. def sasl_factory():
  16. sasl_client = sasl.Client()
  17. sasl_client.setAttr('host', host)
  18. sasl_client.setAttr('username', username)
  19. sasl_client.setAttr('password', password)
  20. sasl_client.init()
  21. return sasl_client
  22. return TSaslClientTransport(sasl_factory, sasl_auth, socket)
  23. def create_hive_kerberos_plain_transport(host, port, kerberos_service_name, timeout=10):
  24. socket = TSocket(host, port)
  25. socket.setTimeout(timeout * 1000)
  26. sasl_auth = 'GSSAPI'
  27. def sasl_factory():
  28. sasl_client = sasl.Client()
  29. sasl_client.setAttr('host', host)
  30. sasl_client.setAttr('service', kerberos_service_name)
  31. sasl_client.init()
  32. return sasl_client
  33. return TSaslClientTransport(sasl_factory, sasl_auth, socket)
  34. class HiveDS(DataSourceBase):
  35. type = 'hive'
  36. def __init__(self, host, port,database_name,\
  37. username=None, password=None, kerberos=0, \
  38. keytab=None, krb5config=None, kerberos_service_name=None, \
  39. principal=None, type='hive', path_type='minio'):
  40. DataSourceBase.__init__(self, host, port, username, password, database_name, type)
  41. self.host = host
  42. self.port = port
  43. self.username = username
  44. self.password = password
  45. self.database_name = 'default' if not database_name else database_name
  46. self.kerberos = int(kerberos)
  47. self.keytab = keytab
  48. self.krb5config = krb5config
  49. self.kerberos_service_name = kerberos_service_name
  50. self.principal = principal
  51. self.path_type = path_type
  52. @property
  53. def jdbc_url(self):
  54. return f'jdbc:hive2://{self.host}:{self.port}/{self.database_name}'
  55. @property
  56. def jdbc_driver_class(self):
  57. return 'org.apache.hive.jdbc.HiveDriver'
  58. @property
  59. def connection_str(self):
  60. pass
  61. def _execute_sql(self, sqls):
  62. conn = None
  63. res = []
  64. try:
  65. if self.kerberos == 0:
  66. # conn = hive.Connection(host=self.host, port=self.port, username=self.username, database=self.database_name)
  67. conn = hive.connect(
  68. thrift_transport=create_hive_plain_transport(
  69. host=self.host,
  70. port=self.port,
  71. username=self.username,
  72. password=self.password,
  73. timeout=10
  74. ),
  75. database=self.database_name
  76. )
  77. else:
  78. file_name = ''
  79. if self.path_type == 'minio':
  80. get_kerberos_to_local(self.keytab)
  81. file_name = './assets/kerberos/'+self.keytab.split("/")[-1]
  82. else:
  83. file_name = self.keytab
  84. auth_res = os.system(f'kinit -kt {file_name} {self.principal}')
  85. if auth_res != 0:
  86. raise Exception('hive 连接失败')
  87. # conn = hive.Connection(host=self.host, port=self.port, auth="KERBEROS", kerberos_service_name=self.kerberos_service_name, database=self.database_name)
  88. conn = hive.connect(
  89. thrift_transport=create_hive_kerberos_plain_transport(
  90. host=self.host,
  91. port=self.port,
  92. kerberos_service_name=self.kerberos_service_name,
  93. timeout=10
  94. ),
  95. database=self.database_name
  96. )
  97. cursor = conn.cursor()
  98. for sql in sqls:
  99. cursor.execute(sql)
  100. res.append(cursor.fetchall())
  101. # logger.info(res)
  102. except Exception as e:
  103. logger.error(e)
  104. raise Exception('hive 连接失败')
  105. finally:
  106. if conn is not None:
  107. conn.close()
  108. return res
  109. def is_connect(self):
  110. sql = 'select 1'
  111. res = self._execute_sql([sql])
  112. logger.info(res)
  113. if res:
  114. return True
  115. else:
  116. return False
  117. def get_preview_data(self, table_name, limit=100, page = 0):
  118. table_schema = self.get_table_schema(table_name)
  119. c_list = []
  120. for col in table_schema:
  121. c = col.split(':')
  122. c_list.append(c)
  123. sql2 = f"SELECT * FROM `{table_name}` LIMIT {page},{limit}"
  124. res = self._execute_sql([ sql2])
  125. logger.info(res)
  126. return {
  127. 'header': flat_map(lambda x: [':'.join(x[1:3])], c_list),
  128. 'content': res[0]
  129. }
  130. def get_data_num(self, table_name):
  131. sql2 = f"SELECT 1 FROM `{table_name}`"
  132. res = self._execute_sql([sql2])
  133. return len(res[0])
  134. def list_tables(self):
  135. sql = f'show tables'
  136. res = self._execute_sql([sql])
  137. return flat_map(lambda x: x, res[0])
  138. def get_table_schema(self, table_name):
  139. logger.info(self.database_name)
  140. sql_test = f'desc `{self.database_name}`.`{table_name}`'
  141. res_test = self._execute_sql([sql_test])
  142. table_schema = []
  143. if res_test and len(res_test) > 0:
  144. index = 0
  145. for col in res_test[0]:
  146. col_name = col[0]
  147. col_type = col[1]
  148. if col_name != '' and col_name.find('#') < 0:
  149. col_str = f'{index}:{col_name}:{col_type}'
  150. table_schema.append(col_str)
  151. index+=1
  152. else:
  153. break
  154. return table_schema
  155. # sql1 = f'show columns in {self.database_name}.{table_name}'
  156. # res = self._execute_sql([sql1])
  157. # print("===",res)
  158. # if res:
  159. # columns = list(map(lambda x: x[0],res[0]))
  160. # # logger.info(columns)
  161. # else:
  162. # raise Exception(f'{table_name} no columns')
  163. # ans = []
  164. # for i, col in enumerate(columns):
  165. # sql = f'describe {self.database_name}.{table_name} {col}'
  166. # try:
  167. # res = self._execute_sql([sql])
  168. # if res:
  169. # res = [[str(i), *x] for x in filter(lambda x: x[0] != '', res[0])]
  170. # ans.append(''.join(flat_map(lambda x: ':'.join(x[:3]), res)))
  171. # else:
  172. # raise Exception('table not found')
  173. # except Exception:
  174. # return ans
  175. # return ans