lr_df_demo.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. #!coding=utf8
  2. import json
  3. import sys
  4. from pyspark.sql import SparkSession, DataFrame
  5. from pyspark.sql.types import *
  6. from pyspark.ml.classification import LogisticRegression
  7. from pyspark.ml.feature import VectorAssembler
  8. from pyspark.ml import Pipeline
  9. from pyspark.sql.functions import udf, col
  10. from pyspark.context import SparkContext
  11. # argv[0] inputs:{"input1_key":"input1_path","input2_key":"input2_path",..}
  12. # argv[1] outputs: [result_path1,result_path2...]
  13. def run(inputs: dict, outputs: list):
  14. spark = SparkSession.builder.getOrCreate()
  15. sc = spark.sparkContext
  16. inputs = preprocess(inputs, spark=spark, sc=sc)
  17. rets = main_func(**inputs)
  18. postprocess(rets=rets, outputs=outputs)
  19. def read_input_file(spark: SparkSession, sc: SparkContext, uri) -> DataFrame: # todo: 待替换
  20. rdd_train = sc.textFile(uri)
  21. col_lst = ['uid', 'label', 'feature1', 'feature2', 'feature3', 'feature4', 'feature5', 'feature6', 'feature7',
  22. 'feature8', 'feature9']
  23. rdd_data = rdd_train.map(lambda x: getDict(x.split('\t'), col_lst))
  24. return spark.createDataFrame(rdd_data)
  25. def write_output_file(result: DataFrame, result_path: str): # todo: 待替换
  26. result.write.mode("overwrite").options(header="true").csv(result_path)
  27. def preprocess(input_infos: dict, spark: SparkSession, sc: SparkContext) -> dict:
  28. return {k: read_input_file(spark=spark, sc=sc, uri=v) for k, v in input_infos.items()}
  29. def postprocess(rets, outputs):
  30. [write_output_file(result=ret, result_path=outputs[idx]) for idx, ret in enumerate(rets)]
  31. def getFeatureName():
  32. featureLst = ['feature1', 'feature2', 'feature3', 'feature4', 'feature5', 'feature6', 'feature7', 'feature8',
  33. 'feature9']
  34. colLst = ['uid', 'label'] + featureLst
  35. return featureLst, colLst
  36. def parseFloat(x):
  37. try:
  38. rx = float(x)
  39. except:
  40. rx = 0.0
  41. return rx
  42. def getDict(dictDataLst, colLst):
  43. dictData = {}
  44. for i in range(len(colLst)):
  45. dictData[colLst[i]] = parseFloat(dictDataLst[i])
  46. return dictData
  47. def to_array(col):
  48. def to_array_(v):
  49. return v.toArray().tolist()
  50. return udf(to_array_, ArrayType(DoubleType())).asNondeterministic()(col)
  51. def main_func(train_df, test_df):
  52. feature_lst, col_lst = getFeatureName()
  53. vectorAssembler = VectorAssembler().setInputCols(feature_lst).setOutputCol("features")
  54. print("step 1")
  55. lr = LogisticRegression(regParam=0.01, maxIter=100) # regParam 正则项参数
  56. pipeline = Pipeline(stages=[vectorAssembler, lr])
  57. model = pipeline.fit(train_df)
  58. # 打印参数
  59. print("\n-------------------------------------------------------------------------")
  60. print("LogisticRegression parameters:\n" + lr.explainParams() + "\n")
  61. print("-------------------------------------------------------------------------\n")
  62. print("step 2")
  63. labelsAndPreds = model.transform(test_df).withColumn("probability_xj", to_array(col("probability"))[1]) \
  64. .select("uid", "label", "prediction", "probability_xj")
  65. labelsAndPreds.show()
  66. print(f'labelsAndPreds type is {type(labelsAndPreds)}')
  67. return [labelsAndPreds]
  68. if __name__ == '__main__':
  69. inputs_str = sys.argv[1]
  70. outputs_str = sys.argv[2]
  71. run(inputs=json.loads(inputs_str), outputs=json.loads(outputs_str))