lrDemo.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. from pyspark.sql import SparkSession
  2. from pyspark.sql.types import *
  3. from pyspark.ml.classification import LogisticRegression
  4. from pyspark.ml.feature import VectorAssembler
  5. from pyspark.ml import Pipeline
  6. from pyspark.sql.functions import udf, col
  7. input_path_1 = "hdfs://192.168.199.27:9000/user/sxkj/train.txt"
  8. input_path_2 = "hdfs://192.168.199.27:9000/user/sxkj/test.txt"
  9. output_path = "hdfs://192.168.199.27:9000/tmp/sparkDemo/${ModelType}"
  10. def getFeatureName():
  11. featureLst = ['feature1', 'feature2', 'feature3', 'feature4', 'feature5', 'feature6', 'feature7', 'feature8',
  12. 'feature9']
  13. colLst = ['uid', 'label'] + featureLst
  14. return featureLst, colLst
  15. def parseFloat(x):
  16. try:
  17. rx = float(x)
  18. except:
  19. rx = 0.0
  20. return rx
  21. def getDict(dictDataLst, colLst):
  22. dictData = {}
  23. for i in range(len(colLst)):
  24. dictData[colLst[i]] = parseFloat(dictDataLst[i])
  25. return dictData
  26. def to_array(col):
  27. def to_array_(v):
  28. return v.toArray().tolist()
  29. # Important: asNondeterministic requires Spark 2.3 or later
  30. # It can be safely removed i.e.
  31. # return udf(to_array_, ArrayType(DoubleType()))(col)
  32. # but at the cost of decreased performance
  33. return udf(to_array_, ArrayType(DoubleType())).asNondeterministic()(col)
  34. def main():
  35. # spark = SparkSession.builder.master("yarn").appName("spark_demo").getOrCreate()
  36. spark = SparkSession.builder.getOrCreate()
  37. print("Session created!")
  38. sc = spark.sparkContext
  39. print("applicaton id: " + sc.applicationId)
  40. sampleHDFS_train = input_path_1 #sys.argv[1]
  41. sampleHDFS_test = input_path_2 #sys.argv[2]
  42. outputHDFS = output_path #sys.argv[3]
  43. featureLst, colLst = getFeatureName()
  44. # 读取hdfs上数据,将RDD转为DataFrame
  45. # 训练数据
  46. rdd_train = sc.textFile(sampleHDFS_train)
  47. rowRDD_train = rdd_train.map(lambda x: getDict(x.split('\t'), colLst))
  48. trainDF = spark.createDataFrame(rowRDD_train)
  49. # 测试数据
  50. rdd_test = sc.textFile(sampleHDFS_test)
  51. rowRDD_test = rdd_test.map(lambda x: getDict(x.split('\t'), colLst))
  52. testDF = spark.createDataFrame(rowRDD_test)
  53. # 用于训练的特征featureLst
  54. vectorAssembler = VectorAssembler().setInputCols(featureLst).setOutputCol("features")
  55. #### 训练 ####
  56. print("step 1")
  57. lr = LogisticRegression(regParam=0.01, maxIter=100) # regParam 正则项参数
  58. pipeline = Pipeline(stages=[vectorAssembler, lr])
  59. model = pipeline.fit(trainDF)
  60. # 打印参数
  61. print("\n-------------------------------------------------------------------------")
  62. print("LogisticRegression parameters:\n" + lr.explainParams() + "\n")
  63. print("-------------------------------------------------------------------------\n")
  64. #### 预测, 保存结果 ####
  65. print("step 2")
  66. labelsAndPreds = model.transform(testDF).withColumn("probability_xj", to_array(col("probability"))[1]) \
  67. .select("uid", "label", "prediction", "probability_xj")
  68. labelsAndPreds.show()
  69. labelsAndPreds.write.mode("overwrite").options(header="true").csv(outputHDFS + "/target/output")
  70. #### 评估不同阈值下的准确率、召回率
  71. print("step 3")
  72. labelsAndPreds_label_1 = labelsAndPreds.where(labelsAndPreds.label == 1)
  73. labelsAndPreds_label_0 = labelsAndPreds.where(labelsAndPreds.label == 0)
  74. labelsAndPreds_label_1.show(3)
  75. labelsAndPreds_label_0.show(3)
  76. t_cnt = labelsAndPreds_label_1.count()
  77. f_cnt = labelsAndPreds_label_0.count()
  78. print("thre\ttp\ttn\tfp\tfn\taccuracy\trecall")
  79. for thre in [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]:
  80. tp = labelsAndPreds_label_1.where(labelsAndPreds_label_1.probability_xj > thre).count()
  81. tn = t_cnt - tp
  82. fp = labelsAndPreds_label_0.where(labelsAndPreds_label_0.probability_xj > thre).count()
  83. fn = f_cnt - fp
  84. print("%.1f\t%d\t%d\t%d\t%d\t%.4f\t%.4f" % (thre, tp, tn, fp, fn, float(tp) / (tp + fp), float(tp) / (t_cnt)))
  85. # 保存模型
  86. model.write().overwrite().save(outputHDFS + "/target/model/lrModel")
  87. # 加载模型
  88. # model.load(outputHDFS + "/target/model/lrModel")
  89. print("output:", outputHDFS)
  90. if __name__ == '__main__':
  91. main()