run1.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. #!coding=utf8
  2. import json
  3. import sys
  4. from pyspark.sql import SparkSession, DataFrame
  5. from pyspark.sql.types import DoubleType
  6. from pyspark.sql.functions import col
  7. import pyspark.sql.functions as F
  8. from pyspark.sql.types import *
  9. import sys
  10. from pyspark.sql import SparkSession, Row
  11. from pyspark.sql.types import *
  12. from pyspark.ml.linalg import Vectors
  13. from pyspark.ml.classification import LogisticRegression
  14. from pyspark.ml.feature import StringIndexer, VectorAssembler
  15. from pyspark.ml import Pipeline
  16. from pyspark.sql.functions import udf, col
  17. def somefunc(value):
  18. if value < 100:
  19. return 0
  20. else:
  21. return 1
  22. def to_array(col):
  23. def to_array_(v):
  24. return v.toArray().tolist()
  25. # Important: asNondeterministic requires Spark 2.3 or later
  26. # It can be safely removed i.e.
  27. # return udf(to_array_, ArrayType(DoubleType()))(col)
  28. # but at the cost of decreased performance
  29. return udf(to_array_, ArrayType(DoubleType())).asNondeterministic()(col)
  30. def main_func(spark, sc):
  31. # 筛选特征列
  32. print('step 1')
  33. df01 = spark.sql(
  34. 'select req_qty, row_qty,row_amt,purc_price,supply_price,ship_qty,receipt_qty,audit_qty from dataming.d_evt_product_order_dtl')
  35. # 把每一列的内容转换成double类型
  36. coff = 1.2
  37. df01 = df01.withColumn("supply_price", df01.supply_price.cast('double') * coff)
  38. df01 = df01.withColumn("req_qty", df01.req_qty.cast('double') * coff)
  39. df01 = df01.withColumn("row_qty", df01.row_qty.cast('double') * coff)
  40. df01 = df01.withColumn("row_amt", df01.row_amt.cast('double') * coff)
  41. df01 = df01.withColumn("purc_price", df01.purc_price.cast('double') * coff)
  42. df01 = df01.withColumn("ship_qty", df01.ship_qty.cast('double') * coff)
  43. df01 = df01.withColumn("receipt_qty", df01.receipt_qty.cast('double') * coff)
  44. df01 = df01.withColumn("audit_qty", df01.audit_qty.cast('double') * coff)
  45. # 对df01中所有的null进行处理
  46. df02 = df01.na.fill(0)
  47. # convert to a UDF Function by passing in the function and return type of function
  48. # 增加label列作为数据的标签,supply_price<100,label=0,否则label=1
  49. udfsomefunc = F.udf(somefunc, IntegerType())
  50. df03 = df02.withColumn("label", udfsomefunc("supply_price"))
  51. print('step 2')
  52. dfs = df03.randomSplit([0.6, 0.2, 0.2], seed=26)
  53. lr = LogisticRegression(regParam=0.01, maxIter=5000)
  54. featureLst = ['req_qty', 'row_qty', 'row_amt', 'purc_price', 'supply_price', 'ship_qty', 'receipt_qty', 'audit_qty']
  55. vectorAssembler = VectorAssembler().setInputCols(featureLst).setOutputCol("features")
  56. pipeline = Pipeline(stages=[vectorAssembler, lr])
  57. trainDF = dfs[0]
  58. model = pipeline.fit(trainDF)
  59. print("\n-------------------------------------------------------------------------")
  60. print("LogisticRegression parameters:\n" + lr.explainParams() + "\n")
  61. print("-------------------------------------------------------------------------\n")
  62. validDF = dfs[1]
  63. labelsAndPreds = model.transform(validDF).withColumn("probability_xj", to_array(col("probability"))[1]).select(
  64. "label", "prediction", "probability_xj")
  65. labelsAndPreds.show()
  66. #### 评估不同阈值下的准确率、召回率
  67. print("step 3")
  68. labelsAndPreds_label_1 = labelsAndPreds.where(labelsAndPreds.label == 1)
  69. labelsAndPreds_label_0 = labelsAndPreds.where(labelsAndPreds.label == 0)
  70. labelsAndPreds_label_1.show(3)
  71. labelsAndPreds_label_0.show(3)
  72. t_cnt = labelsAndPreds_label_1.count()
  73. f_cnt = labelsAndPreds_label_0.count()
  74. print("thre\ttp\ttn\tfp\tfn\taccuracy\trecall")
  75. for thre in [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]:
  76. tp = labelsAndPreds_label_1.where(labelsAndPreds_label_1.probability_xj > thre).count()
  77. tn = t_cnt - tp
  78. fp = labelsAndPreds_label_0.where(labelsAndPreds_label_0.probability_xj > thre).count()
  79. fn = f_cnt - fp
  80. print("%.1f\t%d\t%d\t%d\t%d\t%.4f\t%.4f" % (thre, tp, tn, fp, fn, float(tp) / (tp + fp), float(tp) / (t_cnt)))
  81. model.write().overwrite().save("hdfs:/tmp/target/model/lrModel")
  82. # argv[0] inputs:{"input1_key":"input1_path","input2_key":"input2_path",..}
  83. # argv[1] outputs: [result_path1,result_path2...]
  84. def run(inputs: dict, outputs: list):
  85. spark = SparkSession.builder.config('hive.metastore.uris',
  86. 'thrift://10.116.1.72:9083').enableHiveSupport().getOrCreate()
  87. param_dict = preprocess(input_infos=inputs, ss=spark)
  88. rets = main_func(**param_dict, spark=spark, sc=spark.sparkContext)
  89. postprocess(rets=rets, outputs=outputs)
  90. def read_table(ss: SparkSession, tb_name: str) -> DataFrame:
  91. return ss.sql(f'select * from {tb_name}')
  92. def write_table(df: DataFrame, tb_name: str):
  93. df.write.mode("overwrite").saveAsTable(tb_name)
  94. def preprocess(input_infos: dict, ss: SparkSession) -> dict:
  95. return {k: read_table(ss=ss, tb_name=v) for k, v in input_infos.items()}
  96. def postprocess(rets, outputs):
  97. if isinstance(rets, list):
  98. for idx, df in enumerate(rets):
  99. if idx == 0:
  100. write_table(df=df, tb_name=outputs[idx])
  101. else:
  102. write_table(df=df, tb_name=outputs[0].replace("_output0", "_output" + str(idx)))
  103. if __name__ == '__main__':
  104. run(inputs={}, outputs=['dataming.job32_task25_subnodehrgKx21DwpKfXAetpGoW8_output0_tmp'])