Spark 目前支持 Hash 分区和 Range 分区,和用户自定义分区。Hash 分区为当前的默认分区。分区器直接决定了 RDD 中分区的个数、RDD 中每条数据经过 Shuffle 后进入哪个分区,进而决定了 Reduce 的个数。
➢ 只有 Key-Value 类型的 RDD 才有分区器,非 Key-Value 类型的 RDD 分区的值是 None
➢ 每个 RDD 的分区 ID 范围:0 ~ (numPartitions - 1),决定这个值是属于那个分区的。

Hash 分区:对于给定的 key,计算其 hashCode,并除以分区个数取余

  1. class HashPartitioner(partitions: Int) extends Partitioner {
  2. require(partitions >= 0, s"Number of partitions ($partitions) cannot be negative.")
  3. def numPartitions: Int = partitions
  4. def getPartition(key: Any): Int = key match {
  5. case null => 0
  6. case _ => Utils.nonNegativeMod(key.hashCode, numPartitions)
  7. }
  8. override def equals(other: Any): Boolean = other match {
  9. case h: HashPartitioner =>
  10. h.numPartitions == numPartitions
  11. case _ =>
  12. false
  13. }
  14. override def hashCode: Int = numPartitions
  15. }

Range 分区:将一定范围内的数据映射到一个分区中,尽量保证每个分区数据均匀,而且分区间有序

  1. class RangePartitioner[K : Ordering : ClassTag, V](
  2. partitions: Int,
  3. rdd: RDD[_ <: Product2[K, V]],
  4. private var ascending: Boolean = true,
  5. val samplePointsPerPartitionHint: Int = 20)
  6. extends Partitioner {
  7. // A constructor declared in order to maintain backward compatibility for Java, when we add the
  8. // 4th constructor parameter samplePointsPerPartitionHint. See SPARK-22160.
  9. // This is added to make sure from a bytecode point of view, there is still a 3-arg ctor.
  10. def this(partitions: Int, rdd: RDD[_ <: Product2[K, V]], ascending: Boolean) = {
  11. this(partitions, rdd, ascending, samplePointsPerPartitionHint = 20)
  12. }
  13. // We allow partitions = 0, which happens when sorting an empty RDD under the default settings.
  14. require(partitions >= 0, s"Number of partitions cannot be negative but found $partitions.")
  15. require(samplePointsPerPartitionHint > 0,
  16. s"Sample points per partition must be greater than 0 but found $samplePointsPerPartitionHint")
  17. private var ordering = implicitly[Ordering[K]]
  18. // An array of upper bounds for the first (partitions - 1) partitions
  19. private var rangeBounds: Array[K] = {
  20. if (partitions <= 1) {
  21. Array.empty
  22. } else {
  23. // This is the sample size we need to have roughly balanced output partitions, capped at 1M.
  24. // Cast to double to avoid overflowing ints or longs
  25. val sampleSize = math.min(samplePointsPerPartitionHint.toDouble * partitions, 1e6)
  26. // Assume the input partitions are roughly balanced and over-sample a little bit.
  27. val sampleSizePerPartition = math.ceil(3.0 * sampleSize / rdd.partitions.length).toInt
  28. val (numItems, sketched) = RangePartitioner.sketch(rdd.map(_._1), sampleSizePerPartition)
  29. if (numItems == 0L) {
  30. Array.empty
  31. } else {
  32. // If a partition contains much more than the average number of items, we re-sample from it
  33. // to ensure that enough items are collected from that partition.
  34. val fraction = math.min(sampleSize / math.max(numItems, 1L), 1.0)
  35. val candidates = ArrayBuffer.empty[(K, Float)]
  36. val imbalancedPartitions = mutable.Set.empty[Int]
  37. sketched.foreach { case (idx, n, sample) =>
  38. if (fraction * n > sampleSizePerPartition) {
  39. imbalancedPartitions += idx
  40. } else {
  41. // The weight is 1 over the sampling probability.
  42. val weight = (n.toDouble / sample.length).toFloat
  43. for (key <- sample) {
  44. candidates += ((key, weight))
  45. }
  46. }
  47. }
  48. if (imbalancedPartitions.nonEmpty) {
  49. // Re-sample imbalanced partitions with the desired sampling probability.
  50. val imbalanced = new PartitionPruningRDD(rdd.map(_._1), imbalancedPartitions.contains)
  51. val seed = byteswap32(-rdd.id - 1)
  52. val reSampled = imbalanced.sample(withReplacement = false, fraction, seed).collect()
  53. val weight = (1.0 / fraction).toFloat
  54. candidates ++= reSampled.map(x => (x, weight))
  55. }
  56. RangePartitioner.determineBounds(candidates, math.min(partitions, candidates.size))
  57. }
  58. }
  59. }
  60. def numPartitions: Int = rangeBounds.length + 1
  61. private var binarySearch: ((Array[K], K) => Int) = CollectionsUtils.makeBinarySearch[K]
  62. def getPartition(key: Any): Int = {
  63. val k = key.asInstanceOf[K]
  64. var partition = 0
  65. if (rangeBounds.length <= 128) {
  66. // If we have less than 128 partitions naive search
  67. while (partition < rangeBounds.length && ordering.gt(k, rangeBounds(partition))) {
  68. partition += 1
  69. }
  70. } else {
  71. // Determine which binary search method to use only once.
  72. partition = binarySearch(rangeBounds, k)
  73. // binarySearch either returns the match location or -[insertion point]-1
  74. if (partition < 0) {
  75. partition = -partition-1
  76. }
  77. if (partition > rangeBounds.length) {
  78. partition = rangeBounds.length
  79. }
  80. }
  81. if (ascending) {
  82. partition
  83. } else {
  84. rangeBounds.length - partition
  85. }
  86. }
  87. override def equals(other: Any): Boolean = other match {
  88. case r: RangePartitioner[_, _] =>
  89. r.rangeBounds.sameElements(rangeBounds) && r.ascending == ascending
  90. case _ =>
  91. false
  92. }
  93. override def hashCode(): Int = {
  94. val prime = 31
  95. var result = 1
  96. var i = 0
  97. while (i < rangeBounds.length) {
  98. result = prime * result + rangeBounds(i).hashCode
  99. i += 1
  100. }
  101. result = prime * result + ascending.hashCode
  102. result
  103. }
  104. @throws(classOf[IOException])
  105. private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException {
  106. val sfactory = SparkEnv.get.serializer
  107. sfactory match {
  108. case js: JavaSerializer => out.defaultWriteObject()
  109. case _ =>
  110. out.writeBoolean(ascending)
  111. out.writeObject(ordering)
  112. out.writeObject(binarySearch)
  113. val ser = sfactory.newInstance()
  114. Utils.serializeViaNestedStream(out, ser) { stream =>
  115. stream.writeObject(scala.reflect.classTag[Array[K]])
  116. stream.writeObject(rangeBounds)
  117. }
  118. }
  119. }
  120. @throws(classOf[IOException])
  121. private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException {
  122. val sfactory = SparkEnv.get.serializer
  123. sfactory match {
  124. case js: JavaSerializer => in.defaultReadObject()
  125. case _ =>
  126. ascending = in.readBoolean()
  127. ordering = in.readObject().asInstanceOf[Ordering[K]]
  128. binarySearch = in.readObject().asInstanceOf[(Array[K], K) => Int]
  129. val ser = sfactory.newInstance()
  130. Utils.deserializeViaNestedStream(in, ser) { ds =>
  131. implicit val classTag = ds.readObject[ClassTag[Array[K]]]()
  132. rangeBounds = ds.readObject[Array[K]]()
  133. }
  134. }
  135. }
  136. }

