import json from typing import Optional, Dict from app.core.airflow.af_util import spark_result_tb_name from app.schemas import AirflowTask from jinja2 import Environment, PackageLoader, select_autoescape from app.common.minio import FileHandler from configs.settings import config class TaskCompiler: def __init__(self, item: AirflowTask): self.task = item self.default_image = None self.default_cmd = None @staticmethod def render_spark_script(parameters, template_file): env = Environment( loader=PackageLoader('app.core.airflow'), autoescape=select_autoescape() ) template = env.get_template(template_file) return template.render(parameters) def translate(self, job_id, task_mode=1): return {'image': self.task.run_image or self.default_image, 'cmds': ["/bin/bash", "-c", f"{self.task.cmd or self.default_cmd} "], 'script': self.task.script, 'id': f'{self.task.id}', 'env': {**{"SCRIPT": self.task.script}, **self.task.envs}, 'operator_name': f'op_{self.task.id}', 'name': self.task.name, 'desc': "" } @staticmethod def write_to_oss(oss_path, context, bucket='mytest'): if isinstance(context, str): context = bytes(context, 'utf-8') minio_handler = FileHandler(bucket_name=bucket) return minio_handler.put_byte_file(file_name=oss_path, file_content=context) class JavaTaskCompiler(TaskCompiler): def __init__(self, item: AirflowTask): super(JavaTaskCompiler, self).__init__(item) self.default_image = config.get('TASK_IMAGES', 'java') # 'SXKJ:32775/java:1.0' self.default_cmd = "echo \"$SCRIPT\" > run.py && python run.py" self.task.cmd = self.task.cmd or self.default_cmd tar_name = self.task.file_urls[0].split('/')[-1].split('_')[-1] self.task.cmd = f'curl http://{config.get("BACKEND", "url")}/jpt/files/{self.task.file_urls[0]} --output {tar_name} && {self.task.cmd}' class PythonTaskCompiler(TaskCompiler): def __init__(self, item: Optional[AirflowTask]): super(PythonTaskCompiler, self).__init__(item) self.default_image = config.get('TASK_IMAGES', 'python') # 'SXKJ:32775/pod_python:1.1' self.default_cmd = "python main.py" self.task.cmd = "echo \"$SCRIPT\" > main.py && " + (self.task.cmd or self.default_cmd) # if config.get('HOST_ALIAS', 'enable', fallback=None) in ['true', "True", True]: # host_alias: Dict = json.loads(config.get('HOST_ALIAS', 'host_alias')) # for k, v in host_alias.items(): # self.task.cmd = f"echo '{k} {v}' >> /etc/hosts && {self.task.cmd}" class DataXTaskCompiler(TaskCompiler): def __init__(self, item: AirflowTask): super(DataXTaskCompiler, self).__init__(item) self.default_image = config.get('TASK_IMAGES', 'datax') # 'SXKJ:32775/pod_datax:0.9' self.default_cmd = f"cd datax/bin && echo \"$SCRIPT\" > transform_datax.py &&cat transform_datax.py && " \ f"python3 transform_datax.py && cat config.json && $HOME/conda/envs/py27/b" \ f"in/python datax.py {self.task.cmd_parameters} config.json " def translate(self, job_id, task_mode=1): print(f'{self.task.envs}') script_str = self.render_spark_script( parameters={'script': self.task.script, 'first_begin_time': self.task.envs.get('first_begin_time', None), 'last_key': self.task.envs.get('last_key', None), 'current_key': self.task.envs.get('current_key', None), 'location_key': self.task.envs.get('location_key', None), 'location_value': self.task.envs.get('location_value', None), 'partition_key': self.task.envs.get('partition_key', None), 'partition_word': self.task.envs.get('partition_word', None), 'partition_format': self.task.envs.get('partition_format', None), 'partition_diff': self.task.envs.get('partition_diff', None), }, template_file="transform_datax.py.jinja2") res = {'image': self.task.run_image or self.default_image, 'cmds': ["/bin/bash", "-c", f"{self.task.cmd or self.default_cmd} "], 'script': script_str, 'id': f'{self.task.id}', 'env': {**{"SCRIPT": script_str}, **self.task.envs}, 'operator_name': f'op_{self.task.id}', 'name': self.task.name, 'desc': "" } return res class SparksTaskCompiler(TaskCompiler): def __init__(self, item: Optional[AirflowTask]): super(SparksTaskCompiler, self).__init__(item) self.default_image = config.get('TASK_IMAGES', 'sparks') parameters = {"master": "yarn", "deploy-mode": "cluster", "driver-memory": "1g", "driver-cores ": 1, "executor-memory": "1g", "executor-cores": 1, "num-executors": 1, "archives": "/workspace/py37.zip'#python3env" } if item: parameters.update({ "archives": f"{self.task.envs.get('requirement_package_path','/workspace/py37.zip')}#python3env" }) spark_config = {'spark.default.parallelism': 1, "spark.executor.memoryOverhead": "1g", "spark.driver.memoryOverhead": "1g", "spark.yarn.maxAppAttempts": 1, "spark.yarn.submit.waitAppCompletion": "true", "spark.pyspark.driver.python": "python3env/py37/bin/python", "spark.yarn.appMasterEnv.PYSPARK_PYTHON": "python3env/py37/bin/python", "spark.pyspark.python": "python3env/py37/bin/python", # "spark.sql.legacy.allowCreatingManagedTableUsingNonemptyLocation": "true" } param_str = ' '.join([f'--{k} {v}' for k, v in parameters.items()]) param_str += ''.join([f' --conf {k}={v}' for k, v in spark_config.items()]) basic_cmds = "cd /workspace && echo \"$SCRIPT\" > run.py && ${SPARK_HOME}/bin/spark-submit" if config.get('HOST_ALIAS', 'enable', fallback=None) in ['true', "True", True]: host_alias: Dict = json.loads(config.get('HOST_ALIAS', 'host_alias')) for k, v in host_alias.items(): basic_cmds = f"echo '{k} {v}' >> /etc/hosts && {basic_cmds}" if config.get('K8S', 'enable_kerberos', fallback=None) in ['true', "True", True]: principal = config.get('HIVE', 'principal', fallback=None) basic_cmds = f"kinit -kt /workspace/conf/user.keytab {principal} && {basic_cmds}" self.cmd_str = lambda name: f"{basic_cmds} --name {name} {param_str} run.py" def translate(self, job_id, task_mode=1): # dag_script = { # "sub_nodes": [ # { # "id": "1", # "name": "SqlNode1", # "op": "sql", # "script": "select * from train", # }, # { # "id": "2", # "name": "SqlNode1", # "op": "sql", # "script": "select * from test", # }, # { # "id": "3", # "name": "PysparkNode1", # "op": "pyspark", # or python # "inputs": {'train': ("1", 0), # 'test': ("2", 0) # }, # "script": "import os\n ...", # }, # ], # "edges": [ # ("1", "3"), # ("2", "3") # ], # "requirements":[ # # ] # } infos = json.loads(self.task.script) sub_nodes = [] skip_nodes = [] for info in infos['sub_nodes']: if info.get('skip', False): skip_nodes.append(info["id"]) continue if info['op'] == 'sql': template_file = 'sql_script_template.py.jinja2' elif info['op'] == 'pyspark': template_file = 'pyspark_script_template.py.jinja2' else: continue inputs = {k: spark_result_tb_name(job_id=job_id, task_id=self.task.id, spark_node_id=v[0], out_pin=v[1], is_tmp=task_mode) for k, v in info.get('inputs', {}).items()} outputs = [spark_result_tb_name(job_id=job_id, task_id=self.task.id, spark_node_id=info['id'], out_pin=0, is_tmp=task_mode)] sub_node = { 'id': f'{self.task.id}_{info["id"]}', 'name': info['name'], 'env': { 'SCRIPT': self.render_spark_script( parameters={'script': info['script'], 'inputs': inputs, 'outputs': outputs, "hive_metastore_uris": config.get('HIVE_METASTORE', 'uris')}, template_file=template_file), }, 'cmds': ['/bin/bash', '-c', self.cmd_str(name=f'spark_{self.task.id}_{info["id"]}')], 'image': config.get('TASK_IMAGES', 'sparks') } sub_nodes.append(sub_node) edges = [] for (source, sink) in infos['edges']: if source not in skip_nodes and sink not in skip_nodes: edges.append((f'{self.task.id}_{source}', f'{self.task.id}_{sink}')) return { "id": self.task.id, "sub_nodes": sub_nodes, "edges": edges, 'name': self.task.name, 'desc': "first spark dag task" }