代码GIT地址

gitee: https://gitee.com/cjunhao/sparkdemo

常用算法:

image.png

image.pngKNN算法

image.png

朴素贝叶斯算法

大致原理:底层是根据贝叶斯公式,朴素的原因主要是居于一个假设:每个特征值之间是互相独立的。
训练集(data/bayes/chugui/chugui.txt

  1. name,age,job,salary,label
  2. 张三,29,程序员,10000,是
  3. 李四,20,外卖员,6000,否
  4. 王五,40,公务员,9000,是
  5. 赵六,50,老师,6000,否
  6. 陈盼盼,30,外卖员,7000,否
  7. 李四四,40,公务员,12000,是
  8. 黄顶顶,13,学生,0,否
  9. 吴秘密,60,老师,9000,否
  10. 吴听听,25,舞蹈员,7000,是
  11. 张韩语,32,会计,8000,否
  12. 李听,50,会计,16000,是
  13. 李过,20,外卖员,6000,是
  14. 王一,34,公务员,12000,是
  15. 赵西,40,老师,8000,否

测试集(data/bayes/chugui/test.txt

name,age,job,salary
李婷婷,31,公务员,10000
吴晓琳,27,会计,6000
陈琳,35,老师,3000

测试结果

+------------------+----+-----------------------------------------+-----------------------------------------+----------+
|feature           |name|rawPrediction                            |probability                              |prediction|
+------------------+----+-----------------------------------------+-----------------------------------------+----------+
|[31.0,3.0,10000.0]|李婷婷 |[-245.96880163238194,-238.28167417879274]|[4.584835969725884E-4,0.9995415164030276]|1.0       |
|[27.0,7.0,6000.0] |吴晓琳 |[-229.5993100935666,-232.46015041809136] |[0.945876335461309,0.05412366453869107]  |0.0       |
|[35.0,4.0,3000.0] |陈琳  |[-230.45830976664556,-243.3420329281712] |[0.9999974609630454,2.539036954591085E-6]|0.0       |
+------------------+----+-----------------------------------------+-----------------------------------------+----------+

代码:

package com.atguigu.sparkmllib

import org.apache.log4j.{Level, Logger}
import org.apache.spark.ml.classification.{NaiveBayes, NaiveBayesModel}
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.sql.SparkSession

import scala.collection.mutable

object NaiveBayesTrainAndPredictApp {
  def main(args: Array[String]): Unit = {
    Logger.getLogger("org.apache").setLevel(Level.WARN)
    //加载数据
    val spark = SparkSession.builder().appName("出轨分析").master("local[*]").getOrCreate()
    val df = spark.read.option("header", "true").csv("data/bayes/chugui/chugui.txt")
    //加工数据
    val df2 = df.selectExpr("name",
      "cast(age as double) age",
      "cast(case job when '程序员' then 1.0 when '外卖员' then 2.0 when '公务员' then 3.0 when '老师' then 4.0 when '学生' then 5.0 when '舞蹈员'then 6.0 when '会计' then 7.0 end as double) as job",
      "cast(salary as double) as salary", "cast(case label when '是' then 1.0 when '否' then 0.0 end as double) as label")
    val vec = (wrappedArray: mutable.WrappedArray[Double]) => {
      Vectors.dense(wrappedArray.toArray)
    }
    spark.udf.register("vec", vec)
    val df3 = df2.selectExpr("vec(array(age,job,salary)) as feature", "label")
    //训练模型
    val bayes = new NaiveBayes()
      .setSmoothing(1.0)
      .setFeaturesCol("feature")
      .setLabelCol("label")
    val model = bayes.fit(df3)
    //保存模型
    model.write.overwrite().save("data/bayes/chugui/model")

    //加载测试集
    val testdf = spark.read.option("header", "true").csv("data/bayes/chugui/test.txt")
    //加工测试集
    val testdf2 = testdf.selectExpr("name",
      "cast(age as double) age",
      "cast(case job when '程序员' then 1.0 when '外卖员' then 2.0 when '公务员' then 3.0 when '老师' then 4.0 when '学生' then 5.0 when '舞蹈员'then 6.0 when '会计' then 7.0 end as double) as job",
      "cast(salary as double) as salary")
    val testdf3 = testdf2.selectExpr("vec(array(age,job,salary)) as feature", "name")
    val loadedModel = NaiveBayesModel.load("data/bayes/chugui/model")
    //预测
    val predict = loadedModel.transform(testdf3)
    predict.show(50, false)
  }
}

线性回归

通过求损失函数的最优解来求线性回归,损失函数最优解是通过梯度下降法来推导,底层使用偏导数。

package com.atguigu.sparkmllib

import com.atguigu.commons.SparkUtil
import org.apache.log4j.{Level, Logger}
import org.apache.spark.ml.evaluation.RegressionEvaluator
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.regression.LinearRegression

/* @description: 线性回归预测房价
 * @author: chengjunhao
 * @date: 2020/4/5 17:13
*/
object LinearRegression {
  def main(args: Array[String]): Unit = {
    Logger.getLogger("org").setLevel(Level.WARN)
    //加载数据
    import org.apache.spark.sql.functions._
    val spark = SparkUtil.getSparkSession("LinearRegression")

    def readFile(path: String) = {
      spark.read.textFile(path)
    }
    import spark.implicits._
    //加载房价数据
    val hdata = readFile("data/regression/linearge/housing.data")
    val df = hdata.map(str => {
      val line = str.replaceAll("(\\s+)", " ")
      val arr = line.split(" ")
      val arr2 = arr.map(item => item.toDouble)
      val (features, arrlabel) = arr2.splitAt(arr2.size - 1)
      (Vectors.dense(features), arrlabel(0))
    }).toDF("features", "label")
    df
    //训练模型
    val regression = new LinearRegression()
      .setFeaturesCol("features")
      .setLabelCol("label")
    val model = regression.fit(df)
    val prediction = model.transform(df)
    //房价预测明细
    prediction.show(100, false)
    val evaluator = new RegressionEvaluator()
      .setLabelCol("label")
      .setPredictionCol("prediction")
      .setMetricName("rmse")
    val evaluatorResult = evaluator.evaluate(prediction)
    //房价预测评估
    println(evaluatorResult)
  }
}
//房价预测评估(与真实值的误差)
4.728921754863951
//房价预测明细(特征值,真实房价,预测房价)
+----------------------------------------------------------------------------+-----+------------------+
|features                                                                    |label|prediction        |
+----------------------------------------------------------------------------+-----+------------------+
|[0.00632,18.0,2.31,0.0,0.538,6.575,65.2,4.09,1.0,296.0,15.3,396.9,4.98]     |24.0 |30.190308651220306|
|[0.02731,0.0,7.07,0.0,0.469,6.421,78.9,4.9671,2.0,242.0,17.8,396.9,9.14]    |21.6 |25.18082356676625 |
|[0.02729,0.0,7.07,0.0,0.469,7.185,61.1,4.9671,2.0,242.0,17.8,392.83,4.03]   |34.7 |30.879137984747175|
|[0.03237,0.0,2.18,0.0,0.458,6.998,45.8,6.0622,3.0,222.0,18.7,394.63,2.94]   |33.4 |28.929827983486426|
|[0.06905,0.0,2.18,0.0,0.458,7.147,54.2,6.0622,3.0,222.0,18.7,396.9,5.33]    |36.2 |28.221119235201883|
|[0.02985,0.0,2.18,0.0,0.458,6.43,58.7,6.0622,3.0,222.0,18.7,394.12,5.21]    |28.7 |25.495148321722954|
|[0.08829,12.5,7.87,0.0,0.524,6.012,66.6,5.5605,5.0,311.0,15.2,395.6,12.43]  |22.9 |22.994008794619425|
|[0.14455,12.5,7.87,0.0,0.524,6.172,96.1,5.9505,5.0,311.0,15.2,396.9,19.15]  |27.1 |19.411743228512137|
|[0.21124,12.5,7.87,0.0,0.524,5.631,100.0,6.0821,5.0,311.0,15.2,386.63,29.93]|16.5 |11.116182643466054|
|[0.17004,12.5,7.87,0.0,0.524,6.004,85.9,6.5921,5.0,311.0,15.2,386.71,17.1]  |18.9 |18.875755494056953|
|[0.22489,12.5,7.87,0.0,0.524,6.377,94.3,6.3467,5.0,311.0,15.2,392.52,20.45] |15.0 |18.883718884516938|
|[0.11747,12.5,7.87,0.0,0.524,6.009,82.9,6.2267,5.0,311.0,15.2,396.9,13.27]  |18.9 |21.614326621187548|
|[0.09378,12.5,7.87,0.0,0.524,5.889,39.0,5.4509,5.0,311.0,15.2,390.5,15.71]  |21.7 |20.79957586017244 |
|[0.62976,0.0,8.14,0.0,0.538,5.949,61.8,4.7075,4.0,307.0,21.0,396.9,8.26]    |20.4 |19.738812505148353|
|[0.63796,0.0,8.14,0.0,0.538,6.096,84.5,4.4619,4.0,307.0,21.0,380.02,10.26]  |18.2 |19.392210535216037|
|[0.62739,0.0,8.14,0.0,0.538,5.834,56.5,4.4986,4.0,307.0,21.0,395.62,8.47]   |19.9 |19.454206945022317|
|[1.05393,0.0,8.14,0.0,0.538,5.935,29.3,4.4986,4.0,307.0,21.0,386.85,6.58]   |23.1 |20.734847090514773|
|[0.7842,0.0,8.14,0.0,0.538,5.99,81.7,4.2579,4.0,307.0,21.0,386.75,14.67]    |17.5 |16.90496844020901 |
|[0.80271,0.0,8.14,0.0,0.538,5.456,36.6,3.7965,4.0,307.0,21.0,288.99,11.69]  |20.2 |16.068345653789162|
|[0.7258,0.0,8.14,0.0,0.538,5.727,69.5,3.7965,4.0,307.0,21.0,390.95,11.28]   |18.2 |18.43260488799402 |
|[1.25179,0.0,8.14,0.0,0.538,5.57,98.1,3.7979,4.0,307.0,21.0,376.57,21.02]   |13.6 |12.309407215651753|
|[0.85204,0.0,8.14,0.0,0.538,5.965,89.2,4.0123,4.0,307.0,21.0,392.53,13.83]  |19.6 |17.672096530362325|
|[1.23247,0.0,8.14,0.0,0.538,6.142,91.7,3.9769,4.0,307.0,21.0,396.9,18.72]   |15.2 |15.739106004141671|
|[0.98843,0.0,8.14,0.0,0.538,5.813,100.0,4.0952,4.0,307.0,21.0,394.54,19.88] |14.5 |13.668460075613503|
|[0.75026,0.0,8.14,0.0,0.538,5.924,94.1,4.3996,4.0,307.0,21.0,394.33,16.3]   |15.6 |15.648887074136162|
|[0.84054,0.0,8.14,0.0,0.538,5.599,85.7,4.4546,4.0,307.0,21.0,303.42,16.51]  |13.9 |13.240610975979848|
|[0.67191,0.0,8.14,0.0,0.538,5.813,90.3,4.682,4.0,307.0,21.0,376.88,14.81]   |16.6 |15.464927727688355|
|[0.95577,0.0,8.14,0.0,0.538,6.047,88.8,4.4534,4.0,307.0,21.0,306.38,17.28]  |14.8 |14.573852495098887|
|[0.77299,0.0,8.14,0.0,0.538,6.495,94.4,4.4547,4.0,307.0,21.0,387.94,12.8]   |18.4 |19.628722416039974|
|[1.00245,0.0,8.14,0.0,0.538,6.674,87.3,4.239,4.0,307.0,21.0,380.23,11.98]   |21.0 |20.966640126110367|
|[1.13081,0.0,8.14,0.0,0.538,5.713,94.1,4.233,4.0,307.0,21.0,360.17,22.6]    |12.7 |11.223753725628281|
|[1.35472,0.0,8.14,0.0,0.538,6.072,100.0,4.175,4.0,307.0,21.0,376.73,13.04]  |14.5 |18.093115056496018|
|[1.38799,0.0,8.14,0.0,0.538,5.95,82.0,3.99,4.0,307.0,21.0,232.6,27.71]      |13.2 |8.316132718067578 |
|[1.15172,0.0,8.14,0.0,0.538,5.701,95.0,3.7872,4.0,307.0,21.0,358.77,18.35]  |13.1 |14.115222877378901|
|[1.61282,0.0,8.14,0.0,0.538,6.096,96.9,3.7598,4.0,307.0,21.0,248.31,20.34]  |13.5 |13.399626400996773|
|[0.06417,0.0,5.96,0.0,0.499,5.933,68.2,3.3603,5.0,279.0,19.2,396.9,9.68]    |18.9 |23.722685383288756|
|[0.09744,0.0,5.96,0.0,0.499,5.841,61.4,3.3779,5.0,279.0,19.2,377.56,11.41]  |20.0 |22.184711007711133|
|[0.08014,0.0,5.96,0.0,0.499,5.85,41.5,3.9342,5.0,279.0,19.2,396.9,8.77]     |21.0 |23.07942553867743 |
|[0.17505,0.0,5.96,0.0,0.499,5.966,30.2,3.8473,5.0,279.0,19.2,393.43,10.13]  |24.7 |22.8507702710117  |
|[0.02763,75.0,2.95,0.0,0.428,6.595,21.8,5.4011,3.0,252.0,18.3,395.63,4.32]  |30.8 |31.28006993091833 |
|[0.03359,75.0,2.95,0.0,0.428,7.024,15.8,5.4011,3.0,252.0,18.3,395.62,1.98]  |34.9 |34.21552849591169 |
|[0.12744,0.0,6.91,0.0,0.448,6.77,2.9,5.7209,3.0,233.0,17.9,385.41,4.84]     |26.6 |28.27830217644444 |
|[0.1415,0.0,6.91,0.0,0.448,6.169,6.6,5.7209,3.0,233.0,17.9,383.37,5.81]     |25.3 |25.40527129936151 |
|[0.15936,0.0,6.91,0.0,0.448,6.211,6.5,5.7209,3.0,233.0,17.9,394.46,7.44]    |24.7 |24.787525821054338|
|[0.12269,0.0,6.91,0.0,0.448,6.069,40.0,5.7209,3.0,233.0,17.9,389.39,9.55]   |21.2 |23.05600001254663 |
|[0.17142,0.0,6.91,0.0,0.448,5.682,33.8,5.1004,3.0,233.0,17.9,396.9,10.21]   |19.3 |22.135516569051077|
|[0.18836,0.0,6.91,0.0,0.448,5.786,33.3,5.1004,3.0,233.0,17.9,396.9,14.15]   |20.0 |20.374450708031127|
|[0.22927,0.0,6.91,0.0,0.448,6.03,85.5,5.6894,3.0,233.0,17.9,392.74,18.8]    |16.6 |17.93473222778474 |
|[0.25387,0.0,6.91,0.0,0.448,5.399,95.3,5.87,3.0,233.0,17.9,396.9,30.81]     |14.4 |8.706631427576724 |
|[0.21977,0.0,6.91,0.0,0.448,5.602,62.0,6.0877,3.0,233.0,17.9,396.9,16.2]    |19.4 |17.178846384667416|
|[0.08873,21.0,5.64,0.0,0.439,5.963,45.7,6.8147,4.0,243.0,16.8,395.56,13.45] |19.7 |21.245543392473255|
|[0.04337,21.0,5.64,0.0,0.439,6.115,63.0,6.8147,4.0,243.0,16.8,393.97,9.43]  |20.5 |24.037054024454424|
|[0.0536,21.0,5.64,0.0,0.439,6.511,21.1,6.8147,4.0,243.0,16.8,396.9,5.28]    |25.0 |27.842340282261322|
|[0.04981,21.0,5.64,0.0,0.439,5.998,21.4,6.8147,4.0,243.0,16.8,396.9,8.43]   |23.4 |24.133812775357192|
|[0.0136,75.0,4.0,0.0,0.41,5.888,47.6,7.3197,3.0,469.0,21.1,396.9,14.8]      |18.9 |15.204212070389076|
|[0.01311,90.0,1.22,0.0,0.403,7.249,21.9,8.6966,5.0,226.0,17.9,395.93,4.81]  |35.4 |31.17273816244207 |
|[0.02055,85.0,0.74,0.0,0.41,6.383,35.7,9.1876,2.0,313.0,17.3,396.9,5.77]    |24.7 |25.003120474324106|
|[0.01432,100.0,1.32,0.0,0.411,6.816,40.5,8.3248,5.0,256.0,15.1,392.9,3.95]  |31.6 |33.0501442598882  |
|[0.15445,25.0,5.13,0.0,0.453,6.145,29.2,7.8148,8.0,284.0,19.7,390.68,6.86]  |23.3 |21.844459264785772|
|[0.10328,25.0,5.13,0.0,0.453,5.927,47.2,6.932,8.0,284.0,19.7,396.9,9.22]    |19.6 |21.0226888914256  |
|[0.14932,25.0,5.13,0.0,0.453,5.741,66.2,7.2254,8.0,284.0,19.7,395.11,13.15] |18.7 |17.730297622043963|
|[0.17171,25.0,5.13,0.0,0.453,5.966,93.4,6.8185,8.0,284.0,19.7,378.08,14.44] |16.0 |18.300089563663125|
|[0.11027,25.0,5.13,0.0,0.453,6.456,67.8,7.2255,8.0,284.0,19.7,396.9,6.73]   |22.2 |24.035933036957566|
|[0.1265,25.0,5.13,0.0,0.453,6.762,43.4,7.9809,8.0,284.0,19.7,395.58,9.5]    |25.0 |22.611161965976976|
|[0.01951,17.5,1.38,0.0,0.4161,7.104,59.5,9.2229,3.0,216.0,18.6,393.24,8.05] |33.0 |23.7118587624982  |
|[0.03584,80.0,3.37,0.0,0.398,6.29,17.8,6.6115,4.0,337.0,16.1,396.9,4.67]    |23.5 |30.26687448389147 |
|[0.04379,80.0,3.37,0.0,0.398,5.787,31.1,6.6115,4.0,337.0,16.1,396.9,10.24]  |19.4 |25.27801769118446 |
|[0.05789,12.5,6.07,0.0,0.409,5.878,21.4,6.498,4.0,345.0,18.9,396.21,8.1]    |22.0 |21.21194768977976 |
|[0.13554,12.5,6.07,0.0,0.409,5.594,36.8,6.498,4.0,345.0,18.9,396.9,13.09]   |17.4 |17.390004693770543|
|[0.12816,12.5,6.07,0.0,0.409,5.885,33.0,6.498,4.0,345.0,18.9,396.9,8.79]    |20.9 |20.870257832388887|
|[0.08826,0.0,10.81,0.0,0.413,6.417,6.6,5.2873,4.0,305.0,19.2,383.73,6.72]   |24.2 |25.339885054239968|
|[0.15876,0.0,10.81,0.0,0.413,5.961,17.5,5.2873,4.0,305.0,19.2,376.94,9.88]  |21.7 |21.776730166784652|
|[0.09164,0.0,10.81,0.0,0.413,6.065,7.8,5.2873,4.0,305.0,19.2,390.91,5.52]   |22.8 |24.713947364951146|
|[0.19539,0.0,10.81,0.0,0.413,6.245,6.2,5.2873,4.0,305.0,19.2,377.17,7.54]   |23.4 |24.148328355465534|
|[0.07896,0.0,12.83,0.0,0.437,6.273,6.0,4.2515,5.0,398.0,18.7,394.92,6.78]   |24.1 |25.580263420202034|
|[0.09512,0.0,12.83,0.0,0.437,6.286,45.0,4.5026,5.0,398.0,18.7,383.23,8.94]  |21.4 |23.999948769178417|
|[0.10153,0.0,12.83,0.0,0.437,6.279,74.5,4.0522,5.0,398.0,18.7,373.66,11.97] |20.0 |22.86080492474821 |
|[0.08707,0.0,12.83,0.0,0.437,6.14,45.8,4.0905,5.0,398.0,18.7,386.96,10.27]  |20.8 |23.322112071030045|
|[0.05646,0.0,12.83,0.0,0.437,6.232,53.7,5.0141,5.0,398.0,18.7,386.4,12.34]  |21.2 |21.254832078946972|
|[0.08387,0.0,12.83,0.0,0.437,5.874,36.6,4.5026,5.0,398.0,18.7,396.06,9.1]   |20.3 |22.449559654604922|
|[0.04113,25.0,4.86,0.0,0.426,6.727,33.5,5.4007,4.0,281.0,19.0,396.9,5.29]   |28.0 |28.480573367799344|
|[0.04462,25.0,4.86,0.0,0.426,6.619,70.4,5.4007,4.0,281.0,19.0,395.63,7.22]  |23.9 |27.017829328681373|
|[0.03659,25.0,4.86,0.0,0.426,6.302,32.2,5.4007,4.0,281.0,19.0,396.9,6.72]   |24.8 |26.054218988842194|
|[0.03551,25.0,4.86,0.0,0.426,6.167,46.7,5.4007,4.0,281.0,19.0,390.64,7.51]  |22.9 |25.044752060188472|
|[0.05059,0.0,4.49,0.0,0.449,6.389,48.0,4.7794,3.0,247.0,18.5,396.9,9.62]    |23.9 |24.832618514589313|
|[0.05735,0.0,4.49,0.0,0.449,6.63,56.1,4.4377,3.0,247.0,18.5,392.3,6.53]     |26.6 |27.89263939520159 |
|[0.05188,0.0,4.49,0.0,0.449,6.015,45.1,4.4272,3.0,247.0,18.5,395.99,12.86]  |22.5 |22.091164776225657|
|[0.07151,0.0,4.49,0.0,0.449,6.121,56.8,3.7476,3.0,247.0,18.5,395.15,8.44]   |22.2 |25.871768906772093|
|[0.0566,0.0,3.41,0.0,0.489,7.007,86.3,3.4217,2.0,270.0,17.8,396.9,5.5]      |23.6 |30.81825218329388 |
|[0.05302,0.0,3.41,0.0,0.489,7.079,63.1,3.4145,2.0,270.0,17.8,396.06,5.7]    |28.7 |30.971595744003565|
|[0.04684,0.0,3.41,0.0,0.489,6.417,66.1,3.0923,2.0,270.0,17.8,392.18,8.81]   |22.6 |27.12110239078219 |
|[0.03932,0.0,3.41,0.0,0.489,6.405,73.9,3.0921,2.0,270.0,17.8,393.55,8.2]    |22.0 |27.430097731675495|
|[0.04203,28.0,15.04,0.0,0.464,6.442,53.6,3.6659,4.0,270.0,18.2,395.01,8.16] |22.9 |28.908354939561406|
|[0.02875,28.0,15.04,0.0,0.464,6.211,28.9,3.6659,4.0,270.0,18.2,396.33,6.21] |25.0 |29.082827105623092|
|[0.04294,28.0,15.04,0.0,0.464,6.249,77.3,3.615,4.0,270.0,18.2,396.9,10.59]  |20.6 |26.937832632384236|
|[0.12204,0.0,2.89,0.0,0.445,6.625,57.8,3.4952,2.0,276.0,18.0,357.98,6.65]   |28.4 |28.635280481715768|
|[0.11504,0.0,2.89,0.0,0.445,6.163,69.6,3.4952,2.0,276.0,18.0,391.83,11.34]  |21.4 |24.640422958037888|
|[0.12083,0.0,2.89,0.0,0.445,8.069,76.0,3.4952,2.0,276.0,18.0,396.9,4.21]    |38.7 |35.96955000322432 |
|[0.08187,0.0,2.89,0.0,0.445,7.82,36.9,3.4952,2.0,276.0,18.0,393.53,3.57]    |43.8 |35.29890121386673 |
|[0.0686,0.0,2.89,0.0,0.445,7.416,62.5,3.4952,2.0,276.0,18.0,396.9,6.19]     |33.2 |32.35610981009244 |
+----------------------------------------------------------------------------+-----+------------------+

数学知识:

向量

image.png
image.png
image.png

向量拓展:欧式里德距离(相似度)、余弦相似度

//算欧式里德相似度、余弦相似度
import java.util.logging.{Level, Logger}

import org.apache.spark.ml.linalg._
import org.apache.spark.sql._

object App {
  def main(args: Array[String]): Unit = {
    Logger.getLogger("org").setLevel(Level.WARNING)
    //读数据
    val spark = SparkSession.builder().appName("similarityCalculate").master("local[*]").getOrCreate()
    val ds: Dataset[String] = spark.read.textFile("data/persons.txt")
    import spark.implicits._
    //加工数据
    val attri = ds.map(line => {
      val ar = line.split(",")
      val userid = ar(0)
      val vector = Vectors.dense(ar.tail.map(item => item.toDouble))
      (userid, vector)
    }).toDF("userid", "vector")
    attri.show(100, false)
    import org.apache.spark.sql.functions._
    val joined = attri.join(attri.toDF("b_userid", "b_vector"), col("userid") < col("b_userid"))
     //计算欧几里得距离相似度、余弦相似度
    spark.udf.register("euc", eucliSimilarity)
    spark.udf.register("cos", cosSimilarity)
    //oined.selectExpr("userid", "b_userid", "euc(vector,b_vector)", "cos(vector,b_vector)").show()
    joined.createTempView("joined")
    spark.sql(
      """
        |select
        |userid,
        |b_userid,
        |euc(vector,b_vector),
        |cos(vector,b_vector)
        |from joined
      """.stripMargin).show(200,false)
    //打印结果
  }

  //欧几里得相似度
  val eucliSimilarity = (uv: Vector, buv: Vector) => {
    1 / (1 + Math.pow(Vectors.sqdist(uv, buv), 0.5))
  }
  //余弦相似度
  val cosSimilarity = (uv: Vector, buv: Vector) => {
    val sum1 = uv.toArray.map(item => Math.pow(item, 2)).sum
    val sum2 = buv.toArray.map(item => Math.pow(item, 2)).sum
    val tuples = uv.toArray.zip(buv.toArray)
    val sum3 = tuples.map(item => item._1 * item._2).sum
    sum3 / (sum1 * sum2)
  }
}

/* 
+----persons.txt----+
u1,20,172,98,3999
u2,30,174,70,6999
u3,22,171,78,4999
u4,40,178,75,7999
u5,33,172,68,6999
u6,31,188,68,8999
u7,50,172,58,8099
u8,18,172,99,999
u9,70,160,44,1999

+----输出结果----+
+------+--------+-------------------------+-------------------------+
|userid|b_userid|UDF:euc(vector, b_vector)|UDF:cos(vector, b_vector)|
+------+--------+-------------------------+-------------------------+
|u1    |u2      |3.332058269756651E-4     |3.566143930793115E-8     |
|u1    |u3      |9.987989647192792E-4     |4.99210425243552E-8      |
|u1    |u4      |2.4992998066749325E-4    |3.120334262477809E-8     |
|u1    |u5      |3.3320247790333556E-4    |3.5661395232678165E-8    |
|u1    |u6      |1.9995490223656236E-4    |2.7735692648707768E-8    |
|u1    |u7      |2.438248392885663E-4     |3.0817318676858244E-8    |
|u1    |u8      |3.332221667160612E-4     |2.4252302194838063E-7    |
|u1    |u9      |4.994033335448089E-4     |1.2429473207172002E-7    |
|u2    |u3      |4.997415712132636E-4     |2.8548568958543982E-8    |
|u2    |u4      |9.989306472244388E-4     |1.7849858206328596E-8    |
|u2    |u5      |0.1951941016011038       |2.0399131030943938E-8    |
|u2    |u6      |4.997375754634144E-4     |1.5867017705841344E-8    |
|u2    |u7      |9.080597966973777E-4     |1.7629908498384664E-8    |
|u2    |u8      |1.6663660497596988E-4    |1.3818168544604422E-7    |
|u2    |u9      |1.9995012468525684E-4    |7.10140480147098E-8      |
|u3    |u4      |3.332151901115194E-4     |2.4980364474953618E-8    |
|u3    |u5      |4.997362643788828E-4     |2.8548670489039806E-8    |
|u3    |u6      |2.499338456621841E-4     |2.2204923975223692E-8    |
|u3    |u7      |3.2245674655075133E-4    |2.4672014941383436E-8    |
|u3    |u8      |2.499339393612514E-4     |1.9374745563676967E-7    |
|u3    |u9      |3.331560083851983E-4     |9.944297027049687E-8     |
|u4    |u5      |9.98934139515095E-4      |1.7850018870123525E-8    |
|u4    |u6      |9.988862484373682E-4     |1.3884385937169933E-8    |
|u4    |u7      |0.009699049240532491     |1.5427018165936453E-8    |
|u4    |u8      |1.428351404244442E-4     |1.2086241316083946E-7    |
|u4    |u9      |1.666338375630366E-4     |6.213190885762279E-8     |
|u5    |u6      |4.997338919664413E-4     |1.586716435702208E-8     |
|u5    |u7      |9.081193831836474E-4     |1.7630092300025892E-8    |
|u5    |u8      |1.666361491300729E-4     |1.3817428819747693E-7    |
|u5    |u9      |1.9995165586315277E-4    |7.101417386546018E-8     |
|u6    |u7      |0.0011093875601599508    |1.3713415495054653E-8    |
|u6    |u8      |1.249830237975477E-4     |1.0739544562964691E-7    |
|u6    |u9      |1.428325392852325E-4     |5.522204853293927E-8     |
|u7    |u8      |1.408214582261319E-4     |1.1933383836016132E-7    |
|u7    |u9      |1.6390592660457887E-4    |6.136345718376115E-8     |
|u8    |u9      |9.960832005447605E-4     |4.856381190959876E-7     |
+------+--------+-------------------------+-------------------------+

*/

用户画像模型标签

情感分析

image.png

实现步骤:

  1. 中文分词
  2. 文本的特征向量化

向量化方案:
image.png

image.png
tf-idf=tf * idf
idf=lg(文档总数/1+出现这个词的文档)
image.png
**

if-idf代码实现版:

package com.atguigu.sparkmllib

import org.apache.commons.codec.digest.DigestUtils
import org.apache.log4j.{Level, Logger}
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.sql.{Row, SparkSession}

object TFIDF {
  def main(args: Array[String]): Unit = {
    Logger.getLogger("org").setLevel(Level.WARN)
    //加载数据
    val spark = SparkSession.builder().appName("tfidf").master("local[*]").getOrCreate()
    val df = spark.read.option("header", "true").csv("data/tfidf/train.txt")
    df.cache()
    val totalCount = df.count()
    // 获取单纯数(不重复)
    val wordNum = df.rdd.map(row => row.getAs("doc").toString).flatMap(item => item.split(" ")).distinct().count()
    val bucketNum = (26).toInt
    //加工数据
    import spark.implicits._
    // 求tf向量
    val tfVector = df.map({
      case Row(id: String, doc: String) => {
        val arr = doc.split(" ")
        val wordcount = arr.map(item => (item, 1)).groupBy(_._1).mapValues(_.size)
        val vectorarr = Vectors.zeros(bucketNum).toArray
        wordcount.foreach(item => {
          val index: Int = (DigestUtils.md5Hex(item._1).hashCode & Int.MaxValue % bucketNum).toInt
          vectorarr(index) = item._2
        })
        (id, vectorarr)
      }
    })
    tfVector.show(30, false)
    //求idf向量
    val dtfVector = tfVector.rdd.map(item => item._2.map(item => {
      if (item != 0) {
        1
      } else {
        0
      }
    })).reduce((i, current) => i.zip(current).map((item1) => {
      item1._1 + item1._2
    })).map(value => Math.log10(totalCount / (0.01 + value)))
    //算出最总tf-idf向量
    val bdTFVector = spark.sparkContext.broadcast(dtfVector)
    val tfIdfVectors = tfVector.map(item => {
      val tfIdfValues = item._2.zip(bdTFVector.value).map(tp => tp._1 * tp._2)
      (item._1,tfIdfValues)
    }).toDF("id","tf-idf")
    tfIdfVectors.show(40,false)
  }
}

if-idf算法调用版:

package com.atguigu.sparkmllib

import org.apache.log4j.{Level, Logger}
import org.apache.spark.ml.feature.{HashingTF, IDF}
import org.apache.spark.sql.SparkSession

object TFIDFV2 {
  def main(args: Array[String]): Unit = {
    Logger.getLogger("org").setLevel(Level.WARN)
    //加载数据
    val spark = SparkSession.builder().appName("tfidf").master("local[*]").getOrCreate()
    val df = spark.read.option("header", "true").csv("data/tfidf/train.txt")
    df.cache()
    val totalCount = df.count()
    // 获取单纯数(不重复)
    val wordNum = df.rdd.map(row => row.getAs("doc").toString).flatMap(item => item.split(" ")).distinct().count()
    val df2 = df.selectExpr("id", "split(doc,' ') as words")
    val hashingTF = new HashingTF()
      .setInputCol("words")
      .setOutputCol("tfvec")
      .setNumFeatures(Math.pow(wordNum, 2).toInt)
    val tfdf = hashingTF.transform(df2)
    tfdf.show(10, false)

    val idf = new IDF()
      .setInputCol("tfvec")
      .setOutputCol("idfvec")
    val model = idf.fit(tfdf)
    val tfidf = model.transform(tfdf)
    tfidf.show(10,false)
  }
}

+---+------------------------+---------------------------------------------+------------------------------------------------------------------------------------------------------------------------+
|id |words                   |tfvec                                        |idfvec                                                                                                                  |
+---+------------------------+---------------------------------------------+------------------------------------------------------------------------------------------------------------------------+
|1  |[a, a, a, x, x, y]      |(121,[59,102,106],[2.0,1.0,3.0])             |(121,[59,102,106],[0.3083013596545167,0.15415067982725836,2.5418935811616112])                                          |
|2  |[b, b, c, x, y]         |(121,[59,74,101,102],[1.0,1.0,2.0,1.0])      |(121,[59,74,101,102],[0.15415067982725836,1.252762968495368,2.505525936990736,0.15415067982725836])                     |
|3  |[d, d, d, d, x, x, x, y]|(121,[59,96,102],[3.0,4.0,1.0])              |(121,[59,96,102],[0.46245203948177505,3.3891914415488147,0.15415067982725836])                                          |
|4  |[e, d, x, x]            |(121,[59,77,96],[2.0,1.0,1.0])               |(121,[59,77,96],[0.3083013596545167,1.252762968495368,0.8472978603872037])                                              |
|5  |[g, f, y, y]            |(121,[52,102,110],[1.0,2.0,1.0])             |(121,[52,102,110],[1.252762968495368,0.3083013596545167,1.252762968495368])                                             |
|6  |[i, i, i, a, h, h, x, y]|(121,[3,59,69,102,106],[3.0,1.0,2.0,1.0,1.0])|(121,[3,59,69,102,106],[3.758288905486104,0.15415067982725836,2.505525936990736,0.15415067982725836,0.8472978603872037])|
+---+------------------------+---------------------------------------------+------------------------------------------------------------------------------------------------------------------------+

情感分析完整实现代码

package com.atguigu.sparkmllib

import com.atguigu.commons.SparkUtil
import com.hankcs.hanlp.HanLP
import com.hankcs.hanlp.corpus.tag.Nature
import org.apache.log4j.{Level, Logger}
import org.apache.spark.ml.classification.NaiveBayes
import org.apache.spark.ml.feature.{HashingTF, IDF}
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}

/* @description: 情感分析
 * @author: chengjunhao
 * @date: 2020/4/5 10:49
 * pos.txt:酒店正面评论
 * neg.txt: 酒店负面评论
 * 数据来源:https://www.aitechclub.com/data-detail?data_id=23
*/
object CommentClassify {
  def main(args: Array[String]): Unit = {
    Logger.getLogger("org").setLevel(Level.WARN)
    //加载数据
    import org.apache.spark.sql.functions._
    val spark = SparkUtil.getSparkSession("CommentClassify")

    def readFile(path: String) = {
      spark.read.format("text").load(path).select(decode(col("value"), "GBK"))
    }
    //加载酒店正负面评论
    val posData = readFile("data/comment/hotel/pos.txt")
    val negData = readFile("data/comment/hotel/neg.txt")
    //加工数据
    import scala.collection.JavaConversions._
    import spark.implicits._
    def transformDF(data: DataFrame, label: Int) = {
      val df = data.map(row => {
        val str = row(0).toString
        var terms = HanLP.segment(str)
        val words = terms.filter(n => {
          n.nature != Nature.w && n.nature != Nature.u && n.nature != Nature.y
        }).map(term => term.word)
        (label, words)
      }).toDF("label", "comment")
      df
    }
    val posdf = transformDF(posData, 1)
    val negdf = transformDF(negData, 0)
    val alldf = posdf.union(negdf)
    //获取tf-idf
    val htf = new HashingTF()
      .setInputCol("comment")
      .setOutputCol("comment_tf")
      .setNumFeatures(1000000)
    val tf: DataFrame = htf.transform(alldf)
    val idf = new IDF()
      .setInputCol("comment_tf")
      .setOutputCol("comment_idf")
    val idfModel = idf.fit(tf)
    var attr = idfModel.transform(tf)
    //提取特征向量
    attr = attr.drop("comment", "comment_tf")
      .withColumnRenamed("comment_idf", "comment_tfidf")
    //划分训练集和测试集
    val Array(train, test) = attr.randomSplit(Array(9, 1))
    val bayes = new NaiveBayes()
      .setFeaturesCol("comment_tfidf")
      .setLabelCol("label")
      .setSmoothing(1.0)
    //训练模型
    val bayesModel = bayes.fit(train)
    //测试模型
    val testResult = bayesModel.transform(test)
    val testResult2 = testResult.drop("comment_tfidf", "rawPrediction")
    testResult2.show(100,false)
    val correctCount = testResult2.filter("prediction==label").count()
    val totalCount = testResult2.count()
    println("测试样本数:"+totalCount)
    println("预测正确数:"+correctCount)
    println("预测准确率:" + (correctCount.toDouble / totalCount.toDouble) * 100 + "%")
  }
}


结果:
+-----+-------------------------------------------+----------+
测试样本数:1006
预测正确数:891
预测准确率:88.56858846918489%
+-----+-------------------------------------------+----------+

前100条预测明细:
+-----+-------------------------------------------+----------+
|label|probability                                |prediction|
+-----+-------------------------------------------+----------+
|1    |[9.678717143234523E-59,1.0]                |1.0       |
|1    |[8.775062678782543E-142,1.0]               |1.0       |
|1    |[6.285157015859581E-287,1.0]               |1.0       |
|1    |[6.116057531716703E-113,1.0]               |1.0       |
|1    |[4.737281726805822E-150,1.0]               |1.0       |
|1    |[3.851913734023942E-10,0.9999999996148086] |1.0       |
|1    |[1.3879510356857376E-104,1.0]              |1.0       |
|1    |[1.0,3.651790756375366E-62]                |0.0       |
|1    |[7.70413622355533E-40,1.0]                 |1.0       |
|1    |[3.9760940936939995E-28,1.0]               |1.0       |
|1    |[2.78706766383241E-48,1.0]                 |1.0       |
|1    |[9.288763280945696E-166,1.0]               |1.0       |
|1    |[0.9997137425609458,2.8625743905415067E-4] |0.0       |
|1    |[9.84622497818606E-59,1.0]                 |1.0       |
|1    |[0.9999999999988398,1.1600819770604523E-12]|0.0       |
|1    |[2.3374019217668343E-97,1.0]               |1.0       |
|1    |[1.367557692324186E-9,0.9999999986324424]  |1.0       |
|1    |[9.978630188774865E-68,1.0]                |1.0       |
|1    |[6.0737540285315896E-86,1.0]               |1.0       |
|1    |[1.2923974592986347E-17,1.0]               |1.0       |
|1    |[5.087678913605508E-9,0.999999994912321]   |1.0       |
|1    |[4.766936602531464E-8,0.999999952330634]   |1.0       |
|1    |[1.141142922229361E-127,1.0]               |1.0       |
|1    |[2.970308596304562E-13,0.9999999999997029] |1.0       |
|1    |[1.4296598862118238E-42,1.0]               |1.0       |
|1    |[6.971012507333794E-148,1.0]               |1.0       |
|1    |[7.114643055746503E-99,1.0]                |1.0       |
|1    |[6.245858335922584E-82,1.0]                |1.0       |
|1    |[1.2987703764237814E-40,1.0]               |1.0       |
|1    |[1.9927201000367488E-136,1.0]              |1.0       |
|1    |[1.0586783109230259E-66,1.0]               |1.0       |
|1    |[4.475148365444469E-34,1.0]                |1.0       |
|1    |[6.42265390063968E-149,1.0]                |1.0       |
|1    |[1.7670324846527388E-53,1.0]               |1.0       |
|1    |[3.214215818203068E-29,1.0]                |1.0       |
|1    |[2.7031436124761305E-91,1.0]               |1.0       |
|1    |[1.1178924868445532E-27,1.0]               |1.0       |
|1    |[7.097686736295291E-14,0.999999999999929]  |1.0       |
|1    |[4.095785012296749E-222,1.0]               |1.0       |
|1    |[3.89706391895739E-299,1.0]                |1.0       |
|1    |[1.580874182227348E-116,1.0]               |1.0       |
|1    |[1.6808571855446918E-26,1.0]               |1.0       |
|1    |[2.122226182799178E-43,1.0]                |1.0       |
|1    |[2.4510384299167502E-95,1.0]               |1.0       |
|1    |[4.948393044822232E-71,1.0]                |1.0       |
|1    |[9.214193434525995E-64,1.0]                |1.0       |
|1    |[1.0829059305785953E-56,1.0]               |1.0       |
|1    |[3.1353496199852386E-69,1.0]               |1.0       |
|1    |[3.230419370748443E-89,1.0]                |1.0       |
|1    |[1.208868987138325E-47,1.0]                |1.0       |
|1    |[0.0,1.0]                                  |1.0       |
|1    |[0.9999999999998364,1.6365818485051497E-13]|0.0       |
|1    |[1.526303473662237E-139,1.0]               |1.0       |
|1    |[5.831580863834545E-109,1.0]               |1.0       |
|1    |[9.693350571265665E-20,1.0]                |1.0       |
|1    |[6.323437863366204E-103,1.0]               |1.0       |
|1    |[1.0,1.4768880153660372E-17]               |0.0       |
|1    |[2.8273558435432666E-115,1.0]              |1.0       |
|1    |[1.3864521422689692E-198,1.0]              |1.0       |
|1    |[1.0,4.0873159958445865E-38]               |0.0       |
|1    |[4.379279554305029E-88,1.0]                |1.0       |
|1    |[0.999999999925818,7.418191559676693E-11]  |0.0       |
|1    |[5.020061656540557E-25,1.0]                |1.0       |
|1    |[2.4363804471091126E-56,1.0]               |1.0       |
|1    |[1.0,2.5598337472000046E-23]               |0.0       |
|1    |[5.518535257484849E-36,1.0]                |1.0       |
|1    |[5.1309103906743525E-33,1.0]               |1.0       |
|1    |[1.6224323763499578E-29,1.0]               |1.0       |
|1    |[4.536961252116455E-61,1.0]                |1.0       |
|1    |[9.894093466435771E-76,1.0]                |1.0       |
|1    |[1.2451317061247508E-17,1.0]               |1.0       |
|1    |[9.024371536509077E-109,1.0]               |1.0       |
|1    |[2.0559019081889856E-67,1.0]               |1.0       |
|1    |[1.4145467277937284E-26,1.0]               |1.0       |
|1    |[3.8234346133769476E-44,1.0]               |1.0       |
|1    |[8.29804708513053E-69,1.0]                 |1.0       |
|1    |[1.8247535837043456E-11,0.9999999999817524]|1.0       |
|1    |[2.523500736280223E-28,1.0]                |1.0       |
|1    |[9.94053781304072E-38,1.0]                 |1.0       |
|1    |[3.8448382830707386E-14,0.9999999999999616]|1.0       |
|1    |[2.7033352857318492E-39,1.0]               |1.0       |
|1    |[0.07353055667343208,0.9264694433265679]   |1.0       |
|1    |[2.562594162188542E-30,1.0]                |1.0       |
|1    |[1.0,2.5308953260610245E-17]               |0.0       |
|1    |[0.9999886571179135,1.134288208641545E-5]  |0.0       |
|1    |[8.509917767568811E-222,1.0]               |1.0       |
|1    |[6.086933620962101E-97,1.0]                |1.0       |
|1    |[1.0,3.96823938785063E-18]                 |0.0       |
|1    |[8.568268333992698E-4,0.9991431731666007]  |1.0       |
|1    |[1.6693327787350398E-60,1.0]               |1.0       |
|1    |[7.210961234876599E-149,1.0]               |1.0       |
|1    |[2.2681097229354707E-206,1.0]              |1.0       |
|1    |[7.480989153144268E-24,1.0]                |1.0       |
|1    |[3.5723396221655744E-47,1.0]               |1.0       |
|1    |[0.9755649920158308,0.024435007984169203]  |0.0       |
|1    |[5.725331858163097E-52,1.0]                |1.0       |
|1    |[9.179762909765966E-15,0.9999999999999909] |1.0       |
|1    |[2.3344359417810775E-56,1.0]               |1.0       |
|1    |[1.1420687103520189E-6,0.9999988579312895] |1.0       |
|1    |[1.9164354713814646E-12,0.9999999999980835]|1.0       |
+-----+-------------------------------------------+----------+

性别预测

特征数据如下:
image.png
特征值大部分属于离散型,比较适合朴素贝叶斯算法。当然有两个字段是连续的:”30天内的购买单数”、”30天内的消费总金额”,对于这两个字段,可以进行区间化。

用户流失风险

特征值如下:
image.png

特征值属于连续型,而且预测值受特征值大小的影响。所以比较适合逻辑回归算法。