集合创建RDD,指定分区数

  1. val conf: SparkConf = new SparkConf().setAppName("MYTest").setMaster("local[*]")
  2. val sc: SparkContext = new SparkContext(conf)
  3. //集合中4个数据---设置分区数为3---实际输出3个分区
  4. val rdd: RDD[Int] = sc.makeRDD(List(1,2,3,4),3)
  5. //分区中数据分布 0分区->1, 1分区->2, 3分区->3 4
  6. rdd.saveAsTextFile("D:\\studycoderun\\SparkStudy\\output")

数据是根据什么样的规则分到对应的分区中的呢?

源码

makeRDD

def makeRDD[T: ClassTag](
  seq: Seq[T],
  numSlices: Int = defaultParallelism): RDD[T] = withScope {
  parallelize(seq, numSlices)
}

此时 numSlices: Int 就不等于默认的分区数了,而是自己传进来的分区数

parallelize(seq, numSlices) - 创建RDD

// Methods for creating RDDs

/** Distribute a local Scala collection to form an RDD.
   *
   * @note Parallelize acts lazily. If `seq` is a mutable collection and is altered after the call
   * to parallelize and before the first action on the RDD, the resultant RDD will reflect the
   * modified collection. Pass a copy of the argument to avoid this.
   * @note avoid using `parallelize(Seq())` to create an empty `RDD`. Consider `emptyRDD` for an
   * RDD with no partitions, or `parallelize(Seq[T]())` for an RDD of `T` with empty partitions.
   */
def parallelize[T: ClassTag](
  seq: Seq[T],
  numSlices: Int = defaultParallelism): RDD[T] = withScope {
  assertNotStopped()
  new ParallelCollectionRDD[T](this, seq, numSlices, Map[Int, Seq[String]]())
}

ParallelCollectionRDD

private[spark] class ParallelCollectionRDD[T: ClassTag](
    sc: SparkContext,
    @transient private val data: Seq[T],
    numSlices: Int,
    locationPrefs: Map[Int, Seq[String]])
    extends RDD[T](sc, Nil) {
  // TODO: Right now, each split sends along its full data, even if later down the RDD chain it gets
  // cached. It might be worthwhile to write the data to a file in the DFS and read it in the split
  // instead.
  // UPDATE: A parallel collection can be checkpointed to HDFS, which achieves this goal.

  override def getPartitions: Array[Partition] = {
    val slices = ParallelCollectionRDD.slice(data, numSlices).toArray
    slices.indices.map(i => new ParallelCollectionPartition(id, i, slices(i))).toArray
  }

  override def compute(s: Partition, context: TaskContext): Iterator[T] = {
    new InterruptibleIterator(context, s.asInstanceOf[ParallelCollectionPartition[T]].iterator)
  }

  override def getPreferredLocations(s: Partition): Seq[String] = {
    locationPrefs.getOrElse(s.index, Nil)
  }
}

RDD中的getPartitions - 模板方法设计模式

先不着急看ParallelCollectionRDD,我们先来看看RDD中的getPartitions

/**
   * Get the array of partitions of this RDD, taking into account whether the
   * RDD is checkpointed or not.
   */
final def partitions: Array[Partition] = {
  checkpointRDD.map(_.partitions).getOrElse {
    if (partitions_ == null) {
      partitions_ = getPartitions
      partitions_.zipWithIndex.foreach { case (partition, index) =>
        require(partition.index == index,
                s"partitions($index).partition == ${partition.index}, but it should equal $index")
      }
    }
    partitions_
  }
}

getPartitions是一个抽象方法,模板方法的设计模式,RDD中只定义了上面的一套模板,但是具体的方法由子类实现。

/**
   * Implemented by subclasses to return the set of partitions in this RDD. This method will only
   * be called once, so it is safe to implement a time-consuming computation in it.
   *
   * The partitions in this array must satisfy the following property:
   *   `rdd.partitions.zipWithIndex.forall { case (partition, index) => partition.index == index }`
   */
protected def getPartitions: Array[Partition]

ParallelCollectionRDD中的getPartitions - 具体实现

override def getPartitions: Array[Partition] = {
  val slices = ParallelCollectionRDD.slice(data, numSlices).toArray
  slices.indices.map(i => new ParallelCollectionPartition(id, i, slices(i))).toArray
}

将数据和分区数传给slice方法,去看看slice方法

