以SortShuffleManager为例进行分析

  1. /**
  2. * In sort-based shuffle, incoming records are sorted according to their target partition ids, then
  3. * written to a single map output file. Reducers fetch contiguous regions of this file in order to
  4. * read their portion of the map output. In cases where the map output data is too large to fit in
  5. * memory, sorted subsets of the output can are spilled to disk and those on-disk files are merged
  6. * to produce the final output file.
  7. * 翻译:在居于排序的shuffle中,shuffle的记录是安装分区号进行排序存储在一个map输出的文件中,reducer
  8. * 从文件中读取连对应分区连续的部分数据,在map输出到内存中,如果缓冲区不够,则会溢写到磁盘文件,而对于每次溢写
  9. * 的文件,最终会被合并为一个输出文件
  10. *
  11. * Sort-based shuffle has two different write paths for producing its map output files:
  12. *
  13. * - Serialized sorting: used when all three of the following conditions hold:
  14. * 1. The shuffle dependency specifies no aggregation or output ordering.
  15. * 2. The shuffle serializer supports relocation of serialized values (this is currently
  16. * supported by KryoSerializer and Spark SQL's custom serializers).
  17. * 3. The shuffle produces fewer than 16777216 output partitions.
  18. * - Deserialized sorting: used to handle all other cases.
  19. *
  20. * -----------------------
  21. * Serialized sorting mode
  22. * -----------------------
  23. *
  24. * In the serialized sorting mode, incoming records are serialized as soon as they are passed to the
  25. * shuffle writer and are buffered in a serialized form during sorting. This write path implements
  26. * several optimizations:
  27. * 翻译:在序列化模式中,map的记录被传入shuffle writer时就进行序列化
  28. *
  29. * - Its sort operates on serialized binary data rather than Java objects, which reduces memory
  30. * consumption and GC overheads. This optimization requires the record serializer to have certain
  31. * properties to allow serialized records to be re-ordered without requiring deserialization.
  32. * See SPARK-4550, where this optimization was first proposed and implemented, for more details.
  33. *
  34. *关键意思翻译:该模式是对序列化的二进制数据进行排序,而不是对java objects进行排序,这样能减少内存消耗和gc负债
  35. * - It uses a specialized cache-efficient sorter ([[ShuffleExternalSorter]]) that sorts
  36. * arrays of compressed record pointers and partition ids. By using only 8 bytes of space per
  37. * record in the sorting array, this fits more of the array into cache.
  38. *关键意思翻译:它使用ShuffleExternalSorter对压缩记录指针的数组和分区ids进行排序,而在排序数组中,
  39. *每条记录使用了8个字节的空间,大大减少了内存的消耗
  40. * - The spill merging procedure operates on blocks of serialized records that belong to the same
  41. * partition and does not need to deserialize records during the merge.
  42. *关键意思翻译:溢出合并的过程不需要进行反序列化
  43. * - When the spill compression codec supports concatenation of compressed data, the spill merge
  44. * simply concatenates the serialized and compressed spill partitions to produce the final output
  45. * partition. This allows efficient data copying methods, like NIO's `transferTo`, to be used
  46. * and avoids the need to allocate decompression or copying buffers during the merge.
  47. *
  48. * For more details on these optimizations, see SPARK-7081.
  49. */
  50. private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager with Logging{
  51. ...此处省略....
  52. }

大概的意思就是,SortShuffleManager会对map记录进行序列化和压缩。在溢出和合并的时候会进行排序。

shuffle注册

 override def registerShuffle[K, V, C](
      shuffleId: Int,
      numMaps: Int,
      dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
    if (SortShuffleWriter.shouldBypassMergeSort(SparkEnv.get.conf, dependency)) {
      // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't
      // need map-side aggregation, then write numPartitions files directly and just concatenate
      // them at the end. This avoids doing serialization and deserialization twice to merge
      // together the spilled files, which would happen with the normal code path. The downside is
      // having multiple files open at a time and thus more memory allocated to buffers.
      new BypassMergeSortShuffleHandle[K, V](
        shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]])
    } else if (SortShuffleManager.canUseSerializedShuffle(dependency)) {
      // Otherwise, try to buffer map outputs in a serialized form, since this is more efficient:
      new SerializedShuffleHandle[K, V](
        shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]])
    } else {
      // Otherwise, buffer map outputs in a deserialized form:
      new BaseShuffleHandle(shuffleId, numMaps, dependency)
    }
  }

