Spaces:
Running
Running
import torch | |
import torch.distributed as dist | |
from torch.autograd.function import Function | |
class SyncBatchNorm(Function): | |
def forward(self, input, weight, bias, running_mean, running_var, eps, momentum, process_group, world_size): | |
if not ( | |
input.is_contiguous(memory_format=torch.channels_last) or | |
input.is_contiguous(memory_format=torch.channels_last_3d) | |
): | |
input = input.contiguous() | |
if weight is not None: | |
weight = weight.contiguous() | |
size = int(input.numel() // input.size(1)) | |
if size == 1 and world_size < 2: | |
raise ValueError(f'Expected more than 1 value per channel when training, got input size {size}') | |
num_channels = input.shape[1] | |
if input.numel() > 0: | |
# calculate mean/invstd for input. | |
mean, invstd = torch.batch_norm_stats(input, eps) | |
count = torch.full( | |
(1,), | |
input.numel() // input.size(1), | |
dtype=mean.dtype, | |
device=mean.device | |
) | |
# C, C, 1 -> (2C + 1) | |
combined = torch.cat([mean, invstd, count], dim=0) | |
else: | |
# for empty input, set stats and the count to zero. The stats with | |
# zero count will be filtered out later when computing global mean | |
# & invstd, but they still needs to participate the all_gather | |
# collective communication to unblock other peer processes. | |
combined = torch.zeros( | |
2 * num_channels + 1, | |
dtype=input.dtype, | |
device=input.device | |
) | |
# Use allgather instead of allreduce because count could be different across | |
# ranks, simple all reduce op can not give correct results. | |
# batch_norm_gather_stats_with_counts calculates global mean & invstd based on | |
# all gathered mean, invstd and count. | |
# for nccl backend, use the optimized version of all gather. | |
# The Gloo backend does not support `all_gather_into_tensor`. | |
if process_group._get_backend_name() != "gloo": | |
# world_size * (2C + 1) | |
combined_size = combined.numel() | |
combined_flat = torch.empty(1, | |
combined_size * world_size, | |
dtype=combined.dtype, | |
device=combined.device) | |
dist.all_gather_into_tensor(combined_flat, combined, process_group, async_op=False) | |
combined = torch.reshape(combined_flat, (world_size, combined_size)) | |
# world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1 | |
mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1) | |
else: | |
# world_size * (2C + 1) | |
combined_list = [ | |
torch.empty_like(combined) for _ in range(world_size) | |
] | |
dist.all_gather(combined_list, combined, process_group, async_op=False) | |
combined = torch.stack(combined_list, dim=0) | |
# world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1 | |
mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1) | |
if not (torch.cuda.is_available() and torch.cuda.is_current_stream_capturing()): | |
# The lines below force a synchronization between CUDA and CPU, because | |
# the shape of the result count_all depends on the values in mask tensor. | |
# Such synchronizations break CUDA Graph capturing. | |
# See https://github.com/pytorch/pytorch/issues/78549 | |
# FIXME: https://github.com/pytorch/pytorch/issues/78656 describes | |
# a better longer-term solution. | |
# remove stats from empty inputs | |
mask = count_all.squeeze(-1) >= 1 | |
count_all = count_all[mask] | |
mean_all = mean_all[mask] | |
invstd_all = invstd_all[mask] | |
# calculate global mean & invstd | |
counts = count_all.view(-1) | |
if running_mean is not None and counts.dtype != running_mean.dtype: | |
counts = counts.to(running_mean.dtype) | |
mean, invstd = torch.batch_norm_gather_stats_with_counts( | |
input, | |
mean_all, | |
invstd_all, | |
running_mean, | |
running_var, | |
momentum, | |
eps, | |
counts, | |
) | |
self.save_for_backward(input, weight, mean, invstd, count_all.to(torch.int32)) | |
self.process_group = process_group | |
# apply element-wise normalization | |
if input.numel() > 0: | |
return torch.batch_norm_elemt(input, weight, bias, mean, invstd, eps) | |
else: | |
return torch.empty_like(input) | |
def backward(self, grad_output): | |
if not ( | |
grad_output.is_contiguous(memory_format=torch.channels_last) or | |
grad_output.is_contiguous(memory_format=torch.channels_last_3d) | |
): | |
grad_output = grad_output.contiguous() | |
saved_input, weight, mean, invstd, count_tensor = self.saved_tensors | |
grad_input = grad_weight = grad_bias = None | |
process_group = self.process_group | |
if saved_input.numel() > 0: | |
# calculate local stats as well as grad_weight / grad_bias | |
sum_dy, sum_dy_xmu, grad_weight, grad_bias = torch.batch_norm_backward_reduce( | |
grad_output, | |
saved_input, | |
mean, | |
invstd, | |
weight, | |
self.needs_input_grad[0], | |
self.needs_input_grad[1], | |
self.needs_input_grad[2] | |
) | |
if self.needs_input_grad[0]: | |
# synchronizing stats used to calculate input gradient. | |
num_channels = sum_dy.shape[0] | |
combined = torch.cat([sum_dy, sum_dy_xmu], dim=0) | |
torch.distributed.all_reduce( | |
combined, torch.distributed.ReduceOp.SUM, process_group, async_op=False) | |
sum_dy, sum_dy_xmu = torch.split(combined, num_channels) | |
# backward pass for gradient calculation | |
if weight is not None and weight.dtype != mean.dtype: | |
weight = weight.to(mean.dtype) | |
grad_input = torch.batch_norm_backward_elemt( | |
grad_output, | |
saved_input, | |
mean, | |
invstd, | |
weight, | |
sum_dy, | |
sum_dy_xmu, | |
count_tensor | |
) | |
# synchronizing of grad_weight / grad_bias is not needed as distributed | |
# training would handle all reduce. | |
if weight is None or not self.needs_input_grad[1]: | |
grad_weight = None | |
if weight is None or not self.needs_input_grad[2]: | |
grad_bias = None | |
else: | |
# This process got an empty input tensor in the forward pass. | |
# Although this process can directly set grad_input as an empty | |
# tensor of zeros, it still needs to participate in the collective | |
# communication to unblock its peers, as other peer processes might | |
# have received non-empty inputs. | |
num_channels = saved_input.shape[1] | |
if self.needs_input_grad[0]: | |
# launch all_reduce to unblock other peer processes | |
combined = torch.zeros( | |
2 * num_channels, | |
dtype=saved_input.dtype, | |
device=saved_input.device | |
) | |
torch.distributed.all_reduce( | |
combined, torch.distributed.ReduceOp.SUM, process_group, async_op=False) | |
# Leave grad_input, grad_weight and grad_bias as None, which will be | |
# interpreted by the autograd engine as Tensors full of zeros. | |
return grad_input, grad_weight, grad_bias, None, None, None, None, None, None | |
class CrossMapLRN2d(Function): | |
def forward(ctx, input, size, alpha=1e-4, beta=0.75, k=1): | |
ctx.size = size | |
ctx.alpha = alpha | |
ctx.beta = beta | |
ctx.k = k | |
ctx.scale = None | |
if input.dim() != 4: | |
raise ValueError(f"CrossMapLRN2d: Expected input to be 4D, got {input.dim()}D instead.") | |
ctx.scale = ctx.scale or input.new() | |
output = input.new() | |
batch_size = input.size(0) | |
channels = input.size(1) | |
input_height = input.size(2) | |
input_width = input.size(3) | |
output.resize_as_(input) | |
ctx.scale.resize_as_(input) | |
# use output storage as temporary buffer | |
input_square = output | |
torch.pow(input, 2, out=input_square) | |
pre_pad = int((ctx.size - 1) / 2 + 1) | |
pre_pad_crop = min(pre_pad, channels) | |
scale_first = ctx.scale.select(1, 0) | |
scale_first.zero_() | |
# compute first feature map normalization | |
for c in range(pre_pad_crop): | |
scale_first.add_(input_square.select(1, c)) | |
# reuse computations for next feature maps normalization | |
# by adding the next feature map and removing the previous | |
for c in range(1, channels): | |
scale_previous = ctx.scale.select(1, c - 1) | |
scale_current = ctx.scale.select(1, c) | |
scale_current.copy_(scale_previous) | |
if c < channels - pre_pad + 1: | |
square_next = input_square.select(1, c + pre_pad - 1) | |
scale_current.add_(square_next, alpha=1) | |
if c > pre_pad: | |
square_previous = input_square.select(1, c - pre_pad) | |
scale_current.add_(square_previous, alpha=-1) | |
ctx.scale.mul_(ctx.alpha / ctx.size).add_(ctx.k) | |
torch.pow(ctx.scale, -ctx.beta, out=output) | |
output.mul_(input) | |
ctx.save_for_backward(input, output) | |
return output | |
def backward(ctx, grad_output): | |
input, output = ctx.saved_tensors | |
grad_input = grad_output.new() | |
batch_size = input.size(0) | |
channels = input.size(1) | |
input_height = input.size(2) | |
input_width = input.size(3) | |
paddded_ratio = input.new(channels + ctx.size - 1, input_height, | |
input_width) | |
accum_ratio = input.new(input_height, input_width) | |
cache_ratio_value = 2 * ctx.alpha * ctx.beta / ctx.size | |
inversePrePad = int(ctx.size - (ctx.size - 1) / 2) | |
grad_input.resize_as_(input) | |
torch.pow(ctx.scale, -ctx.beta, out=grad_input).mul_(grad_output) | |
paddded_ratio.zero_() | |
padded_ratio_center = paddded_ratio.narrow(0, inversePrePad, | |
channels) | |
for n in range(batch_size): | |
torch.mul(grad_output[n], output[n], out=padded_ratio_center) | |
padded_ratio_center.div_(ctx.scale[n]) | |
torch.sum( | |
paddded_ratio.narrow(0, 0, ctx.size - 1), 0, keepdim=False, out=accum_ratio) | |
for c in range(channels): | |
accum_ratio.add_(paddded_ratio[c + ctx.size - 1]) | |
grad_input[n][c].addcmul_(input[n][c], accum_ratio, value=-cache_ratio_value) | |
accum_ratio.add_(paddded_ratio[c], alpha=-1) | |
return grad_input, None, None, None, None | |
class BackwardHookFunction(torch.autograd.Function): | |
def forward(ctx, *args): | |
ctx.mark_non_differentiable(*[arg for arg in args if not arg.requires_grad]) | |
return args | |
def backward(ctx, *args): | |
return args | |