Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import os | |
import tempfile | |
import unittest | |
from distutils.util import strtobool | |
from transformers.file_utils import _tf_available, _torch_available | |
CACHE_DIR = os.path.join(tempfile.gettempdir(), "transformers_test") | |
SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy" | |
DUMMY_UNKWOWN_IDENTIFIER = "julien-c/dummy-unknown" | |
# Used to test Auto{Config, Model, Tokenizer} model_type detection. | |
def parse_flag_from_env(key, default=False): | |
try: | |
value = os.environ[key] | |
except KeyError: | |
# KEY isn't set, default to `default`. | |
_value = default | |
else: | |
# KEY is set, convert it to True or False. | |
try: | |
_value = strtobool(value) | |
except ValueError: | |
# More values are supported, but let's keep the message simple. | |
raise ValueError("If set, {} must be yes or no.".format(key)) | |
return _value | |
_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False) | |
_run_custom_tokenizers = parse_flag_from_env("RUN_CUSTOM_TOKENIZERS", default=False) | |
def slow(test_case): | |
""" | |
Decorator marking a test as slow. | |
Slow tests are skipped by default. Set the RUN_SLOW environment variable | |
to a truthy value to run them. | |
""" | |
if not _run_slow_tests: | |
test_case = unittest.skip("test is slow")(test_case) | |
return test_case | |
def custom_tokenizers(test_case): | |
""" | |
Decorator marking a test for a custom tokenizer. | |
Custom tokenizers require additional dependencies, and are skipped | |
by default. Set the RUN_CUSTOM_TOKENIZERS environment variable | |
to a truthy value to run them. | |
""" | |
if not _run_custom_tokenizers: | |
test_case = unittest.skip("test of custom tokenizers")(test_case) | |
return test_case | |
def require_torch(test_case): | |
""" | |
Decorator marking a test that requires PyTorch. | |
These tests are skipped when PyTorch isn't installed. | |
""" | |
if not _torch_available: | |
test_case = unittest.skip("test requires PyTorch")(test_case) | |
return test_case | |
def require_tf(test_case): | |
""" | |
Decorator marking a test that requires TensorFlow. | |
These tests are skipped when TensorFlow isn't installed. | |
""" | |
if not _tf_available: | |
test_case = unittest.skip("test requires TensorFlow")(test_case) | |
return test_case | |
if _torch_available: | |
# Set the USE_CUDA environment variable to select a GPU. | |
torch_device = "cuda" if parse_flag_from_env("USE_CUDA") else "cpu" | |
else: | |
torch_device = None | |