一、环境准备(Yarn集群)
1、Driver、Executor
二、组件通信
1、Driver => Executor
2、Executor => Driver
3、Executor => Executor
三、作业执行
1、RDD依赖
2、阶段的划分
3、任务的切分
4、任务的调度
5、任务的执行
四、Shuffle
1、Shuffle的原理和执行过程
2、Shuffle写磁盘
3、Shuffle读取磁盘
五、内存的管理
1、内存的分类
2、内存的配置

一、环境准备

1、起点:SparkSubmit

我们要向yarn集群提交任务,需要使用bin/spark-submit脚本来提交任务,

  1. #!/usr/bin/env bash
  2. if [ -z "${SPARK_HOME}" ]; then
  3. source "$(dirname "$0")"/find-spark-home
  4. fi
  5. # disable randomized hash for string in Python 3.3+
  6. export PYTHONHASHSEED=0
  7. # 执行spark-class
  8. exec "${SPARK_HOME}"/bin/spark-class org.apache.spark.deploy.SparkSubmit "$@"

执行spark-class脚本去执行org.apache.spark.deploy.SparkSubmit类。
而spark-classs

  1. #!/usr/bin/env bash
  2. #
  3. # Licensed to the Apache Software Foundation (ASF) under one or more
  4. # contributor license agreements. See the NOTICE file distributed with
  5. # this work for additional information regarding copyright ownership.
  6. # The ASF licenses this file to You under the Apache License, Version 2.0
  7. # (the "License"); you may not use this file except in compliance with
  8. # the License. You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. #
  18. if [ -z "${SPARK_HOME}" ]; then
  19. source "$(dirname "$0")"/find-spark-home
  20. fi
  21. . "${SPARK_HOME}"/bin/load-spark-env.sh
  22. # Find the java binary
  23. if [ -n "${JAVA_HOME}" ]; then
  24. RUNNER="${JAVA_HOME}/bin/java"
  25. else
  26. if [ "$(command -v java)" ]; then
  27. RUNNER="java"
  28. else
  29. echo "JAVA_HOME is not set" >&2
  30. exit 1
  31. fi
  32. fi
  33. # Find Spark jars.
  34. if [ -d "${SPARK_HOME}/jars" ]; then
  35. SPARK_JARS_DIR="${SPARK_HOME}/jars"
  36. else
  37. SPARK_JARS_DIR="${SPARK_HOME}/assembly/target/scala-$SPARK_SCALA_VERSION/jars"
  38. fi
  39. if [ ! -d "$SPARK_JARS_DIR" ] && [ -z "$SPARK_TESTING$SPARK_SQL_TESTING" ]; then
  40. echo "Failed to find Spark jars directory ($SPARK_JARS_DIR)." 1>&2
  41. echo "You need to build Spark with the target \"package\" before running this program." 1>&2
  42. exit 1
  43. else
  44. LAUNCH_CLASSPATH="$SPARK_JARS_DIR/*"
  45. fi
  46. # Add the launcher build dir to the classpath if requested.
  47. if [ -n "$SPARK_PREPEND_CLASSES" ]; then
  48. LAUNCH_CLASSPATH="${SPARK_HOME}/launcher/target/scala-$SPARK_SCALA_VERSION/classes:$LAUNCH_CLASSPATH"
  49. fi
  50. # For tests
  51. if [[ -n "$SPARK_TESTING" ]]; then
  52. unset YARN_CONF_DIR
  53. unset HADOOP_CONF_DIR
  54. fi
  55. # The launcher library will print arguments separated by a NULL character, to allow arguments with
  56. # characters that would be otherwise interpreted by the shell. Read that in a while loop, populating
  57. # an array that will be used to exec the final command.
  58. #
  59. # The exit code of the launcher is appended to the output, so the parent shell removes it from the
  60. # command array and checks the value to see if the launcher succeeded.
  61. build_command() {
  62. "$RUNNER" -Xmx128m $SPARK_LAUNCHER_OPTS -cp "$LAUNCH_CLASSPATH" org.apache.spark.launcher.Main "$@"
  63. printf "%d\0" $?
  64. }
  65. # Turn off posix mode since it does not allow process substitution
  66. set +o posix
  67. CMD=()
  68. DELIM=$'\n'
  69. CMD_START_FLAG="false"
  70. while IFS= read -d "$DELIM" -r ARG; do
  71. if [ "$CMD_START_FLAG" == "true" ]; then
  72. CMD+=("$ARG")
  73. else
  74. if [ "$ARG" == $'\0' ]; then
  75. # After NULL character is consumed, change the delimiter and consume command string.
  76. DELIM=''
  77. CMD_START_FLAG="true"
  78. elif [ "$ARG" != "" ]; then
  79. echo "$ARG"
  80. fi
  81. fi
  82. done < <(build_command "$@")
  83. COUNT=${#CMD[@]}
  84. LAST=$((COUNT - 1))
  85. LAUNCHER_EXIT_CODE=${CMD[$LAST]}
  86. # Certain JVM failures result in errors being printed to stdout (instead of stderr), which causes
  87. # the code that parses the output of the launcher to get confused. In those cases, check if the
  88. # exit code is an integer, and if it's not, handle it as a special error case.
  89. if ! [[ $LAUNCHER_EXIT_CODE =~ ^[0-9]+$ ]]; then
  90. echo "${CMD[@]}" | head -n-1 1>&2
  91. exit 1
  92. fi
  93. if [ $LAUNCHER_EXIT_CODE != 0 ]; then
  94. exit $LAUNCHER_EXIT_CODE
  95. fi
  96. CMD=("${CMD[@]:0:$LAST}")
  97. exec "${CMD[@]}"

最终执行就是

  1. java -cp org.apache.spark.deploy.SparkSubmit -Xmx1g xxxxxxxx

要能够执行java,就需要执行这个类的main方法,所以一切的起点就在SparkSubmit中,Scala中的main方法肯定是在object伴生对象中的。

2、向Yarn提交程序

  • super.doSubmit(args)
    • 解析参数值 parseArguments(args: Array[String])
      • new SparkSubmitArguments
        • 解析参数 parse(args.asJava),通过正则表达式分割出来
          • handle处理参数,获取到每个参数的值
        • 校验合法参数 validateSubmitArguments,获取到action=SUBMIT
          • SparkSubmitOptionParser中将每个参数解析得到
    • 提交任务 submit(appArgs, uninitLog)
      • runMain执行main方法
        • 准备环境 prepareSubmitEnvironment(args)
          • 判断当前是哪种环境:kubernetes、yarn、standalone等,获取childMainClass信息
          • 反射加载当前环境的Class
          • 根据是否继承了SparkApplication创建不同对象
          • start启动

1、SparkSubmit的main方法开始提交任务

  1. override def main(args: Array[String]): Unit = {
  2. // 创建一个SparkSubmit
  3. val submit = new SparkSubmit() {
  4. self =>
  5. override protected def parseArguments(args: Array[String]): SparkSubmitArguments = {
  6. // 准备一个SparkSubmitArguments将命令行启动参数传入
  7. new SparkSubmitArguments(args) {
  8. override protected def logInfo(msg: => String): Unit = self.logInfo(msg)
  9. override protected def logWarning(msg: => String): Unit = self.logWarning(msg)
  10. override protected def logError(msg: => String): Unit = self.logError(msg)
  11. }
  12. }
  13. override protected def logInfo(msg: => String): Unit = printMessage(msg)
  14. override protected def logWarning(msg: => String): Unit = printMessage(s"Warning: $msg")
  15. override protected def logError(msg: => String): Unit = printMessage(s"Error: $msg")
  16. override def doSubmit(args: Array[String]): Unit = {
  17. try {
  18. // 父类提交
  19. super.doSubmit(args)
  20. } catch {
  21. case e: SparkUserAppException =>
  22. // 异常退出
  23. exitFn(e.exitCode)
  24. }
  25. }
  26. }
  27. // 执行提交
  28. submit.doSubmit(args)
  29. }

1.1、处理提交parseArguments。

  1. def doSubmit(args: Array[String]): Unit = {
  2. // Initialize logging if it hasn't been done yet. Keep track of whether logging needs to
  3. // be reset before the application starts.
  4. val uninitLog = initializeLogIfNecessary(true, silent = true)
  5. // 解析参数
  6. val appArgs = parseArguments(args)
  7. if (appArgs.verbose) {
  8. logInfo(appArgs.toString)
  9. }
  10. // 匹配执行
  11. appArgs.action match {
  12. case SparkSubmitAction.SUBMIT => submit(appArgs, uninitLog)
  13. case SparkSubmitAction.KILL => kill(appArgs)
  14. case SparkSubmitAction.REQUEST_STATUS => requestStatus(appArgs)
  15. case SparkSubmitAction.PRINT_VERSION => printVersion()
  16. }
  17. }

1.1.1、解析命令行启动参数:new 了一个SparkSubmitArguments,准备了很多参数并且调用 parse(args.asJava)解析

  1. private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, String] = sys.env)
  2. extends SparkSubmitArgumentsParser with Logging {
  3. var master: String = null
  4. var deployMode: String = null
  5. var executorMemory: String = null
  6. var executorCores: String = null
  7. var totalExecutorCores: String = null
  8. var propertiesFile: String = null
  9. var driverMemory: String = null
  10. var driverExtraClassPath: String = null
  11. var driverExtraLibraryPath: String = null
  12. var driverExtraJavaOptions: String = null
  13. var queue: String = null
  14. var numExecutors: String = null
  15. var files: String = null
  16. var archives: String = null
  17. var mainClass: String = null
  18. .........................
  19. // 解析参数
  20. parse(args.asJava)
  21. // Populate `sparkProperties` map from properties file
  22. mergeDefaultSparkProperties()
  23. // Remove keys that don't start with "spark." from `sparkProperties`.
  24. ignoreNonSparkProperties()
  25. // Use `sparkProperties` map along with env vars to fill in any missing parameters
  26. loadEnvironmentArguments()
  27. useRest = sparkProperties.getOrElse("spark.master.rest.enabled", "false").toBoolean
  28. ....................
  29. // Action should be SUBMIT unless otherwise specified
  30. // action被赋予默认值SUBMIT
  31. action = Option(action).getOrElse(SUBMIT)
  32. // 在这里校验参数
  33. validateArguments()
  34. private def validateArguments(): Unit = {
  35. action match {
  36. // 校验合法参数
  37. case SUBMIT => validateSubmitArguments()
  38. case KILL => validateKillArguments()
  39. case REQUEST_STATUS => validateStatusRequestArguments()
  40. case PRINT_VERSION =>
  41. }
  42. }
  43. }

1.1.2、解析参数,通过正则表达式解析得到参数名和值,并调用handle处理

  1. protected final void parse(List<String> args) {
  2. // 通过正则表达式解析
  3. Pattern eqSeparatedOpt = Pattern.compile("(--[^=]+)=(.+)");
  4. int idx = 0;
  5. for (idx = 0; idx < args.size(); idx++) {
  6. String arg = args.get(idx);
  7. String value = null;
  8. Matcher m = eqSeparatedOpt.matcher(arg);
  9. if (m.matches()) {
  10. // 解析到命令行参数和值
  11. arg = m.group(1);
  12. value = m.group(2);
  13. }
  14. // Look for options with a value.
  15. String name = findCliOption(arg, opts);
  16. if (name != null) {
  17. if (value == null) {
  18. if (idx == args.size() - 1) {
  19. throw new IllegalArgumentException(
  20. String.format("Missing argument for option '%s'.", arg));
  21. }
  22. idx++;
  23. value = args.get(idx);
  24. }
  25. if (!handle(name, value)) {
  26. break;
  27. }
  28. continue;
  29. }
  30. // Look for a switch.
  31. name = findCliOption(arg, switches);
  32. if (name != null) {
  33. // 处理这些参数和值
  34. if (!handle(name, null)) {
  35. break;
  36. }
  37. continue;
  38. }
  39. if (!handleUnknown(arg)) {
  40. break;
  41. }
  42. }
  43. if (idx < args.size()) {
  44. idx++;
  45. }
  46. handleExtraArgs(args.subList(idx, args.size()));
  47. }

1.1.3、handle处理参数,得到所有值

  1. override protected def handle(opt: String, value: String): Boolean = {
  2. opt match {
  3. case NAME =>
  4. name = value
  5. case MASTER =>
  6. master = value
  7. case CLASS =>
  8. mainClass = value
  9. case DEPLOY_MODE =>
  10. if (value != "client" && value != "cluster") {
  11. error("--deploy-mode must be either \"client\" or \"cluster\"")
  12. }
  13. deployMode = value
  14. case NUM_EXECUTORS =>
  15. numExecutors = value
  16. case TOTAL_EXECUTOR_CORES =>
  17. totalExecutorCores = value
  18. case EXECUTOR_CORES =>
  19. executorCores = value
  20. case EXECUTOR_MEMORY =>
  21. executorMemory = value
  22. case DRIVER_MEMORY =>
  23. driverMemory = value
  24. case DRIVER_CORES =>
  25. driverCores = value
  26. case DRIVER_CLASS_PATH =>
  27. driverExtraClassPath = value
  28. case DRIVER_JAVA_OPTIONS =>
  29. driverExtraJavaOptions = value
  30. case DRIVER_LIBRARY_PATH =>
  31. driverExtraLibraryPath = value
  32. case PROPERTIES_FILE =>
  33. propertiesFile = value
  34. case KILL_SUBMISSION =>
  35. submissionToKill = value
  36. if (action != null) {
  37. error(s"Action cannot be both $action and $KILL.")
  38. }
  39. action = KILL
  40. case STATUS =>
  41. submissionToRequestStatusFor = value
  42. if (action != null) {
  43. error(s"Action cannot be both $action and $REQUEST_STATUS.")
  44. }
  45. action = REQUEST_STATUS
  46. case SUPERVISE =>
  47. supervise = true
  48. case QUEUE =>
  49. queue = value
  50. case FILES =>
  51. files = Utils.resolveURIs(value)
  52. case PY_FILES =>
  53. pyFiles = Utils.resolveURIs(value)
  54. case ARCHIVES =>
  55. archives = Utils.resolveURIs(value)
  56. case JARS =>
  57. jars = Utils.resolveURIs(value)
  58. case PACKAGES =>
  59. packages = value
  60. case PACKAGES_EXCLUDE =>
  61. packagesExclusions = value
  62. case REPOSITORIES =>
  63. repositories = value
  64. case CONF =>
  65. val (confName, confValue) = SparkSubmitUtils.parseSparkConfProperty(value)
  66. sparkProperties(confName) = confValue
  67. case PROXY_USER =>
  68. proxyUser = value
  69. case PRINCIPAL =>
  70. principal = value
  71. case KEYTAB =>
  72. keytab = value
  73. case HELP =>
  74. printUsageAndExit(0)
  75. case VERBOSE =>
  76. verbose = true
  77. case VERSION =>
  78. action = SparkSubmitAction.PRINT_VERSION
  79. case USAGE_ERROR =>
  80. printUsageAndExit(1)
  81. case _ =>
  82. error(s"Unexpected argument '$opt'.")
  83. }
  84. action != SparkSubmitAction.PRINT_VERSION
  85. }

这些参数就是SparkSubmitOptionParser定义的,通过命令行传的值

  1. protected final String CLASS = "--class";
  2. protected final String CONF = "--conf";
  3. protected final String DEPLOY_MODE = "--deploy-mode";
  4. protected final String DRIVER_CLASS_PATH = "--driver-class-path";
  5. protected final String DRIVER_CORES = "--driver-cores";
  6. protected final String DRIVER_JAVA_OPTIONS = "--driver-java-options";
  7. protected final String DRIVER_LIBRARY_PATH = "--driver-library-path";
  8. protected final String DRIVER_MEMORY = "--driver-memory";
  9. protected final String EXECUTOR_MEMORY = "--executor-memory";
  10. protected final String FILES = "--files";
  11. protected final String JARS = "--jars";
  12. protected final String KILL_SUBMISSION = "--kill";
  13. protected final String MASTER = "--master";
  14. protected final String NAME = "--name";
  15. protected final String PACKAGES = "--packages";
  16. protected final String PACKAGES_EXCLUDE = "--exclude-packages";
  17. protected final String PROPERTIES_FILE = "--properties-file";
  18. protected final String PROXY_USER = "--proxy-user";
  19. protected final String PY_FILES = "--py-files";
  20. protected final String REPOSITORIES = "--repositories";
  21. protected final String STATUS = "--status";
  22. protected final String TOTAL_EXECUTOR_CORES = "--total-executor-cores";
  23. ..........................

action有了值,之后,就回到doSubmit中,可以提交了。

1.2、提交任务,runMain

  1. private def submit(args: SparkSubmitArguments, uninitLog: Boolean): Unit = {
  2. // 3、
  3. def doRunMain(): Unit = {
  4. if (args.proxyUser != null) {
  5. val proxyUser = UserGroupInformation.createProxyUser(args.proxyUser,
  6. UserGroupInformation.getCurrentUser())
  7. try {
  8. // 有没有代理服务器,
  9. proxyUser.doAs(new PrivilegedExceptionAction[Unit]() {
  10. override def run(): Unit = {
  11. runMain(args, uninitLog)
  12. }
  13. })
  14. } catch {
  15. case e: Exception =>
  16. // Hadoop's AuthorizationException suppresses the exception's stack trace, which
  17. // makes the message printed to the output by the JVM not very helpful. Instead,
  18. // detect exceptions with empty stack traces here, and treat them differently.
  19. if (e.getStackTrace().length == 0) {
  20. error(s"ERROR: ${e.getClass().getName()}: ${e.getMessage()}")
  21. } else {
  22. throw e
  23. }
  24. }
  25. } else {
  26. // runMain
  27. runMain(args, uninitLog)
  28. }
  29. }
  30. // In standalone cluster mode, there are two submission gateways:
  31. // (1) The traditional RPC gateway using o.a.s.deploy.Client as a wrapper
  32. // (2) The new REST-based gateway introduced in Spark 1.3
  33. // The latter is the default behavior as of Spark 1.3, but Spark submit will fail over
  34. // to use the legacy gateway if the master endpoint turns out to be not a REST server.
  35. // 1、判断是不是spark集群,是rest风格,我们是yarn集群,所以走下面
  36. if (args.isStandaloneCluster && args.useRest) {
  37. try {
  38. logInfo("Running Spark using the REST application submission protocol.")
  39. doRunMain()
  40. } catch {
  41. // Fail over to use the legacy submission gateway
  42. case e: SubmitRestConnectionException =>
  43. logWarning(s"Master endpoint ${args.master} was not a REST server. " +
  44. "Falling back to legacy submission gateway instead.")
  45. args.useRest = false
  46. submit(args, false)
  47. }
  48. // In all other modes, just run the main class as prepared
  49. } else {
  50. // 2、走这里
  51. doRunMain()
  52. }
  53. }

