Spaces:
Running
Running
import functools | |
import logging | |
import math | |
import sys | |
import typing | |
from typing import Optional | |
import torch | |
import torch._decomp as decomp | |
import torch._prims_common as utils | |
import torch.ao.quantization.fx._decomposed | |
from torch._decomp import ( | |
core_aten_decompositions, | |
get_decompositions, | |
remove_decompositions, | |
) | |
from torch._decomp.decompositions import ( | |
_grid_sampler_2d as decomp_grid_sampler_2d, | |
pw_cast_for_opmath, | |
) | |
from torch._decomp.decompositions_for_rng import extra_random_decomps | |
from torch._higher_order_ops.out_dtype import out_dtype | |
from torch._prims_common import ( | |
elementwise_dtypes, | |
ELEMENTWISE_TYPE_PROMOTION_KIND, | |
type_to_dtype, | |
) | |
from . import config, inductor_prims | |
log = logging.getLogger(__name__) | |
aten = torch.ops.aten | |
prims = torch.ops.prims | |
quantized_decomposed = torch.ops.quantized_decomposed | |
inductor_decompositions = get_decompositions( | |
[ | |
aten._adaptive_avg_pool2d_backward, | |
aten.arange, | |
aten.bitwise_and_, | |
aten.bitwise_or_, | |
aten.clamp_min_, | |
aten.dist, | |
aten.empty_like, | |
aten.flip, | |
aten.gelu, | |
aten.hardtanh, | |
aten.index_select, | |
aten.lcm, | |
aten.leaky_relu, | |
aten.linalg_vector_norm, | |
aten._log_softmax, | |
aten.max_pool2d_with_indices_backward, | |
aten._native_batch_norm_legit, | |
aten._native_batch_norm_legit_functional, | |
aten._native_batch_norm_legit_no_training, | |
aten.native_batch_norm, | |
aten.native_group_norm, | |
aten.native_layer_norm, | |
aten.nll_loss2d_backward, | |
aten._softmax, | |
aten.sin_, | |
aten.sqrt_, | |
out_dtype, | |
aten._to_copy, | |
aten.tril_indices, | |
aten.triu_indices, | |
aten.upsample_bilinear2d.vec, | |
] | |
) | |
decompositions = {**core_aten_decompositions(), **inductor_decompositions} | |
# Remove unwanted decompositions included via the core ATen decompositions from | |
# the Inductor decomp table. | |
decomps_to_exclude = [ | |
aten._unsafe_index, | |
aten._scaled_dot_product_flash_attention_for_cpu.default, # See comments in torch/_decomp/decompositions.py | |
aten.clamp_max, | |
aten.clamp_min, | |
aten.glu, # inductor lowers this directly | |
aten.split.Tensor, # inductor lowers this directly | |
aten.squeeze, # inductor lowers this directly | |
aten.sum, # inductor lowers this directly | |
aten.unbind, # inductor lowers this directly | |
] | |
remove_decompositions(decompositions, decomps_to_exclude) | |
def register_decomposition(ops): | |
for op in [ops] if callable(ops) else ops: | |
if op in decompositions: | |
log.warning("duplicate decomp: %s", ops) | |
return decomp.register_decomposition(ops, decompositions) | |
# TODO: for now, inductor doesn't handle asserts | |
# because the condition is symbool -> tensor in the graph. | |
def assert_async_msg_decomp(tensor, msg): | |
return | |
# Following `assert_async_msg_decomp` and implement as non-op. | |
def functional_assert_async_msg_decomp(tensor, msg): | |
return | |
def sym_constrain_range_for_size(symbol, *, min=None, max=None): | |
return | |
def clamp(x, min=None, max=None): | |
if min is not None: | |
x = x.clamp_min(min) | |
if max is not None: | |
x = x.clamp_max(max) | |
return x | |
def full(size, fill_value, **kwargs): | |
dtype = kwargs.get("dtype") | |
if dtype is None: | |
kwargs["dtype"] = type_to_dtype(type(fill_value)) | |
return aten.full(size, fill_value, **kwargs) | |
return NotImplemented | |
# Not really sure how to put this into the main library. PrimTorch wants | |
# empty_permuted to go to the prim, and typically users don't really want | |
# to decompose to empty_strided (but inductor is OK with it, because we are | |
# cool with strides and everything goes to empty_strided) | |
def empty_permuted(size, physical_layout, **kwargs): | |
perm = [0] * len(size) | |
for p, l in enumerate(physical_layout): | |
perm[l] = p | |
return torch.empty([size[l] for l in physical_layout], **kwargs).permute(perm) | |
def convolution_backward( | |
grad_output, | |
input, | |
weight, | |
bias_sizes, | |
stride, | |
padding, | |
dilation, | |
transposed, | |
output_padding, | |
groups, | |
output_mask, | |
): | |
if not output_mask[2] or grad_output.device.type != "cuda": | |
return NotImplemented | |
grad_bias = aten.sum(grad_output, [0] + list(range(2, grad_output.dim()))) | |
grad_inp, grad_weight, _ = aten.convolution_backward( | |
grad_output, | |
input, | |
weight, | |
bias_sizes, | |
stride, | |
padding, | |
dilation, | |
transposed, | |
output_padding, | |
groups, | |
[output_mask[0], output_mask[1], False], | |
) | |
return (grad_inp, grad_weight, grad_bias) | |
def log2(x): | |
return torch.log(x) * (1.0 / math.log(2.0)) | |
def round_dec(x, decimals=0): | |
ten_pow_decimals = 10.0**decimals | |
return aten.round(x * ten_pow_decimals) * (1.0 / ten_pow_decimals) | |
def bmm(self, batch2): | |
if config.coordinate_descent_tuning: | |
if self.shape[1] == 1 or batch2.shape[2] == 1: | |
out = (self.unsqueeze(-1) * batch2.unsqueeze(1)).sum(dim=2) | |
return out | |
if self.device.type == "cpu": | |
if self.size(1) == 1 and batch2.size(-1) == 1: | |
return torch.sum( | |
self.squeeze(1) * batch2.squeeze(-1), dim=1, keepdim=True | |
).unsqueeze(1) | |
return NotImplemented | |
def addmm(self, mat1, mat2, beta=1, alpha=1): | |
if self.device.type == "cpu": | |
if mat1.size(0) == 1 and mat2.size(-1) == 1: | |
out = torch.sum( | |
mat1.squeeze(0) * mat2.squeeze(-1), dim=0, keepdim=True | |
).unsqueeze(0) | |
return alpha * out + beta * self | |
if mat1.size(0) == 1 and mat2.size(0) <= 16 and mat2.size(1) <= 16: | |
out = (mat1.T * mat2).sum(dim=0, keepdim=True) | |
return alpha * out + beta * self | |
return NotImplemented | |
def mm(self, input2): | |
from torch.fx.experimental.symbolic_shapes import ( | |
definitely_true, | |
guard_size_oblivious, | |
) | |
# Our matrix vector multiplies only achieve peak bandwidth with coordinate descent tuning. | |
# todo: Look into why and fix it (hopefully) | |
if config.coordinate_descent_tuning: | |
if self.shape[0] == 1 or input2.shape[1] == 1: | |
return (self.unsqueeze(2) * input2.unsqueeze(0)).sum(dim=1) | |
if self.device.type == "cpu": | |
if ( | |
guard_size_oblivious(self.size(-1) == 1) | |
and guard_size_oblivious(self.size(0) > 0) | |
and guard_size_oblivious(input2.size(0) == 1) | |
and (self.dtype == input2.dtype) | |
and definitely_true((torch.numel(self) + torch.numel(input2)) <= 32) | |
): | |
return torch.cat([self[i, :] * input2 for i in range(self.size(0))]) | |
if guard_size_oblivious(self.size(0) == 1) and guard_size_oblivious( | |
input2.size(-1) == 1 | |
): | |
return torch.sum( | |
self.squeeze(0) * input2.squeeze(-1), dim=0, keepdim=True | |
).unsqueeze(0) | |
return NotImplemented | |
# This pass does two things: | |
# - Eliminate cat when there is only one tensor input | |
# - Normalize cat calls, so that legacy empty 1-D tensors are removed (NB: we | |
# don't remove ALL empty tensors, only the naughty ones) | |
def cat(tensors, dim=0): | |
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious | |
def non_empty_tensor(x): | |
# For better or worse, this is a valid cat: | |
# | |
# torch.cat([torch.randn(2, 2, 4), torch.randn(0), torch.randn(3, 2, 4)]) | |
# | |
# We'd like to eliminate naughtiness like this for downstream passes | |
# like split_cat. The easiest way is to just drop such inputs | |
# (guarding that they are non-zero). | |
# | |
# Is it permissible for this filtering to be size-oblivious? A case | |
# where this could matter is cat([(2, 2), (u0,)], dim=0); if u0 | |
# happened to be zero, we would have liked to have filtered it out. | |
# But actually, the ONLY way this could have passed is if u0 == 0, | |
# so by the time we get here we have already installed a deferred | |
# runtime assert forcing u0 to be zero. So if this hasn't happened, | |
# we know that the unbacked SymInt has appropriate size and there are | |
# no problems. | |
return len(x.shape) != 1 or guard_size_oblivious(x.shape[0] > 0) | |
filtered_tensors = list(filter(non_empty_tensor, tensors)) | |
if len(filtered_tensors) == 1: | |
return filtered_tensors[0].clone() | |
elif 1 < len(filtered_tensors) < len(tensors): | |
# on the first call, when we remove empty tensors, we redispatch recursively | |
return aten.cat.default(filtered_tensors, dim) | |
# when no 'filtering' has occurred, we raise to prevent infinite recursion (no more decomposition needed) | |
return NotImplemented | |
def angle(x): | |
if x.is_complex(): | |
return torch.where( | |
torch.isnan(x.real), float("nan"), torch.atan2(x.imag, x.real) | |
) | |
# when x is real number | |
# if x >= 0, return 0 | |
# if x < 0, return pi | |
# if x is nan, return nan | |
_, dtype = elementwise_dtypes( | |
x, | |
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, | |
) | |
pi = torch.scalar_tensor(math.pi, dtype=dtype, device=x.device) | |
ret = torch.where(x < 0, pi, 0.0) | |
return torch.where(torch.isnan(x), float("nan"), ret) | |
def add(x, y, *, alpha=None): | |
x_is_complex_tensor = torch.is_tensor(x) and x.is_complex() | |
y_is_complex_tensor = torch.is_tensor(y) and y.is_complex() | |
if not x_is_complex_tensor or not y_is_complex_tensor: | |
return NotImplemented | |
z = y | |
if alpha is not None: | |
z = alpha * y | |
complex_type = torch.promote_types(x.dtype, y.dtype) | |
return (x.view(x.real.dtype) + z.view(y.real.dtype)).view(complex_type) | |
def conj_physical(self): | |
assert not self.is_complex(), "TODO: implement this" | |
return self | |
def lift(self): | |
return self | |
def bernoulli(self, *, generator=None): | |
assert generator is None | |
return (torch.rand_like(self, dtype=torch.float32) < self).to(self.dtype) | |
def fmin(self, other): | |
return torch.where(torch.isnan(other) | (other > self), self, other) | |
def fmax(self, other): | |
return torch.where(torch.isnan(other) | (other < self), self, other) | |
def amax(self, dim=None, keepdim=False): | |
if self.dtype == torch.bool: | |
return torch.any(self, dim=dim, keepdim=keepdim) | |
return NotImplemented | |
def amin(self, dim=None, keepdim=False): | |
if self.dtype == torch.bool: | |
return torch.all(self, dim=dim, keepdim=keepdim) | |
return NotImplemented | |
def narrow_copy(self, dim, start, length): | |
return torch.narrow(self, dim, start, length).clone() | |
def expand_copy(self, size, *, implicit=False): | |
return aten.expand(self, size, implicit=implicit).clone() | |
def view_copy_default(self, size): | |
return aten.view(self, size).clone() | |
def view_copy_dtype(self, dtype): | |
return self.to(dtype).clone() | |
def get_like_layout( | |
tensor: torch.Tensor, memory_format: Optional[torch.memory_format] | |
) -> torch.memory_format: | |
# TODO: _to_copy tensor to stride permutation | |
if memory_format is torch.preserve_format or memory_format is None: | |
return utils.suggest_memory_format(tensor) | |
else: | |
return memory_format | |
def rand_like(self, *, dtype=None, device=None, memory_format=None, **kwargs): | |
return torch.rand( | |
[*self.size()], | |
dtype=dtype or self.dtype, | |
device=device or self.device, | |
**kwargs, | |
).to(memory_format=get_like_layout(self, memory_format)) | |
def randn_like(self, *, dtype=None, device=None, memory_format=None, **kwargs): | |
return torch.randn( | |
[*self.size()], | |
dtype=dtype or self.dtype, | |
device=device or self.device, | |
**kwargs, | |
).to(memory_format=get_like_layout(self, memory_format)) | |
def full_like( | |
self, | |
fill_value, | |
*, | |
dtype=None, | |
layout=None, | |
device=None, | |
pin_memory=False, | |
requires_grad=False, | |
memory_format=torch.preserve_format, | |
): | |
return torch.full( | |
[*self.size()], | |
fill_value, | |
dtype=dtype or self.dtype, | |
layout=layout or self.layout, | |
device=device or self.device, | |
requires_grad=requires_grad, | |
).to(memory_format=get_like_layout(self, memory_format)) | |
def randint_like(self, high, *, dtype=None, device=None, memory_format=None, **kwargs): | |
return aten.randint.low( | |
0, | |
high, | |
[*self.size()], | |
dtype=dtype or self.dtype, | |
device=device or self.device, | |
**kwargs, | |
).to(memory_format=get_like_layout(self, memory_format)) | |
def randint_like_low( | |
self, low, high, *, dtype=None, device=None, memory_format=None, **kwargs | |
): | |
return aten.randint.low( | |
low, | |
high, | |
[*self.size()], | |
dtype=dtype or self.dtype, | |
device=device or self.device, | |
**kwargs, | |
).to(memory_format=get_like_layout(self, memory_format)) | |
def randint(high, size, **kwargs): | |
return aten.randint.low(0, high, size, **kwargs) | |
# The difference between quantize_per_tensor.default and quantize_per_tensor.tensor is | |
# scale and zero_point is scalar or scalar tensor | |
def quantize_per_tensor_default_decomp_impl( | |
input: torch.Tensor, | |
scale: float, | |
zero_point: int, | |
quant_min: int, | |
quant_max: int, | |
dtype: torch.dtype, | |
) -> torch.Tensor: | |
if input.dtype == torch.bfloat16: | |
input = input.to(torch.float32) | |
inv_scale = 1.0 / scale | |
return torch.clamp( | |
torch.round(input * inv_scale) + zero_point, quant_min, quant_max | |
).to(dtype) | |
# The difference between dequantize_per_tensor.default and dequantize_per_tensor.tensor is | |
# scale and zero_point is scalar or scalar tensor | |
def dequantize_per_tensor_default_decomp_impl( | |
input: torch.Tensor, | |
scale: float, | |
zero_point: int, | |
quant_min: int, | |
quant_max: int, | |
dtype: torch.dtype, | |
) -> torch.Tensor: | |
return (input.to(torch.float32) - zero_point) * scale | |
def quantize_per_tensor_tensor_decomp_impl( | |
input: torch.Tensor, | |
scale: torch.Tensor, | |
zero_point: torch.Tensor, | |
quant_min: int, | |
quant_max: int, | |
dtype: torch.dtype, | |
) -> torch.Tensor: | |
if input.dtype == torch.bfloat16: | |
input = input.to(torch.float32) | |
inv_scale = 1.0 / scale | |
return torch.clamp( | |
torch.round(input * inv_scale) + zero_point, quant_min, quant_max | |
).to(dtype) | |
def dequantize_per_tensor_tensor_decomp_impl( | |
input: torch.Tensor, | |
scale: torch.Tensor, | |
zero_point: torch.Tensor, | |
quant_min: int, | |
quant_max: int, | |
dtype: torch.dtype, | |
) -> torch.Tensor: | |
return (input.to(torch.float32) - zero_point.to(torch.int32)) * scale.to( | |
torch.float32 | |
) | |
def q_embedding_bag_byte_unpack_decomp(packed): | |
def bitcast_u8_to_f32(u8): | |
x, y, z, w = (u8[..., n].to(torch.int32) for n in (0, 1, 2, 3)) | |
if sys.byteorder == "little": | |
return (x + (y << 8) + (z << 16) + (w << 24)).view(torch.float32)[..., None] | |
else: | |
return ((x << 24) + (y << 16) + (z << 8) + w).view(torch.float32)[..., None] | |
scales = bitcast_u8_to_f32(packed[..., -8:-4]) | |
offsets = bitcast_u8_to_f32(packed[..., -4:]) | |
return packed[..., :-8].to(torch.float32) * scales + offsets | |
def grid_sampler_2d( | |
a: torch.Tensor, | |
grid: torch.Tensor, | |
interpolation_mode: int = 0, | |
padding_mode: int = 0, | |
align_corners: bool = False, | |
) -> torch.Tensor: | |
# We do not expand the grid (_expand_grid=False) on cpu for performance reasons | |
# Experimenting locally it was found that compiled CUDA code is accelerated by ~5x | |
# and CPU code by ~2x on bicubic mode, if we expand the grid from (N, H, W, 2) into (N, C, H, W, 2) | |
# However, this leads to a slowdown around ~0.8x on CPU bilinear mode, channels first. | |
# Thus we apply this hack to not expand the grid for this case. | |
_expand_grid = not ( | |
a.device == torch.device("cpu") | |
and interpolation_mode == 0 | |
and a.is_contiguous(memory_format=torch.contiguous_format) | |
) | |
output = decomp_grid_sampler_2d( | |
a, | |
grid=grid, | |
interpolation_mode=interpolation_mode, | |
padding_mode=padding_mode, | |
align_corners=align_corners, | |
_expand_grid=_expand_grid, | |
) | |
return output | |
def _foreach_addcmul_scalar(self, left_tensors, right_tensors, scalar=1): | |
return aten._foreach_add.List( | |
self, aten._foreach_mul.List(left_tensors, right_tensors), alpha=scalar | |
) | |
def _foreach_addcdiv_scalar(self, left_tensors, right_tensors, scalar=1): | |
return aten._foreach_add.List( | |
self, aten._foreach_div.List(left_tensors, right_tensors), alpha=scalar | |
) | |
def _foreach_lerp_scalar(start_tensors, end_tensors, weight): | |
return aten._foreach_add.List( | |
start_tensors, | |
aten._foreach_mul.Scalar( | |
aten._foreach_sub.List(end_tensors, start_tensors), weight | |
), | |
) | |
def miopen_batch_norm( | |
input: torch.Tensor, | |
weight: torch.Tensor, | |
bias: typing.Optional[torch.Tensor], | |
running_mean: typing.Optional[torch.Tensor], | |
running_var: typing.Optional[torch.Tensor], | |
training: bool, | |
exponential_average_factor: float, | |
epsilon: float, | |
): | |
a, b, c = aten.native_batch_norm( | |
input, | |
weight, | |
bias, | |
running_mean, | |
running_var, | |
training, | |
exponential_average_factor, | |
epsilon, | |
) | |
if training: | |
return (a, b, c) | |
return ( | |
a, | |
weight.new_zeros((0,)), | |
weight.new_zeros((0,)), | |
) | |
def fast_random_decomps(): | |
return {**decompositions, **extra_random_decomps} | |
def select_decomp_table(): | |
"""decomps can change based on config""" | |
if config.fallback_random: | |
return decompositions | |
return fast_random_decomps() | |
def masked_scatter(self, mask, source): | |
if self.device.type == "cuda": | |
# This two-step algorithm is the same as eager CUDA, for eager CPU we | |
# use a 1-shot serial iteration. | |
self, mask = aten.broadcast_tensors([self, mask]) | |
source_idx = mask.reshape(-1).cumsum(0) - 1 | |
return inductor_prims.masked_scatter_with_index(self, mask, source_idx, source) | |
return NotImplemented | |
def choose_qparams_tensor( | |
input: torch.Tensor, quant_min: int, quant_max: int, eps: float, dtype: torch.dtype | |
): | |
min_val, max_val = torch.aminmax(input) | |
scale = (max_val - min_val) / float(quant_max - quant_min) | |
scale = torch.max(scale, torch.Tensor([eps])) | |
zero_point = quant_min - torch.round(min_val / scale).to(torch.int) | |
zero_point = torch.clamp(zero_point, quant_min, quant_max) | |
return scale.to(torch.float64), zero_point.to(torch.int64) | |
def put(self, index, source, accumulate=False): | |
flattened = self.flatten() | |
flattened = torch.index_put( | |
flattened, [index], source.reshape(index.shape), accumulate | |
) | |
return flattened.reshape(self.shape) | |
def put_(self, index, source, accumulate=False): | |
out = aten.put(self, index, source, accumulate=accumulate) | |
return self.copy_(out) | |