ShuffleMapTask.scala

    1. // ShuffleMapTask的 runTask 有 MapStatus返回值
    2. override def runTask(context: TaskContext): MapStatus = {
    3. // Deserialize the RDD using the broadcast variable.
    4. val threadMXBean = ManagementFactory.getThreadMXBean
    5. val deserializeStartTime = System.currentTimeMillis()
    6. val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
    7. threadMXBean.getCurrentThreadCpuTime
    8. } else 0L
    9. // 对task要处理的数据,做反序列化操作
    10. /*
    11. 问题:多个task在executor中并发运行,数据可能都不在一台机器上,一个stage处理的rdd都是一样的
    12. task怎么拿到自己要处理的数据的?
    13. 答案:通过broadcast value 广播变量获取
    14. */
    15. val ser = SparkEnv.get.closureSerializer.newInstance()
    16. val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])](
    17. ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
    18. _executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime
    19. _executorDeserializeCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
    20. threadMXBean.getCurrentThreadCpuTime - deserializeStartCpuTime
    21. } else 0L
    22. var writer: ShuffleWriter[Any, Any] = null
    23. try {
    24. // 拿到shuffleManager
    25. val manager = SparkEnv.get.shuffleManager
    26. // 拿到shuffleWriter
    27. writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
    28. // 首先,调用rdd的iterator方法,并且传入了当前要处理的partition
    29. // 核心逻辑就在rdd的iterator()方法中
    30. // 执行完成rdd之后,rdd或返回处理过后的partition数据,这些数据通过shuffleWriter
    31. // 在经过HashPartitioner写入对应的分区中
    32. writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
    33. // 返回结果 MapStatus ,里面封装了ShuffleMapTask存储在哪里,其实就是BlockManager相关信息
    34. writer.stop(success = true).get
    35. } catch {
    36. case e: Exception =>
    37. try {
    38. if (writer != null) {
    39. writer.stop(success = false)
    40. }
    41. } catch {
    42. case e: Exception =>
    43. log.debug("Could not stop writer", e)
    44. }
    45. throw e
    46. }
    47. }

    ShuffledRDD.scala

    1. override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = {
    2. val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]
    3. // ResultTask或ShuffleMapTask执行到ShuffledRDD的时候,计算当前RDD的partition数据
    4. // 会调用ShuffleManager的getReader() 获取到HashShuffleReader,然后调用read()方法
    5. // 读取ResultTask或ShuffleMapTask的数据
    6. SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
    7. .read()
    8. .asInstanceOf[Iterator[(K, C)]]
    9. }

    BlockStoreShuffleReader.scala

    1. override def read(): Iterator[Product2[K, C]] = {
    2. // ResultTask在读取数据的时候,调用ShuffleBlockFetcherIterator从那个DAGSchduler的mapOutputTracker中获取数据
    3. // 通过BlockManager从对应的位置读取
    4. val blockFetcherItr = new ShuffleBlockFetcherIterator(
    5. context,
    6. blockManager.shuffleClient,
    7. blockManager,
    8. mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
    9. // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
    10. SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024)
    11. // Wrap the streams for compression based on configuration
    12. val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) =>
    13. blockManager.wrapForCompression(blockId, inputStream)

    ShuffleBlockFetcherIterator.scala

    1. private[this] def initialize(): Unit = {
    2. // Add a task completion callback (called in both success case and failure case) to cleanup.
    3. context.addTaskCompletionListener(_ => cleanup())
    4. // Split local and remote blocks.
    5. val remoteRequests = splitLocalRemoteBlocks()
    6. // Add the remote requests into our queue in a random order
    7. fetchRequests ++= Utils.randomize(remoteRequests)
    8. // Send out initial requests for blocks, up to our maxBytesInFlight
    9. fetchUpToMaxBytes()
    10. val numFetches = remoteRequests.size - fetchRequests.size
    11. logInfo("Started " + numFetches + " remote fetches in" + Utils.getUsedTimeMs(startTime))
    12. // Get Local Blocks
    13. fetchLocalBlocks()
    14. logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime))
    15. }
    1. private def fetchUpToMaxBytes(): Unit = {
    2. // Send fetch requests up to maxBytesInFlight
    3. // 这里有一个重要的参数,max.bytes.in.flight 它决定了最多能拉取多少数据到本地
    4. // 然后就开始执行reduce中自定义算子
    5. while (fetchRequests.nonEmpty &&
    6. (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
    7. // 发送请求到远程获取数据
    8. sendRequest(fetchRequests.dequeue())
    9. }
    10. }