1.2.1、判断当前是哪种环境,得到childMainClass。当前yarn环境获取到
private[deploy] val YARN_CLUSTER_SUBMIT_CLASS = “org.apache.spark.deploy.yarn.YarnClusterApplication”

  1. if (isYarnCluster) {
  2. childMainClass = YARN_CLUSTER_SUBMIT_CLASS
  3. if (args.isPython) {
  4. childArgs += ("--primary-py-file", args.primaryResource)
  5. childArgs += ("--class", "org.apache.spark.deploy.PythonRunner")
  6. } else if (args.isR) {
  7. val mainFile = new Path(args.primaryResource).getName
  8. childArgs += ("--primary-r-file", mainFile)
  9. childArgs += ("--class", "org.apache.spark.deploy.RRunner")
  10. } else {
  11. if (args.primaryResource != SparkLauncher.NO_RESOURCE) {
  12. childArgs += ("--jar", args.primaryResource)
  13. }
  14. childArgs += ("--class", args.mainClass)
  15. }
  16. if (args.childArgs != null) {
  17. args.childArgs.foreach { arg => childArgs += ("--arg", arg) }
  18. }
  19. }
  20. if (isMesosCluster) {
  21. assert(args.useRest, "Mesos cluster mode is only supported through the REST submission API")
  22. childMainClass = REST_CLUSTER_SUBMIT_CLASS
  23. if (args.isPython) {
  24. // Second argument is main class
  25. childArgs += (args.primaryResource, "")
  26. if (args.pyFiles != null) {
  27. sparkConf.set(SUBMIT_PYTHON_FILES, args.pyFiles.split(",").toSeq)
  28. }
  29. } else if (args.isR) {
  30. // Second argument is main class
  31. childArgs += (args.primaryResource, "")
  32. } else {
  33. childArgs += (args.primaryResource, args.mainClass)
  34. }
  35. if (args.childArgs != null) {
  36. childArgs ++= args.childArgs
  37. }
  38. }
  39. if (isKubernetesCluster) {
  40. childMainClass = KUBERNETES_CLUSTER_SUBMIT_CLASS
  41. if (args.primaryResource != SparkLauncher.NO_RESOURCE) {
  42. if (args.isPython) {
  43. childArgs ++= Array("--primary-py-file", args.primaryResource)
  44. childArgs ++= Array("--main-class", "org.apache.spark.deploy.PythonRunner")
  45. } else if (args.isR) {
  46. childArgs ++= Array("--primary-r-file", args.primaryResource)
  47. childArgs ++= Array("--main-class", "org.apache.spark.deploy.RRunner")
  48. }
  49. else {
  50. childArgs ++= Array("--primary-java-resource", args.primaryResource)
  51. childArgs ++= Array("--main-class", args.mainClass)
  52. }
  53. } else {
  54. childArgs ++= Array("--main-class", args.mainClass)
  55. }
  56. if (args.childArgs != null) {
  57. args.childArgs.foreach { arg =>
  58. childArgs += ("--arg", arg)
  59. }
  60. }
  61. }

1.2.2、反射加载当前环境的类
1.2.3、判断是否继承SparkApplication,创建SparkApplication或者JavaMainApplication
1.2.4、启动

  1. private def runMain(args: SparkSubmitArguments, uninitLog: Boolean): Unit = {
  2. // 1、判断环境
  3. val (childArgs, childClasspath, sparkConf, childMainClass) = prepareSubmitEnvironment(args)
  4. // Let the main class re-initialize the logging system once it starts.
  5. if (uninitLog) {
  6. Logging.uninitialize()
  7. }
  8. if (args.verbose) {
  9. logInfo(s"Main class:\n$childMainClass")
  10. logInfo(s"Arguments:\n${childArgs.mkString("\n")}")
  11. // sysProps may contain sensitive information, so redact before printing
  12. logInfo(s"Spark config:\n${Utils.redact(sparkConf.getAll.toMap).mkString("\n")}")
  13. logInfo(s"Classpath elements:\n${childClasspath.mkString("\n")}")
  14. logInfo("\n")
  15. }
  16. val loader = getSubmitClassLoader(sparkConf)
  17. for (jar <- childClasspath) {
  18. addJarToClasspath(jar, loader)
  19. }
  20. var mainClass: Class[_] = null
  21. try {
  22. // 2、利用反射加载childMainClass
  23. mainClass = Utils.classForName(childMainClass)
  24. } catch {
  25. case e: ClassNotFoundException =>
  26. logError(s"Failed to load class $childMainClass.")
  27. if (childMainClass.contains("thriftserver")) {
  28. logInfo(s"Failed to load main class $childMainClass.")
  29. logInfo("You need to build Spark with -Phive and -Phive-thriftserver.")
  30. }
  31. throw new SparkUserAppException(CLASS_NOT_FOUND_EXIT_STATUS)
  32. case e: NoClassDefFoundError =>
  33. logError(s"Failed to load $childMainClass: ${e.getMessage()}")
  34. if (e.getMessage.contains("org/apache/hadoop/hive")) {
  35. logInfo(s"Failed to load hive class.")
  36. logInfo("You need to build Spark with -Phive and -Phive-thriftserver.")
  37. }
  38. throw new SparkUserAppException(CLASS_NOT_FOUND_EXIT_STATUS)
  39. }
  40. // 3、判断当前类是否是继承了SparkApplication,如果继承,创建一个SparkApplication实例,
  41. // 没有就创建一个JavaMainApplication
  42. val app: SparkApplication = if (classOf[SparkApplication].isAssignableFrom(mainClass)) {
  43. mainClass.getConstructor().newInstance().asInstanceOf[SparkApplication]
  44. } else {
  45. new JavaMainApplication(mainClass)
  46. }
  47. @tailrec
  48. def findCause(t: Throwable): Throwable = t match {
  49. case e: UndeclaredThrowableException =>
  50. if (e.getCause() != null) findCause(e.getCause()) else e
  51. case e: InvocationTargetException =>
  52. if (e.getCause() != null) findCause(e.getCause()) else e
  53. case e: Throwable =>
  54. e
  55. }
  56. // 4、最终启动应用程序
  57. try {
  58. app.start(childArgs.toArray, sparkConf)
  59. } catch {
  60. case t: Throwable =>
  61. throw findCause(t)
  62. }
  63. }

3、启动应用程序

image.png
可以看到没有我们的Yarn环境,是因为我们没有引入yarn的依赖

  1. <dependency>
  2. <groupId>org.apache.spark</groupId>
  3. <artifactId>spark-yarn_2.12</artifactId>
  4. <version>3.0.0</version>
  5. </dependency>

image.png

