代码
import org.apache.spark.ml.recommendation.ALSimport org.apache.spark.SparkConfimport org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}import org.apache.spark.ml.evaluation.RegressionEvaluatorobject ALSDemo {case class Rating(userId: Int, movieId: Int, rating: Float, timestamp: Long)def main(args: Array[String]): Unit = {System.setProperty("hadoop.home.dir", "H:\\winutils\\winutils-master\\hadoop-2.6.0")val spark = SparkSession.builder().config(new SparkConf().setMaster("local[*]").setAppName("text"))// 添加满足 inner join 的配置.config("spark.sql.crossJoin.enabled","true").getOrCreate()import spark.implicits._def parseRating(str: String): Rating = {val fields = str.split("::")Rating(fields(0).toInt, fields(1).toInt, fields(2).toFloat, fields(3).toLong)}val ratings = spark.sparkContext.textFile("data/scala/sample_movielens_ratings.txt").map(parseRating).toDF()ratings.show(false)val _als: ALS = new ALS()// 正则化参数.setRegParam(0.01)// 最大循环迭代.setMaxIter(10)// 打分列.setRatingCol("rating")// 物品列.setItemCol("movieId")// 用户列.setUserCol("userId")val Array(tranning, test) = ratings.randomSplit(Array(0.8, 0.2)) // 训练集和测试集val model = _als.fit(tranning) // 训练出模型val frame: DataFrame = model.transform(test) // 预测数据frame.show(false)// 以方差的方式评估val evaluator: RegressionEvaluator = new RegressionEvaluator()// 预测列.setPredictionCol("prediction")// 本来标签列.setLabelCol("rating").setMetricName("rmse")val d: Double = evaluator.evaluate(frame)println(d)}}
ALS补充参数
- .setRank() //特征的数量
- .setSeed(10) // 种子
- .setNonnegative(true) // 求解的最小二乘值是否非负
- .setRegParam(0.01) // 正则化
- .setMaxIter(10) // 迭代次数
- .setRatingCol(“rating”) // 打分列
- .setItemCol(“movieId”) // 物品列
- .setUserCol(“userId”) // 用户列
setImplicitPrefs() // 显示反馈还是隐示反馈
ALS特殊补充(源码方法)
@Since("1.3.0")override def transformSchema(schema: StructType): StructType = {// user and item will be cast to Int// user 和 item 必须是 Int类型不能是stringSchemaUtils.checkNumericType(schema, $(userCol))SchemaUtils.checkNumericType(schema, $(itemCol))// 返回的预测列是float类型SchemaUtils.appendColumn(schema, $(predictionCol), FloatType)}
transformSchema 检查userID, itemID 是否为Int
- ALS.train 模型训练 得到用户特征矩阵和物品特征矩阵
- new ALSModel 创建ALS模型
ALS 默认参数
ratings: RDD[Rating[ID]],rank: Int = 10,numUserBlocks: Int = 10,numItemBlocks: Int = 10,maxIter: Int = 10,regParam: Double = 1.0,implicitPrefs: Boolean = false,alpha: Double = 1.0,nonnegative: Boolean = false,intermediateRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK,finalRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK,checkpointInterval: Int = 10,
