DrFetWartz's picture
Upload folder using huggingface_hub
ffaa9fc
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import os
import sys
import warnings
from distutils.util import strtobool
from functools import lru_cache
import torch
from packaging.version import parse
from .environment import parse_flag_from_env
from .versions import compare_versions, is_torch_version
# The package importlib_metadata is in a different place, depending on the Python version.
if sys.version_info < (3, 8):
import importlib_metadata
else:
import importlib.metadata as importlib_metadata
try:
import torch_xla.core.xla_model as xm # noqa: F401
_tpu_available = True
except ImportError:
_tpu_available = False
# Cache this result has it's a C FFI call which can be pretty time-consuming
_torch_distributed_available = torch.distributed.is_available()
def is_torch_distributed_available() -> bool:
return _torch_distributed_available
def is_ccl_available():
return (
importlib.util.find_spec("torch_ccl") is not None
or importlib.util.find_spec("oneccl_bindings_for_pytorch") is not None
)
def get_ccl_version():
return importlib_metadata.version("oneccl_bind_pt")
def is_apex_available():
return importlib.util.find_spec("apex") is not None
def is_fp8_available():
return importlib.util.find_spec("transformer_engine") is not None
@lru_cache()
def is_tpu_available(check_device=True):
"Checks if `torch_xla` is installed and potentially if a TPU is in the environment"
if _tpu_available and check_device:
try:
# Will raise a RuntimeError if no XLA configuration is found
_ = xm.xla_device()
return True
except RuntimeError:
return False
return _tpu_available
def is_deepspeed_available():
package_exists = importlib.util.find_spec("deepspeed") is not None
# Check we're not importing a "deepspeed" directory somewhere but the actual library by trying to grab the version
# AND checking it has an author field in the metadata that is HuggingFace.
if package_exists:
try:
_ = importlib_metadata.metadata("deepspeed")
return True
except importlib_metadata.PackageNotFoundError:
return False
def is_bf16_available(ignore_tpu=False):
"Checks if bf16 is supported, optionally ignoring the TPU"
if is_tpu_available():
return not ignore_tpu
if is_torch_version(">=", "1.10"):
if torch.cuda.is_available():
return torch.cuda.is_bf16_supported()
return True
return False
def is_megatron_lm_available():
if strtobool(os.environ.get("ACCELERATE_USE_MEGATRON_LM", "False")) == 1:
package_exists = importlib.util.find_spec("megatron") is not None
if package_exists:
megatron_version = parse(importlib_metadata.version("megatron-lm"))
return compare_versions(megatron_version, ">=", "2.2.0")
return False
def is_safetensors_available():
return importlib.util.find_spec("safetensors") is not None
def is_transformers_available():
return importlib.util.find_spec("transformers") is not None
def is_datasets_available():
return importlib.util.find_spec("datasets") is not None
def is_aim_available():
return importlib.util.find_spec("aim") is not None
def is_tensorboard_available():
return importlib.util.find_spec("tensorboard") is not None or importlib.util.find_spec("tensorboardX") is not None
def is_wandb_available():
return importlib.util.find_spec("wandb") is not None
def is_comet_ml_available():
return importlib.util.find_spec("comet_ml") is not None
def is_boto3_available():
return importlib.util.find_spec("boto3") is not None
def is_rich_available():
if importlib.util.find_spec("rich") is not None:
if parse_flag_from_env("DISABLE_RICH"):
warnings.warn(
"The `DISABLE_RICH` flag is deprecated and will be removed in version 0.17.0 of 🤗 Accelerate. Use `ACCELERATE_DISABLE_RICH` instead.",
FutureWarning,
)
return not parse_flag_from_env("DISABLE_RICH")
return not parse_flag_from_env("ACCELERATE_DISABLE_RICH")
return False
def is_sagemaker_available():
return importlib.util.find_spec("sagemaker") is not None
def is_tqdm_available():
return importlib.util.find_spec("tqdm") is not None
def is_mlflow_available():
return importlib.util.find_spec("mlflow") is not None
def is_mps_available():
return is_torch_version(">=", "1.12") and torch.backends.mps.is_available() and torch.backends.mps.is_built()