启动程序

  • 1、创建Client对象
    • new YarnClientImpl
      • 创建一个rmClient:ResourceManager
  • 2、Client.run()

    • 提交任务
      • 从rm创建一个YarnClientApplication
      • 获取响应
      • 准备容器环境 createContainerLaunchContext
        • 如果是集群,创建org.apache.spark.deploy.yarn.ApplicationMaster;否则创建org.apache.spark.deploy.yarn.ExecutorLauncher
        • 封装applicationMaster的指令,发送给ResourceManager,让rm选择一个NodeManager启动am
      • 提交任务 ```scala private[spark] class YarnClusterApplication extends SparkApplication {

    override def start(args: Array[String], conf: SparkConf): Unit = { // SparkSubmit would use yarn cache to distribute files & jars in yarn mode, // so remove them from sparkConf here for yarn mode. conf.remove(JARS) conf.remove(FILES)

    1. // 创建Client对象并启动

    new Client(new ClientArguments(args), conf, null).run() }

}

  1. 创建Client
  2. ```scala
  3. private[spark] class Client(
  4. val args: ClientArguments,
  5. val sparkConf: SparkConf,
  6. val rpcEnv: RpcEnv)
  7. extends Logging {
  8. import Client._
  9. import YarnSparkHadoopUtil._
  10. // 一上来就创建YarnClient
  11. private val yarnClient = YarnClient.createYarnClient
  12. private val hadoopConf = new YarnConfiguration(SparkHadoopUtil.newConfiguration(sparkConf))
  13. private val isClusterMode = sparkConf.get(SUBMIT_DEPLOY_MODE) == "cluster"
  14. private val isClientUnmanagedAMEnabled = sparkConf.get(YARN_UNMANAGED_AM) && !isClusterMode
  15. private var appMaster: ApplicationMaster = _
  16. private var stagingDirPath: Path = _
  1. public abstract class YarnClient extends AbstractService {
  2. @Public
  3. public static YarnClient createYarnClient() {
  4. YarnClient client = new YarnClientImpl();
  5. return client;
  6. }

创建一个实现类对象,并创建ResourceManager

  1. public YarnClientImpl() {
  2. super(YarnClientImpl.class.getName());
  3. }
  4. // 创建了一个rmClient:就是Yarn的ResourceManager
  5. @Override
  6. protected void serviceStart() throws Exception {
  7. try {
  8. rmClient = ClientRMProxy.createRMProxy(getConfig(),
  9. ApplicationClientProtocol.class);
  10. // 如果配置了历史服务器和时间服务器,将这两个再启动
  11. if (historyServiceEnabled) {
  12. historyClient.start();
  13. }
  14. if (timelineServiceEnabled) {
  15. timelineClient.start();
  16. }
  17. } catch (IOException e) {
  18. throw new YarnRuntimeException(e);
  19. }
  20. super.serviceStart();
  21. }

Client.run()一上来就提交任务

  1. def run(): Unit = {
  2. // 获取到yarn的全局任务id
  3. this.appId = submitApplication()
  4. if (!launcherBackend.isConnected() && fireAndForget) {
  5. val report = getApplicationReport(appId)
  6. val state = report.getYarnApplicationState
  7. logInfo(s"Application report for $appId (state: $state)")
  8. logInfo(formatReportDetails(report))
  9. if (state == YarnApplicationState.FAILED || state == YarnApplicationState.KILLED) {
  10. throw new SparkException(s"Application $appId finished with status: $state")
  11. }
  12. } else {
  13. val YarnAppReport(appState, finalState, diags) = monitorApplication(appId)
  14. if (appState == YarnApplicationState.FAILED || finalState == FinalApplicationStatus.FAILED) {
  15. diags.foreach { err =>
  16. logError(s"Application diagnostics message: $err")
  17. }
  18. throw new SparkException(s"Application $appId finished with failed status")
  19. }
  20. if (appState == YarnApplicationState.KILLED || finalState == FinalApplicationStatus.KILLED) {
  21. throw new SparkException(s"Application $appId is killed")
  22. }
  23. if (finalState == FinalApplicationStatus.UNDEFINED) {
  24. throw new SparkException(s"The final status of application $appId is undefined")
  25. }
  26. }
  27. }

提交任务的时候,让YarnClient连接yarn集群,并创建一个Application

  1. def submitApplication(): ApplicationId = {
  2. ResourceRequestHelper.validateResources(sparkConf)
  3. var appId: ApplicationId = null
  4. try {
  5. // 连接,初始化,启动client
  6. launcherBackend.connect()
  7. yarnClient.init(hadoopConf)
  8. yarnClient.start()
  9. logInfo("Requesting a new application from cluster with %d NodeManagers"
  10. .format(yarnClient.getYarnClusterMetrics.getNumNodeManagers))
  11. // Get a new application from our RM
  12. // 创建一个application
  13. val newApp = yarnClient.createApplication()
  14. val newAppResponse = newApp.getNewApplicationResponse()
  15. appId = newAppResponse.getApplicationId()
  16. // The app staging dir based on the STAGING_DIR configuration if configured
  17. // otherwise based on the users home directory.
  18. val appStagingBaseDir = sparkConf.get(STAGING_DIR)
  19. .map { new Path(_, UserGroupInformation.getCurrentUser.getShortUserName) }
  20. .getOrElse(FileSystem.get(hadoopConf).getHomeDirectory())
  21. stagingDirPath = new Path(appStagingBaseDir, getAppStagingDir(appId))
  22. new CallerContext("CLIENT", sparkConf.get(APP_CALLER_CONTEXT),
  23. Option(appId.toString)).setCurrentContext()
  24. // Verify whether the cluster has enough resources for our AM
  25. verifyClusterResources(newAppResponse)
  26. // Set up the appropriate contexts to launch our AM
  27. val containerContext = createContainerLaunchContext(newAppResponse)
  28. val appContext = createApplicationSubmissionContext(newApp, containerContext)
  29. // Finally, submit and monitor the application
  30. logInfo(s"Submitting application $appId to ResourceManager")
  31. yarnClient.submitApplication(appContext)
  32. launcherBackend.setAppId(appId.toString)
  33. reportLauncherState(SparkAppHandle.State.SUBMITTED)
  34. appId
  35. } catch {
  36. case e: Throwable =>
  37. if (stagingDirPath != null) {
  38. cleanupStagingDir()
  39. }
  40. throw e
  41. }
  42. }

YarnClient创建应用,创建一个YarnClientApplication

  1. @Override
  2. public YarnClientApplication createApplication()
  3. throws YarnException, IOException {
  4. ApplicationSubmissionContext context = Records.newRecord
  5. (ApplicationSubmissionContext.class);
  6. GetNewApplicationResponse newApp = getNewApplication();
  7. ApplicationId appId = newApp.getApplicationId();
  8. context.setApplicationId(appId);
  9. return new YarnClientApplication(newApp, context);
  10. }

准备容器环境createContainerLaunchContext
开始是一些jvm参数配置,后面准备am(ApplicationMaster的配置),运行java进程

  1. 。。。。。。。。。。。。。。。。。
  2. val javaOpts = ListBuffer[String]()
  3. // Set the environment variable through a command prefix
  4. // to append to the existing value of the variable
  5. var prefixEnv: Option[String] = None
  6. // Add Xmx for AM memory
  7. javaOpts += "-Xmx" + amMemory + "m"
  8. val tmpDir = new Path(Environment.PWD.$$(), YarnConfiguration.DEFAULT_CONTAINER_TEMP_DIR)
  9. javaOpts += "-Djava.io.tmpdir=" + tmpDir
  10. // TODO: Remove once cpuset version is pushed out.
  11. // The context is, default gc for server class machines ends up using all cores to do gc -
  12. // hence if there are multiple containers in same node, Spark GC affects all other containers'
  13. // performance (which can be that of other Spark containers)
  14. // Instead of using this, rely on cpusets by YARN to enforce "proper" Spark behavior in
  15. // multi-tenant environments. Not sure how default Java GC behaves if it is limited to subset
  16. // of cores on a node.
  17. val useConcurrentAndIncrementalGC = launchEnv.get("SPARK_USE_CONC_INCR_GC").exists(_.toBoolean)
  18. if (useConcurrentAndIncrementalGC) {
  19. // In our expts, using (default) throughput collector has severe perf ramifications in
  20. // multi-tenant machines
  21. javaOpts += "-XX:+UseConcMarkSweepGC"
  22. javaOpts += "-XX:MaxTenuringThreshold=31"
  23. javaOpts += "-XX:SurvivorRatio=8"
  24. javaOpts += "-XX:+CMSIncrementalMode"
  25. javaOpts += "-XX:+CMSIncrementalPacing"
  26. javaOpts += "-XX:CMSIncrementalDutyCycleMin=0"
  27. javaOpts += "-XX:CMSIncrementalDutyCycle=10"
  28. }
  29. 。。。。。。。。。。。。。。。。。。。。。。。。。。。
  30. val amClass =
  31. // 集群环境用org.apache.spark.deploy.yarn.ApplicationMaster
  32. // 否则用org.apache.spark.deploy.yarn.ExecutorLauncher
  33. if (isClusterMode) {
  34. Utils.classForName("org.apache.spark.deploy.yarn.ApplicationMaster").getName
  35. } else {
  36. Utils.classForName("org.apache.spark.deploy.yarn.ExecutorLauncher").getName
  37. }
  38. val amArgs =
  39. Seq(amClass) ++ userClass ++ userJar ++ primaryPyFile ++ primaryRFile ++ userArgs ++
  40. Seq("--properties-file",
  41. buildPath(Environment.PWD.$$(), LOCALIZED_CONF_DIR, SPARK_CONF_FILE)) ++
  42. Seq("--dist-cache-conf",
  43. buildPath(Environment.PWD.$$(), LOCALIZED_CONF_DIR, DIST_CACHE_CONF_FILE))
  44. // Command for the ApplicationMaster
  45. val commands = prefixEnv ++
  46. // 使用java命令运行一个java进程
  47. Seq(Environment.JAVA_HOME.$$() + "/bin/java", "-server") ++
  48. javaOpts ++ amArgs ++
  49. Seq(
  50. "1>", ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout",
  51. "2>", ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stderr")
  52. 。。。。。。。。。。。。
  53. // send the acl settings into YARN to control who has access via YARN interfaces
  54. val securityManager = new SecurityManager(sparkConf)
  55. amContainer.setApplicationACLs(
  56. YarnSparkHadoopUtil.getApplicationAclsForYarn(securityManager).asJava)
  57. setupSecurityToken(amContainer)
  58. amContainer

4、ApplicationMaster-启动Driver线程

我们已经看到ApplicationMaster被加载进来,所以接下来从ApplicationMaster入手。

  • 1、创建ApplicationMaster
    • 创建YarnRMClient
      • 创建createAMRMClient,并启动
  • 2、启动am进程: master.run()

    • 通过--class参数判断集群模式,分别创建Driver或Executor
    • 创建Driver:runDriver
      • startUserApplication启动用户线程
        • 获取main方法准备一个线程
        • 调用main方法 invoke(mainMethod)
        • Driver线程启动
      • 等待SparkContext上下文环境准备完成 :hreadUtils.awaitResult(sparkContextPromise.future,

    Duration(totalWaitTime, TimeUnit.MILLISECONDS))

main方法进入,创建ApplicationMaster

  1. def main(args: Array[String]): Unit = {
  2. SignalUtils.registerLogger(log)
  3. val amArgs = new ApplicationMasterArguments(args)
  4. val sparkConf = new SparkConf()
  5. if (amArgs.propertiesFile != null) {
  6. Utils.getPropertiesFromFile(amArgs.propertiesFile).foreach { case (k, v) =>
  7. sparkConf.set(k, v)
  8. }
  9. }
  10. // Set system properties for each config entry. This covers two use cases:
  11. // - The default configuration stored by the SparkHadoopUtil class
  12. // - The user application creating a new SparkConf in cluster mode
  13. //
  14. // Both cases create a new SparkConf object which reads these configs from system properties.
  15. sparkConf.getAll.foreach { case (k, v) =>
  16. sys.props(k) = v
  17. }
  18. val yarnConf = new YarnConfiguration(SparkHadoopUtil.newConfiguration(sparkConf))
  19. // new一个ApplicationMaster
  20. master = new ApplicationMaster(amArgs, sparkConf, yarnConf)
  21. val ugi = sparkConf.get(PRINCIPAL) match {
  22. // We only need to log in with the keytab in cluster mode. In client mode, the driver
  23. // handles the user keytab.
  24. case Some(principal) if master.isClusterMode =>
  25. val originalCreds = UserGroupInformation.getCurrentUser().getCredentials()
  26. SparkHadoopUtil.get.loginUserFromKeytab(principal, sparkConf.get(KEYTAB).orNull)
  27. val newUGI = UserGroupInformation.getCurrentUser()
  28. if (master.appAttemptId == null || master.appAttemptId.getAttemptId > 1) {
  29. // Re-obtain delegation tokens if this is not a first attempt, as they might be outdated
  30. // as of now. Add the fresh tokens on top of the original user's credentials (overwrite).
  31. // Set the context class loader so that the token manager has access to jars
  32. // distributed by the user.
  33. Utils.withContextClassLoader(master.userClassLoader) {
  34. val credentialManager = new HadoopDelegationTokenManager(sparkConf, yarnConf, null)
  35. credentialManager.obtainDelegationTokens(originalCreds)
  36. }
  37. }
  38. // Transfer the original user's tokens to the new user, since it may contain needed tokens
  39. // (such as those user to connect to YARN).
  40. newUGI.addCredentials(originalCreds)
  41. newUGI
  42. case _ =>
  43. SparkHadoopUtil.get.createSparkUser()
  44. }
  45. ugi.doAs(new PrivilegedExceptionAction[Unit]() {
  46. override def run(): Unit = System.exit(master.run())
  47. })
  48. }

new 一个YarnRMClient,ResourceManager的Client客户端。

  1. private[spark] class ApplicationMaster(
  2. args: ApplicationMasterArguments,
  3. sparkConf: SparkConf,
  4. yarnConf: YarnConfiguration) extends Logging {
  5. // TODO: Currently, task to container is computed once (TaskSetManager) - which need not be
  6. // optimal as more containers are available. Might need to handle this better.
  7. private val appAttemptId =
  8. if (System.getenv(ApplicationConstants.Environment.CONTAINER_ID.name()) != null) {
  9. YarnSparkHadoopUtil.getContainerId.getApplicationAttemptId()
  10. } else {
  11. null
  12. }
  13. private val isClusterMode = args.userClass != null
  14. private val securityMgr = new SecurityManager(sparkConf)
  15. private var metricsSystem: Option[MetricsSystem] = None
  16. private val userClassLoader = {
  17. val classpath = Client.getUserClasspath(sparkConf)
  18. val urls = classpath.map { entry =>
  19. new URL("file:" + new File(entry.getPath()).getAbsolutePath())
  20. }
  21. if (isClusterMode) {
  22. if (Client.isUserClassPathFirst(sparkConf, isDriver = true)) {
  23. new ChildFirstURLClassLoader(urls, Utils.getContextOrSparkClassLoader)
  24. } else {
  25. new MutableURLClassLoader(urls, Utils.getContextOrSparkClassLoader)
  26. }
  27. } else {
  28. new MutableURLClassLoader(urls, Utils.getContextOrSparkClassLoader)
  29. }
  30. }
  31. private val client = new YarnRMClient()

在YarnRMClient中,创建

  1. private[spark] class YarnRMClient extends Logging {
  2. private var amClient: AMRMClient[ContainerRequest] = _
  3. private var uiHistoryAddress: String = _
  4. private var registered: Boolean = false
  5. def register(
  6. driverHost: String,
  7. driverPort: Int,
  8. conf: YarnConfiguration,
  9. sparkConf: SparkConf,
  10. uiAddress: Option[String],
  11. uiHistoryAddress: String): Unit = {
  12. // 创建AM、RM通信的客户端并启动
  13. amClient = AMRMClient.createAMRMClient()
  14. amClient.init(conf)
  15. amClient.start()
  16. this.uiHistoryAddress = uiHistoryAddress
  17. val trackingUrl = uiAddress.getOrElse {
  18. if (sparkConf.get(ALLOW_HISTORY_SERVER_TRACKING_URL)) uiHistoryAddress else ""
  19. }
  20. logInfo("Registering the ApplicationMaster")
  21. synchronized {
  22. amClient.registerApplicationMaster(driverHost, driverPort, trackingUrl)
  23. registered = true
  24. }
  25. }

准备好之后,启动ApplicationMaster进程

  1. ugi.doAs(new PrivilegedExceptionAction[Unit]() {
  2. override def run(): Unit = System.exit(master.run())
  3. })

通过判断—class参数判断是否是集群模式,并且创建Driver还是Executor

  1. try {
  2. val attemptID = if (isClusterMode) {
  3. // Set the web ui port to be ephemeral for yarn so we don't conflict with
  4. // other spark processes running on the same box
  5. System.setProperty(UI_PORT.key, "0")
  6. // Set the master and deploy mode property to match the requested mode.
  7. System.setProperty("spark.master", "yarn")
  8. System.setProperty(SUBMIT_DEPLOY_MODE.key, "cluster")
  9. // Set this internal configuration if it is running on cluster mode, this
  10. // configuration will be checked in SparkContext to avoid misuse of yarn cluster mode.
  11. System.setProperty("spark.yarn.app.id", appAttemptId.getApplicationId().toString())
  12. Option(appAttemptId.getAttemptId.toString)
  13. } else {
  14. None
  15. }
  16. new CallerContext(
  17. "APPMASTER", sparkConf.get(APP_CALLER_CONTEXT),
  18. Option(appAttemptId.getApplicationId.toString), attemptID).setCurrentContext()
  19. logInfo("ApplicationAttemptId: " + appAttemptId)
  20. // This shutdown hook should run *after* the SparkContext is shut down.
  21. val priority = ShutdownHookManager.SPARK_CONTEXT_SHUTDOWN_PRIORITY - 1
  22. ShutdownHookManager.addShutdownHook(priority) { () =>
  23. val maxAppAttempts = client.getMaxRegAttempts(sparkConf, yarnConf)
  24. val isLastAttempt = appAttemptId.getAttemptId() >= maxAppAttempts
  25. if (!finished) {
  26. // The default state of ApplicationMaster is failed if it is invoked by shut down hook.
  27. // This behavior is different compared to 1.x version.
  28. // If user application is exited ahead of time by calling System.exit(N), here mark
  29. // this application as failed with EXIT_EARLY. For a good shutdown, user shouldn't call
  30. // System.exit(0) to terminate the application.
  31. finish(finalStatus,
  32. ApplicationMaster.EXIT_EARLY,
  33. "Shutdown hook called before final status was reported.")
  34. }
  35. if (!unregistered) {
  36. // we only want to unregister if we don't want the RM to retry
  37. if (finalStatus == FinalApplicationStatus.SUCCEEDED || isLastAttempt) {
  38. unregister(finalStatus, finalMsg)
  39. cleanupStagingDir(new Path(System.getenv("SPARK_YARN_STAGING_DIR")))
  40. }
  41. }
  42. }
  43. // 判断是否是集群模式,通过--class参数是否设置判断
  44. // private val isClusterMode = args.userClass != null
  45. if (isClusterMode) {
  46. runDriver()
  47. } else {
  48. runExecutorLauncher()
  49. }
  50. } catch {
  51. case e: Exception =>
  52. // catch everything else if not specifically handled
  53. logError("Uncaught exception: ", e)
  54. finish(FinalApplicationStatus.FAILED,
  55. ApplicationMaster.EXIT_UNCAUGHT_EXCEPTION,
  56. "Uncaught exception: " + StringUtils.stringifyException(e))
  57. } finally {
  58. try {
  59. metricsSystem.foreach { ms =>
  60. ms.report()
  61. ms.stop()
  62. }
  63. } catch {
  64. case e: Exception =>
  65. logWarning("Exception during stopping of the metric system: ", e)
  66. }
  67. }
  68. exitCode
  69. }

启动Driver线程

  1. private def runDriver(): Unit = {
  2. addAmIpFilter(None, System.getenv(ApplicationConstants.APPLICATION_WEB_PROXY_BASE_ENV))
  3. // 1、startUserApplication先启动用户应用程序,将--class的类加载并启动
  4. userClassThread = startUserApplication()
  5. // This a bit hacky, but we need to wait until the spark.driver.port property has
  6. // been set by the Thread executing the user class.
  7. logInfo("Waiting for spark context initialization...")
  8. val totalWaitTime = sparkConf.get(AM_MAX_WAIT_TIME)
  9. try {
  10. // 2、当前线程阻塞等待context执行
  11. val sc = ThreadUtils.awaitResult(sparkContextPromise.future,
  12. Duration(totalWaitTime, TimeUnit.MILLISECONDS))
  13. if (sc != null) {
  14. val rpcEnv = sc.env.rpcEnv
  15. val userConf = sc.getConf
  16. val host = userConf.get(DRIVER_HOST_ADDRESS)
  17. val port = userConf.get(DRIVER_PORT)
  18. registerAM(host, port, userConf, sc.ui.map(_.webUrl), appAttemptId)
  19. val driverRef = rpcEnv.setupEndpointRef(
  20. RpcAddress(host, port),
  21. YarnSchedulerBackend.ENDPOINT_NAME)
  22. createAllocator(driverRef, userConf, rpcEnv, appAttemptId, distCacheConf)
  23. } else {
  24. // Sanity check; should never happen in normal operation, since sc should only be null
  25. // if the user app did not create a SparkContext.
  26. throw new IllegalStateException("User did not initialize spark context!")
  27. }
  28. resumeDriver()
  29. userClassThread.join()
  30. } catch {
  31. case e: SparkException if e.getCause().isInstanceOf[TimeoutException] =>
  32. logError(
  33. s"SparkContext did not initialize after waiting for $totalWaitTime ms. " +
  34. "Please check earlier log output for errors. Failing the application.")
  35. finish(FinalApplicationStatus.FAILED,
  36. ApplicationMaster.EXIT_SC_NOT_INITED,
  37. "Timed out waiting for SparkContext.")
  38. } finally {
  39. resumeDriver()
  40. }
  41. }

启动用户线程startUserApplication

  1. private def startUserApplication(): Thread = {
  2. logInfo("Starting the user application in a separate Thread")
  3. var userArgs = args.userArgs
  4. if (args.primaryPyFile != null && args.primaryPyFile.endsWith(".py")) {
  5. // When running pyspark, the app is run using PythonRunner. The second argument is the list
  6. // of files to add to PYTHONPATH, which Client.scala already handles, so it's empty.
  7. userArgs = Seq(args.primaryPyFile, "") ++ userArgs
  8. }
  9. if (args.primaryRFile != null &&
  10. (args.primaryRFile.endsWith(".R") || args.primaryRFile.endsWith(".r"))) {
  11. // TODO(davies): add R dependencies here
  12. }
  13. // 加载--class的类,获取main方法
  14. val mainMethod = userClassLoader.loadClass(args.userClass)
  15. .getMethod("main", classOf[Array[String]])
  16. val userThread = new Thread {
  17. override def run(): Unit = {
  18. try {
  19. if (!Modifier.isStatic(mainMethod.getModifiers)) {
  20. logError(s"Could not find static main method in object ${args.userClass}")
  21. finish(FinalApplicationStatus.FAILED, ApplicationMaster.EXIT_EXCEPTION_USER_CLASS)
  22. } else {
  23. // 调用main方法
  24. mainMethod.invoke(null, userArgs.toArray)
  25. finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS)
  26. logDebug("Done running user class")
  27. }
  28. } catch {
  29. case e: InvocationTargetException =>
  30. e.getCause match {
  31. case _: InterruptedException =>
  32. // Reporter thread can interrupt to stop user class
  33. case SparkUserAppException(exitCode) =>
  34. val msg = s"User application exited with status $exitCode"
  35. logError(msg)
  36. finish(FinalApplicationStatus.FAILED, exitCode, msg)
  37. case cause: Throwable =>
  38. logError("User class threw exception: " + cause, cause)
  39. finish(FinalApplicationStatus.FAILED,
  40. ApplicationMaster.EXIT_EXCEPTION_USER_CLASS,
  41. "User class threw exception: " + StringUtils.stringifyException(cause))
  42. }
  43. sparkContextPromise.tryFailure(e.getCause())
  44. } finally {
  45. // Notify the thread waiting for the SparkContext, in case the application did not
  46. // instantiate one. This will do nothing when the user code instantiates a SparkContext
  47. // (with the correct master), or when the user code throws an exception (due to the
  48. // tryFailure above).
  49. sparkContextPromise.trySuccess(null)
  50. }
  51. }
  52. }
  53. // 设置该线程为Driver线程并启动
  54. userThread.setContextClassLoader(userClassLoader)
  55. userThread.setName("Driver")
  56. userThread.start()
  57. userThread
  58. }

5、ApplicationMaster-启动Executor线程

从am.run开始说起

  • 创建Driver:runDriver

    • startUserApplication启动用户线程
      • 获取main方法准备一个线程
      • 调用main方法 invoke(mainMethod)
      • Driver线程启动
    • 等待SparkContext上下文环境准备完成 :hreadUtils.awaitResult(sparkContextPromise.future,
    • 准备rpc环境,向ResourceManager注册ApplicationMaster
      • rmClient.register(host, port, yarnConf, _sparkConf, uiAddress, historyAddress)
    • 创建分配器,分配资源 createAllocator(driverRef, userConf, rpcEnv, appAttemptId, distCacheConf)

      • rpc建立am到rm的端点
      • 分配资源 allocator.allocateResources()
        • am分配资源 amClient.allocate(progressIndicator)
        • 获取所有已经分配的容器
        • 处理分配 handleAllocatedContainers(allocatedContainers.asScala)
          • 运行所有容器 runAllocatedContainers
            • 启动executor线程 run()
            • startContainers
              • prepareEnvironment
              • prepareCommands 创建了一个YarnCoarseGrainedExecutorBackend
              • nmClient发送启动容器命令 ```scala private def runDriver(): Unit = { addAmIpFilter(None, System.getenv(ApplicationConstants.APPLICATION_WEB_PROXY_BASE_ENV)) // 1、startUserApplication先启动用户应用程序,将—class的类加载并启动 userClassThread = startUserApplication()

      // This a bit hacky, but we need to wait until the spark.driver.port property has // been set by the Thread executing the user class. logInfo(“Waiting for spark context initialization…”) val totalWaitTime = sparkConf.get(AM_MAX_WAIT_TIME) try { // 2、当前线程阻塞等待context执行 val sc = ThreadUtils.awaitResult(sparkContextPromise.future, Duration(totalWaitTime, TimeUnit.MILLISECONDS)) if (sc != null) { // 3、初始化rpc远程调用信息 val rpcEnv = sc.env.rpcEnv

      val userConf = sc.getConf val host = userConf.get(DRIVER_HOST_ADDRESS) val port = userConf.get(DRIVER_PORT)

      // 向ResourceManager注册AM,建立通信 registerAM(host, port, userConf, sc.ui.map(_.webUrl), appAttemptId)

      val driverRef = rpcEnv.setupEndpointRef(

      1. RpcAddress(host, port),
      2. YarnSchedulerBackend.ENDPOINT_NAME)

      createAllocator(driverRef, userConf, rpcEnv, appAttemptId, distCacheConf) } else { // Sanity check; should never happen in normal operation, since sc should only be null // if the user app did not create a SparkContext. throw new IllegalStateException(“User did not initialize spark context!”) } resumeDriver() userClassThread.join() } catch { case e: SparkException if e.getCause().isInstanceOf[TimeoutException] => logError(

      1. s"SparkContext did not initialize after waiting for $totalWaitTime ms. " +
      2. "Please check earlier log output for errors. Failing the application.")

      finish(FinalApplicationStatus.FAILED,

      1. ApplicationMaster.EXIT_SC_NOT_INITED,
      2. "Timed out waiting for SparkContext.")

      } finally { resumeDriver() } }

      1. 分配资源
      2. ```scala
      3. def allocateResources(): Unit = synchronized {
      4. // 1、更新资源请求
      5. updateResourceRequests()
      6. val progressIndicator = 0.1f
      7. // Poll the ResourceManager. This doubles as a heartbeat if there are no pending container
      8. // requests.
      9. // 1、amClient分配资源
      10. val allocateResponse = amClient.allocate(progressIndicator)
      11. // 2、获取所有已分配的容器
      12. val allocatedContainers = allocateResponse.getAllocatedContainers()
      13. allocatorBlacklistTracker.setNumClusterNodes(allocateResponse.getNumClusterNodes)
      14. if (allocatedContainers.size > 0) {
      15. logDebug(("Allocated containers: %d. Current executor count: %d. " +
      16. "Launching executor count: %d. Cluster resources: %s.")
      17. .format(
      18. allocatedContainers.size,
      19. runningExecutors.size,
      20. numExecutorsStarting.get,
      21. allocateResponse.getAvailableResources))
      22. handleAllocatedContainers(allocatedContainers.asScala)
      23. }
      24. val completedContainers = allocateResponse.getCompletedContainersStatuses()
      25. if (completedContainers.size > 0) {
      26. logDebug("Completed %d containers".format(completedContainers.size))
      27. processCompletedContainers(completedContainers.asScala)
      28. logDebug("Finished processing %d completed containers. Current running executor count: %d."
      29. .format(completedContainers.size, runningExecutors.size))
      30. }
      31. }

运行所有容器:如果当前容器数量小于需要的数量,从线程池中再拿线程启动

  1. private def runAllocatedContainers(containersToUse: ArrayBuffer[Container]): Unit = {
  2. for (container <- containersToUse) {
  3. executorIdCounter += 1
  4. val executorHostname = container.getNodeId.getHost
  5. val containerId = container.getId
  6. val executorId = executorIdCounter.toString
  7. assert(container.getResource.getMemory >= resource.getMemory)
  8. logInfo(s"Launching container $containerId on host $executorHostname " +
  9. s"for executor with ID $executorId")
  10. def updateInternalState(): Unit = synchronized {
  11. runningExecutors.add(executorId)
  12. numExecutorsStarting.decrementAndGet()
  13. executorIdToContainer(executorId) = container
  14. containerIdToExecutorId(container.getId) = executorId
  15. val containerSet = allocatedHostToContainersMap.getOrElseUpdate(executorHostname,
  16. new HashSet[ContainerId])
  17. containerSet += containerId
  18. allocatedContainerToHostMap.put(containerId, executorHostname)
  19. }
  20. // 如果正在运行的executor的数量小于目标所需要的数量,从launcherPool一个线程池中再次执行线程
  21. if (runningExecutors.size() < targetNumExecutors) {
  22. numExecutorsStarting.incrementAndGet()
  23. if (launchContainers) {
  24. launcherPool.execute(() => {
  25. try {
  26. new ExecutorRunnable(
  27. Some(container),
  28. conf,
  29. sparkConf,
  30. driverUrl,
  31. executorId,
  32. executorHostname,
  33. executorMemory,
  34. executorCores,
  35. appAttemptId.getApplicationId.toString,
  36. securityMgr,
  37. localResources,
  38. ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID // use until fully supported
  39. // 每个executor线程要运行的run方法
  40. ).run()
  41. updateInternalState()
  42. } catch {
  43. case e: Throwable =>
  44. numExecutorsStarting.decrementAndGet()
  45. if (NonFatal(e)) {
  46. logError(s"Failed to launch executor $executorId on container $containerId", e)
  47. // Assigned container should be released immediately
  48. // to avoid unnecessary resource occupation.
  49. amClient.releaseAssignedContainer(containerId)
  50. } else {
  51. throw e
  52. }
  53. }
  54. })
  55. } else {
  56. // For test only
  57. updateInternalState()
  58. }
  59. } else {
  60. logInfo(("Skip launching executorRunnable as running executors count: %d " +
  61. "reached target executors count: %d.").format(
  62. runningExecutors.size, targetNumExecutors))
  63. }
  64. }
  65. }

启动executor

  1. def run(): Unit = {
  2. logDebug("Starting Executor Container")
  3. // 创建NodeManager客户端
  4. nmClient = NMClient.createNMClient()
  5. nmClient.init(conf)
  6. nmClient.start()
  7. // 告诉NodeManager启动容器
  8. startContainer()
  9. }

startContainer()

  1. def startContainer(): java.util.Map[String, ByteBuffer] = {
  2. val ctx = Records.newRecord(classOf[ContainerLaunchContext])
  3. .asInstanceOf[ContainerLaunchContext]
  4. // 准备环境
  5. val env = prepareEnvironment().asJava
  6. ctx.setLocalResources(localResources.asJava)
  7. ctx.setEnvironment(env)
  8. val credentials = UserGroupInformation.getCurrentUser().getCredentials()
  9. val dob = new DataOutputBuffer()
  10. credentials.writeTokenStorageToStream(dob)
  11. ctx.setTokens(ByteBuffer.wrap(dob.getData()))
  12. // 准备启动命令
  13. val commands = prepareCommand()
  14. ctx.setCommands(commands.asJava)
  15. ctx.setApplicationACLs(
  16. YarnSparkHadoopUtil.getApplicationAclsForYarn(securityMgr).asJava)
  17. // If external shuffle service is enabled, register with the Yarn shuffle service already
  18. // started on the NodeManager and, if authentication is enabled, provide it with our secret
  19. // key for fetching shuffle files later
  20. if (sparkConf.get(SHUFFLE_SERVICE_ENABLED)) {
  21. val secretString = securityMgr.getSecretKey()
  22. val secretBytes =
  23. if (secretString != null) {
  24. // This conversion must match how the YarnShuffleService decodes our secret
  25. JavaUtils.stringToBytes(secretString)
  26. } else {
  27. // Authentication is not enabled, so just provide dummy metadata
  28. ByteBuffer.allocate(0)
  29. }
  30. ctx.setServiceData(Collections.singletonMap("spark_shuffle", secretBytes))
  31. }
  32. // Send the start request to the ContainerManager
  33. try {
  34. // 向nodeManger 发送启动命令
  35. nmClient.startContainer(container.get, ctx)
  36. } catch {
  37. case ex: Exception =>
  38. throw new SparkException(s"Exception while starting container ${container.get.getId}" +
  39. s" on host $hostname", ex)
  40. }
  41. }

在prepareCommands中,创建启动命令,启动的是org.apache.spark.executor.YarnCoarseGrainedExecutorBackend

  1. val commands = prefixEnv ++
  2. Seq(Environment.JAVA_HOME.$$() + "/bin/java", "-server") ++
  3. javaOpts ++
  4. Seq("org.apache.spark.executor.YarnCoarseGrainedExecutorBackend",
  5. "--driver-url", masterAddress,
  6. "--executor-id", executorId,
  7. "--hostname", hostname,
  8. "--cores", executorCores.toString,
  9. "--app-id", appId,
  10. "--resourceProfileId", resourceProfileId.toString) ++
  11. userClassPath ++
  12. Seq(
  13. s"1>${ApplicationConstants.LOG_DIR_EXPANSION_VAR}/stdout",
  14. s"2>${ApplicationConstants.LOG_DIR_EXPANSION_VAR}/stderr")
  15. // TODO: it would be nicer to just make sure there are no null commands here
  16. commands.map(s => if (s == null) "null" else s).toList

6、ApplicationMaster-建立通信环境以及Executor计算节点

从org.apache.spark.executor.YarnCoarseGrainedExecutorBackend入手,

  • 1、创建YarnCoarseGrainedExecutorBackend执行环境,解析参数并run起来
    • 1、根据rpcEnv创建远程调用信息 RpcEnv.create
    • 2、创建远程端点的引用: driver: RpcEndpointRef = fetcher.setupEndpointRefByURI(arguments.driverUrl)
    • 3、创建环境 SparkEnv.createExecutorEnv
      • new NettyRpcEnvFactory().create(config)
        • nettyEnv.startServer(config.bindAddress, actualPort)
          • 1、创建netty的server : server = transportContext.createServer(bindAddress, port, bootstraps)
          • 2、注册远程通信端点:dispatcher.registerRpcEndpoint(RpcEndpointVerifier.NAME, new RpcEndpointVerifier(this, dispatcher))}
          • 3、注册端点,获取DedicatedMessageLoop不停收发消息
            • 1、注册inbox信箱,维护线程池来接收信息
              • 1、创建信箱的时候,会发送启动消息
              • 2、Inbox有process方法,处理各种消息,现在处理OnStart
              • 3、CoarseGrainedExecutorBackend处理,得到driver
              • 4、driver.ask()/ / driver向我们的连接的Driver发送RegisterExecutor的请求
              • 5、SparkContext中有一个属性private var _schedulerBackend: ,用来和我们后台通信
              • 6、CoarseGrainedSchedulerBackend的receiveAndReoly(),处理RegisterExecutor的请求
                • 1、totalCoreCount.addAndGet(cores)、totalRegisteredExecutors.addAndGet(1)自己增加核数
                • 2、返回driver消息true: context.reply(true)
              • 7、信箱收到之后,如果success,自己发送一条消息:case Success(_) =>self.send(RegisteredExecutor)表示注册成功
              • 8、Inbox自己receive之后,开始new一个Executor,发送LaunchedExecutor消息
              • 9、接收LaunchedExecutor消息,makeOffers()执行任务
    • 4、和executor端建立端点
  • 2、ApplicationMaster的resumeDriver()、userClassThread.join()让应用程序继续执行。

【现在开始是两条线开始走,一条计算资源、一条执行任务】

  • 3、SparkContext的后置处理_taskScheduler.postStartHook(),抽象方法,来到YarnClusterScheduler中,ApplicationMaster.sparkContextInitialized(sc)初始化上下文环境
  • 4、resumeDriver()中通知上下文sparkContextPromise.notify(),让Driver程序继续往下,执行用户的应用程序
    1. // 和Executor建立端点
    2. env.rpcEnv.setupEndpoint("Executor",
    3. backendCreateFn(env.rpcEnv, arguments, env, cfg.resourceProfile))
    4. arguments.workerUrl.foreach { url =>
    5. env.rpcEnv.setupEndpoint("WorkerWatcher", new WorkerWatcher(env.rpcEnv, url))
    6. }

main方法入站:

  1. def main(args: Array[String]): Unit = {
  2. // 创建一个YarnCoarseGrainedExecutorBackend,执行Executor
  3. val createFn: (RpcEnv, CoarseGrainedExecutorBackend.Arguments, SparkEnv, ResourceProfile) =>
  4. CoarseGrainedExecutorBackend = { case (rpcEnv, arguments, env, resourceProfile) =>
  5. new YarnCoarseGrainedExecutorBackend(rpcEnv, arguments.driverUrl, arguments.executorId,
  6. arguments.bindAddress, arguments.hostname, arguments.cores, arguments.userClassPath, env,
  7. arguments.resourcesFileOpt, resourceProfile)
  8. }
  9. // 解析参数
  10. val backendArgs = CoarseGrainedExecutorBackend.parseArguments(args,
  11. this.getClass.getCanonicalName.stripSuffix("$"))
  12. // run起来,带着环境
  13. CoarseGrainedExecutorBackend.run(backendArgs, createFn)
  14. System.exit(0)
  15. }

run

  1. def run(
  2. arguments: Arguments,
  3. backendCreateFn: (RpcEnv, Arguments, SparkEnv, ResourceProfile) =>
  4. CoarseGrainedExecutorBackend): Unit = {
  5. Utils.initDaemon(log)
  6. SparkHadoopUtil.get.runAsSparkUser { () =>
  7. // Debug code
  8. Utils.checkHost(arguments.hostname)
  9. // Bootstrap to fetch the driver's Spark properties.
  10. val executorConf = new SparkConf
  11. // 根据rpcEnv创建远程调用信息,
  12. val fetcher = RpcEnv.create(
  13. "driverPropsFetcher",
  14. arguments.bindAddress,
  15. arguments.hostname,
  16. -1,
  17. executorConf,
  18. new SecurityManager(executorConf),
  19. numUsableCores = 0,
  20. clientMode = true)
  21. // 创建远程端点引用
  22. var driver: RpcEndpointRef = null
  23. val nTries = 3
  24. for (i <- 0 until nTries if driver == null) {
  25. try {
  26. // 根据url启动端点
  27. driver = fetcher.setupEndpointRefByURI(arguments.driverUrl)
  28. } catch {
  29. case e: Throwable => if (i == nTries - 1) {
  30. throw e
  31. }
  32. }
  33. }
  34. val cfg = driver.askSync[SparkAppConfig](RetrieveSparkAppConfig(arguments.resourceProfileId))
  35. val props = cfg.sparkProperties ++ Seq[(String, String)](("spark.app.id", arguments.appId))
  36. fetcher.shutdown()
  37. // Create SparkEnv using properties we fetched from the driver.
  38. val driverConf = new SparkConf()
  39. for ((key, value) <- props) {
  40. // this is required for SSL in standalone mode
  41. if (SparkConf.isExecutorStartupConf(key)) {
  42. driverConf.setIfMissing(key, value)
  43. } else {
  44. driverConf.set(key, value)
  45. }
  46. }
  47. cfg.hadoopDelegationCreds.foreach { tokens =>
  48. SparkHadoopUtil.get.addDelegationTokens(tokens, driverConf)
  49. }
  50. driverConf.set(EXECUTOR_ID, arguments.executorId)
  51. // 创建Executor的环境
  52. val env = SparkEnv.createExecutorEnv(driverConf, arguments.executorId, arguments.bindAddress,
  53. arguments.hostname, arguments.cores, cfg.ioEncryptionKey, isLocal = false)
  54. // 和Executor建立端点
  55. env.rpcEnv.setupEndpoint("Executor",
  56. backendCreateFn(env.rpcEnv, arguments, env, cfg.resourceProfile))
  57. arguments.workerUrl.foreach { url =>
  58. env.rpcEnv.setupEndpoint("WorkerWatcher", new WorkerWatcher(env.rpcEnv, url))
  59. }
  60. env.rpcEnv.awaitTermination()
  61. }
  62. }

创建Executor环境createExecutorEnv

  1. private[spark] def createExecutorEnv(
  2. conf: SparkConf,
  3. executorId: String,
  4. bindAddress: String,
  5. hostname: String,
  6. numCores: Int,
  7. ioEncryptionKey: Option[Array[Byte]],
  8. isLocal: Boolean): SparkEnv = {
  9. // 创建环境
  10. val env = create(
  11. conf,
  12. executorId,
  13. bindAddress,
  14. hostname,
  15. None,
  16. isLocal,
  17. numCores,
  18. ioEncryptionKey
  19. )
  20. SparkEnv.set(env)
  21. env
  22. }

create方法

  1. private def create(
  2. conf: SparkConf,
  3. executorId: String,
  4. bindAddress: String,
  5. advertiseAddress: String,
  6. port: Option[Int],
  7. isLocal: Boolean,
  8. numUsableCores: Int,
  9. ioEncryptionKey: Option[Array[Byte]],
  10. listenerBus: LiveListenerBus = null,
  11. mockOutputCommitCoordinator: Option[OutputCommitCoordinator] = None): SparkEnv = {
  12. val isDriver = executorId == SparkContext.DRIVER_IDENTIFIER
  13. // Listener bus is only used on the driver
  14. if (isDriver) {
  15. assert(listenerBus != null, "Attempted to create driver SparkEnv with null listener bus!")
  16. }
  17. val authSecretFileConf = if (isDriver) AUTH_SECRET_FILE_DRIVER else AUTH_SECRET_FILE_EXECUTOR
  18. val securityManager = new SecurityManager(conf, ioEncryptionKey, authSecretFileConf)
  19. if (isDriver) {
  20. securityManager.initializeAuth()
  21. }
  22. ioEncryptionKey.foreach { _ =>
  23. if (!securityManager.isEncryptionEnabled()) {
  24. logWarning("I/O encryption enabled without RPC encryption: keys will be visible on the " +
  25. "wire.")
  26. }
  27. }
  28. val systemName = if (isDriver) driverSystemName else executorSystemName
  29. // 1、通过netty创建远程调用环境,非常重要
  30. val rpcEnv = RpcEnv.create(systemName, bindAddress, advertiseAddress, port.getOrElse(-1), conf,
  31. securityManager, numUsableCores, !isDriver)
  32. // Figure out which port RpcEnv actually bound to in case the original port is 0 or occupied.
  33. if (isDriver) {
  34. conf.set(DRIVER_PORT, rpcEnv.address.port)
  35. }

通过NettyRpcEnvFactory创建netty环境

  1. def create(
  2. name: String,
  3. bindAddress: String,
  4. advertiseAddress: String,
  5. port: Int,
  6. conf: SparkConf,
  7. securityManager: SecurityManager,
  8. numUsableCores: Int,
  9. clientMode: Boolean): RpcEnv = {
  10. val config = RpcEnvConfig(conf, name, bindAddress, advertiseAddress, port, securityManager,
  11. numUsableCores, clientMode)
  12. // 创建环境
  13. new NettyRpcEnvFactory().create(config)
  14. }

Netty的创建

  1. def create(config: RpcEnvConfig): RpcEnv = {
  2. val sparkConf = config.conf
  3. // Use JavaSerializerInstance in multiple threads is safe. However, if we plan to support
  4. // KryoSerializer in future, we have to use ThreadLocal to store SerializerInstance
  5. val javaSerializerInstance =
  6. new JavaSerializer(sparkConf).newInstance().asInstanceOf[JavaSerializerInstance]
  7. val nettyEnv =
  8. new NettyRpcEnv(sparkConf, javaSerializerInstance, config.advertiseAddress,
  9. config.securityManager, config.numUsableCores)
  10. if (!config.clientMode) {
  11. val startNettyRpcEnv: Int => (NettyRpcEnv, Int) = { actualPort =>
  12. // 启动netty服务器
  13. nettyEnv.startServer(config.bindAddress, actualPort)
  14. (nettyEnv, nettyEnv.address.port)
  15. }
  16. try {
  17. Utils.startServiceOnPort(config.port, startNettyRpcEnv, sparkConf, config.name)._1
  18. } catch {
  19. case NonFatal(e) =>
  20. nettyEnv.shutdown()
  21. throw e
  22. }
  23. }
  24. nettyEnv
  25. }

注册通信端点,获取消息循环器DedicatedMessageLoop

  1. def registerRpcEndpoint(name: String, endpoint: RpcEndpoint): NettyRpcEndpointRef = {
  2. // 1、获取通讯端点地址
  3. val addr = RpcEndpointAddress(nettyEnv.address, name)
  4. // 2、获取端点引用
  5. val endpointRef = new NettyRpcEndpointRef(nettyEnv.conf, addr, nettyEnv)
  6. synchronized {
  7. if (stopped) {
  8. throw new IllegalStateException("RpcEnv has been stopped")
  9. }
  10. if (endpoints.containsKey(name)) {
  11. throw new IllegalArgumentException(s"There is already an RpcEndpoint called $name")
  12. }
  13. // This must be done before assigning RpcEndpoint to MessageLoop, as MessageLoop sets Inbox be
  14. // active when registering, and endpointRef must be put into endpointRefs before onStart is
  15. // called.
  16. endpointRefs.put(endpoint, endpointRef)
  17. // 3、获取消息循环器,不停收发消息
  18. var messageLoop: MessageLoop = null
  19. try {
  20. messageLoop = endpoint match {
  21. case e: IsolatedRpcEndpoint =>
  22. // 创建DedicatedMessageLoop
  23. new DedicatedMessageLoop(name, e, this)
  24. case _ =>
  25. sharedLoop.register(name, endpoint)
  26. sharedLoop
  27. }
  28. endpoints.put(name, messageLoop)
  29. } catch {
  30. case NonFatal(e) =>
  31. endpointRefs.remove(endpoint)
  32. throw e
  33. }
  34. }
  35. endpointRef
  36. }

获取消息循环器之后,内部有inbox收件箱,持有线程池来监听事件

  1. private class DedicatedMessageLoop(
  2. name: String,
  3. endpoint: IsolatedRpcEndpoint,
  4. dispatcher: Dispatcher)
  5. extends MessageLoop(dispatcher) {
  6. // 创建收件箱
  7. private val inbox = new Inbox(name, endpoint)
  8. // 维护线程池
  9. override protected val threadpool = if (endpoint.threadCount() > 1) {
  10. ThreadUtils.newDaemonCachedThreadPool(s"dispatcher-$name", endpoint.threadCount())
  11. } else {
  12. ThreadUtils.newDaemonSingleThreadExecutor(s"dispatcher-$name")
  13. }
  14. (1 to endpoint.threadCount()).foreach { _ =>
  15. // 循环处理接收到的信息
  16. threadpool.submit(receiveLoopRunnable)
  17. }
  18. // Mark active to handle the OnStart message.
  19. setActive(inbox)
  20. override def post(endpointName: String, message: InboxMessage): Unit = {
  21. require(endpointName == name)
  22. inbox.post(message)
  23. setActive(inbox)
  24. }
  25. override def unregister(endpointName: String): Unit = synchronized {
  26. require(endpointName == name)
  27. inbox.stop()
  28. // Mark active to handle the OnStop message.
  29. setActive(inbox)
  30. setActive(MessageLoop.PoisonPill)
  31. threadpool.shutdown()
  32. }
  33. }

Inbox信箱处理各种事件

  1. private[netty] class Inbox(val endpointName: String, val endpoint: RpcEndpoint)
  2. extends Logging {
  3. inbox => // Give this an alias so we can use it more clearly in closures.
  4. @GuardedBy("this")
  5. protected val messages = new java.util.LinkedList[InboxMessage]()
  6. /** True if the inbox (and its associated endpoint) is stopped. */
  7. @GuardedBy("this")
  8. private var stopped = false
  9. /** Allow multiple threads to process messages at the same time. */
  10. @GuardedBy("this")
  11. private var enableConcurrent = false
  12. /** The number of threads processing messages for this inbox. */
  13. @GuardedBy("this")
  14. private var numActiveThreads = 0
  15. // OnStart should be the first message to process
  16. // 一创建就处理启动消息
  17. inbox.synchronized {
  18. messages.add(OnStart)
  19. }
  20. /**
  21. * Process stored messages.
  22. */
  23. def process(dispatcher: Dispatcher): Unit = {
  24. var message: InboxMessage = null
  25. inbox.synchronized {
  26. if (!enableConcurrent && numActiveThreads != 0) {
  27. return
  28. }
  29. message = messages.poll()
  30. if (message != null) {
  31. numActiveThreads += 1
  32. } else {
  33. return
  34. }
  35. }
  36. while (true) {
  37. safelyCall(endpoint) {
  38. message match {
  39. case RpcMessage(_sender, content, context) =>
  40. try {
  41. endpoint.receiveAndReply(context).applyOrElse[Any, Unit](content, { msg =>
  42. throw new SparkException(s"Unsupported message $message from ${_sender}")
  43. })
  44. } catch {
  45. case e: Throwable =>
  46. context.sendFailure(e)
  47. // Throw the exception -- this exception will be caught by the safelyCall function.
  48. // The endpoint's onError function will be called.
  49. throw e
  50. }
  51. case OneWayMessage(_sender, content) =>
  52. endpoint.receive.applyOrElse[Any, Unit](content, { msg =>
  53. throw new SparkException(s"Unsupported message $message from ${_sender}")
  54. })
  55. case OnStart =>
  56. endpoint.onStart()
  57. if (!endpoint.isInstanceOf[ThreadSafeRpcEndpoint]) {
  58. inbox.synchronized {
  59. if (!stopped) {
  60. enableConcurrent = true
  61. }
  62. }
  63. }
  64. case OnStop =>
  65. val activeThreads = inbox.synchronized { inbox.numActiveThreads }
  66. assert(activeThreads == 1,
  67. s"There should be only a single active thread but found $activeThreads threads.")
  68. dispatcher.removeRpcEndpointRef(endpoint)
  69. endpoint.onStop()
  70. assert(isEmpty, "OnStop should be the last message")
  71. case RemoteProcessConnected(remoteAddress) =>
  72. endpoint.onConnected(remoteAddress)
  73. case RemoteProcessDisconnected(remoteAddress) =>
  74. endpoint.onDisconnected(remoteAddress)
  75. case RemoteProcessConnectionError(cause, remoteAddress) =>
  76. endpoint.onNetworkError(cause, remoteAddress)
  77. }
  78. }
  79. inbox.synchronized {
  80. // "enableConcurrent" will be set to false after `onStop` is called, so we should check it
  81. // every time.
  82. if (!enableConcurrent && numActiveThreads != 1) {
  83. // If we are not the only one worker, exit
  84. numActiveThreads -= 1
  85. return
  86. }
  87. message = messages.poll()
  88. if (message == null) {
  89. numActiveThreads -= 1
  90. return
  91. }
  92. }
  93. }
  94. }

处理启动事件,去启动端点

  1. case OnStart =>
  2. // 启动端点
  3. endpoint.onStart()
  4. if (!endpoint.isInstanceOf[ThreadSafeRpcEndpoint]) {
  5. inbox.synchronized {
  6. if (!stopped) {
  7. enableConcurrent = true
  8. }
  9. }
  10. }

来到CoarseGrainedExecutorBackend的onStart方法,得到driver

  1. override def onStart(): Unit = {
  2. logInfo("Connecting to driver: " + driverUrl)
  3. try {
  4. _resources = parseOrFindResources(resourcesFileOpt)
  5. } catch {
  6. case NonFatal(e) =>
  7. exitExecutor(1, "Unable to create executor due to " + e.getMessage, e)
  8. }
  9. rpcEnv.asyncSetupEndpointRefByURI(driverUrl).flatMap { ref =>
  10. // This is a very fast action so we can use "ThreadUtils.sameThread"
  11. // 得到driver
  12. driver = Some(ref)
  13. // driver向我们的连接的Driver发送请求
  14. ref.ask[Boolean](RegisterExecutor(executorId, self, hostname, cores, extractLogUrls,
  15. extractAttributes, _resources, resourceProfile.id))
  16. }(ThreadUtils.sameThread).onComplete {
  17. case Success(_) =>
  18. self.send(RegisteredExecutor)
  19. case Failure(e) =>
  20. exitExecutor(1, s"Cannot register with driver: $driverUrl", e, notifyDriver = false)
  21. }(ThreadUtils.sameThread)
  22. }

image.png
处理消息的后端

CoarseGrainedSchedulerBackend中,有方法receiveAndReply(),用来处理请求和应答,正好处理注册请求

  1. override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
  2. case RegisterExecutor(executorId, executorRef, hostname, cores, logUrls,
  3. attributes, resources, resourceProfileId) =>
  4. if (executorDataMap.contains(executorId)) {
  5. context.sendFailure(new IllegalStateException(s"Duplicate executor ID: $executorId"))
  6. } else if (scheduler.nodeBlacklist.contains(hostname) ||
  7. isBlacklisted(executorId, hostname)) {
  8. // If the cluster manager gives us an executor on a blacklisted node (because it
  9. // already started allocating those resources before we informed it of our blacklist,
  10. // or if it ignored our blacklist), then we reject that executor immediately.
  11. logInfo(s"Rejecting $executorId as it has been blacklisted.")
  12. context.sendFailure(new IllegalStateException(s"Executor is blacklisted: $executorId"))
  13. } else {
  14. // If the executor's rpc env is not listening for incoming connections, `hostPort`
  15. // will be null, and the client connection should be used to contact the executor.
  16. val executorAddress = if (executorRef.address != null) {
  17. executorRef.address
  18. } else {
  19. context.senderAddress
  20. }
  21. logInfo(s"Registered executor $executorRef ($executorAddress) with ID $executorId")
  22. addressToExecutorId(executorAddress) = executorId
  23. totalCoreCount.addAndGet(cores)
  24. totalRegisteredExecutors.addAndGet(1)
  25. val resourcesInfo = resources.map{ case (k, v) =>
  26. (v.name,
  27. new ExecutorResourceInfo(v.name, v.addresses,
  28. // tell the executor it can schedule resources up to numParts times,
  29. // as configured by the user, or set to 1 as that is the default (1 task/resource)
  30. taskResourceNumParts.getOrElse(v.name, 1)))
  31. }
  32. val data = new ExecutorData(executorRef, executorAddress, hostname,
  33. 0, cores, logUrlHandler.applyPattern(logUrls, attributes), attributes,
  34. resourcesInfo, resourceProfileId)
  35. // This must be synchronized because variables mutated
  36. // in this block are read when requesting executors
  37. CoarseGrainedSchedulerBackend.this.synchronized {
  38. executorDataMap.put(executorId, data)
  39. if (currentExecutorIdCounter < executorId.toInt) {
  40. currentExecutorIdCounter = executorId.toInt
  41. }
  42. if (numPendingExecutors > 0) {
  43. numPendingExecutors -= 1
  44. logDebug(s"Decremented number of pending executors ($numPendingExecutors left)")
  45. }
  46. }
  47. listenerBus.post(
  48. SparkListenerExecutorAdded(System.currentTimeMillis(), executorId, data))
  49. // Note: some tests expect the reply to come after we put the executor in the map
  50. context.reply(true)
  51. }
  52. case StopDriver =>
  53. context.reply(true)
  54. stop()
  55. case StopExecutors =>
  56. logInfo("Asking each executor to shut down")
  57. for ((_, executorData) <- executorDataMap) {
  58. executorData.executorEndpoint.send(StopExecutor)
  59. }
  60. context.reply(true)
  61. case RemoveWorker(workerId, host, message) =>
  62. removeWorker(workerId, host, message)
  63. context.reply(true)
  64. case RetrieveSparkAppConfig(resourceProfileId) =>
  65. // note this will be updated in later prs to get the ResourceProfile from a
  66. // ResourceProfileManager that matches the resource profile id
  67. // for now just use default profile
  68. val rp = ResourceProfile.getOrCreateDefaultProfile(conf)
  69. val reply = SparkAppConfig(
  70. sparkProperties,
  71. SparkEnv.get.securityManager.getIOEncryptionKey(),
  72. Option(delegationTokens.get()),
  73. rp)
  74. context.reply(reply)
  75. }

注册成功消息之后,创建Executor,发送LaunchedExecutor消息

  1. override def receive: PartialFunction[Any, Unit] = {
  2. case RegisteredExecutor =>
  3. logInfo("Successfully registered with driver")
  4. try {
  5. //
  6. executor = new Executor(executorId, hostname, env, userClassPath, isLocal = false,
  7. resources = _resources)
  8. driver.get.send(LaunchedExecutor(executorId))
  9. } catch {
  10. case NonFatal(e) =>
  11. exitExecutor(1, "Unable to create executor due to " + e.getMessage, e)
  12. }
  1. case LaunchedExecutor(executorId) =>
  2. executorDataMap.get(executorId).foreach { data =>
  3. // 增加核数
  4. data.freeCores = data.totalCores
  5. }
  6. // 执行任务
  7. makeOffers(executorId)

二、通信原理

Spark底层通信使用Netty作为通信框架,Netty支持NIO、AIO操作,但是Linux对AIO支持不够好,Windows支持AIO,但是Linux采用Epoll方式模仿AIO操作。

  • 1、NettyRpcEnv.create()准备通信环境
    • 1、启动服务器 nettyEnv.startServer(config.bindAddress, actualPort)
      • 1、创建服务器 transportContext.createServer(bindAddress, port, bootstraps)
        • 1、new TransportServer().init()
          • 1、new ServerBootstrap()
          • 2、initializePipeline初始化管道
      • 2、注册通信端点 dispatcher.registerRpcEndpoint()
        • 1、new NettyRpcEndpointRef注册通信端点的引用(有ask、send等方法,用来发送消息),内有outboxes发件箱,可以发送给多个人消息。
          • 有多个client,TransportClient来往TransportServer发送消息
        • 2、new DedicatedMessageLoop注册消息循环,内持有Inbox收信箱,有receive、reply等方法,用来接收消息

image.png

1、通信组件

Driver、Executor之间如何通信?
我们都知道有一个RpcEnv作为远程环境。底层使用Netty进行通信,就找Netty的环境。

  1. private[rpc] class NettyRpcEnvFactory extends RpcEnvFactory with Logging {
  2. def create(config: RpcEnvConfig): RpcEnv = {
  3. val sparkConf = config.conf
  4. // Use JavaSerializerInstance in multiple threads is safe. However, if we plan to support
  5. // KryoSerializer in future, we have to use ThreadLocal to store SerializerInstance
  6. val javaSerializerInstance =
  7. new JavaSerializer(sparkConf).newInstance().asInstanceOf[JavaSerializerInstance]
  8. // 1、这里创建一个Netty的通信环境
  9. val nettyEnv =
  10. new NettyRpcEnv(sparkConf, javaSerializerInstance, config.advertiseAddress,
  11. config.securityManager, config.numUsableCores)
  12. if (!config.clientMode) {
  13. // 启动服务器
  14. val startNettyRpcEnv: Int => (NettyRpcEnv, Int) = { actualPort =>
  15. nettyEnv.startServer(config.bindAddress, actualPort)
  16. (nettyEnv, nettyEnv.address.port)
  17. }
  18. try {
  19. // 在指定端口启动服务器
  20. Utils.startServiceOnPort(config.port, startNettyRpcEnv, sparkConf, config.name)._1
  21. } catch {
  22. case NonFatal(e) =>
  23. nettyEnv.shutdown()
  24. throw e
  25. }
  26. }
  27. nettyEnv
  28. }
  29. }

启动服务器,得服务器,所以先创建服务器

  1. def startServer(bindAddress: String, port: Int): Unit = {
  2. val bootstraps: java.util.List[TransportServerBootstrap] =
  3. if (securityManager.isAuthenticationEnabled()) {
  4. java.util.Arrays.asList(new AuthServerBootstrap(transportConf, securityManager))
  5. } else {
  6. java.util.Collections.emptyList()
  7. }
  8. // 创建服务器,
  9. server = transportContext.createServer(bindAddress, port, bootstraps)
  10. // 注册通信端点
  11. dispatcher.registerRpcEndpoint(
  12. RpcEndpointVerifier.NAME, new RpcEndpointVerifier(this, dispatcher))
  13. }

创建Transport Server

  1. public TransportServer(
  2. TransportContext context,
  3. String hostToBind,
  4. int portToBind,
  5. RpcHandler appRpcHandler,
  6. List<TransportServerBootstrap> bootstraps) {
  7. this.context = context;
  8. this.conf = context.getConf();
  9. this.appRpcHandler = appRpcHandler;
  10. if (conf.sharedByteBufAllocators()) {
  11. this.pooledAllocator = NettyUtils.getSharedPooledByteBufAllocator(
  12. conf.preferDirectBufsForSharedByteBufAllocators(), true /* allowCache */);
  13. } else {
  14. this.pooledAllocator = NettyUtils.createPooledByteBufAllocator(
  15. conf.preferDirectBufs(), true /* allowCache */, conf.serverThreads());
  16. }
  17. this.bootstraps = Lists.newArrayList(Preconditions.checkNotNull(bootstraps));
  18. boolean shouldClose = true;
  19. try {
  20. // 执行init初始化方法
  21. init(hostToBind, portToBind);
  22. shouldClose = false;
  23. } finally {
  24. if (shouldClose) {
  25. JavaUtils.closeQuietly(this);
  26. }
  27. }
  28. }

三、应用程序的执行

应用程序的执行必然是都运行在准备好的环境之上。我们从环境入手

1、上下文对象-SparkContext

SparkContext里面的几个关键对象

  • SparkConf
    • 基础环境配置
  • SparkEnv
    • 通信环境
  • SchedulerBackend
    • 通信后端,与Executor进行通信
  • TaskScheduler
    • 任务调度器,主要用于任务的调度
  • DAGScheduler
    • 阶段调度器,主要用于阶段的划分和任务的切分

2、RDD依赖

  1. // 2.2、将一行分割后转换
  2. val value: RDD[(String, Int)] = lines.flatMap(_.split(" "))
  3. .groupBy(word => word)
  4. .map {
  5. case (word, list) => {
  6. (word, list.size)
  7. }
  8. }

最原始的RDD先经过flatMap,包装一个MapPartitionsRdd

  1. def flatMap[U: ClassTag](f: T => TraversableOnce[U]): RDD[U] = withScope {
  2. val cleanF = sc.clean(f)
  3. // 当前对象转换
  4. new MapPartitionsRDD[U, T](this, (_, _, iter) => iter.flatMap(cleanF))
  5. }
  1. private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag](
  2. var prev: RDD[T],
  3. f: (TaskContext, Int, Iterator[T]) => Iterator[U], // (TaskContext, partition index, iterator)
  4. preservesPartitioning: Boolean = false,
  5. isFromBarrier: Boolean = false,
  6. isOrderSensitive: Boolean = false)
  7. // 继承自这个有参的RDD,
  8. extends RDD[U](prev) {

构建了一个OneToOneDependency并将之前的RDD传进去

  1. /** Construct an RDD with just a one-to-one dependency on one parent */
  2. def this(@transient oneParent: RDD[_]) =
  3. this(oneParent.context, List(new OneToOneDependency(oneParent)))

OneToOneDependency继承了NarrowDependency,其中把RDD保存了。

  1. class OneToOneDependency[T](rdd: RDD[T]) extends NarrowDependency[T](rdd) {
  2. override def getParents(partitionId: Int): List[Int] = List(partitionId)
  3. }
  4. abstract class NarrowDependency[T](_rdd: RDD[T]) extends Dependency[T] {
  5. /**
  6. * Get the parent partitions for a child partition.
  7. * @param partitionId a partition of the child RDD
  8. * @return the partitions of the parent RDD that the child partition depends upon
  9. */
  10. def getParents(partitionId: Int): Seq[Int]
  11. override def rdd: RDD[T] = _rdd
  12. }

第二步,经过groupBy的时候,

  1. def groupBy[K](f: T => K, p: Partitioner)(implicit kt: ClassTag[K], ord: Ordering[K] = null)
  2. : RDD[(K, Iterable[T])] = withScope {
  3. val cleanF = sc.clean(f)
  4. // 走的时候groupByKey
  5. this.map(t => (cleanF(t), t)).groupByKey(p)
  6. }

走combine

  1. def groupByKey(partitioner: Partitioner): RDD[(K, Iterable[V])] = self.withScope {
  2. // groupByKey shouldn't use map side combine because map side combine does not
  3. // reduce the amount of data shuffled and requires all map side data be inserted
  4. // into a hash table, leading to more objects in the old gen.
  5. val createCombiner = (v: V) => CompactBuffer(v)
  6. val mergeValue = (buf: CompactBuffer[V], v: V) => buf += v
  7. val mergeCombiners = (c1: CompactBuffer[V], c2: CompactBuffer[V]) => c1 ++= c2
  8. // 进这里
  9. val bufs = combineByKeyWithClassTag[CompactBuffer[V]](
  10. createCombiner, mergeValue, mergeCombiners, partitioner, mapSideCombine = false)
  11. bufs.asInstanceOf[RDD[(K, Iterable[V])]]
  12. }

里面有一个ShuffleRDD

  1. new ShuffledRDD[K, V, C](self, partitioner)

点一个,看继承关系,传的是一个Nil?没关系,我们获取依赖关系的时候,通过getDependencies方法

  1. class ShuffledRDD[K: ClassTag, V: ClassTag, C: ClassTag](
  2. @transient var prev: RDD[_ <: Product2[K, V]],
  3. part: Partitioner)
  4. extends RDD[(K, C)](prev.context, Nil) {
  1. override def getDependencies: Seq[Dependency[_]] = {
  2. val serializer = userSpecifiedSerializer.getOrElse {
  3. val serializerManager = SparkEnv.get.serializerManager
  4. if (mapSideCombine) {
  5. serializerManager.getSerializer(implicitly[ClassTag[K]], implicitly[ClassTag[C]])
  6. } else {
  7. serializerManager.getSerializer(implicitly[ClassTag[K]], implicitly[ClassTag[V]])
  8. }
  9. }
  10. // 创建一个ShuffleDependency,将先前的prev的RDD传入并指向
  11. List(new ShuffleDependency(prev, part, serializer, keyOrdering, aggregator, mapSideCombine))
  12. }

所以依赖关系就是将RDDDependency指向前一个依赖的RDD,形成有向无环图。

3、阶段的划分

collect算子的执行会触发作业的提交,就会进行阶段的划分

  1. def collect(): Array[T] = withScope {
  2. val results = sc.runJob(this, (iter: Iterator[T]) => iter.toArray)
  3. Array.concat(results: _*)
  4. }

运行任务

  1. dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, resultHandler, localProperties.get)
  2. // 提交任务
  3. val waiter = submitJob(rdd, func, partitions, callSite, resultHandler, properties)
  4. // 提交一个JobSubmitted事件
  5. eventProcessLoop.post(JobSubmitted(
  6. jobId, rdd, func2, partitions.toArray, callSite, waiter,
  7. Utils.cloneProperties(properties)))

提交的时候将消息放在消息队列中

  1. /**
  2. * Put the event into the event queue. The event thread will process it later.
  3. */
  4. def post(event: E): Unit = {
  5. if (!stopped.get) {
  6. // 判断当前线程是不是还活着,往事件队列中放事件
  7. if (eventThread.isAlive) {
  8. eventQueue.put(event)
  9. } else {
  10. onError(new IllegalStateException(s"$name has already been stopped accidentally."))
  11. }
  12. }
  13. }

这个线程一启动的时候就会从事件队列中拿事件

  1. // Exposed for testing.
  2. private[spark] val eventThread = new Thread(name) {
  3. setDaemon(true)
  4. override def run(): Unit = {
  5. try {
  6. while (!stopped.get) {
  7. // 拿事件
  8. val event = eventQueue.take()
  9. try {
  10. // 进行接收事件状态
  11. onReceive(event)
  12. } catch {
  13. case NonFatal(e) =>
  14. try {
  15. onError(e)
  16. } catch {
  17. case NonFatal(e) => logError("Unexpected error in " + name, e)
  18. }
  19. }
  20. }
  21. } catch {
  22. case ie: InterruptedException => // exit even if eventQueue is not empty
  23. case NonFatal(e) => logError("Unexpected error in " + name, e)
  24. }
  25. }
  26. }

EventLoop中接收事件最终来到doOnReceive方法,这里面定义了各种事件如何进行处理

  1. private def doOnReceive(event: DAGSchedulerEvent): Unit = event match {
  2. case JobSubmitted(jobId, rdd, func, partitions, callSite, listener, properties) =>
  3. dagScheduler.handleJobSubmitted(jobId, rdd, func, partitions, callSite, listener, properties)
  4. case MapStageSubmitted(jobId, dependency, callSite, listener, properties) =>
  5. dagScheduler.handleMapStageSubmitted(jobId, dependency, callSite, listener, properties)
  6. case StageCancelled(stageId, reason) =>
  7. dagScheduler.handleStageCancellation(stageId, reason)
  8. case JobCancelled(jobId, reason) =>
  9. dagScheduler.handleJobCancellation(jobId, reason)
  10. case JobGroupCancelled(groupId) =>
  11. dagScheduler.handleJobGroupCancelled(groupId)

在handleJobSubmitted中,创建一个createResultStage最终的结果阶段

  1. finalStage = createResultStage(finalRDD, func, partitions, jobId, callSite)

在创建结果阶段之前,还要判断是否有上级阶段,如果有就创建出来

  1. private def createResultStage(
  2. rdd: RDD[_],
  3. func: (TaskContext, Iterator[_]) => _,
  4. partitions: Array[Int],
  5. jobId: Int,
  6. callSite: CallSite): ResultStage = {
  7. checkBarrierStageWithDynamicAllocation(rdd)
  8. checkBarrierStageWithNumSlots(rdd)
  9. checkBarrierStageWithRDDChainPattern(rdd, partitions.toSet.size)
  10. // 创建上级阶段
  11. val parents = getOrCreateParentStages(rdd, jobId)
  12. val id = nextStageId.getAndIncrement()
  13. // 创建结果阶段
  14. val stage = new ResultStage(id, rdd, func, partitions, parents, jobId, callSite)
  15. stageIdToStage(id) = stage
  16. updateJobIdStageIdMaps(jobId, stage)
  17. stage
  18. }

创建上级阶段的依据是判断是不是Shuffle依赖,如果有就创建一个ShuffleMapStage,Shuffle写磁盘数据的过程

  1. private def getOrCreateParentStages(rdd: RDD[_], firstJobId: Int): List[Stage] = {
  2. getShuffleDependencies(rdd).map { shuffleDep =>
  3. getOrCreateShuffleMapStage(shuffleDep, firstJobId)
  4. }.toList
  5. }

获取ShuffleDependencies的时候,从上一级rdd中拿出依赖,模式匹配遍历,如果是Shuffle依赖,加1

  1. private[scheduler] def getShuffleDependencies(
  2. rdd: RDD[_]): HashSet[ShuffleDependency[_, _, _]] = {
  3. val parents = new HashSet[ShuffleDependency[_, _, _]]
  4. val visited = new HashSet[RDD[_]]
  5. val waitingForVisit = new ListBuffer[RDD[_]]
  6. waitingForVisit += rdd
  7. while (waitingForVisit.nonEmpty) {
  8. val toVisit = waitingForVisit.remove(0)
  9. if (!visited(toVisit)) {
  10. visited += toVisit
  11. toVisit.dependencies.foreach {
  12. // 如果是Shuffle依赖,父阶段+1
  13. case shuffleDep: ShuffleDependency[_, _, _] =>
  14. parents += shuffleDep
  15. case dependency =>
  16. waitingForVisit.prepend(dependency.rdd)
  17. }
  18. }
  19. }
  20. parents
  21. }

创建ShuffleMapStage的时候createShuffleMapStage,获取到ShuffleDependency中的rdd,也就是上一个阶段的rdd,接着再去判断上一个rdd是否是shuffleRdd等等等。

  1. val parents = getOrCreateParentStages(rdd, jobId)
  2. val stage = new ShuffleMapStage(
  3. id, rdd, numTasks, parents, jobId, rdd.creationSite, shuffleDep, mapOutputTracker)

总结:
阶段会先创建一个ResultStage,如果过程中有ShuffleRdd ,就会创建ShuffleMapStage,有多少次shuffle就加几个阶段。

4、任务的切分

在handleJobSubmitted方法中,会创建一个ActiveJob,并在最后提交阶段

  1. submitStage(finalStage)

提交阶段的时候,先查看缺失的阶段

  1. /** Submits stage, but first recursively submits any missing parents. */
  2. private def submitStage(stage: Stage): Unit = {
  3. val jobId = activeJobForStage(stage)
  4. if (jobId.isDefined) {
  5. logDebug(s"submitStage($stage (name=${stage.name};" +
  6. s"jobs=${stage.jobIds.toSeq.sorted.mkString(",")}))")
  7. if (!waitingStages(stage) && !runningStages(stage) && !failedStages(stage)) {
  8. // 先获取丢失的阶段,就是获取shuffle阶段,
  9. val missing = getMissingParentStages(stage).sortBy(_.id)
  10. logDebug("missing: " + missing)
  11. if (missing.isEmpty) {
  12. logInfo("Submitting " + stage + " (" + stage.rdd + "), which has no missing parents")
  13. // 提交缺失的任务
  14. submitMissingTasks(stage, jobId.get)
  15. } else {
  16. // 如果有哦缺失的阶段,将缺失的阶段先提交
  17. for (parent <- missing) {
  18. submitStage(parent)
  19. }
  20. waitingStages += stage
  21. }
  22. }
  23. } else {
  24. abortStage(stage, "No active job for stage " + stage.id, None)
  25. }
  26. }

提交丢失任务的时候,判断是不是ShuffleMapStage,将每个id遍历执行不过

  1. val tasks: Seq[Task[_]] = try {
  2. val serializedTaskMetrics = closureSerializer.serialize(stage.latestInfo.taskMetrics).array()
  3. stage match {
  4. // 判断阶段
  5. case stage: ShuffleMapStage =>
  6. stage.pendingPartitions.clear()
  7. // 获取每个阶段id并且创建同数量任务去执行!
  8. partitionsToCompute.map { id =>
  9. val locs = taskIdToLocations(id)
  10. val part = partitions(id)
  11. stage.pendingPartitions += id
  12. new ShuffleMapTask(stage.id, stage.latestInfo.attemptNumber,
  13. taskBinary, part, locs, properties, serializedTaskMetrics, Option(jobId),
  14. Option(sc.applicationId), sc.applicationAttemptId, stage.rdd.isBarrier())
  15. }
  16. case stage: ResultStage =>
  17. partitionsToCompute.map { id =>
  18. val p: Int = stage.partitions(id)
  19. val part = partitions(p)
  20. val locs = taskIdToLocations(id)
  21. // 创建ResultTask一个
  22. new ResultTask(stage.id, stage.latestInfo.attemptNumber,
  23. taskBinary, part, locs, id, properties, serializedTaskMetrics,
  24. Option(jobId), Option(sc.applicationId), sc.applicationAttemptId,
  25. stage.rdd.isBarrier())
  26. }
  27. }

到底有多少个阶段呢?就需要看partitionsToCompute了

  1. // Figure out the indexes of partition ids to compute.
  2. val partitionsToCompute: Seq[Int] = stage.findMissingPartitions()
  3. // ShuffleMapStage
  4. /** Returns the sequence of partition ids that are missing (i.e. needs to be computed). */
  5. override def findMissingPartitions(): Seq[Int] = {
  6. mapOutputTrackerMaster
  7. .findMissingPartitions(shuffleDep.shuffleId)
  8. // 有多少分区就创建多少个ShuffleMapTask任务,而且是所有的Shuffle阶段,如果有两个shuffle阶段就是两倍
  9. .getOrElse(0 until numPartitions)
  10. }
  11. // ResultStage
  12. val job = activeJob.get
  13. (0 until job.numPartitions).filter(id => !job.finished(id))

总结:
任务的数量等于所有阶段中每个阶段的最后一个Rdd的分区数量之和

5、任务的调度

在submitMissingsTasks中,创建完阶段Task之后,下面就提交Task,将所有Task封装成一个任务集中

  1. taskScheduler.submitTasks(new TaskSet(
  2. tasks.toArray, stage.id, stage.latestInfo.attemptNumber, jobId, properties))

提交结果集。

  1. override def submitTasks(taskSet: TaskSet): Unit = {
  2. // 获取结果集的任务
  3. val tasks = taskSet.tasks
  4. logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks")
  5. this.synchronized {
  6. // 创建TaskManager管理任务集
  7. val manager = createTaskSetManager(taskSet, maxTaskFailures)
  8. val stage = taskSet.stageId
  9. val stageTaskSets =
  10. taskSetsByStageIdAndAttempt.getOrElseUpdate(stage, new HashMap[Int, TaskSetManager])
  11. stageTaskSets.foreach { case (_, ts) =>
  12. ts.isZombie = true
  13. }
  14. stageTaskSets(taskSet.stageAttemptId) = manager
  15. schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)
  16. if (!isLocal && !hasReceivedTask) {
  17. starvationTimer.scheduleAtFixedRate(new TimerTask() {
  18. override def run(): Unit = {
  19. if (!hasLaunchedTask) {
  20. logWarning("Initial job has not accepted any resources; " +
  21. "check your cluster UI to ensure that workers are registered " +
  22. "and have sufficient resources")
  23. } else {
  24. this.cancel()
  25. }
  26. }
  27. }, STARVATION_TIMEOUT_MS, STARVATION_TIMEOUT_MS)
  28. }
  29. hasReceivedTask = true
  30. }
  31. // 后端拉取任务
  32. backend.reviveOffers()
  33. }

CoarseGrainedSchedulerBackend集群后端处理消息

  1. case ReviveOffers =>
  2. makeOffers()
  3. // ...........................
  4. // Make fake resource offers on all executors
  5. private def makeOffers(): Unit = {
  6. // Make sure no executor is killed while some task is launching on it
  7. // 获取任务的描述信息
  8. val taskDescs = withLock {
  9. // Filter out executors under killing
  10. val activeExecutors = executorDataMap.filterKeys(isExecutorActive)
  11. val workOffers = activeExecutors.map {
  12. case (id, executorData) =>
  13. new WorkerOffer(id, executorData.executorHost, executorData.freeCores,
  14. Some(executorData.executorAddress.hostPort),
  15. executorData.resourcesInfo.map { case (rName, rInfo) =>
  16. (rName, rInfo.availableAddrs.toBuffer)
  17. })
  18. }.toIndexedSeq
  19. // 刷新offers
  20. scheduler.resourceOffers(workOffers)
  21. }
  22. // 运行任务
  23. if (taskDescs.nonEmpty) {
  24. launchTasks(taskDescs)
  25. }
  26. }

在获取资源中:,从rootPool中获取排好序的任务集,遍历每个排好序的task,交给executor执行
,rootPool就是一个池子,包含多个任务集

  1. val sortedTaskSets = rootPool.getSortedTaskSetQueue.filterNot(_.isZombie)
  2. for (taskSet <- sortedTaskSets) {
  3. logDebug("parentName: %s, name: %s, runningTasks: %s".format(
  4. taskSet.parent.name, taskSet.name, taskSet.runningTasks))
  5. if (newExecAvail) {
  6. // 让任务交给executor执行
  7. taskSet.executorAdded()
  8. }
  9. }

executor执行Task的时候,会进行本地化级别计算recomputeLocality

  1. ef recomputeLocality(): Unit = {
  2. // A zombie TaskSetManager may reach here while executorLost happens
  3. if (isZombie) return
  4. // 计算本地化级别
  5. val previousLocalityLevel = myLocalityLevels(currentLocalityIndex)
  6. myLocalityLevels = computeValidLocalityLevels()
  7. localityWaits = myLocalityLevels.map(getLocalityWait)
  8. // 依次尝试,适配本地化级别
  9. currentLocalityIndex = getLocalityIndex(previousLocalityLevel)
  10. }

本地化级别:就是计算和任务运行的级别。
如果一个任务和数据在同一个进程中,那么就不用移动数据,有一句话:移动数据不如移动计算

数据和任务不一定是在 一起。**这种级别称为本地化级别:
有以下本地化级别:

  • 进程本地化:数据和计算都在同一个进程
  • 节点本地化:数据和计算在一台机器上但是不在同一个进程
  • 机架本地化:数据和计算在一个机架
  • 其他任意

所以在本地化计算的时候,会经过以下步骤:

  1. private def computeValidLocalityLevels(): Array[TaskLocality.TaskLocality] = {
  2. import TaskLocality.{PROCESS_LOCAL, NODE_LOCAL, NO_PREF, RACK_LOCAL, ANY}
  3. val levels = new ArrayBuffer[TaskLocality.TaskLocality]
  4. // 进程本地化
  5. if (!pendingTasks.forExecutor.isEmpty &&
  6. pendingTasks.forExecutor.keySet.exists(sched.isExecutorAlive(_))) {
  7. levels += PROCESS_LOCAL
  8. }
  9. // 节点本地化
  10. if (!pendingTasks.forHost.isEmpty &&
  11. pendingTasks.forHost.keySet.exists(sched.hasExecutorsAliveOnHost(_))) {
  12. levels += NODE_LOCAL
  13. }
  14. if (!pendingTasks.noPrefs.isEmpty) {
  15. levels += NO_PREF
  16. }
  17. // 机架本地化
  18. if (!pendingTasks.forRack.isEmpty &&
  19. pendingTasks.forRack.keySet.exists(sched.hasHostAliveOnRack(_))) {
  20. levels += RACK_LOCAL
  21. }
  22. // 随意
  23. levels += ANY
  24. logDebug("Valid locality levels for " + taskSet + ": " + levels.mkString(", "))
  25. levels.toArray
  26. }

进行本地化级别依次尝试,否则降低本地化级别

  1. def getLocalityIndex(locality: TaskLocality.TaskLocality): Int = {
  2. var index = 0
  3. while (locality > myLocalityLevels(index)) {
  4. index += 1
  5. }
  6. index
  7. }

最后将这些任务返回return tasks,然后后端将任务序列化之后发给executor端执行。

  1. // Launch tasks returned by a set of resource offers
  2. private def launchTasks(tasks: Seq[Seq[TaskDescription]]): Unit = {
  3. for (task <- tasks.flatten) {
  4. // 将任务编码
  5. val serializedTask = TaskDescription.encode(task)
  6. if (serializedTask.limit() >= maxRpcMessageSize) {
  7. Option(scheduler.taskIdToTaskSetManager.get(task.taskId)).foreach { taskSetMgr =>
  8. try {
  9. var msg = "Serialized task %s:%d was %d bytes, which exceeds max allowed: " +
  10. s"${RPC_MESSAGE_MAX_SIZE.key} (%d bytes). Consider increasing " +
  11. s"${RPC_MESSAGE_MAX_SIZE.key} or using broadcast variables for large values."
  12. msg = msg.format(task.taskId, task.index, serializedTask.limit(), maxRpcMessageSize)
  13. taskSetMgr.abort(msg)
  14. } catch {
  15. case e: Exception => logError("Exception in error callback", e)
  16. }
  17. }
  18. }
  19. else {
  20. val executorData = executorDataMap(task.executorId)
  21. // Do resources allocation here. The allocated resources will get released after the task
  22. // finishes.
  23. executorData.freeCores -= scheduler.CPUS_PER_TASK
  24. task.resources.foreach { case (rName, rInfo) =>
  25. assert(executorData.resourcesInfo.contains(rName))
  26. executorData.resourcesInfo(rName).acquire(rInfo.addresses)
  27. }
  28. logDebug(s"Launching task ${task.taskId} on executor id: ${task.executorId} hostname: " +
  29. s"${executorData.executorHost}.")
  30. // 讲这些任务序列化之后发给executor端执行
  31. executorData.executorEndpoint.send(LaunchTask(new SerializableBuffer(serializedTask)))
  32. }
  33. }
  34. }

6、任务的执行

SchedulerBackend已经将任务发给了executorEnd,所以那边肯定有接收消息

executorData.executorEndpoint.send(LaunchTask(new SerializableBuffer(serializedTask)))

CoarseGrainedExecutorBackend的receive方法中接收运行任务的消息,判断有没有executor,如果没有就退出,有就将任务解码之后,executor运行。

    case LaunchTask(data) =>
      if (executor == null) {
        exitExecutor(1, "Received LaunchTask command but executor was null")
      } else {
        val taskDesc = TaskDescription.decode(data.value)
        logInfo("Got assigned task " + taskDesc.taskId)
        taskResources(taskDesc.taskId) = taskDesc.resources
        // 运行任务
        executor.launchTask(this, taskDesc)
      }

每个任务被封装成一个TaskRunner,运行任务是从线程池中拿出一个线程来执行这个TaskRunner的run方法。run方法中有一处调用task.run(),让任务真正执行。

  def launchTask(context: ExecutorBackend, taskDescription: TaskDescription): Unit = {
    // 创建一个TaskRunner来跑任务
    val tr = new TaskRunner(context, taskDescription)
    // 将任务放入正在运行的任务集合中
    runningTasks.put(taskDescription.taskId, tr)
    // 执行这个TaskRunner
    threadPool.execute(tr)
  }
  // Maintains the list of running tasks.
  private val runningTasks = new ConcurrentHashMap[Long, TaskRunner]
  private val threadPool = {
    val threadFactory = new ThreadFactoryBuilder()
      .setDaemon(true)
      .setNameFormat("Executor task launch worker-%d")
      .setThreadFactory((r: Runnable) => new UninterruptibleThread(r, "unused"))
      .build()
 class TaskRunner(
      execBackend: ExecutorBackend,
      private val taskDescription: TaskDescription)
    extends Runnable {
      ..........
              val value = Utils.tryWithSafeFinally {
          val res = task.run(
            taskAttemptId = taskId,
            attemptNumber = taskDescription.attemptNumber,
            metricsSystem = env.metricsSystem,
            resources = taskDescription.resources)
          threwException = false
          res
        }

task.run()中会调用runTask(context执行任务,这是一个抽象方法,具体执行判断该任务是什么类型,ShuffleMapTask还是ResultTask。都有具体的实现去完成自己的任务

def runTask(context: TaskContext): T

image.png

四、Shuffle

image.png

为了提高性能,Shuffle采取生成一个File数据文件和Index索引文件的方式,让其他下游的Task读取该文件找到自己的数据去处理。
image.png

1、流程梳理

  • 1、DAGSchedule调度任务,判断是ShuffleMapStage还是ResultStage,分别创建ShuffleMapTask和ResultTask对象,最终TaskScheduler会submit任务。任务会被SchedulerBackend发给ExecutorBackend,将任务封装成TaskRunner给每个executor从线程池内拿取线程执行任务。每个任务会运行自己的runTask()逻辑。我们Shuffle发生在ShuffleMapTask任务中,所以从ShuffleMapTask入手
  • 2、ShuffleMapTask的runTask()方法中会拿取一个shuffleWriterProcessor,调用write方法写出数据
    • 1、先获取到ShuffleManager,从ShuffleManager中获取writer,写出数据,获取到的是SortShuffleWriter
    • 2、对数据排个序,并准备好写出器 mapOutputWriter
    • 3、sorter.insertAll(records)
    • 4、按照分区写出数据 writePartitionedMapOutput
      • 获取按照分区的迭代器,一个一个分区写出
    • 5、提交所有分区 commitAllPartitions
      • 实现类LocalDiskShuffleMapOutputWriter
      • 写出索引文件并提交 blockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, resolvedTmp);
        • 获取索引和数据文件
        • writeLong()写出
  • 3、ShuffleMapTask完成任务,接着往后执行,比如ResultTask的runTask()方法
    • 遍历每一个rdd,func(context, rdd.iterator(partition, context))
      • 判断是否设置过存储级别(缓存、文件等)
      • 如果没有禁止缓存,getOrCompute(split, context)
        • 还是走检查点
      • 禁止缓存,读取检查点computeOrReadCheckpoint(split, context)
        • compute()计算,是一个抽象方法,每个rdd都有自己的计算规则。但我们知道是ShuffledRdd,需要将落盘的文件读取进来。
          • ShuffledRdd的read()方法,来到BlockStoreShuffleReader.read()
            • readMetrics.incRecordsRead(1)
            • 在TempShuffleReadMetrics中读取数据override def incRecordsRead(v: Long): Unit = _recordsRead += v
            • 完成shuffle操作
 override def runTask(context: TaskContext): MapStatus = {
    // Deserialize the RDD using the broadcast variable.
    val threadMXBean = ManagementFactory.getThreadMXBean
    val deserializeStartTimeNs = System.nanoTime()
    val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
      threadMXBean.getCurrentThreadCpuTime
    } else 0L
    val ser = SparkEnv.get.closureSerializer.newInstance()
    val rddAndDep = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])](
      ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
    _executorDeserializeTimeNs = System.nanoTime() - deserializeStartTimeNs
    _executorDeserializeCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
      threadMXBean.getCurrentThreadCpuTime - deserializeStartCpuTime
    } else 0L

    val rdd = rddAndDep._1
    val dep = rddAndDep._2
    // While we use the old shuffle fetch protocol, we use partitionId as mapId in the
    // ShuffleBlockId construction.
    val mapId = if (SparkEnv.get.conf.get(config.SHUFFLE_USE_OLD_FETCH_PROTOCOL)) {
      partitionId
    } else context.taskAttemptId()
   // 写出数据
    dep.shuffleWriterProcessor.write(rdd, dep, mapId, context, partition)
  }

写出数据

 /**
   * The write process for particular partition, it controls the life circle of [[ShuffleWriter]]
   * get from [[ShuffleManager]] and triggers rdd compute, finally return the [[MapStatus]] for
   * this task.
   */
  def write(
      rdd: RDD[_],
      dep: ShuffleDependency[_, _, _],
      mapId: Long,
      context: TaskContext,
      partition: Partition): MapStatus = {
    var writer: ShuffleWriter[Any, Any] = null
    try {
      // 获取shuffleManager
      val manager = SparkEnv.get.shuffleManager
      // 从shuffleManager获取writer准备写出
      writer = manager.getWriter[Any, Any](
        dep.shuffleHandle,
        mapId,
        context,
        createMetricsReporter(context))
      // 写出数据
      writer.write(
        rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
      writer.stop(success = true).get
    } catch {
      case e: Exception =>
        try {
          if (writer != null) {
            writer.stop(success = false)
          }
        } catch {
          case e: Exception =>
            log.debug("Could not stop writer", e)
        }
        throw e
    }
  }

写数据是个抽象方法,现在可以告诉是SortShuffleWriter,因为ShuffleManager是特质,有实现类SortShuffleManager根据ShuffleHandle的类型来处理不同的Shuffle,给与不同的写出器。

    handle match {
      // 不安全的Shuffle
      case unsafeShuffleHandle: SerializedShuffleHandle[K @unchecked, V @unchecked] =>
        new UnsafeShuffleWriter(
          env.blockManager,
          context.taskMemoryManager(),
          unsafeShuffleHandle,
          mapId,
          context,
          env.conf,
          metrics,
          shuffleExecutorComponents)
      // bypas归并排序
      case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] =>
        new BypassMergeSortShuffleWriter(
          env.blockManager,
          bypassMergeSortHandle,
          mapId,
          env.conf,
          metrics,
          shuffleExecutorComponents)
      // 基本Shuffle
      case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] =>
        new SortShuffleWriter(
          shuffleBlockResolver, other, mapId, context, shuffleExecutorComponents)
    }

image.png

SortShuffleWriter写出数据,

  override def write(records: Iterator[Product2[K, V]]): Unit = {
    sorter = if (dep.mapSideCombine) {
      new ExternalSorter[K, V, C](
        context, dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
    } else {
      // In this case we pass neither an aggregator nor an ordering to the sorter, because we don't
      // care whether the keys get sorted in each partition; that will be done on the reduce side
      // if the operation being run is sortByKey.
      new ExternalSorter[K, V, V](
        context, aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer)
    }
    sorter.insertAll(records)

    // Don't bother including the time to open the merged output file in the shuffle write time,
    // because it just opens a single file, so is typically too fast to measure accurately
    // (see SPARK-3570).

    // 准备写出器
    val mapOutputWriter = shuffleExecutorComponents.createMapOutputWriter(
      dep.shuffleId, mapId, dep.partitioner.numPartitions)
    // 按照分区写出
    sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter)
    // 提交所有分区
    val partitionLengths = mapOutputWriter.commitAllPartitions()
    mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId)
  }

按照分区写出数据

  def writePartitionedMapOutput(
      shuffleId: Int,
      mapId: Long,
      mapOutputWriter: ShuffleMapOutputWriter): Unit = {
    var nextPartitionId = 0
    if (spills.isEmpty) {
      // Case where we only have in-memory data
      val collection = if (aggregator.isDefined) map else buffer
      val it = collection.destructiveSortedWritablePartitionedIterator(comparator)
      while (it.hasNext()) {
        val partitionId = it.nextPartition()
        var partitionWriter: ShufflePartitionWriter = null
        var partitionPairsWriter: ShufflePartitionPairsWriter = null
        TryUtils.tryWithSafeFinally {
          partitionWriter = mapOutputWriter.getPartitionWriter(partitionId)
          val blockId = ShuffleBlockId(shuffleId, mapId, partitionId)
          partitionPairsWriter = new ShufflePartitionPairsWriter(
            partitionWriter,
            serializerManager,
            serInstance,
            blockId,
            context.taskMetrics().shuffleWriteMetrics)
          while (it.hasNext && it.nextPartition() == partitionId) {
            it.writeNext(partitionPairsWriter)
          }
        } {
          if (partitionPairsWriter != null) {
            partitionPairsWriter.close()
          }
        }
        nextPartitionId = partitionId + 1
      }
    } else {
      // We must perform merge-sort; get an iterator by partition and write everything directly.
      for ((id, elements) <- this.partitionedIterator) {
        val blockId = ShuffleBlockId(shuffleId, mapId, id)
        var partitionWriter: ShufflePartitionWriter = null
        var partitionPairsWriter: ShufflePartitionPairsWriter = null
        TryUtils.tryWithSafeFinally {
          partitionWriter = mapOutputWriter.getPartitionWriter(id)
          partitionPairsWriter = new ShufflePartitionPairsWriter(
            partitionWriter,
            serializerManager,
            serInstance,
            blockId,
            context.taskMetrics().shuffleWriteMetrics)
          if (elements.hasNext) {
            for (elem <- elements) {
              partitionPairsWriter.write(elem._1, elem._2)
            }
          }
        } {
          if (partitionPairsWriter != null) {
            partitionPairsWriter.close()
          }
        }
        nextPartitionId = id + 1
      }
    }

    context.taskMetrics().incMemoryBytesSpilled(memoryBytesSpilled)
    context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled)
    context.taskMetrics().incPeakExecutionMemory(peakMemoryUsedBytes)
  }

写出索引文件并提交

  def writeIndexFileAndCommit(
      shuffleId: Int,
      mapId: Long,
      lengths: Array[Long],
      dataTmp: File): Unit = {
    // 获取索引文件
    val indexFile = getIndexFile(shuffleId, mapId)
    val indexTmp = Utils.tempFileWith(indexFile)
    try {
      // 获取数据文件
      val dataFile = getDataFile(shuffleId, mapId)
      // There is only one IndexShuffleBlockResolver per executor, this synchronization make sure
      // the following check and rename are atomic.
      synchronized {
        // 检查索引和数据文件
        val existingLengths = checkIndexAndDataFile(indexFile, dataFile, lengths.length)
        if (existingLengths != null) {
          // Another attempt for the same task has already written our map outputs successfully,
          // so just use the existing partition lengths and delete our temporary map outputs.
          System.arraycopy(existingLengths, 0, lengths, 0, lengths.length)
          if (dataTmp != null && dataTmp.exists()) {
            dataTmp.delete()
          }
        } else {
          // This is the first successful attempt in writing the map outputs for this task,
          // so override any existing index and data files with the ones we wrote.
          val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexTmp)))
          Utils.tryWithSafeFinally {
            // We take in lengths of each block, need to convert it to offsets.
            var offset = 0L

            // 写出
            out.writeLong(offset)
            for (length <- lengths) {
              offset += length
              out.writeLong(offset)
            }
          } {
            out.close()
          }

          if (indexFile.exists()) {
            indexFile.delete()
          }
          if (dataFile.exists()) {
            dataFile.delete()
          }
          if (!indexTmp.renameTo(indexFile)) {
            throw new IOException("fail to rename file " + indexTmp + " to " + indexFile)
          }
          if (dataTmp != null && dataTmp.exists() && !dataTmp.renameTo(dataFile)) {
            throw new IOException("fail to rename file " + dataTmp + " to " + dataFile)
          }
        }
      }
    } finally {
      if (indexTmp.exists() && !indexTmp.delete()) {
        logError(s"Failed to delete temporary index file at ${indexTmp.getAbsolutePath}")
      }
    }
  }

