12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849 |
- #!coding=utf8
- import json
- from pyspark.sql import SparkSession, DataFrame
- import pyspark.sql.functions as F
- from pyspark.sql.types import *
- from pyspark.sql import SparkSession, Row
- def main_func(spark, sc):
- # 筛选特征列
- df = spark.sql('select * from data_lake.d_evt_product_order_dtl')
- return [df]
- # argv[0] inputs:{"input1_key":"input1_path","input2_key":"input2_path",..}
- # argv[1] outputs: [result_path1,result_path2...]
- def run(inputs: dict, outputs: list):
- spark = SparkSession.builder.config('hive.metastore.uris',
- 'thrift://10.116.1.72:9083').enableHiveSupport().getOrCreate()
- param_dict = preprocess(input_infos=inputs, ss=spark)
- rets = main_func(**param_dict, spark=spark, sc=spark.sparkContext)
- postprocess(rets=rets, outputs=outputs)
- def read_table(ss: SparkSession, tb_name: str) -> DataFrame:
- return ss.sql(f'select * from {tb_name}')
- def write_table(df: DataFrame, tb_name: str):
- df.write.mode("overwrite").saveAsTable(tb_name)
- def preprocess(input_infos: dict, ss: SparkSession) -> dict:
- return {k: read_table(ss=ss, tb_name=v) for k, v in input_infos.items()}
- def postprocess(rets, outputs):
- if isinstance(rets, list):
- for idx, df in enumerate(rets):
- if idx == 0:
- write_table(df=df, tb_name=outputs[idx])
- else:
- write_table(df=df, tb_name=outputs[0].replace("_output0", "_output" + str(idx)))
- if __name__ == '__main__':
- run(inputs={}, outputs=['dataming.job32_task25_subnodehrgKx21DwpKfXAetpGoW8_output0_tmp'])
|