task.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. import json
  2. from typing import Optional, Dict
  3. from app.core.airflow.af_util import spark_result_tb_name
  4. from app.schemas import AirflowTask
  5. from jinja2 import Environment, PackageLoader, select_autoescape
  6. from app.common.minio import FileHandler
  7. from configs.settings import config
  8. class TaskCompiler:
  9. def __init__(self, item: AirflowTask):
  10. self.task = item
  11. self.default_image = None
  12. self.default_cmd = None
  13. @staticmethod
  14. def render_spark_script(parameters, template_file):
  15. env = Environment(
  16. loader=PackageLoader('app.core.airflow'),
  17. autoescape=select_autoescape()
  18. )
  19. template = env.get_template(template_file)
  20. return template.render(parameters)
  21. def translate(self, job_id, task_mode=1):
  22. return {'image': self.task.run_image or self.default_image,
  23. 'cmds': ["/bin/bash", "-c", f"{self.task.cmd or self.default_cmd} "],
  24. 'script': self.task.script,
  25. 'id': f'{self.task.id}',
  26. 'env': {**{"SCRIPT": self.task.script}, **self.task.envs},
  27. 'operator_name': f'op_{self.task.id}',
  28. 'name': self.task.name,
  29. 'desc': ""
  30. }
  31. @staticmethod
  32. def write_to_oss(oss_path, context, bucket='mytest'):
  33. if isinstance(context, str):
  34. context = bytes(context, 'utf-8')
  35. minio_handler = FileHandler(bucket_name=bucket)
  36. return minio_handler.put_byte_file(file_name=oss_path, file_content=context)
  37. class JavaTaskCompiler(TaskCompiler):
  38. def __init__(self, item: AirflowTask):
  39. super(JavaTaskCompiler, self).__init__(item)
  40. self.default_image = config.get('TASK_IMAGES', 'java') # 'SXKJ:32775/java:1.0'
  41. self.default_cmd = "echo \"$SCRIPT\" > run.py && python run.py"
  42. self.task.cmd = self.task.cmd or self.default_cmd
  43. tar_name = self.task.file_urls[0].split('/')[-1].split('_')[-1]
  44. self.task.cmd = f'curl http://{config.get("BACKEND", "url")}/jpt/files/{self.task.file_urls[0]} --output {tar_name} && {self.task.cmd}'
  45. class PythonTaskCompiler(TaskCompiler):
  46. def __init__(self, item: Optional[AirflowTask]):
  47. super(PythonTaskCompiler, self).__init__(item)
  48. self.default_image = config.get('TASK_IMAGES', 'python') # 'SXKJ:32775/pod_python:1.1'
  49. self.default_cmd = "python main.py"
  50. self.task.cmd = "echo \"$SCRIPT\" > main.py && " + (self.task.cmd or self.default_cmd)
  51. if config.get('HOST_ALIAS', 'enable', fallback=None) in ['true', "True", True]:
  52. host_alias: Dict = json.loads(config.get('HOST_ALIAS', 'host_alias'))
  53. for k, v in host_alias.items():
  54. self.task.cmd = f"echo '{k} {v}' >> /etc/hosts && {self.task.cmd}"
  55. class DataXTaskCompiler(TaskCompiler):
  56. def __init__(self, item: AirflowTask):
  57. super(DataXTaskCompiler, self).__init__(item)
  58. self.default_image = config.get('TASK_IMAGES', 'datax') # 'SXKJ:32775/pod_datax:0.9'
  59. self.default_cmd = f"cd datax/bin && echo \"$SCRIPT\" > transform_datax.py &&cat transform_datax.py && " \
  60. f"python3 transform_datax.py && cat config.json && $HOME/conda/envs/py27/b" \
  61. f"in/python datax.py {self.task.cmd_parameters} config.json "
  62. def translate(self, job_id, task_mode=1):
  63. print(f'{self.task.envs}')
  64. script_str = self.render_spark_script(
  65. parameters={'script': self.task.script,
  66. 'first_begin_time': self.task.envs.get('first_begin_time', None),
  67. 'last_key': self.task.envs.get('last_key', None),
  68. 'current_key': self.task.envs.get('current_key', None),
  69. 'location_key': self.task.envs.get('location_key', None),
  70. 'location_value': self.task.envs.get('location_value', None),
  71. 'partition_key': self.task.envs.get('partition_key', None),
  72. 'partition_word': self.task.envs.get('partition_word', None),
  73. 'partition_format': self.task.envs.get('partition_format', None),
  74. 'partition_diff': self.task.envs.get('partition_diff', None),
  75. },
  76. template_file="transform_datax.py.jinja2")
  77. res = {'image': self.task.run_image or self.default_image,
  78. 'cmds': ["/bin/bash", "-c", f"{self.task.cmd or self.default_cmd} "],
  79. 'script': script_str,
  80. 'id': f'{self.task.id}',
  81. 'env': {**{"SCRIPT": script_str}, **self.task.envs},
  82. 'operator_name': f'op_{self.task.id}',
  83. 'name': self.task.name,
  84. 'desc': "",
  85. 'execution_timeout': self.task.envs.get('execution_timeout', 30 * 60),
  86. 'retries': self.task.envs.get('retries', 3)
  87. }
  88. return res
  89. class SparksTaskCompiler(TaskCompiler):
  90. def __init__(self, item: Optional[AirflowTask]):
  91. super(SparksTaskCompiler, self).__init__(item)
  92. self.default_image = config.get('TASK_IMAGES', 'sparks')
  93. parameters = {"master": "yarn",
  94. "deploy-mode": "cluster",
  95. "driver-memory": "1g",
  96. "driver-cores ": 1,
  97. "executor-memory": "15g",
  98. "executor-cores": 3,
  99. "num-executors": 2
  100. }
  101. # "archives": "/workspace/py37.zip#python3env"
  102. if item:
  103. parameters.update({
  104. "archives": f"{self.task.envs.get('requirement_package_path', '/workspace/py37.zip')}#python3env"
  105. })
  106. spark_config = {'spark.default.parallelism': 1,
  107. "spark.executor.memoryOverhead": "1g",
  108. "spark.driver.memoryOverhead": "1g",
  109. "spark.yarn.maxAppAttempts": 1,
  110. "spark.yarn.submit.waitAppCompletion": "true",
  111. "spark.pyspark.driver.python": "python3env/py37/bin/python",
  112. "spark.yarn.appMasterEnv.PYSPARK_PYTHON": "python3env/py37/bin/python",
  113. "spark.pyspark.python": "python3env/py37/bin/python",
  114. "spark.sql.hive.convertMetastoreOrc": "false",
  115. "spark.sql.hive.convertMetastoreParquet": "false"
  116. # "spark.sql.legacy.allowCreatingManagedTableUsingNonemptyLocation": "true"
  117. }
  118. param_str = ' '.join([f'--{k} {v}' for k, v in parameters.items()])
  119. param_str += ''.join([f' --conf {k}={v}' for k, v in spark_config.items()])
  120. basic_cmds = "cd /workspace && echo \"$SCRIPT\" > run.py && ${SPARK_HOME}/bin/spark-submit"
  121. if config.get('HOST_ALIAS', 'enable', fallback=None) in ['true', "True", True]:
  122. host_alias: Dict = json.loads(config.get('HOST_ALIAS', 'host_alias'))
  123. for k, v in host_alias.items():
  124. basic_cmds = f"echo '{k} {v}' >> /etc/hosts && {basic_cmds}"
  125. if config.get('K8S', 'enable_kerberos', fallback=None) in ['true', "True", True]:
  126. principal = config.get('HIVE', 'principal', fallback=None)
  127. basic_cmds = f"kinit -kt /workspace/conf/user.keytab {principal} && {basic_cmds}"
  128. self.cmd_str = lambda name: f"{basic_cmds} --name {name} {param_str} run.py"
  129. def translate(self, job_id, task_mode=1):
  130. # dag_script = {
  131. # "sub_nodes": [
  132. # {
  133. # "id": "1",
  134. # "name": "SqlNode1",
  135. # "op": "sql",
  136. # "script": "select * from train",
  137. # },
  138. # {
  139. # "id": "2",
  140. # "name": "SqlNode1",
  141. # "op": "sql",
  142. # "script": "select * from test",
  143. # },
  144. # {
  145. # "id": "3",
  146. # "name": "PysparkNode1",
  147. # "op": "pyspark", # or python
  148. # "inputs": {'train': ("1", 0),
  149. # 'test': ("2", 0)
  150. # },
  151. # "script": "import os\n ...",
  152. # },
  153. # ],
  154. # "edges": [
  155. # ("1", "3"),
  156. # ("2", "3")
  157. # ],
  158. # "requirements":[
  159. #
  160. # ]
  161. # }
  162. infos = json.loads(self.task.script)
  163. sub_nodes = []
  164. skip_nodes = []
  165. for info in infos['sub_nodes']:
  166. if info.get('skip', False):
  167. skip_nodes.append(info["id"])
  168. continue
  169. if info['op'] == 'sql':
  170. template_file = 'sql_script_template.py.jinja2'
  171. elif info['op'] == 'pyspark':
  172. template_file = 'pyspark_script_template.py.jinja2'
  173. else:
  174. continue
  175. inputs = {k: spark_result_tb_name(job_id=job_id, task_id=self.task.id, spark_node_id=v[0],
  176. out_pin=v[1], is_tmp=task_mode) for k, v in
  177. info.get('inputs', {}).items()}
  178. outputs = [spark_result_tb_name(job_id=job_id, task_id=self.task.id, spark_node_id=info['id'],
  179. out_pin=0, is_tmp=task_mode)]
  180. sub_node = {
  181. 'id': f'{self.task.id}_{info["id"]}',
  182. 'name': info['name'],
  183. 'env': {
  184. 'SCRIPT': self.render_spark_script(
  185. parameters={'script': info['script'], 'inputs': inputs, 'outputs': outputs,
  186. "hive_metastore_uris": config.get('HIVE_METASTORE', 'uris')},
  187. template_file=template_file),
  188. },
  189. 'cmds': ['/bin/bash', '-c', self.cmd_str(name=f'spark_{self.task.id}_{info["id"]}')],
  190. 'image': config.get('TASK_IMAGES', 'sparks')
  191. }
  192. sub_nodes.append(sub_node)
  193. edges = []
  194. for (source, sink) in infos['edges']:
  195. if source not in skip_nodes and sink not in skip_nodes:
  196. edges.append((f'{self.task.id}_{source}', f'{self.task.id}_{sink}'))
  197. return {
  198. "id": self.task.id,
  199. "sub_nodes": sub_nodes,
  200. "edges": edges,
  201. 'name': self.task.name,
  202. 'desc': "first spark dag task"
  203. }