ResultTask读取shuffle的数据

  override def runTask(context: TaskContext): U = {
    // Deserialize the RDD and the func using the broadcast variables.
    val threadMXBean = ManagementFactory.getThreadMXBean
    val deserializeStartTimeNs = System.nanoTime()
    val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
      threadMXBean.getCurrentThreadCpuTime
    } else 0L
    val ser = SparkEnv.get.closureSerializer.newInstance()
    val (rdd, func) = ser.deserialize[(RDD[T], (TaskContext, Iterator[T]) => U)](
      ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
    _executorDeserializeTimeNs = System.nanoTime() - deserializeStartTimeNs
    _executorDeserializeCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
      threadMXBean.getCurrentThreadCpuTime - deserializeStartCpuTime
    } else 0L

    // rdd.iterator迭代获取
    func(context, rdd.iterator(partition, context))
  }

迭代rdd,获取数据

  final def iterator(split: Partition, context: TaskContext): Iterator[T] = {
    if (storageLevel != StorageLevel.NONE) {
      // 计算执行
      getOrCompute(split, context)
    } else {
      // 读取检查点
      computeOrReadCheckpoint(split, context)
    }
  }
  private[spark] def getOrCompute(partition: Partition, context: TaskContext): Iterator[T] = {
    val blockId = RDDBlockId(id, partition.index)
    var readCachedBlock = true
    // This method is called on executors, so we need call SparkEnv.get instead of sc.env.
    SparkEnv.get.blockManager.getOrElseUpdate(blockId, storageLevel, elementClassTag, () => {
      readCachedBlock = false
      // 计算或读取检查点
      computeOrReadCheckpoint(partition, context)
    }) match {
      case Left(blockResult) =>
        if (readCachedBlock) {
          val existingMetrics = context.taskMetrics().inputMetrics
          existingMetrics.incBytesRead(blockResult.bytes)
          new InterruptibleIterator[T](context, blockResult.data.asInstanceOf[Iterator[T]]) {
            override def next(): T = {
              existingMetrics.incRecordsRead(1)
              delegate.next()
            }
          }
        } else {
          new InterruptibleIterator(context, blockResult.data.asInstanceOf[Iterator[T]])
        }
      case Right(iter) =>
        new InterruptibleIterator(context, iter.asInstanceOf[Iterator[T]])
    }
  }

