# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import sys from collections.abc import Iterable from runpy import run_path from shlex import split from typing import Any, Callable, Dict, List, Optional, Union from unittest.mock import patch from torch.nn import GroupNorm, LayerNorm from torch.testing import assert_allclose as _assert_allclose from mmengine.utils import digit_version from mmengine.utils.dl_utils import TORCH_VERSION from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm, _InstanceNorm def assert_allclose( actual: Any, expected: Any, rtol: Optional[float] = None, atol: Optional[float] = None, equal_nan: bool = True, msg: Optional[Union[str, Callable]] = '', ) -> None: """Asserts that ``actual`` and ``expected`` are close. A wrapper function of ``torch.testing.assert_allclose``. Args: actual (Any): Actual input. expected (Any): Expected input. rtol (Optional[float]): Relative tolerance. If specified ``atol`` must also be specified. If omitted, default values based on the :attr:`~torch.Tensor.dtype` are selected with the below table. atol (Optional[float]): Absolute tolerance. If specified :attr:`rtol` must also be specified. If omitted, default values based on the :attr:`~torch.Tensor.dtype` are selected with the below table. equal_nan (bool): If ``True``, two ``NaN`` values will be considered equal. msg (Optional[Union[str, Callable]]): Optional error message to use if the values of corresponding tensors mismatch. Unused when PyTorch < 1.6. """ if 'parrots' not in TORCH_VERSION and \ digit_version(TORCH_VERSION) >= digit_version('1.6'): _assert_allclose( actual, expected, rtol=rtol, atol=atol, equal_nan=equal_nan, msg=msg) else: # torch.testing.assert_allclose has no ``msg`` argument # when PyTorch < 1.6 _assert_allclose( actual, expected, rtol=rtol, atol=atol, equal_nan=equal_nan) def check_python_script(cmd): """Run the python cmd script with `__main__`. The difference between `os.system` is that, this function exectues code in the current process, so that it can be tracked by coverage tools. Currently it supports two forms: - ./tests/data/scripts/hello.py zz - python tests/data/scripts/hello.py zz """ args = split(cmd) if args[0] == 'python': args = args[1:] with patch.object(sys, 'argv', args): run_path(args[0], run_name='__main__') def _any(judge_result): """Since built-in ``any`` works only when the element of iterable is not iterable, implement the function.""" if not isinstance(judge_result, Iterable): return judge_result try: for element in judge_result: if _any(element): return True except TypeError: # Maybe encounter the case: torch.tensor(True) | torch.tensor(False) if judge_result: return True return False def assert_dict_contains_subset(dict_obj: Dict[Any, Any], expected_subset: Dict[Any, Any]) -> bool: """Check if the dict_obj contains the expected_subset. Args: dict_obj (Dict[Any, Any]): Dict object to be checked. expected_subset (Dict[Any, Any]): Subset expected to be contained in dict_obj. Returns: bool: Whether the dict_obj contains the expected_subset. """ for key, value in expected_subset.items(): if key not in dict_obj.keys() or _any(dict_obj[key] != value): return False return True def assert_attrs_equal(obj: Any, expected_attrs: Dict[str, Any]) -> bool: """Check if attribute of class object is correct. Args: obj (object): Class object to be checked. expected_attrs (Dict[str, Any]): Dict of the expected attrs. Returns: bool: Whether the attribute of class object is correct. """ for attr, value in expected_attrs.items(): if not hasattr(obj, attr) or _any(getattr(obj, attr) != value): return False return True def assert_dict_has_keys(obj: Dict[str, Any], expected_keys: List[str]) -> bool: """Check if the obj has all the expected_keys. Args: obj (Dict[str, Any]): Object to be checked. expected_keys (List[str]): Keys expected to contained in the keys of the obj. Returns: bool: Whether the obj has the expected keys. """ return set(expected_keys).issubset(set(obj.keys())) def assert_keys_equal(result_keys: List[str], target_keys: List[str]) -> bool: """Check if target_keys is equal to result_keys. Args: result_keys (List[str]): Result keys to be checked. target_keys (List[str]): Target keys to be checked. Returns: bool: Whether target_keys is equal to result_keys. """ return set(result_keys) == set(target_keys) def assert_is_norm_layer(module) -> bool: """Check if the module is a norm layer. Args: module (nn.Module): The module to be checked. Returns: bool: Whether the module is a norm layer. """ norm_layer_candidates = (_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm) return isinstance(module, norm_layer_candidates) def assert_params_all_zeros(module) -> bool: """Check if the parameters of the module is all zeros. Args: module (nn.Module): The module to be checked. Returns: bool: Whether the parameters of the module is all zeros. """ weight_data = module.weight.data is_weight_zero = weight_data.allclose( weight_data.new_zeros(weight_data.size())) if hasattr(module, 'bias') and module.bias is not None: bias_data = module.bias.data is_bias_zero = bias_data.allclose( bias_data.new_zeros(bias_data.size())) else: is_bias_zero = True return is_weight_zero and is_bias_zero