包括三种handle:BypassMergeSortShuffleHandle、SerializedShuffleHandle、BaseShuffleHandle。
BaseShuffleHandle是默认的handle。当map端不进行聚合且map任务不超过bypass阈值(默认是200),则注册BypassMergeSortShuffleHandle。

shuffle write

根据不同的handle,获取不同的writer

 override def getWriter[K, V](
      handle: ShuffleHandle,
      mapId: Int,
      context: TaskContext): ShuffleWriter[K, V] = {
    numMapsForShuffle.putIfAbsent(
      handle.shuffleId, handle.asInstanceOf[BaseShuffleHandle[_, _, _]].numMaps)
    val env = SparkEnv.get
    handle match {
      case unsafeShuffleHandle: SerializedShuffleHandle[K @unchecked, V @unchecked] =>
        new UnsafeShuffleWriter(
          env.blockManager,
          shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver],
          context.taskMemoryManager(),
          unsafeShuffleHandle,
          mapId,
          context,
          env.conf)
      case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] =>
        new BypassMergeSortShuffleWriter(
          env.blockManager,
          shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver],
          bypassMergeSortHandle,
          mapId,
          context,
          env.conf)
      case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] =>
        new SortShuffleWriter(shuffleBlockResolver, other, mapId, context)
    }
  }

SortShuffleWriter

write方法

/** Write a bunch of records to this task's output */
  override def write(records: Iterator[Product2[K, V]]): Unit = {
    sorter = if (dep.mapSideCombine) {
      require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!")
      new ExternalSorter[K, V, C](
        context, dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
    } else {
      // In this case we pass neither an aggregator nor an ordering to the sorter, because we don't
      // care whether the keys get sorted in each partition; that will be done on the reduce side
      // if the operation being run is sortByKey.
      new ExternalSorter[K, V, V](
        context, aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer)
    }
    sorter.insertAll(records)

    // Don't bother including the time to open the merged output file in the shuffle write time,
    // because it just opens a single file, so is typically too fast to measure accurately
    // (see SPARK-3570).
    val output = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)
    val tmp = Utils.tempFileWith(output)
    try {
      val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID)
      val partitionLengths = sorter.writePartitionedFile(blockId, tmp)
      shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp)
      mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
    } finally {
      if (tmp.exists() && !tmp.delete()) {
        logError(s"Error while deleting temp file ${tmp.getAbsolutePath}")
      }
    }
  }

ExternalSorter.insertAll关键代码,如果有map端聚合,则先调用SizeTrackingAppendOnlyMap进行合并,接着对进行spill可能性操作。如果不用聚合,则直接插入缓冲区,并进行spill可能性操作。

def insertAll(records: Iterator[Product2[K, V]]): Unit = {
    // TODO: stop combining if we find that the reduction factor isn't high
    val shouldCombine = aggregator.isDefined

    if (shouldCombine) {
      // Combine values in-memory first using our AppendOnlyMap
      val mergeValue = aggregator.get.mergeValue
      val createCombiner = aggregator.get.createCombiner
      var kv: Product2[K, V] = null
      val update = (hadValue: Boolean, oldValue: C) => {
        if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2)
      }
      while (records.hasNext) {
        addElementsRead()
        kv = records.next()
        map.changeValue((getPartition(kv._1), kv._1), update)
        maybeSpillCollection(usingMap = true)
      }
    } else {
      // Stick values into our buffer
      while (records.hasNext) {
        addElementsRead()
        val kv = records.next()
        buffer.insert(getPartition(kv._1), kv._1, kv._2.asInstanceOf[C])
        maybeSpillCollection(usingMap = false)
      }
    }
  }

map聚合缓存

SizeTrackingAppendOnlyMap.chageValue代码

