The Objective Dad
Merge branch 'OpenAccess-AI-Collective:main' into logging_enhancement
83237b8
unverified
| """Module containing the Trainer class and related functions""" | |
| import importlib | |
| import logging | |
| import math | |
| import os | |
| import sys | |
| from dataclasses import field | |
| from pathlib import Path | |
| from typing import Optional | |
| import bitsandbytes as bnb | |
| import torch.cuda | |
| import transformers | |
| from torch import nn | |
| from torch.optim.lr_scheduler import OneCycleLR | |
| from transformers import EarlyStoppingCallback, Trainer, TrainingArguments | |
| from transformers.trainer_pt_utils import get_parameter_names | |
| from axolotl.utils.callbacks import ( | |
| SaveBetterTransformerModelCallback, | |
| SavePeftModelCallback, | |
| ) | |
| from axolotl.utils.schedulers import ( | |
| InterpolatingLogScheduler, | |
| get_cosine_schedule_with_quadratic_warmup, | |
| ) | |
| LOG = logging.getLogger("axolotl") | |
| class AxolotlTrainingArguments(TrainingArguments): | |
| """ | |
| Extend the base TrainingArguments for axolotl helpers | |
| """ | |
| lr_quadratic_warmup: bool = field( | |
| default=False, | |
| metadata={"help": "Use quadratic warmup for cosine scheduling."}, | |
| ) | |
| class AxolotlTrainer(Trainer): | |
| """ | |
| Extend the base Trainer for axolotl helpers | |
| """ | |
| args = None # type: AxolotlTrainingArguments | |
| def create_scheduler( | |
| self, num_training_steps: int, optimizer: torch.optim.Optimizer = None | |
| ): | |
| """ | |
| Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or | |
| passed as an argument. | |
| Args: | |
| num_training_steps (int): The number of training steps to do. | |
| optimizer (torch.optim.Optimizer): The training optimizer | |
| """ | |
| # fmt: off | |
| if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition | |
| # fmt: on | |
| if ( | |
| self.args.lr_scheduler_type == "cosine" | |
| and self.args.lr_quadratic_warmup is True | |
| ): | |
| self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init | |
| optimizer, | |
| num_warmup_steps=self.args.get_warmup_steps(num_training_steps), | |
| num_training_steps=num_training_steps, | |
| ) | |
| else: | |
| return super().create_scheduler(num_training_steps, optimizer) | |
| return self.lr_scheduler | |
| class OneCycleLRSchedulerTrainer(AxolotlTrainer): | |
| """ | |
| Trainer subclass that uses the OneCycleLR scheduler | |
| """ | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.lr_scheduler = None | |
| def create_scheduler( | |
| self, | |
| num_training_steps: int, | |
| optimizer: Optional[torch.optim.Optimizer] = None, | |
| ): | |
| optimizer = self.optimizer if optimizer is None else optimizer | |
| num_warmup_steps = self.args.get_warmup_steps(num_training_steps) | |
| pct_start = num_warmup_steps / num_training_steps | |
| self.lr_scheduler = OneCycleLR( | |
| optimizer, | |
| max_lr=self.args.learning_rate, | |
| total_steps=num_training_steps, | |
| pct_start=pct_start, | |
| div_factor=6, | |
| ) | |
| return self.lr_scheduler | |
| def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): | |
| total_num_steps = int( | |
| math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) | |
| ) | |
| warmup_steps = ( | |
| cfg.warmup_steps | |
| if cfg.warmup_steps is not None | |
| else min(int(0.03 * total_num_steps), 100) | |
| ) | |
| logging_steps = ( | |
| cfg.logging_steps | |
| if cfg.logging_steps is not None | |
| else max(min(int(0.005 * total_num_steps), 10), 1) | |
| ) | |
| training_arguments_kwargs = {} | |
| if cfg.bf16 == "full": | |
| training_arguments_kwargs["bf16_full_eval"] = True | |
| else: | |
| training_arguments_kwargs["bf16"] = cfg.bf16 | |
| training_arguments_kwargs["fp16"] = (cfg.fp16 and not cfg.bf16) or False | |
| training_arguments_kwargs["tf32"] = cfg.tf32 | |
| training_arguments_kwargs["warmup_steps"] = warmup_steps | |
| training_arguments_kwargs["logging_steps"] = logging_steps | |
| if cfg.seed: | |
| training_arguments_kwargs["seed"] = cfg.seed | |
| if cfg.gradient_checkpointing: | |
| if cfg.gptq: | |
| from alpaca_lora_4bit.gradient_checkpointing import ( | |
| apply_gradient_checkpointing, | |
| ) | |
| gradient_checkpointing_ratio = ( | |
| cfg.gradient_checkpointing_ratio | |
| if cfg.gradient_checkpointing_ratio | |
| else 1.0 | |
| ) | |
| apply_gradient_checkpointing( | |
| model, checkpoint_ratio=gradient_checkpointing_ratio | |
| ) | |
| else: | |
| training_arguments_kwargs[ | |
| "gradient_checkpointing" | |
| ] = cfg.gradient_checkpointing | |
| if cfg.fsdp: | |
| training_arguments_kwargs["fsdp"] = cfg.fsdp | |
| if cfg.fsdp_config: | |
| training_arguments_kwargs["fsdp_config"] = dict(cfg.fsdp_config) | |
| if cfg.lr_quadratic_warmup is not None: | |
| training_arguments_kwargs["lr_quadratic_warmup"] = cfg.lr_quadratic_warmup | |
| # deepspeed | |
| if ( | |
| os.environ.get("ACCELERATE_USE_DEEPSPEED") == "true" | |
| and torch.cuda.device_count() > 1 | |
| ): | |
| if cfg.deepspeed: | |
| training_arguments_kwargs["deepspeed"] = cfg.deepspeed | |
| else: | |
| # make a guess here | |
| # TODO search Path("./") for one | |
| training_arguments_kwargs["deepspeed"] = "./ds_config.json" | |
| if cfg.adam_beta1: | |
| training_arguments_kwargs["adam_beta1"] = cfg.adam_beta1 | |
| if cfg.adam_beta2: | |
| training_arguments_kwargs["adam_beta2"] = cfg.adam_beta2 | |
| if cfg.adam_epsilon: | |
| training_arguments_kwargs["adam_epsilon"] = cfg.adam_epsilon | |
| if cfg.max_grad_norm: | |
| training_arguments_kwargs["max_grad_norm"] = cfg.max_grad_norm | |
| if cfg.hub_model_id: | |
| training_arguments_kwargs["hub_model_id"] = cfg.hub_model_id | |
| training_arguments_kwargs["push_to_hub"] = True | |
| training_arguments_kwargs["hub_private_repo"] = True | |
| if cfg.save_safetensors: | |
| training_arguments_kwargs["save_safetensors"] = cfg.save_safetensors | |
| training_args = AxolotlTrainingArguments( | |
| per_device_train_batch_size=cfg.micro_batch_size, | |
| per_device_eval_batch_size=cfg.eval_batch_size | |
| if cfg.eval_batch_size is not None | |
| else cfg.micro_batch_size, | |
| gradient_accumulation_steps=cfg.gradient_accumulation_steps, | |
| eval_accumulation_steps=cfg.gradient_accumulation_steps, | |
| num_train_epochs=cfg.num_epochs, | |
| learning_rate=cfg.learning_rate, | |
| evaluation_strategy="steps" if cfg.val_set_size > 0 else "no", | |
| save_strategy="steps" if cfg.save_steps else "epoch", | |
| eval_steps=cfg.eval_steps if cfg.val_set_size > 0 else None, | |
| save_steps=cfg.save_steps, | |
| output_dir=cfg.output_dir, | |
| save_total_limit=3, | |
| load_best_model_at_end=( | |
| cfg.load_best_model_at_end is not False | |
| and cfg.val_set_size > 0 | |
| and cfg.save_steps | |
| and cfg.save_steps % cfg.eval_steps == 0 | |
| and cfg.load_in_8bit is not True | |
| ) | |
| or False, | |
| ddp_find_unused_parameters=False if cfg.ddp else None, | |
| group_by_length=cfg.group_by_length, | |
| report_to="wandb" if cfg.use_wandb else None, | |
| run_name=cfg.wandb_run_id if cfg.use_wandb else None, | |
| optim=cfg.optimizer if cfg.optimizer else "adamw_hf", | |
| lr_scheduler_type=cfg.lr_scheduler | |
| if cfg.lr_scheduler and cfg.lr_scheduler not in ("one_cycle", "log_sweep") | |
| else "cosine", | |
| weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0, | |
| **training_arguments_kwargs, | |
| ) | |
| trainer_kwargs = {} | |
| if cfg.optimizer == "adamw_anyprecision": | |
| if Path(cfg.torchdistx_path).exists(): | |
| sys.path.append(cfg.torchdistx_path) | |
| importlib.import_module("torchdistx") | |
| if ( | |
| cfg.optimizer == "adamw_bnb_8bit" | |
| and not cfg.gptq | |
| and "deepspeed" not in training_arguments_kwargs | |
| and not cfg.fsdp | |
| ): | |
| decay_parameters = get_parameter_names(model, [nn.LayerNorm]) | |
| decay_parameters = [name for name in decay_parameters if "bias" not in name] | |
| optimizer_grouped_parameters = [ | |
| { | |
| "params": [ | |
| p | |
| for n, p in model.named_parameters() | |
| if (n in decay_parameters and p.requires_grad) | |
| ], | |
| "weight_decay": training_args.weight_decay, | |
| }, | |
| { | |
| "params": [ | |
| p | |
| for n, p in model.named_parameters() | |
| if (n not in decay_parameters and p.requires_grad) | |
| ], | |
| "weight_decay": 0.0, | |
| }, | |
| ] | |
| optimizer = bnb.optim.Adam8bit( | |
| optimizer_grouped_parameters, | |
| betas=(training_args.adam_beta1, training_args.adam_beta2), | |
| eps=training_args.adam_epsilon, | |
| lr=training_args.learning_rate, | |
| ) | |
| if cfg.lr_scheduler == "one_cycle": | |
| lr_scheduler_kwargs = ( | |
| cfg.lr_scheduler_kwargs if cfg.lr_scheduler_kwargs else {} | |
| ) | |
| lr_scheduler = OneCycleLR( | |
| optimizer, | |
| cfg.learning_rate, | |
| total_steps=total_num_steps, | |
| epochs=cfg.num_epochs, | |
| div_factor=cfg.lr_div_factor if cfg.lr_div_factor else 6, | |
| **lr_scheduler_kwargs, | |
| ) | |
| elif cfg.lr_scheduler == "log_sweep": | |
| lr_scheduler = InterpolatingLogScheduler( | |
| optimizer, | |
| cfg.warmup_steps, | |
| cfg.log_sweep_min_lr if cfg.log_sweep_min_lr else 1e-10, | |
| cfg.log_sweep_max_lr if cfg.log_sweep_max_lr else 10, | |
| ) | |
| else: | |
| lr_scheduler = transformers.get_cosine_schedule_with_warmup( | |
| optimizer, | |
| training_args.warmup_steps, | |
| total_num_steps, | |
| ) | |
| trainer_kwargs["optimizers"] = (optimizer, lr_scheduler) | |
| callbacks = [] | |
| # TODO on_save callback to sync checkpoints to GCP/AWS in background | |
| if cfg.early_stopping_patience: | |
| early_stop_cb = EarlyStoppingCallback( | |
| cfg.early_stopping_patience, | |
| ) | |
| callbacks.append(early_stop_cb) | |
| if cfg.local_rank == 0 and cfg.adapter in [ | |
| "lora", | |
| "qlora", | |
| ]: # only save in rank 0 | |
| callbacks.append(SavePeftModelCallback) | |
| if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True: | |
| callbacks.append(SaveBetterTransformerModelCallback) | |
| data_collator_kwargs = { | |
| "padding": True, | |
| } | |
| if cfg.collator_pad_to_longest: | |
| data_collator_kwargs["padding"] = "longest" | |
| else: | |
| data_collator_kwargs["pad_to_multiple_of"] = 8 | |
| if cfg.is_llama_derived_model and cfg.landmark_attention: | |
| from functools import partial | |
| from axolotl.monkeypatch.llama_landmark_attn import ( | |
| add_mem_tokens, | |
| get_mem_id, | |
| set_model_mem_id, | |
| ) | |
| set_model_mem_id(model, tokenizer) | |
| LOG.info("Adding landmark attention tokens to dataset") | |
| for dataset in [train_dataset, eval_dataset]: | |
| dataset = dataset.map( | |
| partial(add_mem_tokens, mem_freq=50, mem_id=get_mem_id(tokenizer)), | |
| batched=False, | |
| num_proc=32, | |
| ) | |
| trainer_cls = ( | |
| OneCycleLRSchedulerTrainer | |
| if cfg.lr_scheduler == "one_cycle" and (cfg.fsdp or cfg.adapter == "qlora") | |
| else AxolotlTrainer | |
| ) | |
| trainer = trainer_cls( | |
| model=model, | |
| train_dataset=train_dataset, | |
| eval_dataset=eval_dataset, | |
| args=training_args, | |
| data_collator=transformers.DataCollatorForSeq2Seq( | |
| tokenizer, | |
| return_tensors="pt", | |
| **data_collator_kwargs, | |
| ), | |
| callbacks=callbacks, | |
| **trainer_kwargs, | |
| ) | |
| return trainer | |