Spaces:
Runtime error
Runtime error
| import os | |
| import random | |
| import unittest | |
| from distutils.util import strtobool | |
| import torch | |
| from packaging import version | |
| global_rng = random.Random() | |
| torch_device = "cuda" if torch.cuda.is_available() else "cpu" | |
| is_torch_higher_equal_than_1_12 = version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.12") | |
| if is_torch_higher_equal_than_1_12: | |
| torch_device = "mps" if torch.backends.mps.is_available() else torch_device | |
| def parse_flag_from_env(key, default=False): | |
| try: | |
| value = os.environ[key] | |
| except KeyError: | |
| # KEY isn't set, default to `default`. | |
| _value = default | |
| else: | |
| # KEY is set, convert it to True or False. | |
| try: | |
| _value = strtobool(value) | |
| except ValueError: | |
| # More values are supported, but let's keep the message simple. | |
| raise ValueError(f"If set, {key} must be yes or no.") | |
| return _value | |
| _run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False) | |
| def floats_tensor(shape, scale=1.0, rng=None, name=None): | |
| """Creates a random float32 tensor""" | |
| if rng is None: | |
| rng = global_rng | |
| total_dims = 1 | |
| for dim in shape: | |
| total_dims *= dim | |
| values = [] | |
| for _ in range(total_dims): | |
| values.append(rng.random() * scale) | |
| return torch.tensor(data=values, dtype=torch.float).view(shape).contiguous() | |
| def slow(test_case): | |
| """ | |
| Decorator marking a test as slow. | |
| Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them. | |
| """ | |
| return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case) | |