# Copyright (c) Facebook, Inc. and its affiliates. # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. # Part of the code is from # `https://github.com/facebookresearch/vissl/blob/main/vissl/utils/distributed_utils.py` and # `https://github.com/facebookresearch/ClassyVision/blob/main/classy_vision/generic/distributed_util.py` # Modified by Yue Zhao # The original code is under MIT License import torch import torch.distributed as dist from typing import Tuple def convert_to_distributed_tensor(tensor: torch.Tensor) -> Tuple[torch.Tensor, str]: """ For some backends, such as NCCL, communication only works if the tensor is on the GPU. This helper function converts to the correct device and returns the tensor + original device. """ orig_device = "cpu" if not tensor.is_cuda else "gpu" if ( torch.distributed.is_available() and torch.distributed.get_backend() == torch.distributed.Backend.NCCL and not tensor.is_cuda ): tensor = tensor.cuda() return (tensor, orig_device) def convert_to_normal_tensor(tensor: torch.Tensor, orig_device: str) -> torch.Tensor: """ For some backends, such as NCCL, communication only works if the tensor is on the GPU. This converts the tensor back to original device. """ if tensor.is_cuda and orig_device == "cpu": tensor = tensor.cpu() return tensor def is_distributed_training_run() -> bool: return ( torch.distributed.is_available() and torch.distributed.is_initialized() and (torch.distributed.get_world_size() > 1) ) class GatherLayer(torch.autograd.Function): """ Gather tensors from all workers with support for backward propagation: This implementation does not cut the gradients as torch.distributed.all_gather does. """ @staticmethod def forward(ctx, x): output = [torch.zeros_like(x) for _ in range(dist.get_world_size())] dist.all_gather(output, x) return tuple(output) @staticmethod def backward(ctx, *grads): all_gradients = torch.stack(grads) dist.all_reduce(all_gradients) return all_gradients[dist.get_rank()] def gather_from_all(tensor: torch.Tensor) -> torch.Tensor: """ Similar to classy_vision.generic.distributed_util.gather_from_all except that it does not cut the gradients """ if tensor.ndim == 0: # 0 dim tensors cannot be gathered. so unsqueeze tensor = tensor.unsqueeze(0) if is_distributed_training_run(): tensor, orig_device = convert_to_distributed_tensor(tensor) gathered_tensors = GatherLayer.apply(tensor) gathered_tensors = [ convert_to_normal_tensor(_tensor, orig_device) for _tensor in gathered_tensors ] else: gathered_tensors = [tensor] gathered_tensor = torch.cat(gathered_tensors, 0) return gathered_tensor