123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100 |
- # coding: UTF-8
- # input script according to definition of "run" interface
- from pyspark import SparkContext
- from pyspark.sql import SQLContext
- from pyspark.sql.functions import date_format, concat, col, lit
- from pyspark.sql import SparkSession
- spark = SparkSession \
- .builder \
- .appName("Python Spark SQL basic example") \
- .config("spark.some.config.option", "some-value") \
- .getOrCreate()
- df = spark.read.json("examples/src/main/resources/people.json")
- def read_input_file(spark_session, dag_id, predecessor_id, out_pin_id):
- # 先创建csv文件
- uri = f"{dag_id}/{predecessor_id}/{out_pin_id}.csv"
- data = spark_session.read.csv(f'hdfs:///tmp/{uri}', header=True)
- return data
- def write_output_file(dag_id, subjob_id, pin_id):
- uri = f"{dag_id}/{subjob_id}/{pin_id}"
- jdbcDF.write.mode("overwrite").options(header="true").csv("/home/ai/da/da_aipurchase_dailysale_for_ema_predict.csv")
- pass
- def do_something(df):
- # define process to be executed
- df1 = df.withColumn(
- "inv_date",
- date_format(
- concat(col('inv_year'), lit('/'), col('inv_month'), lit('/01')),
- 'yyyy/MM/dd'
- )
- )
- return df1
- def data_handle(element):
- from prophet import PySparkException
- from graphframes import GraphFrame
- from pyspark.sql import DataFrame
- if isinstance(element, GraphFrame):
- return element.vertices
- elif isinstance(element, DataFrame):
- return element
- else:
- raise PySparkException("输入数据类型错误")
- def run(t1, context_string):
- """
- Define main line of script (two input for instance). Given input data (Dataframes) and configuration output data will be returned (list of Dataframes)
- Params:
- t1 Dataframe, upstream data, whose name should be consistent with first slot definition
- context_strinig String, task config whose name should be "context_string"
- Return:
- Wrap one or more output data as list of dataframes
- """
- sc = SparkContext._active_spark_context
- sqlContext = SQLContext(sc)
- t1 = do_something(t1) # data processing
- # Input Source handler for Prophet Platform
- r = data_handle(t1)
- return [r]
- class DagTask():
- def __int__(self, dag_id, spark_session):
- self.dag_id = dag_id
- self.spark_session = spark_session
- pass
- def preprocess(self, input_infos):
- inputs_data = []
- for (predecessor_id, out_pin_id) in input_infos:
- data = read_input_file(spark_session=self.spark_session, dag_id=self.dag_id,
- predecessor_id=predecessor_id, out_pin_id=out_pin_id)
- inputs_data.append(data)
- return tuple(inputs_data)
- def postprocess(self, rets):
- for ret in rets:
- write_output_file(dag_id=self.dag_id)
- pass
- pass
- def main_func(self, input_infos):
- t1, t2, t3 = self.preprocess(input_infos)
- rets = run_func(t1, t2, t3)
- self.postprocess(rets)
- if __name__ == '__main__':
- dt = DagTask(1,1,input_infos)
- dt.main_func()
|