Procházet zdrojové kódy

Merge remote-tracking branch 'origin/master'

# Conflicts:
#	app/services/dag.py
luoyulong před 2 roky
rodič
revize
294e6e422f

+ 6 - 0
app/crud/job_jdbc_datasource.py

@@ -123,3 +123,9 @@ def delete_job_jdbc_datasource(db: Session, ds_id: int):
     db.flush()
     db.refresh(db_item)
     return db_item
+
+def get_job_jdbc_datasource(db: Session, ds_id: int):
+    db_item: models.JobJdbcDatasource = db.query(models.JobJdbcDatasource).filter(models.JobJdbcDatasource.id == ds_id).first()
+    if not db_item:
+        raise Exception('未找到该数据源')
+    return db_item

+ 1 - 2
app/routers/data_management.py

@@ -27,10 +27,9 @@ router = APIRouter(
 @web_try()
 @sxtimeit
 def create_data_management(item: schemas.DataManagementCreate, db: Session = Depends(get_db)):
-    current_time = int(time.time())
+
     table_name = f'project{item.project_id.lower()}_user{item.user_id.lower()}_{item.name.lower()}_{current_time}'
     tmp_table_name = get_tmp_table_name(item.dag_uuid, item.node_id, str(item.out_pin), db)
-    # 执行临时表的转存,目前还不能,先将临时表名存入
     af_run_id = data_transfer_run(tmp_table_name, table_name)
     res = crud.create_data_management(db, item, table_name)
     return res

+ 22 - 19
app/services/dag.py

@@ -5,9 +5,11 @@ from app.utils.send_util import *
 from app.utils.utils import get_cmd_parameter
 from sqlalchemy.orm import Session
 from app.common.hive import hiveDs
+from configs.settings import DefaultOption, config
 
+database_name = config.get('HIVE', 'DATABASE_NAME')
 
-def dag_create_job(dag_uuid: str, dag_script: str, db: Session):
+def dag_create_job(dag_uuid:str,dag_script: str,db: Session):
     af_task = dag_create_task(dag_script)
     af_job = {
         "tasks": [af_task],
@@ -20,16 +22,15 @@ def dag_create_job(dag_uuid: str, dag_script: str, db: Session):
         "executor_timeout": 0,
         "executor_fail_retry_count": 0,
         "trigger_status": 1,
-        "job_mode": 2,
+        "job_mode":2,
         "job_type": 0,
         "user_id": 0,
     }
     res = send_post('/af/af_job', af_job)
     af_job = res['data']
-    crud.create_debug_relation(db, dag_uuid, 'debug', af_job['id'])
+    crud.create_debug_relation(db,dag_uuid,'debug',af_job['id'])
     return af_job
 
-
 def dag_create_task(dag_script: str):
     af_task = {
         "name": "调试作业",
@@ -46,14 +47,13 @@ def dag_create_task(dag_script: str):
     af_task = res['data']
     return af_task
 
-
-def dag_update_job(dag_uuid: str, dag_script: str, db: Session):
+def dag_update_job(dag_uuid:str,dag_script: str, db: Session):
     relation = crud.get_dag_af_id(db, dag_uuid, 'debug')
     af_job_id = relation.af_id
-    res = send_get("/af/af_job/getOnce", af_job_id)
+    res = send_get("/af/af_job/getOnce",af_job_id)
     old_af_job = res['data']
     old_af_task = old_af_job['tasks'][0]
-    af_task = dag_put_task(dag_script, old_af_task)
+    af_task = dag_put_task(dag_script,old_af_task)
     af_job = {
         "tasks": [af_task],
         "name": "调试任务",
@@ -71,7 +71,7 @@ def dag_update_job(dag_uuid: str, dag_script: str, db: Session):
     return af_job
 
 
-def dag_put_task(dag_script: str, old_af_task):
+def dag_put_task(dag_script: str,old_af_task):
     af_task = {
         "name": "调试作业",
         "file_urls": [],
@@ -82,13 +82,12 @@ def dag_put_task(dag_script: str, old_af_task):
         "run_image": "",
         "task_type": "sparks",
     }
-    res = send_put('/af/af_task', old_af_task['id'], af_task)
+    res = send_put('/af/af_task', old_af_task['id'],af_task)
     af_task = res['data']
     return af_task
 
-
-def dag_job_submit(dag_uuid: str, dag_script: str, db: Session):
-    job_relation = crud.get_dag_af_id(db, dag_uuid, 'debug')
+def dag_job_submit(dag_uuid:str,dag_script: str,db: Session):
+    job_relation = crud.get_dag_af_id(db,dag_uuid,'debug')
     af_job = None
     if job_relation is None:
         af_job = dag_create_job(dag_uuid, dag_script, db)
@@ -96,12 +95,12 @@ def dag_job_submit(dag_uuid: str, dag_script: str, db: Session):
         af_job = dag_update_job(dag_uuid, dag_script, db)
     current_time = int(time.time())
     send_submit(af_job['id'])
-    for i in range(0, 11):
+    for i in range(0,11):
         time.sleep(2)
         res = get_job_last_parsed_time(af_job['id'])
         last_parsed_time = res['data']['last_parsed_time']
         if last_parsed_time and current_time < int(last_parsed_time):
-            send_pause(af_job['id'], 1)
+            send_pause(af_job['id'],1)
             send_execute(af_job['id'])
             print(f"{af_job['id']}<==执行成功==>{last_parsed_time}")
             break
@@ -111,10 +110,10 @@ def dag_job_submit(dag_uuid: str, dag_script: str, db: Session):
 
 
 def get_tmp_table_name(dag_uuid: str, node_id: str, out_pin: str, db: Session):
-    relation = crud.get_dag_af_id(db, dag_uuid, 'debug')
+    relation = crud.get_dag_af_id(db,dag_uuid, 'debug')
     job_id = relation.af_id
-    af_job_run = crud.get_airflow_run_once_debug_mode(db, job_id)
-    tasks = af_job_run.details['tasks'] if len(af_job_run.details['tasks']) > 0 else {}
+    af_job_run = crud.get_airflow_run_once_debug_mode(db,job_id)
+    tasks = af_job_run.details['tasks'] if len(af_job_run.details['tasks'])>0 else {}
     task_id = None
     for task in tasks:
         t_id = task.split('_')[0]
@@ -123,10 +122,14 @@ def get_tmp_table_name(dag_uuid: str, node_id: str, out_pin: str, db: Session):
             task_id = t_id
             break
     if task_id:
-        table_name = f'job{job_id}_task{task_id}_subnode{node_id}_output{out_pin}_tmp'
+        table_name = f'{database_name}.job{job_id}_task{task_id}_subnode{node_id}_output{out_pin}_tmp'
         t_list = hiveDs.list_tables()
         if table_name.lower() not in t_list:
             raise Exception('该节点不存在中间结果')
         return table_name
     else:
         raise Exception('该节点不存在中间结果')
+
+def get_transfer_table_name(project_id: str, user_id: str, name: str, ):
+    current_time = int(time.time())
+    return f'{database_name}.project{project_id.lower()}_user{user_id.lower()}_{name.lower()}_{current_time}'

+ 1 - 1
app/services/datax.py

@@ -57,7 +57,7 @@ def datax_create_task(job_info: models.JobInfo):
         "task_type": "datax",
         "user_id": 0,
     }
-    res = send_post('/jpt/af_task', af_task)
+    res = send_post('/af/af_task', af_task)
     af_task = res['data']
     return af_task
 

+ 5 - 2
app/services/jm_job.py

@@ -4,6 +4,7 @@ import time
 from turtle import update
 from app import crud, models
 from app.common import minio
+from app.core.datasource.datasource import DataSourceBase
 from app.crud.jm_homework_datasource_relation import get_jm_relations
 from app.utils.send_util import *
 from sqlalchemy.orm import Session
@@ -158,7 +159,8 @@ def red_dag_and_format(jm_homework: models.JmHomework, db: Session):
             for filed in fileds:
                 script += filed['dataField'] + ','
             script = script.strip(',')
-            script += ' from ' + node_relation_dict[node['id']].table
+            data_source = crud.get_job_jdbc_datasource(db,node_relation_dict[node['id']].datasource_id)
+            script += ' from ' + data_source.database_name + '.'+node_relation_dict[node['id']].table+''
             sub_node = {
                 "id": node['id'],
                 "name": node['name'],
@@ -168,8 +170,9 @@ def red_dag_and_format(jm_homework: models.JmHomework, db: Session):
             sub_nodes.append(sub_node)
         elif node['op'] == 'outputsource':
             fileds = node['data']['output_source']
+            data_source = crud.get_job_jdbc_datasource(db,node_relation_dict[node['id']].datasource_id)
             script = '''def main_func (input0, spark,sc):
-    input0.write.mode("overwrite").saveAsTable("'''+node_relation_dict[node['id']].table+'''")'''
+    input0.write.mode("overwrite").saveAsTable("''' + data_source.database_name + '.'+node_relation_dict[node['id']].table+'''")'''
             inputs = {}
             index = 0
             input_list = t_s[node['id']] if node['id'] in t_s.keys() else []

+ 20 - 16
app/utils/send_util.py

@@ -11,17 +11,17 @@ def send_post(uri,data):
     if 'code' in result.keys() and result['code'] == 200:
         return res.json()
     else:
-        print(result)
-        raise Exception(f'{uri}-->请求airflow失败-->'+result['msg'])
+        msg = result['msg'] if 'msf' in result.keys() else result
+        raise Exception(f'{uri}-->请求airflow失败-->{msg}')
 
 def send_submit(af_job_id):
     res = requests.post(url=f'http://{HOST}:{PORT}/af/af_job/submit?id='+str(af_job_id))
     result = res.json()
-    print(result)
     if 'code' in result.keys() and result['code'] == 200:
         return res.json()
     else:
-        raise Exception('提交任务,请求airflow失败-->'+result['msg'])
+        msg = result['msg'] if 'msf' in result.keys() else result
+        raise Exception(f'提交任务,请求airflow失败-->{msg}')
 
 
 def send_put(uri,path_data,data):
@@ -30,7 +30,8 @@ def send_put(uri,path_data,data):
     if 'code' in result.keys() and result['code'] == 200:
         return res.json()
     else:
-        raise Exception(f'{uri}-->请求airflow失败-->'+result['msg'])
+        msg = result['msg'] if 'msf' in result.keys() else result
+        raise Exception(f'{uri}-->请求airflow失败-->{msg}')
 
 def send_get(uri,path_data):
     res = requests.get(url=f'http://{HOST}:{PORT}{uri}/{path_data}')
@@ -38,18 +39,19 @@ def send_get(uri,path_data):
     if 'code' in result.keys() and result['code'] == 200:
         return res.json()
     else:
-        raise Exception(f'{uri}-->请求airflow失败-->'+result['msg'])
+        msg = result['msg'] if 'msf' in result.keys() else result
+        raise Exception(f'{uri}-->请求airflow失败-->{msg}')
 
 
 # 执行任务
 def send_execute(path_data):
     res = requests.post(url=f'http://{HOST}:{PORT}/af/af_job/{str(path_data)}/run')
     result = res.json()
-    print(result)
     if 'code' in result.keys() and result['code'] == 200:
         return res.json()
     else:
-        raise Exception('执行一次任务,请求airflow失败-->'+result['msg'])
+        msg = result['msg'] if 'msf' in result.keys() else result
+        raise Exception(f'执行任务,请求airflow失败-->{msg}')
 
 # 起停任务
 def send_pause(af_job_id, status):
@@ -59,7 +61,8 @@ def send_pause(af_job_id, status):
     if 'code' in result.keys() and result['code'] == 200:
         return res.json()
     else:
-        raise Exception('修改任务状态,请求airflow失败-->'+result['msg'])
+        msg = result['msg'] if 'msf' in result.keys() else result
+        raise Exception(f'修改任务状态,请求airflow失败-->{msg}')
 
 # 删除任务
 def send_delete(uri, path_data):
@@ -68,28 +71,28 @@ def send_delete(uri, path_data):
     if 'code' in result.keys() and result['code'] == 200:
         return res.json()
     else:
-        print(result)
-        raise Exception(f'{uri}-->请求airflow失败-->'+result['msg'])
+        msg = result['msg'] if 'msf' in result.keys() else result
+        raise Exception(f'{uri}-->请求airflow失败-->{msg}')
 
 # 获取airflow端dag文件生成时间
 def get_job_last_parsed_time(path_data):
     res = requests.get(url=f'http://{HOST}:{PORT}/af/af_job/{path_data}/last_parsed_time')
     result = res.json()
-    print(result)
     if 'code' in result.keys() and result['code'] == 200:
         return res.json()
     else:
-        raise Exception('获取上次转化时间-->请求airflow失败-->'+result['msg'])
+        msg = result['msg'] if 'msf' in result.keys() else result
+        raise Exception(f'获取上次转化时间-->请求airflow失败-->{msg}')
 
 # 获取job某次运行的状态
 def get_job_run_status(path_data):
     res = requests.get(url=f'http://{HOST}:{PORT}/af/af_run/{path_data}/status')
     result = res.json()
-    print(result)
     if 'code' in result.keys() and result['code'] == 200:
         return res.json()
     else:
-        raise Exception('获取job某次运行的状态-->请求airflow失败-->'+result['msg'])
+        msg = result['msg'] if 'msf' in result.keys() else result
+        raise Exception(f'获取job某次运行的状态-->请求airflow失败-->{msg}')
 
 # 中间结果转存
 def data_transfer_run(source_tb: str, target_tb: str):
@@ -99,4 +102,5 @@ def data_transfer_run(source_tb: str, target_tb: str):
     if 'code' in result.keys() and result['code'] == 200:
         return res.json()
     else:
-        raise Exception('中间结果转存,请求airflow失败-->'+result['msg'])
+        msg = result['msg'] if 'msf' in result.keys() else result
+        raise Exception('中间结果转存,请求airflow失败-->{msg}')

+ 0 - 12
constants/test.table

@@ -1,12 +0,0 @@
-{
-A: [1,2,3]
-},
-{
-B: [2,3,4]
-},
-{
-C: [4,5,6]
-},
-{
-D: [5,6,7]
-}

+ 1 - 1
production.ini

@@ -34,7 +34,7 @@ host = 10.254.20.22
 port = 7001
 username = hive
 password = hive
-database_name = default
+database_name = ailab
 kerberos = 1
 keytab = assets/test/user.keytab
 krb5config = assets/test/krb5.conf