代码GIT地址
gitee: https://gitee.com/cjunhao/sparkdemo
常用算法:
KNN算法
朴素贝叶斯算法
大致原理:底层是根据贝叶斯公式,朴素的原因主要是居于一个假设:每个特征值之间是互相独立的。
训练集(data/bayes/chugui/chugui.txt)
name,age,job,salary,label张三,29,程序员,10000,是李四,20,外卖员,6000,否王五,40,公务员,9000,是赵六,50,老师,6000,否陈盼盼,30,外卖员,7000,否李四四,40,公务员,12000,是黄顶顶,13,学生,0,否吴秘密,60,老师,9000,否吴听听,25,舞蹈员,7000,是张韩语,32,会计,8000,否李听,50,会计,16000,是李过,20,外卖员,6000,是王一,34,公务员,12000,是赵西,40,老师,8000,否
测试集(data/bayes/chugui/test.txt)
name,age,job,salary
李婷婷,31,公务员,10000
吴晓琳,27,会计,6000
陈琳,35,老师,3000
测试结果
+------------------+----+-----------------------------------------+-----------------------------------------+----------+
|feature |name|rawPrediction |probability |prediction|
+------------------+----+-----------------------------------------+-----------------------------------------+----------+
|[31.0,3.0,10000.0]|李婷婷 |[-245.96880163238194,-238.28167417879274]|[4.584835969725884E-4,0.9995415164030276]|1.0 |
|[27.0,7.0,6000.0] |吴晓琳 |[-229.5993100935666,-232.46015041809136] |[0.945876335461309,0.05412366453869107] |0.0 |
|[35.0,4.0,3000.0] |陈琳 |[-230.45830976664556,-243.3420329281712] |[0.9999974609630454,2.539036954591085E-6]|0.0 |
+------------------+----+-----------------------------------------+-----------------------------------------+----------+
代码:
package com.atguigu.sparkmllib
import org.apache.log4j.{Level, Logger}
import org.apache.spark.ml.classification.{NaiveBayes, NaiveBayesModel}
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.sql.SparkSession
import scala.collection.mutable
object NaiveBayesTrainAndPredictApp {
def main(args: Array[String]): Unit = {
Logger.getLogger("org.apache").setLevel(Level.WARN)
//加载数据
val spark = SparkSession.builder().appName("出轨分析").master("local[*]").getOrCreate()
val df = spark.read.option("header", "true").csv("data/bayes/chugui/chugui.txt")
//加工数据
val df2 = df.selectExpr("name",
"cast(age as double) age",
"cast(case job when '程序员' then 1.0 when '外卖员' then 2.0 when '公务员' then 3.0 when '老师' then 4.0 when '学生' then 5.0 when '舞蹈员'then 6.0 when '会计' then 7.0 end as double) as job",
"cast(salary as double) as salary", "cast(case label when '是' then 1.0 when '否' then 0.0 end as double) as label")
val vec = (wrappedArray: mutable.WrappedArray[Double]) => {
Vectors.dense(wrappedArray.toArray)
}
spark.udf.register("vec", vec)
val df3 = df2.selectExpr("vec(array(age,job,salary)) as feature", "label")
//训练模型
val bayes = new NaiveBayes()
.setSmoothing(1.0)
.setFeaturesCol("feature")
.setLabelCol("label")
val model = bayes.fit(df3)
//保存模型
model.write.overwrite().save("data/bayes/chugui/model")
//加载测试集
val testdf = spark.read.option("header", "true").csv("data/bayes/chugui/test.txt")
//加工测试集
val testdf2 = testdf.selectExpr("name",
"cast(age as double) age",
"cast(case job when '程序员' then 1.0 when '外卖员' then 2.0 when '公务员' then 3.0 when '老师' then 4.0 when '学生' then 5.0 when '舞蹈员'then 6.0 when '会计' then 7.0 end as double) as job",
"cast(salary as double) as salary")
val testdf3 = testdf2.selectExpr("vec(array(age,job,salary)) as feature", "name")
val loadedModel = NaiveBayesModel.load("data/bayes/chugui/model")
//预测
val predict = loadedModel.transform(testdf3)
predict.show(50, false)
}
}
线性回归
通过求损失函数的最优解来求线性回归,损失函数最优解是通过梯度下降法来推导,底层使用偏导数。
package com.atguigu.sparkmllib
import com.atguigu.commons.SparkUtil
import org.apache.log4j.{Level, Logger}
import org.apache.spark.ml.evaluation.RegressionEvaluator
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.regression.LinearRegression
/* @description: 线性回归预测房价
* @author: chengjunhao
* @date: 2020/4/5 17:13
*/
object LinearRegression {
def main(args: Array[String]): Unit = {
Logger.getLogger("org").setLevel(Level.WARN)
//加载数据
import org.apache.spark.sql.functions._
val spark = SparkUtil.getSparkSession("LinearRegression")
def readFile(path: String) = {
spark.read.textFile(path)
}
import spark.implicits._
//加载房价数据
val hdata = readFile("data/regression/linearge/housing.data")
val df = hdata.map(str => {
val line = str.replaceAll("(\\s+)", " ")
val arr = line.split(" ")
val arr2 = arr.map(item => item.toDouble)
val (features, arrlabel) = arr2.splitAt(arr2.size - 1)
(Vectors.dense(features), arrlabel(0))
}).toDF("features", "label")
df
//训练模型
val regression = new LinearRegression()
.setFeaturesCol("features")
.setLabelCol("label")
val model = regression.fit(df)
val prediction = model.transform(df)
//房价预测明细
prediction.show(100, false)
val evaluator = new RegressionEvaluator()
.setLabelCol("label")
.setPredictionCol("prediction")
.setMetricName("rmse")
val evaluatorResult = evaluator.evaluate(prediction)
//房价预测评估
println(evaluatorResult)
}
}
//房价预测评估(与真实值的误差)
4.728921754863951
//房价预测明细(特征值,真实房价,预测房价)
+----------------------------------------------------------------------------+-----+------------------+
|features |label|prediction |
+----------------------------------------------------------------------------+-----+------------------+
|[0.00632,18.0,2.31,0.0,0.538,6.575,65.2,4.09,1.0,296.0,15.3,396.9,4.98] |24.0 |30.190308651220306|
|[0.02731,0.0,7.07,0.0,0.469,6.421,78.9,4.9671,2.0,242.0,17.8,396.9,9.14] |21.6 |25.18082356676625 |
|[0.02729,0.0,7.07,0.0,0.469,7.185,61.1,4.9671,2.0,242.0,17.8,392.83,4.03] |34.7 |30.879137984747175|
|[0.03237,0.0,2.18,0.0,0.458,6.998,45.8,6.0622,3.0,222.0,18.7,394.63,2.94] |33.4 |28.929827983486426|
|[0.06905,0.0,2.18,0.0,0.458,7.147,54.2,6.0622,3.0,222.0,18.7,396.9,5.33] |36.2 |28.221119235201883|
|[0.02985,0.0,2.18,0.0,0.458,6.43,58.7,6.0622,3.0,222.0,18.7,394.12,5.21] |28.7 |25.495148321722954|
|[0.08829,12.5,7.87,0.0,0.524,6.012,66.6,5.5605,5.0,311.0,15.2,395.6,12.43] |22.9 |22.994008794619425|
|[0.14455,12.5,7.87,0.0,0.524,6.172,96.1,5.9505,5.0,311.0,15.2,396.9,19.15] |27.1 |19.411743228512137|
|[0.21124,12.5,7.87,0.0,0.524,5.631,100.0,6.0821,5.0,311.0,15.2,386.63,29.93]|16.5 |11.116182643466054|
|[0.17004,12.5,7.87,0.0,0.524,6.004,85.9,6.5921,5.0,311.0,15.2,386.71,17.1] |18.9 |18.875755494056953|
|[0.22489,12.5,7.87,0.0,0.524,6.377,94.3,6.3467,5.0,311.0,15.2,392.52,20.45] |15.0 |18.883718884516938|
|[0.11747,12.5,7.87,0.0,0.524,6.009,82.9,6.2267,5.0,311.0,15.2,396.9,13.27] |18.9 |21.614326621187548|
|[0.09378,12.5,7.87,0.0,0.524,5.889,39.0,5.4509,5.0,311.0,15.2,390.5,15.71] |21.7 |20.79957586017244 |
|[0.62976,0.0,8.14,0.0,0.538,5.949,61.8,4.7075,4.0,307.0,21.0,396.9,8.26] |20.4 |19.738812505148353|
|[0.63796,0.0,8.14,0.0,0.538,6.096,84.5,4.4619,4.0,307.0,21.0,380.02,10.26] |18.2 |19.392210535216037|
|[0.62739,0.0,8.14,0.0,0.538,5.834,56.5,4.4986,4.0,307.0,21.0,395.62,8.47] |19.9 |19.454206945022317|
|[1.05393,0.0,8.14,0.0,0.538,5.935,29.3,4.4986,4.0,307.0,21.0,386.85,6.58] |23.1 |20.734847090514773|
|[0.7842,0.0,8.14,0.0,0.538,5.99,81.7,4.2579,4.0,307.0,21.0,386.75,14.67] |17.5 |16.90496844020901 |
|[0.80271,0.0,8.14,0.0,0.538,5.456,36.6,3.7965,4.0,307.0,21.0,288.99,11.69] |20.2 |16.068345653789162|
|[0.7258,0.0,8.14,0.0,0.538,5.727,69.5,3.7965,4.0,307.0,21.0,390.95,11.28] |18.2 |18.43260488799402 |
|[1.25179,0.0,8.14,0.0,0.538,5.57,98.1,3.7979,4.0,307.0,21.0,376.57,21.02] |13.6 |12.309407215651753|
|[0.85204,0.0,8.14,0.0,0.538,5.965,89.2,4.0123,4.0,307.0,21.0,392.53,13.83] |19.6 |17.672096530362325|
|[1.23247,0.0,8.14,0.0,0.538,6.142,91.7,3.9769,4.0,307.0,21.0,396.9,18.72] |15.2 |15.739106004141671|
|[0.98843,0.0,8.14,0.0,0.538,5.813,100.0,4.0952,4.0,307.0,21.0,394.54,19.88] |14.5 |13.668460075613503|
|[0.75026,0.0,8.14,0.0,0.538,5.924,94.1,4.3996,4.0,307.0,21.0,394.33,16.3] |15.6 |15.648887074136162|
|[0.84054,0.0,8.14,0.0,0.538,5.599,85.7,4.4546,4.0,307.0,21.0,303.42,16.51] |13.9 |13.240610975979848|
|[0.67191,0.0,8.14,0.0,0.538,5.813,90.3,4.682,4.0,307.0,21.0,376.88,14.81] |16.6 |15.464927727688355|
|[0.95577,0.0,8.14,0.0,0.538,6.047,88.8,4.4534,4.0,307.0,21.0,306.38,17.28] |14.8 |14.573852495098887|
|[0.77299,0.0,8.14,0.0,0.538,6.495,94.4,4.4547,4.0,307.0,21.0,387.94,12.8] |18.4 |19.628722416039974|
|[1.00245,0.0,8.14,0.0,0.538,6.674,87.3,4.239,4.0,307.0,21.0,380.23,11.98] |21.0 |20.966640126110367|
|[1.13081,0.0,8.14,0.0,0.538,5.713,94.1,4.233,4.0,307.0,21.0,360.17,22.6] |12.7 |11.223753725628281|
|[1.35472,0.0,8.14,0.0,0.538,6.072,100.0,4.175,4.0,307.0,21.0,376.73,13.04] |14.5 |18.093115056496018|
|[1.38799,0.0,8.14,0.0,0.538,5.95,82.0,3.99,4.0,307.0,21.0,232.6,27.71] |13.2 |8.316132718067578 |
|[1.15172,0.0,8.14,0.0,0.538,5.701,95.0,3.7872,4.0,307.0,21.0,358.77,18.35] |13.1 |14.115222877378901|
|[1.61282,0.0,8.14,0.0,0.538,6.096,96.9,3.7598,4.0,307.0,21.0,248.31,20.34] |13.5 |13.399626400996773|
|[0.06417,0.0,5.96,0.0,0.499,5.933,68.2,3.3603,5.0,279.0,19.2,396.9,9.68] |18.9 |23.722685383288756|
|[0.09744,0.0,5.96,0.0,0.499,5.841,61.4,3.3779,5.0,279.0,19.2,377.56,11.41] |20.0 |22.184711007711133|
|[0.08014,0.0,5.96,0.0,0.499,5.85,41.5,3.9342,5.0,279.0,19.2,396.9,8.77] |21.0 |23.07942553867743 |
|[0.17505,0.0,5.96,0.0,0.499,5.966,30.2,3.8473,5.0,279.0,19.2,393.43,10.13] |24.7 |22.8507702710117 |
|[0.02763,75.0,2.95,0.0,0.428,6.595,21.8,5.4011,3.0,252.0,18.3,395.63,4.32] |30.8 |31.28006993091833 |
|[0.03359,75.0,2.95,0.0,0.428,7.024,15.8,5.4011,3.0,252.0,18.3,395.62,1.98] |34.9 |34.21552849591169 |
|[0.12744,0.0,6.91,0.0,0.448,6.77,2.9,5.7209,3.0,233.0,17.9,385.41,4.84] |26.6 |28.27830217644444 |
|[0.1415,0.0,6.91,0.0,0.448,6.169,6.6,5.7209,3.0,233.0,17.9,383.37,5.81] |25.3 |25.40527129936151 |
|[0.15936,0.0,6.91,0.0,0.448,6.211,6.5,5.7209,3.0,233.0,17.9,394.46,7.44] |24.7 |24.787525821054338|
|[0.12269,0.0,6.91,0.0,0.448,6.069,40.0,5.7209,3.0,233.0,17.9,389.39,9.55] |21.2 |23.05600001254663 |
|[0.17142,0.0,6.91,0.0,0.448,5.682,33.8,5.1004,3.0,233.0,17.9,396.9,10.21] |19.3 |22.135516569051077|
|[0.18836,0.0,6.91,0.0,0.448,5.786,33.3,5.1004,3.0,233.0,17.9,396.9,14.15] |20.0 |20.374450708031127|
|[0.22927,0.0,6.91,0.0,0.448,6.03,85.5,5.6894,3.0,233.0,17.9,392.74,18.8] |16.6 |17.93473222778474 |
|[0.25387,0.0,6.91,0.0,0.448,5.399,95.3,5.87,3.0,233.0,17.9,396.9,30.81] |14.4 |8.706631427576724 |
|[0.21977,0.0,6.91,0.0,0.448,5.602,62.0,6.0877,3.0,233.0,17.9,396.9,16.2] |19.4 |17.178846384667416|
|[0.08873,21.0,5.64,0.0,0.439,5.963,45.7,6.8147,4.0,243.0,16.8,395.56,13.45] |19.7 |21.245543392473255|
|[0.04337,21.0,5.64,0.0,0.439,6.115,63.0,6.8147,4.0,243.0,16.8,393.97,9.43] |20.5 |24.037054024454424|
|[0.0536,21.0,5.64,0.0,0.439,6.511,21.1,6.8147,4.0,243.0,16.8,396.9,5.28] |25.0 |27.842340282261322|
|[0.04981,21.0,5.64,0.0,0.439,5.998,21.4,6.8147,4.0,243.0,16.8,396.9,8.43] |23.4 |24.133812775357192|
|[0.0136,75.0,4.0,0.0,0.41,5.888,47.6,7.3197,3.0,469.0,21.1,396.9,14.8] |18.9 |15.204212070389076|
|[0.01311,90.0,1.22,0.0,0.403,7.249,21.9,8.6966,5.0,226.0,17.9,395.93,4.81] |35.4 |31.17273816244207 |
|[0.02055,85.0,0.74,0.0,0.41,6.383,35.7,9.1876,2.0,313.0,17.3,396.9,5.77] |24.7 |25.003120474324106|
|[0.01432,100.0,1.32,0.0,0.411,6.816,40.5,8.3248,5.0,256.0,15.1,392.9,3.95] |31.6 |33.0501442598882 |
|[0.15445,25.0,5.13,0.0,0.453,6.145,29.2,7.8148,8.0,284.0,19.7,390.68,6.86] |23.3 |21.844459264785772|
|[0.10328,25.0,5.13,0.0,0.453,5.927,47.2,6.932,8.0,284.0,19.7,396.9,9.22] |19.6 |21.0226888914256 |
|[0.14932,25.0,5.13,0.0,0.453,5.741,66.2,7.2254,8.0,284.0,19.7,395.11,13.15] |18.7 |17.730297622043963|
|[0.17171,25.0,5.13,0.0,0.453,5.966,93.4,6.8185,8.0,284.0,19.7,378.08,14.44] |16.0 |18.300089563663125|
|[0.11027,25.0,5.13,0.0,0.453,6.456,67.8,7.2255,8.0,284.0,19.7,396.9,6.73] |22.2 |24.035933036957566|
|[0.1265,25.0,5.13,0.0,0.453,6.762,43.4,7.9809,8.0,284.0,19.7,395.58,9.5] |25.0 |22.611161965976976|
|[0.01951,17.5,1.38,0.0,0.4161,7.104,59.5,9.2229,3.0,216.0,18.6,393.24,8.05] |33.0 |23.7118587624982 |
|[0.03584,80.0,3.37,0.0,0.398,6.29,17.8,6.6115,4.0,337.0,16.1,396.9,4.67] |23.5 |30.26687448389147 |
|[0.04379,80.0,3.37,0.0,0.398,5.787,31.1,6.6115,4.0,337.0,16.1,396.9,10.24] |19.4 |25.27801769118446 |
|[0.05789,12.5,6.07,0.0,0.409,5.878,21.4,6.498,4.0,345.0,18.9,396.21,8.1] |22.0 |21.21194768977976 |
|[0.13554,12.5,6.07,0.0,0.409,5.594,36.8,6.498,4.0,345.0,18.9,396.9,13.09] |17.4 |17.390004693770543|
|[0.12816,12.5,6.07,0.0,0.409,5.885,33.0,6.498,4.0,345.0,18.9,396.9,8.79] |20.9 |20.870257832388887|
|[0.08826,0.0,10.81,0.0,0.413,6.417,6.6,5.2873,4.0,305.0,19.2,383.73,6.72] |24.2 |25.339885054239968|
|[0.15876,0.0,10.81,0.0,0.413,5.961,17.5,5.2873,4.0,305.0,19.2,376.94,9.88] |21.7 |21.776730166784652|
|[0.09164,0.0,10.81,0.0,0.413,6.065,7.8,5.2873,4.0,305.0,19.2,390.91,5.52] |22.8 |24.713947364951146|
|[0.19539,0.0,10.81,0.0,0.413,6.245,6.2,5.2873,4.0,305.0,19.2,377.17,7.54] |23.4 |24.148328355465534|
|[0.07896,0.0,12.83,0.0,0.437,6.273,6.0,4.2515,5.0,398.0,18.7,394.92,6.78] |24.1 |25.580263420202034|
|[0.09512,0.0,12.83,0.0,0.437,6.286,45.0,4.5026,5.0,398.0,18.7,383.23,8.94] |21.4 |23.999948769178417|
|[0.10153,0.0,12.83,0.0,0.437,6.279,74.5,4.0522,5.0,398.0,18.7,373.66,11.97] |20.0 |22.86080492474821 |
|[0.08707,0.0,12.83,0.0,0.437,6.14,45.8,4.0905,5.0,398.0,18.7,386.96,10.27] |20.8 |23.322112071030045|
|[0.05646,0.0,12.83,0.0,0.437,6.232,53.7,5.0141,5.0,398.0,18.7,386.4,12.34] |21.2 |21.254832078946972|
|[0.08387,0.0,12.83,0.0,0.437,5.874,36.6,4.5026,5.0,398.0,18.7,396.06,9.1] |20.3 |22.449559654604922|
|[0.04113,25.0,4.86,0.0,0.426,6.727,33.5,5.4007,4.0,281.0,19.0,396.9,5.29] |28.0 |28.480573367799344|
|[0.04462,25.0,4.86,0.0,0.426,6.619,70.4,5.4007,4.0,281.0,19.0,395.63,7.22] |23.9 |27.017829328681373|
|[0.03659,25.0,4.86,0.0,0.426,6.302,32.2,5.4007,4.0,281.0,19.0,396.9,6.72] |24.8 |26.054218988842194|
|[0.03551,25.0,4.86,0.0,0.426,6.167,46.7,5.4007,4.0,281.0,19.0,390.64,7.51] |22.9 |25.044752060188472|
|[0.05059,0.0,4.49,0.0,0.449,6.389,48.0,4.7794,3.0,247.0,18.5,396.9,9.62] |23.9 |24.832618514589313|
|[0.05735,0.0,4.49,0.0,0.449,6.63,56.1,4.4377,3.0,247.0,18.5,392.3,6.53] |26.6 |27.89263939520159 |
|[0.05188,0.0,4.49,0.0,0.449,6.015,45.1,4.4272,3.0,247.0,18.5,395.99,12.86] |22.5 |22.091164776225657|
|[0.07151,0.0,4.49,0.0,0.449,6.121,56.8,3.7476,3.0,247.0,18.5,395.15,8.44] |22.2 |25.871768906772093|
|[0.0566,0.0,3.41,0.0,0.489,7.007,86.3,3.4217,2.0,270.0,17.8,396.9,5.5] |23.6 |30.81825218329388 |
|[0.05302,0.0,3.41,0.0,0.489,7.079,63.1,3.4145,2.0,270.0,17.8,396.06,5.7] |28.7 |30.971595744003565|
|[0.04684,0.0,3.41,0.0,0.489,6.417,66.1,3.0923,2.0,270.0,17.8,392.18,8.81] |22.6 |27.12110239078219 |
|[0.03932,0.0,3.41,0.0,0.489,6.405,73.9,3.0921,2.0,270.0,17.8,393.55,8.2] |22.0 |27.430097731675495|
|[0.04203,28.0,15.04,0.0,0.464,6.442,53.6,3.6659,4.0,270.0,18.2,395.01,8.16] |22.9 |28.908354939561406|
|[0.02875,28.0,15.04,0.0,0.464,6.211,28.9,3.6659,4.0,270.0,18.2,396.33,6.21] |25.0 |29.082827105623092|
|[0.04294,28.0,15.04,0.0,0.464,6.249,77.3,3.615,4.0,270.0,18.2,396.9,10.59] |20.6 |26.937832632384236|
|[0.12204,0.0,2.89,0.0,0.445,6.625,57.8,3.4952,2.0,276.0,18.0,357.98,6.65] |28.4 |28.635280481715768|
|[0.11504,0.0,2.89,0.0,0.445,6.163,69.6,3.4952,2.0,276.0,18.0,391.83,11.34] |21.4 |24.640422958037888|
|[0.12083,0.0,2.89,0.0,0.445,8.069,76.0,3.4952,2.0,276.0,18.0,396.9,4.21] |38.7 |35.96955000322432 |
|[0.08187,0.0,2.89,0.0,0.445,7.82,36.9,3.4952,2.0,276.0,18.0,393.53,3.57] |43.8 |35.29890121386673 |
|[0.0686,0.0,2.89,0.0,0.445,7.416,62.5,3.4952,2.0,276.0,18.0,396.9,6.19] |33.2 |32.35610981009244 |
+----------------------------------------------------------------------------+-----+------------------+
数学知识:
向量
向量拓展:欧式里德距离(相似度)、余弦相似度
//算欧式里德相似度、余弦相似度
import java.util.logging.{Level, Logger}
import org.apache.spark.ml.linalg._
import org.apache.spark.sql._
object App {
def main(args: Array[String]): Unit = {
Logger.getLogger("org").setLevel(Level.WARNING)
//读数据
val spark = SparkSession.builder().appName("similarityCalculate").master("local[*]").getOrCreate()
val ds: Dataset[String] = spark.read.textFile("data/persons.txt")
import spark.implicits._
//加工数据
val attri = ds.map(line => {
val ar = line.split(",")
val userid = ar(0)
val vector = Vectors.dense(ar.tail.map(item => item.toDouble))
(userid, vector)
}).toDF("userid", "vector")
attri.show(100, false)
import org.apache.spark.sql.functions._
val joined = attri.join(attri.toDF("b_userid", "b_vector"), col("userid") < col("b_userid"))
//计算欧几里得距离相似度、余弦相似度
spark.udf.register("euc", eucliSimilarity)
spark.udf.register("cos", cosSimilarity)
//oined.selectExpr("userid", "b_userid", "euc(vector,b_vector)", "cos(vector,b_vector)").show()
joined.createTempView("joined")
spark.sql(
"""
|select
|userid,
|b_userid,
|euc(vector,b_vector),
|cos(vector,b_vector)
|from joined
""".stripMargin).show(200,false)
//打印结果
}
//欧几里得相似度
val eucliSimilarity = (uv: Vector, buv: Vector) => {
1 / (1 + Math.pow(Vectors.sqdist(uv, buv), 0.5))
}
//余弦相似度
val cosSimilarity = (uv: Vector, buv: Vector) => {
val sum1 = uv.toArray.map(item => Math.pow(item, 2)).sum
val sum2 = buv.toArray.map(item => Math.pow(item, 2)).sum
val tuples = uv.toArray.zip(buv.toArray)
val sum3 = tuples.map(item => item._1 * item._2).sum
sum3 / (sum1 * sum2)
}
}
/*
+----persons.txt----+
u1,20,172,98,3999
u2,30,174,70,6999
u3,22,171,78,4999
u4,40,178,75,7999
u5,33,172,68,6999
u6,31,188,68,8999
u7,50,172,58,8099
u8,18,172,99,999
u9,70,160,44,1999
+----输出结果----+
+------+--------+-------------------------+-------------------------+
|userid|b_userid|UDF:euc(vector, b_vector)|UDF:cos(vector, b_vector)|
+------+--------+-------------------------+-------------------------+
|u1 |u2 |3.332058269756651E-4 |3.566143930793115E-8 |
|u1 |u3 |9.987989647192792E-4 |4.99210425243552E-8 |
|u1 |u4 |2.4992998066749325E-4 |3.120334262477809E-8 |
|u1 |u5 |3.3320247790333556E-4 |3.5661395232678165E-8 |
|u1 |u6 |1.9995490223656236E-4 |2.7735692648707768E-8 |
|u1 |u7 |2.438248392885663E-4 |3.0817318676858244E-8 |
|u1 |u8 |3.332221667160612E-4 |2.4252302194838063E-7 |
|u1 |u9 |4.994033335448089E-4 |1.2429473207172002E-7 |
|u2 |u3 |4.997415712132636E-4 |2.8548568958543982E-8 |
|u2 |u4 |9.989306472244388E-4 |1.7849858206328596E-8 |
|u2 |u5 |0.1951941016011038 |2.0399131030943938E-8 |
|u2 |u6 |4.997375754634144E-4 |1.5867017705841344E-8 |
|u2 |u7 |9.080597966973777E-4 |1.7629908498384664E-8 |
|u2 |u8 |1.6663660497596988E-4 |1.3818168544604422E-7 |
|u2 |u9 |1.9995012468525684E-4 |7.10140480147098E-8 |
|u3 |u4 |3.332151901115194E-4 |2.4980364474953618E-8 |
|u3 |u5 |4.997362643788828E-4 |2.8548670489039806E-8 |
|u3 |u6 |2.499338456621841E-4 |2.2204923975223692E-8 |
|u3 |u7 |3.2245674655075133E-4 |2.4672014941383436E-8 |
|u3 |u8 |2.499339393612514E-4 |1.9374745563676967E-7 |
|u3 |u9 |3.331560083851983E-4 |9.944297027049687E-8 |
|u4 |u5 |9.98934139515095E-4 |1.7850018870123525E-8 |
|u4 |u6 |9.988862484373682E-4 |1.3884385937169933E-8 |
|u4 |u7 |0.009699049240532491 |1.5427018165936453E-8 |
|u4 |u8 |1.428351404244442E-4 |1.2086241316083946E-7 |
|u4 |u9 |1.666338375630366E-4 |6.213190885762279E-8 |
|u5 |u6 |4.997338919664413E-4 |1.586716435702208E-8 |
|u5 |u7 |9.081193831836474E-4 |1.7630092300025892E-8 |
|u5 |u8 |1.666361491300729E-4 |1.3817428819747693E-7 |
|u5 |u9 |1.9995165586315277E-4 |7.101417386546018E-8 |
|u6 |u7 |0.0011093875601599508 |1.3713415495054653E-8 |
|u6 |u8 |1.249830237975477E-4 |1.0739544562964691E-7 |
|u6 |u9 |1.428325392852325E-4 |5.522204853293927E-8 |
|u7 |u8 |1.408214582261319E-4 |1.1933383836016132E-7 |
|u7 |u9 |1.6390592660457887E-4 |6.136345718376115E-8 |
|u8 |u9 |9.960832005447605E-4 |4.856381190959876E-7 |
+------+--------+-------------------------+-------------------------+
*/
用户画像模型标签
情感分析

