基于PySpark的模型开发

pyspark官方文档http://spark.apache.org/docs/latest/api/python/index.html

会员流失预测模型

通用模型开发流程

16-基于PySpark的机器学习环境搭建和模型开发 - 图1
模型开发流程

需求沟通与问题确立

定义流失口径:比如,流失客户定义为最近一次购买日期距今的时间大于平均购买间期加3倍的标准差;非流失客户定义为波动比较小,购买频次比较稳定的客户 选定时间窗口:比如,选择每个会员最近一次购买时间回溯一年的历史订单情况 推测可能的影响因素:头脑风暴,特征初筛,从业务角度出发,尽可能多的筛选出可能的影响因素作为原始特征集

数据整合与特征工程

1)把来自不同表的数据整合到一张宽表中,一般是通过SQL处理
2)数据预处理和特征工程
16-基于PySpark的机器学习环境搭建和模型开发 - 图2
预处理与特征工程

模型开发与效果评估

1)样本数据先按照正负例分别随机拆分,然后分别组成训练和测试集,保证训练集和测试集之间没有重复数据,训练集和测试集正负例比例基本一致,最终两个数据集中正负例比例均接近1:1

16-基于PySpark的机器学习环境搭建和模型开发 - 图3
数据集拆分过程

2)对于建立模型而言并非特征越多越好,建模的目标是使用尽量简单的模型去实现尽量好的效果。减少一些价值小贡献小的特征有利于在表现效果不变或降低很小的前提下,找到最简单的模型。

16-基于PySpark的机器学习环境搭建和模型开发 - 图4
特征选择与模型建立

使用卡方检验对特征与因变量进行独立性检验,如果独立性高就表示两者没太大关系,特征可以舍弃;如果独立性小,两者相关性高,则说明该特征会对应变量产生比较大的影响,应当选择。
3)CV或者TVS将数据划分为训练数据和测试数据,对于每个(训练,测试)对,遍历一组参数。用每一组参数来拟合,得到训练后的模型,再用AUC和ACC评估模型表现,选择性能表现最优模型对应参数表。

16-基于PySpark的机器学习环境搭建和模型开发 - 图5
超参调整与模型评估

模型应用与迭代优化

应用模型预测结果/评分进行精细化营销或者挽回,同时不断根据实际情况优化模型,再用优化后的模型重新预测,形成一个迭代优化的闭环。

16-基于PySpark的机器学习环境搭建和模型开发 - 图6
模型迭代流程

模型代码

