Spaces:
Running
on
Zero
Running
on
Zero
import ast | |
import fnmatch | |
import hashlib | |
import inspect | |
import io | |
import json | |
import logging | |
import os | |
import re | |
import shutil | |
import tarfile | |
import tempfile | |
from collections.abc import Mapping | |
from contextlib import contextmanager | |
from dataclasses import dataclass | |
from enum import Enum | |
from functools import partial | |
from hashlib import sha256 | |
from os.path import basename, isdir, isfile, join | |
from pathlib import Path | |
from typing import Callable, Dict, List, Optional, Tuple, Union | |
from zipfile import ZipFile, is_zipfile | |
import torch | |
import requests | |
from filelock import FileLock | |
from huggingface_hub import HfApi, HfFolder, snapshot_download | |
from huggingface_hub.file_download import http_get | |
from huggingface_hub.utils import ( | |
EntryNotFoundError, | |
RepositoryNotFoundError, | |
RevisionNotFoundError, | |
hf_raise_for_status, | |
) | |
from requests.exceptions import HTTPError | |
from transformers.utils import http_user_agent, is_remote_url | |
from . import __version__ | |
from .context import ForwardContext | |
logger = logging.getLogger(__name__) | |
CONFIG_NAME = "adapter_config.json" | |
WEIGHTS_NAME = "pytorch_adapter.bin" | |
SAFE_WEIGHTS_NAME = "adapter.safetensors" | |
HEAD_CONFIG_NAME = "head_config.json" | |
HEAD_WEIGHTS_NAME = "pytorch_model_head.bin" | |
SAFE_HEAD_WEIGHTS_NAME = "model_head.safetensors" | |
ADAPTERFUSION_CONFIG_NAME = "adapter_fusion_config.json" | |
ADAPTERFUSION_WEIGHTS_NAME = "pytorch_model_adapter_fusion.bin" | |
SAFE_ADAPTERFUSION_WEIGHTS_NAME = "model_adapter_fusion.safetensors" | |
EMBEDDING_FILE = "embedding.pt" | |
TOKENIZER_PATH = "tokenizer" | |
ADAPTER_HUB_URL = "https://raw.githubusercontent.com/Adapter-Hub/Hub/master/dist/v2/" | |
ADAPTER_HUB_INDEX_FILE = ADAPTER_HUB_URL + "index/{}.json" | |
ADAPTER_HUB_CONFIG_FILE = ADAPTER_HUB_URL + "architectures.json" | |
ADAPTER_HUB_ALL_FILE = ADAPTER_HUB_URL + "all.json" | |
ADAPTER_HUB_ADAPTER_ENTRY_JSON = ADAPTER_HUB_URL + "adapters/{}/{}.json" | |
# the download cache | |
torch_cache_home = os.getenv( | |
"TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", os.path.expanduser("~/.cache")), "torch") | |
) | |
ADAPTER_CACHE = join(torch_cache_home, "adapters") | |
# these keys are ignored when calculating the config hash | |
ADAPTER_CONFIG_HASH_IGNORE = [] | |
# old: new | |
ACTIVATION_RENAME = { | |
"gelu": "gelu_new", | |
"gelu_orig": "gelu", | |
} | |
# HACK: To keep config hashs consistent with v2, remove default values of keys introduced in v3 from hash computation | |
ADAPTER_CONFIG_HASH_IGNORE_DEFAULT = { | |
"phm_layer": True, | |
"phm_dim": 4, | |
"factorized_phm_W": True, | |
"shared_W_phm": False, | |
"shared_phm_rule": True, | |
"factorized_phm_rule": False, | |
"phm_c_init": "normal", | |
"phm_init_range": 0.0001, | |
"learn_phm": True, | |
"hypercomplex_nonlinearity": "glorot-uniform", | |
"phm_rank": 1, | |
"phm_bias": True, | |
"init_weights": "bert", | |
"scaling": 1.0, | |
} | |
ADAPTER_CONFIG_STRING_PATTERN = re.compile(r"^(?P<name>[^\[\]\|\n]+)(?:\[(?P<kvs>.*)\])?$") | |
class AdapterType(str, Enum): | |
"""Models all currently available model adapter types.""" | |
text_task = "text_task" | |
text_lang = "text_lang" | |
def has(cls, value): | |
return value in cls.__members__.values() | |
def __repr__(self): | |
return self.value | |
class AdapterInfo: | |
""" | |
Holds information about an adapter publicly available on the Hub. Returned by | |
:func:`list_adapters()`. | |
Args: | |
source (str): The source repository of this adapter. Always 'hf' for adapters available on HF Model Hub. | |
adapter_id (str): The unique identifier of this adapter. | |
model_name (str, optional): The identifier of the model this adapter was trained for. | |
task (str, optional): The task this adapter was trained for. | |
subtask (str, optional): The subtask or dataset this adapter was trained on. | |
username (str, optional): The username of author(s) of this adapter. | |
adapter_config (dict, optional): The configuration dictionary of this adapter. | |
""" | |
source: str | |
adapter_id: str | |
model_name: Optional[str] = None | |
task: Optional[str] = None | |
subtask: Optional[str] = None | |
username: Optional[str] = None | |
adapter_config: Optional[dict] = None | |
sha1_checksum: Optional[str] = None | |
def _minimize_dict(d): | |
if isinstance(d, Mapping): | |
return {k: _minimize_dict(v) for (k, v) in d.items() if v} | |
else: | |
return d | |
def get_adapter_config_hash(config, length=16, ignore_params=[]): | |
""" | |
Calculates the hash of a given adapter configuration which is used to identify this configuration. | |
Returns: | |
str: The resulting hash of the given config dict. | |
""" | |
minimized_config = _minimize_dict( | |
{k: v for (k, v) in config.items() if k not in ADAPTER_CONFIG_HASH_IGNORE + ignore_params} | |
) | |
# ensure hash is kept consistent to previous versions | |
for name, default in ADAPTER_CONFIG_HASH_IGNORE_DEFAULT.items(): | |
if minimized_config.get(name, None) == default: | |
del minimized_config[name] | |
dict_str = json.dumps(minimized_config, sort_keys=True) | |
h = hashlib.sha1() | |
h.update(dict_str.encode(encoding="utf-8")) | |
return h.hexdigest()[:length] | |
def inherit_doc(cls): | |
for name, func in vars(cls).items(): | |
if isinstance(func, Callable) and not func.__doc__: | |
for parent in cls.__bases__: | |
parfunc = getattr(parent, name, None) | |
if parfunc and getattr(parfunc, "__doc__", None): | |
func.__doc__ = parfunc.__doc__ | |
break | |
return cls | |
def urljoin(*args): | |
return "/".join([s.strip("/") for s in args]) | |
def remote_file_exists(url): | |
r = requests.head(url) | |
return r.status_code == 200 | |
# Copied from here: https://github.com/huggingface/huggingface_hub/blob/v0.25.0/src/huggingface_hub/file_download.py#L266 | |
def url_to_filename(url: str, etag: Optional[str] = None) -> str: | |
"""Generate a local filename from a url. | |
Convert `url` into a hashed filename in a reproducible way. If `etag` is | |
specified, append its hash to the url's, delimited by a period. If the url | |
ends with .h5 (Keras HDF5 weights) adds '.h5' to the name so that TF 2.0 can | |
identify it as a HDF5 file (see | |
https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1380) | |
Args: | |
url (`str`): | |
The address to the file. | |
etag (`str`, *optional*): | |
The ETag of the file. | |
Returns: | |
The generated filename. | |
""" | |
url_bytes = url.encode("utf-8") | |
filename = sha256(url_bytes).hexdigest() | |
if etag: | |
etag_bytes = etag.encode("utf-8") | |
filename += "." + sha256(etag_bytes).hexdigest() | |
if url.endswith(".h5"): | |
filename += ".h5" | |
return filename | |
# Copied from last version of this method in HF codebase: | |
# https://github.com/huggingface/transformers/blob/9129fd0377e4d46cb2d0ea28dc1eb91a15f65b77/src/transformers/utils/hub.py#L460 | |
def get_from_cache( | |
url: str, | |
cache_dir=None, | |
force_download=False, | |
proxies=None, | |
etag_timeout=10, | |
resume_download=False, | |
user_agent: Union[Dict, str, None] = None, | |
use_auth_token: Union[bool, str, None] = None, | |
local_files_only=False, | |
) -> Optional[str]: | |
""" | |
Given a URL, look for the corresponding file in the local cache. If it's not there, download it. Then return the | |
path to the cached file. | |
Return: | |
Local path (string) of file or if networking is off, last version of file cached on disk. | |
Raises: | |
In case of non-recoverable file (non-existent or inaccessible url + no cache on disk). | |
""" | |
if cache_dir is None: | |
cache_dir = ADAPTER_CACHE | |
if isinstance(cache_dir, Path): | |
cache_dir = str(cache_dir) | |
os.makedirs(cache_dir, exist_ok=True) | |
headers = {"user-agent": http_user_agent(user_agent)} | |
if isinstance(use_auth_token, str): | |
headers["authorization"] = f"Bearer {use_auth_token}" | |
elif use_auth_token: | |
token = HfFolder.get_token() | |
if token is None: | |
raise EnvironmentError("You specified use_auth_token=True, but a huggingface token was not found.") | |
headers["authorization"] = f"Bearer {token}" | |
url_to_download = url | |
etag = None | |
if not local_files_only: | |
try: | |
r = requests.head(url, headers=headers, allow_redirects=False, proxies=proxies, timeout=etag_timeout) | |
hf_raise_for_status(r) | |
etag = r.headers.get("X-Linked-Etag") or r.headers.get("ETag") | |
# We favor a custom header indicating the etag of the linked resource, and | |
# we fallback to the regular etag header. | |
# If we don't have any of those, raise an error. | |
if etag is None: | |
raise OSError( | |
"Distant resource does not have an ETag, we won't be able to reliably ensure reproducibility." | |
) | |
# In case of a redirect, | |
# save an extra redirect on the request.get call, | |
# and ensure we download the exact atomic version even if it changed | |
# between the HEAD and the GET (unlikely, but hey). | |
if 300 <= r.status_code <= 399: | |
url_to_download = r.headers["Location"] | |
except ( | |
requests.exceptions.SSLError, | |
requests.exceptions.ProxyError, | |
RepositoryNotFoundError, | |
EntryNotFoundError, | |
RevisionNotFoundError, | |
): | |
# Actually raise for those subclasses of ConnectionError | |
# Also raise the custom errors coming from a non existing repo/branch/file as they are caught later on. | |
raise | |
except (HTTPError, requests.exceptions.ConnectionError, requests.exceptions.Timeout): | |
# Otherwise, our Internet connection is down. | |
# etag is None | |
pass | |
filename = url_to_filename(url, etag) | |
# get cache path to put the file | |
cache_path = os.path.join(cache_dir, filename) | |
# etag is None == we don't have a connection or we passed local_files_only. | |
# try to get the last downloaded one | |
if etag is None: | |
if os.path.exists(cache_path): | |
return cache_path | |
else: | |
matching_files = [ | |
file | |
for file in fnmatch.filter(os.listdir(cache_dir), filename.split(".")[0] + ".*") | |
if not file.endswith(".json") and not file.endswith(".lock") | |
] | |
if len(matching_files) > 0: | |
return os.path.join(cache_dir, matching_files[-1]) | |
else: | |
# If files cannot be found and local_files_only=True, | |
# the models might've been found if local_files_only=False | |
# Notify the user about that | |
if local_files_only: | |
fname = url.split("/")[-1] | |
raise EntryNotFoundError( | |
f"Cannot find the requested file ({fname}) in the cached path and outgoing traffic has been" | |
" disabled. To enable model look-ups and downloads online, set 'local_files_only'" | |
" to False." | |
) | |
else: | |
raise ValueError( | |
"Connection error, and we cannot find the requested files in the cached path." | |
" Please try again or make sure your Internet connection is on." | |
) | |
# From now on, etag is not None. | |
if os.path.exists(cache_path) and not force_download: | |
return cache_path | |
# Prevent parallel downloads of the same file with a lock. | |
lock_path = cache_path + ".lock" | |
with FileLock(lock_path): | |
# If the download just completed while the lock was activated. | |
if os.path.exists(cache_path) and not force_download: | |
# Even if returning early like here, the lock will be released. | |
return cache_path | |
if resume_download: | |
incomplete_path = cache_path + ".incomplete" | |
def _resumable_file_manager() -> "io.BufferedWriter": | |
with open(incomplete_path, "ab") as f: | |
yield f | |
temp_file_manager = _resumable_file_manager | |
if os.path.exists(incomplete_path): | |
resume_size = os.stat(incomplete_path).st_size | |
else: | |
resume_size = 0 | |
else: | |
temp_file_manager = partial(tempfile.NamedTemporaryFile, mode="wb", dir=cache_dir, delete=False) | |
resume_size = 0 | |
# Download to temporary file, then copy to cache dir once finished. | |
# Otherwise you get corrupt cache entries if the download gets interrupted. | |
with temp_file_manager() as temp_file: | |
logger.info(f"{url} not found in cache or force_download set to True, downloading to {temp_file.name}") | |
http_get( | |
url_to_download, | |
temp_file, | |
proxies=proxies, | |
resume_size=resume_size, | |
headers=headers, | |
) | |
logger.info(f"storing {url} in cache at {cache_path}") | |
os.replace(temp_file.name, cache_path) | |
# NamedTemporaryFile creates a file with hardwired 0600 perms (ignoring umask), so fixing it. | |
umask = os.umask(0o666) | |
os.umask(umask) | |
os.chmod(cache_path, 0o666 & ~umask) | |
logger.info(f"creating metadata file for {cache_path}") | |
meta = {"url": url, "etag": etag} | |
meta_path = cache_path + ".json" | |
with open(meta_path, "w") as meta_file: | |
json.dump(meta, meta_file) | |
return cache_path | |
def download_cached(url, checksum=None, checksum_algo="sha1", cache_dir=None, force_extract=False, **kwargs): | |
""" | |
This method downloads a file and caches it. | |
For more information on why this is needed, refer to the explanation in this Pull Request: https://github.com/adapter-hub/adapters/pull/750 | |
""" | |
if isinstance(url, Path): | |
url = str(url) | |
if is_remote_url(url): | |
output_path = get_from_cache(url, cache_dir=cache_dir, **kwargs) | |
else: | |
raise ValueError("Unable to parse '{}' as a URL".format(url)) | |
if not output_path: | |
return None | |
# if checksum is given, verify it | |
if checksum and checksum_algo: | |
h = hashlib.new(checksum_algo) | |
with open(output_path, "rb") as f: | |
h.update(f.read()) | |
calculated_checksum = h.hexdigest() | |
if calculated_checksum != checksum.lower(): | |
raise EnvironmentError("Failed to verify checksum of '{}'".format(output_path)) | |
if not is_zipfile(output_path) and not tarfile.is_tarfile(output_path): | |
return output_path | |
# Path where we extract compressed archives | |
# We avoid '.' in dir name and add "-extracted" at the end: "./model.zip" => "./model-zip-extracted/" | |
output_dir, output_file = os.path.split(output_path) | |
output_extract_dir_name = output_file.replace(".", "-") + "-extracted" | |
output_path_extracted = os.path.join(output_dir, output_extract_dir_name) | |
if os.path.isdir(output_path_extracted) and os.listdir(output_path_extracted) and not force_extract: | |
return output_path_extracted | |
# Prevent parallel extractions | |
lock_path = output_path + ".lock" | |
with FileLock(lock_path): | |
shutil.rmtree(output_path_extracted, ignore_errors=True) | |
os.makedirs(output_path_extracted) | |
if is_zipfile(output_path): | |
with ZipFile(output_path, "r") as zip_file: | |
# we want to extract all files into a flat folder structure (i.e. no subfolders) | |
for file in zip_file.namelist(): | |
# check if we have a valid file | |
if basename(file): | |
file_data = zip_file.read(file) | |
with open(join(output_path_extracted, basename(file)), "wb") as f: | |
f.write(file_data) | |
elif tarfile.is_tarfile(output_path): | |
tar_file = tarfile.open(output_path) | |
tar_file.extractall(output_path_extracted) | |
tar_file.close() | |
else: | |
raise EnvironmentError("Archive format of {} could not be identified".format(output_path)) | |
return output_path_extracted | |
def parse_adapter_config_string(config_string: str) -> List[Tuple[str, dict]]: | |
""" | |
Parses an adapter configuration string into a list of tuples. Each tuple constists of an adapter config identifier | |
and dictionary. | |
""" | |
# First split by "|" into individual adapter configs | |
config_string_chunks = config_string.split("|") | |
# Now match each adapter config against the regex | |
adapter_configs = [] | |
for config_string_chunk in config_string_chunks: | |
match = re.match(ADAPTER_CONFIG_STRING_PATTERN, config_string_chunk.strip()) | |
if not match or not match.group("name"): | |
raise ValueError(f"Invalid adapter config string format: '{config_string_chunk}'.") | |
name = match.group("name") | |
if match.group("kvs"): | |
kvs = match.group("kvs") | |
# Replace "=" with ":" in key-value pairs for valid Python dict | |
kvs = re.sub(r"(\w+)=", r"'\1':", kvs) | |
else: | |
kvs = "" | |
# Now evaluate key-value pairs as Python dict | |
try: | |
config_kwargs = ast.literal_eval("{" + kvs + "}") | |
except Exception: | |
raise ValueError(f"Invalid adapter configguration '{kvs}' in '{name}'.") | |
adapter_configs.append((name, config_kwargs)) | |
return adapter_configs | |
def resolve_adapter_config(config: Union[dict, str], local_map=None, **kwargs) -> dict: | |
""" | |
Resolves a given adapter configuration specifier to a full configuration dictionary. | |
Args: | |
config (Union[dict, str]): The configuration to resolve. Can be either: | |
- a dictionary: returned without further action | |
- an identifier string available in local_map | |
- the path to a file containing a full adapter configuration | |
Returns: | |
dict: The resolved adapter configuration dictionary. | |
""" | |
# already a dict, so we don't have to do anything | |
if isinstance(config, Mapping): | |
return config | |
# first, look in local map | |
if local_map and config in local_map: | |
return local_map[config] | |
# load from file system if it's a local file | |
if isfile(config): | |
with open(config, "r") as f: | |
loaded_config = json.load(f) | |
# search for nested config if the loaded dict has the form of a config saved with an adapter module | |
if "config" in loaded_config: | |
return loaded_config["config"] | |
else: | |
return loaded_config | |
# parse the config string | |
config_pairs = parse_adapter_config_string(config) | |
if len(config_pairs) > 0: | |
full_configs = [] | |
for name, config_kwargs in config_pairs: | |
# first, look in local map | |
if local_map and name in local_map: | |
config_obj = local_map[name] | |
full_configs.append(config_obj.replace(**config_kwargs)) | |
else: | |
raise ValueError("Could not identify '{}' as a valid adapter configuration.".format(name)) | |
# Case 1: only one config, return it directly | |
if len(full_configs) == 1: | |
return full_configs[0] | |
# Case 2: multiple configs, return a config union | |
elif len(full_configs) > 1: | |
return {"architecture": "union", "configs": full_configs} | |
raise ValueError("Could not identify '{}' as a valid adapter configuration.".format(config)) | |
def _split_identifier(identifier): | |
task, subtask, org_name = None, None, None | |
identifier = identifier.split("@") | |
if len(identifier) > 1: | |
org_name = identifier[1] | |
identifier = identifier[0].split("/") | |
if len(identifier) > 1: | |
subtask = identifier[1] | |
task = identifier[0] | |
return task, subtask, org_name | |
def _dict_extract(d, primary_key, secondary_key=None): | |
for k, v in d.items(): | |
if k == primary_key: | |
if secondary_key: | |
if secondary_key in v.keys(): | |
yield v[secondary_key] | |
else: | |
for k, v in v.items(): | |
yield v | |
elif secondary_key is None: | |
for k, v in v.items(): | |
if k == primary_key: | |
yield v | |
def find_in_index( | |
identifier: str, | |
model_name: str, | |
adapter_config: Optional[dict] = None, | |
strict: bool = False, | |
index_file: str = None, | |
) -> Optional[str]: | |
identifier = identifier.strip() | |
# identifiers of form "@<org>/<file>" are unique and can be retrieved directly | |
match = re.match(r"@(\S+)\/(\S+)", identifier) | |
if match: | |
return ADAPTER_HUB_ADAPTER_ENTRY_JSON.format(match.group(1), match.group(2)) | |
if not index_file: | |
index_file = download_cached(ADAPTER_HUB_INDEX_FILE.format(model_name)) | |
if not index_file: | |
raise EnvironmentError("Unable to load adapter hub index file. The file might be temporarily unavailable.") | |
with open(index_file, "r") as f: | |
adapter_index = json.load(f) | |
# split into <task>/<subtask>@<org> | |
task, subtask, org = _split_identifier(identifier) | |
# find all entries for this task and subtask | |
entries = list(_dict_extract(adapter_index, task, subtask)) | |
if not entries: | |
# we found no matching entry | |
return None | |
elif len(entries) == 1: | |
index_entry = entries[0] | |
else: | |
# there are multiple possible options for this identifier | |
raise ValueError("Found multiple possible adapters matching '{}'.".format(identifier)) | |
# go on with searching a matching adapter_config hash in the task entry | |
if adapter_config: | |
config_hash = get_adapter_config_hash(adapter_config) | |
if config_hash in index_entry: | |
# now match the org if given | |
hub_entry = _get_matching_version(index_entry[config_hash], org) | |
if hub_entry: | |
logger.info("Found matching adapter at: {}".format(hub_entry)) | |
return hub_entry | |
# if we're here, no matching config is available or no config was given | |
if not adapter_config or not strict: | |
if "default" in index_entry: | |
logger.info("No exactly matching adapter config found for this specifier, falling back to default.") | |
return index_entry["default"] | |
# there's only one possible config and we allow matches with different configs | |
elif len(index_entry) == 1: | |
logger.info("Only one configuration available for this adapter, using default.") | |
config_entry = list(index_entry.values())[0] | |
return _get_matching_version(config_entry, org) | |
raise ValueError("No adapter '{}' found for the current model or configuration.".format(identifier)) | |
def _get_matching_version(config_entry, org): | |
if org: | |
return config_entry["versions"].get(org, None) | |
elif len(config_entry["versions"]) == 1: | |
return list(config_entry["versions"].values())[0] | |
elif "default" in config_entry: | |
return config_entry["default"] | |
else: | |
raise ValueError("Multiple adapters with this name are available for this config.") | |
def pull_from_hub( | |
specifier: str, | |
model_name: str, | |
adapter_config: Optional[Union[dict, str]] = None, | |
version: str = None, | |
strict: bool = False, | |
**kwargs, | |
) -> str: | |
""" | |
Redirects loading from the archived Hub repository to HuggingFace Model Hub. | |
Args: | |
specifier (str): A string specifying the adapter to be loaded. | |
model_name (str): The identifier of the pre-trained model for which to load an adapter. | |
adapter_config (Union[dict, str], optional): The configuration of the adapter to be loaded. | |
version (str, optional): The version of the adapter to be loaded. Defaults to None. | |
strict (bool, optional): | |
If set to True, only allow adapters exactly matching the given config to be loaded. Defaults to False. | |
Returns: | |
str: The local path to which the adapter has been downloaded. | |
""" | |
if not model_name: | |
raise ValueError("Unable to resolve adapter without the name of a model. Please specify model_name.") | |
# resolve config if it's an identifier | |
if adapter_config: | |
adapter_config = resolve_adapter_config(adapter_config) | |
# search the correct entry in the index | |
hub_entry_url = find_in_index(specifier, model_name, adapter_config=adapter_config, strict=strict) | |
if not hub_entry_url: | |
raise EnvironmentError("No adapter with name '{}' was found in the adapter index.".format(specifier)) | |
hf_hub_specifier = "AdapterHub/" + os.path.basename(hub_entry_url).split(".")[0] | |
logger.warning( | |
"Automatic redirect to HF Model Hub repo '{}'. Please switch to the new ID to remove this warning.".format( | |
hf_hub_specifier | |
) | |
) | |
return pull_from_hf_model_hub(hf_hub_specifier, version=version, **kwargs) | |
def pull_from_hf_model_hub(specifier: str, version: str = None, **kwargs) -> str: | |
download_path = snapshot_download( | |
specifier, | |
revision=version, | |
cache_dir=kwargs.pop("cache_dir", None), | |
library_name="adapters", | |
library_version=__version__, | |
) | |
return download_path | |
def resolve_adapter_path( | |
adapter_name_or_path, | |
model_name: str = None, | |
adapter_config: Union[dict, str] = None, | |
version: str = None, | |
**kwargs, | |
) -> str: | |
""" | |
Resolves the path to a pre-trained adapter module. Note: If attempting to resolve an adapter from the Hub, | |
adapter_config and model_name must be present. | |
Args: | |
adapter_name_or_path (str): Can be either: | |
- the path to a folder in the file system containing the adapter configuration and weights | |
- an url pointing to a zip folder containing the adapter configuration and weights | |
- a specifier matching a pre-trained adapter uploaded to Adapter-Hub | |
model_name (str, optional): The identifier of the pre-trained model for which to load an adapter. | |
adapter_config (Union[dict, str], optional): The configuration of the adapter to be loaded. | |
version (str, optional): The version of the adapter to be loaded. Defaults to None. | |
Returns: | |
str: The local path from where the adapter module can be loaded. | |
""" | |
# url of a folder containing pretrained adapters -> try to load from this url | |
if is_remote_url(adapter_name_or_path): | |
resolved_folder = download_cached(adapter_name_or_path, **kwargs) | |
if not resolved_folder: | |
raise EnvironmentError( | |
"Unable to load file from {}. The file might be unavailable.".format(resolved_folder) | |
) | |
return resolved_folder | |
# path to a local folder saved using save() | |
elif isdir(adapter_name_or_path): | |
if ( | |
isfile(join(adapter_name_or_path, WEIGHTS_NAME)) or isfile(join(adapter_name_or_path, SAFE_WEIGHTS_NAME)) | |
) and isfile(join(adapter_name_or_path, CONFIG_NAME)): | |
return adapter_name_or_path | |
else: | |
raise EnvironmentError( | |
"No file {} or no file {} found in directory {}".format( | |
WEIGHTS_NAME, CONFIG_NAME, adapter_name_or_path | |
) | |
) | |
else: | |
try: | |
logger.info("Attempting to load adapter from HF Model Hub...") | |
return pull_from_hf_model_hub(adapter_name_or_path, version=version, **kwargs) | |
except (EnvironmentError, ValueError) as ex: | |
logger.info(ex) | |
logger.info("Attempting to redirect from archived Hub repo...") | |
try: | |
return pull_from_hub( | |
adapter_name_or_path, | |
model_name, | |
adapter_config=adapter_config, | |
version=version, | |
redirect_to_hf_hub=True, | |
**kwargs, | |
) | |
except Exception as ex: | |
logger.info(ex) | |
raise EnvironmentError( | |
"Unable to load adapter {} from any source. Please check the name of the adapter or the source.".format( | |
adapter_name_or_path | |
) | |
) | |
def list_adapters(model_name: str = None) -> List[AdapterInfo]: | |
""" | |
Retrieves a list of all publicly available adapters on AdapterHub.ml or on huggingface.co. | |
Args: | |
model_name (str, optional): If specified, only returns adapters trained for the model with this identifier. | |
""" | |
adapters = [] | |
if "fetch_config" in inspect.signature(HfApi.list_models).parameters: | |
kwargs = {"full": True, "fetch_config": True} | |
else: | |
logger.warning( | |
"Using old version of huggingface-hub package for fetching. Please upgrade to latest version for" | |
" accurate results." | |
) | |
kwargs = {"full": True} | |
all_hf_adapters_data = HfApi().list_models(filter="adapters", **kwargs) | |
for model_info in all_hf_adapters_data: | |
adapter_info = AdapterInfo( | |
source="hf", | |
adapter_id=model_info.modelId, | |
model_name=model_info.config.get("adapters", {}).get("model_name") if model_info.config else None, | |
username=model_info.modelId.split("/")[0], | |
sha1_checksum=model_info.sha, | |
) | |
adapters.append(adapter_info) | |
if model_name is not None: | |
adapters = [adapter for adapter in adapters if adapter.model_name == model_name] | |
return adapters | |
def get_adapter_info(adapter_id: str) -> Optional[AdapterInfo]: | |
""" | |
Retrieves information about a specific adapter. | |
Args: | |
adapter_id (str): The identifier of the adapter to retrieve. | |
Returns: | |
AdapterInfo: The adapter information or None if the adapter was not found. | |
""" | |
try: | |
model_info = HfApi().model_info(adapter_id) | |
return AdapterInfo( | |
source="hf", | |
adapter_id=model_info.modelId, | |
model_name=( | |
model_info.config.get("adapter_transformers", {}).get("model_name") if model_info.config else None | |
), | |
username=model_info.modelId.split("/")[0], | |
sha1_checksum=model_info.sha, | |
) | |
except requests.exceptions.HTTPError: | |
return None | |
def prefix_attention_mask(attention_mask, dim: Union[int, List[int]] = 3, prefix_value: int = 0): | |
""" | |
Adds a prefix to an attention mask. The length of the prefix is determined by the `prefix_attention_mask_length` | |
attribute in the ForwardContext. | |
Args: | |
attention_mask: | |
The attention mask to add the prefix to. | |
dim (int): | |
The dimension along which to concatenate the prefix_attention_mask. Defaults to 3. | |
prefix_value (int): | |
The value to use for the prefix_attention_mask. Defaults to 0, however some models, e.g. DistilBert, use | |
different values. BERT like models invert their extended_attention_mask, hence they use 0 as value for not | |
masked tokens. This inversion is usually done in the forward method of the model in 2 different ways: | |
1) by calling self.invert_attention_mask, as BERT does 2) by doing the inversion manually, e.g. ALBERT | |
does: `extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(self.dtype).min` | |
""" | |
forward_context = ForwardContext.get_context() | |
if ( | |
attention_mask is not None | |
and forward_context is not None | |
and getattr(forward_context, "prompt_tokens_length", None) is not None | |
): | |
if isinstance(dim, int): | |
dim = [dim] | |
for d in dim: | |
# Create a tensor of ones with the desired shape | |
ones_shape = list(attention_mask.shape) | |
ones_shape[d] = forward_context.prompt_tokens_length | |
prefix_attention_mask = torch.full( | |
ones_shape, | |
prefix_value, | |
dtype=attention_mask.dtype, | |
).to(attention_mask.device) | |
# Concatenate the prefix_attention_mask along the specified dimension | |
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=d) | |
return attention_mask | |
def patch_forward(module: torch.nn.Module): | |
# HF Accelerate's `add_hook_to_module()` replaces the module forward method with a wrapper | |
# and stores the original forward method in `_old_forward`. For this to work with Adapters' post-hook wrapping, | |
# we need to explicitly set to potentially overriden forward methods on adapter init. | |
# The `add_hook_to_module()` method is e.g. used for `device_map="auto"` in the `PreTrainedModel.from_pretrained()` method. | |
if hasattr(module, "_old_forward"): | |
module._old_forward = module.__class__.forward.__get__(module, module.__class__) | |