def changeValue(key: K, updateFunc: (Boolean, V) => V): V = {
    assert(!destroyed, destructionMessage)
    val k = key.asInstanceOf[AnyRef]
    if (k.eq(null)) {
      if (!haveNullValue) {
        incrementSize()
      }
      nullValue = updateFunc(haveNullValue, nullValue)
      haveNullValue = true
      return nullValue
    }
    var pos = rehash(k.hashCode) & mask
    var i = 1
    while (true) {
      val curKey = data(2 * pos)
      if (curKey.eq(null)) {
        val newValue = updateFunc(false, null.asInstanceOf[V])
        data(2 * pos) = k
        data(2 * pos + 1) = newValue.asInstanceOf[AnyRef]
        incrementSize()
        return newValue
      } else if (k.eq(curKey) || k.equals(curKey)) {
        val newValue = updateFunc(true, data(2 * pos + 1).asInstanceOf[V])
        data(2 * pos + 1) = newValue.asInstanceOf[AnyRef]
        return newValue
      } else {
        val delta = i
        pos = (pos + delta) & mask
        i += 1
      }
    }
    null.asInstanceOf[V] // Never reached but needed to keep compiler happy
  }

关键逻辑:先更新合并的value,接着找到key在缓存区array的位置pos,最后将key和value存入缓冲区中。

data(2 * pos) = k
data(2 * pos + 1) = newValue.asInstanceOf[AnyRef]

其中,data是Array数组,存储形式是:key0,value0,key1,value1….,所以array的初始化大小是2 * capacity。当数组大小超过阈值时,数组会以两倍大小进行扩容。扩容时创建新的数组,然后将旧数组的元素进行拷贝。

每次更新完数组后,spark会对数组的容量增长进行预测。预测是基于采样数据进行的。

/**
   * Take a new sample of the current collection's size.
   */
  private def takeSample(): Unit = {
    samples.enqueue(Sample(SizeEstimator.estimate(this), numUpdates))
    // Only use the last two samples to extrapolate
    if (samples.size > 2) {
      samples.dequeue()
    }
    val bytesDelta = samples.toList.reverse match {
      case latest :: previous :: tail =>
        (latest.size - previous.size).toDouble / (latest.numUpdates - previous.numUpdates)
      // If fewer than 2 samples, assume no change
      case _ => 0
    }
    bytesPerUpdate = math.max(0, bytesDelta)
    nextSampleNum = math.ceil(numUpdates * SAMPLE_GROWTH_RATE).toLong
  }

  /**
   * Estimate the current size of the collection in bytes. O(1) time.
   */
  def estimateSize(): Long = {
    assert(samples.nonEmpty)
    val extrapolatedDelta = bytesPerUpdate * (numUpdates - samples.last.numUpdates)
    (samples.last.size + extrapolatedDelta).toLong
  }

map简单缓存

直接调用PartitionedPairBuffer.insert方法进行缓存

def insert(partition: Int, key: K, value: V): Unit = {
    if (curSize == capacity) {
      growArray()
    }
    data(2 * curSize) = (partition, key.asInstanceOf[AnyRef])
    data(2 * curSize + 1) = value.asInstanceOf[AnyRef]
    curSize += 1
    afterUpdate()
  }

溢出处理

ExternalSorter.maybeSpillCollection()

private def maybeSpillCollection(usingMap: Boolean): Unit = {
    var estimatedSize = 0L
    if (usingMap) {
      estimatedSize = map.estimateSize()
      if (maybeSpill(map, estimatedSize)) {
        map = new PartitionedAppendOnlyMap[K, C]
      }
    } else {
      estimatedSize = buffer.estimateSize()
      if (maybeSpill(buffer, estimatedSize)) {
        buffer = new PartitionedPairBuffer[K, C]
      }
    }

    if (estimatedSize > _peakMemoryUsedBytes) {
      _peakMemoryUsedBytes = estimatedSize
    }
  }

关键逻辑

protected def maybeSpill(collection: C, currentMemory: Long): Boolean = {
    var shouldSpill = false
    if (elementsRead % 32 == 0 && currentMemory >= myMemoryThreshold) {
      // Claim up to double our current memory from the shuffle memory pool
      val amountToRequest = 2 * currentMemory - myMemoryThreshold
      val granted = acquireMemory(amountToRequest)
      myMemoryThreshold += granted
      // If we were granted too little memory to grow further (either tryToAcquire returned 0,
      // or we already had more memory than myMemoryThreshold), spill the current collection
      shouldSpill = currentMemory >= myMemoryThreshold
    }
    shouldSpill = shouldSpill || _elementsRead > numElementsForceSpillThreshold
    // Actually spill
    if (shouldSpill) {
      _spillCount += 1
      logSpillage(currentMemory)
      spill(collection)
      _elementsRead = 0
      _memoryBytesSpilled += currentMemory
      releaseMemory()
    }
    shouldSpill
  }

BypassMergeSortShuffleWriter