slice

  • 对我们传进来的seq进行匹配,我们传进来的是list,匹配最后一个
  • 将list集合转换成一个数组
  • 将数组的长度和分区数传给了slice中定义的positions方法

    /**
     * Slice a collection into numSlices sub-collections. One extra thing we do here is to treat Range
     * collections specially, encoding the slices as other Ranges to minimize memory cost. This makes
     * it efficient to run Spark over RDDs representing large sets of numbers. And if the collection
     * is an inclusive Range, we use inclusive range for the last slice.
     */
    def slice[T: ClassTag](seq: Seq[T], numSlices: Int): Seq[Seq[T]] = {
    //切片小于1,报异常
    if (numSlices < 1) {
      throw new IllegalArgumentException("Positive number of slices required")
    }
    // Sequences need to be sliced at the same set of index positions for operations
    // like RDD.zip() to behave as expected
    def positions(length: Long, numSlices: Int): Iterator[(Int, Int)] = {
      (0 until numSlices).iterator.map { i =>
        val start = ((i * length) / numSlices).toInt
        val end = (((i + 1) * length) / numSlices).toInt
        (start, end)
      }
    }
    //对我们传进来的seq进行匹配,我们传进来的是list,匹配最后一个
    seq match {
      case r: Range =>
      positions(r.length, numSlices).zipWithIndex.map { case ((start, end), index) =>
        // If the range is inclusive, use inclusive range for the last slice
        if (r.isInclusive && index == numSlices - 1) {
          new Range.Inclusive(r.start + start * r.step, r.end, r.step)
        }
        else {
          new Range(r.start + start * r.step, r.start + end * r.step, r.step)
        }
      }.toSeq.asInstanceOf[Seq[Seq[T]]]
      case nr: NumericRange[_] =>
      // For ranges of Long, Double, BigInteger, etc
      val slices = new ArrayBuffer[Seq[T]](numSlices)
      var r = nr
      for ((start, end) <- positions(nr.length, numSlices)) {
        val sliceSize = end - start
        slices += r.take(sliceSize).asInstanceOf[Seq[T]]
        r = r.drop(sliceSize)
      }
      slices
      case _ =>
      //将list集合转换成一个数组
      val array = seq.toArray // To prevent O(n^2) operations for List etc
      //将数组的长度和分区数传给了slice中定义的positions方法
      positions(array.length, numSlices).map { case (start, end) =>
        array.slice(start, end).toSeq
      }.toSeq
    }
    }
    

    positions

    将数组的长度和分区数传给了slice中定义的positions方法
    返回一个可迭代的元组,存储每一个分区中数据的起始位置和结束位置
    [start, end) 包含start,不包含end

    def positions(length: Long, numSlices: Int): Iterator[(Int, Int)] = {
    (0 until numSlices).iterator.map { i =>
      val start = ((i * length) / numSlices).toInt
      val end = (((i + 1) * length) / numSlices).toInt
      (start, end)
    }
    }
    

    传进来的原始List[1,2,3,4]
    length = 4个数
    numSlices = 3个分区
    i = 0 -> start = (04)/3 = 0, end = (0+1) 4 / 3 = 1 [0,1)
    i = 1-> start = (14)/3 = 1, end = (1+1) 4 / 3 = 2 [1,2)
    i = 2-> start = (24)/3 = 2, end = (2+1) 4 / 3 = 4 [2,4)

    array.slice(start, end)

    数据的切片逻辑就和上面数字演示的是一样的

    def slice(from: Int, until: Int): Repr = {
    val lo    = math.max(from, 0)
    val hi    = math.min(math.max(until, 0), length)
    val elems = math.max(hi - lo, 0)
    val b     = newBuilder
    b.sizeHint(elems)
    
    var i = lo
    while (i < hi) {
      b += self(i)
      i += 1
    }
    b.result()
    }
    

    这里走完返回一个seq集合,就又要回到ParallelCollectionRDD的getPartitions方法,每一个分区中具体有哪些数据就知道了,把刚刚切片的数据放入对应的分区。

    override def getPartitions: Array[Partition] = {
    val slices = ParallelCollectionRDD.slice(data, numSlices).toArray
    slices.indices.map(i => new ParallelCollectionPartition(id, i, slices(i))).toArray
    }
    

    ParallelCollectionRDD中的compute方法会对每一个分区中的数据进行迭代。

    override def compute(s: Partition, context: TaskContext): Iterator[T] = {
    new InterruptibleIterator(context, s.asInstanceOf[ParallelCollectionPartition[T]].iterator)
    }
    

    总结

  • 集合创建RDD,指定几个分区,实际就创建几个分区

  • 确定分区中有哪些数据,最关键的就是positions方法,确定[start, end),具体过程看前面的演示。
    def positions(length: Long, numSlices: Int): Iterator[(Int, Int)] = {
    (0 until numSlices).iterator.map { i =>
      val start = ((i * length) / numSlices).toInt
      val end = (((i + 1) * length) / numSlices).toInt
      (start, end)
    }
    }