spark_template_demo.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. # coding: UTF-8
  2. # input script according to definition of "run" interface
  3. from pyspark import SparkContext
  4. from pyspark.sql import SQLContext
  5. from pyspark.sql.functions import date_format, concat, col, lit
  6. from pyspark.sql import SparkSession
  7. spark = SparkSession \
  8. .builder \
  9. .appName("Python Spark SQL basic example") \
  10. .config("spark.some.config.option", "some-value") \
  11. .getOrCreate()
  12. df = spark.read.json("examples/src/main/resources/people.json")
  13. def read_input_file(spark_session, dag_id, predecessor_id, out_pin_id):
  14. # 先创建csv文件
  15. uri = f"{dag_id}/{predecessor_id}/{out_pin_id}.csv"
  16. data = spark_session.read.csv(f'hdfs:///tmp/{uri}', header=True)
  17. return data
  18. def write_output_file(dag_id, subjob_id, pin_id):
  19. uri = f"{dag_id}/{subjob_id}/{pin_id}"
  20. jdbcDF.write.mode("overwrite").options(header="true").csv("/home/ai/da/da_aipurchase_dailysale_for_ema_predict.csv")
  21. pass
  22. def do_something(df):
  23. # define process to be executed
  24. df1 = df.withColumn(
  25. "inv_date",
  26. date_format(
  27. concat(col('inv_year'), lit('/'), col('inv_month'), lit('/01')),
  28. 'yyyy/MM/dd'
  29. )
  30. )
  31. return df1
  32. def data_handle(element):
  33. from prophet import PySparkException
  34. from graphframes import GraphFrame
  35. from pyspark.sql import DataFrame
  36. if isinstance(element, GraphFrame):
  37. return element.vertices
  38. elif isinstance(element, DataFrame):
  39. return element
  40. else:
  41. raise PySparkException("输入数据类型错误")
  42. def run(t1, context_string):
  43. """
  44. Define main line of script (two input for instance). Given input data (Dataframes) and configuration output data will be returned (list of Dataframes)
  45. Params:
  46. t1 Dataframe, upstream data, whose name should be consistent with first slot definition
  47. context_strinig String, task config whose name should be "context_string"
  48. Return:
  49. Wrap one or more output data as list of dataframes
  50. """
  51. sc = SparkContext._active_spark_context
  52. sqlContext = SQLContext(sc)
  53. t1 = do_something(t1) # data processing
  54. # Input Source handler for Prophet Platform
  55. r = data_handle(t1)
  56. return [r]
  57. class DagTask():
  58. def __int__(self, dag_id, spark_session):
  59. self.dag_id = dag_id
  60. self.spark_session = spark_session
  61. pass
  62. def preprocess(self, input_infos):
  63. inputs_data = []
  64. for (predecessor_id, out_pin_id) in input_infos:
  65. data = read_input_file(spark_session=self.spark_session, dag_id=self.dag_id,
  66. predecessor_id=predecessor_id, out_pin_id=out_pin_id)
  67. inputs_data.append(data)
  68. return tuple(inputs_data)
  69. def postprocess(self, rets):
  70. for ret in rets:
  71. write_output_file(dag_id=self.dag_id)
  72. pass
  73. pass
  74. def main_func(self, input_infos):
  75. t1, t2, t3 = self.preprocess(input_infos)
  76. rets = run_func(t1, t2, t3)
  77. self.postprocess(rets)
  78. if __name__ == '__main__':
  79. dt = DagTask(1,1,input_infos)
  80. dt.main_func()