Преглед на файлове

Merge branch 'master' of http://gogsb.soaringnova.com/sxwl_DL/datax-admin

Zhang Li преди 2 години
родител
ревизия
41fa030822
променени са 3 файла, в които са добавени 66 реда и са изтрити 9 реда
  1. 60 5
      app/core/datasource/hive.py
  2. 5 3
      app/core/datasource/mysql.py
  3. 1 1
      constants/constants.py

+ 60 - 5
app/core/datasource/hive.py

@@ -5,6 +5,41 @@ from pyhive.exc import DatabaseError
 from app.utils.get_kerberos import get_kerberos_to_local
 from configs.logging import logger
 from utils import flat_map
+import sasl
+from thrift_sasl import TSaslClientTransport
+from thrift.transport.TSocket import TSocket
+
+
+def create_hive_plain_transport(host, port, username, password, timeout=10):
+    socket = TSocket(host, port)
+    socket.setTimeout(timeout * 1000)
+
+    sasl_auth = 'PLAIN'
+
+    def sasl_factory():
+        sasl_client = sasl.Client()
+        sasl_client.setAttr('host', host)
+        sasl_client.setAttr('username', username)
+        sasl_client.setAttr('password', password)
+        sasl_client.init()
+        return sasl_client
+
+    return TSaslClientTransport(sasl_factory, sasl_auth, socket)
+
+def create_hive_kerberos_plain_transport(host, port, kerberos_service_name, timeout=10):
+    socket = TSocket(host, port)
+    socket.setTimeout(timeout * 1000)
+
+    sasl_auth = 'GSSAPI'
+
+    def sasl_factory():
+        sasl_client = sasl.Client()
+        sasl_client.setAttr('host', host)
+        sasl_client.setAttr('service', kerberos_service_name)
+        sasl_client.init()
+        return sasl_client
+
+    return TSaslClientTransport(sasl_factory, sasl_auth, socket)
 
 class HiveDS(DataSourceBase):
     type = 'hive'
@@ -44,7 +79,17 @@ class HiveDS(DataSourceBase):
         res = []
         try:
             if self.kerberos == 0:
-                conn = hive.Connection(host=self.host, port=self.port, username=self.username, database=self.database_name)
+                # conn = hive.Connection(host=self.host, port=self.port, username=self.username, database=self.database_name)
+                conn = hive.connect(
+                    thrift_transport=create_hive_plain_transport(
+                        host=self.host,
+                        port=self.port,
+                        username=self.username,
+                        password=self.password,
+                        timeout=10
+                    ),
+                    database=self.database_name
+                )
             else:
                 file_name = ''
                 if self.path_type == 'minio':
@@ -52,8 +97,19 @@ class HiveDS(DataSourceBase):
                     file_name = './assets/kerberos/'+self.keytab.split("/")[-1]
                 else:
                     file_name = self.keytab
-                os.system(f'kinit -kt {file_name} {self.principal}')
-                conn = hive.Connection(host=self.host, database=self.database_name, port=self.port,  auth="KERBEROS", kerberos_service_name=self.kerberos_service_name)
+                auth_res = os.system(f'kinit -kt {file_name} {self.principal}')
+                if auth_res != 0:
+                    raise Exception('hive 连接失败')
+                # conn = hive.Connection(host=self.host, port=self.port,  auth="KERBEROS", kerberos_service_name=self.kerberos_service_name, database=self.database_name)
+                conn = hive.connect(
+                    thrift_transport=create_hive_kerberos_plain_transport(
+                        host=self.host,
+                        port=self.port,
+                        kerberos_service_name=self.kerberos_service_name,
+                        timeout=10
+                    ),
+                    database=self.database_name
+                )
 
 
             cursor = conn.cursor()
@@ -63,7 +119,7 @@ class HiveDS(DataSourceBase):
             # logger.info(res)
         except Exception as e:
             logger.error(e)
-
+            raise Exception('hive 连接失败')
         finally:
             if conn is not None:
                 conn.close()
@@ -85,7 +141,6 @@ class HiveDS(DataSourceBase):
         for col in table_schema:
             c = col.split(':')
             c_list.append(c)
-        print(c_list)
         sql2 = f"SELECT * FROM {table_name} LIMIT {page},{limit}"
         res = self._execute_sql([ sql2])
         logger.info(res)

+ 5 - 3
app/core/datasource/mysql.py

@@ -40,7 +40,8 @@ class MysqlDS(DataSourceBase):
                                     database=self.database_name,
                                     user=self.username,
                                     password=self.password,
-                                    ssl_disabled=not use_ssl)
+                                    ssl_disabled=not use_ssl,
+                                    connection_timeout=5)
             if conn.is_connected():
                 logger.info('Connected to MySQL database')
 
@@ -64,7 +65,8 @@ class MysqlDS(DataSourceBase):
                                     database=self.database_name,
                                     user=self.username,
                                     password=self.password,
-                                      ssl_disabled=not use_ssl)
+                                    ssl_disabled=not use_ssl,
+                                    connection_timeout=5)
             cursor = conn.cursor()
             for sql in sqls:
                 cursor.execute(sql)
@@ -72,7 +74,7 @@ class MysqlDS(DataSourceBase):
             logger.info(res)
         except Error as e:
             logger.error(e)
-
+            raise Exception('mysql 连接失败')
         finally:
             if conn is not None and conn.is_connected():
                 conn.close()

+ 1 - 1
constants/constants.py

@@ -8,4 +8,4 @@ CONSTANTS = {
     'DATASOURCES': DATASOURCES
 }
 
-RUN_STATUS = {"queued": 0, 'running': 1, 'success': 2, 'failed': 3, 'upstream_failed': 3}
+RUN_STATUS = {"queued": 0, 'scheduled': 1, 'running': 1, 'success': 2, 'failed': 3, 'upstream_failed': 3}