集合创建RDD,指定分区数
val conf: SparkConf = new SparkConf().setAppName("MYTest").setMaster("local[*]")val sc: SparkContext = new SparkContext(conf)//集合中4个数据---设置分区数为3---实际输出3个分区val rdd: RDD[Int] = sc.makeRDD(List(1,2,3,4),3)//分区中数据分布 0分区->1, 1分区->2, 3分区->3 4rdd.saveAsTextFile("D:\\studycoderun\\SparkStudy\\output")
源码
makeRDD
def makeRDD[T: ClassTag](
seq: Seq[T],
numSlices: Int = defaultParallelism): RDD[T] = withScope {
parallelize(seq, numSlices)
}
此时 numSlices: Int 就不等于默认的分区数了,而是自己传进来的分区数
parallelize(seq, numSlices) - 创建RDD
// Methods for creating RDDs
/** Distribute a local Scala collection to form an RDD.
*
* @note Parallelize acts lazily. If `seq` is a mutable collection and is altered after the call
* to parallelize and before the first action on the RDD, the resultant RDD will reflect the
* modified collection. Pass a copy of the argument to avoid this.
* @note avoid using `parallelize(Seq())` to create an empty `RDD`. Consider `emptyRDD` for an
* RDD with no partitions, or `parallelize(Seq[T]())` for an RDD of `T` with empty partitions.
*/
def parallelize[T: ClassTag](
seq: Seq[T],
numSlices: Int = defaultParallelism): RDD[T] = withScope {
assertNotStopped()
new ParallelCollectionRDD[T](this, seq, numSlices, Map[Int, Seq[String]]())
}
ParallelCollectionRDD
private[spark] class ParallelCollectionRDD[T: ClassTag](
sc: SparkContext,
@transient private val data: Seq[T],
numSlices: Int,
locationPrefs: Map[Int, Seq[String]])
extends RDD[T](sc, Nil) {
// TODO: Right now, each split sends along its full data, even if later down the RDD chain it gets
// cached. It might be worthwhile to write the data to a file in the DFS and read it in the split
// instead.
// UPDATE: A parallel collection can be checkpointed to HDFS, which achieves this goal.
override def getPartitions: Array[Partition] = {
val slices = ParallelCollectionRDD.slice(data, numSlices).toArray
slices.indices.map(i => new ParallelCollectionPartition(id, i, slices(i))).toArray
}
override def compute(s: Partition, context: TaskContext): Iterator[T] = {
new InterruptibleIterator(context, s.asInstanceOf[ParallelCollectionPartition[T]].iterator)
}
override def getPreferredLocations(s: Partition): Seq[String] = {
locationPrefs.getOrElse(s.index, Nil)
}
}
RDD中的getPartitions - 模板方法设计模式
先不着急看ParallelCollectionRDD,我们先来看看RDD中的getPartitions
/**
* Get the array of partitions of this RDD, taking into account whether the
* RDD is checkpointed or not.
*/
final def partitions: Array[Partition] = {
checkpointRDD.map(_.partitions).getOrElse {
if (partitions_ == null) {
partitions_ = getPartitions
partitions_.zipWithIndex.foreach { case (partition, index) =>
require(partition.index == index,
s"partitions($index).partition == ${partition.index}, but it should equal $index")
}
}
partitions_
}
}
getPartitions是一个抽象方法,模板方法的设计模式,RDD中只定义了上面的一套模板,但是具体的方法由子类实现。
/**
* Implemented by subclasses to return the set of partitions in this RDD. This method will only
* be called once, so it is safe to implement a time-consuming computation in it.
*
* The partitions in this array must satisfy the following property:
* `rdd.partitions.zipWithIndex.forall { case (partition, index) => partition.index == index }`
*/
protected def getPartitions: Array[Partition]
ParallelCollectionRDD中的getPartitions - 具体实现
override def getPartitions: Array[Partition] = {
val slices = ParallelCollectionRDD.slice(data, numSlices).toArray
slices.indices.map(i => new ParallelCollectionPartition(id, i, slices(i))).toArray
}
slice
- 对我们传进来的seq进行匹配,我们传进来的是list,匹配最后一个
- 将list集合转换成一个数组
将数组的长度和分区数传给了slice中定义的positions方法
/** * Slice a collection into numSlices sub-collections. One extra thing we do here is to treat Range * collections specially, encoding the slices as other Ranges to minimize memory cost. This makes * it efficient to run Spark over RDDs representing large sets of numbers. And if the collection * is an inclusive Range, we use inclusive range for the last slice. */ def slice[T: ClassTag](seq: Seq[T], numSlices: Int): Seq[Seq[T]] = { //切片小于1,报异常 if (numSlices < 1) { throw new IllegalArgumentException("Positive number of slices required") } // Sequences need to be sliced at the same set of index positions for operations // like RDD.zip() to behave as expected def positions(length: Long, numSlices: Int): Iterator[(Int, Int)] = { (0 until numSlices).iterator.map { i => val start = ((i * length) / numSlices).toInt val end = (((i + 1) * length) / numSlices).toInt (start, end) } } //对我们传进来的seq进行匹配,我们传进来的是list,匹配最后一个 seq match { case r: Range => positions(r.length, numSlices).zipWithIndex.map { case ((start, end), index) => // If the range is inclusive, use inclusive range for the last slice if (r.isInclusive && index == numSlices - 1) { new Range.Inclusive(r.start + start * r.step, r.end, r.step) } else { new Range(r.start + start * r.step, r.start + end * r.step, r.step) } }.toSeq.asInstanceOf[Seq[Seq[T]]] case nr: NumericRange[_] => // For ranges of Long, Double, BigInteger, etc val slices = new ArrayBuffer[Seq[T]](numSlices) var r = nr for ((start, end) <- positions(nr.length, numSlices)) { val sliceSize = end - start slices += r.take(sliceSize).asInstanceOf[Seq[T]] r = r.drop(sliceSize) } slices case _ => //将list集合转换成一个数组 val array = seq.toArray // To prevent O(n^2) operations for List etc //将数组的长度和分区数传给了slice中定义的positions方法 positions(array.length, numSlices).map { case (start, end) => array.slice(start, end).toSeq }.toSeq } }positions
将数组的长度和分区数传给了slice中定义的positions方法
返回一个可迭代的元组,存储每一个分区中数据的起始位置和结束位置
[start, end) 包含start,不包含enddef positions(length: Long, numSlices: Int): Iterator[(Int, Int)] = { (0 until numSlices).iterator.map { i => val start = ((i * length) / numSlices).toInt val end = (((i + 1) * length) / numSlices).toInt (start, end) } }传进来的原始List[1,2,3,4]
length = 4个数
numSlices = 3个分区
i = 0 -> start = (04)/3 = 0, end = (0+1) 4 / 3 = 1 [0,1)
i = 1-> start = (14)/3 = 1, end = (1+1) 4 / 3 = 2 [1,2)
i = 2-> start = (24)/3 = 2, end = (2+1) 4 / 3 = 4 [2,4)array.slice(start, end)
数据的切片逻辑就和上面数字演示的是一样的
def slice(from: Int, until: Int): Repr = { val lo = math.max(from, 0) val hi = math.min(math.max(until, 0), length) val elems = math.max(hi - lo, 0) val b = newBuilder b.sizeHint(elems) var i = lo while (i < hi) { b += self(i) i += 1 } b.result() }这里走完返回一个seq集合,就又要回到ParallelCollectionRDD的getPartitions方法,每一个分区中具体有哪些数据就知道了,把刚刚切片的数据放入对应的分区。
override def getPartitions: Array[Partition] = { val slices = ParallelCollectionRDD.slice(data, numSlices).toArray slices.indices.map(i => new ParallelCollectionPartition(id, i, slices(i))).toArray }ParallelCollectionRDD中的compute方法会对每一个分区中的数据进行迭代。
override def compute(s: Partition, context: TaskContext): Iterator[T] = { new InterruptibleIterator(context, s.asInstanceOf[ParallelCollectionPartition[T]].iterator) }总结
集合创建RDD,指定几个分区,实际就创建几个分区
- 确定分区中有哪些数据,最关键的就是positions方法,确定[start, end),具体过程看前面的演示。
def positions(length: Long, numSlices: Int): Iterator[(Int, Int)] = { (0 until numSlices).iterator.map { i => val start = ((i * length) / numSlices).toInt val end = (((i + 1) * length) / numSlices).toInt (start, end) } }
