# Copyright 2025 the LlamaFactory team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import json import os import signal import sys import time from concurrent.futures import ThreadPoolExecutor from datetime import timedelta from typing import TYPE_CHECKING, Any, Dict, Optional import torch import transformers from peft import PeftModel from transformers import PreTrainedModel, ProcessorMixin, TrainerCallback from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length from transformers.utils import ( SAFE_WEIGHTS_NAME, WEIGHTS_NAME, is_safetensors_available, ) from typing_extensions import override from ..extras import logging from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME from ..extras.misc import get_peak_memory, is_env_enabled, use_ray if is_safetensors_available(): from safetensors import safe_open from safetensors.torch import save_file if TYPE_CHECKING: from transformers import TrainerControl, TrainerState, TrainingArguments from trl import AutoModelForCausalLMWithValueHead from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments logger = logging.get_logger(__name__) def fix_valuehead_checkpoint( model: "AutoModelForCausalLMWithValueHead", output_dir: str, safe_serialization: bool ) -> None: r""" The model is already unwrapped. There are three cases: 1. full tuning without ds_zero3: state_dict = {"model.layers.*": ..., "v_head.summary.*": ...} 2. lora tuning without ds_zero3: state_dict = {"v_head.summary.*": ...} 3. under deepspeed zero3: state_dict = {"pretrained_model.model.layers.*": ..., "v_head.summary.*": ...} We assume `stage3_gather_16bit_weights_on_model_save=true`. """ if not isinstance(model.pretrained_model, (PreTrainedModel, PeftModel)): return if safe_serialization: path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME) with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f: state_dict: Dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()} else: path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME) state_dict: Dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu") os.remove(path_to_checkpoint) decoder_state_dict, v_head_state_dict = {}, {} for name, param in state_dict.items(): if name.startswith("v_head."): v_head_state_dict[name] = param else: decoder_state_dict[name.replace("pretrained_model.", "", 1)] = param model.pretrained_model.save_pretrained( output_dir, state_dict=decoder_state_dict or None, safe_serialization=safe_serialization ) if safe_serialization: save_file(v_head_state_dict, os.path.join(output_dir, V_HEAD_SAFE_WEIGHTS_NAME), metadata={"format": "pt"}) else: torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME)) logger.info_rank0(f"Value head model saved at: {output_dir}") class FixValueHeadModelCallback(TrainerCallback): r""" A callback for fixing the checkpoint for valuehead models. """ @override def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): if args.should_save: output_dir = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}") fix_valuehead_checkpoint( model=kwargs.pop("model"), output_dir=output_dir, safe_serialization=args.save_safetensors ) class SaveProcessorCallback(TrainerCallback): r""" A callback for saving the processor. """ def __init__(self, processor: "ProcessorMixin") -> None: self.processor = processor @override def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): if args.should_save: output_dir = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}") self.processor.save_pretrained(output_dir) @override def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): if args.should_save: self.processor.save_pretrained(args.output_dir) class PissaConvertCallback(TrainerCallback): r""" A callback for converting the PiSSA adapter to a normal one. """ @override def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): if args.should_save: model = kwargs.pop("model") pissa_init_dir = os.path.join(args.output_dir, "pissa_init") logger.info_rank0(f"Initial PiSSA adapter will be saved at: {pissa_init_dir}.") if isinstance(model, PeftModel): init_lora_weights = getattr(model.peft_config["default"], "init_lora_weights") setattr(model.peft_config["default"], "init_lora_weights", True) model.save_pretrained(pissa_init_dir, safe_serialization=args.save_safetensors) setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights) @override def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): if args.should_save: model = kwargs.pop("model") pissa_init_dir = os.path.join(args.output_dir, "pissa_init") pissa_backup_dir = os.path.join(args.output_dir, "pissa_backup") pissa_convert_dir = os.path.join(args.output_dir, "pissa_converted") logger.info_rank0(f"Converted PiSSA adapter will be saved at: {pissa_convert_dir}.") # 1. save a pissa backup with init_lora_weights: True # 2. save a converted lora with init_lora_weights: pissa # 3. load the pissa backup with init_lora_weights: True # 4. delete the initial adapter and change init_lora_weights to pissa if isinstance(model, PeftModel): init_lora_weights = getattr(model.peft_config["default"], "init_lora_weights") setattr(model.peft_config["default"], "init_lora_weights", True) model.save_pretrained(pissa_backup_dir, safe_serialization=args.save_safetensors) setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights) model.save_pretrained( pissa_convert_dir, safe_serialization=args.save_safetensors, convert_pissa_to_lora=pissa_init_dir ) # TODO: use `path_initial_model_for_weight_conversion` (peft>=0.12.0) model.load_adapter(pissa_backup_dir, "default", is_trainable=True) model.set_adapter("default") if "pissa_init" in model.peft_config.keys(): # backward compatibility (peft<0.12.0) model.delete_adapter("pissa_init") setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights) class LogCallback(TrainerCallback): r""" A callback for logging training and evaluation status. """ def __init__(self) -> None: # Progress self.start_time = 0 self.cur_steps = 0 self.max_steps = 0 self.elapsed_time = "" self.remaining_time = "" self.thread_pool: Optional["ThreadPoolExecutor"] = None # Status self.aborted = False self.do_train = False # Web UI self.webui_mode = is_env_enabled("LLAMABOARD_ENABLED") if self.webui_mode and not use_ray(): signal.signal(signal.SIGABRT, self._set_abort) self.logger_handler = logging.LoggerHandler(os.environ.get("LLAMABOARD_WORKDIR")) logging.add_handler(self.logger_handler) transformers.logging.add_handler(self.logger_handler) def _set_abort(self, signum, frame) -> None: self.aborted = True def _reset(self, max_steps: int = 0) -> None: self.start_time = time.time() self.cur_steps = 0 self.max_steps = max_steps self.elapsed_time = "" self.remaining_time = "" def _timing(self, cur_steps: int) -> None: cur_time = time.time() elapsed_time = cur_time - self.start_time avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0 remaining_time = (self.max_steps - cur_steps) * avg_time_per_step self.cur_steps = cur_steps self.elapsed_time = str(timedelta(seconds=int(elapsed_time))) self.remaining_time = str(timedelta(seconds=int(remaining_time))) def _write_log(self, output_dir: str, logs: Dict[str, Any]) -> None: with open(os.path.join(output_dir, TRAINER_LOG), "a", encoding="utf-8") as f: f.write(json.dumps(logs) + "\n") def _create_thread_pool(self, output_dir: str) -> None: os.makedirs(output_dir, exist_ok=True) self.thread_pool = ThreadPoolExecutor(max_workers=1) def _close_thread_pool(self) -> None: if self.thread_pool is not None: self.thread_pool.shutdown(wait=True) self.thread_pool = None @override def on_init_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): if ( args.should_save and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG)) and args.overwrite_output_dir ): logger.warning_rank0_once("Previous trainer log in this folder will be deleted.") os.remove(os.path.join(args.output_dir, TRAINER_LOG)) @override def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): if args.should_save: self.do_train = True self._reset(max_steps=state.max_steps) self._create_thread_pool(output_dir=args.output_dir) @override def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): self._close_thread_pool() @override def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): if self.aborted: control.should_epoch_stop = True control.should_training_stop = True @override def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): if self.aborted: control.should_epoch_stop = True control.should_training_stop = True @override def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): if not self.do_train: self._close_thread_pool() @override def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): if not self.do_train: self._close_thread_pool() @override def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): if not args.should_save: return self._timing(cur_steps=state.global_step) logs = dict( current_steps=self.cur_steps, total_steps=self.max_steps, loss=state.log_history[-1].get("loss"), eval_loss=state.log_history[-1].get("eval_loss"), predict_loss=state.log_history[-1].get("predict_loss"), reward=state.log_history[-1].get("reward"), accuracy=state.log_history[-1].get("rewards/accuracies"), lr=state.log_history[-1].get("learning_rate"), epoch=state.log_history[-1].get("epoch"), percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100, elapsed_time=self.elapsed_time, remaining_time=self.remaining_time, ) if state.num_input_tokens_seen: logs["throughput"] = round(state.num_input_tokens_seen / (time.time() - self.start_time), 2) logs["total_tokens"] = state.num_input_tokens_seen if is_env_enabled("RECORD_VRAM"): vram_allocated, vram_reserved = get_peak_memory() logs["vram_allocated"] = round(vram_allocated / (1024**3), 2) logs["vram_reserved"] = round(vram_reserved / (1024**3), 2) logs = {k: v for k, v in logs.items() if v is not None} if self.webui_mode and all(key in logs for key in ("loss", "lr", "epoch")): log_str = f"'loss': {logs['loss']:.4f}, 'learning_rate': {logs['lr']:2.4e}, 'epoch': {logs['epoch']:.2f}" for extra_key in ("reward", "accuracy", "throughput"): if logs.get(extra_key): log_str += f", '{extra_key}': {logs[extra_key]:.2f}" logger.info_rank0("{" + log_str + "}") if self.thread_pool is not None: self.thread_pool.submit(self._write_log, args.output_dir, logs) @override def on_prediction_step( self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs ): if self.do_train: return if self.aborted: sys.exit(0) if not args.should_save: return eval_dataloader = kwargs.pop("eval_dataloader", None) if has_length(eval_dataloader): if self.max_steps == 0: self._reset(max_steps=len(eval_dataloader)) self._create_thread_pool(output_dir=args.output_dir) self._timing(cur_steps=self.cur_steps + 1) if self.cur_steps % 5 == 0 and self.thread_pool is not None: logs = dict( current_steps=self.cur_steps, total_steps=self.max_steps, percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100, elapsed_time=self.elapsed_time, remaining_time=self.remaining_time, ) self.thread_pool.submit(self._write_log, args.output_dir, logs) class ReporterCallback(TrainerCallback): r""" A callback for reporting training status to external logger. """ def __init__( self, model_args: "ModelArguments", data_args: "DataArguments", finetuning_args: "FinetuningArguments", generating_args: "GeneratingArguments", ) -> None: self.model_args = model_args self.data_args = data_args self.finetuning_args = finetuning_args self.generating_args = generating_args os.environ["WANDB_PROJECT"] = os.getenv("WANDB_PROJECT", "llamafactory") @override def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): if not state.is_world_process_zero: return if "wandb" in args.report_to: import wandb wandb.config.update( { "model_args": self.model_args.to_dict(), "data_args": self.data_args.to_dict(), "finetuning_args": self.finetuning_args.to_dict(), "generating_args": self.generating_args.to_dict(), } ) if self.finetuning_args.use_swanlab: import swanlab # type: ignore swanlab.config.update( { "model_args": self.model_args.to_dict(), "data_args": self.data_args.to_dict(), "finetuning_args": self.finetuning_args.to_dict(), "generating_args": self.generating_args.to_dict(), } )