实现步骤:
- 中文分词
- 文本的特征向量化
向量化方案:

tf-idf=tf * idf
idf=lg(文档总数/1+出现这个词的文档)
**
if-idf代码实现版:
package com.atguigu.sparkmllib
import org.apache.commons.codec.digest.DigestUtils
import org.apache.log4j.{Level, Logger}
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.sql.{Row, SparkSession}
object TFIDF {
def main(args: Array[String]): Unit = {
Logger.getLogger("org").setLevel(Level.WARN)
//加载数据
val spark = SparkSession.builder().appName("tfidf").master("local[*]").getOrCreate()
val df = spark.read.option("header", "true").csv("data/tfidf/train.txt")
df.cache()
val totalCount = df.count()
// 获取单纯数(不重复)
val wordNum = df.rdd.map(row => row.getAs("doc").toString).flatMap(item => item.split(" ")).distinct().count()
val bucketNum = (26).toInt
//加工数据
import spark.implicits._
// 求tf向量
val tfVector = df.map({
case Row(id: String, doc: String) => {
val arr = doc.split(" ")
val wordcount = arr.map(item => (item, 1)).groupBy(_._1).mapValues(_.size)
val vectorarr = Vectors.zeros(bucketNum).toArray
wordcount.foreach(item => {
val index: Int = (DigestUtils.md5Hex(item._1).hashCode & Int.MaxValue % bucketNum).toInt
vectorarr(index) = item._2
})
(id, vectorarr)
}
})
tfVector.show(30, false)
//求idf向量
val dtfVector = tfVector.rdd.map(item => item._2.map(item => {
if (item != 0) {
1
} else {
0
}
})).reduce((i, current) => i.zip(current).map((item1) => {
item1._1 + item1._2
})).map(value => Math.log10(totalCount / (0.01 + value)))
//算出最总tf-idf向量
val bdTFVector = spark.sparkContext.broadcast(dtfVector)
val tfIdfVectors = tfVector.map(item => {
val tfIdfValues = item._2.zip(bdTFVector.value).map(tp => tp._1 * tp._2)
(item._1,tfIdfValues)
}).toDF("id","tf-idf")
tfIdfVectors.show(40,false)
}
}
if-idf算法调用版:
package com.atguigu.sparkmllib
import org.apache.log4j.{Level, Logger}
import org.apache.spark.ml.feature.{HashingTF, IDF}
import org.apache.spark.sql.SparkSession
object TFIDFV2 {
def main(args: Array[String]): Unit = {
Logger.getLogger("org").setLevel(Level.WARN)
//加载数据
val spark = SparkSession.builder().appName("tfidf").master("local[*]").getOrCreate()
val df = spark.read.option("header", "true").csv("data/tfidf/train.txt")
df.cache()
val totalCount = df.count()
// 获取单纯数(不重复)
val wordNum = df.rdd.map(row => row.getAs("doc").toString).flatMap(item => item.split(" ")).distinct().count()
val df2 = df.selectExpr("id", "split(doc,' ') as words")
val hashingTF = new HashingTF()
.setInputCol("words")
.setOutputCol("tfvec")
.setNumFeatures(Math.pow(wordNum, 2).toInt)
val tfdf = hashingTF.transform(df2)
tfdf.show(10, false)
val idf = new IDF()
.setInputCol("tfvec")
.setOutputCol("idfvec")
val model = idf.fit(tfdf)
val tfidf = model.transform(tfdf)
tfidf.show(10,false)
}
}
+---+------------------------+---------------------------------------------+------------------------------------------------------------------------------------------------------------------------+
|id |words |tfvec |idfvec |
+---+------------------------+---------------------------------------------+------------------------------------------------------------------------------------------------------------------------+
|1 |[a, a, a, x, x, y] |(121,[59,102,106],[2.0,1.0,3.0]) |(121,[59,102,106],[0.3083013596545167,0.15415067982725836,2.5418935811616112]) |
|2 |[b, b, c, x, y] |(121,[59,74,101,102],[1.0,1.0,2.0,1.0]) |(121,[59,74,101,102],[0.15415067982725836,1.252762968495368,2.505525936990736,0.15415067982725836]) |
|3 |[d, d, d, d, x, x, x, y]|(121,[59,96,102],[3.0,4.0,1.0]) |(121,[59,96,102],[0.46245203948177505,3.3891914415488147,0.15415067982725836]) |
|4 |[e, d, x, x] |(121,[59,77,96],[2.0,1.0,1.0]) |(121,[59,77,96],[0.3083013596545167,1.252762968495368,0.8472978603872037]) |
|5 |[g, f, y, y] |(121,[52,102,110],[1.0,2.0,1.0]) |(121,[52,102,110],[1.252762968495368,0.3083013596545167,1.252762968495368]) |
|6 |[i, i, i, a, h, h, x, y]|(121,[3,59,69,102,106],[3.0,1.0,2.0,1.0,1.0])|(121,[3,59,69,102,106],[3.758288905486104,0.15415067982725836,2.505525936990736,0.15415067982725836,0.8472978603872037])|
+---+------------------------+---------------------------------------------+------------------------------------------------------------------------------------------------------------------------+
情感分析完整实现代码
package com.atguigu.sparkmllib
import com.atguigu.commons.SparkUtil
import com.hankcs.hanlp.HanLP
import com.hankcs.hanlp.corpus.tag.Nature
import org.apache.log4j.{Level, Logger}
import org.apache.spark.ml.classification.NaiveBayes
import org.apache.spark.ml.feature.{HashingTF, IDF}
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
/* @description: 情感分析
* @author: chengjunhao
* @date: 2020/4/5 10:49
* pos.txt:酒店正面评论
* neg.txt: 酒店负面评论
* 数据来源:https://www.aitechclub.com/data-detail?data_id=23
*/
object CommentClassify {
def main(args: Array[String]): Unit = {
Logger.getLogger("org").setLevel(Level.WARN)
//加载数据
import org.apache.spark.sql.functions._
val spark = SparkUtil.getSparkSession("CommentClassify")
def readFile(path: String) = {
spark.read.format("text").load(path).select(decode(col("value"), "GBK"))
}
//加载酒店正负面评论
val posData = readFile("data/comment/hotel/pos.txt")
val negData = readFile("data/comment/hotel/neg.txt")
//加工数据
import scala.collection.JavaConversions._
import spark.implicits._
def transformDF(data: DataFrame, label: Int) = {
val df = data.map(row => {
val str = row(0).toString
var terms = HanLP.segment(str)
val words = terms.filter(n => {
n.nature != Nature.w && n.nature != Nature.u && n.nature != Nature.y
}).map(term => term.word)
(label, words)
}).toDF("label", "comment")
df
}
val posdf = transformDF(posData, 1)
val negdf = transformDF(negData, 0)
val alldf = posdf.union(negdf)
//获取tf-idf
val htf = new HashingTF()
.setInputCol("comment")
.setOutputCol("comment_tf")
.setNumFeatures(1000000)
val tf: DataFrame = htf.transform(alldf)
val idf = new IDF()
.setInputCol("comment_tf")
.setOutputCol("comment_idf")
val idfModel = idf.fit(tf)
var attr = idfModel.transform(tf)
//提取特征向量
attr = attr.drop("comment", "comment_tf")
.withColumnRenamed("comment_idf", "comment_tfidf")
//划分训练集和测试集
val Array(train, test) = attr.randomSplit(Array(9, 1))
val bayes = new NaiveBayes()
.setFeaturesCol("comment_tfidf")
.setLabelCol("label")
.setSmoothing(1.0)
//训练模型
val bayesModel = bayes.fit(train)
//测试模型
val testResult = bayesModel.transform(test)
val testResult2 = testResult.drop("comment_tfidf", "rawPrediction")
testResult2.show(100,false)
val correctCount = testResult2.filter("prediction==label").count()
val totalCount = testResult2.count()
println("测试样本数:"+totalCount)
println("预测正确数:"+correctCount)
println("预测准确率:" + (correctCount.toDouble / totalCount.toDouble) * 100 + "%")
}
}
结果:
+-----+-------------------------------------------+----------+
测试样本数:1006
预测正确数:891
预测准确率:88.56858846918489%
+-----+-------------------------------------------+----------+
前100条预测明细:
+-----+-------------------------------------------+----------+
|label|probability |prediction|
+-----+-------------------------------------------+----------+
|1 |[9.678717143234523E-59,1.0] |1.0 |
|1 |[8.775062678782543E-142,1.0] |1.0 |
|1 |[6.285157015859581E-287,1.0] |1.0 |
|1 |[6.116057531716703E-113,1.0] |1.0 |
|1 |[4.737281726805822E-150,1.0] |1.0 |
|1 |[3.851913734023942E-10,0.9999999996148086] |1.0 |
|1 |[1.3879510356857376E-104,1.0] |1.0 |
|1 |[1.0,3.651790756375366E-62] |0.0 |
|1 |[7.70413622355533E-40,1.0] |1.0 |
|1 |[3.9760940936939995E-28,1.0] |1.0 |
|1 |[2.78706766383241E-48,1.0] |1.0 |
|1 |[9.288763280945696E-166,1.0] |1.0 |
|1 |[0.9997137425609458,2.8625743905415067E-4] |0.0 |
|1 |[9.84622497818606E-59,1.0] |1.0 |
|1 |[0.9999999999988398,1.1600819770604523E-12]|0.0 |
|1 |[2.3374019217668343E-97,1.0] |1.0 |
|1 |[1.367557692324186E-9,0.9999999986324424] |1.0 |
|1 |[9.978630188774865E-68,1.0] |1.0 |
|1 |[6.0737540285315896E-86,1.0] |1.0 |
|1 |[1.2923974592986347E-17,1.0] |1.0 |
|1 |[5.087678913605508E-9,0.999999994912321] |1.0 |
|1 |[4.766936602531464E-8,0.999999952330634] |1.0 |
|1 |[1.141142922229361E-127,1.0] |1.0 |
|1 |[2.970308596304562E-13,0.9999999999997029] |1.0 |
|1 |[1.4296598862118238E-42,1.0] |1.0 |
|1 |[6.971012507333794E-148,1.0] |1.0 |
|1 |[7.114643055746503E-99,1.0] |1.0 |
|1 |[6.245858335922584E-82,1.0] |1.0 |
|1 |[1.2987703764237814E-40,1.0] |1.0 |
|1 |[1.9927201000367488E-136,1.0] |1.0 |
|1 |[1.0586783109230259E-66,1.0] |1.0 |
|1 |[4.475148365444469E-34,1.0] |1.0 |
|1 |[6.42265390063968E-149,1.0] |1.0 |
|1 |[1.7670324846527388E-53,1.0] |1.0 |
|1 |[3.214215818203068E-29,1.0] |1.0 |
|1 |[2.7031436124761305E-91,1.0] |1.0 |
|1 |[1.1178924868445532E-27,1.0] |1.0 |
|1 |[7.097686736295291E-14,0.999999999999929] |1.0 |
|1 |[4.095785012296749E-222,1.0] |1.0 |
|1 |[3.89706391895739E-299,1.0] |1.0 |
|1 |[1.580874182227348E-116,1.0] |1.0 |
|1 |[1.6808571855446918E-26,1.0] |1.0 |
|1 |[2.122226182799178E-43,1.0] |1.0 |
|1 |[2.4510384299167502E-95,1.0] |1.0 |
|1 |[4.948393044822232E-71,1.0] |1.0 |
|1 |[9.214193434525995E-64,1.0] |1.0 |
|1 |[1.0829059305785953E-56,1.0] |1.0 |
|1 |[3.1353496199852386E-69,1.0] |1.0 |
|1 |[3.230419370748443E-89,1.0] |1.0 |
|1 |[1.208868987138325E-47,1.0] |1.0 |
|1 |[0.0,1.0] |1.0 |
|1 |[0.9999999999998364,1.6365818485051497E-13]|0.0 |
|1 |[1.526303473662237E-139,1.0] |1.0 |
|1 |[5.831580863834545E-109,1.0] |1.0 |
|1 |[9.693350571265665E-20,1.0] |1.0 |
|1 |[6.323437863366204E-103,1.0] |1.0 |
|1 |[1.0,1.4768880153660372E-17] |0.0 |
|1 |[2.8273558435432666E-115,1.0] |1.0 |
|1 |[1.3864521422689692E-198,1.0] |1.0 |
|1 |[1.0,4.0873159958445865E-38] |0.0 |
|1 |[4.379279554305029E-88,1.0] |1.0 |
|1 |[0.999999999925818,7.418191559676693E-11] |0.0 |
|1 |[5.020061656540557E-25,1.0] |1.0 |
|1 |[2.4363804471091126E-56,1.0] |1.0 |
|1 |[1.0,2.5598337472000046E-23] |0.0 |
|1 |[5.518535257484849E-36,1.0] |1.0 |
|1 |[5.1309103906743525E-33,1.0] |1.0 |
|1 |[1.6224323763499578E-29,1.0] |1.0 |
|1 |[4.536961252116455E-61,1.0] |1.0 |
|1 |[9.894093466435771E-76,1.0] |1.0 |
|1 |[1.2451317061247508E-17,1.0] |1.0 |
|1 |[9.024371536509077E-109,1.0] |1.0 |
|1 |[2.0559019081889856E-67,1.0] |1.0 |
|1 |[1.4145467277937284E-26,1.0] |1.0 |
|1 |[3.8234346133769476E-44,1.0] |1.0 |
|1 |[8.29804708513053E-69,1.0] |1.0 |
|1 |[1.8247535837043456E-11,0.9999999999817524]|1.0 |
|1 |[2.523500736280223E-28,1.0] |1.0 |
|1 |[9.94053781304072E-38,1.0] |1.0 |
|1 |[3.8448382830707386E-14,0.9999999999999616]|1.0 |
|1 |[2.7033352857318492E-39,1.0] |1.0 |
|1 |[0.07353055667343208,0.9264694433265679] |1.0 |
|1 |[2.562594162188542E-30,1.0] |1.0 |
|1 |[1.0,2.5308953260610245E-17] |0.0 |
|1 |[0.9999886571179135,1.134288208641545E-5] |0.0 |
|1 |[8.509917767568811E-222,1.0] |1.0 |
|1 |[6.086933620962101E-97,1.0] |1.0 |
|1 |[1.0,3.96823938785063E-18] |0.0 |
|1 |[8.568268333992698E-4,0.9991431731666007] |1.0 |
|1 |[1.6693327787350398E-60,1.0] |1.0 |
|1 |[7.210961234876599E-149,1.0] |1.0 |
|1 |[2.2681097229354707E-206,1.0] |1.0 |
|1 |[7.480989153144268E-24,1.0] |1.0 |
|1 |[3.5723396221655744E-47,1.0] |1.0 |
|1 |[0.9755649920158308,0.024435007984169203] |0.0 |
|1 |[5.725331858163097E-52,1.0] |1.0 |
|1 |[9.179762909765966E-15,0.9999999999999909] |1.0 |
|1 |[2.3344359417810775E-56,1.0] |1.0 |
|1 |[1.1420687103520189E-6,0.9999988579312895] |1.0 |
|1 |[1.9164354713814646E-12,0.9999999999980835]|1.0 |
+-----+-------------------------------------------+----------+
性别预测
特征数据如下:
特征值大部分属于离散型,比较适合朴素贝叶斯算法。当然有两个字段是连续的:”30天内的购买单数”、”30天内的消费总金额”,对于这两个字段,可以进行区间化。
用户流失风险
特征值如下:
特征值属于连续型,而且预测值受特征值大小的影响。所以比较适合逻辑回归算法。


