PyTorch 分布式操作之 Barrier

关于 barrier 的概念

关于 barrier 这个概念可以参考 Wiki 中的介绍:同步屏障(Barrier)是并行计算中的一种同步方法。对于一群进程或线程,程序中的一个同步屏障意味着任何线程/进程执行到此后必须等待,直到所有线程/进程都到达此点才可继续执行下文。

这里要注意,barrier 这一方法并不是 pytorch 独有的,这是并行计算中的一个基本概念,其他的并行计算的场景下也可能会涉及这一概念和操作。本文主要讨论 pytorch 中的情况。

  1. torch.distributed.barrier(group=None, async_op=False, device_ids=None)
  2. Synchronizes all processes.
  3. This collective blocks processes until the whole group enters this function, if async_op is False, or if async work handle is called on wait().
  4. Parameters
  5. group (ProcessGroup, optional) The process group to work on. If None, the default process group will be used.
  6. async_op (bool, optional) Whether this op should be an async op
  7. device_ids ([int], optional) List of device/GPU ids. Valid only for NCCL backend.
  8. Returns
  9. Async work handle, if async_op is set to True. None, if not async_op or if not part of the group

在多卡训练的时候,由于不同的 GPU 往往被设定在不同的进程中,有时候为了在单独的进程中执行一些任务,但是又同时希望限制其他进程的执行进度,就有了使用barrier的需求。
一个实际的场景是准备数据集:我们只需要在 0 号进程处理,其他进程没必要也执行这一任务,但是其他进程的后续工作却依赖准备好的数据。于是就需要在 0 号进程执行过程中阻塞其他的进程,使其进入等待状态。等到处理好之后,再一起放行。

这种需求下,一个典型的基于上下文管理器形式的构造如下:

  1. # https://github.com/ultralytics/yolov5/blob/7d56d451241e94cd9dbe4fcb9bfba0e92c6e0e23/utils/torch_utils.py#L29-L38
  2. @contextmanager
  3. def torch_distributed_zero_first(local_rank: int):
  4. """
  5. Decorator to make all processes in distributed training
  6. wait for each local_master to do something.
  7. """
  8. if local_rank not in [-1, 0]:
  9. dist.barrier(device_ids=[local_rank])
  10. yield
  11. if local_rank == 0:
  12. dist.barrier(device_ids=[0])

关于 barrier 的细节

  1. # -*- coding: utf-8 -*-
  2. # @Time : 2021/1/6
  3. # @Author : Lart Pang
  4. # @GitHub : https://github.com/lartpang
  5. import os
  6. import time
  7. import torch.distributed as dist
  8. import torch.multiprocessing as mp
  9. def ddp_test_v0(local_rank, word_size):
  10. # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
  11. dist.init_process_group(backend="nccl", world_size=word_size, rank=local_rank)
  12. print("first before barrier{}\n".format(local_rank))
  13. if local_rank != 0:
  14. dist.barrier()
  15. print("first after barrier{}\n".format(local_rank))
  16. print("inter {}".format(local_rank))
  17. print("second before barrier{}\n".format(local_rank))
  18. if local_rank == 0:
  19. dist.barrier()
  20. print("second after barrier{}\n".format(local_rank))
  21. print("{} exit".format(local_rank))
  22. def ddp_test_v1(local_rank, word_size):
  23. # Initializes the distributed backend which will take care of synchronizing nodes/GPUs
  24. dist.init_process_group(backend="nccl", world_size=word_size, rank=local_rank)
  25. if local_rank != 0:
  26. print("1 before barrier{}\n".format(local_rank))
  27. start = time.time()
  28. time.sleep(5)
  29. dist.barrier()
  30. print(time.time() - start)
  31. print("1 after barrier{}\n".format(local_rank))
  32. dist.barrier()
  33. print("1 after barrier{}\n".format(local_rank))
  34. else:
  35. print("0 before barrier{}\n".format(local_rank))
  36. start = time.time()
  37. dist.barrier()
  38. print(time.time() - start)
  39. print("0 after barrier{}\n".format(local_rank))
  40. print("0 after barrier{}\n".format(local_rank))
  41. dist.barrier()
  42. print("0 after barrier{}\n".format(local_rank))
  43. print("{} exit".format(local_rank))
  44. def main():
  45. world_size = 2
  46. os.environ["MASTER_ADDR"] = "127.0.0.1"
  47. os.environ["MASTER_PORT"] = "29500"
  48. mp.spawn(ddp_test_v0, args=(world_size,), nprocs=world_size, join=True)
  49. if __name__ == "__main__":
  50. main()

这里展示了两个例子,实际上在官方展示的 dist.barrier 之外显示了该方法的一个重要特性,就是其操作实际上是每一个进程内部都需要对应的执行同样的次数,才会对应的由阻塞变为正常运行。
先看第一个例子:

def ddp_test(local_rank, word_size):
    # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
    dist.init_process_group(backend="nccl", world_size=word_size, rank=local_rank)

    print("first before barrier{}\n".format(local_rank))
    if local_rank != 0:
        dist.barrier()
    print("first after barrier{}\n".format(local_rank))

    print("inter {}".format(local_rank))

    print("second before barrier{}\n".format(local_rank))
    if local_rank == 0:
        dist.barrier()
    print("second after barrier{}\n".format(local_rank))

    print("{} exit".format(local_rank))