附1:本地开发的Python代码


  1. # coding: utf-8
  2. from pyspark.sql import SparkSession
  3. from pyspark.sql.types import StructType, StructField, StringType, DoubleType
  4. import pyspark.sql.functions as fn
  5. from pyspark.sql.functions import regexp_extract,col # regexp_extract是pyspark正则表达式模块
  6. from pyspark.ml.feature import Bucketizer, QuantileDiscretizer, OneHotEncoder, StringIndexer, IndexToString, VectorIndexer, VectorAssembler
  7. from pyspark.ml.feature import ChiSqSelector, StandardScaler
  8. from pyspark.ml import Pipeline, PipelineModel
  9. from pyspark.ml.classification import RandomForestClassifier, LogisticRegression,LogisticRegressionModel
  10. from pyspark.ml.evaluation import BinaryClassificationEvaluator
  11. from pyspark.ml.tuning import CrossValidator, ParamGridBuilder, TrainValidationSplit
  12. if __name__ == '__main__':
  13. # 创建spark连接
  14. spark = SparkSession.builder.appName("usrLostModel").master("local[*]").getOrCreate()
  15. # 定义数据格式
  16. schema = StructType([
  17. StructField('label', DoubleType(), True),
  18. StructField('erp_code', StringType(), True),
  19. StructField('gender', StringType(), True),
  20. StructField('is_nt_concn_webchat', StringType(), True),
  21. StructField('erp_corp', StringType(), True),
  22. StructField('is_nm', DoubleType(), True),
  23. StructField('is_tel', DoubleType(), True),
  24. StructField('is_birth_dt', DoubleType(), True),
  25. StructField('is_cert_num', DoubleType(), True),
  26. StructField('is_handl_org', DoubleType(), True),
  27. StructField('is_addr', DoubleType(), True),
  28. StructField('year_count_erp', DoubleType(), True),
  29. StructField('haft_year_count_erp', DoubleType(), True),
  30. StructField('three_month_count_erp', DoubleType(), True),
  31. StructField('one_month_count_erp', DoubleType(), True),
  32. StructField('year_avg_count_erp', DoubleType(), True),
  33. StructField('haftyear_avg_count_erp', DoubleType(), True),
  34. StructField('threemonth_avg_count_erp', DoubleType(), True),
  35. StructField('onemonth_avg_count_erp', DoubleType(), True),
  36. StructField('year_recv_amt_sum', DoubleType(), True),
  37. StructField('year_discnt_amt_sum', DoubleType(), True),
  38. StructField('haftyear_recv_amt_sum', DoubleType(), True),
  39. StructField('haftyear_discnt_amt_sum', DoubleType(), True),
  40. StructField('threemonth_recv_amt_sum', DoubleType(), True),
  41. StructField('threemonth_discnt_amt_sum', DoubleType(), True),
  42. StructField('onemonth_recv_amt_sum', DoubleType(), True),
  43. StructField('onemonth_discnt_amt_sum', DoubleType(), True),
  44. StructField('is_prdc_mfr_pianhao', DoubleType(), True),
  45. StructField('sell_tm_erp', StringType(), True),
  46. StructField('shopgui_erp', StringType(), True),
  47. StructField('duration_erp', DoubleType(), True),
  48. StructField('is_shopgui_erp', DoubleType(), True),
  49. StructField('lvl_shopgui_erp', DoubleType(), True)
  50. ])
  51. # 导入数据为dataframe格式,header=true表示第一行为字段名,schema定义字段类型,schema=None表示由系统自动识别
  52. data = spark.read.csv("./data/data.csv",header=True,schema=schema)
  53. vdata = spark.read.csv("./data/vdata.csv",header=True,schema=schema)
  54. # data.show(5)
  55. # vdata.groupBy("gender").count().show()
  56. # vdata.groupBy("is_nt_concn_webchat").count().show()
  57. # vdata.groupBy("erp_corp").count().show()
  58. def featureEngineering(data):
  59. # 用0填补is_nt_concn_webchat缺失值
  60. data0 = data.na.fill({'is_nt_concn_webchat': '0'})
  61. # data.groupBy("is_nt_concn_webchat").count().show()
  62. # 检查每一列中缺失数据的百分比
  63. # data.agg(*[(1-(fn.count(c)/fn.count('*'))).alias(c+'_null') for c in data.columns]).show()
  64. # data.agg(*[(fn.count('*')-fn.count(c)).alias(c+'_null') for c in data.columns]).show()
  65. # 用0.0填补double类型数据的缺失值
  66. data0 = data0.na.fill({'year_discnt_amt_sum': 0.0})
  67. data0 = data0.na.fill({'haftyear_recv_amt_sum': 0.0})
  68. data0 = data0.na.fill({'haftyear_discnt_amt_sum': 0.0})
  69. data0 = data0.na.fill({'threemonth_recv_amt_sum': 0.0})
  70. data0 = data0.na.fill({'threemonth_discnt_amt_sum': 0.0})
  71. data0 = data0.na.fill({'onemonth_recv_amt_sum': 0.0})
  72. data0 = data0.na.fill({'onemonth_discnt_amt_sum': 0.0})
  73. data0 = data0.na.fill({'is_prdc_mfr_pianhao': 0.0})
  74. data0 = data0.na.fill({'sell_tm_erp': 0.0})
  75. data0 = data0.na.fill({'shopgui_erp': 0.0})
  76. data0 = data0.na.fill({'duration_erp': 0.0})
  77. data0 = data0.na.fill({'is_shopgui_erp': 0.0})
  78. data0 = data0.na.fill({'lvl_shopgui_erp': 0.0})
  79. # 检查每一列中缺失数据的百分比
  80. # data.agg(*[(1-(fn.count(c)/fn.count('*'))).alias(c+'_null') for c in data.columns]).show()
  81. # 移除sell_tm_erp、shopgui_erp
  82. data1 = data0.drop('sell_tm_erp').drop('shopgui_erp')
  83. # data1.dtypes
  84. # data2.describe(['year_count_erp','haft_year_count_erp','three_month_count_erp','one_month_count_erp']).show()
  85. # data2.where('haft_year_count_erp>10 and haft_year_count_erp<=20').count()
  86. # data2.describe(['year_recv_amt_sum','haftyear_recv_amt_sum','threemonth_recv_amt_sum','onemonth_recv_amt_sum']).show()
  87. # data2.where('year_recv_amt_sum<0').count()
  88. # data2.where('onemonth_recv_amt_sum>=50 and onemonth_recv_amt_sum<100').count()
  89. # data2.describe(['year_discnt_amt_sum','haftyear_discnt_amt_sum','threemonth_discnt_amt_sum','onemonth_discnt_amt_sum','duration_erp']).show()
  90. data2 = data1.where('year_count_erp<200 and year_recv_amt_sum>0')
  91. return data2
  92. # 连续数据离散化
  93. bucketizer1 = QuantileDiscretizer(numBuckets=5, inputCol='year_count_erp',
  94. outputCol='year_count_bucketed',
  95. relativeError=0.01, handleInvalid='error')
  96. bucketizer2 = QuantileDiscretizer(numBuckets=5, inputCol='haft_year_count_erp',
  97. outputCol='haft_year_count_bucketed',
  98. relativeError=0.01, handleInvalid='error')
  99. bucketizer3 = QuantileDiscretizer(numBuckets=5, inputCol='three_month_count_erp',
  100. outputCol='three_month_count_bucketed',
  101. relativeError=0.01, handleInvalid='error')
  102. bucketizer4 = QuantileDiscretizer(numBuckets=5, inputCol='one_month_count_erp',
  103. outputCol='one_month_count_bucketed',
  104. relativeError=0.01, handleInvalid='error')
  105. bucketizer5 = QuantileDiscretizer(numBuckets=5, inputCol='year_recv_amt_sum',
  106. outputCol='year_recv_bucketed',
  107. relativeError=0.01, handleInvalid='error')
  108. bucketizer6 = QuantileDiscretizer(numBuckets=5, inputCol='haftyear_recv_amt_sum',
  109. outputCol='haftyear_recv_bucketed',
  110. relativeError=0.01, handleInvalid='error')
  111. bucketizer7 = QuantileDiscretizer(numBuckets=5, inputCol='threemonth_recv_amt_sum',
  112. outputCol='threemonth_recv_bucketed',
  113. relativeError=0.01, handleInvalid='error')
  114. bucketizer8 = QuantileDiscretizer(numBuckets=5, inputCol='onemonth_recv_amt_sum',
  115. outputCol='onemonth_recv_bucketed',
  116. relativeError=0.01, handleInvalid='error')
  117. bucketizer9 = QuantileDiscretizer(numBuckets=5, inputCol='duration_erp',
  118. outputCol='duration_bucketed',
  119. relativeError=0.01, handleInvalid='error')
  120. # 把String类型的字段转化为double类型
  121. indexer1 = StringIndexer().setInputCol("gender").setOutputCol("gender_index")
  122. indexer2 = StringIndexer().setInputCol("is_nt_concn_webchat").setOutputCol("webchat_index")
  123. indexer3 = StringIndexer().setInputCol("erp_corp").setOutputCol("corp_index")
  124. # onehot编码
  125. encoder01 = OneHotEncoder().setInputCol("gender_index").setOutputCol("gender_vec").setDropLast(False)
  126. encoder02 = OneHotEncoder().setInputCol("webchat_index").setOutputCol("webchat_vec").setDropLast(False)
  127. encoder03 = OneHotEncoder().setInputCol("corp_index").setOutputCol("corp_vec").setDropLast(False)
  128. encoder1 = OneHotEncoder().setInputCol("year_count_bucketed").setOutputCol("year_count_vec").setDropLast(False)
  129. encoder2 = OneHotEncoder().setInputCol("haft_year_count_bucketed").setOutputCol("haft_year_count_vec").setDropLast(False)
  130. encoder3 = OneHotEncoder().setInputCol("three_month_count_bucketed").setOutputCol("three_month_count_vec").setDropLast(False)
  131. encoder4 = OneHotEncoder().setInputCol("one_month_count_bucketed").setOutputCol("one_month_count_vec").setDropLast(False)
  132. encoder5 = OneHotEncoder().setInputCol("year_recv_bucketed").setOutputCol("year_recv_vec").setDropLast(False)
  133. encoder6 = OneHotEncoder().setInputCol("haftyear_recv_bucketed").setOutputCol("haftyear_recv_vec").setDropLast(False)
  134. encoder7 = OneHotEncoder().setInputCol("threemonth_recv_bucketed").setOutputCol("threemonth_recv_vec").setDropLast(False)
  135. encoder8 = OneHotEncoder().setInputCol("onemonth_recv_bucketed").setOutputCol("onemonth_recv_vec").setDropLast(False)
  136. encoder9 = OneHotEncoder().setInputCol("duration_bucketed").setOutputCol("duration_vec").setDropLast(False)
  137. encoder10 = OneHotEncoder().setInputCol("lvl_shopgui_erp").setOutputCol("lvl_shopgui_vec").setDropLast(False)
  138. preprocessPipeline = Pipeline(stages=[bucketizer1, bucketizer2, bucketizer3, bucketizer4, bucketizer5, bucketizer6, bucketizer7, bucketizer8,
  139. bucketizer9, indexer1, indexer2, indexer3, encoder01, encoder02, encoder03, encoder1, encoder2, encoder3,
  140. encoder4, encoder5, encoder6, encoder7, encoder8, encoder9, encoder10])
  141. # (dt1, dt2) = data.randomSplit([0.9, 0.1], seed=1)
  142. data2 = featureEngineering(data)
  143. preP = preprocessPipeline.fit(data2)
  144. data3 = preP.transform(data2)
  145. data3.cache()
  146. dp = data3.where('label=1.0')
  147. dn = data3.where('label=0.0')
  148. # print(dp.count())
  149. # print(dn.count())
  150. samplerate = round(dn.count()/dp.count())
  151. # print(samplerate)
  152. # 将数据切分为训练集和测试集,按照训练集70%,测试集30%的比例
  153. (dp1, dp2) = dp.randomSplit([0.7, 0.3], seed=1)
  154. (dn1, dn2) = dn.randomSplit([0.7, 0.3], seed=2)
  155. df1 = dp1.union(dn1)
  156. df2 = dp2.union(dn2)
  157. df1.groupBy("label").count().show()
  158. df2.groupBy("label").count().show()
  159. data_p1 = df1.where('label=1.0')
  160. data_p1.show(5)
  161. data_n1 = df1.where('label=0.0')
  162. data_p2 = df2.where('label=1.0')
  163. data_n2 = df2.where('label=0.0')
  164. data_p11 = data_p1.rdd.sample(True,samplerate,100)
  165. data_p12 = spark.createDataFrame(data_p11)
  166. data_p21 = data_p2.rdd.sample(True,samplerate,100)
  167. data_p22 = spark.createDataFrame(data_p21)
  168. train_df = data_n1.union(data_p12)
  169. test_df = data_n2.union(data_p22)
  170. # print('train_df正负例比例:')
  171. # train_df.groupBy("label").count().show()
  172. # print('test_df正负例比例:')
  173. # test_df.groupBy("label").count().show()
  174. train_df.cache()
  175. test_df.cache()
  176. # 特征初步选择
  177. assemblerInputs = ['gender_vec', 'webchat_vec', 'corp_vec','is_nm'
  178. # assemblerInputs = ['gender_vec', 'webchat_vec', 'is_nm'
  179. ,'is_tel'
  180. ,'is_birth_dt'
  181. ,'is_cert_num'
  182. ,'is_handl_org'
  183. ,'is_addr'
  184. ,'year_count_vec'
  185. ,'haft_year_count_vec'
  186. ,'three_month_count_vec'
  187. ,'one_month_count_vec'
  188. ,'year_recv_vec'
  189. ,'haftyear_recv_vec'
  190. ,'threemonth_recv_vec'
  191. ,'onemonth_recv_vec'
  192. ,'is_prdc_mfr_pianhao'
  193. ,'duration_vec'
  194. ,'lvl_shopgui_vec'
  195. ,'lvl_shopgui_erp'
  196. ]
  197. # 构建VectorAssembler转换器,用于把特征值转化为特征向量
  198. assembler = VectorAssembler(inputCols=assemblerInputs, outputCol='features')
  199. # 奇偶选择器 卡方检验,用于筛选重要特征,numTopFeatures=10表示筛选出最重要的10个特征,fpr=0.05假设检验的p值
  200. chiSqSelector = ChiSqSelector(featuresCol="features",fpr=0.05,
  201. outputCol="selectedFeatures", labelCol="label")
  202. # 构建模型验证相关参数的转换器
  203. evaluator = BinaryClassificationEvaluator(rawPredictionCol="rawPrediction", labelCol="label",
  204. metricName="areaUnderROC")
  205. lr = LogisticRegression(elasticNetParam=0.0,maxIter=30,regParam=0.001,featuresCol="selectedFeatures", labelCol="label")
  206. lr_pipeline = Pipeline(stages=[assembler, chiSqSelector, lr])
  207. lrModel = lr_pipeline.fit(train_df)
  208. # 模型保存
  209. lrModel.write().overwrite().save("./model/")
  210. # 加载模型
  211. samelrModel = PipelineModel.load("./model/")
  212. # 模型验证,评估分类效果
  213. result = samelrModel.transform(test_df)
  214. auc = evaluator.evaluate(result)
  215. print("AUC(AreaUnderROC)为:{}".format(auc))
  216. total_amount = result.count()
  217. correct_amount = result.filter(result.label == result.prediction).count()
  218. precision_rate = correct_amount / total_amount
  219. print("预测准确率为:{}".format(precision_rate))
  220. positive_amount = result.filter(result.label == 1).count()
  221. negative_amount = result.filter(result.label == 0).count()
  222. print("正样本数:{},负样本数:{}".format(positive_amount, negative_amount))
  223. positive_precision_amount = result.filter(result.label == 1).filter(result.prediction == 1).count()
  224. negative_precision_amount = result.filter(result.label == 0).filter(result.prediction == 0).count()
  225. positive_false_amount = result.filter(result.label == 1).filter(result.prediction == 0).count()
  226. negative_false_amount = result.filter(result.label == 0).filter(result.prediction == 1).count()
  227. print("正样本预测准确数量:{},负样本预测准确数量:{}".format(positive_precision_amount, negative_precision_amount))
  228. print("正样本预测错误数量:{},负样本预测错误数量:{}".format(positive_false_amount, negative_false_amount))
  229. recall_rate1 = positive_precision_amount / positive_amount
  230. recall_rate2 = negative_precision_amount / negative_amount
  231. print("正样本召回率为:{},负样本召回率为:{}".format(recall_rate1, recall_rate2))
  232. prediction = result.select('label', 'prediction','rawPrediction', 'probability')
  233. print(result.show(5))
  234. print(prediction.show(5))
  235. spark.stop()

