|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
|
import collections |
|
import collections.abc |
|
import ctypes |
|
import functools |
|
import os |
|
from datetime import timedelta |
|
from typing import Any, Callable, Optional |
|
|
|
import pynvml |
|
import torch |
|
import torch.distributed as dist |
|
|
|
from .log import log |
|
from .device import Device |
|
|
|
|
|
def init() -> int | None: |
|
"""Initialize distributed training.""" |
|
|
|
pynvml.nvmlInit() |
|
local_rank = int(os.getenv("LOCAL_RANK", 0)) |
|
device = Device(local_rank) |
|
os.sched_setaffinity(0, device.get_cpu_affinity()) |
|
|
|
os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "0" |
|
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1" |
|
if dist.is_available(): |
|
if dist.is_initialized(): |
|
return torch.cuda.current_device() |
|
torch.cuda.set_device(local_rank) |
|
|
|
timeout_seconds = os.getenv("TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC", 1800) |
|
|
|
timeout_timedelta = timedelta(seconds=int(timeout_seconds)) |
|
dist.init_process_group(backend="nccl", init_method="env://", timeout=timeout_timedelta) |
|
log.critical( |
|
f"Initialized distributed training with local rank {local_rank} with timeout {timeout_seconds}", |
|
rank0_only=False, |
|
) |
|
|
|
_libcudart = ctypes.CDLL("libcudart.so") |
|
|
|
p_value = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int)) |
|
_libcudart.cudaDeviceSetLimit(ctypes.c_int(0x05), ctypes.c_int(128)) |
|
_libcudart.cudaDeviceGetLimit(p_value, ctypes.c_int(0x05)) |
|
log.info(f"Training with {get_world_size()} GPUs.") |
|
|
|
|
|
def get_rank(group: Optional[dist.ProcessGroup] = None) -> int: |
|
"""Get the rank (GPU device) of the worker. |
|
|
|
Returns: |
|
rank (int): The rank of the worker. |
|
""" |
|
rank = 0 |
|
if dist.is_available() and dist.is_initialized(): |
|
rank = dist.get_rank(group) |
|
return rank |
|
|
|
|
|
def get_world_size(group: Optional[dist.ProcessGroup] = None) -> int: |
|
"""Get world size. How many GPUs are available in this job. |
|
|
|
Returns: |
|
world_size (int): The total number of GPUs available in this job. |
|
""" |
|
world_size = 1 |
|
if dist.is_available() and dist.is_initialized(): |
|
world_size = dist.get_world_size(group) |
|
return world_size |
|
|
|
|
|
def is_rank0() -> bool: |
|
"""Check if current process is the master GPU. |
|
|
|
Returns: |
|
(bool): True if this function is called from the master GPU, else False. |
|
""" |
|
return get_rank() == 0 |
|
|
|
|
|
def rank0_only(func: Callable) -> Callable: |
|
"""Apply this function only to the master GPU. |
|
|
|
Example usage: |
|
@rank0_only |
|
def func(x): |
|
return x + 3 |
|
|
|
Args: |
|
func (Callable): a function. |
|
|
|
Returns: |
|
(Callable): A function wrapper executing the function only on the master GPU. |
|
""" |
|
|
|
@functools.wraps(func) |
|
def wrapper(*args, **kwargs): |
|
if is_rank0(): |
|
return func(*args, **kwargs) |
|
else: |
|
return None |
|
|
|
return wrapper |
|
|
|
|
|
def barrier() -> None: |
|
"""Barrier for all GPUs.""" |
|
if dist.is_available() and dist.is_initialized(): |
|
dist.barrier() |
|
|
|
|
|
class DistributedDataParallel(torch.nn.parallel.DistributedDataParallel): |
|
"""This extends torch.nn.parallel.DistributedDataParallel with .training_step(). |
|
|
|
This borrows the concept of `forward-redirection` from Pytorch lightning. It wraps an coreModel such that |
|
model.training_step() would be executed when calling self.training_step(), while preserving the behavior of calling |
|
model() for Pytorch modules. Internally, this is a double rerouting mechanism (training_step -> forward -> |
|
training_step), allowing us to preserve the function names and signatures. |
|
""" |
|
|
|
def __init__(self, model: torch.nn.Module, *args, **kwargs): |
|
super().__init__(model, *args, **kwargs) |
|
|
|
def training_step(self, *args, **kwargs) -> Any: |
|
|
|
original_forward = self.module.forward |
|
|
|
def wrapped_training_step(*_args, **_kwargs): |
|
|
|
self.module.forward = original_forward |
|
|
|
return self.module.training_step(*_args, **_kwargs) |
|
|
|
|
|
self.module.forward = wrapped_training_step |
|
|
|
|
|
return self(*args, **kwargs) |
|
|
|
|
|
def collate_batches(data_batches: list[dict[str, torch.Tensor]]) -> torch.Tensor | dict[str, torch.Tensor]: |
|
"""Aggregate the list of data batches from all devices and process the results. |
|
|
|
This is used for gathering validation data batches with utils.dataloader.DistributedEvalSampler. |
|
It will return the data/output of the entire validation set in its original index order. The sizes of data_batches |
|
in different ranks may differ by 1 (if dataset size is not evenly divisible), in which case a dummy sample will be |
|
created before calling dis.all_gather(). |
|
|
|
Args: |
|
data_batches (list[dict[str, torch.Tensor]]): List of tensors or (hierarchical) dictionary where |
|
leaf entries are tensors. |
|
|
|
Returns: |
|
data_gather (torch.Tensor | dict[str, torch.Tensor]): tensors or (hierarchical) dictionary where |
|
leaf entries are concatenated tensors. |
|
""" |
|
if isinstance(data_batches[0], torch.Tensor): |
|
|
|
data_concat = torch.cat(data_batches, dim=0) |
|
|
|
max_num_local_samples = torch.tensor(len(data_concat), device="cuda") |
|
dist.all_reduce(max_num_local_samples, op=dist.ReduceOp.MAX) |
|
if len(data_concat) < max_num_local_samples: |
|
assert len(data_concat) + 1 == max_num_local_samples |
|
dummy = torch.empty_like(data_concat[:1]) |
|
data_concat = torch.cat([data_concat, dummy], dim=0) |
|
dummy_count = torch.tensor(1, device="cuda") |
|
else: |
|
dummy_count = torch.tensor(0, device="cuda") |
|
|
|
dist.all_reduce(dummy_count, op=dist.ReduceOp.SUM) |
|
data_concat = all_gather_tensor(data_concat.contiguous()) |
|
data_collate = torch.stack(data_concat, dim=1).flatten(start_dim=0, end_dim=1) |
|
|
|
if dummy_count > 0: |
|
data_collate = data_collate[:-dummy_count] |
|
elif isinstance(data_batches[0], collections.abc.Mapping): |
|
data_collate = dict() |
|
for key in data_batches[0].keys(): |
|
data_collate[key] = collate_batches([data[key] for data in data_batches]) |
|
else: |
|
raise TypeError |
|
return data_collate |
|
|
|
|
|
@torch.no_grad() |
|
def all_gather_tensor(tensor: torch.Tensor) -> list[torch.Tensor]: |
|
"""Gather the corresponding tensor from all GPU devices to a list. |
|
|
|
Args: |
|
tensor (torch.Tensor): Pytorch tensor. |
|
|
|
Returns: |
|
tensor_list (list[torch.Tensor]): A list of Pytorch tensors gathered from all GPU devices. |
|
""" |
|
tensor_list = [torch.zeros_like(tensor) for _ in range(get_world_size())] |
|
dist.all_gather(tensor_list, tensor) |
|
return tensor_list |
|
|
|
|
|
def broadcast(tensor, src, group=None, async_op=False): |
|
world_size = get_world_size() |
|
if world_size < 2: |
|
return tensor |
|
dist.broadcast(tensor, src=src, group=group, async_op=async_op) |
|
|