其输出是:

first before barrier1
first before barrier0


first after barrier0

inter 0
second before barrier0

second after barrier0

0 exit
first after barrier1

inter 1
second before barrier1

second after barrier1

1 exit

Process finished with exit code 0

可以看到,有几个细节:

  • barrier 之前,所有的操作都是各 GPU 进程自己输出自己的。
    • 由于 local_rank=0 执行到自己可见的 barrier 中间会输出多个,而 local_rank=1 则只有一条 first before barrier1
  • second before barrier0 之后,0 号执行到了属于自己的 barrier ,这回让使得其他进程不再阻塞,开始正常运行。由于中间操作的时间,所以先是 0 号输出自己的 second after barrier0 并随之退出,之后 1 号也接着开始输出自己的结果。

这里有一点值得注意,不同进程的 barrier 实际上是互相对应的,必须所有进程都执行一次barrier,才会重新放行正常前进。
对于第二段代码:

def ddp_test_v1(local_rank, word_size):
    # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
    dist.init_process_group(backend="nccl", world_size=word_size, rank=local_rank)

    if local_rank != 0:
        print("1 before barrier{}\n".format(local_rank))
        start = time.time()
        time.sleep(5)
        dist.barrier()
        print(time.time() - start)
        print("1 after barrier{}\n".format(local_rank))
        dist.barrier()
        print("1 after barrier{}\n".format(local_rank))
    else:
        print("0 before barrier{}\n".format(local_rank))
        start = time.time()
        dist.barrier()
        print(time.time() - start)
        print("0 after barrier{}\n".format(local_rank))
        print("0 after barrier{}\n".format(local_rank))
        dist.barrier()
        print("0 after barrier{}\n".format(local_rank))

    print("{} exit".format(local_rank))

则是有输出:

1 before barrier1
0 before barrier0


5.002117395401001
5.0021262168884281 after barrier1


0 after barrier0

0 after barrier0

0 after barrier0

0 exit
1 after barrier1

1 exit

Process finished with exit code 0

可以看到一个重要的点,就是这两处 print(time.time() - start) 的输出是基本一样的,不管前面延时多少, barrier 后面的时间都是按照最长到达并执行 barrier 的间隔时间来的。这个更体现了不同进程 barrier 之间的互相限制关系。而 0 到达自己的第二个 barrier 之后,会使得 1 号再次运行。但是此时 0 是先结束的。
另外,可以验证,如果某个编号对应的代码中的两个 barrier 之中的一个,那么另一个就会陷入无限等待之中。
例如:


def ddp_test_v1(local_rank, word_size):
    # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
    dist.init_process_group(backend="nccl", world_size=word_size, rank=local_rank)

    if local_rank != 0:
        print("1 before barrier{}\n".format(local_rank))
        start = time.time()
        time.sleep(5)
        dist.barrier()
        print(time.time() - start)
        print("1 after barrier{}\n".format(local_rank))
        # dist.barrier()
        print("1 after barrier{}\n".format(local_rank))
    else:
        print("0 before barrier{}\n".format(local_rank))
        start = time.time()
        time.sleep(3)
        dist.barrier()
        print(time.time() - start)
        print("0 after barrier{}\n".format(local_rank))
        print("0 after barrier{}\n".format(local_rank))
        dist.barrier()
        print("0 after barrier{}\n".format(local_rank))

    print("{} exit".format(local_rank))

输出:

0 before barrier0
1 before barrier1


5.002458572387695
1 after barrier1

1 after barrier1

1 exit
5.002473831176758
0 after barrier0

0 after barrier0

Traceback (most recent call last):
  File "/home/lart/Coding/SODBetterProj/tools/dist_experiment_test.py", line 67, in <module>
    main()
  File "/home/lart/Coding/SODBetterProj/tools/dist_experiment_test.py", line 63, in main
    mp.spawn(ddp_test_v1, args=(world_size,), nprocs=world_size, join=True)
  File "/home/lart/miniconda3/envs/pt17/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 199, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
  File "/home/lart/miniconda3/envs/pt17/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 157, in start_processes
    while not context.join():
  File "/home/lart/miniconda3/envs/pt17/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 75, in join
    ready = multiprocessing.connection.wait(
  File "/home/lart/miniconda3/envs/pt17/lib/python3.8/multiprocessing/connection.py", line 931, in wait
    ready = selector.select(timeout)
  File "/home/lart/miniconda3/envs/pt17/lib/python3.8/selectors.py", line 415, in select
    fd_event_list = self._selector.poll(timeout)
KeyboardInterrupt

Process finished with exit code 137 (interrupted by signal 9: SIGKILL)

会在第二个 barrier 处无限等待下去。
这一特点在这个回答中也被提到了:

when a process encounters a barrier it will block the position of the barrier is not important (not all processes have to enter the same if-statement, for instance) a process is blocked by a barrier until all processes have encountered a barrier, upon which the barrier is lifted for all processes

https://stackoverflow.com/a/59766443

重要的参考资料