Spaces:
Build error
Build error
File size: 6,291 Bytes
28c256d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import sys
from collections.abc import Iterable
from runpy import run_path
from shlex import split
from typing import Any, Callable, Dict, List, Optional, Union
from unittest.mock import patch
from torch.nn import GroupNorm, LayerNorm
from torch.testing import assert_allclose as _assert_allclose
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm, _InstanceNorm
def assert_allclose(
actual: Any,
expected: Any,
rtol: Optional[float] = None,
atol: Optional[float] = None,
equal_nan: bool = True,
msg: Optional[Union[str, Callable]] = '',
) -> None:
"""Asserts that ``actual`` and ``expected`` are close. A wrapper function
of ``torch.testing.assert_allclose``.
Args:
actual (Any): Actual input.
expected (Any): Expected input.
rtol (Optional[float]): Relative tolerance. If specified ``atol`` must
also be specified. If omitted, default values based on the
:attr:`~torch.Tensor.dtype` are selected with the below table.
atol (Optional[float]): Absolute tolerance. If specified :attr:`rtol`
must also be specified. If omitted, default values based on the
:attr:`~torch.Tensor.dtype` are selected with the below table.
equal_nan (bool): If ``True``, two ``NaN`` values will be considered
equal.
msg (Optional[Union[str, Callable]]): Optional error message to use if
the values of corresponding tensors mismatch. Unused when PyTorch
< 1.6.
"""
if 'parrots' not in TORCH_VERSION and \
digit_version(TORCH_VERSION) >= digit_version('1.6'):
_assert_allclose(
actual,
expected,
rtol=rtol,
atol=atol,
equal_nan=equal_nan,
msg=msg)
else:
# torch.testing.assert_allclose has no ``msg`` argument
# when PyTorch < 1.6
_assert_allclose(
actual, expected, rtol=rtol, atol=atol, equal_nan=equal_nan)
def check_python_script(cmd):
"""Run the python cmd script with `__main__`. The difference between
`os.system` is that, this function exectues code in the current process, so
that it can be tracked by coverage tools. Currently it supports two forms:
- ./tests/data/scripts/hello.py zz
- python tests/data/scripts/hello.py zz
"""
args = split(cmd)
if args[0] == 'python':
args = args[1:]
with patch.object(sys, 'argv', args):
run_path(args[0], run_name='__main__')
def _any(judge_result):
"""Since built-in ``any`` works only when the element of iterable is not
iterable, implement the function."""
if not isinstance(judge_result, Iterable):
return judge_result
try:
for element in judge_result:
if _any(element):
return True
except TypeError:
# Maybe encounter the case: torch.tensor(True) | torch.tensor(False)
if judge_result:
return True
return False
def assert_dict_contains_subset(dict_obj: Dict[Any, Any],
expected_subset: Dict[Any, Any]) -> bool:
"""Check if the dict_obj contains the expected_subset.
Args:
dict_obj (Dict[Any, Any]): Dict object to be checked.
expected_subset (Dict[Any, Any]): Subset expected to be contained in
dict_obj.
Returns:
bool: Whether the dict_obj contains the expected_subset.
"""
for key, value in expected_subset.items():
if key not in dict_obj.keys() or _any(dict_obj[key] != value):
return False
return True
def assert_attrs_equal(obj: Any, expected_attrs: Dict[str, Any]) -> bool:
"""Check if attribute of class object is correct.
Args:
obj (object): Class object to be checked.
expected_attrs (Dict[str, Any]): Dict of the expected attrs.
Returns:
bool: Whether the attribute of class object is correct.
"""
for attr, value in expected_attrs.items():
if not hasattr(obj, attr) or _any(getattr(obj, attr) != value):
return False
return True
def assert_dict_has_keys(obj: Dict[str, Any],
expected_keys: List[str]) -> bool:
"""Check if the obj has all the expected_keys.
Args:
obj (Dict[str, Any]): Object to be checked.
expected_keys (List[str]): Keys expected to contained in the keys of
the obj.
Returns:
bool: Whether the obj has the expected keys.
"""
return set(expected_keys).issubset(set(obj.keys()))
def assert_keys_equal(result_keys: List[str], target_keys: List[str]) -> bool:
"""Check if target_keys is equal to result_keys.
Args:
result_keys (List[str]): Result keys to be checked.
target_keys (List[str]): Target keys to be checked.
Returns:
bool: Whether target_keys is equal to result_keys.
"""
return set(result_keys) == set(target_keys)
def assert_is_norm_layer(module) -> bool:
"""Check if the module is a norm layer.
Args:
module (nn.Module): The module to be checked.
Returns:
bool: Whether the module is a norm layer.
"""
norm_layer_candidates = (_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm)
return isinstance(module, norm_layer_candidates)
def assert_params_all_zeros(module) -> bool:
"""Check if the parameters of the module is all zeros.
Args:
module (nn.Module): The module to be checked.
Returns:
bool: Whether the parameters of the module is all zeros.
"""
weight_data = module.weight.data
is_weight_zero = weight_data.allclose(
weight_data.new_zeros(weight_data.size()))
if hasattr(module, 'bias') and module.bias is not None:
bias_data = module.bias.data
is_bias_zero = bias_data.allclose(
bias_data.new_zeros(bias_data.size()))
else:
is_bias_zero = True
return is_weight_zero and is_bias_zero
|