自定义分区器

  1. package part
  2. import org.apache.log4j.{Level, Logger}
  3. import org.apache.spark.{Partitioner, SparkConf, SparkContext}
  4. object Spark_Part {
  5. def main(args: Array[String]): Unit = {
  6. //屏蔽日志信息
  7. Logger.getLogger("org").setLevel(Level.ERROR)
  8. //创建sparkconf
  9. val conf = new SparkConf().setMaster("local[2]").setAppName("wc")
  10. //创建spark程序入口
  11. val sc = new SparkContext(conf)
  12. //创建集合对象
  13. val list = List(("nba","************"),("cba","************"),
  14. ("wnba","************"),("nba","************"))
  15. //将集合对象写进RDD里 并创建三个分区
  16. val inputRDD = sc.makeRDD(list,3)
  17. //将新的RDD使用partitionby方法自定义分区
  18. val value = inputRDD.partitionBy(new Mypartitioner)
  19. //保存到文件里
  20. value.saveAsTextFile("output")
  21. sc.stop()
  22. }
  23. /**
  24. * 第一 : 自定义分区器
  25. * 第二 : 重写方法
  26. */
  27. class Mypartitioner extends Partitioner{
  28. //分区数量
  29. override def numPartitions: Int = 3
  30. //根据数据的key值 返回数据所在的分区索引 (从0开始)
  31. override def getPartition(key: Any): Int = {
  32. //方式一 : 用if做判断
  33. // if(key == "nba"){
  34. // 0
  35. // }else if (key == "cba"){
  36. // 1
  37. // }else{
  38. // 2
  39. // }
  40. //方式二 : 用模式匹配
  41. //如果是nba 放到0号分区,如果是cba 放到1号分区,如果是其他,放到2号分区
  42. key match {
  43. case "nba" => 0
  44. case "cba" => 1
  45. case _ => 2
  46. }
  47. }
  48. }
  49. }