Spaces:
Running
Running
# mypy: ignore-errors | |
import functools | |
import warnings | |
from typing import Callable, Union | |
import torch | |
import torch.utils._pytree as pytree | |
from torch._ops import OpOverload | |
from torch._subclasses.fake_tensor import ( | |
FakeTensorMode, | |
tree_flatten_only, | |
UnsupportedFakeTensorException, | |
) | |
from torch.utils._python_dispatch import TorchDispatchMode | |
aten = torch._ops.ops.aten | |
def outputs_alias_inputs(outputs, inputs): | |
input_storages = { | |
inp._typed_storage()._cdata | |
for inp in tree_flatten_only(torch.Tensor, inputs) | |
if torch._C._has_storage(inp) | |
} | |
return any( | |
torch._C._has_storage(out) and out._typed_storage()._cdata in input_storages | |
for out in tree_flatten_only(torch.Tensor, outputs) | |
) | |
def outputs_are_inputs(outputs, inputs): | |
input_ids = {id(inp) for inp in tree_flatten_only(torch.Tensor, inputs)} | |
return any(id(out) in input_ids for out in tree_flatten_only(torch.Tensor, outputs)) | |
def output_alias_each_other(outputs): | |
storages = set() | |
for out in tree_flatten_only(torch.Tensor, outputs): | |
if not torch._C._has_storage(out): | |
continue | |
stor = out._typed_storage()._cdata | |
if stor in storages: | |
return True | |
storages.add(stor) | |
return False | |
def is_sdpa_error(func, idx, e): | |
if ( | |
( | |
func is aten._scaled_dot_product_flash_attention.default | |
or func is aten._flash_attention_forward.default | |
) | |
and idx in (6, 7) | |
and "Devices" in repr(e) | |
): | |
return True | |
if ( | |
( | |
func is aten._scaled_dot_product_efficient_attention.default | |
or func is aten._efficient_attention_forward.default | |
) | |
and idx in (2, 3) | |
and "Devices" in repr(e) | |
): | |
return True | |
return False | |
class CrossRefFakeMode(TorchDispatchMode): | |
def __init__( | |
self, | |
ignore_op_fn: Union[Callable[[OpOverload], bool], None] = None, | |
*, | |
check_strides=True, | |
check_aliasing=True, | |
): | |
self.ignore_op_fn = ( | |
ignore_op_fn if ignore_op_fn is not None else lambda fn: False | |
) | |
self.check_strides = check_strides | |
self.check_aliasing = check_aliasing | |
def __torch_dispatch__(self, func, types, args=(), kwargs=None): | |
kwargs = kwargs or {} | |
fake_r = None | |
# empty_like excluded for now due to sparse complex | |
# aten._to_dense.default this one is getting called with csc | |
if ( | |
func | |
not in ( | |
aten.lift_fresh.default, | |
aten.lift_fresh_copy.default, | |
aten.set_.source_Storage_storage_offset, | |
) | |
and not self.ignore_op_fn(func) | |
and torch.Tag.dynamic_output_shape not in func.tags | |
and torch.Tag.inplace_view not in func.tags | |
and torch.Tag.data_dependent_output not in func.tags | |
): | |
# Do not import symbolic_shapes at the top of the module as it imports sympy and that's slow | |
from torch.fx.experimental.symbolic_shapes import ShapeEnv | |
try: | |
# TODO: enable_python_dispatcher() here | |
with FakeTensorMode(shape_env=ShapeEnv()) as fake_mode: | |
fake_args, fake_kwargs = pytree.tree_map_only( | |
torch.Tensor, | |
functools.partial(fake_mode.from_tensor, static_shapes=True), | |
(args, kwargs), | |
) | |
with warnings.catch_warnings(): | |
fake_r = func(*fake_args, **fake_kwargs) | |
except UnsupportedFakeTensorException: | |
pass | |
context = ( | |
f"When comparing the output of {func} on FakeTensor and concrete Tensors, " | |
f"found" | |
) | |
r = func(*args, **kwargs) | |
if fake_r is not None: | |
r_flat = pytree.tree_leaves(r) | |
f_flat = pytree.tree_leaves(fake_r) | |
assert len(f_flat) == len( | |
r_flat | |
), f"{context} mismatch in number of returns {len(f_flat)} != {len(r_flat)}" | |
if self.check_aliasing: | |
r_aliasing = outputs_alias_inputs(r, (args, kwargs)) | |
f_aliasing = outputs_alias_inputs(fake_r, (fake_args, fake_kwargs)) | |
assert ( | |
r_aliasing == f_aliasing | |
), f"{context} mismatch in outputs_alias_inputs check {f_aliasing} != {r_aliasing}" | |
r_identity_eq = outputs_are_inputs(r, (args, kwargs)) | |
f_identity_eq = outputs_are_inputs(fake_r, (fake_args, fake_kwargs)) | |
assert ( | |
r_identity_eq == f_identity_eq | |
), f"{context} mismatch in outputs_are_inputs check {f_identity_eq} != {r_identity_eq}" | |
r_output_alias_each_other = output_alias_each_other(r) | |
f_output_alias_each_other = output_alias_each_other(fake_r) | |
assert r_output_alias_each_other == f_output_alias_each_other, ( | |
f"{context} mismatch in outputs_alias_each_other check " | |
f"{f_output_alias_each_other} != {r_output_alias_each_other}" | |
) | |
for idx, (r_out, fake_out) in enumerate( | |
zip(pytree.tree_leaves(r), pytree.tree_leaves(fake_r)) | |
): | |
r_is_ten = isinstance(r_out, torch.Tensor) | |
assert r_is_ten == isinstance( | |
fake_out, torch.Tensor | |
), f"{context} mismatched number of tensor outputs" | |
if r_is_ten: | |
assert r_out.requires_grad == fake_out.requires_grad, ( | |
f"{context} mismatched requires_grad-ness of outputs. " | |
f"This usually means that you have added autograd support " | |
f"for your operator at a dispatch key other than Autograd, " | |
f"which will lead to problems" | |
) | |
if torch._C._has_storage(r_out): | |
r_offset = r_out.storage_offset() | |
f_offset = fake_out.storage_offset() | |
assert ( | |
r_offset == f_offset | |
), f"{context} mismatched storage offset" | |
try: | |
torch._prims.utils.compare_tensor_meta( | |
r_out, | |
fake_out, | |
check_strides=self.check_strides, | |
allow_rhs_unbacked=True, | |
) | |
except Exception as e: | |
if is_sdpa_error(func, idx, e): | |
continue | |
error_message = ( | |
f"{context} mismatched tensor metadata: {e}" | |
if len(r_flat) == 1 | |
else f"{context} mismatched tensor metadata for output[{idx}]: {e}" | |
) | |
raise RuntimeError(error_message) from e | |
return r | |