run.py 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. #!coding=utf8
  2. import json
  3. from pyspark.sql import SparkSession, DataFrame
  4. import pyspark.sql.functions as F
  5. from pyspark.sql.types import *
  6. from pyspark.sql import SparkSession, Row
  7. def main_func(spark, sc):
  8. # 筛选特征列
  9. df = spark.sql('select * from data_lake.d_evt_product_order_dtl')
  10. return [df]
  11. # argv[0] inputs:{"input1_key":"input1_path","input2_key":"input2_path",..}
  12. # argv[1] outputs: [result_path1,result_path2...]
  13. def run(inputs: dict, outputs: list):
  14. spark = SparkSession.builder.config('hive.metastore.uris',
  15. 'thrift://10.116.1.72:9083').enableHiveSupport().getOrCreate()
  16. param_dict = preprocess(input_infos=inputs, ss=spark)
  17. rets = main_func(**param_dict, spark=spark, sc=spark.sparkContext)
  18. postprocess(rets=rets, outputs=outputs)
  19. def read_table(ss: SparkSession, tb_name: str) -> DataFrame:
  20. return ss.sql(f'select * from {tb_name}')
  21. def write_table(df: DataFrame, tb_name: str):
  22. df.write.mode("overwrite").saveAsTable(tb_name)
  23. def preprocess(input_infos: dict, ss: SparkSession) -> dict:
  24. return {k: read_table(ss=ss, tb_name=v) for k, v in input_infos.items()}
  25. def postprocess(rets, outputs):
  26. if isinstance(rets, list):
  27. for idx, df in enumerate(rets):
  28. if idx == 0:
  29. write_table(df=df, tb_name=outputs[idx])
  30. else:
  31. write_table(df=df, tb_name=outputs[0].replace("_output0", "_output" + str(idx)))
  32. if __name__ == '__main__':
  33. run(inputs={}, outputs=['dataming.job32_task25_subnodehrgKx21DwpKfXAetpGoW8_output0_tmp'])