1. def saveDf(df: DataFrame, table: String): Unit = {
    2. df.createOrReplaceTempView(s"${table}")
    3. }
    4. def makeJsonDf(joinSql: String, jsonColumnName: String, otherCollectFields: String): DataFrame = {
    5. //先join获取全部字段信息
    6. //拿到t1和t2两张表
    7. val karr = joinSql.split(" ")
    8. val k1 = karr.indexOf("t1")
    9. val k2 = karr.indexOf("t2")
    10. var table1 = karr(k1-1)
    11. var table2 = karr(k2-1)
    12. val df1: DataFrame = spark.sql("select * from " + table1)
    13. var df2: DataFrame = spark.sql("select * from " + table2)
    14. val df: DataFrame = spark.sql(joinSql)
    15. var realSql: String = joinSql.replace("\n", "").trim()
    16. println(s"sql: ${realSql}")
    17. println("join表schema:")
    18. df1.printSchema()
    19. println()
    20. val mfield = realSql.slice("select ".length(), realSql.indexOf(" from "))
    21. //默认以t1开头的字段作为group by分组字段
    22. var groupByFields = mfield.split(",").filter(_.contains("t1.")).map(item => {
    23. val kstr = item.trim()
    24. if(kstr.contains("distinct ")){
    25. kstr.substring("distinct t1.".length)
    26. }
    27. else{
    28. kstr.substring(3)
    29. }
    30. })
    31. //不是t1开头的其他字段作为collect字段
    32. var fields = mfield.split(",").filter(_.contains(".")).filter(!_.contains("t1.")).map(item => {
    33. val kstr = item.trim()
    34. if(kstr.contains(" as ")){
    35. kstr.split(" as ")(1)
    36. }
    37. else{
    38. kstr.substring(3)
    39. }
    40. })
    41. println("group fields....")
    42. groupByFields.foreach(println)
    43. if (groupByFields.contains("*")) {
    44. df1.schema.fieldNames.foreach(one => {
    45. groupByFields = groupByFields :+ one
    46. })
    47. }
    48. if(fields.contains("*")){
    49. df2.schema.fieldNames.foreach(one => {
    50. fields = fields :+ one
    51. })
    52. }
    53. otherCollectFields.split(",").foreach(one => {
    54. fields = fields :+ one
    55. })
    56. var groupByColumns: Array[Column] = Array[Column]()
    57. groupByFields.distinct.foreach(one => {
    58. if (df1.columns.contains(one)) {
    59. groupByColumns = groupByColumns :+ df("t1." + one)
    60. }
    61. })
    62. val mts = df.columns
    63. mts.foreach(println)
    64. var collectColumns: Array[Column] = Array[Column]()
    65. fields.distinct.foreach(field => {
    66. if (df.columns.contains(field)) {
    67. collectColumns = collectColumns :+ lit(field)
    68. collectColumns = collectColumns :+ df("t2." + field)
    69. }
    70. })
    71. println("group fields:")
    72. groupByColumns.foreach(println)
    73. println()
    74. println("collect fields:")
    75. collectColumns.foreach(println)
    76. println()
    77. val df5: DataFrame = df.groupBy(groupByColumns: _*).agg(collect_list(map(collectColumns: _*)).as(jsonColumnName))
    78. val df6 = df5.withColumn(jsonColumnName, to_json(df5(jsonColumnName)))
    79. df6
    80. }

    使用示例

    1. var sql = "select t1.registno,t1.accidentno,t1.policyno,t2.casetype as caseflag from ods_new.ods_car_lregist t1 left join ods_new.ods_car_lclaim t2 on t1.accidentno = t2.accidentno and t1.policyno = t2.policyno"
    2. var claimDf = spark.sql(sql)
    3. saveDf(claimDf, "tmp_claim")
    4. // 理赔 join 报案
    5. sql = "select t1.*,t2.frameno,t2.licenseno,t2.reportormobile,t2.comcode,t2.accidentno,t2.policyno from tmp_claim t1 left join ods_new.ods_car_lregist t2 on t1.accidentno = t2.accidentno and t1.registno = t2.registno";
    6. claimDf = makeJsonDf(sql, "incident", "")
    7. saveDf(claimDf, "tmp_claim")