#!coding=utf8 import json import sys from pyspark.sql import SparkSession, DataFrame from pyspark.sql.types import DoubleType from pyspark.sql.functions import col import pyspark.sql.functions as F from pyspark.sql.types import * import sys from pyspark.sql import SparkSession, Row from pyspark.sql.types import * from pyspark.ml.linalg import Vectors from pyspark.ml.classification import LogisticRegression from pyspark.ml.feature import StringIndexer, VectorAssembler from pyspark.ml import Pipeline from pyspark.sql.functions import udf, col def somefunc(value): if value < 100: return 0 else: return 1 def to_array(col): def to_array_(v): return v.toArray().tolist() # Important: asNondeterministic requires Spark 2.3 or later # It can be safely removed i.e. # return udf(to_array_, ArrayType(DoubleType()))(col) # but at the cost of decreased performance return udf(to_array_, ArrayType(DoubleType())).asNondeterministic()(col) def main_func(spark, sc): # 筛选特征列 print('step 1') df01 = spark.sql( 'select req_qty, row_qty,row_amt,purc_price,supply_price,ship_qty,receipt_qty,audit_qty from dataming.d_evt_product_order_dtl') # 把每一列的内容转换成double类型 coff = 1.2 df01 = df01.withColumn("supply_price", df01.supply_price.cast('double') * coff) df01 = df01.withColumn("req_qty", df01.req_qty.cast('double') * coff) df01 = df01.withColumn("row_qty", df01.row_qty.cast('double') * coff) df01 = df01.withColumn("row_amt", df01.row_amt.cast('double') * coff) df01 = df01.withColumn("purc_price", df01.purc_price.cast('double') * coff) df01 = df01.withColumn("ship_qty", df01.ship_qty.cast('double') * coff) df01 = df01.withColumn("receipt_qty", df01.receipt_qty.cast('double') * coff) df01 = df01.withColumn("audit_qty", df01.audit_qty.cast('double') * coff) # 对df01中所有的null进行处理 df02 = df01.na.fill(0) # convert to a UDF Function by passing in the function and return type of function # 增加label列作为数据的标签,supply_price<100,label=0,否则label=1 udfsomefunc = F.udf(somefunc, IntegerType()) df03 = df02.withColumn("label", udfsomefunc("supply_price")) print('step 2') dfs = df03.randomSplit([0.6, 0.2, 0.2], seed=26) lr = LogisticRegression(regParam=0.01, maxIter=5000) featureLst = ['req_qty', 'row_qty', 'row_amt', 'purc_price', 'supply_price', 'ship_qty', 'receipt_qty', 'audit_qty'] vectorAssembler = VectorAssembler().setInputCols(featureLst).setOutputCol("features") pipeline = Pipeline(stages=[vectorAssembler, lr]) trainDF = dfs[0] model = pipeline.fit(trainDF) print("\n-------------------------------------------------------------------------") print("LogisticRegression parameters:\n" + lr.explainParams() + "\n") print("-------------------------------------------------------------------------\n") validDF = dfs[1] labelsAndPreds = model.transform(validDF).withColumn("probability_xj", to_array(col("probability"))[1]).select( "label", "prediction", "probability_xj") labelsAndPreds.show() #### 评估不同阈值下的准确率、召回率 print("step 3") labelsAndPreds_label_1 = labelsAndPreds.where(labelsAndPreds.label == 1) labelsAndPreds_label_0 = labelsAndPreds.where(labelsAndPreds.label == 0) labelsAndPreds_label_1.show(3) labelsAndPreds_label_0.show(3) t_cnt = labelsAndPreds_label_1.count() f_cnt = labelsAndPreds_label_0.count() print("thre\ttp\ttn\tfp\tfn\taccuracy\trecall") for thre in [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]: tp = labelsAndPreds_label_1.where(labelsAndPreds_label_1.probability_xj > thre).count() tn = t_cnt - tp fp = labelsAndPreds_label_0.where(labelsAndPreds_label_0.probability_xj > thre).count() fn = f_cnt - fp print("%.1f\t%d\t%d\t%d\t%d\t%.4f\t%.4f" % (thre, tp, tn, fp, fn, float(tp) / (tp + fp), float(tp) / (t_cnt))) model.write().overwrite().save("hdfs:/tmp/target/model/lrModel") # argv[0] inputs:{"input1_key":"input1_path","input2_key":"input2_path",..} # argv[1] outputs: [result_path1,result_path2...] def run(inputs: dict, outputs: list): spark = SparkSession.builder.config('hive.metastore.uris', 'thrift://10.116.1.72:9083').enableHiveSupport().getOrCreate() param_dict = preprocess(input_infos=inputs, ss=spark) rets = main_func(**param_dict, spark=spark, sc=spark.sparkContext) postprocess(rets=rets, outputs=outputs) def read_table(ss: SparkSession, tb_name: str) -> DataFrame: return ss.sql(f'select * from {tb_name}') def write_table(df: DataFrame, tb_name: str): df.write.mode("overwrite").saveAsTable(tb_name) def preprocess(input_infos: dict, ss: SparkSession) -> dict: return {k: read_table(ss=ss, tb_name=v) for k, v in input_infos.items()} def postprocess(rets, outputs): if isinstance(rets, list): for idx, df in enumerate(rets): if idx == 0: write_table(df=df, tb_name=outputs[idx]) else: write_table(df=df, tb_name=outputs[0].replace("_output0", "_output" + str(idx))) if __name__ == '__main__': run(inputs={}, outputs=['dataming.job32_task25_subnodehrgKx21DwpKfXAetpGoW8_output0_tmp'])