分布式数据并行入门
原文: https://pytorch.org/tutorials/intermediate/ddp_tutorial.html
作者:申力
DistributedDataParallel (DDP)在模块级别实现数据并行性。 它使用 Torch.distributed 程序包中的通信集合来同步梯度,参数和缓冲区。 并行性在流程内和跨流程均可用。 在一个过程中,DDP 将输入模块复制到device_ids中指定的设备,将输入沿批次维度分散,然后将输出收集到output_device,这与 DataParallel 相似。 在整个过程中,DDP 在正向传递中插入必要的参数同步,在反向传递中插入梯度同步。 用户可以将进程映射到可用资源,只要进程不共享 GPU 设备即可。 推荐的方法(通常是最快的方法)是为每个模块副本创建一个过程,即在一个过程中不进行任何模块复制。 本教程中的代码在 8-GPU 服务器上运行,但可以轻松地推广到其他环境。
DataParallel和DistributedDataParallel之间的比较
在深入探讨之前,让我们澄清一下为什么尽管增加了复杂性,但还是考虑使用DistributedDataParallel而不是DataParallel:
- 首先,请回顾先前的教程,如果模型太大而无法容纳在单个 GPU 上,则必须使用模型并行将其拆分到多个 GPU 中。
DistributedDataParallel与模型并行一起使用;DataParallel目前没有。 DataParallel是单进程,多线程,并且只能在单台机器上运行,而DistributedDataParallel是多进程,并且适用于单机和多机训练。 因此,即使在单机训练中,数据足够小以适合单机,DistributedDataParallel仍比DataParallel快。DistributedDataParallel还预先复制模型,而不是在每次迭代时复制模型,并避免了全局解释器锁定。- 如果您的两个数据都太大而无法容纳在一台计算机和上,而您的模型又太大了以至于无法安装在单个 GPU 上,则可以将模型并行(跨多个 GPU 拆分单个模型)与
DistributedDataParallel结合使用。 在这种情况下,每个DistributedDataParallel进程都可以并行使用模型,而所有进程都将并行使用数据。
基本用例
要创建 DDP 模块,请首先正确设置过程组。 更多细节可以在用 PyTorch 编写分布式应用程序中找到。
import osimport tempfileimport torchimport torch.distributed as distimport torch.nn as nnimport torch.optim as optimimport torch.multiprocessing as mpfrom torch.nn.parallel import DistributedDataParallel as DDPdef setup(rank, world_size):os.environ['MASTER_ADDR'] = 'localhost'os.environ['MASTER_PORT'] = '12355'# initialize the process groupdist.init_process_group("gloo", rank=rank, world_size=world_size)# Explicitly setting seed to make sure that models created in two processes# start from same random weights and biases.torch.manual_seed(42)def cleanup():dist.destroy_process_group()
现在,让我们创建一个玩具模块,将其与 DDP 封装在一起,并提供一些虚拟输入数据。 请注意,由于DDP将0级进程中的模型状态广播到DDP构造函数中的所有其他进程,因此无需担心不同的DDP进程从不同的模型参数初始值开始。
class ToyModel(nn.Module):def __init__(self):super(ToyModel, self).__init__()self.net1 = nn.Linear(10, 10)self.relu = nn.ReLU()self.net2 = nn.Linear(10, 5)def forward(self, x):return self.net2(self.relu(self.net1(x)))def demo_basic(rank, world_size):setup(rank, world_size)# setup devices for this process, rank 1 uses GPUs [0, 1, 2, 3] and# rank 2 uses GPUs [4, 5, 6, 7].n = torch.cuda.device_count() // world_sizedevice_ids = list(range(rank * n, (rank + 1) * n))# create model and move it to device_ids[0]model = ToyModel().to(device_ids[0])# output_device defaults to device_ids[0]ddp_model = DDP(model, device_ids=device_ids)loss_fn = nn.MSELoss()optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)optimizer.zero_grad()outputs = ddp_model(torch.randn(20, 10))labels = torch.randn(20, 5).to(device_ids[0])loss_fn(outputs, labels).backward()optimizer.step()cleanup()def run_demo(demo_fn, world_size):mp.spawn(demo_fn,args=(world_size,),nprocs=world_size,join=True)
如您所见,DDP 包装了较低级别的分布式通信详细信息,并提供了干净的 API,就好像它是本地模型一样。 对于基本用例,DDP 仅需要几个 LoC 来设置流程组。 在将 DDP 应用到更高级的用例时,需要注意一些警告。
偏斜的处理速度
在 DDP 中,构造函数,转发方法和输出的微分是分布式同步点。 期望不同的过程以相同的顺序到达同步点,并在大致相同的时间进入每个同步点。 否则,快速流程可能会提早到达,并在等待流浪者时超时。 因此,用户负责平衡流程之间的工作负载分配。 有时,由于例如网络延迟,资源争用,不可预测的工作量峰值,不可避免地会出现偏斜的处理速度。 为了避免在这些情况下超时,请在调用 init_process_group 时传递足够大的timeout值。
保存和加载检查点
在训练过程中通常使用torch.save和torch.load来检查点模块并从检查点中恢复。 有关更多详细信息,请参见保存和加载模型。 使用 DDP 时,一种优化方法是仅在一个进程中保存模型,然后将其加载到所有进程中,从而减少写开销。 这是正确的,因为所有过程都从相同的参数开始,并且梯度在向后传递中同步,因此优化程序应将参数设置为相同的值。 如果使用此优化,请确保在保存完成之前不要启动所有进程。 此外,在加载模块时,您需要提供适当的map_location参数,以防止进程进入其他设备。 如果缺少map_location,则torch.load将首先将该模块加载到 CPU,然后将每个参数复制到其保存位置,这将导致同一台机器上的所有进程使用同一组设备。
def demo_checkpoint(rank, world_size):setup(rank, world_size)# setup devices for this process, rank 1 uses GPUs [0, 1, 2, 3] and# rank 2 uses GPUs [4, 5, 6, 7].n = torch.cuda.device_count() // world_sizedevice_ids = list(range(rank * n, (rank + 1) * n))model = ToyModel().to(device_ids[0])# output_device defaults to device_ids[0]ddp_model = DDP(model, device_ids=device_ids)loss_fn = nn.MSELoss()optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)CHECKPOINT_PATH = tempfile.gettempdir() + "/model.checkpoint"if rank == 0:# All processes should see same parameters as they all start from same# random parameters and gradients are synchronized in backward passes.# Therefore, saving it in one process is sufficient.torch.save(ddp_model.state_dict(), CHECKPOINT_PATH)# Use a barrier() to make sure that process 1 loads the model after process# 0 saves it.dist.barrier()# configure map_location properlyrank0_devices = [x - rank * len(device_ids) for x in device_ids]device_pairs = zip(rank0_devices, device_ids)map_location = {'cuda:%d' % x: 'cuda:%d' % y for x, y in device_pairs}ddp_model.load_state_dict(torch.load(CHECKPOINT_PATH, map_location=map_location))optimizer.zero_grad()outputs = ddp_model(torch.randn(20, 10))labels = torch.randn(20, 5).to(device_ids[0])loss_fn = nn.MSELoss()loss_fn(outputs, labels).backward()optimizer.step()# Use a barrier() to make sure that all processes have finished reading the# checkpointdist.barrier()if rank == 0:os.remove(CHECKPOINT_PATH)cleanup()
将 DDP 与模型并行性结合
DDP 还可以与多 GPU 模型一起使用,但是不支持进程内的复制。 您需要为每个模块副本创建一个进程,与每个进程的多个副本相比,通常可以提高性能。 当训练具有大量数据的大型模型时,DDP 包装多 GPU 模型特别有用。 使用此功能时,需要小心地实现多 GPU 模型,以避免使用硬编码的设备,因为会将不同的模型副本放置到不同的设备上。
class ToyMpModel(nn.Module):def __init__(self, dev0, dev1):super(ToyMpModel, self).__init__()self.dev0 = dev0self.dev1 = dev1self.net1 = torch.nn.Linear(10, 10).to(dev0)self.relu = torch.nn.ReLU()self.net2 = torch.nn.Linear(10, 5).to(dev1)def forward(self, x):x = x.to(self.dev0)x = self.relu(self.net1(x))x = x.to(self.dev1)return self.net2(x)
将多 GPU 模型传递给 DDP 时,不得设置device_ids和output_device。 输入和输出数据将通过应用程序或模型forward()方法放置在适当的设备中。
def demo_model_parallel(rank, world_size):setup(rank, world_size)# setup mp_model and devices for this processdev0 = rank * 2dev1 = rank * 2 + 1mp_model = ToyMpModel(dev0, dev1)ddp_mp_model = DDP(mp_model)loss_fn = nn.MSELoss()optimizer = optim.SGD(ddp_mp_model.parameters(), lr=0.001)optimizer.zero_grad()# outputs will be on dev1outputs = ddp_mp_model(torch.randn(20, 10))labels = torch.randn(20, 5).to(dev1)loss_fn(outputs, labels).backward()optimizer.step()cleanup()if __name__ == "__main__":run_demo(demo_basic, 2)run_demo(demo_checkpoint, 2)if torch.cuda.device_count() >= 8:run_demo(demo_model_parallel, 4)
