Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import re | |
from typing import Callable, Dict, List, Optional, Tuple, Union | |
import torch | |
from torch import nn | |
from torch.utils.data.dataset import Dataset | |
from transformers import PreTrainedModel, Seq2SeqTrainer, Trainer, __version__ | |
from transformers.configuration_utils import PretrainedConfig | |
from transformers.data.data_collator import DataCollator | |
from transformers.modeling_utils import unwrap_model | |
from transformers.tokenization_utils_base import PreTrainedTokenizerBase | |
from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState | |
from transformers.trainer_utils import EvalPrediction | |
from transformers.training_args import TrainingArguments | |
from transformers.utils import CONFIG_NAME, WEIGHTS_NAME, is_sagemaker_mp_enabled, logging | |
from .composition import AdapterCompositionBlock, Fuse | |
if is_sagemaker_mp_enabled(): | |
import smdistributed.modelparallel.torch as smp | |
logger = logging.get_logger(__name__) | |
class AdapterTrainer(Trainer): | |
def __init__( | |
self, | |
model: Union[PreTrainedModel, nn.Module] = None, | |
args: TrainingArguments = None, | |
data_collator: Optional[DataCollator] = None, | |
train_dataset: Optional[Dataset] = None, | |
eval_dataset: Optional[Dataset] = None, | |
tokenizer: Optional[PreTrainedTokenizerBase] = None, | |
model_init: Callable[[], PreTrainedModel] = None, | |
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, | |
callbacks: Optional[List[TrainerCallback]] = None, | |
adapter_names: Optional[List[List[str]]] = None, | |
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), | |
preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None, | |
): | |
if model is not None: | |
model_quantized = getattr(model, "is_quantized", False) | |
model.is_quantized = False | |
super().__init__( | |
model, | |
args, | |
data_collator, | |
train_dataset, | |
eval_dataset, | |
tokenizer=tokenizer, | |
model_init=model_init, | |
compute_metrics=compute_metrics, | |
callbacks=[AdapterTrainerCallback(self)] + callbacks if callbacks else [AdapterTrainerCallback(self)], | |
optimizers=optimizers, | |
preprocess_logits_for_metrics=preprocess_logits_for_metrics, | |
) | |
if model is not None: | |
model.is_quantized = model_quantized | |
if adapter_names is not None: | |
self.model.set_active_adapters(adapter_names) | |
# Set the defaults for loading/ saving model & adapters | |
if isinstance(self.model, PreTrainedModel): | |
model_frozen = getattr(self.model.base_model, "model_frozen", False) | |
else: | |
model_frozen = False | |
if model_frozen and self.model.active_adapters: | |
# Check if training AdapterFusion | |
self.train_adapter_fusion = ( | |
isinstance(self.model.active_adapters, Fuse) | |
or isinstance(self.model.active_adapters, AdapterCompositionBlock) | |
and any([isinstance(child, Fuse) for child in self.model.active_adapters.children]) | |
) | |
if self.model.active_adapters is None: | |
raise ValueError( | |
"Expected a model with an active adapter setup." | |
"If you want to fully finetune the model use the Trainer class." | |
) | |
if (self.label_names is None or len(self.label_names) < 1) and self.model.active_head is not None: | |
all_label_names = set() | |
for head in self.model._active_heads: | |
all_label_names |= set(self.model.heads[head].get_label_names()) | |
self.label_names = list(all_label_names) | |
def create_optimizer(self): | |
""" | |
Setup the optimizer. | |
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the | |
Trainer's init through `optimizers`, or subclass and override this method in a subclass. | |
""" | |
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model | |
if self.optimizer is None: | |
decay_parameters = self.get_decay_parameter_names(opt_model) | |
if hasattr(self.model, "config") and hasattr(self.model.config, "adapters"): | |
match_str = r"adapter_fusion_layer\..*\.value" | |
decay_parameters = [name for name in decay_parameters if not re.match(match_str, name)] | |
optimizer_grouped_parameters = [ | |
{ | |
"params": [ | |
p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad) | |
], | |
"weight_decay": self.args.weight_decay, | |
}, | |
{ | |
"params": [ | |
p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad) | |
], | |
"weight_decay": 0.0, | |
}, | |
] | |
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) | |
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) | |
if is_sagemaker_mp_enabled(): | |
self.optimizer = smp.DistributedOptimizer(self.optimizer) | |
return self.optimizer | |
def _save(self, output_dir: Optional[str] = None, state_dict=None): | |
# If we are executing this function, we are the process zero, so we don't check for that. | |
output_dir = output_dir if output_dir is not None else self.args.output_dir | |
os.makedirs(output_dir, exist_ok=True) | |
logger.info(f"Saving model checkpoint to {output_dir}") | |
# Save a trained model and configuration using `save_pretrained()`. | |
# They can then be reloaded using `from_pretrained()` | |
if not isinstance(self.model, PreTrainedModel): | |
if isinstance(unwrap_model(self.model), PreTrainedModel): | |
if state_dict is None: | |
state_dict = self.model.state_dict() | |
unwrap_model(self.model).save_pretrained(output_dir, state_dict=state_dict) | |
else: | |
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") | |
if state_dict is None: | |
state_dict = self.model.state_dict() | |
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) | |
else: | |
self.model.save_all_adapters(output_dir) | |
if self.train_adapter_fusion: | |
self.model.save_all_adapter_fusions(output_dir) | |
if hasattr(self.model, "heads"): | |
self.model.save_all_heads(output_dir) | |
if self.tokenizer is not None: | |
self.tokenizer.save_pretrained(output_dir) | |
# Good practice: save your training arguments together with the trained model | |
torch.save(self.args, os.path.join(output_dir, "training_args.bin")) | |
def _load_from_checkpoint(self, resume_from_checkpoint): | |
args = self.args | |
if os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)): | |
logger.info(f"Loading model from {resume_from_checkpoint}).") | |
if os.path.isfile(os.path.join(resume_from_checkpoint, CONFIG_NAME)): | |
config = PretrainedConfig.from_json_file(os.path.join(resume_from_checkpoint, CONFIG_NAME)) | |
checkpoint_version = config.transformers_version | |
if checkpoint_version is not None and checkpoint_version != __version__: | |
logger.warn( | |
f"You are resuming training from a checkpoint trained with {checkpoint_version} of " | |
f"Transformers but your current version is {__version__}. This is not recommended and could " | |
"yield to errors or unwanted behaviors." | |
) | |
if args.deepspeed: | |
# will be resumed in deepspeed_init | |
pass | |
else: | |
adapter_loaded = False | |
if os.path.isdir(resume_from_checkpoint): | |
adapter_loaded = self._load_adapters(resume_from_checkpoint) | |
self._load_adapter_fusions(resume_from_checkpoint) | |
# Save all heads for a model with heads | |
if hasattr(self.model, "heads"): | |
self._load_heads(resume_from_checkpoint) | |
if not adapter_loaded: | |
raise Exception("Can't find a valid checkpoint at {}".format(resume_from_checkpoint)) | |
def _load_adapters(self, resume_from_checkpoint): | |
adapter_loaded = False | |
for file_name in os.listdir(resume_from_checkpoint): | |
if os.path.isdir(os.path.join(resume_from_checkpoint, file_name)): | |
if "," not in file_name and "adapter_config.json" in os.listdir( | |
os.path.join(resume_from_checkpoint, file_name) | |
): | |
self.model.load_adapter(os.path.join(os.path.join(resume_from_checkpoint, file_name))) | |
adapter_loaded = True | |
return adapter_loaded | |
def _load_adapter_fusions(self, resume_from_checkpoint): | |
for file_name in os.listdir(resume_from_checkpoint): | |
if os.path.isdir(os.path.join(resume_from_checkpoint, file_name)): | |
if "," in file_name: | |
self.model.load_adapter_fusion(os.path.join(resume_from_checkpoint, file_name)) | |
def _load_heads(self, resume_from_checkpoint): | |
for file_name in os.listdir(resume_from_checkpoint): | |
if os.path.isdir(os.path.join(resume_from_checkpoint, file_name)): | |
if "," not in file_name and "head_config.json" in os.listdir( | |
os.path.join(resume_from_checkpoint, file_name) | |
): | |
self.model.load_head(os.path.join(resume_from_checkpoint, file_name)) | |
def _load_best_model(self): | |
model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model | |
logger.info( | |
f"Loading best adapter(s) from {self.state.best_model_checkpoint} (score: {self.state.best_metric})." | |
) | |
# attempt to re-load all adapters from checkpoint | |
for adapter in model.adapters_config.adapters: | |
adapter_dir = os.path.join(self.state.best_model_checkpoint, adapter) | |
if os.path.exists(adapter_dir): | |
model.load_adapter(adapter_dir) | |
model.adapter_to(adapter, device=self.args.device) | |
if self.train_adapter_fusion: | |
logger.info( | |
f"Loading best adapter fusion(s) from {self.state.best_model_checkpoint} (score:" | |
f" {self.state.best_metric})." | |
) | |
# attempt to re-load all adapter fusions from checkpoint | |
for fusion in model.adapters_config.fusions: | |
fusion_dir = os.path.join(self.state.best_model_checkpoint, fusion) | |
if os.path.exists(fusion_dir): | |
model.load_adapter_fusion(fusion_dir) | |
model.adapter_fusion_to(fusion, device=self.args.device) | |
class AdapterTrainerCallback(TrainerCallback): | |
def __init__(self, trainer): | |
super().__init__() | |
self.trainer = trainer | |
def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): | |
model = kwargs.pop("model") | |
model_frozen = getattr(model.base_model, "model_frozen", False) | |
if not model_frozen: | |
raise ValueError( | |
"The pre-trained model weights are not frozen. For training adapters, please call the train_adapter()" | |
" method" | |
) | |
def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): | |
# apply adapter fusion weight regularization on the value matrix | |
model = kwargs.pop("model") | |
if self.trainer.train_adapter_fusion: | |
fusion_reg_loss = model.base_model.get_fusion_regularization_loss() | |
if fusion_reg_loss is not None: | |
fusion_reg_loss.backward() | |
class Seq2SeqAdapterTrainer(AdapterTrainer, Seq2SeqTrainer): | |
pass | |