"""Kernel test utils""" import itertools import random import unittest from numbers import Number from typing import (Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union) import pytest import torch from torch._prims_common import TensorLikeType ALL_OPCHECK_TEST_UTILS: Tuple[str, ...] = ( "test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic", ) # Copied/modified from torch._refs.__init__.py def fp8_allclose( a: TensorLikeType, b: TensorLikeType, rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False, ) -> bool: """ Reference implementation of torch.allclose """ torch._refs._check_close_args(name="torch.allclose", a=a, b=b, rtol=rtol, atol=atol) return bool( torch.all( torch.isclose(a.double(), b.double(), rtol=rtol, atol=atol, equal_nan=equal_nan)).item()) # A special version of op check that has a restricted default set of test_utils # and a patched version of allclose that supports fp8 types. def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket, torch._library.custom_ops.CustomOpDef], args: Tuple[Any, ...], kwargs: Optional[Dict[str, Any]] = None, *, test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS, raise_exception: bool = True, cond: bool = True) -> Dict[str, str]: with unittest.mock.patch('torch.allclose', new=fp8_allclose): return torch.library.opcheck( op, args, kwargs, test_utils=test_utils, raise_exception=raise_exception) if cond else {}