import json import os import time from datetime import timedelta from typing import TYPE_CHECKING from transformers import TrainerCallback from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length from .constants import LOG_FILE_NAME from .logging import get_logger from .misc import fix_valuehead_checkpoint if TYPE_CHECKING: from transformers import TrainerControl, TrainerState, TrainingArguments logger = get_logger(__name__) class FixValueHeadModelCallback(TrainerCallback): def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): r""" Event called after a checkpoint save. """ if args.should_save: fix_valuehead_checkpoint( model=kwargs.pop("model"), output_dir=os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)), safe_serialization=args.save_safetensors, ) class LogCallback(TrainerCallback): def __init__(self, runner=None): self.runner = runner self.in_training = False self.start_time = time.time() self.cur_steps = 0 self.max_steps = 0 self.elapsed_time = "" self.remaining_time = "" def timing(self): cur_time = time.time() elapsed_time = cur_time - self.start_time avg_time_per_step = elapsed_time / self.cur_steps if self.cur_steps != 0 else 0 remaining_time = (self.max_steps - self.cur_steps) * avg_time_per_step self.elapsed_time = str(timedelta(seconds=int(elapsed_time))) self.remaining_time = str(timedelta(seconds=int(remaining_time))) def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): r""" Event called at the beginning of training. """ if state.is_local_process_zero: self.in_training = True self.start_time = time.time() self.max_steps = state.max_steps if os.path.exists(os.path.join(args.output_dir, LOG_FILE_NAME)) and args.overwrite_output_dir: logger.warning("Previous log file in this folder will be deleted.") os.remove(os.path.join(args.output_dir, LOG_FILE_NAME)) def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): r""" Event called at the end of training. """ if state.is_local_process_zero: self.in_training = False self.cur_steps = 0 self.max_steps = 0 def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): r""" Event called at the end of an substep during gradient accumulation. """ if state.is_local_process_zero and self.runner is not None and self.runner.aborted: control.should_epoch_stop = True control.should_training_stop = True def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): r""" Event called at the end of a training step. """ if state.is_local_process_zero: self.cur_steps = state.global_step self.timing() if self.runner is not None and self.runner.aborted: control.should_epoch_stop = True control.should_training_stop = True def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): r""" Event called after an evaluation phase. """ if state.is_local_process_zero and not self.in_training: self.cur_steps = 0 self.max_steps = 0 def on_predict( self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", *other, **kwargs ): r""" Event called after a successful prediction. """ if state.is_local_process_zero and not self.in_training: self.cur_steps = 0 self.max_steps = 0 def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs) -> None: r""" Event called after logging the last logs. """ if not state.is_local_process_zero: return logs = dict( current_steps=self.cur_steps, total_steps=self.max_steps, loss=state.log_history[-1].get("loss", None), eval_loss=state.log_history[-1].get("eval_loss", None), predict_loss=state.log_history[-1].get("predict_loss", None), reward=state.log_history[-1].get("reward", None), learning_rate=state.log_history[-1].get("learning_rate", None), epoch=state.log_history[-1].get("epoch", None), percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100, elapsed_time=self.elapsed_time, remaining_time=self.remaining_time, ) if self.runner is not None: logger.info( "{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}}}".format( logs["loss"] or 0, logs["learning_rate"] or 0, logs["epoch"] or 0 ) ) os.makedirs(args.output_dir, exist_ok=True) with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f: f.write(json.dumps(logs) + "\n") def on_prediction_step( self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs ): r""" Event called after a prediction step. """ eval_dataloader = kwargs.pop("eval_dataloader", None) if state.is_local_process_zero and has_length(eval_dataloader) and not self.in_training: if self.max_steps == 0: self.max_steps = len(eval_dataloader) self.cur_steps += 1 self.timing()