Spaces:
Running
Running
File size: 12,076 Bytes
c61ccee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 |
import torch
import torch.distributed as dist
from torch.autograd.function import Function
class SyncBatchNorm(Function):
@staticmethod
def forward(self, input, weight, bias, running_mean, running_var, eps, momentum, process_group, world_size):
if not (
input.is_contiguous(memory_format=torch.channels_last) or
input.is_contiguous(memory_format=torch.channels_last_3d)
):
input = input.contiguous()
if weight is not None:
weight = weight.contiguous()
size = int(input.numel() // input.size(1))
if size == 1 and world_size < 2:
raise ValueError(f'Expected more than 1 value per channel when training, got input size {size}')
num_channels = input.shape[1]
if input.numel() > 0:
# calculate mean/invstd for input.
mean, invstd = torch.batch_norm_stats(input, eps)
count = torch.full(
(1,),
input.numel() // input.size(1),
dtype=mean.dtype,
device=mean.device
)
# C, C, 1 -> (2C + 1)
combined = torch.cat([mean, invstd, count], dim=0)
else:
# for empty input, set stats and the count to zero. The stats with
# zero count will be filtered out later when computing global mean
# & invstd, but they still needs to participate the all_gather
# collective communication to unblock other peer processes.
combined = torch.zeros(
2 * num_channels + 1,
dtype=input.dtype,
device=input.device
)
# Use allgather instead of allreduce because count could be different across
# ranks, simple all reduce op can not give correct results.
# batch_norm_gather_stats_with_counts calculates global mean & invstd based on
# all gathered mean, invstd and count.
# for nccl backend, use the optimized version of all gather.
# The Gloo backend does not support `all_gather_into_tensor`.
if process_group._get_backend_name() != "gloo":
# world_size * (2C + 1)
combined_size = combined.numel()
combined_flat = torch.empty(1,
combined_size * world_size,
dtype=combined.dtype,
device=combined.device)
dist.all_gather_into_tensor(combined_flat, combined, process_group, async_op=False)
combined = torch.reshape(combined_flat, (world_size, combined_size))
# world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1
mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1)
else:
# world_size * (2C + 1)
combined_list = [
torch.empty_like(combined) for _ in range(world_size)
]
dist.all_gather(combined_list, combined, process_group, async_op=False)
combined = torch.stack(combined_list, dim=0)
# world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1
mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1)
if not (torch.cuda.is_available() and torch.cuda.is_current_stream_capturing()):
# The lines below force a synchronization between CUDA and CPU, because
# the shape of the result count_all depends on the values in mask tensor.
# Such synchronizations break CUDA Graph capturing.
# See https://github.com/pytorch/pytorch/issues/78549
# FIXME: https://github.com/pytorch/pytorch/issues/78656 describes
# a better longer-term solution.
# remove stats from empty inputs
mask = count_all.squeeze(-1) >= 1
count_all = count_all[mask]
mean_all = mean_all[mask]
invstd_all = invstd_all[mask]
# calculate global mean & invstd
counts = count_all.view(-1)
if running_mean is not None and counts.dtype != running_mean.dtype:
counts = counts.to(running_mean.dtype)
mean, invstd = torch.batch_norm_gather_stats_with_counts(
input,
mean_all,
invstd_all,
running_mean,
running_var,
momentum,
eps,
counts,
)
self.save_for_backward(input, weight, mean, invstd, count_all.to(torch.int32))
self.process_group = process_group
# apply element-wise normalization
if input.numel() > 0:
return torch.batch_norm_elemt(input, weight, bias, mean, invstd, eps)
else:
return torch.empty_like(input)
@staticmethod
def backward(self, grad_output):
if not (
grad_output.is_contiguous(memory_format=torch.channels_last) or
grad_output.is_contiguous(memory_format=torch.channels_last_3d)
):
grad_output = grad_output.contiguous()
saved_input, weight, mean, invstd, count_tensor = self.saved_tensors
grad_input = grad_weight = grad_bias = None
process_group = self.process_group
if saved_input.numel() > 0:
# calculate local stats as well as grad_weight / grad_bias
sum_dy, sum_dy_xmu, grad_weight, grad_bias = torch.batch_norm_backward_reduce(
grad_output,
saved_input,
mean,
invstd,
weight,
self.needs_input_grad[0],
self.needs_input_grad[1],
self.needs_input_grad[2]
)
if self.needs_input_grad[0]:
# synchronizing stats used to calculate input gradient.
num_channels = sum_dy.shape[0]
combined = torch.cat([sum_dy, sum_dy_xmu], dim=0)
torch.distributed.all_reduce(
combined, torch.distributed.ReduceOp.SUM, process_group, async_op=False)
sum_dy, sum_dy_xmu = torch.split(combined, num_channels)
# backward pass for gradient calculation
if weight is not None and weight.dtype != mean.dtype:
weight = weight.to(mean.dtype)
grad_input = torch.batch_norm_backward_elemt(
grad_output,
saved_input,
mean,
invstd,
weight,
sum_dy,
sum_dy_xmu,
count_tensor
)
# synchronizing of grad_weight / grad_bias is not needed as distributed
# training would handle all reduce.
if weight is None or not self.needs_input_grad[1]:
grad_weight = None
if weight is None or not self.needs_input_grad[2]:
grad_bias = None
else:
# This process got an empty input tensor in the forward pass.
# Although this process can directly set grad_input as an empty
# tensor of zeros, it still needs to participate in the collective
# communication to unblock its peers, as other peer processes might
# have received non-empty inputs.
num_channels = saved_input.shape[1]
if self.needs_input_grad[0]:
# launch all_reduce to unblock other peer processes
combined = torch.zeros(
2 * num_channels,
dtype=saved_input.dtype,
device=saved_input.device
)
torch.distributed.all_reduce(
combined, torch.distributed.ReduceOp.SUM, process_group, async_op=False)
# Leave grad_input, grad_weight and grad_bias as None, which will be
# interpreted by the autograd engine as Tensors full of zeros.
return grad_input, grad_weight, grad_bias, None, None, None, None, None, None
class CrossMapLRN2d(Function):
@staticmethod
def forward(ctx, input, size, alpha=1e-4, beta=0.75, k=1):
ctx.size = size
ctx.alpha = alpha
ctx.beta = beta
ctx.k = k
ctx.scale = None
if input.dim() != 4:
raise ValueError(f"CrossMapLRN2d: Expected input to be 4D, got {input.dim()}D instead.")
ctx.scale = ctx.scale or input.new()
output = input.new()
batch_size = input.size(0)
channels = input.size(1)
input_height = input.size(2)
input_width = input.size(3)
output.resize_as_(input)
ctx.scale.resize_as_(input)
# use output storage as temporary buffer
input_square = output
torch.pow(input, 2, out=input_square)
pre_pad = int((ctx.size - 1) / 2 + 1)
pre_pad_crop = min(pre_pad, channels)
scale_first = ctx.scale.select(1, 0)
scale_first.zero_()
# compute first feature map normalization
for c in range(pre_pad_crop):
scale_first.add_(input_square.select(1, c))
# reuse computations for next feature maps normalization
# by adding the next feature map and removing the previous
for c in range(1, channels):
scale_previous = ctx.scale.select(1, c - 1)
scale_current = ctx.scale.select(1, c)
scale_current.copy_(scale_previous)
if c < channels - pre_pad + 1:
square_next = input_square.select(1, c + pre_pad - 1)
scale_current.add_(square_next, alpha=1)
if c > pre_pad:
square_previous = input_square.select(1, c - pre_pad)
scale_current.add_(square_previous, alpha=-1)
ctx.scale.mul_(ctx.alpha / ctx.size).add_(ctx.k)
torch.pow(ctx.scale, -ctx.beta, out=output)
output.mul_(input)
ctx.save_for_backward(input, output)
return output
@staticmethod
def backward(ctx, grad_output):
input, output = ctx.saved_tensors
grad_input = grad_output.new()
batch_size = input.size(0)
channels = input.size(1)
input_height = input.size(2)
input_width = input.size(3)
paddded_ratio = input.new(channels + ctx.size - 1, input_height,
input_width)
accum_ratio = input.new(input_height, input_width)
cache_ratio_value = 2 * ctx.alpha * ctx.beta / ctx.size
inversePrePad = int(ctx.size - (ctx.size - 1) / 2)
grad_input.resize_as_(input)
torch.pow(ctx.scale, -ctx.beta, out=grad_input).mul_(grad_output)
paddded_ratio.zero_()
padded_ratio_center = paddded_ratio.narrow(0, inversePrePad,
channels)
for n in range(batch_size):
torch.mul(grad_output[n], output[n], out=padded_ratio_center)
padded_ratio_center.div_(ctx.scale[n])
torch.sum(
paddded_ratio.narrow(0, 0, ctx.size - 1), 0, keepdim=False, out=accum_ratio)
for c in range(channels):
accum_ratio.add_(paddded_ratio[c + ctx.size - 1])
grad_input[n][c].addcmul_(input[n][c], accum_ratio, value=-cache_ratio_value)
accum_ratio.add_(paddded_ratio[c], alpha=-1)
return grad_input, None, None, None, None
class BackwardHookFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, *args):
ctx.mark_non_differentiable(*[arg for arg in args if not arg.requires_grad])
return args
@staticmethod
def backward(ctx, *args):
return args
|