1. 概要

Spark RDD主要由DependencyPartition、Partitioner组成。Partition记录了数据Split的逻辑,Dependency记录的是Transformation操作过程中Partition的演化,Partitioner是Shuffle过程中key重分区时的策略,即计算key决定k-v属于哪个分区。
Spark中分区器直接决定了RDD中分区的个数、RDD中每条数据经过Shuffle过程属于哪个分区和Reduce的个数。
注意:

  1. 只有Key-Value类型的RDD才有分区函数,非Key-Value类型的RDD无分区函数(None),但是也是有分区的(ParallelCollectionPartition)。
  2. 每个RDD的分区ID范围:0 ~ (numPartitions-1),决定这个值是属于那个分区的。

    2. 分区器作用

    Partitioner是在Shuffle阶段起作用,无论对于MapReduce还是Spark,Shuffle都是重中之重,因为Shuffle的性能直接影响着整个程序。先了解下shuffle:详细探究Spark的shuffle实现,Shuffle涉及到网络开销及可能导致的数据倾斜问题,是调优关注的重点。
    image.png

    3. ParallelCollectionPartition

    一般我们在spark-shell练习RDD的一些算子时,都喜欢用sc.parallelize()生成一个RDD。通过这种方式生成的RDD就是ParallelCollectionRDD。下面是SparkContext的parallelize函数实现,最核心的代码就是创建了一ParallelCollectionRDD对象。

    1. def parallelize[T: ClassTag](
    2. seq: Seq[T],
    3. numSlices: Int = defaultParallelism): RDD[T] = withScope {
    4. assertNotStopped()
    5. new ParallelCollectionRDD[T](this, seq, numSlices, Map[Int, Seq[String]]())
    6. }

    下面是ParallelCollectionRDD类的实现代码,构造函数有4个参数,分别是SparkContext,集合数据,分区数以及优选位置信息,它继承了RDD抽象类,调用RDD构造函数时,第二个参数填了Nil,表示该RDD是没有依赖的父RDD的,它就是RDD生成的一个源头。如果通过map等一系列转换操作后,生成的子RDD最终指向的RDD依赖就是它了。

    1. private[spark] class ParallelCollectionRDD[T: ClassTag](
    2. sc: SparkContext,
    3. @transient private val data: Seq[T],
    4. numSlices: Int,
    5. locationPrefs: Map[Int, Seq[String]])
    6. extends RDD[T](sc, Nil) {
    7. override def getPartitions: Array[Partition] = {
    8. val slices = ParallelCollectionRDD.slice(data, numSlices).toArray
    9. slices.indices.map(i => new ParallelCollectionPartition(id, i, slices(i))).toArray
    10. }
    11. override def compute(s: Partition, context: TaskContext): Iterator[T] = {
    12. new InterruptibleIterator(context, s.asInstanceOf[ParallelCollectionPartition[T]].iterator)
    13. }
    14. override def getPreferredLocations(s: Partition): Seq[String] = {
    15. locationPrefs.getOrElse(s.index, Nil)
    16. }
    17. }

    它实现了三个方法,分别是getPartitions,compute,getPreferredLocations。其实它只需要实现前面两个方法就可以,后面一个实现方法是多余的。PreferredLocation的作用是在需要计算某个分区的数据时,如果知道这个数据在什么位置,那么就在该位置上提交任务进行计算,这样可以减少IO开销。当然我们这个ParallelCollectionRDD是没有优先位置的,在parallelize函数中,这个信息就填了一个空的map。
    getPartitions方法,获取该RDD的所有分区信息。该函数首先把数据集合均匀的切分为numSlices份,然后每一份数据生成一个ParallelCollectionPartition分区对象,然后返回所有的ParallelCollectionPartition分区。
    ParallelCollectionPartition分区,主要有三个数据,rddId,slice(切片号,其实就是分区号),values(分区的数据)。它首先定义了一个iterator,指向values.iterator,紧接着重载了hashCode()方法,然后再重载了equals方法,需要类型相同,rddId以及slice相同才认为是同一个分区。后面把index字段重载为slice,最后writeObject,readObject函数是序列化,反序列化使用的,这里不深入研究。

    1. private[spark] class ParallelCollectionPartition[T: ClassTag](
    2. var rddId: Long,
    3. var slice: Int,
    4. var values: Seq[T]
    5. ) extends Partition with Serializable {
    6. def iterator: Iterator[T] = values.iterator
    7. override def hashCode(): Int = (41 * (41 + rddId) + slice).toInt
    8. override def equals(other: Any): Boolean = other match {
    9. case that: ParallelCollectionPartition[_] =>
    10. this.rddId == that.rddId && this.slice == that.slice
    11. case _ => false
    12. }
    13. override def index: Int = slice
    14. @throws(classOf[IOException])
    15. private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException {
    16. ...
    17. }
    18. @throws(classOf[IOException])
    19. private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException {
    20. ...
    21. }
    22. }

    有了分区后,就可以计算分区的数据了,ParallelCollectionRDD的compute函数,首先把传入的Partition对象动态转换为ParallelCollectionPartition对象,然后取得ParallelCollectionPartition对象的iterator,最后用InterruptibleIterator函数把这个iterator重新包装了一下,并返回该迭代器。返回的迭代器其本质就是分区中数据的迭代器,有了这个迭代器,就可以获取这个分区的数据了。
    从上面的分析我们可以看出,ParallelCollectionRDD只有分区,没有分区器,也不需要依赖任何其它的RDD。

    4. HashPartitioner

    HashPartitioner分区的原理:对于给定的key,计算其hashCode,并除于分区的个数取余,如果余数小于0,则用“余数+分区”的个数,最后返回的值就是这个key所属的分区ID。实现如下:

    1. /**
    2. * A [[org.apache.spark.Partitioner]] that implements hash-based partitioning using
    3. * Java's `Object.hashCode`.
    4. *
    5. * Java arrays have hashCodes that are based on the arrays' identities rather than their contents,
    6. * so attempting to partition an RDD[Array[_]] or RDD[(Array[_], _)] using a HashPartitioner will
    7. * produce an unexpected or incorrect result.
    8. */
    9. class HashPartitioner(partitions: Int) extends Partitioner {
    10. require(partitions >= 0, s"Number of partitions ($partitions) cannot be negative.")
    11. def numPartitions: Int = partitions
    12. def getPartition(key: Any): Int = key match {
    13. case null => 0
    14. case _ => Utils.nonNegativeMod(key.hashCode, numPartitions)
    15. }
    16. override def equals(other: Any): Boolean = other match {
    17. case h: HashPartitioner =>
    18. h.numPartitions == numPartitions
    19. case _ =>
    20. false
    21. }
    22. override def hashCode: Int = numPartitions
    23. }

    5. RangePartitioner

    RangePartitioner基于水塘抽样算法实现,其目的在于从包含n个项目的集合S中选取k个样本,其中n为一很大或未知的数量,尤其适用于不能把所有n个项目都存放到内存的情况。算法如下:

    1. S中抽取首k项放入「水塘」中
    2. 对于每一个S[j]项(j k):
    3. 随机产生一个范围0j的整数r
    4. r < k 则把水塘中的第r项换成S[j]项

    RangePartitioner作用:将一定范围内的数映射到某一个分区内,在实现中,分界的算法尤为重要。RDD的Transformation中,sortBy、sortByKey,使用RangePartitioner实现。

    6. 自定义分区器

  • 示例

