加载数据

从mysql上下载数据保存在本都用于训练。

  1. def loadData(): Unit ={
  2. startTime = DateTime.now();
  3. println("开始时间:" + startTime)
  4. val url = "jdbc:mysql://127.0.0.1:3306/purchase?characterEncoding=UTF-8"
  5. val username = "root"
  6. val password = "123456"
  7. val sql = "(select user_id,product_id from history LIMIT 10000) T"
  8. val properties = new Properties
  9. properties.put("user", username)
  10. properties.put("password", password)
  11. properties.put("driver", "com.mysql.cj.jdbc.Driver")
  12. properties.put("fetchsize", "100")
  13. val dataFrame = sparkSession.read.jdbc(url, sql, properties)
  14. dataFrame.write.option("header", true).csv("/data/recommend/data")
  15. }

数据预处理

将数据处理成MinHashLSH可用于计算的数据格式

  • 从mysql加载的数据是这样的
    | user_id | product_id | | —- | —- | | 1 | 1 | | 1 | 2 | | 2 | 1 | | 2 | 2 | | 2 | 3 |

  • product_id按照user_id集合

    1. // 从本地加载数据
    2. val dataFrame = sparkSession.read.option("header", true).csv("/data/recommend/data")
    3. val frame = dataFrame.groupBy("user_id").agg(collect_set("product_id").as("pids"))


    聚合后的数据格式

user_id product_id
1 [1,3]
2 [1,2,3]
  • 数据预处理成bag of word形式的0-1向量,且用sparse向量来表示
    1. val vectorizer = new CountVectorizer().setInputCol("pids").setOutputCol("features").setBinary(true)
    2. val cvModel: CountVectorizerModel = vectorizer.fit(frame)
    3. val vectorDFSparse: DataFrame = cvModel.transform(frame).select("user_id","features")

    输出的稀疏向量
    features里面的第一位数3指的是字典单词的个数是3,后面的[0,2]分别表示单词位于第0个、第1个位置,[1.0,1.0,1.0]表示单词在本文档中出现的次数。
user_id features
1 (3,[0,2],[1.0,1.0])
2 (3,[0,1,2],[1.0,1.0,1.0])
  • 过滤掉sparse向量中features为空的数据
    1. val vectorDFSparseFilter = sparkSession.createDataFrame(vectorDFSparse.rdd.map(row =>
    2. (row.getAs[String]("user_id") ,row.getAs[SparseVector]("features"))).map(x =>
    3. (x._1,x._2,x._2.numNonzeros)).filter(x => x._3 >= 1).map(x => (x._1,x._2))).toDF("user_id","features")
  • 调用MinHashLSH 来近似计算jaccard距离
    • 向量模型
      1. val mh = new MinHashLSH().setNumHashTables(100).setInputCol("features").setOutputCol("hashes")
      2. val model: MinHashLSHModel = mh.fit(vectorDFSparseFilter)
  • 计算数据
    1. val distance: DataFrame = model.approxSimilarityJoin(
    2. vectorDFSparseFilter, vectorDFSparseFilter, 0.7, "JaccardDistance")
    3. .select(col("datasetA.user_id").alias("user_id1"),col("datasetB.user_id").alias("user_id2"),
    4. col("JaccardDistance"))
  • 格式化数据,并保存

    1. val ratio = distance.rdd.map(x => {
    2. val node1 = x.getString(0)
    3. val node2 = x.getString(1)
    4. val overlapRatio = 1 - x.getDouble(2)
    5. if(node1 < node2) ((node1, node2),overlapRatio) else ((node2, node1),overlapRatio)
    6. }).filter(x => x._1._1 != x._1._2).map(x=> (x._1.toString, x._2.toString ))
    7. sparkSession.createDataFrame(ratio).write
    8. .format("jdbc")
    9. .option("url", "jdbc:mysql://localhost:3306/my?characterEncoding=UTF-8&useSSL=false")
    10. .option("dbtable", "recommend")
    11. .option("user", "root")
    12. .option("password", "123654")
    13. .save()