danieldk's picture
danieldk HF Staff
Import CUTLASS tests and add missing scaled mm with zp signature
2dd62c9
raw
history blame
1.97 kB
"""Kernel test utils"""
import itertools
import random
import unittest
from numbers import Number
from typing import (Any, Dict, List, NamedTuple, Optional, Sequence, Tuple,
Union)
import pytest
import torch
from torch._prims_common import TensorLikeType
ALL_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
"test_schema",
"test_autograd_registration",
"test_faketensor",
"test_aot_dispatch_dynamic",
)
# Copied/modified from torch._refs.__init__.py
def fp8_allclose(
a: TensorLikeType,
b: TensorLikeType,
rtol: float = 1e-05,
atol: float = 1e-08,
equal_nan: bool = False,
) -> bool:
"""
Reference implementation of torch.allclose
"""
torch._refs._check_close_args(name="torch.allclose",
a=a,
b=b,
rtol=rtol,
atol=atol)
return bool(
torch.all(
torch.isclose(a.double(),
b.double(),
rtol=rtol,
atol=atol,
equal_nan=equal_nan)).item())
# A special version of op check that has a restricted default set of test_utils
# and a patched version of allclose that supports fp8 types.
def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket,
torch._library.custom_ops.CustomOpDef],
args: Tuple[Any, ...],
kwargs: Optional[Dict[str, Any]] = None,
*,
test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS,
raise_exception: bool = True,
cond: bool = True) -> Dict[str, str]:
with unittest.mock.patch('torch.allclose', new=fp8_allclose):
return torch.library.opcheck(
op,
args,
kwargs,
test_utils=test_utils,
raise_exception=raise_exception) if cond else {}