Selaa lähdekoodia

Merge remote-tracking branch 'origin/master'

# Conflicts:
#	server.py
luoyulong 2 vuotta sitten
vanhempi
commit
5bab33427e

+ 8 - 7
Dockerfile

@@ -68,7 +68,7 @@ RUN mamba update -n base -c defaults conda -y && mamba env create -f /environmen
 RUN echo "\
 [program:be]\n\
 directory=%(ENV_WORKDIR)s\n\
-command= /opt/conda/envs/py38/bin/gunicorn  server:app -b 0.0.0.0:8080 --reload  -k utils.r_uvicorn_worker.RestartableUvicornWorker \n\
+command= /opt/conda/envs/py38/bin/gunicorn  server:app -b 0.0.0.0:8080 --reload  -k utils.r_uvicorn_worker.RestartableUvicornWorker -w 2 \n\
 autorestart=true\n\
 startretries=1000\n\
 redirect_stderr=true\n\
@@ -81,10 +81,7 @@ EXPOSE 8080
 
 FROM builder2 as image-dev
 
-RUN apt-get update && apt-get install -y --no-install-recommends \
-     krb5-user
-
-
+RUN apt-get update && apt-get install -y --no-install-recommends krb5-user
 ADD . ${WORKDIR}
 RUN mv docker/dev/krb5.conf /etc/
 
@@ -92,12 +89,16 @@ RUN mv docker/dev/krb5.conf /etc/
 
 FROM builder2 as image-test
 RUN apt-get update && apt-get install -y --no-install-recommends krb5-user
-
-
 ADD . ${WORKDIR}
 RUN mv docker/test/krb5.conf /etc/
 
 
+
+FROM builder2 as image-idctest
+RUN apt-get update
+RUN apt-get update && apt-get install -y --no-install-recommends krb5-user
+ADD . ${WORKDIR}
+
 # FROM builder2 as builder3
 
 # RUN apt-get update && apt-get install -y --no-install-recommends openssh-server && rm -rf /var/lib/apt/lists/*

+ 5 - 1
Makefile

@@ -23,4 +23,8 @@ deploy: image
 	@docker-compose down  && docker-compose up -d
 
 remote:
-	@ssh -t $(USER)@$(HOST) "cd $(REMOTE_WORKSPACE); make deploy"
+	@ssh -t $(USER)@$(HOST) "cd $(REMOTE_WORKSPACE); make deploy"
+
+idctest:
+	@docker build --target image-idctest -t registry.cn-hangzhou.aliyuncs.com/sxtest/$(NAME):idctest .
+	@docker push registry.cn-hangzhou.aliyuncs.com/sxtest/$(NAME):idctest

+ 106 - 30
app/core/datasource/hive.py

@@ -5,6 +5,41 @@ from pyhive.exc import DatabaseError
 from app.utils.get_kerberos import get_kerberos_to_local
 from configs.logging import logger
 from utils import flat_map
+import sasl
+from thrift_sasl import TSaslClientTransport
+from thrift.transport.TSocket import TSocket
+
+
+def create_hive_plain_transport(host, port, username, password, timeout=10):
+    socket = TSocket(host, port)
+    socket.setTimeout(timeout * 1000)
+
+    sasl_auth = 'PLAIN'
+
+    def sasl_factory():
+        sasl_client = sasl.Client()
+        sasl_client.setAttr('host', host)
+        sasl_client.setAttr('username', username)
+        sasl_client.setAttr('password', password)
+        sasl_client.init()
+        return sasl_client
+
+    return TSaslClientTransport(sasl_factory, sasl_auth, socket)
+
+def create_hive_kerberos_plain_transport(host, port, kerberos_service_name, timeout=10):
+    socket = TSocket(host, port)
+    socket.setTimeout(timeout * 1000)
+
+    sasl_auth = 'GSSAPI'
+
+    def sasl_factory():
+        sasl_client = sasl.Client()
+        sasl_client.setAttr('host', host)
+        sasl_client.setAttr('service', kerberos_service_name)
+        sasl_client.init()
+        return sasl_client
+
+    return TSaslClientTransport(sasl_factory, sasl_auth, socket)
 
 class HiveDS(DataSourceBase):
     type = 'hive'
@@ -44,7 +79,17 @@ class HiveDS(DataSourceBase):
         res = []
         try:
             if self.kerberos == 0:
-                conn = hive.Connection(host=self.host, port=self.port, username=self.username, database=self.database_name)
+                # conn = hive.Connection(host=self.host, port=self.port, username=self.username, database=self.database_name)
+                conn = hive.connect(
+                    thrift_transport=create_hive_plain_transport(
+                        host=self.host,
+                        port=self.port,
+                        username=self.username,
+                        password=self.password,
+                        timeout=10
+                    ),
+                    database=self.database_name
+                )
             else:
                 file_name = ''
                 if self.path_type == 'minio':