读取shuffle数据

override def read(): Iterator[Product2[K, C]] = {
    val wrappedStreams = new ShuffleBlockFetcherIterator(
      context,
      blockManager.blockStoreClient,
      blockManager,
      blocksByAddress,
      serializerManager.wrapStream,
      // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
      SparkEnv.get.conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024, // 48m
      SparkEnv.get.conf.get(config.REDUCER_MAX_REQS_IN_FLIGHT),
      SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS),
      SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM),
      SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT),
      SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT_MEMORY),
      readMetrics,
      fetchContinuousBlocksInBatch).toCompletionIterator

    val serializerInstance = dep.serializer.newInstance()

    // Create a key/value iterator for each stream
    val recordIter = wrappedStreams.flatMap { case (blockId, wrappedStream) =>
      // Note: the asKeyValueIterator below wraps a key/value iterator inside of a
      // NextIterator. The NextIterator makes sure that close() is called on the
      // underlying InputStream when all records have been read.
      serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
    }

    // Update the context task metrics for each record read.
    val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
      recordIter.map { record =>
        // 读取记录
        readMetrics.incRecordsRead(1)
        record
      },
      context.taskMetrics().mergeShuffleReadMetrics())

    // An interruptible iterator must be used here in order to support task cancellation
    val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)

    val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
      if (dep.mapSideCombine) {
        // We are reading values that are already combined
        val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
        dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
      } else {
        // We don't know the value type, but also don't care -- the dependency *should*
        // have made sure its compatible w/ this aggregator, which will convert the value
        // type to the combined type C
        val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
        dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
      }
    } else {
      interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
    }

