加载数据
从mysql上下载数据保存在本都用于训练。
def loadData(): Unit ={startTime = DateTime.now();println("开始时间:" + startTime)val url = "jdbc:mysql://127.0.0.1:3306/purchase?characterEncoding=UTF-8"val username = "root"val password = "123456"val sql = "(select user_id,product_id from history LIMIT 10000) T"val properties = new Propertiesproperties.put("user", username)properties.put("password", password)properties.put("driver", "com.mysql.cj.jdbc.Driver")properties.put("fetchsize", "100")val dataFrame = sparkSession.read.jdbc(url, sql, properties)dataFrame.write.option("header", true).csv("/data/recommend/data")}
数据预处理
将数据处理成MinHashLSH可用于计算的数据格式
从mysql加载的数据是这样的
| user_id | product_id | | —- | —- | | 1 | 1 | | 1 | 2 | | 2 | 1 | | 2 | 2 | | 2 | 3 |将
product_id按照user_id集合// 从本地加载数据val dataFrame = sparkSession.read.option("header", true).csv("/data/recommend/data")val frame = dataFrame.groupBy("user_id").agg(collect_set("product_id").as("pids"))
聚合后的数据格式
| user_id | product_id |
|---|---|
| 1 | [1,3] |
| 2 | [1,2,3] |
- 数据预处理成bag of word形式的0-1向量,且用sparse向量来表示
val vectorizer = new CountVectorizer().setInputCol("pids").setOutputCol("features").setBinary(true)val cvModel: CountVectorizerModel = vectorizer.fit(frame)val vectorDFSparse: DataFrame = cvModel.transform(frame).select("user_id","features")
输出的稀疏向量
features里面的第一位数3指的是字典单词的个数是3,后面的[0,2]分别表示单词位于第0个、第1个位置,[1.0,1.0,1.0]表示单词在本文档中出现的次数。
| user_id | features |
|---|---|
| 1 | (3,[0,2],[1.0,1.0]) |
| 2 | (3,[0,1,2],[1.0,1.0,1.0]) |
- 过滤掉sparse向量中features为空的数据
val vectorDFSparseFilter = sparkSession.createDataFrame(vectorDFSparse.rdd.map(row =>(row.getAs[String]("user_id") ,row.getAs[SparseVector]("features"))).map(x =>(x._1,x._2,x._2.numNonzeros)).filter(x => x._3 >= 1).map(x => (x._1,x._2))).toDF("user_id","features")
- 调用MinHashLSH 来近似计算jaccard距离
- 向量模型
val mh = new MinHashLSH().setNumHashTables(100).setInputCol("features").setOutputCol("hashes")val model: MinHashLSHModel = mh.fit(vectorDFSparseFilter)
- 向量模型
- 计算数据
val distance: DataFrame = model.approxSimilarityJoin(vectorDFSparseFilter, vectorDFSparseFilter, 0.7, "JaccardDistance").select(col("datasetA.user_id").alias("user_id1"),col("datasetB.user_id").alias("user_id2"),col("JaccardDistance"))
格式化数据,并保存
val ratio = distance.rdd.map(x => {val node1 = x.getString(0)val node2 = x.getString(1)val overlapRatio = 1 - x.getDouble(2)if(node1 < node2) ((node1, node2),overlapRatio) else ((node2, node1),overlapRatio)}).filter(x => x._1._1 != x._1._2).map(x=> (x._1.toString, x._2.toString ))sparkSession.createDataFrame(ratio).write.format("jdbc").option("url", "jdbc:mysql://localhost:3306/my?characterEncoding=UTF-8&useSSL=false").option("dbtable", "recommend").option("user", "root").option("password", "123654").save()
