123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172 |
- import json
- import sys
- from pyspark.sql.types import *
- from pyspark.ml.classification import LogisticRegression
- from pyspark.ml.feature import VectorAssembler
- from pyspark.ml import Pipeline
- from pyspark.sql.functions import udf, col
- from pyspark.sql import SparkSession, DataFrame
- import numpy
- # argv[0] inputs:{"input1_key":"input1_path","input2_key":"input2_path",..}
- # argv[1] outputs: [result_path1,result_path2...]
- def run(inputs: dict, outputs: list):
- spark = SparkSession.builder.config('hive.metastore.uris',
- 'thrift://192.168.199.27:9083').enableHiveSupport().getOrCreate()
- param_dict = preprocess(input_infos=inputs, ss=spark)
- rets = main_func(**param_dict)
- postprocess(rets=rets, outputs=outputs)
- def read_table(ss: SparkSession, tb_name: str) -> DataFrame:
- return ss.sql(f'select * from {tb_name}')
- def write_table(df: DataFrame, tb_name: str):
- df.write.mode("overwrite").saveAsTable(tb_name)
- def preprocess(input_infos: dict, ss: SparkSession) -> dict:
- return {k: read_table(ss=ss, tb_name=v) for k, v in input_infos.items()}
- def postprocess(rets, outputs):
- [write_table(df=df, tb_name=outputs[idx]) for idx, df in enumerate(rets)]
- def to_array(col):
- def to_array_(v):
- return v.toArray().tolist()
- return udf(to_array_, ArrayType(DoubleType())).asNondeterministic()(col)
- def main_func(train_df: DataFrame, test_df: DataFrame):
- feat_cols = ['feature1', 'feature2', 'feature3', 'feature4', 'feature5', 'feature6', 'feature7', 'feature8',
- 'feature9']
- vector_assembler = VectorAssembler().setInputCols(feat_cols).setOutputCol("features")
- #### 训练 ####
- print("step 1")
- lr = LogisticRegression(regParam=0.01, maxIter=100) # regParam 正则项参数
- pipeline = Pipeline(stages=[vector_assembler, lr])
- model = pipeline.fit(train_df)
- # 打印参数
- print("\n-------------------------------------------------------------------------")
- print("LogisticRegression parameters:\n" + lr.explainParams() + "\n")
- print("-------------------------------------------------------------------------\n")
- #### 预测, 保存结果 ####
- print("step 2")
- labels_and_preds = model.transform(test_df).withColumn("probability_xj", to_array(col("probability"))[1]) \
- .select("uuid", "label", "prediction", "probability_xj")
- return [labels_and_preds]
- if __name__ == '__main__':
- inputs_str = sys.argv[1]
- outputs_str = sys.argv[2]
- run(inputs=json.loads(inputs_str), outputs=json.loads(outputs_str))
|