2、写流程

获取到写出器,对数据排序之后按分区写出。

获取到什么样的写出器?

在ShuffleWriterProcessor的write方法中,获取写出器时传入了一个shuffleHandle依赖,就是根据这个依赖条件判断是什么样的写出器,它向ShuffleManager注册一个Shuffle,

writer = manager.getWriter[Any, Any](
        dep.shuffleHandle,

  val shuffleHandle: ShuffleHandle = _rdd.context.env.shuffleManager.registerShuffle(
    shuffleId, this)

所以来到SortShuffleManager中注册Shuffle

  override def registerShuffle[K, V, C](
      shuffleId: Int,
      dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
    // 如果应该忽略合并排序,就使用BypassMergeSortShuffleHandle
    if (SortShuffleWriter.shouldBypassMergeSort(conf, dependency)) {
      // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't
      // need map-side aggregation, then write numPartitions files directly and just concatenate
      // them at the end. This avoids doing serialization and deserialization twice to merge
      // together the spilled files, which would happen with the normal code path. The downside is
      // having multiple files open at a time and thus more memory allocated to buffers.
      new BypassMergeSortShuffleHandle[K, V](
        shuffleId, dependency.asInstanceOf[ShuffleDependency[K, V, V]])
      // 如果能序列化shuffle,使用SerializedShuffleHandle
    } else if (SortShuffleManager.canUseSerializedShuffle(dependency)) {
      // Otherwise, try to buffer map outputs in a serialized form, since this is more efficient:
      new SerializedShuffleHandle[K, V](
        shuffleId, dependency.asInstanceOf[ShuffleDependency[K, V, V]])
    } else {
      // 其他使用BaseShuffleHandle
      // Otherwise, buffer map outputs in a deserialized form:
      new BaseShuffleHandle(shuffleId, dependency)
    }
  }
  • BypassMergeSortShuffleHandle
    • 1、应该忽略归并排序 shouldBypassMergeSort
      • 1、如果有预聚合功能,则不可以使用 if (dep.mapSideCombine) false
      • 2、分区数量小于等于200. dep.partitioner.numPartitions <= bypassMergeThreshold
        • SHUFFLE_SORT_BYPASS_MERGE_THRESHOLD取配置项spark.shuffle.sort.bypassMergeThreshold
        • 默认200 createWithDefault(200)
  • SerializedShuffleHandle
    • 1、能支持序列化 canUseSerializedShuffle
      • 1、是否支持序列化重分配位置
        • Java序列化不支持,Kryo支持
      • 2、如果有预聚合功能,不可以使用
      • 3、分区数量必须小于等于16777216
        • numPartitions > MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE
        • val MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE =PackedRecordPointer.MAXIMUM_PARTITION_ID + 1
        • static final int MAXIMUM_PARTITION_ID = (1 << 24) - 1; // 16777215
  • BaseShuffleHandle

什么情况算应该忽略归并排序?

  def shouldBypassMergeSort(conf: SparkConf, dep: ShuffleDependency[_, _, _]): Boolean = {
    // We cannot bypass sorting if we need to do map-side aggregation.
    if (dep.mapSideCombine) {
      false
    } else {
      val bypassMergeThreshold: Int = conf.get(config.SHUFFLE_SORT_BYPASS_MERGE_THRESHOLD)
      dep.partitioner.numPartitions <= bypassMergeThreshold
    }
  }

什么情况能序列化?

  def canUseSerializedShuffle(dependency: ShuffleDependency[_, _, _]): Boolean = {
    val shufId = dependency.shuffleId
    val numPartitions = dependency.partitioner.numPartitions

    // 是否支持重定位
    if (!dependency.serializer.supportsRelocationOfSerializedObjects) {
      log.debug(s"Can't use serialized shuffle for shuffle $shufId because the serializer, " +
        s"${dependency.serializer.getClass.getName}, does not support object relocation")
      false
    } else if (dependency.mapSideCombine) {
      log.debug(s"Can't use serialized shuffle for shuffle $shufId because we need to do " +
        s"map-side aggregation")
      false
    } else if (numPartitions > MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE) {
      log.debug(s"Can't use serialized shuffle for shuffle $shufId because it has more than " +
        s"$MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE partitions")
      false
    } else {
      log.debug(s"Can use serialized shuffle for shuffle $shufId")
      true
    }
  }

3、归并排序和读

  • 1、获取到SortShuffleWriter之后,write数据,会先进行排序,然后插入数据
  • 2、sorter.insertAll(records)
    • 1、如果有聚合器(预聚合功能),会创建一个map,来按照相同k进行预聚合 @volatile private var map = new PartitionedAppendOnlyMap[K, C]
    • 2、没有预聚合,创建数组把@volatile private var buffer = new PartitionedPairBuffer[K, C]
    • 3、写数据的时候,可能内存不够用,会产生溢写临时文件:maybeSpillCollection(usingMap = true)
      • 先预估容量
      • 是否应该溢写?maybeSpill
        • 如果当前内存超过阈值5m,就应该溢写,或者大于Integer.MAX_VALUE
        • 1、记录溢写日志
        • 2、溢写:spill(collection)
          • map类型溢写map
            • 将内存写入磁盘:spillMemoryIteratorToDisk(inMemoryIterator)
              • 1、创建临时块存储
              • 2、获取磁盘写入器,有一个缓冲区大小为32k
                • private val fileBufferSize = sparkConf.get(config.SHUFFLE_FILE_BUFFER_SIZE).toInt * 1024
          • buffer溢写buffer
        • 3、释放多余内存 releaseMemory()
          • 堆内内存
          • 堆外内存
      • 如果是map,创建map
      • 如果是buffer,创建buffer
  • 3、按照分区写出数据 sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter)
    • 1、判断是否有溢写操作if (spills.isEmpty) {
      • 如果没有,直接写
      • 如果存在溢写操作,遍历this.partitionedIterator,
        • 分区迭代器中判断是否产生溢写,
          • 如果溢写,合并溢写文件:merge(spills, destructiveIterator(collection.partitionedDestructiveSortedIterator(comparator)))
          • 进行聚合、归并排序
  • 4、提交所有分区 mapOutputWriter.commitAllPartitions()
    • 1、从临时文件中获取所以和数据
    • 2、删除临时数据文件
    • 3、将数据写出
    • 4、删除索引和数据文件
  override def write(records: Iterator[Product2[K, V]]): Unit = {
    sorter = if (dep.mapSideCombine) {// 如果预聚合,传入aggregator聚合器,并且按照相同key排序
      new ExternalSorter[K, V, C](
        context, dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
    } else {
      // In this case we pass neither an aggregator nor an ordering to the sorter, because we don't
      // care whether the keys get sorted in each partition; that will be done on the reduce side
      // if the operation being run is sortByKey.
      new ExternalSorter[K, V, V]( // 没有聚合则没有哦聚合器和排序
        context, aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer)
    }
    // 在这里插入数据
    sorter.insertAll(records)

    // Don't bother including the time to open the merged output file in the shuffle write time,
    // because it just opens a single file, so is typically too fast to measure accurately
    // (see SPARK-3570).
    val mapOutputWriter = shuffleExecutorComponents.createMapOutputWriter(
      dep.shuffleId, mapId, dep.partitioner.numPartitions)
    sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter)
    val partitionLengths = mapOutputWriter.commitAllPartitions()
    mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId)
  }

