|
"""Kernel test utils""" |
|
|
|
import itertools |
|
import random |
|
import unittest |
|
from numbers import Number |
|
from typing import (Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, |
|
Union) |
|
|
|
import pytest |
|
import torch |
|
from torch._prims_common import TensorLikeType |
|
|
|
ALL_OPCHECK_TEST_UTILS: Tuple[str, ...] = ( |
|
"test_schema", |
|
"test_autograd_registration", |
|
"test_faketensor", |
|
"test_aot_dispatch_dynamic", |
|
) |
|
|
|
|
|
def fp8_allclose( |
|
a: TensorLikeType, |
|
b: TensorLikeType, |
|
rtol: float = 1e-05, |
|
atol: float = 1e-08, |
|
equal_nan: bool = False, |
|
) -> bool: |
|
""" |
|
Reference implementation of torch.allclose |
|
""" |
|
torch._refs._check_close_args(name="torch.allclose", |
|
a=a, |
|
b=b, |
|
rtol=rtol, |
|
atol=atol) |
|
|
|
return bool( |
|
torch.all( |
|
torch.isclose(a.double(), |
|
b.double(), |
|
rtol=rtol, |
|
atol=atol, |
|
equal_nan=equal_nan)).item()) |
|
|
|
|
|
|
|
def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket, |
|
torch._library.custom_ops.CustomOpDef], |
|
args: Tuple[Any, ...], |
|
kwargs: Optional[Dict[str, Any]] = None, |
|
*, |
|
test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS, |
|
raise_exception: bool = True, |
|
cond: bool = True) -> Dict[str, str]: |
|
with unittest.mock.patch('torch.allclose', new=fp8_allclose): |
|
return torch.library.opcheck( |
|
op, |
|
args, |
|
kwargs, |
|
test_utils=test_utils, |
|
raise_exception=raise_exception) if cond else {} |
|
|