Spark 目前支持 Hash 分区和 Range 分区,和用户自定义分区。Hash 分区为当前的默认分区。分区器直接决定了 RDD 中分区的个数、RDD 中每条数据经过 Shuffle 后进入哪个分区,进而决定了 Reduce 的个数。
➢ 只有 Key-Value 类型的 RDD 才有分区器,非 Key-Value 类型的 RDD 分区的值是 None
➢ 每个 RDD 的分区 ID 范围:0 ~ (numPartitions - 1),决定这个值是属于那个分区的。
Hash 分区:对于给定的 key,计算其 hashCode,并除以分区个数取余
class HashPartitioner(partitions: Int) extends Partitioner {require(partitions >= 0, s"Number of partitions ($partitions) cannot be negative.")def numPartitions: Int = partitionsdef getPartition(key: Any): Int = key match {case null => 0case _ => Utils.nonNegativeMod(key.hashCode, numPartitions)}override def equals(other: Any): Boolean = other match {case h: HashPartitioner =>h.numPartitions == numPartitionscase _ =>false}override def hashCode: Int = numPartitions}
Range 分区:将一定范围内的数据映射到一个分区中,尽量保证每个分区数据均匀,而且分区间有序
class RangePartitioner[K : Ordering : ClassTag, V](partitions: Int,rdd: RDD[_ <: Product2[K, V]],private var ascending: Boolean = true,val samplePointsPerPartitionHint: Int = 20)extends Partitioner {// A constructor declared in order to maintain backward compatibility for Java, when we add the// 4th constructor parameter samplePointsPerPartitionHint. See SPARK-22160.// This is added to make sure from a bytecode point of view, there is still a 3-arg ctor.def this(partitions: Int, rdd: RDD[_ <: Product2[K, V]], ascending: Boolean) = {this(partitions, rdd, ascending, samplePointsPerPartitionHint = 20)}// We allow partitions = 0, which happens when sorting an empty RDD under the default settings.require(partitions >= 0, s"Number of partitions cannot be negative but found $partitions.")require(samplePointsPerPartitionHint > 0,s"Sample points per partition must be greater than 0 but found $samplePointsPerPartitionHint")private var ordering = implicitly[Ordering[K]]// An array of upper bounds for the first (partitions - 1) partitionsprivate var rangeBounds: Array[K] = {if (partitions <= 1) {Array.empty} else {// This is the sample size we need to have roughly balanced output partitions, capped at 1M.// Cast to double to avoid overflowing ints or longsval sampleSize = math.min(samplePointsPerPartitionHint.toDouble * partitions, 1e6)// Assume the input partitions are roughly balanced and over-sample a little bit.val sampleSizePerPartition = math.ceil(3.0 * sampleSize / rdd.partitions.length).toIntval (numItems, sketched) = RangePartitioner.sketch(rdd.map(_._1), sampleSizePerPartition)if (numItems == 0L) {Array.empty} else {// If a partition contains much more than the average number of items, we re-sample from it// to ensure that enough items are collected from that partition.val fraction = math.min(sampleSize / math.max(numItems, 1L), 1.0)val candidates = ArrayBuffer.empty[(K, Float)]val imbalancedPartitions = mutable.Set.empty[Int]sketched.foreach { case (idx, n, sample) =>if (fraction * n > sampleSizePerPartition) {imbalancedPartitions += idx} else {// The weight is 1 over the sampling probability.val weight = (n.toDouble / sample.length).toFloatfor (key <- sample) {candidates += ((key, weight))}}}if (imbalancedPartitions.nonEmpty) {// Re-sample imbalanced partitions with the desired sampling probability.val imbalanced = new PartitionPruningRDD(rdd.map(_._1), imbalancedPartitions.contains)val seed = byteswap32(-rdd.id - 1)val reSampled = imbalanced.sample(withReplacement = false, fraction, seed).collect()val weight = (1.0 / fraction).toFloatcandidates ++= reSampled.map(x => (x, weight))}RangePartitioner.determineBounds(candidates, math.min(partitions, candidates.size))}}}def numPartitions: Int = rangeBounds.length + 1private var binarySearch: ((Array[K], K) => Int) = CollectionsUtils.makeBinarySearch[K]def getPartition(key: Any): Int = {val k = key.asInstanceOf[K]var partition = 0if (rangeBounds.length <= 128) {// If we have less than 128 partitions naive searchwhile (partition < rangeBounds.length && ordering.gt(k, rangeBounds(partition))) {partition += 1}} else {// Determine which binary search method to use only once.partition = binarySearch(rangeBounds, k)// binarySearch either returns the match location or -[insertion point]-1if (partition < 0) {partition = -partition-1}if (partition > rangeBounds.length) {partition = rangeBounds.length}}if (ascending) {partition} else {rangeBounds.length - partition}}override def equals(other: Any): Boolean = other match {case r: RangePartitioner[_, _] =>r.rangeBounds.sameElements(rangeBounds) && r.ascending == ascendingcase _ =>false}override def hashCode(): Int = {val prime = 31var result = 1var i = 0while (i < rangeBounds.length) {result = prime * result + rangeBounds(i).hashCodei += 1}result = prime * result + ascending.hashCoderesult}@throws(classOf[IOException])private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException {val sfactory = SparkEnv.get.serializersfactory match {case js: JavaSerializer => out.defaultWriteObject()case _ =>out.writeBoolean(ascending)out.writeObject(ordering)out.writeObject(binarySearch)val ser = sfactory.newInstance()Utils.serializeViaNestedStream(out, ser) { stream =>stream.writeObject(scala.reflect.classTag[Array[K]])stream.writeObject(rangeBounds)}}}@throws(classOf[IOException])private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException {val sfactory = SparkEnv.get.serializersfactory match {case js: JavaSerializer => in.defaultReadObject()case _ =>ascending = in.readBoolean()ordering = in.readObject().asInstanceOf[Ordering[K]]binarySearch = in.readObject().asInstanceOf[(Array[K], K) => Int]val ser = sfactory.newInstance()Utils.deserializeViaNestedStream(in, ser) { ds =>implicit val classTag = ds.readObject[ClassTag[Array[K]]]()rangeBounds = ds.readObject[Array[K]]()}}}}
自定义分区器
package partimport org.apache.log4j.{Level, Logger}import org.apache.spark.{Partitioner, SparkConf, SparkContext}object Spark_Part {def main(args: Array[String]): Unit = {//屏蔽日志信息Logger.getLogger("org").setLevel(Level.ERROR)//创建sparkconfval conf = new SparkConf().setMaster("local[2]").setAppName("wc")//创建spark程序入口val sc = new SparkContext(conf)//创建集合对象val list = List(("nba","************"),("cba","************"),("wnba","************"),("nba","************"))//将集合对象写进RDD里 并创建三个分区val inputRDD = sc.makeRDD(list,3)//将新的RDD使用partitionby方法自定义分区val value = inputRDD.partitionBy(new Mypartitioner)//保存到文件里value.saveAsTextFile("output")sc.stop()}/*** 第一 : 自定义分区器* 第二 : 重写方法*/class Mypartitioner extends Partitioner{//分区数量override def numPartitions: Int = 3//根据数据的key值 返回数据所在的分区索引 (从0开始)override def getPartition(key: Any): Int = {//方式一 : 用if做判断// if(key == "nba"){// 0// }else if (key == "cba"){// 1// }else{// 2// }//方式二 : 用模式匹配//如果是nba 放到0号分区,如果是cba 放到1号分区,如果是其他,放到2号分区key match {case "nba" => 0case "cba" => 1case _ => 2}}}}