排序

 def insertAll(records: Iterator[Product2[K, V]]): Unit = {
    // TODO: stop combining if we find that the reduction factor isn't high
    val shouldCombine = aggregator.isDefined

    if (shouldCombine) {
      // Combine values in-memory first using our AppendOnlyMap
      val mergeValue = aggregator.get.mergeValue
      val createCombiner = aggregator.get.createCombiner
      var kv: Product2[K, V] = null
      val update = (hadValue: Boolean, oldValue: C) => {
        if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2)
      }
      while (records.hasNext) {
        addElementsRead()
        kv = records.next()
        map.changeValue((getPartition(kv._1), kv._1), update)
        maybeSpillCollection(usingMap = true)
      }
    } else {
      // Stick values into our buffer
      while (records.hasNext) {
        addElementsRead()
        val kv = records.next()
        buffer.insert(getPartition(kv._1), kv._1, kv._2.asInstanceOf[C])
        maybeSpillCollection(usingMap = false)
      }
    }
  }

溢写

 private def maybeSpillCollection(usingMap: Boolean): Unit = {
    var estimatedSize = 0L
    if (usingMap) {
      estimatedSize = map.estimateSize()
      if (maybeSpill(map, estimatedSize)) {
        map = new PartitionedAppendOnlyMap[K, C]
      }
    } else {
      estimatedSize = buffer.estimateSize()
      if (maybeSpill(buffer, estimatedSize)) {
        buffer = new PartitionedPairBuffer[K, C]
      }
    }

    if (estimatedSize > _peakMemoryUsedBytes) {
      _peakMemoryUsedBytes = estimatedSize
    }
  }

