hive.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. import os
  2. import random
  3. import re
  4. from app.core.datasource.datasource import DataSourceBase
  5. from pyhive import hive
  6. from pyhive.exc import DatabaseError
  7. from app.utils.get_kerberos import get_kerberos_to_local
  8. from configs.logging import logger
  9. from utils import flat_map
  10. import sasl
  11. from thrift_sasl import TSaslClientTransport
  12. from thrift.transport.TSocket import TSocket
  13. from kazoo.client import KazooClient
  14. def create_hive_plain_transport(host, port, username, password, timeout=10):
  15. socket = TSocket(host, port)
  16. socket.setTimeout(timeout * 1000)
  17. sasl_auth = 'PLAIN'
  18. def sasl_factory():
  19. sasl_client = sasl.Client()
  20. sasl_client.setAttr('host', host)
  21. sasl_client.setAttr('username', username)
  22. sasl_client.setAttr('password', password)
  23. sasl_client.init()
  24. return sasl_client
  25. return TSaslClientTransport(sasl_factory, sasl_auth, socket)
  26. def create_hive_kerberos_plain_transport(host, port, kerberos_service_name, timeout=10):
  27. socket = TSocket(host, port)
  28. socket.setTimeout(timeout * 1000)
  29. sasl_auth = 'GSSAPI'
  30. def sasl_factory():
  31. sasl_client = sasl.Client()
  32. sasl_client.setAttr('host', host)
  33. sasl_client.setAttr('service', kerberos_service_name)
  34. sasl_client.init()
  35. return sasl_client
  36. return TSaslClientTransport(sasl_factory, sasl_auth, socket)
  37. class HiveDS(DataSourceBase):
  38. type = 'hive'
  39. def __init__(self, host, port,database_name,\
  40. username=None, password=None, kerberos=0, \
  41. keytab=None, krb5config=None, kerberos_service_name=None, \
  42. principal=None, type='hive', path_type='minio', \
  43. zookeeper_enable=0, zookeeper_hosts=None, zookeeper_namespace=None):
  44. DataSourceBase.__init__(self, host, port, username, password, database_name, type)
  45. self.host = host
  46. self.port = port
  47. self.username = username
  48. self.password = password
  49. self.database_name = 'default' if not database_name else database_name
  50. self.kerberos = int(kerberos)
  51. self.keytab = keytab
  52. self.krb5config = krb5config
  53. self.kerberos_service_name = kerberos_service_name
  54. self.principal = principal
  55. self.path_type = path_type
  56. self.zookeeper_enable = zookeeper_enable
  57. self.zookeeper_hosts = zookeeper_hosts
  58. self.zookeeper_namespace = zookeeper_namespace
  59. @property
  60. def jdbc_url(self):
  61. return f'jdbc:hive2://{self.host}:{self.port}/{self.database_name}'
  62. @property
  63. def jdbc_driver_class(self):
  64. return 'org.apache.hive.jdbc.HiveDriver'
  65. @property
  66. def connection_str(self):
  67. pass
  68. def _connect(self, db_name: str = None):
  69. conn = None
  70. host_list = [f'{self.host}:{self.port}']
  71. if self.zookeeper_enable == 1:
  72. zk_client = KazooClient(hosts=self.zookeeper_hosts)
  73. zk_client.start()
  74. result = zk_client.get_children(self.zookeeper_namespace)
  75. zk_client.stop()
  76. host_list = []
  77. for host in result:
  78. if bool(re.search(r"(serverUri)",host)):
  79. host_list.append(host.split("=")[1].split(";")[0])
  80. host_count = len(host_list)
  81. while host_count > 0:
  82. host_count -= 1
  83. index = random.randint(0, host_count)
  84. host_str = host_list.pop(index).split(":")
  85. try:
  86. if self.kerberos == 0:
  87. conn = hive.connect(
  88. thrift_transport=create_hive_plain_transport(
  89. host=host_str[0],
  90. port=host_str[1],
  91. username=self.username,
  92. password=self.password,
  93. timeout=10
  94. ),
  95. database=db_name if db_name else self.database_name
  96. )
  97. else:
  98. file_name = ''
  99. if self.path_type == 'minio':
  100. get_kerberos_to_local(self.keytab)
  101. file_name = './assets/kerberos/'+self.keytab.split("/")[-1]
  102. else:
  103. file_name = self.keytab
  104. auth_res = os.system(f'kinit -kt {file_name} {self.principal}')
  105. if auth_res != 0:
  106. raise Exception('hive 连接失败')
  107. conn = hive.connect(
  108. thrift_transport=create_hive_kerberos_plain_transport(
  109. host=host_str[0],
  110. port=host_str[1],
  111. kerberos_service_name=self.kerberos_service_name,
  112. timeout=10
  113. ),
  114. database=db_name if db_name else self.database_name
  115. )
  116. cursor = conn.cursor()
  117. cursor.execute('select 1')
  118. result = cursor.fetchall()
  119. print('获取连接,通过连接查询数据测试===>',result)
  120. if result:
  121. return conn
  122. except Exception as e:
  123. logger.error(e)
  124. raise Exception('hive 连接失败')
  125. def _execute_sql(self, sqls, db_name: str = None):
  126. conn = None
  127. res = []
  128. try:
  129. conn = self._connect(db_name)
  130. cursor = conn.cursor()
  131. for sql in sqls:
  132. cursor.execute(sql)
  133. result = cursor.fetchall()
  134. # logger.info(res)
  135. res.append(result)
  136. except Exception as e:
  137. logger.error(e)
  138. raise Exception('hive 连接失败')
  139. finally:
  140. if conn is not None:
  141. conn.close()
  142. return res
  143. def _execute_create_sql(self, sqls):
  144. conn = None
  145. res = []
  146. try:
  147. conn = self._connect()
  148. cursor = conn.cursor()
  149. for sql in sqls:
  150. cursor.execute(sql)
  151. except Exception as e:
  152. logger.error(e)
  153. raise Exception('表创建失败,请检查字段填写是否有误')
  154. finally:
  155. if conn is not None:
  156. conn.close()
  157. return res
  158. def is_connect(self):
  159. sql = 'select 1'
  160. res = self._execute_sql([sql])
  161. logger.info(res)
  162. if res:
  163. return True
  164. else:
  165. return False
  166. def get_preview_data(self, table_name, size=100, start = 0, db_name: str = None):
  167. table_schema = self.get_table_schema(table_name, db_name)
  168. c_list = []
  169. for col in table_schema:
  170. c = col.split(':')
  171. c_list.append(c)
  172. sql2 = f"SELECT * FROM `{table_name}` LIMIT {start},{size}"
  173. res = self._execute_sql([ sql2], db_name)
  174. logger.info(res)
  175. return {
  176. 'header': flat_map(lambda x: [':'.join(x[1:3])], c_list),
  177. 'content': res[0]
  178. }
  179. def get_data_num(self, table_name, db_name: str = None):
  180. sql2 = f"SELECT 1 FROM `{table_name}`"
  181. res = self._execute_sql([sql2], db_name)
  182. return len(res[0])
  183. def list_tables(self, db_name: str = None):
  184. sql = f'show tables'
  185. res = self._execute_sql([sql], db_name)
  186. return flat_map(lambda x: x, res[0])
  187. def get_table_schema(self, table_name, db_name: str = None):
  188. sl_db_name = db_name if db_name is not None else self.database_name
  189. sql_test = f'desc `{sl_db_name}`.`{table_name}`'
  190. res_test = self._execute_sql([sql_test], db_name)
  191. table_schema = []
  192. if res_test and len(res_test) > 0:
  193. index = 0
  194. for col in res_test[0]:
  195. col_name = col[0]
  196. col_type = col[1]
  197. if col_name != '' and col_name.find('#') < 0:
  198. col_str = f'{index}:{col_name}:{col_type}'
  199. table_schema.append(col_str)
  200. index+=1
  201. else:
  202. break
  203. return table_schema
  204. def get_table_info(self,table_name: str, db_name: str = None):
  205. sql1 = f"DESCRIBE FORMATTED `{table_name}`"
  206. res = self._execute_sql([sql1], db_name)
  207. return res
  208. # sql1 = f'show columns in {self.database_name}.{table_name}'
  209. # res = self._execute_sql([sql1])
  210. # print("===",res)
  211. # if res:
  212. # columns = list(map(lambda x: x[0],res[0]))
  213. # # logger.info(columns)
  214. # else:
  215. # raise Exception(f'{table_name} no columns')
  216. # ans = []
  217. # for i, col in enumerate(columns):
  218. # sql = f'describe {self.database_name}.{table_name} {col}'
  219. # try:
  220. # res = self._execute_sql([sql])
  221. # if res:
  222. # res = [[str(i), *x] for x in filter(lambda x: x[0] != '', res[0])]
  223. # ans.append(''.join(flat_map(lambda x: ':'.join(x[:3]), res)))
  224. # else:
  225. # raise Exception('table not found')
  226. # except Exception:
  227. # return ans
  228. # return ans