Spaces:
Build error
Build error
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
"""Distributed helpers.""" | |
import functools | |
import logging | |
import pickle | |
import torch | |
import torch.distributed as dist | |
_LOCAL_PROCESS_GROUP = None | |
def all_gather(tensors): | |
""" | |
All gathers the provided tensors from all processes across machines. | |
Args: | |
tensors (list): tensors to perform all gather across all processes in | |
all machines. | |
""" | |
gather_list = [] | |
output_tensor = [] | |
world_size = dist.get_world_size() | |
for tensor in tensors: | |
tensor_placeholder = [ | |
torch.ones_like(tensor) for _ in range(world_size) | |
] | |
dist.all_gather(tensor_placeholder, tensor, async_op=False) | |
gather_list.append(tensor_placeholder) | |
for gathered_tensor in gather_list: | |
output_tensor.append(torch.cat(gathered_tensor, dim=0)) | |
return output_tensor | |
def all_reduce(tensors, average=True): | |
""" | |
All reduce the provided tensors from all processes across machines. | |
Args: | |
tensors (list): tensors to perform all reduce across all processes in | |
all machines. | |
average (bool): scales the reduced tensor by the number of overall | |
processes across all machines. | |
""" | |
for tensor in tensors: | |
dist.all_reduce(tensor, async_op=False) | |
if average: | |
world_size = dist.get_world_size() | |
for tensor in tensors: | |
tensor.mul_(1.0 / world_size) | |
return tensors | |
def init_process_group( | |
local_rank, | |
local_world_size, | |
shard_id, | |
num_shards, | |
init_method, | |
dist_backend="nccl", | |
): | |
""" | |
Initializes the default process group. | |
Args: | |
local_rank (int): the rank on the current local machine. | |
local_world_size (int): the world size (number of processes running) on | |
the current local machine. | |
shard_id (int): the shard index (machine rank) of the current machine. | |
num_shards (int): number of shards for distributed training. | |
init_method (string): supporting three different methods for | |
initializing process groups: | |
"file": use shared file system to initialize the groups across | |
different processes. | |
"tcp": use tcp address to initialize the groups across different | |
dist_backend (string): backend to use for distributed training. Options | |
includes gloo, mpi and nccl, the details can be found here: | |
https://pytorch.org/docs/stable/distributed.html | |
""" | |
# Sets the GPU to use. | |
torch.cuda.set_device(local_rank) | |
# Initialize the process group. | |
proc_rank = local_rank + shard_id * local_world_size | |
world_size = local_world_size * num_shards | |
dist.init_process_group( | |
backend=dist_backend, | |
init_method=init_method, | |
world_size=world_size, | |
rank=proc_rank, | |
) | |
def is_master_proc(num_gpus=8): | |
""" | |
Determines if the current process is the master process. | |
""" | |
if torch.distributed.is_initialized(): | |
return dist.get_rank() % num_gpus == 0 | |
else: | |
return True | |
def is_root_proc(): | |
""" | |
Determines if the current process is the root process. | |
""" | |
if torch.distributed.is_initialized(): | |
return dist.get_rank() == 0 | |
else: | |
return True | |
def get_world_size(): | |
""" | |
Get the size of the world. | |
""" | |
if not dist.is_available(): | |
return 1 | |
if not dist.is_initialized(): | |
return 1 | |
return dist.get_world_size() | |
def get_rank(): | |
""" | |
Get the rank of the current process. | |
""" | |
if not dist.is_available(): | |
return 0 | |
if not dist.is_initialized(): | |
return 0 | |
return dist.get_rank() | |
def synchronize(): | |
""" | |
Helper function to synchronize (barrier) among all processes when | |
using distributed training | |
""" | |
if not dist.is_available(): | |
return | |
if not dist.is_initialized(): | |
return | |
world_size = dist.get_world_size() | |
if world_size == 1: | |
return | |
dist.barrier() | |
def _get_global_gloo_group(): | |
""" | |
Return a process group based on gloo backend, containing all the ranks | |
The result is cached. | |
Returns: | |
(group): pytorch dist group. | |
""" | |
if dist.get_backend() == "nccl": | |
return dist.new_group(backend="gloo") | |
else: | |
return dist.group.WORLD | |
def _serialize_to_tensor(data, group): | |
""" | |
Seriialize the tensor to ByteTensor. Note that only `gloo` and `nccl` | |
backend is supported. | |
Args: | |
data (data): data to be serialized. | |
group (group): pytorch dist group. | |
Returns: | |
tensor (ByteTensor): tensor that serialized. | |
""" | |
backend = dist.get_backend(group) | |
assert backend in ["gloo", "nccl"] | |
device = torch.device("cpu" if backend == "gloo" else "cuda") | |
buffer = pickle.dumps(data) | |
if len(buffer) > 1024 ** 3: | |
logger = logging.getLogger(__name__) | |
logger.warning( | |
"Rank {} trying to all-gather {:.2f} GB of data on device {}".format( | |
get_rank(), len(buffer) / (1024 ** 3), device | |
) | |
) | |
storage = torch.ByteStorage.from_buffer(buffer) | |
tensor = torch.ByteTensor(storage).to(device=device) | |
return tensor | |
def _pad_to_largest_tensor(tensor, group): | |
""" | |
Padding all the tensors from different GPUs to the largest ones. | |
Args: | |
tensor (tensor): tensor to pad. | |
group (group): pytorch dist group. | |
Returns: | |
list[int]: size of the tensor, on each rank | |
Tensor: padded tensor that has the max size | |
""" | |
world_size = dist.get_world_size(group=group) | |
assert ( | |
world_size >= 1 | |
), "comm.gather/all_gather must be called from ranks within the given group!" | |
local_size = torch.tensor( | |
[tensor.numel()], dtype=torch.int64, device=tensor.device | |
) | |
size_list = [ | |
torch.zeros([1], dtype=torch.int64, device=tensor.device) | |
for _ in range(world_size) | |
] | |
dist.all_gather(size_list, local_size, group=group) | |
size_list = [int(size.item()) for size in size_list] | |
max_size = max(size_list) | |
# we pad the tensor because torch all_gather does not support | |
# gathering tensors of different shapes | |
if local_size != max_size: | |
padding = torch.zeros( | |
(max_size - local_size,), dtype=torch.uint8, device=tensor.device | |
) | |
tensor = torch.cat((tensor, padding), dim=0) | |
return size_list, tensor | |
def all_gather_unaligned(data, group=None): | |
""" | |
Run all_gather on arbitrary picklable data (not necessarily tensors). | |
Args: | |
data: any picklable object | |
group: a torch process group. By default, will use a group which | |
contains all ranks on gloo backend. | |
Returns: | |
list[data]: list of data gathered from each rank | |
""" | |
if get_world_size() == 1: | |
return [data] | |
if group is None: | |
group = _get_global_gloo_group() | |
if dist.get_world_size(group) == 1: | |
return [data] | |
tensor = _serialize_to_tensor(data, group) | |
size_list, tensor = _pad_to_largest_tensor(tensor, group) | |
max_size = max(size_list) | |
# receiving Tensor from all ranks | |
tensor_list = [ | |
torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) | |
for _ in size_list | |
] | |
dist.all_gather(tensor_list, tensor, group=group) | |
data_list = [] | |
for size, tensor in zip(size_list, tensor_list): | |
buffer = tensor.cpu().numpy().tobytes()[:size] | |
data_list.append(pickle.loads(buffer)) | |
return data_list | |
def init_distributed_training(cfg): | |
""" | |
Initialize variables needed for distributed training. | |
""" | |
if cfg.NUM_GPUS <= 1: | |
return | |
num_gpus_per_machine = cfg.NUM_GPUS | |
num_machines = dist.get_world_size() // num_gpus_per_machine | |
for i in range(num_machines): | |
ranks_on_i = list( | |
range(i * num_gpus_per_machine, (i + 1) * num_gpus_per_machine) | |
) | |
pg = dist.new_group(ranks_on_i) | |
if i == cfg.SHARD_ID: | |
global _LOCAL_PROCESS_GROUP | |
_LOCAL_PROCESS_GROUP = pg | |
def get_local_size() -> int: | |
""" | |
Returns: | |
The size of the per-machine process group, | |
i.e. the number of processes per machine. | |
""" | |
if not dist.is_available(): | |
return 1 | |
if not dist.is_initialized(): | |
return 1 | |
return dist.get_world_size(group=_LOCAL_PROCESS_GROUP) | |
def get_local_rank() -> int: | |
""" | |
Returns: | |
The rank of the current process within the local (per-machine) process group. | |
""" | |
if not dist.is_available(): | |
return 0 | |
if not dist.is_initialized(): | |
return 0 | |
assert _LOCAL_PROCESS_GROUP is not None | |
return dist.get_rank(group=_LOCAL_PROCESS_GROUP) | |