Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,191 Bytes
d1ed09d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
from dataclasses import dataclass, field
from typing import Optional
from .composition import Stack
from .configuration import AdapterConfig
@dataclass
class AdapterArguments:
"""
The subset of arguments related to adapter training.
Args:
train_adapter (bool): Whether to train an adapter instead of the full model.
load_adapter (str): Pre-trained adapter module to be loaded from Hub.
adapter_config (str): Adapter configuration. Either a config string or a path to a file.
load_lang_adapter (str): Pre-trained language adapter module to be loaded from Hub.
lang_adapter_config (str): Language adapter configuration. Either an identifier or a path to a file.
"""
train_adapter: bool = field(default=False, metadata={"help": "Train an adapter instead of the full model."})
load_adapter: Optional[str] = field(
default="", metadata={"help": "Pre-trained adapter module to be loaded from Hub."}
)
adapter_config: Optional[str] = field(
default="seq_bn", metadata={"help": "Adapter configuration. Either a config string or a path to a file."}
)
load_lang_adapter: Optional[str] = field(
default=None, metadata={"help": "Pre-trained language adapter module to be loaded from Hub."}
)
lang_adapter_config: Optional[str] = field(
default=None, metadata={"help": "Language adapter configuration. Either an identifier or a path to a file."}
)
def setup_adapter_training(
model,
adapter_args: AdapterArguments,
adapter_name: str,
adapter_config_kwargs: Optional[dict] = None,
adapter_load_kwargs: Optional[dict] = None,
):
"""Setup model for adapter training based on given adapter arguments.
Args:
model (_type_): The model instance to be trained.
adapter_args (AdapterArguments): The adapter arguments used for configuration.
adapter_name (str): The name of the adapter to be added.
Returns:
Tuple[str, str]: A tuple containing the names of the loaded adapters.
"""
if adapter_config_kwargs is None:
adapter_config_kwargs = {}
if adapter_load_kwargs is None:
adapter_load_kwargs = {}
# Setup adapters
if adapter_args.train_adapter:
# resolve the adapter config
adapter_config = AdapterConfig.load(adapter_args.adapter_config, **adapter_config_kwargs)
# load a pre-trained from Hub if specified
# note: this logic has changed in versions > 3.1.0: adapter is also loaded if it already exists
if adapter_args.load_adapter:
model.load_adapter(
adapter_args.load_adapter,
config=adapter_config,
load_as=adapter_name,
**adapter_load_kwargs,
)
# otherwise, if adapter does not exist, add it
elif adapter_name not in model.adapters_config:
model.add_adapter(adapter_name, config=adapter_config)
# optionally load a pre-trained language adapter
if adapter_args.load_lang_adapter:
# resolve the language adapter config
lang_adapter_config = AdapterConfig.load(adapter_args.lang_adapter_config, **adapter_config_kwargs)
# load the language adapter from Hub
lang_adapter_name = model.load_adapter(
adapter_args.load_lang_adapter,
config=lang_adapter_config,
**adapter_load_kwargs,
)
else:
lang_adapter_name = None
# Freeze all model weights except of those of this adapter
model.train_adapter(adapter_name)
# Set the adapters to be used in every forward pass
if lang_adapter_name:
model.set_active_adapters(Stack(lang_adapter_name, adapter_name))
else:
model.set_active_adapters(adapter_name)
return adapter_name, lang_adapter_name
else:
if adapter_args.load_adapter or adapter_args.load_lang_adapter:
raise ValueError(
"Adapters can only be loaded in adapters training mode.Use --train_adapter to enable adapter training"
)
return None, None
|