123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130 |
- #!coding=utf8
- import json
- import sys
- from pyspark.sql import SparkSession, DataFrame
- from pyspark.sql.types import DoubleType
- from pyspark.sql.functions import col
- import pyspark.sql.functions as F
- from pyspark.sql.types import *
- import sys
- from pyspark.sql import SparkSession, Row
- from pyspark.sql.types import *
- from pyspark.ml.linalg import Vectors
- from pyspark.ml.classification import LogisticRegression
- from pyspark.ml.feature import StringIndexer, VectorAssembler
- from pyspark.ml import Pipeline
- from pyspark.sql.functions import udf, col
- def somefunc(value):
- if value < 100:
- return 0
- else:
- return 1
- def to_array(col):
- def to_array_(v):
- return v.toArray().tolist()
- # Important: asNondeterministic requires Spark 2.3 or later
- # It can be safely removed i.e.
- # return udf(to_array_, ArrayType(DoubleType()))(col)
- # but at the cost of decreased performance
- return udf(to_array_, ArrayType(DoubleType())).asNondeterministic()(col)
- def main_func(spark, sc):
- # 筛选特征列
- print('step 1')
- df01 = spark.sql(
- 'select req_qty, row_qty,row_amt,purc_price,supply_price,ship_qty,receipt_qty,audit_qty from dataming.d_evt_product_order_dtl')
- # 把每一列的内容转换成double类型
- coff = 1.2
- df01 = df01.withColumn("supply_price", df01.supply_price.cast('double') * coff)
- df01 = df01.withColumn("req_qty", df01.req_qty.cast('double') * coff)
- df01 = df01.withColumn("row_qty", df01.row_qty.cast('double') * coff)
- df01 = df01.withColumn("row_amt", df01.row_amt.cast('double') * coff)
- df01 = df01.withColumn("purc_price", df01.purc_price.cast('double') * coff)
- df01 = df01.withColumn("ship_qty", df01.ship_qty.cast('double') * coff)
- df01 = df01.withColumn("receipt_qty", df01.receipt_qty.cast('double') * coff)
- df01 = df01.withColumn("audit_qty", df01.audit_qty.cast('double') * coff)
- # 对df01中所有的null进行处理
- df02 = df01.na.fill(0)
- # convert to a UDF Function by passing in the function and return type of function
- # 增加label列作为数据的标签,supply_price<100,label=0,否则label=1
- udfsomefunc = F.udf(somefunc, IntegerType())
- df03 = df02.withColumn("label", udfsomefunc("supply_price"))
- print('step 2')
- dfs = df03.randomSplit([0.6, 0.2, 0.2], seed=26)
- lr = LogisticRegression(regParam=0.01, maxIter=5000)
- featureLst = ['req_qty', 'row_qty', 'row_amt', 'purc_price', 'supply_price', 'ship_qty', 'receipt_qty', 'audit_qty']
- vectorAssembler = VectorAssembler().setInputCols(featureLst).setOutputCol("features")
- pipeline = Pipeline(stages=[vectorAssembler, lr])
- trainDF = dfs[0]
- model = pipeline.fit(trainDF)
- print("\n-------------------------------------------------------------------------")
- print("LogisticRegression parameters:\n" + lr.explainParams() + "\n")
- print("-------------------------------------------------------------------------\n")
- validDF = dfs[1]
- labelsAndPreds = model.transform(validDF).withColumn("probability_xj", to_array(col("probability"))[1]).select(
- "label", "prediction", "probability_xj")
- labelsAndPreds.show()
- #### 评估不同阈值下的准确率、召回率
- print("step 3")
- labelsAndPreds_label_1 = labelsAndPreds.where(labelsAndPreds.label == 1)
- labelsAndPreds_label_0 = labelsAndPreds.where(labelsAndPreds.label == 0)
- labelsAndPreds_label_1.show(3)
- labelsAndPreds_label_0.show(3)
- t_cnt = labelsAndPreds_label_1.count()
- f_cnt = labelsAndPreds_label_0.count()
- print("thre\ttp\ttn\tfp\tfn\taccuracy\trecall")
- for thre in [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]:
- tp = labelsAndPreds_label_1.where(labelsAndPreds_label_1.probability_xj > thre).count()
- tn = t_cnt - tp
- fp = labelsAndPreds_label_0.where(labelsAndPreds_label_0.probability_xj > thre).count()
- fn = f_cnt - fp
- print("%.1f\t%d\t%d\t%d\t%d\t%.4f\t%.4f" % (thre, tp, tn, fp, fn, float(tp) / (tp + fp), float(tp) / (t_cnt)))
- model.write().overwrite().save("hdfs:/tmp/target/model/lrModel")
- # argv[0] inputs:{"input1_key":"input1_path","input2_key":"input2_path",..}
- # argv[1] outputs: [result_path1,result_path2...]
- def run(inputs: dict, outputs: list):
- spark = SparkSession.builder.config('hive.metastore.uris',
- 'thrift://10.116.1.72:9083').enableHiveSupport().getOrCreate()
- param_dict = preprocess(input_infos=inputs, ss=spark)
- rets = main_func(**param_dict, spark=spark, sc=spark.sparkContext)
- postprocess(rets=rets, outputs=outputs)
- def read_table(ss: SparkSession, tb_name: str) -> DataFrame:
- return ss.sql(f'select * from {tb_name}')
- def write_table(df: DataFrame, tb_name: str):
- df.write.mode("overwrite").saveAsTable(tb_name)
- def preprocess(input_infos: dict, ss: SparkSession) -> dict:
- return {k: read_table(ss=ss, tb_name=v) for k, v in input_infos.items()}
- def postprocess(rets, outputs):
- if isinstance(rets, list):
- for idx, df in enumerate(rets):
- if idx == 0:
- write_table(df=df, tb_name=outputs[idx])
- else:
- write_table(df=df, tb_name=outputs[0].replace("_output0", "_output" + str(idx)))
- if __name__ == '__main__':
- run(inputs={}, outputs=['dataming.job32_task25_subnodehrgKx21DwpKfXAetpGoW8_output0_tmp'])
|