@@ -52,8 +97,19 @@ class HiveDS(DataSourceBase):
                     file_name = './assets/kerberos/'+self.keytab.split("/")[-1]
                 else:
                     file_name = self.keytab
-                os.system(f'kinit -kt {file_name} {self.principal}')
-                conn = hive.Connection(host=self.host, database=self.database_name, port=self.port,  auth="KERBEROS", kerberos_service_name=self.kerberos_service_name)
+                auth_res = os.system(f'kinit -kt {file_name} {self.principal}')
+                if auth_res != 0:
+                    raise Exception('hive 连接失败')
+                # conn = hive.Connection(host=self.host, port=self.port,  auth="KERBEROS", kerberos_service_name=self.kerberos_service_name, database=self.database_name)
+                conn = hive.connect(
+                    thrift_transport=create_hive_kerberos_plain_transport(
+                        host=self.host,
+                        port=self.port,
+                        kerberos_service_name=self.kerberos_service_name,
+                        timeout=10
+                    ),
+                    database=self.database_name
+                )
 
 
             cursor = conn.cursor()
@@ -63,7 +119,7 @@ class HiveDS(DataSourceBase):
             # logger.info(res)
         except Exception as e:
             logger.error(e)
-
+            raise Exception('hive 连接失败')
         finally:
             if conn is not None:
                 conn.close()
@@ -80,13 +136,17 @@ class HiveDS(DataSourceBase):
 
 
     def get_preview_data(self, table_name, limit=100, page = 0):
-        sql1 = f'describe {self.database_name}.{table_name}'
+        table_schema = self.get_table_schema(table_name)
+        c_list = []
+        for col in table_schema:
+            c = col.split(':')
+            c_list.append(c)
         sql2 = f"SELECT * FROM {table_name} LIMIT {page},{limit}"
-        res = self._execute_sql([sql1, sql2])
+        res = self._execute_sql([ sql2])
         logger.info(res)
         return {
-            'header': flat_map(lambda x: [':'.join(x[:2])], res[0]),
-            'content': res[1]
+            'header': flat_map(lambda x: [':'.join(x[1:3])], c_list),
+            'content': res[0]
         }
 
     def get_data_num(self, table_name):
@@ -102,27 +162,43 @@ class HiveDS(DataSourceBase):
 
     def get_table_schema(self, table_name):
         logger.info(self.database_name)
-        sql1 = f'show columns in {self.database_name}.{table_name}'
-        res = self._execute_sql([sql1])
-        if res:
-            columns = list(map(lambda x: x[0],res[0]))
-            # logger.info(columns)
-        else:
-            raise Exception(f'{table_name} no columns')
-        ans = []
-        for i, col in enumerate(columns):
-            sql = f'describe {self.database_name}.{table_name} {col}'
-            try:
-                res = self._execute_sql([sql])
-                if res:
-                        # print(res[0])
-                        res = [[str(i), *x] for x in filter(lambda x: x[0] != '', res[0])]
-                        ans.append(''.join(flat_map(lambda x: ':'.join(x[:3]), res)))
-
+        sql_test = f'desc {self.database_name}.{table_name}'
+        res_test = self._execute_sql([sql_test])
+        table_schema = []
+        if res_test and  len(res_test) > 0:
+            index = 0
+            for col in res_test[0]:
+                col_name = col[0]
+                col_type = col[1]
+                if col_name != '' and col_name.find('#') < 0:
+                    col_str = f'{index}:{col_name}:{col_type}'
+                    table_schema.append(col_str)
+                    index+=1
                 else:
-                    raise Exception('table not found')
-            except Exception:
-                return ans
-
-        return ans
+                    break
+        return table_schema
+
+        # sql1 = f'show columns in {self.database_name}.{table_name}'
+        # res = self._execute_sql([sql1])
+        # print("===",res)
+        # if res:
+        #     columns = list(map(lambda x: x[0],res[0]))
+        #     # logger.info(columns)
+        # else:
+        #     raise Exception(f'{table_name} no columns')
+        # ans = []
+        # for i, col in enumerate(columns):
+        #     sql = f'describe {self.database_name}.{table_name} {col}'
+        #     try:
+        #         res = self._execute_sql([sql])
+        #         if res:
+        #                 res = [[str(i), *x] for x in filter(lambda x: x[0] != '', res[0])]
+        #                 ans.append(''.join(flat_map(lambda x: ':'.join(x[:3]), res)))
+
+        #         else:
+        #             raise Exception('table not found')
+        #     except Exception:
+        #         return ans
+
+        # return ans
 

+ 5 - 3
app/core/datasource/mysql.py

@@ -40,7 +40,8 @@ class MysqlDS(DataSourceBase):
                                     database=self.database_name,
                                     user=self.username,
                                     password=self.password,
-                                    ssl_disabled=not use_ssl)
+                                    ssl_disabled=not use_ssl,
+                                    connection_timeout=5)
             if conn.is_connected():
                 logger.info('Connected to MySQL database')
 
