代码

  1. import org.apache.spark.ml.recommendation.ALS
  2. import org.apache.spark.SparkConf
  3. import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
  4. import org.apache.spark.ml.evaluation.RegressionEvaluator
  5. object ALSDemo {
  6. case class Rating(userId: Int, movieId: Int, rating: Float, timestamp: Long)
  7. def main(args: Array[String]): Unit = {
  8. System.setProperty("hadoop.home.dir", "H:\\winutils\\winutils-master\\hadoop-2.6.0")
  9. val spark = SparkSession.builder()
  10. .config(
  11. new SparkConf()
  12. .setMaster("local[*]")
  13. .setAppName("text")
  14. )
  15. // 添加满足 inner join 的配置
  16. .config("spark.sql.crossJoin.enabled","true")
  17. .getOrCreate()
  18. import spark.implicits._
  19. def parseRating(str: String): Rating = {
  20. val fields = str.split("::")
  21. Rating(fields(0).toInt, fields(1).toInt, fields(2).toFloat, fields(3).toLong)
  22. }
  23. val ratings = spark.sparkContext.textFile("data/scala/sample_movielens_ratings.txt")
  24. .map(parseRating)
  25. .toDF()
  26. ratings.show(false)
  27. val _als: ALS = new ALS()
  28. // 正则化参数
  29. .setRegParam(0.01)
  30. // 最大循环迭代
  31. .setMaxIter(10)
  32. // 打分列
  33. .setRatingCol("rating")
  34. // 物品列
  35. .setItemCol("movieId")
  36. // 用户列
  37. .setUserCol("userId")
  38. val Array(tranning, test) = ratings.randomSplit(Array(0.8, 0.2)) // 训练集和测试集
  39. val model = _als.fit(tranning) // 训练出模型
  40. val frame: DataFrame = model.transform(test) // 预测数据
  41. frame.show(false)
  42. // 以方差的方式评估
  43. val evaluator: RegressionEvaluator = new RegressionEvaluator()
  44. // 预测列
  45. .setPredictionCol("prediction")
  46. // 本来标签列
  47. .setLabelCol("rating")
  48. .setMetricName("rmse")
  49. val d: Double = evaluator.evaluate(frame)
  50. println(d)
  51. }
  52. }

ALS补充参数

  1. .setRank() //特征的数量
  2. .setSeed(10) // 种子
  3. .setNonnegative(true) // 求解的最小二乘值是否非负
  4. .setRegParam(0.01) // 正则化
  5. .setMaxIter(10) // 迭代次数
  6. .setRatingCol(“rating”) // 打分列
  7. .setItemCol(“movieId”) // 物品列
  8. .setUserCol(“userId”) // 用户列
  9. setImplicitPrefs() // 显示反馈还是隐示反馈

    ALS特殊补充(源码方法)

    1. @Since("1.3.0")
    2. override def transformSchema(schema: StructType): StructType = {
    3. // user and item will be cast to Int
    4. // user 和 item 必须是 Int类型不能是string
    5. SchemaUtils.checkNumericType(schema, $(userCol))
    6. SchemaUtils.checkNumericType(schema, $(itemCol))
    7. // 返回的预测列是float类型
    8. SchemaUtils.appendColumn(schema, $(predictionCol), FloatType)
    9. }
  10. transformSchema 检查userID, itemID 是否为Int

  11. ALS.train 模型训练 得到用户特征矩阵和物品特征矩阵
  12. new ALSModel 创建ALS模型

    ALS 默认参数

    1. ratings: RDD[Rating[ID]],
    2. rank: Int = 10,
    3. numUserBlocks: Int = 10,
    4. numItemBlocks: Int = 10,
    5. maxIter: Int = 10,
    6. regParam: Double = 1.0,
    7. implicitPrefs: Boolean = false,
    8. alpha: Double = 1.0,
    9. nonnegative: Boolean = false,
    10. intermediateRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK,
    11. finalRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK,
    12. checkpointInterval: Int = 10,