需要继承“org.apache.spark.Partitioner”类,实现如下:

  1. import org.apache.spark.Partitioner
  2. class MySparkPartition(numParts: Int) extends Partitioner {
  3. override def numPartitions: Int = numParts
  4. override def getPartition(key: Any): Int = {
  5. val domain = new java.net.URL(key.toString).getHost()
  6. val code = (domain.hashCode % numPartitions)
  7. if (code < 0) {
  8. code + numPartitions
  9. } else {
  10. code
  11. }
  12. }
  13. override def equals(other: Any): Boolean = other match {
  14. case mypartition: MySparkPartition =>
  15. mypartition.numPartitions == numPartitions
  16. case _ =>
  17. false
  18. }
  19. override def hashCode: Int = numPartitions
  20. }
  21. /**
  22. * def numPartitions:这个方法需要返回你想要创建分区的个数;
  23. * def getPartition:这个函数需要对输入的key做计算,然后返回该key的分区ID,范围一定是0到numPartitions-1;
  24. * equals():这个是Java标准的判断相等的函数,之所以要求用户实现这个函数是因为Spark内部会比较两个RDD的分区是否一样。
  25. * /
  • 应用 ```scala import org.apache.spark.{SparkConf, SparkContext}

object UseMyPartitioner {

def main(args: Array[String]) { val conf=new SparkConf() .setMaster(“local[2]”) .setAppName(“UseMyPartitioner”) .set(“spark.app.id”,”test-partition-id”) val sc=new SparkContext(conf)

  1. // 读取hdfs文件
  2. val lines=sc.textFile("hdfs://hadoop3:8020/user/test/word.txt")
  3. val splitMap=lines.flatMap(line=>line.split("\t")).map(word=>(word,2)) // 注意:RDD一定要是key-value
  4. // 保存
  5. splitMap.partitionBy(new MySparkPartition(3)).saveAsTextFile("F:/partrion/test")
  6. sc.stop()

}

} ```

参考

CSDN:Spark RDD之Partitioner
https://blog.csdn.net/u011564172/article/details/54667057