@@ -64,7 +65,8 @@ class MysqlDS(DataSourceBase):
                                     database=self.database_name,
                                     user=self.username,
                                     password=self.password,
-                                      ssl_disabled=not use_ssl)
+                                    ssl_disabled=not use_ssl,
+                                    connection_timeout=5)
             cursor = conn.cursor()
             for sql in sqls:
                 cursor.execute(sql)
@@ -72,7 +74,7 @@ class MysqlDS(DataSourceBase):
             logger.info(res)
         except Error as e:
             logger.error(e)
-
+            raise Exception('mysql 连接失败')
         finally:
             if conn is not None and conn.is_connected():
                 conn.close()

+ 6 - 2
app/core/datax/engine.py

@@ -55,7 +55,11 @@ class DataXEngine:
 
     def build_setting(self):
         return {
-            'speed': {
-                'channel': '1'
+            "speed": {
+                "channel": "5"
             }
+            # },
+            # "errorLimit": {
+            #     "record": 0
+            # }
         }

+ 15 - 2
app/crud/af_run.py

@@ -23,9 +23,16 @@ def get_airflow_runs(db: Session):
     return res
 
 
-def get_airflow_runs_by_af_job_ids(db: Session, job_ids: List[int]):
-    res: List[models.AirflowRun] = db.query(models.AirflowRun) \
+def get_airflow_runs_by_af_job_ids(db: Session, job_ids: List[int], start: int = None, end: int = None):
+    res: List[models.AirflowRun] = []
+    if start is None or end is None:
+        res = db.query(models.AirflowRun) \
         .filter(models.AirflowRun.job_id.in_(job_ids)).all()
+    else:
+        res = db.query(models.AirflowRun) \
+            .filter(models.AirflowRun.job_id.in_(job_ids))\
+            .order_by(models.AirflowRun.start_time.desc())\
+            .slice(start,end).all()
     return res
 
 
@@ -43,3 +50,9 @@ def get_airflow_run_once_debug_mode(db: Session, job_id: int):
 def get_airflow_run_once(db: Session, item_id: int):
     res: models.AirflowRun = db.query(models.AirflowRun).filter(models.AirflowRun.id == item_id).first()
     return res
+
+def count_airflow_runs_by_job_ids(db: Session, job_ids: List[int]):
+    count = db.query(models.AirflowRun) \
+        .filter(models.AirflowRun.job_id.in_(job_ids)).count()
+
+    return count

+ 19 - 2
app/crud/data_management.py

@@ -1,9 +1,10 @@
+from re import S
 import time
 from typing import List
 from app import models, schemas
 from sqlalchemy.orm import Session
 
-def create_data_management(db: Session, item: schemas.DataManagementCreate, table_name: str):
+def create_data_management(db: Session, item: schemas.DataManagementCreate, table_name: str, af_run_id: str):
     create_time: int = int(time.time())
     db_item = models.DataManagement(**{
         'name': item.name,
@@ -11,17 +12,33 @@ def create_data_management(db: Session, item: schemas.DataManagementCreate, tabl
         'create_time': create_time,
         'user_name': item.user_name,
         'user_id': item.user_id,
-        'project_id': item.project_id
+        'project_id': item.project_id,
+        'af_run_id': af_run_id,
+        'status': 1
     })
     db.add(db_item)
     db.commit()
     db.refresh(db_item)
     return db_item
 
+def update_data_management_status(db: Session, d_id: int, status: int):
+    db_item: models.DataManagement = db.query(models.DataManagement).filter(models.DataManagement.id == d_id).first()
+    db_item.status = status
+    db.commit()
+    db.flush()
+    db.refresh(db_item)
+    return db_item
+
 def get_data_managements(db: Session, user_id: str, project_id: str):
     res: List[models.DataManagement] = db.query(models.DataManagement).filter(models.DataManagement.project_id == project_id).all()
     return res
 
+def get_data_management_info(db: Session, d_id: int):
+    db_item: models.DataManagement = db.query(models.DataManagement).filter(models.DataManagement.id == d_id).first()
+    if not db_item:
+        raise Exception('该数据不存在')
+    return db_item
+
 def delete_data_management(db: Session, d_id: int):
     dm_item = db.query(models.DataManagement).filter(models.DataManagement.id == d_id).first()
     if not dm_item:

+ 41 - 7
app/crud/job_jdbc_datasource.py

@@ -5,7 +5,7 @@ from app.core.datasource.datasource import DataSrouceFactory
 import app.schemas as schemas
 import app.models as models
 from app.utils import decode_user
-
+from app.utils.utils import decode_base64
 
 def _decode(url, datasource, database_name):
     url =  url.replace('jdbc:', '').replace('hive2://', '').replace(f'{datasource}://', '').replace(f'/{database_name}','')
@@ -20,7 +20,11 @@ def _format_datasource(db: Session, item: schemas.JobJdbcDatasourceBase, ds_id:
             raise Exception('未找到该数据源')
         item.jdbc_url = _decode(item.jdbc_url, item.datasource, item.database_name)
         item.jdbc_username, item.jdbc_password = decode_user(item.jdbc_username, item.jdbc_password)
-    host, port = item.jdbc_url.split(':')
+    try:
+        host, port = item.jdbc_url.split(':')
+    except:
+        raise Exception('数据库地址填写错误')
+
     if not host or not port:
         raise Exception('jdbc_url无效')
     ds = None
@@ -46,6 +50,8 @@ def _format_datasource(db: Session, item: schemas.JobJdbcDatasourceBase, ds_id:
 
 
 def test_datasource_connection(db: Session, item: schemas.JobJdbcDatasourceCreate):
+    if item.jdbc_password and item.jdbc_password != '':
+        item.jdbc_password = decode_base64(item.jdbc_password)
     ds, item = _format_datasource(db, item)
     return ds.is_connect()
 
@@ -66,10 +72,18 @@ def get_table_names(db: Session, ds_id: int):
 
 
 def create_job_jdbc_datasource(db: Session, item: schemas.JobJdbcDatasourceCreate):
+    if item.jdbc_password and item.jdbc_password != '':
+        item.jdbc_password = decode_base64(item.jdbc_password)
     ds, item = _format_datasource(db, item)
-
-    #
+    con_result = ds.is_connect()
+    if not con_result:
+        raise Exception('连接失败,不允许添加')
     create_time: int = int(time.time())
+    name_item = db.query(models.JobJdbcDatasource)\
+        .filter(models.JobJdbcDatasource.datasource_name == item.datasource_name)\
+        .filter(models.JobJdbcDatasource.status == 1).first()
+    if name_item:
+        raise Exception('数据源名称重复')
     db_item = models.JobJdbcDatasource(**item.dict(), **{
         'status': 1,
         'create_time': create_time,
@@ -89,21 +103,41 @@ def get_job_jdbc_datasources(db: Session, datasource_type: str = None, skip: int
     if datasource_type is not None and datasource_type != '':
         res = db.query(models.JobJdbcDatasource)\
             .filter(models.JobJdbcDatasource.datasource == datasource_type)\
-            .filter(models.JobJdbcDatasource.status == 1).all()
+            .filter(models.JobJdbcDatasource.status == 1)\
+            .order_by(models.JobJdbcDatasource.create_time.desc()).all()
     else:
         res = db.query(models.JobJdbcDatasource)\
-            .filter(models.JobJdbcDatasource.status == 1).all()
+            .filter(models.JobJdbcDatasource.status == 1)\
+            .order_by(models.JobJdbcDatasource.create_time.desc()).all()
     for item in res:
         item.jdbc_url = _decode(item.jdbc_url, item.datasource, item.database_name)
     return res
 
+def get_job_jdbc_datasources_info(db: Session, ds_id: int):
+    db_item: models.JobJdbcDatasource = db.query(models.JobJdbcDatasource)\
+        .filter(models.JobJdbcDatasource.id == ds_id).first()
+    db_item.jdbc_url = _decode(db_item.jdbc_url, db_item.datasource, db_item.database_name)
+    if db_item.jdbc_username and db_item.jdbc_username != '':
+        db_item.jdbc_username = decode_base64(db_item.jdbc_username)
+    return db_item
 
 def update_job_jdbc_datasources(db: Session, ds_id: int, update_item: schemas.JobJdbcDatasourceUpdate):
+    if update_item.jdbc_password and update_item.jdbc_password != '':
+        update_item.jdbc_password = decode_base64(update_item.jdbc_password)
+        print(update_item.jdbc_password)
     ds, update_item = _format_datasource(db, update_item)
-
+    con_result = ds.is_connect()
+    if not con_result:
+        raise Exception('连接失败,不允许添加')
     db_item = db.query(models.JobJdbcDatasource).filter(models.JobJdbcDatasource.id == ds_id).first()
     if not db_item:
         raise Exception('未找到该数据源')
+    name_item = db.query(models.JobJdbcDatasource)\
+        .filter(models.JobJdbcDatasource.datasource_name == update_item.datasource_name)\
+        .filter(models.JobJdbcDatasource.status == 1)\
+        .filter(models.JobJdbcDatasource.id != ds_id).first()
+    if name_item:
+        raise Exception('数据源名称重复')
     update_dict = update_item.dict(exclude_unset=True)
     for k, v in update_dict.items():
         setattr(db_item, k, v)

+ 5 - 1
app/models/data_management.py

@@ -18,4 +18,8 @@ class DataManagement(BaseModel):
     # 创建人编号
     user_id = Column(String, nullable=False)
     # 项目编号
-    project_id = Column(String, nullable=False)
+    project_id = Column(String, nullable=False)
+    # af_run_id
+    af_run_id = Column(String, nullable=False)
+    # 状态(1:转存中,2:成功,3:失败)
+    status = Column(Integer, nullable=False)

+ 8 - 0
app/routers/dag.py

@@ -23,6 +23,14 @@ def execute_dag(dag: schemas.Dag,db: Session = Depends(get_db)):
     af_job = dag_job_submit(dag.dag_uuid, dag.dag_script,db)
     return af_job
 
+@router.get("/debug_execute")
+@web_try()
+@sxtimeit
+def debug_execute(dag_uuid: str, db: Session = Depends(get_db)):
+    relation = crud.get_dag_af_id(db,dag_uuid, 'debug')
+    if relation is None:
+        return False
+    return True
 
 @router.get("/debug_status")
 @web_try()

+ 30 - 6
app/routers/data_management.py

@@ -1,5 +1,5 @@
 from asyncio import current_task
-from re import A
+from re import A, I
 import time
 from typing import Optional
 from fastapi import APIRouter
@@ -10,7 +10,8 @@ from app import schemas
 
 import app.crud as crud
 from app.services.dag import get_tmp_table_name
-from app.utils.send_util import data_transfer_run
+from app.utils.send_util import data_transfer_run, get_data_transfer_run_status
+from constants.constants import RUN_STATUS
 from utils.sx_time import sxtimeit
 from utils.sx_web import web_try
 from app.common.hive import hiveDs
@@ -32,9 +33,15 @@ def create_data_management(item: schemas.DataManagementCreate, db: Session = Dep
     current_time = int(time.time())
     table_name = f'project{item.project_id.lower()}_user{item.user_id.lower()}_{item.name.lower()}_{current_time}'
     tmp_table_name = get_tmp_table_name(item.dag_uuid, item.node_id, str(item.out_pin), db)
-    af_run_id = data_transfer_run(tmp_table_name, table_name)
-    res = crud.create_data_management(db, item, table_name)
-    return res
+    af_run_res = data_transfer_run(database_name+'.'+tmp_table_name, database_name+'.'+table_name)
+    af_run = af_run_res['data'] if 'data' in af_run_res.keys() else None
+    af_run_id = af_run['af_run_id'] if af_run and 'af_run_id' in af_run.keys() else None
+    if af_run_id:
+        item.name = item.name + '_' + str(current_time)
+        res = crud.create_data_management(db, item, table_name, af_run_id)
+        return res
+    else:
+        raise Exception('中间结果转存失败')
 
 
 @router.get("/")
@@ -42,9 +49,26 @@ def create_data_management(item: schemas.DataManagementCreate, db: Session = Dep
 @sxtimeit
 def get_data_managements(user_id: str, project_id: str, db: Session = Depends(get_db)):
     res = crud.get_data_managements(db, user_id, project_id)
+    data_management_list = []
     for item in res:
         item.table_name = f'{database_name}.{item.table_name}'
-    return res
+        data_management_list.append(item)
+    return data_management_list
+
+@router.get("/info")
+@web_try()
+@sxtimeit
+def get_data_management_info(id: int, db: Session = Depends(get_db)):
+    item = crud.get_data_management_info(db, id)
+    if item.status == 1:
+        transfer_run_res = get_data_transfer_run_status(item.af_run_id)
+        transfer_run = transfer_run_res['data'] if 'data' in transfer_run_res.keys() else None
+        transfer_run_status = transfer_run['status'] if transfer_run and 'status' in transfer_run.keys() else None
+        if transfer_run_status:
+            item = crud.update_data_management_status(db, item.id, RUN_STATUS[transfer_run_status])
+    item.table_name = f'{database_name}.{item.table_name}'
+    return item
+
 
 @router.get("/local")
 @web_try()

+ 2 - 6
app/routers/jm_job_log.py

@@ -38,12 +38,8 @@ def get_job_logs(job_id: int = None, params: Params=Depends(get_page), db: Sessi
     relations = crud.get_af_ids(db,id_to_job.keys(), 'job')
     af_to_datax = {relation.af_id:relation.se_id for relation in relations}
     # 获取任务运行记录
-    af_job_runs = crud.get_airflow_runs_by_af_job_ids(db, af_to_datax.keys())
-    # 根据时间进行排序
-    af_job_runs.sort(key=lambda x: x.start_time, reverse=True)
-    total = len(af_job_runs)
-    # 进行分页
-    af_job_runs = af_job_runs[(params['page'] - 1) * params['size']:params['page'] * params['size']]
+    af_job_runs = crud.get_airflow_runs_by_af_job_ids(db, af_to_datax.keys(),(params['page'] - 1) * params['size'],params['page'] * params['size'])
+    total = crud.count_airflow_runs_by_job_ids(db, af_to_datax.keys())
     res = []
     for af_job_run in af_job_runs:
         job_id = af_to_datax[int(af_job_run.job_id)]

+ 6 - 0
app/routers/job_jdbc_datasource.py

@@ -62,6 +62,12 @@ def create_datasource(ds: schemas.JobJdbcDatasourceCreate, db: Session = Depends
 def get_datasources(datasource_type: Optional[str] = None, params: Params=Depends(), db: Session = Depends(get_db)):
     return paginate(crud.get_job_jdbc_datasources(db, datasource_type), params)
 
+@router.get("/info")
+@web_try()
+@sxtimeit
+def get_datasources_info(ds_id: int, db: Session = Depends(get_db)):
+    return crud.get_job_jdbc_datasources_info(db, ds_id)
+
 @router.put("/{ds_id}")
 @web_try()
 @sxtimeit

+ 16 - 13
app/routers/job_log.py

@@ -33,21 +33,23 @@ def get_job_logs(job_id: Optional[int] = None, params: Params=Depends(get_page),
     relations = crud.get_af_ids(db, id_to_job.keys(), 'datax')
     af_to_datax = {relation.af_id:relation.se_id for relation in relations}
     # 获取运行记录
-    af_job_runs = crud.get_airflow_runs_by_af_job_ids(db, af_to_datax.keys())
-    # 根据开始时间排序
-    af_job_runs.sort(key=lambda x: x.start_time, reverse=True)
-    total = len(af_job_runs)
-    # 进行分页
-    af_job_runs = af_job_runs[(params['page'] - 1) * params['size']:params['page'] * params['size']]
+    af_job_runs = crud.get_airflow_runs_by_af_job_ids(db, af_to_datax.keys(),(params['page'] - 1) * params['size'],params['page'] * params['size'])
+    total = crud.count_airflow_runs_by_job_ids(db, af_to_datax.keys())
     res = []
     # 循环获取日志
     for af_job_run in af_job_runs:
         job_id = af_to_datax[int(af_job_run.job_id)]
+        print(f'job_id==>{job_id}')
         # 获取af_job
-        af_job = crud.get_airflow_job_once(db, af_job_run.job_id)
-        task = list(af_job.tasks)[0] if len(list(af_job.tasks))>0 else None
-        log_res = get_task_log(af_job.id, af_job_run.af_run_id, task['id'])
-        job_log = log_res['data'] if 'data' in log_res.keys() else None
+        job_log = None
+        if len(af_job_run.details['tasks']) > 0:
+            job_log = list(af_job_run.details['tasks'].values())[0]
+        else:
+            af_job = crud.get_airflow_job_once(db, af_job_run.job_id)
+            task = list(af_job.tasks)[0] if len(list(af_job.tasks))>0 else None
+            print(f"datax任务的作业{task['id']}")
+            log_res = get_task_log(af_job.id, af_job_run.af_run_id, task['id'])
+            job_log = log_res['data'] if 'data' in log_res.keys() else None
         log = {
             "id": af_job_run.id,
             "job_id": job_id,
@@ -55,8 +57,8 @@ def get_job_logs(job_id: Optional[int] = None, params: Params=Depends(get_page),
             "af_job_id": int(af_job_run.job_id),
             "run_id": af_job_run.id,
             "af_run_id": af_job_run.af_run_id,
-            "start_time": job_log['start_time'],
-            "result": RUN_STATUS[job_log['status']] if job_log['status'] else 0,
+            "start_time": job_log['start_time'] if job_log and 'start_time' in job_log.keys() else None,
+            "result": RUN_STATUS[job_log['status']] if job_log and  'status' in job_log.keys() else 0,
         }
         res.append(log)
     return page_help(res,params['page'],params['size'],total)
@@ -75,7 +77,8 @@ def get_job_logs_once(run_id: int, db: Session = Depends(get_db)):
     log_res = get_task_log(af_job.id, af_job_run.af_run_id, task['id'])
     job_log = log_res['data'] if 'data' in log_res.keys() else None
     log = {
-        "log": job_log['log'] if 'log' in job_log.keys() else None
+        "log": job_log['log'] if  job_log and 'log' in job_log.keys() else None,
+        "status": RUN_STATUS[job_log['status']] if job_log and 'status' in job_log.keys() else None
     }
     return log
 

+ 2 - 1
app/services/dag.py

@@ -125,7 +125,8 @@ def get_tmp_table_name(dag_uuid: str, node_id: str, out_pin: str, db: Session):
     if task_id:
         table_name = f'job{job_id}_task{task_id}_subnode{node_id}_output{out_pin}_tmp'
         t_list = hiveDs.list_tables()
-        if table_name.lower() not in t_list:
+        table_name = table_name.lower()
+        if table_name not in t_list:
             raise Exception('该节点不存在中间结果')
         return table_name
     else:

+ 32 - 18
app/services/datax.py

@@ -33,16 +33,23 @@ def datax_create_task(job_info: models.JobInfo):
     partition_list = []
     if job_info.partition_info is not None and job_info.partition_info != '':
         partition_list = job_info.partition_info.split(',')
-    envs = {}
-    if job_info.inc_start_time and job_info.last_time and len(partition_list) > 0 and job_info.current_time:
-        envs = {
-            "first_begin_time": job_info.inc_start_time,
-            "last_key": job_info.last_time,
-            "current_key": job_info.current_time,
+    first_begin_time = int(time.time())
+    if job_info.inc_start_time is not None and job_info.inc_start_time != '':
+        first_begin_time = job_info.inc_start_time
+    last_key = 'lastTime'
+    if job_info.last_time is not None and job_info.last_time != '':
+        last_key = job_info.last_time
+    current_key = 'currentTime'
+    if job_info.current_time is not None and job_info.current_time != '':
+        current_key = job_info.current_time
+    envs = {
+            "first_begin_time": first_begin_time,
+            "last_key": last_key,
+            "current_key": current_key,
             "partition_key": "partition",
-            "partition_word": partition_list[0] if len(partition_list) > 0 else '',
-            "partition_format": partition_list[2]  if len(partition_list) > 0 else '',
-            "partition_diff": partition_list[1]  if len(partition_list) > 0 else ''
+            "partition_word": partition_list[0] if len(partition_list) > 0 else 'xujiayue',
+            "partition_format": partition_list[2]  if len(partition_list) > 0 else '%Y-%m-%d',
+            "partition_diff": partition_list[1]  if len(partition_list) > 0 else 0
         }
     af_task = {
         "name": job_info.job_desc,
@@ -90,16 +97,23 @@ def datax_put_task(job_info: models.JobInfo,old_af_task):
     partition_list = []
     if job_info.partition_info is not None and job_info.partition_info != '':
         partition_list = job_info.partition_info.split(',')
-    envs = {}
-    if job_info.inc_start_time and job_info.last_time and len(partition_list) > 0 and job_info.current_time:
-        envs = {
-            "first_begin_time": job_info.inc_start_time,
-            "last_key": job_info.last_time,
-            "current_key": job_info.current_time,
+    first_begin_time = int(time.time())
+    if job_info.inc_start_time is not None and job_info.inc_start_time != '':
+        first_begin_time = job_info.inc_start_time
+    last_key = 'lastTime'
+    if job_info.last_time is not None and job_info.last_time != '':
+        last_key = job_info.last_time
+    current_key = 'currentTime'
+    if job_info.current_time is not None and job_info.current_time != '':
+        current_key = job_info.current_time
+    envs = {
+            "first_begin_time": first_begin_time,
+            "last_key": last_key,
+            "current_key": current_key,
             "partition_key": "partition",
-            "partition_word": partition_list[0] if len(partition_list) > 0 else '',
-            "partition_format": partition_list[2]  if len(partition_list) > 0 else '',
-            "partition_diff": partition_list[1]  if len(partition_list) > 0 else ''
+            "partition_word": partition_list[0] if len(partition_list) > 0 else 'xujiayue',
+            "partition_format": partition_list[2]  if len(partition_list) > 0 else '%Y-%m-%d',
+            "partition_diff": partition_list[1]  if len(partition_list) > 0 else 0
         }
     af_task = {
         "name": job_info.job_desc,

+ 13 - 1
app/services/job_info.py

@@ -8,6 +8,11 @@ import app.crud as crud
 def create_job_info_services(db: Session, item: schemas.JobInfoCreate):
     create_time: int = int(time.time())
     item_dict = item.dict()
+    name_item = db.query(models.JobInfo)\
+        .filter(models.JobInfo.job_desc == item.job_desc)\
+        .filter(models.JobInfo.delete_status == 1).first()
+    if name_item:
+        raise Exception('同步配置名称重复')
     # 定时任务对象转为cron表达式
     cron_expression_dict = item_dict.pop('cron_expression')
     cron_expression = joint_cron_expression(schemas.CronExpression(**cron_expression_dict))
@@ -38,13 +43,20 @@ def create_job_info_services(db: Session, item: schemas.JobInfoCreate):
     af_job = datax_create_job(db_item, db)
     # 创建本地同步任务
     db_item = crud.create_job_info(db, db_item)
+    job_info = db_item.to_dict()
     crud.create_relation(db, db_item.id,'datax', af_job['id'])
-    return db_item
+    return job_info
 
 
 def update_job_info_services(db: Session, id: int, update_item: schemas.JobInfoUpdate):
     # 获取任务信息
     db_item = crud.get_job_info(db,id)
+    name_item = db.query(models.JobInfo)\
+        .filter(models.JobInfo.job_desc == update_item.job_desc)\
+        .filter(models.JobInfo.delete_status == 1)\
+        .filter(models.JobInfo.id != id).first()
+    if name_item:
+        raise Exception('同步配置名称重复')
     update_dict = update_item.dict(exclude_unset=True)
     # 定时任务对象转为cron表达式
     cron_expression_dict = update_dict.pop('cron_expression')

+ 12 - 1
app/utils/send_util.py

@@ -113,4 +113,15 @@ def get_task_log(job_id: str, af_run_id: str, task_id: str):
         return res.json()
     else:
         msg = result['msg'] if 'msg' in result.keys() else result
-        raise Exception(f'获取task日志,请求airflow失败-->{msg}')
+        raise Exception(f'获取task日志,请求airflow失败-->{msg}')
+
+
+# 获取中间结果转存状态
+def get_data_transfer_run_status(af_run_id: str):
+    res = requests.get(url=f'http://{HOST}:{PORT}/af/af_run/data_transfer_log/{af_run_id}')
+    result = res.json()
+    if 'code' in result.keys() and result['code'] == 200:
+        return res.json()
+    else:
+        msg = result['msg'] if 'msg' in result.keys() else result
+        raise Exception(f'获取中间结果转存状态,请求airflow失败-->{msg}')

+ 3 - 0
app/utils/utils.py

@@ -18,6 +18,9 @@ def decode_user(username, password):
 def encode_base64(str):
     return  base64.encodebytes(str.encode('utf-8')).decode('utf-8')
 
+def decode_base64(str):
+    return  base64.decodebytes(str.encode('utf-8')).decode('utf-8')
+
 
 def byte_conversion(size):
     if size < 1024:

+ 1 - 1
constants/constants.py

@@ -8,4 +8,4 @@ CONSTANTS = {
     'DATASOURCES': DATASOURCES
 }
 
-RUN_STATUS = {"queued": 0, 'running': 1, 'success': 2, 'failed': 3, 'upstream_failed': 3}
+RUN_STATUS = {"queued": 0, 'scheduled': 1, 'running': 1, 'success': 2, 'failed': 3, 'skipped': 3,'upstream_failed': 3}

+ 7 - 0
data/data.sql

@@ -336,4 +336,11 @@ ALTER TABLE `jm_job_info`
 ADD COLUMN `create_time` int(20) NULL COMMENT '创建时间' AFTER `project_id`,
 ADD COLUMN `update_time` int(20) NULL COMMENT '修改时间' AFTER `create_time`;
 
+-- ----------------------------
+-- Alter for data_management
+-- ----------------------------
+ALTER TABLE `data_management`
+ADD COLUMN `af_run_id` varchar(100) NOT NULL COMMENT 'airflow运行id' AFTER `project_id`,
+ADD COLUMN `status` tinyint(4) NOT NULL COMMENT '状态(1:转存中,2:成功,3:失败)' AFTER `af_run_id`;
+
 SET FOREIGN_KEY_CHECKS = 1;

+ 13 - 13
idctest.ini

@@ -7,24 +7,24 @@ port = 3306
 ssl_disabled = true
 
 [MINIO]
-k8s_url = aihub-minio-yili-test:9000
-url = aihub-minio-yili-test:9000
+k8s_url = aihub-dag-minio:9000
+url = aihub-dag-minio:9000
 access_key = minioadmin
 secret_key = minioadmin
 
 
 [AF_BACKEND]
-uri=aihub-backend-af-yili-test:8080
-host=aihub-backend-af-yili-test
+uri=aihub-dag-backend-af:8080
+host=aihub-dag-backend-af
 port=8080
 dag_files_dir=/dags/
 
 [K8S]
 image_pull_key=codingregistrykey
-enable_kerberos=true
+enable_kerberos=false
 
 [BACKEND]
-url=aihub-backend-yili-test:8080
+url=aihub-dag-backend:8080
 
 [AIRFLOW]
 uri=airflow-webserver:8080
@@ -32,12 +32,12 @@ api_token=YWRtaW46YWRtaW4=
 
 
 [HIVE]
-host = 10.254.20.22
-port = 7001
-username = hive
-password = hive
-database_name = ailab
-kerberos = 1
+host = 10.116.1.75
+port = 10000
+username = aiuser
+password = aiuser
+database_name = dataming
+kerberos = 0
 keytab = assets/test/user.keytab
 krb5config = assets/test/krb5.conf
 kerberos_service_name = hadoop
@@ -45,7 +45,7 @@ principal = ailab@EMR-5XJSY31F
 
 
 [HIVE_METASTORE]
-uris=thrift://10.254.20.18:7004,thrift://10.254.20.22:7004
+uris=thrift://10.116.1.72:9083
 
 [TASK_IMAGES]
 datax=yldc-docker.pkg.coding.yili.com/aiplatform/docker/aihub-datax-yili:latest

+ 1 - 1
server.py

@@ -58,7 +58,7 @@ print('server init finish:)!!!')
 
 
 # Get 健康检查
-@app.get("/ping", description="健康检查")
+@app.get("/jpt/ping", description="健康检查")
 def ping():
     return "pong!!"