是否应该溢写,当内存中的数据是32的倍数并且当前内存超过阈值(5m)时,应该溢写

  protected def maybeSpill(collection: C, currentMemory: Long): Boolean = {
    var shouldSpill = false
    if (elementsRead % 32 == 0 && currentMemory >= myMemoryThreshold) {
      // Claim up to double our current memory from the shuffle memory pool
      val amountToRequest = 2 * currentMemory - myMemoryThreshold
      val granted = acquireMemory(amountToRequest)
      myMemoryThreshold += granted
      // If we were granted too little memory to grow further (either tryToAcquire returned 0,
      // or we already had more memory than myMemoryThreshold), spill the current collection
      shouldSpill = currentMemory >= myMemoryThreshold
    }
    shouldSpill = shouldSpill || _elementsRead > numElementsForceSpillThreshold
    // Actually spill
    if (shouldSpill) {
      _spillCount += 1
      logSpillage(currentMemory)
      spill(collection)
      _elementsRead = 0
      _memoryBytesSpilled += currentMemory
      releaseMemory()
    }
    shouldSpill
  }

@volatile private[this] var myMemoryThreshold = initialMemoryThreshold
 private[this] val initialMemoryThreshold: Long =
    SparkEnv.get.conf.get(SHUFFLE_SPILL_INITIAL_MEM_THRESHOLD)

private[spark] val SHUFFLE_SPILL_INITIAL_MEM_THRESHOLD =
    ConfigBuilder("spark.shuffle.spill.initialMemoryThreshold")
      .internal()
      .doc("Initial threshold for the size of a collection before we start tracking its " +
        "memory usage.")
      .version("1.1.1")
      .bytesConf(ByteUnit.BYTE)
      .createWithDefault(5 * 1024 * 1024)

spill溢写

  override protected[this] def spill(collection: SizeTracker): Unit = {
    val inMemoryIterator = currentMap.destructiveSortedIterator(keyComparator)
    // 将内存数据写出磁盘
    val diskMapIterator = spillMemoryIteratorToDisk(inMemoryIterator)
    spilledMaps += diskMapIterator
  }

溢写缓冲区32k

  private[spark] val SHUFFLE_FILE_BUFFER_SIZE =
    ConfigBuilder("spark.shuffle.file.buffer")
      .doc("Size of the in-memory buffer for each shuffle file output stream, in KiB unless " +
        "otherwise specified. These buffers reduce the number of disk seeks and system calls " +
        "made in creating intermediate shuffle files.")
      .version("1.4.0")
      .bytesConf(ByteUnit.KiB)
      .checkValue(v => v > 0 && v <= ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH / 1024,
        s"The file buffer size must be positive and less than or equal to" +
          s" ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH / 1024}.")
      .createWithDefaultString("32k")

合并溢写文件

  private def merge(spills: Seq[SpilledFile], inMemory: Iterator[((Int, K), C)])
      : Iterator[(Int, Iterator[Product2[K, C]])] = {
    val readers = spills.map(new SpillReader(_))
    val inMemBuffered = inMemory.buffered
    (0 until numPartitions).iterator.map { p =>
      val inMemIterator = new IteratorForPartition(p, inMemBuffered)
      val iterators = readers.map(_.readNextPartition()) ++ Seq(inMemIterator)
      // 聚合
      if (aggregator.isDefined) {
        // Perform partial aggregation across partitions
        (p, mergeWithAggregation(
          iterators, aggregator.get.mergeCombiners, keyComparator, ordering.isDefined))
        // 归并排序
      } else if (ordering.isDefined) {
        // No aggregator given, but we have an ordering (e.g. used by reduce tasks in sortByKey);
        // sort the elements without trying to merge them
        (p, mergeSort(iterators, ordering.get))
      } else {
        (p, iterators.iterator.flatten)
      }
    }
  }

五、内存管理

1、内存的分类

  • 存储内存(60%的50%=30%)
    • 缓存数据
    • 广播变量
  • 执行内存(60%的50%=30%)
    • Shuffle过程中的操作
  • 其它内存(40%)
    • 系统,rdd元数据的信息
  • 预留内存
    • 300m

有一个类MemoryManager,统一内存管理

private[spark] abstract class MemoryManager(
    conf: SparkConf,
    numCores: Int,
    onHeapStorageMemory: Long,
    onHeapExecutionMemory: Long) extends Logging {

  require(onHeapExecutionMemory > 0, "onHeapExecutionMemory must be > 0")

  // -- Methods related to memory allocation policies and bookkeeping ------------------------------

  @GuardedBy("this")
  protected val onHeapStorageMemoryPool = new StorageMemoryPool(this, MemoryMode.ON_HEAP)
  @GuardedBy("this")
  protected val offHeapStorageMemoryPool = new StorageMemoryPool(this, MemoryMode.OFF_HEAP)
  @GuardedBy("this")
  protected val onHeapExecutionMemoryPool = new ExecutionMemoryPool(this, MemoryMode.ON_HEAP)
  @GuardedBy("this")
  protected val offHeapExecutionMemoryPool = new ExecutionMemoryPool(this, MemoryMode.OFF_HEAP)

动态占用机制:

  • 存储内存和执行内存可以互相占用
    • 1、当存储内存和执行内存占满了内存,将存储内存的数据溢写文件。如果没有开启溢写功能,可能丢失
    • 2、当存储内存借用执行内存的空间,如果占满了,会让存储内存释放,考虑是否溢写文件
    • 3、当执行内存借用存储内存的空间,占满了则不归还内存,继续动刀存储内存!

image.png