|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.distributed as dist |
|
from typing import Tuple |
|
|
|
|
|
def convert_to_distributed_tensor(tensor: torch.Tensor) -> Tuple[torch.Tensor, str]: |
|
""" |
|
For some backends, such as NCCL, communication only works if the |
|
tensor is on the GPU. This helper function converts to the correct |
|
device and returns the tensor + original device. |
|
""" |
|
orig_device = "cpu" if not tensor.is_cuda else "gpu" |
|
if ( |
|
torch.distributed.is_available() |
|
and torch.distributed.get_backend() == torch.distributed.Backend.NCCL |
|
and not tensor.is_cuda |
|
): |
|
tensor = tensor.cuda() |
|
return (tensor, orig_device) |
|
|
|
|
|
def convert_to_normal_tensor(tensor: torch.Tensor, orig_device: str) -> torch.Tensor: |
|
""" |
|
For some backends, such as NCCL, communication only works if the |
|
tensor is on the GPU. This converts the tensor back to original device. |
|
""" |
|
if tensor.is_cuda and orig_device == "cpu": |
|
tensor = tensor.cpu() |
|
return tensor |
|
|
|
|
|
def is_distributed_training_run() -> bool: |
|
return ( |
|
torch.distributed.is_available() |
|
and torch.distributed.is_initialized() |
|
and (torch.distributed.get_world_size() > 1) |
|
) |
|
|
|
|
|
class GatherLayer(torch.autograd.Function): |
|
""" |
|
Gather tensors from all workers with support for backward propagation: |
|
This implementation does not cut the gradients as torch.distributed.all_gather does. |
|
""" |
|
|
|
@staticmethod |
|
def forward(ctx, x): |
|
output = [torch.zeros_like(x) for _ in range(dist.get_world_size())] |
|
dist.all_gather(output, x) |
|
return tuple(output) |
|
|
|
@staticmethod |
|
def backward(ctx, *grads): |
|
all_gradients = torch.stack(grads) |
|
dist.all_reduce(all_gradients) |
|
return all_gradients[dist.get_rank()] |
|
|
|
|
|
def gather_from_all(tensor: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Similar to classy_vision.generic.distributed_util.gather_from_all |
|
except that it does not cut the gradients |
|
""" |
|
if tensor.ndim == 0: |
|
|
|
tensor = tensor.unsqueeze(0) |
|
|
|
if is_distributed_training_run(): |
|
tensor, orig_device = convert_to_distributed_tensor(tensor) |
|
gathered_tensors = GatherLayer.apply(tensor) |
|
gathered_tensors = [ |
|
convert_to_normal_tensor(_tensor, orig_device) |
|
for _tensor in gathered_tensors |
|
] |
|
else: |
|
gathered_tensors = [tensor] |
|
gathered_tensor = torch.cat(gathered_tensors, 0) |
|
return gathered_tensor |
|
|