Spaces:
Sleeping
Sleeping
import abc | |
import cmath | |
import collections.abc | |
import contextlib | |
import warnings | |
from typing import ( | |
Any, | |
Callable, | |
Collection, | |
Dict, | |
List, | |
NoReturn, | |
Optional, | |
Sequence, | |
Tuple, | |
Type, | |
Union, | |
) | |
import torch | |
try: | |
import numpy as np | |
NUMPY_AVAILABLE = True | |
except ModuleNotFoundError: | |
NUMPY_AVAILABLE = False | |
class ErrorMeta(Exception): | |
"""Internal testing exception that makes that carries error metadata.""" | |
def __init__( | |
self, type: Type[Exception], msg: str, *, id: Tuple[Any, ...] = () | |
) -> None: | |
super().__init__( | |
"If you are a user and see this message during normal operation " | |
"please file an issue at https://github.com/pytorch/pytorch/issues. " | |
"If you are a developer and working on the comparison functions, please `raise ErrorMeta().to_error()` " | |
"for user facing errors." | |
) | |
self.type = type | |
self.msg = msg | |
self.id = id | |
def to_error( | |
self, msg: Optional[Union[str, Callable[[str], str]]] = None | |
) -> Exception: | |
if not isinstance(msg, str): | |
generated_msg = self.msg | |
if self.id: | |
generated_msg += f"\n\nThe failure occurred for item {''.join(str([item]) for item in self.id)}" | |
msg = msg(generated_msg) if callable(msg) else generated_msg | |
return self.type(msg) | |
# Some analysis of tolerance by logging tests from test_torch.py can be found in | |
# https://github.com/pytorch/pytorch/pull/32538. | |
# {dtype: (rtol, atol)} | |
_DTYPE_PRECISIONS = { | |
torch.float16: (0.001, 1e-5), | |
torch.bfloat16: (0.016, 1e-5), | |
torch.float32: (1.3e-6, 1e-5), | |
torch.float64: (1e-7, 1e-7), | |
torch.complex32: (0.001, 1e-5), | |
torch.complex64: (1.3e-6, 1e-5), | |
torch.complex128: (1e-7, 1e-7), | |
} | |
# The default tolerances of torch.float32 are used for quantized dtypes, because quantized tensors are compared in | |
# their dequantized and floating point representation. For more details see `TensorLikePair._compare_quantized_values` | |
_DTYPE_PRECISIONS.update( | |
dict.fromkeys( | |
(torch.quint8, torch.quint2x4, torch.quint4x2, torch.qint8, torch.qint32), | |
_DTYPE_PRECISIONS[torch.float32], | |
) | |
) | |
def default_tolerances( | |
*inputs: Union[torch.Tensor, torch.dtype], | |
dtype_precisions: Optional[Dict[torch.dtype, Tuple[float, float]]] = None, | |
) -> Tuple[float, float]: | |
"""Returns the default absolute and relative testing tolerances for a set of inputs based on the dtype. | |
See :func:`assert_close` for a table of the default tolerance for each dtype. | |
Returns: | |
(Tuple[float, float]): Loosest tolerances of all input dtypes. | |
""" | |
dtypes = [] | |
for input in inputs: | |
if isinstance(input, torch.Tensor): | |
dtypes.append(input.dtype) | |
elif isinstance(input, torch.dtype): | |
dtypes.append(input) | |
else: | |
raise TypeError( | |
f"Expected a torch.Tensor or a torch.dtype, but got {type(input)} instead." | |
) | |
dtype_precisions = dtype_precisions or _DTYPE_PRECISIONS | |
rtols, atols = zip(*[dtype_precisions.get(dtype, (0.0, 0.0)) for dtype in dtypes]) | |
return max(rtols), max(atols) | |
def get_tolerances( | |
*inputs: Union[torch.Tensor, torch.dtype], | |
rtol: Optional[float], | |
atol: Optional[float], | |
id: Tuple[Any, ...] = (), | |
) -> Tuple[float, float]: | |
"""Gets absolute and relative to be used for numeric comparisons. | |
If both ``rtol`` and ``atol`` are specified, this is a no-op. If both are not specified, the return value of | |
:func:`default_tolerances` is used. | |
Raises: | |
ErrorMeta: With :class:`ValueError`, if only ``rtol`` or ``atol`` is specified. | |
Returns: | |
(Tuple[float, float]): Valid absolute and relative tolerances. | |
""" | |
if (rtol is None) ^ (atol is None): | |
# We require both tolerance to be omitted or specified, because specifying only one might lead to surprising | |
# results. Imagine setting atol=0.0 and the tensors still match because rtol>0.0. | |
raise ErrorMeta( | |
ValueError, | |
f"Both 'rtol' and 'atol' must be either specified or omitted, " | |
f"but got no {'rtol' if rtol is None else 'atol'}.", | |
id=id, | |
) | |
elif rtol is not None and atol is not None: | |
return rtol, atol | |
else: | |
return default_tolerances(*inputs) | |
def _make_mismatch_msg( | |
*, | |
default_identifier: str, | |
identifier: Optional[Union[str, Callable[[str], str]]] = None, | |
extra: Optional[str] = None, | |
abs_diff: float, | |
abs_diff_idx: Optional[Union[int, Tuple[int, ...]]] = None, | |
atol: float, | |
rel_diff: float, | |
rel_diff_idx: Optional[Union[int, Tuple[int, ...]]] = None, | |
rtol: float, | |
) -> str: | |
"""Makes a mismatch error message for numeric values. | |
Args: | |
default_identifier (str): Default description of the compared values, e.g. "Tensor-likes". | |
identifier (Optional[Union[str, Callable[[str], str]]]): Optional identifier that overrides | |
``default_identifier``. Can be passed as callable in which case it will be called with | |
``default_identifier`` to create the description at runtime. | |
extra (Optional[str]): Extra information to be placed after the message header and the mismatch statistics. | |
abs_diff (float): Absolute difference. | |
abs_diff_idx (Optional[Union[int, Tuple[int, ...]]]): Optional index of the absolute difference. | |
atol (float): Allowed absolute tolerance. Will only be added to mismatch statistics if it or ``rtol`` are | |
``> 0``. | |
rel_diff (float): Relative difference. | |
rel_diff_idx (Optional[Union[int, Tuple[int, ...]]]): Optional index of the relative difference. | |
rtol (float): Allowed relative tolerance. Will only be added to mismatch statistics if it or ``atol`` are | |
``> 0``. | |
""" | |
equality = rtol == 0 and atol == 0 | |
def make_diff_msg( | |
*, | |
type: str, | |
diff: float, | |
idx: Optional[Union[int, Tuple[int, ...]]], | |
tol: float, | |
) -> str: | |
if idx is None: | |
msg = f"{type.title()} difference: {diff}" | |
else: | |
msg = f"Greatest {type} difference: {diff} at index {idx}" | |
if not equality: | |
msg += f" (up to {tol} allowed)" | |
return msg + "\n" | |
if identifier is None: | |
identifier = default_identifier | |
elif callable(identifier): | |
identifier = identifier(default_identifier) | |
msg = f"{identifier} are not {'equal' if equality else 'close'}!\n\n" | |
if extra: | |
msg += f"{extra.strip()}\n" | |
msg += make_diff_msg(type="absolute", diff=abs_diff, idx=abs_diff_idx, tol=atol) | |
msg += make_diff_msg(type="relative", diff=rel_diff, idx=rel_diff_idx, tol=rtol) | |
return msg.strip() | |
def make_scalar_mismatch_msg( | |
actual: Union[bool, int, float, complex], | |
expected: Union[bool, int, float, complex], | |
*, | |
rtol: float, | |
atol: float, | |
identifier: Optional[Union[str, Callable[[str], str]]] = None, | |
) -> str: | |
"""Makes a mismatch error message for scalars. | |
Args: | |
actual (Union[bool, int, float, complex]): Actual scalar. | |
expected (Union[bool, int, float, complex]): Expected scalar. | |
rtol (float): Relative tolerance. | |
atol (float): Absolute tolerance. | |
identifier (Optional[Union[str, Callable[[str], str]]]): Optional description for the scalars. Can be passed | |
as callable in which case it will be called by the default value to create the description at runtime. | |
Defaults to "Scalars". | |
""" | |
abs_diff = abs(actual - expected) | |
rel_diff = float("inf") if expected == 0 else abs_diff / abs(expected) | |
return _make_mismatch_msg( | |
default_identifier="Scalars", | |
identifier=identifier, | |
extra=f"Expected {expected} but got {actual}.", | |
abs_diff=abs_diff, | |
atol=atol, | |
rel_diff=rel_diff, | |
rtol=rtol, | |
) | |
def make_tensor_mismatch_msg( | |
actual: torch.Tensor, | |
expected: torch.Tensor, | |
matches: torch.Tensor, | |
*, | |
rtol: float, | |
atol: float, | |
identifier: Optional[Union[str, Callable[[str], str]]] = None, | |
): | |
"""Makes a mismatch error message for tensors. | |
Args: | |
actual (torch.Tensor): Actual tensor. | |
expected (torch.Tensor): Expected tensor. | |
matches (torch.Tensor): Boolean mask of the same shape as ``actual`` and ``expected`` that indicates the | |
location of matches. | |
rtol (float): Relative tolerance. | |
atol (float): Absolute tolerance. | |
identifier (Optional[Union[str, Callable[[str], str]]]): Optional description for the tensors. Can be passed | |
as callable in which case it will be called by the default value to create the description at runtime. | |
Defaults to "Tensor-likes". | |
""" | |
def unravel_flat_index(flat_index: int) -> Tuple[int, ...]: | |
if not matches.shape: | |
return () | |
inverse_index = [] | |
for size in matches.shape[::-1]: | |
div, mod = divmod(flat_index, size) | |
flat_index = div | |
inverse_index.append(mod) | |
return tuple(inverse_index[::-1]) | |
number_of_elements = matches.numel() | |
total_mismatches = number_of_elements - int(torch.sum(matches)) | |
extra = ( | |
f"Mismatched elements: {total_mismatches} / {number_of_elements} " | |
f"({total_mismatches / number_of_elements:.1%})" | |
) | |
actual_flat = actual.flatten() | |
expected_flat = expected.flatten() | |
matches_flat = matches.flatten() | |
if not actual.dtype.is_floating_point and not actual.dtype.is_complex: | |
# TODO: Instead of always upcasting to int64, it would be sufficient to cast to the next higher dtype to avoid | |
# overflow | |
actual_flat = actual_flat.to(torch.int64) | |
expected_flat = expected_flat.to(torch.int64) | |
abs_diff = torch.abs(actual_flat - expected_flat) | |
# Ensure that only mismatches are used for the max_abs_diff computation | |
abs_diff[matches_flat] = 0 | |
max_abs_diff, max_abs_diff_flat_idx = torch.max(abs_diff, 0) | |
rel_diff = abs_diff / torch.abs(expected_flat) | |
# Ensure that only mismatches are used for the max_rel_diff computation | |
rel_diff[matches_flat] = 0 | |
max_rel_diff, max_rel_diff_flat_idx = torch.max(rel_diff, 0) | |
return _make_mismatch_msg( | |
default_identifier="Tensor-likes", | |
identifier=identifier, | |
extra=extra, | |
abs_diff=max_abs_diff.item(), | |
abs_diff_idx=unravel_flat_index(int(max_abs_diff_flat_idx)), | |
atol=atol, | |
rel_diff=max_rel_diff.item(), | |
rel_diff_idx=unravel_flat_index(int(max_rel_diff_flat_idx)), | |
rtol=rtol, | |
) | |
class UnsupportedInputs(Exception): # noqa: B903 | |
"""Exception to be raised during the construction of a :class:`Pair` in case it doesn't support the inputs.""" | |
class Pair(abc.ABC): | |
"""ABC for all comparison pairs to be used in conjunction with :func:`assert_equal`. | |
Each subclass needs to overwrite :meth:`Pair.compare` that performs the actual comparison. | |
Each pair receives **all** options, so select the ones applicable for the subclass and forward the rest to the | |
super class. Raising an :class:`UnsupportedInputs` during constructions indicates that the pair is not able to | |
handle the inputs and the next pair type will be tried. | |
All other errors should be raised as :class:`ErrorMeta`. After the instantiation, :meth:`Pair._make_error_meta` can | |
be used to automatically handle overwriting the message with a user supplied one and id handling. | |
""" | |
def __init__( | |
self, | |
actual: Any, | |
expected: Any, | |
*, | |
id: Tuple[Any, ...] = (), | |
**unknown_parameters: Any, | |
) -> None: | |
self.actual = actual | |
self.expected = expected | |
self.id = id | |
self._unknown_parameters = unknown_parameters | |
def _inputs_not_supported() -> NoReturn: | |
raise UnsupportedInputs() | |
def _check_inputs_isinstance(*inputs: Any, cls: Union[Type, Tuple[Type, ...]]): | |
"""Checks if all inputs are instances of a given class and raise :class:`UnsupportedInputs` otherwise.""" | |
if not all(isinstance(input, cls) for input in inputs): | |
Pair._inputs_not_supported() | |
def _fail( | |
self, type: Type[Exception], msg: str, *, id: Tuple[Any, ...] = () | |
) -> NoReturn: | |
"""Raises an :class:`ErrorMeta` from a given exception type and message and the stored id. | |
.. warning:: | |
If you use this before the ``super().__init__(...)`` call in the constructor, you have to pass the ``id`` | |
explicitly. | |
""" | |
raise ErrorMeta(type, msg, id=self.id if not id and hasattr(self, "id") else id) | |
def compare(self) -> None: | |
"""Compares the inputs and raises an :class`ErrorMeta` in case they mismatch.""" | |
def extra_repr(self) -> Sequence[Union[str, Tuple[str, Any]]]: | |
"""Returns extra information that will be included in the representation. | |
Should be overwritten by all subclasses that use additional options. The representation of the object will only | |
be surfaced in case we encounter an unexpected error and thus should help debug the issue. Can be a sequence of | |
key-value-pairs or attribute names. | |
""" | |
return [] | |
def __repr__(self) -> str: | |
head = f"{type(self).__name__}(" | |
tail = ")" | |
body = [ | |
f" {name}={value!s}," | |
for name, value in [ | |
("id", self.id), | |
("actual", self.actual), | |
("expected", self.expected), | |
*[ | |
(extra, getattr(self, extra)) if isinstance(extra, str) else extra | |
for extra in self.extra_repr() | |
], | |
] | |
] | |
return "\n".join((head, *body, *tail)) | |
class ObjectPair(Pair): | |
"""Pair for any type of inputs that will be compared with the `==` operator. | |
.. note:: | |
Since this will instantiate for any kind of inputs, it should only be used as fallback after all other pairs | |
couldn't handle the inputs. | |
""" | |
def compare(self) -> None: | |
try: | |
equal = self.actual == self.expected | |
except Exception as error: | |
# We are not using `self._raise_error_meta` here since we need the exception chaining | |
raise ErrorMeta( | |
ValueError, | |
f"{self.actual} == {self.expected} failed with:\n{error}.", | |
id=self.id, | |
) from error | |
if not equal: | |
self._fail(AssertionError, f"{self.actual} != {self.expected}") | |
class NonePair(Pair): | |
"""Pair for ``None`` inputs.""" | |
def __init__(self, actual: Any, expected: Any, **other_parameters: Any) -> None: | |
if not (actual is None or expected is None): | |
self._inputs_not_supported() | |
super().__init__(actual, expected, **other_parameters) | |
def compare(self) -> None: | |
if not (self.actual is None and self.expected is None): | |
self._fail( | |
AssertionError, f"None mismatch: {self.actual} is not {self.expected}" | |
) | |
class BooleanPair(Pair): | |
"""Pair for :class:`bool` inputs. | |
.. note:: | |
If ``numpy`` is available, also handles :class:`numpy.bool_` inputs. | |
""" | |
def __init__( | |
self, | |
actual: Any, | |
expected: Any, | |
*, | |
id: Tuple[Any, ...], | |
**other_parameters: Any, | |
) -> None: | |
actual, expected = self._process_inputs(actual, expected, id=id) | |
super().__init__(actual, expected, **other_parameters) | |
def _supported_types(self) -> Tuple[Type, ...]: | |
cls: List[Type] = [bool] | |
if NUMPY_AVAILABLE: | |
cls.append(np.bool_) | |
return tuple(cls) | |
def _process_inputs( | |
self, actual: Any, expected: Any, *, id: Tuple[Any, ...] | |
) -> Tuple[bool, bool]: | |
self._check_inputs_isinstance(actual, expected, cls=self._supported_types) | |
actual, expected = ( | |
self._to_bool(bool_like, id=id) for bool_like in (actual, expected) | |
) | |
return actual, expected | |
def _to_bool(self, bool_like: Any, *, id: Tuple[Any, ...]) -> bool: | |
if isinstance(bool_like, bool): | |
return bool_like | |
elif isinstance(bool_like, np.bool_): | |
return bool_like.item() | |
else: | |
raise ErrorMeta( | |
TypeError, f"Unknown boolean type {type(bool_like)}.", id=id | |
) | |
def compare(self) -> None: | |
if self.actual is not self.expected: | |
self._fail( | |
AssertionError, | |
f"Booleans mismatch: {self.actual} is not {self.expected}", | |
) | |
class NumberPair(Pair): | |
"""Pair for Python number (:class:`int`, :class:`float`, and :class:`complex`) inputs. | |
.. note:: | |
If ``numpy`` is available, also handles :class:`numpy.number` inputs. | |
Kwargs: | |
rtol (Optional[float]): Relative tolerance. If specified ``atol`` must also be specified. If omitted, default | |
values based on the type are selected with the below table. | |
atol (Optional[float]): Absolute tolerance. If specified ``rtol`` must also be specified. If omitted, default | |
values based on the type are selected with the below table. | |
equal_nan (bool): If ``True``, two ``NaN`` values are considered equal. Defaults to ``False``. | |
check_dtype (bool): If ``True``, the type of the inputs will be checked for equality. Defaults to ``False``. | |
The following table displays correspondence between Python number type and the ``torch.dtype``'s. See | |
:func:`assert_close` for the corresponding tolerances. | |
+------------------+-------------------------------+ | |
| ``type`` | corresponding ``torch.dtype`` | | |
+==================+===============================+ | |
| :class:`int` | :attr:`~torch.int64` | | |
+------------------+-------------------------------+ | |
| :class:`float` | :attr:`~torch.float64` | | |
+------------------+-------------------------------+ | |
| :class:`complex` | :attr:`~torch.complex64` | | |
+------------------+-------------------------------+ | |
""" | |
_TYPE_TO_DTYPE = { | |
int: torch.int64, | |
float: torch.float64, | |
complex: torch.complex128, | |
} | |
_NUMBER_TYPES = tuple(_TYPE_TO_DTYPE.keys()) | |
def __init__( | |
self, | |
actual: Any, | |
expected: Any, | |
*, | |
id: Tuple[Any, ...] = (), | |
rtol: Optional[float] = None, | |
atol: Optional[float] = None, | |
equal_nan: bool = False, | |
check_dtype: bool = False, | |
**other_parameters: Any, | |
) -> None: | |
actual, expected = self._process_inputs(actual, expected, id=id) | |
super().__init__(actual, expected, id=id, **other_parameters) | |
self.rtol, self.atol = get_tolerances( | |
*[self._TYPE_TO_DTYPE[type(input)] for input in (actual, expected)], | |
rtol=rtol, | |
atol=atol, | |
id=id, | |
) | |
self.equal_nan = equal_nan | |
self.check_dtype = check_dtype | |
def _supported_types(self) -> Tuple[Type, ...]: | |
cls = list(self._NUMBER_TYPES) | |
if NUMPY_AVAILABLE: | |
cls.append(np.number) | |
return tuple(cls) | |
def _process_inputs( | |
self, actual: Any, expected: Any, *, id: Tuple[Any, ...] | |
) -> Tuple[Union[int, float, complex], Union[int, float, complex]]: | |
self._check_inputs_isinstance(actual, expected, cls=self._supported_types) | |
actual, expected = ( | |
self._to_number(number_like, id=id) for number_like in (actual, expected) | |
) | |
return actual, expected | |
def _to_number( | |
self, number_like: Any, *, id: Tuple[Any, ...] | |
) -> Union[int, float, complex]: | |
if NUMPY_AVAILABLE and isinstance(number_like, np.number): | |
return number_like.item() | |
elif isinstance(number_like, self._NUMBER_TYPES): | |
return number_like # type: ignore[return-value] | |
else: | |
raise ErrorMeta( | |
TypeError, f"Unknown number type {type(number_like)}.", id=id | |
) | |
def compare(self) -> None: | |
if self.check_dtype and type(self.actual) is not type(self.expected): | |
self._fail( | |
AssertionError, | |
f"The (d)types do not match: {type(self.actual)} != {type(self.expected)}.", | |
) | |
if self.actual == self.expected: | |
return | |
if self.equal_nan and cmath.isnan(self.actual) and cmath.isnan(self.expected): | |
return | |
abs_diff = abs(self.actual - self.expected) | |
tolerance = self.atol + self.rtol * abs(self.expected) | |
if cmath.isfinite(abs_diff) and abs_diff <= tolerance: | |
return | |
self._fail( | |
AssertionError, | |
make_scalar_mismatch_msg( | |
self.actual, self.expected, rtol=self.rtol, atol=self.atol | |
), | |
) | |
def extra_repr(self) -> Sequence[str]: | |
return ( | |
"rtol", | |
"atol", | |
"equal_nan", | |
"check_dtype", | |
) | |
class TensorLikePair(Pair): | |
"""Pair for :class:`torch.Tensor`-like inputs. | |
Kwargs: | |
allow_subclasses (bool): | |
rtol (Optional[float]): Relative tolerance. If specified ``atol`` must also be specified. If omitted, default | |
values based on the type are selected. See :func:assert_close: for details. | |
atol (Optional[float]): Absolute tolerance. If specified ``rtol`` must also be specified. If omitted, default | |
values based on the type are selected. See :func:assert_close: for details. | |
equal_nan (bool): If ``True``, two ``NaN`` values are considered equal. Defaults to ``False``. | |
check_device (bool): If ``True`` (default), asserts that corresponding tensors are on the same | |
:attr:`~torch.Tensor.device`. If this check is disabled, tensors on different | |
:attr:`~torch.Tensor.device`'s are moved to the CPU before being compared. | |
check_dtype (bool): If ``True`` (default), asserts that corresponding tensors have the same ``dtype``. If this | |
check is disabled, tensors with different ``dtype``'s are promoted to a common ``dtype`` (according to | |
:func:`torch.promote_types`) before being compared. | |
check_layout (bool): If ``True`` (default), asserts that corresponding tensors have the same ``layout``. If this | |
check is disabled, tensors with different ``layout``'s are converted to strided tensors before being | |
compared. | |
check_stride (bool): If ``True`` and corresponding tensors are strided, asserts that they have the same stride. | |
""" | |
def __init__( | |
self, | |
actual: Any, | |
expected: Any, | |
*, | |
id: Tuple[Any, ...] = (), | |
allow_subclasses: bool = True, | |
rtol: Optional[float] = None, | |
atol: Optional[float] = None, | |
equal_nan: bool = False, | |
check_device: bool = True, | |
check_dtype: bool = True, | |
check_layout: bool = True, | |
check_stride: bool = False, | |
**other_parameters: Any, | |
): | |
actual, expected = self._process_inputs( | |
actual, expected, id=id, allow_subclasses=allow_subclasses | |
) | |
super().__init__(actual, expected, id=id, **other_parameters) | |
self.rtol, self.atol = get_tolerances( | |
actual, expected, rtol=rtol, atol=atol, id=self.id | |
) | |
self.equal_nan = equal_nan | |
self.check_device = check_device | |
self.check_dtype = check_dtype | |
self.check_layout = check_layout | |
self.check_stride = check_stride | |
def _process_inputs( | |
self, actual: Any, expected: Any, *, id: Tuple[Any, ...], allow_subclasses: bool | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
directly_related = isinstance(actual, type(expected)) or isinstance( | |
expected, type(actual) | |
) | |
if not directly_related: | |
self._inputs_not_supported() | |
if not allow_subclasses and type(actual) is not type(expected): | |
self._inputs_not_supported() | |
actual, expected = (self._to_tensor(input) for input in (actual, expected)) | |
for tensor in (actual, expected): | |
self._check_supported(tensor, id=id) | |
return actual, expected | |
def _to_tensor(self, tensor_like: Any) -> torch.Tensor: | |
if isinstance(tensor_like, torch.Tensor): | |
return tensor_like | |
try: | |
return torch.as_tensor(tensor_like) | |
except Exception: | |
self._inputs_not_supported() | |
def _check_supported(self, tensor: torch.Tensor, *, id: Tuple[Any, ...]) -> None: | |
if tensor.layout not in { | |
torch.strided, | |
torch.sparse_coo, | |
torch.sparse_csr, | |
torch.sparse_csc, | |
torch.sparse_bsr, | |
torch.sparse_bsc, | |
}: | |
raise ErrorMeta( | |
ValueError, f"Unsupported tensor layout {tensor.layout}", id=id | |
) | |
def compare(self) -> None: | |
actual, expected = self.actual, self.expected | |
self._compare_attributes(actual, expected) | |
if any(input.device.type == "meta" for input in (actual, expected)): | |
return | |
actual, expected = self._equalize_attributes(actual, expected) | |
self._compare_values(actual, expected) | |
def _compare_attributes( | |
self, | |
actual: torch.Tensor, | |
expected: torch.Tensor, | |
) -> None: | |
"""Checks if the attributes of two tensors match. | |
Always checks | |
- the :attr:`~torch.Tensor.shape`, | |
- whether both inputs are quantized or not, | |
- and if they use the same quantization scheme. | |
Checks for | |
- :attr:`~torch.Tensor.layout`, | |
- :meth:`~torch.Tensor.stride`, | |
- :attr:`~torch.Tensor.device`, and | |
- :attr:`~torch.Tensor.dtype` | |
are optional and can be disabled through the corresponding ``check_*`` flag during construction of the pair. | |
""" | |
def raise_mismatch_error( | |
attribute_name: str, actual_value: Any, expected_value: Any | |
) -> NoReturn: | |
self._fail( | |
AssertionError, | |
f"The values for attribute '{attribute_name}' do not match: {actual_value} != {expected_value}.", | |
) | |
if actual.shape != expected.shape: | |
raise_mismatch_error("shape", actual.shape, expected.shape) | |
if actual.is_quantized != expected.is_quantized: | |
raise_mismatch_error( | |
"is_quantized", actual.is_quantized, expected.is_quantized | |
) | |
elif actual.is_quantized and actual.qscheme() != expected.qscheme(): | |
raise_mismatch_error("qscheme()", actual.qscheme(), expected.qscheme()) | |
if actual.layout != expected.layout: | |
if self.check_layout: | |
raise_mismatch_error("layout", actual.layout, expected.layout) | |
elif ( | |
actual.layout == torch.strided | |
and self.check_stride | |
and actual.stride() != expected.stride() | |
): | |
raise_mismatch_error("stride()", actual.stride(), expected.stride()) | |
if self.check_device and actual.device != expected.device: | |
raise_mismatch_error("device", actual.device, expected.device) | |
if self.check_dtype and actual.dtype != expected.dtype: | |
raise_mismatch_error("dtype", actual.dtype, expected.dtype) | |
def _equalize_attributes( | |
self, actual: torch.Tensor, expected: torch.Tensor | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
"""Equalizes some attributes of two tensors for value comparison. | |
If ``actual`` and ``expected`` are ... | |
- ... not on the same :attr:`~torch.Tensor.device`, they are moved CPU memory. | |
- ... not of the same ``dtype``, they are promoted to a common ``dtype`` (according to | |
:func:`torch.promote_types`). | |
- ... not of the same ``layout``, they are converted to strided tensors. | |
Args: | |
actual (Tensor): Actual tensor. | |
expected (Tensor): Expected tensor. | |
Returns: | |
(Tuple[Tensor, Tensor]): Equalized tensors. | |
""" | |
# The comparison logic uses operators currently not supported by the MPS backends. | |
# See https://github.com/pytorch/pytorch/issues/77144 for details. | |
# TODO: Remove this conversion as soon as all operations are supported natively by the MPS backend | |
if actual.is_mps or expected.is_mps: # type: ignore[attr-defined] | |
actual = actual.cpu() | |
expected = expected.cpu() | |
if actual.device != expected.device: | |
actual = actual.cpu() | |
expected = expected.cpu() | |
if actual.dtype != expected.dtype: | |
actual_dtype = actual.dtype | |
expected_dtype = expected.dtype | |
# For uint64, this is not sound in general, which is why promote_types doesn't | |
# allow it, but for easy testing, we're unlikely to get confused | |
# by large uint64 overflowing into negative int64 | |
if actual_dtype in [torch.uint64, torch.uint32, torch.uint16]: | |
actual_dtype = torch.int64 | |
if expected_dtype in [torch.uint64, torch.uint32, torch.uint16]: | |
expected_dtype = torch.int64 | |
dtype = torch.promote_types(actual_dtype, expected_dtype) | |
actual = actual.to(dtype) | |
expected = expected.to(dtype) | |
if actual.layout != expected.layout: | |
# These checks are needed, since Tensor.to_dense() fails on tensors that are already strided | |
actual = actual.to_dense() if actual.layout != torch.strided else actual | |
expected = ( | |
expected.to_dense() if expected.layout != torch.strided else expected | |
) | |
return actual, expected | |
def _compare_values(self, actual: torch.Tensor, expected: torch.Tensor) -> None: | |
if actual.is_quantized: | |
compare_fn = self._compare_quantized_values | |
elif actual.is_sparse: | |
compare_fn = self._compare_sparse_coo_values | |
elif actual.layout in { | |
torch.sparse_csr, | |
torch.sparse_csc, | |
torch.sparse_bsr, | |
torch.sparse_bsc, | |
}: | |
compare_fn = self._compare_sparse_compressed_values | |
else: | |
compare_fn = self._compare_regular_values_close | |
compare_fn( | |
actual, expected, rtol=self.rtol, atol=self.atol, equal_nan=self.equal_nan | |
) | |
def _compare_quantized_values( | |
self, | |
actual: torch.Tensor, | |
expected: torch.Tensor, | |
*, | |
rtol: float, | |
atol: float, | |
equal_nan: bool, | |
) -> None: | |
"""Compares quantized tensors by comparing the :meth:`~torch.Tensor.dequantize`'d variants for closeness. | |
.. note:: | |
A detailed discussion about why only the dequantized variant is checked for closeness rather than checking | |
the individual quantization parameters for closeness and the integer representation for equality can be | |
found in https://github.com/pytorch/pytorch/issues/68548. | |
""" | |
return self._compare_regular_values_close( | |
actual.dequantize(), | |
expected.dequantize(), | |
rtol=rtol, | |
atol=atol, | |
equal_nan=equal_nan, | |
identifier=lambda default_identifier: f"Quantized {default_identifier.lower()}", | |
) | |
def _compare_sparse_coo_values( | |
self, | |
actual: torch.Tensor, | |
expected: torch.Tensor, | |
*, | |
rtol: float, | |
atol: float, | |
equal_nan: bool, | |
) -> None: | |
"""Compares sparse COO tensors by comparing | |
- the number of sparse dimensions, | |
- the number of non-zero elements (nnz) for equality, | |
- the indices for equality, and | |
- the values for closeness. | |
""" | |
if actual.sparse_dim() != expected.sparse_dim(): | |
self._fail( | |
AssertionError, | |
( | |
f"The number of sparse dimensions in sparse COO tensors does not match: " | |
f"{actual.sparse_dim()} != {expected.sparse_dim()}" | |
), | |
) | |
if actual._nnz() != expected._nnz(): | |
self._fail( | |
AssertionError, | |
( | |
f"The number of specified values in sparse COO tensors does not match: " | |
f"{actual._nnz()} != {expected._nnz()}" | |
), | |
) | |
self._compare_regular_values_equal( | |
actual._indices(), | |
expected._indices(), | |
identifier="Sparse COO indices", | |
) | |
self._compare_regular_values_close( | |
actual._values(), | |
expected._values(), | |
rtol=rtol, | |
atol=atol, | |
equal_nan=equal_nan, | |
identifier="Sparse COO values", | |
) | |
def _compare_sparse_compressed_values( | |
self, | |
actual: torch.Tensor, | |
expected: torch.Tensor, | |
*, | |
rtol: float, | |
atol: float, | |
equal_nan: bool, | |
) -> None: | |
"""Compares sparse compressed tensors by comparing | |
- the number of non-zero elements (nnz) for equality, | |
- the plain indices for equality, | |
- the compressed indices for equality, and | |
- the values for closeness. | |
""" | |
format_name, compressed_indices_method, plain_indices_method = { | |
torch.sparse_csr: ( | |
"CSR", | |
torch.Tensor.crow_indices, | |
torch.Tensor.col_indices, | |
), | |
torch.sparse_csc: ( | |
"CSC", | |
torch.Tensor.ccol_indices, | |
torch.Tensor.row_indices, | |
), | |
torch.sparse_bsr: ( | |
"BSR", | |
torch.Tensor.crow_indices, | |
torch.Tensor.col_indices, | |
), | |
torch.sparse_bsc: ( | |
"BSC", | |
torch.Tensor.ccol_indices, | |
torch.Tensor.row_indices, | |
), | |
}[actual.layout] | |
if actual._nnz() != expected._nnz(): | |
self._fail( | |
AssertionError, | |
( | |
f"The number of specified values in sparse {format_name} tensors does not match: " | |
f"{actual._nnz()} != {expected._nnz()}" | |
), | |
) | |
# Compressed and plain indices in the CSR / CSC / BSR / BSC sparse formates can be `torch.int32` _or_ | |
# `torch.int64`. While the same dtype is enforced for the compressed and plain indices of a single tensor, it | |
# can be different between two tensors. Thus, we need to convert them to the same dtype, or the comparison will | |
# fail. | |
actual_compressed_indices = compressed_indices_method(actual) | |
expected_compressed_indices = compressed_indices_method(expected) | |
indices_dtype = torch.promote_types( | |
actual_compressed_indices.dtype, expected_compressed_indices.dtype | |
) | |
self._compare_regular_values_equal( | |
actual_compressed_indices.to(indices_dtype), | |
expected_compressed_indices.to(indices_dtype), | |
identifier=f"Sparse {format_name} {compressed_indices_method.__name__}", | |
) | |
self._compare_regular_values_equal( | |
plain_indices_method(actual).to(indices_dtype), | |
plain_indices_method(expected).to(indices_dtype), | |
identifier=f"Sparse {format_name} {plain_indices_method.__name__}", | |
) | |
self._compare_regular_values_close( | |
actual.values(), | |
expected.values(), | |
rtol=rtol, | |
atol=atol, | |
equal_nan=equal_nan, | |
identifier=f"Sparse {format_name} values", | |
) | |
def _compare_regular_values_equal( | |
self, | |
actual: torch.Tensor, | |
expected: torch.Tensor, | |
*, | |
equal_nan: bool = False, | |
identifier: Optional[Union[str, Callable[[str], str]]] = None, | |
) -> None: | |
"""Checks if the values of two tensors are equal.""" | |
self._compare_regular_values_close( | |
actual, expected, rtol=0, atol=0, equal_nan=equal_nan, identifier=identifier | |
) | |
def _compare_regular_values_close( | |
self, | |
actual: torch.Tensor, | |
expected: torch.Tensor, | |
*, | |
rtol: float, | |
atol: float, | |
equal_nan: bool, | |
identifier: Optional[Union[str, Callable[[str], str]]] = None, | |
) -> None: | |
"""Checks if the values of two tensors are close up to a desired tolerance.""" | |
matches = torch.isclose( | |
actual, expected, rtol=rtol, atol=atol, equal_nan=equal_nan | |
) | |
if torch.all(matches): | |
return | |
if actual.shape == torch.Size([]): | |
msg = make_scalar_mismatch_msg( | |
actual.item(), | |
expected.item(), | |
rtol=rtol, | |
atol=atol, | |
identifier=identifier, | |
) | |
else: | |
msg = make_tensor_mismatch_msg( | |
actual, expected, matches, rtol=rtol, atol=atol, identifier=identifier | |
) | |
self._fail(AssertionError, msg) | |
def extra_repr(self) -> Sequence[str]: | |
return ( | |
"rtol", | |
"atol", | |
"equal_nan", | |
"check_device", | |
"check_dtype", | |
"check_layout", | |
"check_stride", | |
) | |
def originate_pairs( | |
actual: Any, | |
expected: Any, | |
*, | |
pair_types: Sequence[Type[Pair]], | |
sequence_types: Tuple[Type, ...] = (collections.abc.Sequence,), | |
mapping_types: Tuple[Type, ...] = (collections.abc.Mapping,), | |
id: Tuple[Any, ...] = (), | |
**options: Any, | |
) -> List[Pair]: | |
"""Originates pairs from the individual inputs. | |
``actual`` and ``expected`` can be possibly nested :class:`~collections.abc.Sequence`'s or | |
:class:`~collections.abc.Mapping`'s. In this case the pairs are originated by recursing through them. | |
Args: | |
actual (Any): Actual input. | |
expected (Any): Expected input. | |
pair_types (Sequence[Type[Pair]]): Sequence of pair types that will be tried to construct with the inputs. | |
First successful pair will be used. | |
sequence_types (Tuple[Type, ...]): Optional types treated as sequences that will be checked elementwise. | |
mapping_types (Tuple[Type, ...]): Optional types treated as mappings that will be checked elementwise. | |
id (Tuple[Any, ...]): Optional id of a pair that will be included in an error message. | |
**options (Any): Options passed to each pair during construction. | |
Raises: | |
ErrorMeta: With :class`AssertionError`, if the inputs are :class:`~collections.abc.Sequence`'s, but their | |
length does not match. | |
ErrorMeta: With :class`AssertionError`, if the inputs are :class:`~collections.abc.Mapping`'s, but their set of | |
keys do not match. | |
ErrorMeta: With :class`TypeError`, if no pair is able to handle the inputs. | |
ErrorMeta: With any expected exception that happens during the construction of a pair. | |
Returns: | |
(List[Pair]): Originated pairs. | |
""" | |
# We explicitly exclude str's here since they are self-referential and would cause an infinite recursion loop: | |
# "a" == "a"[0][0]... | |
if ( | |
isinstance(actual, sequence_types) | |
and not isinstance(actual, str) | |
and isinstance(expected, sequence_types) | |
and not isinstance(expected, str) | |
): | |
actual_len = len(actual) | |
expected_len = len(expected) | |
if actual_len != expected_len: | |
raise ErrorMeta( | |
AssertionError, | |
f"The length of the sequences mismatch: {actual_len} != {expected_len}", | |
id=id, | |
) | |
pairs = [] | |
for idx in range(actual_len): | |
pairs.extend( | |
originate_pairs( | |
actual[idx], | |
expected[idx], | |
pair_types=pair_types, | |
sequence_types=sequence_types, | |
mapping_types=mapping_types, | |
id=(*id, idx), | |
**options, | |
) | |
) | |
return pairs | |
elif isinstance(actual, mapping_types) and isinstance(expected, mapping_types): | |
actual_keys = set(actual.keys()) | |
expected_keys = set(expected.keys()) | |
if actual_keys != expected_keys: | |
missing_keys = expected_keys - actual_keys | |
additional_keys = actual_keys - expected_keys | |
raise ErrorMeta( | |
AssertionError, | |
( | |
f"The keys of the mappings do not match:\n" | |
f"Missing keys in the actual mapping: {sorted(missing_keys)}\n" | |
f"Additional keys in the actual mapping: {sorted(additional_keys)}" | |
), | |
id=id, | |
) | |
keys: Collection = actual_keys | |
# Since the origination aborts after the first failure, we try to be deterministic | |
with contextlib.suppress(Exception): | |
keys = sorted(keys) | |
pairs = [] | |
for key in keys: | |
pairs.extend( | |
originate_pairs( | |
actual[key], | |
expected[key], | |
pair_types=pair_types, | |
sequence_types=sequence_types, | |
mapping_types=mapping_types, | |
id=(*id, key), | |
**options, | |
) | |
) | |
return pairs | |
else: | |
for pair_type in pair_types: | |
try: | |
return [pair_type(actual, expected, id=id, **options)] | |
# Raising an `UnsupportedInputs` during origination indicates that the pair type is not able to handle the | |
# inputs. Thus, we try the next pair type. | |
except UnsupportedInputs: | |
continue | |
# Raising an `ErrorMeta` during origination is the orderly way to abort and so we simply re-raise it. This | |
# is only in a separate branch, because the one below would also except it. | |
except ErrorMeta: | |
raise | |
# Raising any other exception during origination is unexpected and will give some extra information about | |
# what happened. If applicable, the exception should be expected in the future. | |
except Exception as error: | |
raise RuntimeError( | |
f"Originating a {pair_type.__name__}() at item {''.join(str([item]) for item in id)} with\n\n" | |
f"{type(actual).__name__}(): {actual}\n\n" | |
f"and\n\n" | |
f"{type(expected).__name__}(): {expected}\n\n" | |
f"resulted in the unexpected exception above. " | |
f"If you are a user and see this message during normal operation " | |
"please file an issue at https://github.com/pytorch/pytorch/issues. " | |
"If you are a developer and working on the comparison functions, " | |
"please except the previous error and raise an expressive `ErrorMeta` instead." | |
) from error | |
else: | |
raise ErrorMeta( | |
TypeError, | |
f"No comparison pair was able to handle inputs of type {type(actual)} and {type(expected)}.", | |
id=id, | |
) | |
def not_close_error_metas( | |
actual: Any, | |
expected: Any, | |
*, | |
pair_types: Sequence[Type[Pair]] = (ObjectPair,), | |
sequence_types: Tuple[Type, ...] = (collections.abc.Sequence,), | |
mapping_types: Tuple[Type, ...] = (collections.abc.Mapping,), | |
**options: Any, | |
) -> List[ErrorMeta]: | |
"""Asserts that inputs are equal. | |
``actual`` and ``expected`` can be possibly nested :class:`~collections.abc.Sequence`'s or | |
:class:`~collections.abc.Mapping`'s. In this case the comparison happens elementwise by recursing through them. | |
Args: | |
actual (Any): Actual input. | |
expected (Any): Expected input. | |
pair_types (Sequence[Type[Pair]]): Sequence of :class:`Pair` types that will be tried to construct with the | |
inputs. First successful pair will be used. Defaults to only using :class:`ObjectPair`. | |
sequence_types (Tuple[Type, ...]): Optional types treated as sequences that will be checked elementwise. | |
mapping_types (Tuple[Type, ...]): Optional types treated as mappings that will be checked elementwise. | |
**options (Any): Options passed to each pair during construction. | |
""" | |
# Hide this function from `pytest`'s traceback | |
__tracebackhide__ = True | |
try: | |
pairs = originate_pairs( | |
actual, | |
expected, | |
pair_types=pair_types, | |
sequence_types=sequence_types, | |
mapping_types=mapping_types, | |
**options, | |
) | |
except ErrorMeta as error_meta: | |
# Explicitly raising from None to hide the internal traceback | |
raise error_meta.to_error() from None | |
error_metas: List[ErrorMeta] = [] | |
for pair in pairs: | |
try: | |
pair.compare() | |
except ErrorMeta as error_meta: | |
error_metas.append(error_meta) | |
# Raising any exception besides `ErrorMeta` while comparing is unexpected and will give some extra information | |
# about what happened. If applicable, the exception should be expected in the future. | |
except Exception as error: | |
raise RuntimeError( | |
f"Comparing\n\n" | |
f"{pair}\n\n" | |
f"resulted in the unexpected exception above. " | |
f"If you are a user and see this message during normal operation " | |
"please file an issue at https://github.com/pytorch/pytorch/issues. " | |
"If you are a developer and working on the comparison functions, " | |
"please except the previous error and raise an expressive `ErrorMeta` instead." | |
) from error | |
# [ErrorMeta Cycles] | |
# ErrorMeta objects in this list capture | |
# tracebacks that refer to the frame of this function. | |
# The local variable `error_metas` refers to the error meta | |
# objects, creating a reference cycle. Frames in the traceback | |
# would not get freed until cycle collection, leaking cuda memory in tests. | |
# We break the cycle by removing the reference to the error_meta objects | |
# from this frame as it returns. | |
error_metas = [error_metas] | |
return error_metas.pop() | |
def assert_close( | |
actual: Any, | |
expected: Any, | |
*, | |
allow_subclasses: bool = True, | |
rtol: Optional[float] = None, | |
atol: Optional[float] = None, | |
equal_nan: bool = False, | |
check_device: bool = True, | |
check_dtype: bool = True, | |
check_layout: bool = True, | |
check_stride: bool = False, | |
msg: Optional[Union[str, Callable[[str], str]]] = None, | |
): | |
r"""Asserts that ``actual`` and ``expected`` are close. | |
If ``actual`` and ``expected`` are strided, non-quantized, real-valued, and finite, they are considered close if | |
.. math:: | |
\lvert \text{actual} - \text{expected} \rvert \le \texttt{atol} + \texttt{rtol} \cdot \lvert \text{expected} \rvert | |
Non-finite values (``-inf`` and ``inf``) are only considered close if and only if they are equal. ``NaN``'s are | |
only considered equal to each other if ``equal_nan`` is ``True``. | |
In addition, they are only considered close if they have the same | |
- :attr:`~torch.Tensor.device` (if ``check_device`` is ``True``), | |
- ``dtype`` (if ``check_dtype`` is ``True``), | |
- ``layout`` (if ``check_layout`` is ``True``), and | |
- stride (if ``check_stride`` is ``True``). | |
If either ``actual`` or ``expected`` is a meta tensor, only the attribute checks will be performed. | |
If ``actual`` and ``expected`` are sparse (either having COO, CSR, CSC, BSR, or BSC layout), their strided members are | |
checked individually. Indices, namely ``indices`` for COO, ``crow_indices`` and ``col_indices`` for CSR and BSR, | |
or ``ccol_indices`` and ``row_indices`` for CSC and BSC layouts, respectively, | |
are always checked for equality whereas the values are checked for closeness according to the definition above. | |
If ``actual`` and ``expected`` are quantized, they are considered close if they have the same | |
:meth:`~torch.Tensor.qscheme` and the result of :meth:`~torch.Tensor.dequantize` is close according to the | |
definition above. | |
``actual`` and ``expected`` can be :class:`~torch.Tensor`'s or any tensor-or-scalar-likes from which | |
:class:`torch.Tensor`'s can be constructed with :func:`torch.as_tensor`. Except for Python scalars the input types | |
have to be directly related. In addition, ``actual`` and ``expected`` can be :class:`~collections.abc.Sequence`'s | |
or :class:`~collections.abc.Mapping`'s in which case they are considered close if their structure matches and all | |
their elements are considered close according to the above definition. | |
.. note:: | |
Python scalars are an exception to the type relation requirement, because their :func:`type`, i.e. | |
:class:`int`, :class:`float`, and :class:`complex`, is equivalent to the ``dtype`` of a tensor-like. Thus, | |
Python scalars of different types can be checked, but require ``check_dtype=False``. | |
Args: | |
actual (Any): Actual input. | |
expected (Any): Expected input. | |
allow_subclasses (bool): If ``True`` (default) and except for Python scalars, inputs of directly related types | |
are allowed. Otherwise type equality is required. | |
rtol (Optional[float]): Relative tolerance. If specified ``atol`` must also be specified. If omitted, default | |
values based on the :attr:`~torch.Tensor.dtype` are selected with the below table. | |
atol (Optional[float]): Absolute tolerance. If specified ``rtol`` must also be specified. If omitted, default | |
values based on the :attr:`~torch.Tensor.dtype` are selected with the below table. | |
equal_nan (Union[bool, str]): If ``True``, two ``NaN`` values will be considered equal. | |
check_device (bool): If ``True`` (default), asserts that corresponding tensors are on the same | |
:attr:`~torch.Tensor.device`. If this check is disabled, tensors on different | |
:attr:`~torch.Tensor.device`'s are moved to the CPU before being compared. | |
check_dtype (bool): If ``True`` (default), asserts that corresponding tensors have the same ``dtype``. If this | |
check is disabled, tensors with different ``dtype``'s are promoted to a common ``dtype`` (according to | |
:func:`torch.promote_types`) before being compared. | |
check_layout (bool): If ``True`` (default), asserts that corresponding tensors have the same ``layout``. If this | |
check is disabled, tensors with different ``layout``'s are converted to strided tensors before being | |
compared. | |
check_stride (bool): If ``True`` and corresponding tensors are strided, asserts that they have the same stride. | |
msg (Optional[Union[str, Callable[[str], str]]]): Optional error message to use in case a failure occurs during | |
the comparison. Can also passed as callable in which case it will be called with the generated message and | |
should return the new message. | |
Raises: | |
ValueError: If no :class:`torch.Tensor` can be constructed from an input. | |
ValueError: If only ``rtol`` or ``atol`` is specified. | |
AssertionError: If corresponding inputs are not Python scalars and are not directly related. | |
AssertionError: If ``allow_subclasses`` is ``False``, but corresponding inputs are not Python scalars and have | |
different types. | |
AssertionError: If the inputs are :class:`~collections.abc.Sequence`'s, but their length does not match. | |
AssertionError: If the inputs are :class:`~collections.abc.Mapping`'s, but their set of keys do not match. | |
AssertionError: If corresponding tensors do not have the same :attr:`~torch.Tensor.shape`. | |
AssertionError: If ``check_layout`` is ``True``, but corresponding tensors do not have the same | |
:attr:`~torch.Tensor.layout`. | |
AssertionError: If only one of corresponding tensors is quantized. | |
AssertionError: If corresponding tensors are quantized, but have different :meth:`~torch.Tensor.qscheme`'s. | |
AssertionError: If ``check_device`` is ``True``, but corresponding tensors are not on the same | |
:attr:`~torch.Tensor.device`. | |
AssertionError: If ``check_dtype`` is ``True``, but corresponding tensors do not have the same ``dtype``. | |
AssertionError: If ``check_stride`` is ``True``, but corresponding strided tensors do not have the same stride. | |
AssertionError: If the values of corresponding tensors are not close according to the definition above. | |
The following table displays the default ``rtol`` and ``atol`` for different ``dtype``'s. In case of mismatching | |
``dtype``'s, the maximum of both tolerances is used. | |
+---------------------------+------------+----------+ | |
| ``dtype`` | ``rtol`` | ``atol`` | | |
+===========================+============+==========+ | |
| :attr:`~torch.float16` | ``1e-3`` | ``1e-5`` | | |
+---------------------------+------------+----------+ | |
| :attr:`~torch.bfloat16` | ``1.6e-2`` | ``1e-5`` | | |
+---------------------------+------------+----------+ | |
| :attr:`~torch.float32` | ``1.3e-6`` | ``1e-5`` | | |
+---------------------------+------------+----------+ | |
| :attr:`~torch.float64` | ``1e-7`` | ``1e-7`` | | |
+---------------------------+------------+----------+ | |
| :attr:`~torch.complex32` | ``1e-3`` | ``1e-5`` | | |
+---------------------------+------------+----------+ | |
| :attr:`~torch.complex64` | ``1.3e-6`` | ``1e-5`` | | |
+---------------------------+------------+----------+ | |
| :attr:`~torch.complex128` | ``1e-7`` | ``1e-7`` | | |
+---------------------------+------------+----------+ | |
| :attr:`~torch.quint8` | ``1.3e-6`` | ``1e-5`` | | |
+---------------------------+------------+----------+ | |
| :attr:`~torch.quint2x4` | ``1.3e-6`` | ``1e-5`` | | |
+---------------------------+------------+----------+ | |
| :attr:`~torch.quint4x2` | ``1.3e-6`` | ``1e-5`` | | |
+---------------------------+------------+----------+ | |
| :attr:`~torch.qint8` | ``1.3e-6`` | ``1e-5`` | | |
+---------------------------+------------+----------+ | |
| :attr:`~torch.qint32` | ``1.3e-6`` | ``1e-5`` | | |
+---------------------------+------------+----------+ | |
| other | ``0.0`` | ``0.0`` | | |
+---------------------------+------------+----------+ | |
.. note:: | |
:func:`~torch.testing.assert_close` is highly configurable with strict default settings. Users are encouraged | |
to :func:`~functools.partial` it to fit their use case. For example, if an equality check is needed, one might | |
define an ``assert_equal`` that uses zero tolerances for every ``dtype`` by default: | |
>>> import functools | |
>>> assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=0) | |
>>> assert_equal(1e-9, 1e-10) | |
Traceback (most recent call last): | |
... | |
AssertionError: Scalars are not equal! | |
<BLANKLINE> | |
Expected 1e-10 but got 1e-09. | |
Absolute difference: 9.000000000000001e-10 | |
Relative difference: 9.0 | |
Examples: | |
>>> # tensor to tensor comparison | |
>>> expected = torch.tensor([1e0, 1e-1, 1e-2]) | |
>>> actual = torch.acos(torch.cos(expected)) | |
>>> torch.testing.assert_close(actual, expected) | |
>>> # scalar to scalar comparison | |
>>> import math | |
>>> expected = math.sqrt(2.0) | |
>>> actual = 2.0 / math.sqrt(2.0) | |
>>> torch.testing.assert_close(actual, expected) | |
>>> # numpy array to numpy array comparison | |
>>> import numpy as np | |
>>> expected = np.array([1e0, 1e-1, 1e-2]) | |
>>> actual = np.arccos(np.cos(expected)) | |
>>> torch.testing.assert_close(actual, expected) | |
>>> # sequence to sequence comparison | |
>>> import numpy as np | |
>>> # The types of the sequences do not have to match. They only have to have the same | |
>>> # length and their elements have to match. | |
>>> expected = [torch.tensor([1.0]), 2.0, np.array(3.0)] | |
>>> actual = tuple(expected) | |
>>> torch.testing.assert_close(actual, expected) | |
>>> # mapping to mapping comparison | |
>>> from collections import OrderedDict | |
>>> import numpy as np | |
>>> foo = torch.tensor(1.0) | |
>>> bar = 2.0 | |
>>> baz = np.array(3.0) | |
>>> # The types and a possible ordering of mappings do not have to match. They only | |
>>> # have to have the same set of keys and their elements have to match. | |
>>> expected = OrderedDict([("foo", foo), ("bar", bar), ("baz", baz)]) | |
>>> actual = {"baz": baz, "bar": bar, "foo": foo} | |
>>> torch.testing.assert_close(actual, expected) | |
>>> expected = torch.tensor([1.0, 2.0, 3.0]) | |
>>> actual = expected.clone() | |
>>> # By default, directly related instances can be compared | |
>>> torch.testing.assert_close(torch.nn.Parameter(actual), expected) | |
>>> # This check can be made more strict with allow_subclasses=False | |
>>> torch.testing.assert_close( | |
... torch.nn.Parameter(actual), expected, allow_subclasses=False | |
... ) | |
Traceback (most recent call last): | |
... | |
TypeError: No comparison pair was able to handle inputs of type | |
<class 'torch.nn.parameter.Parameter'> and <class 'torch.Tensor'>. | |
>>> # If the inputs are not directly related, they are never considered close | |
>>> torch.testing.assert_close(actual.numpy(), expected) | |
Traceback (most recent call last): | |
... | |
TypeError: No comparison pair was able to handle inputs of type <class 'numpy.ndarray'> | |
and <class 'torch.Tensor'>. | |
>>> # Exceptions to these rules are Python scalars. They can be checked regardless of | |
>>> # their type if check_dtype=False. | |
>>> torch.testing.assert_close(1.0, 1, check_dtype=False) | |
>>> # NaN != NaN by default. | |
>>> expected = torch.tensor(float("Nan")) | |
>>> actual = expected.clone() | |
>>> torch.testing.assert_close(actual, expected) | |
Traceback (most recent call last): | |
... | |
AssertionError: Scalars are not close! | |
<BLANKLINE> | |
Expected nan but got nan. | |
Absolute difference: nan (up to 1e-05 allowed) | |
Relative difference: nan (up to 1.3e-06 allowed) | |
>>> torch.testing.assert_close(actual, expected, equal_nan=True) | |
>>> expected = torch.tensor([1.0, 2.0, 3.0]) | |
>>> actual = torch.tensor([1.0, 4.0, 5.0]) | |
>>> # The default error message can be overwritten. | |
>>> torch.testing.assert_close(actual, expected, msg="Argh, the tensors are not close!") | |
Traceback (most recent call last): | |
... | |
AssertionError: Argh, the tensors are not close! | |
>>> # If msg is a callable, it can be used to augment the generated message with | |
>>> # extra information | |
>>> torch.testing.assert_close( | |
... actual, expected, msg=lambda msg: f"Header\n\n{msg}\n\nFooter" | |
... ) | |
Traceback (most recent call last): | |
... | |
AssertionError: Header | |
<BLANKLINE> | |
Tensor-likes are not close! | |
<BLANKLINE> | |
Mismatched elements: 2 / 3 (66.7%) | |
Greatest absolute difference: 2.0 at index (1,) (up to 1e-05 allowed) | |
Greatest relative difference: 1.0 at index (1,) (up to 1.3e-06 allowed) | |
<BLANKLINE> | |
Footer | |
""" | |
# Hide this function from `pytest`'s traceback | |
__tracebackhide__ = True | |
error_metas = not_close_error_metas( | |
actual, | |
expected, | |
pair_types=( | |
NonePair, | |
BooleanPair, | |
NumberPair, | |
TensorLikePair, | |
), | |
allow_subclasses=allow_subclasses, | |
rtol=rtol, | |
atol=atol, | |
equal_nan=equal_nan, | |
check_device=check_device, | |
check_dtype=check_dtype, | |
check_layout=check_layout, | |
check_stride=check_stride, | |
msg=msg, | |
) | |
if error_metas: | |
# TODO: compose all metas into one AssertionError | |
raise error_metas[0].to_error(msg) | |
def assert_allclose( | |
actual: Any, | |
expected: Any, | |
rtol: Optional[float] = None, | |
atol: Optional[float] = None, | |
equal_nan: bool = True, | |
msg: str = "", | |
) -> None: | |
""" | |
.. warning:: | |
:func:`torch.testing.assert_allclose` is deprecated since ``1.12`` and will be removed in a future release. | |
Please use :func:`torch.testing.assert_close` instead. You can find detailed upgrade instructions | |
`here <https://github.com/pytorch/pytorch/issues/61844>`_. | |
""" | |
warnings.warn( | |
"`torch.testing.assert_allclose()` is deprecated since 1.12 and will be removed in a future release. " | |
"Please use `torch.testing.assert_close()` instead. " | |
"You can find detailed upgrade instructions in https://github.com/pytorch/pytorch/issues/61844.", | |
FutureWarning, | |
stacklevel=2, | |
) | |
if not isinstance(actual, torch.Tensor): | |
actual = torch.tensor(actual) | |
if not isinstance(expected, torch.Tensor): | |
expected = torch.tensor(expected, dtype=actual.dtype) | |
if rtol is None and atol is None: | |
rtol, atol = default_tolerances( | |
actual, | |
expected, | |
dtype_precisions={ | |
torch.float16: (1e-3, 1e-3), | |
torch.float32: (1e-4, 1e-5), | |
torch.float64: (1e-5, 1e-8), | |
}, | |
) | |
torch.testing.assert_close( | |
actual, | |
expected, | |
rtol=rtol, | |
atol=atol, | |
equal_nan=equal_nan, | |
check_device=True, | |
check_dtype=False, | |
check_stride=False, | |
msg=msg or None, | |
) | |