附2:基于分布式环境的Python代码

  1. # coding: utf-8
  2. import logging
  3. import sys
  4. from pyspark.sql import SparkSession
  5. from pyspark.sql.types import StructType, StructField, StringType, DoubleType
  6. import pyspark.sql.functions as fn
  7. from pyspark.sql.functions import regexp_extract,col # regexp_extract是pyspark正则表达式模块
  8. from pyspark.ml.feature import Bucketizer,QuantileDiscretizer, OneHotEncoder, StringIndexer, IndexToString, VectorIndexer, VectorAssembler
  9. from pyspark.ml.feature import ChiSqSelector, StandardScaler
  10. from pyspark.ml import Pipeline, PipelineModel
  11. from pyspark.ml.classification import RandomForestClassifier, LogisticRegression,LogisticRegressionModel
  12. from pyspark.ml.evaluation import BinaryClassificationEvaluator
  13. from pyspark.ml.tuning import CrossValidator, ParamGridBuilder, TrainValidationSplit
  14. if __name__ == '__main__':
  15. # 创建spark连接
  16. spark = SparkSession.builder.appName("usrLostModel").enableHiveSupport().master("yarn-client").getOrCreate()
  17. # 导入hive数据为dataframe格式
  18. data = spark.sql("select label,erp_code,gender,is_nt_concn_webchat,erp_corp,is_nm,is_tel,is_birth_dt,\
  19. is_cert_num,is_handl_org,is_addr,cast(year_count_erp as double),cast(haft_year_count_erp as double),\
  20. cast(three_month_count_erp as double),cast(one_month_count_erp as double),cast(year_avg_count_erp as double),\
  21. cast(haftyear_avg_count_erp as double),cast(threemonth_avg_count_erp as double),cast(onemonth_avg_count_erp as double),\
  22. cast(year_recv_amt_sum as double),cast(year_discnt_amt_sum as double),cast(haftyear_recv_amt_sum as double),\
  23. cast(haftyear_discnt_amt_sum as double),cast(threemonth_recv_amt_sum as double),cast(threemonth_discnt_amt_sum as double),\
  24. cast(onemonth_recv_amt_sum as double),cast(onemonth_discnt_amt_sum as double),cast(is_prdc_mfr_pianhao as double),\
  25. sell_tm_erp,shopgui_erp,duration_erp,is_shopgui_erp,lvl_shopgui_erp from spider.t06_wajue_data")
  26. data.show(5)
  27. def featureEngineering(data):
  28. # 用0填补is_nt_concn_webchat缺失值
  29. data0 = data.na.fill({'is_nt_concn_webchat': '0'})
  30. # data.groupBy("is_nt_concn_webchat").count().show()
  31. # 检查每一列中缺失数据的百分比
  32. # data.agg(*[(1-(fn.count(c)/fn.count('*'))).alias(c+'_null') for c in data.columns]).show()
  33. # data.agg(*[(fn.count('*')-fn.count(c)).alias(c+'_null') for c in data.columns]).show()
  34. # 用0.0填补double类型数据的缺失值
  35. data0 = data0.na.fill({'year_discnt_amt_sum': 0.0})
  36. data0 = data0.na.fill({'haftyear_recv_amt_sum': 0.0})
  37. data0 = data0.na.fill({'haftyear_discnt_amt_sum': 0.0})
  38. data0 = data0.na.fill({'threemonth_recv_amt_sum': 0.0})
  39. data0 = data0.na.fill({'threemonth_discnt_amt_sum': 0.0})
  40. data0 = data0.na.fill({'onemonth_recv_amt_sum': 0.0})
  41. data0 = data0.na.fill({'onemonth_discnt_amt_sum': 0.0})
  42. data0 = data0.na.fill({'is_prdc_mfr_pianhao': 0.0})
  43. # data0 = data0.na.fill({'sell_tm_erp': 0.0})
  44. data0 = data0.na.fill({'shopgui_erp': 0.0})
  45. data0 = data0.na.fill({'duration_erp': 0.0})
  46. data0 = data0.na.fill({'is_shopgui_erp': 0.0})
  47. data0 = data0.na.fill({'lvl_shopgui_erp': 0.0})
  48. # 检查每一列中缺失数据的百分比
  49. # data.agg(*[(1-(fn.count(c)/fn.count('*'))).alias(c+'_null') for c in data.columns]).show()
  50. # 移除sell_tm_erp、shopgui_erp
  51. data1 = data0.drop('shopgui_erp')
  52. # data1.dtypes
  53. # data2.describe(['year_count_erp','haft_year_count_erp','three_month_count_erp','one_month_count_erp']).show()
  54. # data2.where('haft_year_count_erp>10 and haft_year_count_erp<=20').count()
  55. # data2.describe(['year_recv_amt_sum','haftyear_recv_amt_sum','threemonth_recv_amt_sum','onemonth_recv_amt_sum']).show()
  56. # data2.where('year_recv_amt_sum<0').count()
  57. # data2.where('onemonth_recv_amt_sum>=50 and onemonth_recv_amt_sum<100').count()
  58. # data2.describe(['year_discnt_amt_sum','haftyear_discnt_amt_sum','threemonth_discnt_amt_sum','onemonth_discnt_amt_sum','duration_erp']).show()
  59. data2 = data1.where('year_count_erp<200 and year_recv_amt_sum>0')
  60. return data2
  61. data2 = featureEngineering(data)
  62. # 连续数据离散化
  63. bucketizer1 = QuantileDiscretizer(numBuckets=5, inputCol='year_count_erp',\
  64. outputCol='year_count_bucketed',\
  65. relativeError=0.01, handleInvalid='error')
  66. bucketizer2 = QuantileDiscretizer(numBuckets=5, inputCol='haft_year_count_erp',\
  67. outputCol='haft_year_count_bucketed',\
  68. relativeError=0.01, handleInvalid='error')
  69. bucketizer3 = QuantileDiscretizer(numBuckets=5, inputCol='three_month_count_erp',\
  70. outputCol='three_month_count_bucketed',\
  71. relativeError=0.01, handleInvalid='error')
  72. bucketizer4 = QuantileDiscretizer(numBuckets=5, inputCol='one_month_count_erp',\
  73. outputCol='one_month_count_bucketed',\
  74. relativeError=0.01, handleInvalid='error')
  75. bucketizer5 = QuantileDiscretizer(numBuckets=5, inputCol='year_recv_amt_sum',\
  76. outputCol='year_recv_bucketed',\
  77. relativeError=0.01, handleInvalid='error')
  78. bucketizer6 = QuantileDiscretizer(numBuckets=5, inputCol='haftyear_recv_amt_sum',\
  79. outputCol='haftyear_recv_bucketed',\
  80. relativeError=0.01, handleInvalid='error')
  81. bucketizer7 = QuantileDiscretizer(numBuckets=5, inputCol='threemonth_recv_amt_sum',\
  82. outputCol='threemonth_recv_bucketed',\
  83. relativeError=0.01, handleInvalid='error')
  84. bucketizer8 = QuantileDiscretizer(numBuckets=5, inputCol='onemonth_recv_amt_sum',\
  85. outputCol='onemonth_recv_bucketed',\
  86. relativeError=0.01, handleInvalid='error')
  87. bucketizer9 = QuantileDiscretizer(numBuckets=5, inputCol='duration_erp',\
  88. outputCol='duration_bucketed',\
  89. relativeError=0.01, handleInvalid='error')
  90. # 把String类型的字段转化为double类型
  91. indexer1 = StringIndexer().setInputCol("gender").setOutputCol("gender_index")
  92. indexer2 = StringIndexer().setInputCol("is_nt_concn_webchat").setOutputCol("webchat_index")
  93. indexer3 = StringIndexer().setInputCol("erp_corp").setOutputCol("corp_index")
  94. # onehot编码
  95. encoder01 = OneHotEncoder().setInputCol("gender_index").setOutputCol("gender_vec").setDropLast(False)
  96. encoder02 = OneHotEncoder().setInputCol("webchat_index").setOutputCol("webchat_vec").setDropLast(False)
  97. encoder03 = OneHotEncoder().setInputCol("corp_index").setOutputCol("corp_vec").setDropLast(False)
  98. encoder1 = OneHotEncoder().setInputCol("year_count_bucketed").setOutputCol("year_count_vec").setDropLast(False)
  99. encoder2 = OneHotEncoder().setInputCol("haft_year_count_bucketed").setOutputCol("haft_year_count_vec").setDropLast(False)
  100. encoder3 = OneHotEncoder().setInputCol("three_month_count_bucketed").setOutputCol("three_month_count_vec").setDropLast(False)
  101. encoder4 = OneHotEncoder().setInputCol("one_month_count_bucketed").setOutputCol("one_month_count_vec").setDropLast(False)
  102. encoder5 = OneHotEncoder().setInputCol("year_recv_bucketed").setOutputCol("year_recv_vec").setDropLast(False)
  103. encoder6 = OneHotEncoder().setInputCol("haftyear_recv_bucketed").setOutputCol("haftyear_recv_vec").setDropLast(False)
  104. encoder7 = OneHotEncoder().setInputCol("threemonth_recv_bucketed").setOutputCol("threemonth_recv_vec").setDropLast(False)
  105. encoder8 = OneHotEncoder().setInputCol("onemonth_recv_bucketed").setOutputCol("onemonth_recv_vec").setDropLast(False)
  106. encoder9 = OneHotEncoder().setInputCol("duration_bucketed").setOutputCol("duration_vec").setDropLast(False)
  107. encoder10 = OneHotEncoder().setInputCol("lvl_shopgui_erp").setOutputCol("lvl_shopgui_vec").setDropLast(False)
  108. preprocessPipeline = Pipeline(stages=[bucketizer1, bucketizer2, bucketizer3, bucketizer4, bucketizer5, bucketizer6, bucketizer7, bucketizer8,
  109. bucketizer9, indexer1, indexer2, indexer3, encoder01, encoder02, encoder03, encoder1, encoder2, encoder3,
  110. encoder4, encoder5, encoder6, encoder7, encoder8, encoder9, encoder10])
  111. preP = preprocessPipeline.fit(data2)
  112. data3 = preP.transform(data2)
  113. data3.cache()
  114. dp = data3.where('label=1.0')
  115. dn = data3.where('label=0.0')
  116. print(dp.count())
  117. print(dn.count())
  118. samplerate = round(dn.count()/dp.count())
  119. # 将数据切分为训练集和测试集,按照训练集70%,测试集30%的比例
  120. (dp1, dp2) = dp.randomSplit([0.7, 0.3], seed=1)
  121. (dn1, dn2) = dn.randomSplit([0.7, 0.3], seed=2)
  122. df1 = dp1.union(dn1)
  123. df2 = dp2.union(dn2)
  124. df1.groupBy("label").count().show()
  125. df2.groupBy("label").count().show()
  126. data_p1 = df1.where('label=1.0')
  127. data_n1 = df1.where('label=0.0')
  128. data_p1.show(5)
  129. data_p2 = df2.where('label=1.0')
  130. data_n2 = df2.where('label=0.0')
  131. data_p11 = data_p1.rdd.sample(True,samplerate,100)
  132. data_p12 = spark.createDataFrame(data_p11)
  133. data_p21 = data_p2.rdd.sample(True,samplerate,100)
  134. data_p22 = spark.createDataFrame(data_p21)
  135. train_df = data_n1.union(data_p12)
  136. test_df = data_n2.union(data_p22)
  137. # print('train_df正负例比例:')
  138. # train_df.groupBy("label").count().show()
  139. # print('test_df正负例比例:')
  140. # test_df.groupBy("label").count().show()
  141. train_df.cache()
  142. test_df.cache()
  143. # 特征初步选择
  144. assemblerInputs = ['gender_vec', 'webchat_vec', 'corp_vec','is_nm'
  145. # assemblerInputs = ['gender_vec', 'webchat_vec', 'is_nm'
  146. ,'is_tel'
  147. ,'is_birth_dt'
  148. ,'is_cert_num'
  149. ,'is_handl_org'
  150. ,'is_addr'
  151. ,'year_count_vec'
  152. ,'haft_year_count_vec'
  153. ,'three_month_count_vec'
  154. ,'one_month_count_vec'
  155. ,'year_recv_vec'
  156. ,'haftyear_recv_vec'
  157. ,'threemonth_recv_vec'
  158. ,'onemonth_recv_vec'
  159. ,'is_prdc_mfr_pianhao'
  160. ,'duration_vec'
  161. ,'lvl_shopgui_vec'
  162. ,'lvl_shopgui_erp'
  163. ]
  164. # 构建VectorAssembler转换器,用于把特征值转化为特征向量
  165. assembler = VectorAssembler(inputCols=assemblerInputs, outputCol='features')
  166. # 奇偶选择器 卡方检验,用于筛选重要特征,numTopFeatures=10表示筛选出最重要的10个特征,fpr=0.05假设检验的p值
  167. chiSqSelector = ChiSqSelector(featuresCol="features",fpr=0.05,
  168. outputCol="selectedFeatures", labelCol="label")
  169. # 构建模型验证相关参数的转换器
  170. evaluator = BinaryClassificationEvaluator(rawPredictionCol="rawPrediction", labelCol="label",
  171. metricName="areaUnderROC")
  172. lr = LogisticRegression(elasticNetParam=0.0,maxIter=30,regParam=0.001,featuresCol="selectedFeatures", labelCol="label")
  173. lr_pipeline = Pipeline(stages=[assembler, chiSqSelector, lr])
  174. lrModel = lr_pipeline.fit(train_df)
  175. # 模型保存
  176. lrModel.write().overwrite().save("hdfs://nameservice1/BDP/model/usrLostModel")
  177. # 加载模型
  178. samelrModel = PipelineModel.load("hdfs://nameservice1/BDP/model/usrLostModel")
  179. # 模型验证,评估分类效果
  180. result = samelrModel.transform(test_df)
  181. # 获取logger实例,如果参数为空则返回root logger
  182. logger = logging.getLogger("usrLostModel")
  183. # 指定logger输出格式
  184. formatter = logging.Formatter('%(asctime)s %(levelname)-8s: %(message)s')
  185. # 文件日志
  186. file_handler = logging.FileHandler("/data/appdata/datamining/m01_usrLost_prediction/usrLostPredModel_pys_online.log")
  187. file_handler.setFormatter(formatter) # 可以通过setFormatter指定输出格式
  188. # 控制台日志
  189. console_handler = logging.StreamHandler(sys.stdout)
  190. console_handler.formatter = formatter # 也可以直接给formatter赋值
  191. # 为logger添加的日志处理器,可以自定义日志处理器让其输出到其他地方
  192. logger.addHandler(file_handler)
  193. logger.addHandler(console_handler)
  194. # 指定日志的最低输出级别,默认为WARN级别
  195. logger.setLevel(logging.INFO)
  196. auc = evaluator.evaluate(result)
  197. logger.info("AUC(AreaUnderROC)为:{}".format(auc))
  198. total_amount = result.count()
  199. correct_amount = result.filter(result.label == result.prediction).count()
  200. precision_rate = correct_amount / total_amount
  201. logger.info("预测准确率为:{}".format(precision_rate))
  202. positive_amount = result.filter(result.label == 1).count()
  203. negative_amount = result.filter(result.label == 0).count()
  204. logger.info("正样本数:{},负样本数:{}".format(positive_amount, negative_amount))
  205. positive_precision_amount = result.filter(result.label == 1).filter(result.prediction == 1).count()
  206. negative_precision_amount = result.filter(result.label == 0).filter(result.prediction == 0).count()
  207. positive_false_amount = result.filter(result.label == 1).filter(result.prediction == 0).count()
  208. negative_false_amount = result.filter(result.label == 0).filter(result.prediction == 1).count()
  209. logger.info("正样本预测准确数量:{},负样本预测准确数量:{}".format(positive_precision_amount, negative_precision_amount))
  210. logger.info("正样本预测错误数量:{},负样本预测错误数量:{}".format(positive_false_amount, negative_false_amount))
  211. recall_rate1 = positive_precision_amount / positive_amount
  212. recall_rate2 = negative_precision_amount / negative_amount
  213. logger.info("正样本召回率为:{},负样本召回率为:{}".format(recall_rate1, recall_rate2))
  214. # 移除日志处理器
  215. logger.removeHandler(file_handler)
  216. # 模型验证,评估分类效果
  217. result0 = samelrModel.transform(data3)
  218. predictions = result0.select('label', 'erp_code','prediction', 'probability','erp_corp')
  219. predictions.registerTempTable("tempTable")
  220. spark.sql('insert overwrite table spider.t06_wajue_usrLostprediction select * from tempTable')
  221. # prediction = result0.select('label', 'prediction','rawPrediction', 'probability')
  222. # prediction.toPandas().to_csv('/data/appdata/pred.csv')
  223. spark.stop()