Spaces:
Paused
Paused
| # Copyright 2020-2025 The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import os | |
| from typing import Optional, Union | |
| import pandas as pd | |
| import torch | |
| from accelerate import Accelerator | |
| from accelerate.state import AcceleratorState | |
| from accelerate.utils import gather_object, is_wandb_available | |
| from transformers import ( | |
| GenerationConfig, | |
| PreTrainedModel, | |
| PreTrainedTokenizerBase, | |
| Trainer, | |
| TrainerCallback, | |
| TrainerControl, | |
| TrainerState, | |
| TrainingArguments, | |
| ) | |
| from transformers.trainer_utils import has_length | |
| from transformers.utils import is_rich_available | |
| from ..data_utils import maybe_apply_chat_template | |
| from ..import_utils import is_mergekit_available | |
| from ..mergekit_utils import MergeConfig, merge_models, upload_model_to_hf | |
| from ..models.utils import unwrap_model_for_generation | |
| from .judges import BasePairwiseJudge | |
| from .utils import log_table_to_comet_experiment | |
| if is_rich_available(): | |
| from rich.console import Console, Group | |
| from rich.live import Live | |
| from rich.panel import Panel | |
| from rich.progress import Progress | |
| if is_wandb_available(): | |
| import wandb | |
| def _generate_completions( | |
| prompts: list[str], | |
| model: PreTrainedModel, | |
| tokenizer: PreTrainedTokenizerBase, | |
| accelerator: Accelerator, | |
| generation_config: Optional[GenerationConfig], | |
| batch_size: int = 1, | |
| ) -> list[str]: | |
| """ | |
| Generates completions for a list of pre-formatted prompts from the given model. | |
| Args: | |
| prompts (list[str]): A list of input prompts for which completions are to be generated. | |
| model (PreTrainedModel): The pre-trained model to be used for generation. | |
| tokenizer (PreTrainedTokenizerBase): The tokenizer to be used for encoding and decoding. | |
| accelerator (Accelerator): The accelerator to be used for model execution. | |
| generation_config (GenerationConfig): Configuration for text generation. | |
| batch_size (int, optional): The number of prompts to process in each batch. Default is 1. | |
| Returns: | |
| list[str]: A list of generated text completions corresponding to the input prompts. | |
| """ | |
| completions = [] | |
| with unwrap_model_for_generation(model, accelerator) as unwrapped_model: | |
| for idx in range(0, len(prompts), batch_size): | |
| batch = prompts[idx : idx + batch_size] | |
| tokenized_batch = tokenizer(batch, return_tensors="pt", padding=True, truncation=True).to(model.device) | |
| generations = unwrapped_model.generate( | |
| **tokenized_batch, | |
| generation_config=generation_config, | |
| ) | |
| for prompt, generation in zip(tokenized_batch.input_ids, generations): | |
| # Remove prompt from generation | |
| generation = generation[len(prompt) :] | |
| completion = tokenizer.decode(generation, skip_special_tokens=True) | |
| completions.append(completion) | |
| return completions | |
| class SyncRefModelCallback(TrainerCallback): | |
| """ | |
| Callback to synchronize the model with a reference model. | |
| """ | |
| def __init__( | |
| self, | |
| ref_model: Union[PreTrainedModel, torch.nn.Module], | |
| accelerator: Optional[Accelerator], | |
| ): | |
| self.accelerator = accelerator | |
| self.ref_model = ref_model | |
| def _sync_target_model(model, target_model, alpha): | |
| for target_param, copy_param in zip(target_model.parameters(), model.parameters()): | |
| target_param.data.mul_(1.0 - alpha).add_(copy_param.data, alpha=alpha) | |
| def sync_target_model(model, target_model, alpha): | |
| deepspeed_plugin = AcceleratorState().deepspeed_plugin | |
| if deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3: | |
| import deepspeed | |
| with deepspeed.zero.GatheredParameters( | |
| list(model.parameters()) + list(target_model.parameters()), modifier_rank=0 | |
| ): | |
| if deepspeed.comm.get_rank() == 0: | |
| SyncRefModelCallback._sync_target_model(model, target_model, alpha) | |
| else: | |
| SyncRefModelCallback._sync_target_model(model, target_model, alpha) | |
| def on_step_end(self, args, state, control, **kwargs): | |
| model: PreTrainedModel = kwargs["model"] | |
| if self.ref_model is not None and state.global_step % args.ref_model_sync_steps == 0: | |
| if self.accelerator: | |
| model = self.accelerator.unwrap_model(model) | |
| self.sync_target_model(model, self.ref_model, args.ref_model_mixup_alpha) | |
| class RichProgressCallback(TrainerCallback): | |
| """ | |
| A [`TrainerCallback`] that displays the progress of training or evaluation using Rich. | |
| """ | |
| def __init__(self): | |
| if not is_rich_available(): | |
| raise ImportError("RichProgressCallback requires the `rich` extra. To install, run `pip install rich`.") | |
| self.training_bar = None | |
| self.prediction_bar = None | |
| self.training_task_id = None | |
| self.prediction_task_id = None | |
| self.rich_group = None | |
| self.rich_console = None | |
| self.training_status = None | |
| self.current_step = None | |
| def on_train_begin(self, args, state, control, **kwargs): | |
| if state.is_world_process_zero: | |
| self.training_bar = Progress() | |
| self.prediction_bar = Progress() | |
| self.rich_console = Console() | |
| self.training_status = self.rich_console.status("Nothing to log yet ...") | |
| self.rich_group = Live(Panel(Group(self.training_bar, self.prediction_bar, self.training_status))) | |
| self.rich_group.start() | |
| self.training_task_id = self.training_bar.add_task("[blue]Training the model", total=state.max_steps) | |
| self.current_step = 0 | |
| def on_step_end(self, args, state, control, **kwargs): | |
| if state.is_world_process_zero: | |
| self.training_bar.update(self.training_task_id, advance=state.global_step - self.current_step, update=True) | |
| self.current_step = state.global_step | |
| def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs): | |
| if state.is_world_process_zero and has_length(eval_dataloader): | |
| if self.prediction_task_id is None: | |
| self.prediction_task_id = self.prediction_bar.add_task( | |
| "[blue]Predicting on the evaluation dataset", total=len(eval_dataloader) | |
| ) | |
| self.prediction_bar.update(self.prediction_task_id, advance=1, update=True) | |
| def on_evaluate(self, args, state, control, **kwargs): | |
| if state.is_world_process_zero: | |
| if self.prediction_task_id is not None: | |
| self.prediction_bar.remove_task(self.prediction_task_id) | |
| self.prediction_task_id = None | |
| def on_predict(self, args, state, control, **kwargs): | |
| if state.is_world_process_zero: | |
| if self.prediction_task_id is not None: | |
| self.prediction_bar.remove_task(self.prediction_task_id) | |
| self.prediction_task_id = None | |
| def on_log(self, args, state, control, logs=None, **kwargs): | |
| if state.is_world_process_zero and self.training_bar is not None: | |
| _ = logs.pop("total_flos", None) | |
| self.training_status.update(f"[bold green]Status = {str(logs)}") | |
| def on_train_end(self, args, state, control, **kwargs): | |
| if state.is_world_process_zero: | |
| self.rich_group.stop() | |
| self.training_bar = None | |
| self.prediction_bar = None | |
| self.training_task_id = None | |
| self.prediction_task_id = None | |
| self.rich_group = None | |
| self.rich_console = None | |
| self.training_status = None | |
| self.current_step = None | |
| def _win_rate_completions_df( | |
| state: TrainerState, prompts: list[str], completions: list[str], winner_indices: list[str] | |
| ) -> pd.DataFrame: | |
| global_step = [str(state.global_step)] * len(prompts) | |
| data = list(zip(global_step, prompts, completions, winner_indices)) | |
| # Split completions from reference model and policy | |
| split_data = [(item[0], item[1], item[2][0], item[2][1], item[3]) for item in data] | |
| return pd.DataFrame(split_data, columns=["step", "prompt", "reference_model", "policy", "winner_index"]) | |
| class WinRateCallback(TrainerCallback): | |
| """ | |
| A [`~transformers.TrainerCallback`] that computes the win rate of a model based on a reference. | |
| It generates completions using prompts from the evaluation dataset and compares the trained model's outputs against | |
| a reference. The reference is either the initial version of the model (before training) or the reference model, if | |
| available in the trainer. During each evaluation step, a judge determines how often the trained model's completions | |
| win against the reference using a judge. The win rate is then logged in the trainer's logs under the key | |
| `"eval_win_rate"`. | |
| Usage: | |
| ```python | |
| trainer = DPOTrainer(...) | |
| judge = PairRMJudge() | |
| win_rate_callback = WinRateCallback(judge=judge, trainer=trainer) | |
| trainer.add_callback(win_rate_callback) | |
| ``` | |
| Args: | |
| judge (`BasePairwiseJudge`): | |
| The judge to use for comparing completions. | |
| trainer (`Trainer`): | |
| Trainer to which the callback will be attached. The trainer's evaluation dataset must include a `"prompt"` | |
| column containing the prompts for generating completions. If the `Trainer` has a reference model (via the | |
| `ref_model` attribute), it will use this reference model for generating the reference completions; | |
| otherwise, it defaults to using the initial model. | |
| generation_config (`GenerationConfig`, *optional*): | |
| The generation config to use for generating completions. | |
| num_prompts (`int` or `None`, *optional*, defaults to `None`): | |
| The number of prompts to generate completions for. If not provided, defaults to the number of examples in | |
| the evaluation dataset. | |
| shuffle_order (`bool`, *optional*, defaults to `True`): | |
| Whether to shuffle the order of the completions before judging. | |
| use_soft_judge (`bool`, *optional*, defaults to `False`): | |
| Whether to use a soft judge that returns a win probability between 0 and 1 for the first completion vs the | |
| second. | |
| """ | |
| def __init__( | |
| self, | |
| judge: BasePairwiseJudge, | |
| trainer: Trainer, | |
| generation_config: Optional[GenerationConfig] = None, | |
| num_prompts: Optional[int] = None, | |
| shuffle_order: bool = True, | |
| use_soft_judge: bool = False, | |
| ): | |
| self.judge = judge | |
| self.trainer = trainer | |
| self.shuffle_order = shuffle_order | |
| self.generation_config = generation_config | |
| self.ref_completions = [] | |
| self.use_soft_judge = use_soft_judge | |
| if self.trainer.eval_dataset is None: | |
| raise ValueError("Trainer must have an evaluation dataset to use the WinRateCallback.") | |
| else: | |
| self.eval_dataset = self.trainer.eval_dataset | |
| if num_prompts is not None: | |
| self.eval_dataset = self.eval_dataset.select(range(num_prompts)) | |
| def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): | |
| # When the trainer is initialized, we generate completions for the reference model. | |
| tokenizer = kwargs["processing_class"] | |
| tokenizer.padding_side = "left" | |
| accelerator = self.trainer.accelerator | |
| # Use the reference model if available, otherwise use the initial model | |
| model = getattr(self.trainer, "ref_model", None) | |
| # At this point, there are two cases where `ref_model` is None: | |
| # 1. The method doesn't require a reference model. | |
| # 2. The method uses a reference model, but `ref_model` is set to None. | |
| # This occurs when using PEFT, where the reference model can be obtained by simply disabling the model's adapter. | |
| # In theory, we should disable the adapter here, but since it's zero-initialized at the start of training, | |
| # the model behaves identically with or without the adapter. | |
| # Therefore, there's no need to explicitly disable it at this point. | |
| if model is None: | |
| model = self.trainer.model_wrapped | |
| with accelerator.split_between_processes(self.eval_dataset["prompt"]) as prompts: | |
| self.ref_completions = _generate_completions( | |
| prompts, | |
| model=model, | |
| tokenizer=tokenizer, | |
| accelerator=accelerator, | |
| generation_config=self.generation_config, | |
| batch_size=args.per_device_eval_batch_size, | |
| ) | |
| # Compute initial win rate as a reference point | |
| completions = list(zip(self.ref_completions, self.ref_completions)) | |
| if self.use_soft_judge: | |
| ref_win_probs = self.judge.judge(prompts, completions, self.shuffle_order, return_scores=True) | |
| winner_indices = [0 if score > 0.5 else 1 for score in ref_win_probs] | |
| ref_win_probs = gather_object(ref_win_probs) | |
| else: | |
| winner_indices = self.judge.judge(prompts, completions, self.shuffle_order) | |
| prompts = gather_object(prompts) | |
| completions = gather_object(completions) | |
| winner_indices = gather_object(winner_indices) | |
| # Logging | |
| if self.trainer.accelerator.is_main_process: | |
| win_rate = sum(winner_idx == 1 for winner_idx in winner_indices) / len(winner_indices) | |
| if self.use_soft_judge: | |
| avg_win_prob = 1.0 - sum(ref_win_probs) / len(ref_win_probs) | |
| self.trainer.log({"eval_avg_win_prob": avg_win_prob, "eval_win_rate": win_rate}) | |
| else: | |
| self.trainer.log({"eval_win_rate": win_rate}) | |
| if "wandb" in args.report_to: | |
| import wandb | |
| if wandb.run is not None: | |
| df = _win_rate_completions_df( | |
| state=state, | |
| prompts=prompts, | |
| completions=completions, | |
| winner_indices=winner_indices, | |
| ) | |
| wandb.log({"win_rate_completions": wandb.Table(dataframe=df)}) | |
| if "comet_ml" in args.report_to: | |
| df = _win_rate_completions_df( | |
| state=state, | |
| prompts=prompts, | |
| completions=completions, | |
| winner_indices=winner_indices, | |
| ) | |
| log_table_to_comet_experiment( | |
| name="win_rate_completions.csv", | |
| table=df, | |
| ) | |
| def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): | |
| # At every evaluation step, we generate completions for the model and compare them with the reference | |
| # completions that have been generated at the beginning of training. We then compute the win rate and log it to | |
| # the trainer. | |
| tokenizer = kwargs["processing_class"] | |
| tokenizer.padding_side = "left" | |
| accelerator = self.trainer.accelerator | |
| model = self.trainer.model_wrapped | |
| with accelerator.split_between_processes(self.eval_dataset["prompt"]) as prompts: | |
| completions = _generate_completions( | |
| prompts, | |
| model=model, | |
| tokenizer=tokenizer, | |
| accelerator=accelerator, | |
| generation_config=self.generation_config, | |
| batch_size=args.per_device_eval_batch_size, | |
| ) | |
| completions = list(zip(self.ref_completions, completions)) | |
| if self.use_soft_judge: | |
| ref_win_probs = self.judge.judge(prompts, completions, self.shuffle_order, return_scores=True) | |
| winner_indices = [0 if score > 0.5 else 1 for score in ref_win_probs] | |
| ref_win_probs = gather_object(ref_win_probs) | |
| else: | |
| winner_indices = self.judge.judge(prompts, completions, self.shuffle_order) | |
| prompts = gather_object(prompts) | |
| completions = gather_object(completions) | |
| winner_indices = gather_object(winner_indices) | |
| # Logging | |
| if self.trainer.accelerator.is_main_process: | |
| win_rate = sum(winner_idx == 1 for winner_idx in winner_indices) / len(winner_indices) | |
| if self.use_soft_judge: | |
| avg_win_prob = 1.0 - sum(ref_win_probs) / len(ref_win_probs) | |
| self.trainer.log({"eval_avg_win_prob": avg_win_prob, "eval_win_rate": win_rate}) | |
| else: | |
| self.trainer.log({"eval_win_rate": win_rate}) | |
| if "wandb" in args.report_to: | |
| import wandb | |
| if wandb.run is not None: | |
| df = _win_rate_completions_df( | |
| state=state, | |
| prompts=prompts, | |
| completions=completions, | |
| winner_indices=winner_indices, | |
| ) | |
| wandb.log({"win_rate_completions": wandb.Table(dataframe=df)}) | |
| if "comet_ml" in args.report_to: | |
| df = _win_rate_completions_df( | |
| state=state, | |
| prompts=prompts, | |
| completions=completions, | |
| winner_indices=winner_indices, | |
| ) | |
| log_table_to_comet_experiment( | |
| name="win_rate_completions.csv", | |
| table=df, | |
| ) | |
| class LogCompletionsCallback(TrainerCallback): | |
| r""" | |
| A [`~transformers.TrainerCallback`] that logs completions to Weights & Biases and/or Comet. | |
| Usage: | |
| ```python | |
| trainer = DPOTrainer(...) | |
| completions_callback = LogCompletionsCallback(trainer=trainer) | |
| trainer.add_callback(completions_callback) | |
| ``` | |
| Args: | |
| trainer (`Trainer`): | |
| Trainer to which the callback will be attached. The trainer's evaluation dataset must include a `"prompt"` | |
| column containing the prompts for generating completions. | |
| generation_config (`GenerationConfig`, *optional*): | |
| The generation config to use for generating completions. | |
| num_prompts (`int` or `None`, *optional*): | |
| The number of prompts to generate completions for. If not provided, defaults to the number of examples in | |
| the evaluation dataset. | |
| freq (`int` or `None`, *optional*): | |
| The frequency at which to log completions. If not provided, defaults to the trainer's `eval_steps`. | |
| """ | |
| def __init__( | |
| self, | |
| trainer: Trainer, | |
| generation_config: Optional[GenerationConfig] = None, | |
| num_prompts: Optional[int] = None, | |
| freq: Optional[int] = None, | |
| ): | |
| self.trainer = trainer | |
| self.generation_config = generation_config | |
| self.freq = freq | |
| self.table = [] | |
| self._last_logged_step = -1 | |
| if self.trainer.eval_dataset is None: | |
| raise ValueError("Trainer must have an evaluation dataset to use the LogCompletionsCallback.") | |
| else: | |
| self.eval_dataset = self.trainer.eval_dataset | |
| if num_prompts is not None: | |
| self.eval_dataset = self.eval_dataset.select(range(num_prompts)) | |
| def on_step_end(self, args, state, control, **kwargs): | |
| # Only log once per step (this method may be called multiple times) | |
| if state.global_step == self._last_logged_step: | |
| return | |
| # Only log every `freq` steps (if no `freq` is provided, log every `eval_steps` steps) | |
| freq = self.freq or state.eval_steps | |
| if state.global_step % freq != 0: | |
| return | |
| tokenizer = kwargs["processing_class"] | |
| tokenizer.padding_side = "left" | |
| accelerator = self.trainer.accelerator | |
| model = self.trainer.model_wrapped | |
| with accelerator.split_between_processes(self.eval_dataset["prompt"]) as prompts: | |
| prompts = [maybe_apply_chat_template({"prompt": prompt}, tokenizer)["prompt"] for prompt in prompts] | |
| completions = _generate_completions( | |
| prompts, | |
| model=model, | |
| tokenizer=tokenizer, | |
| accelerator=accelerator, | |
| generation_config=self.generation_config, | |
| batch_size=args.per_device_eval_batch_size, | |
| ) | |
| completions = gather_object(completions) | |
| prompts = gather_object(prompts) | |
| # Build the data to log | |
| if self.trainer.accelerator.is_main_process: | |
| global_step = [str(state.global_step)] * len(prompts) | |
| data = list(zip(global_step, prompts, completions)) | |
| self.table.extend(data) | |
| table = pd.DataFrame(columns=["step", "prompt", "completion"], data=self.table) | |
| if "wandb" in args.report_to: | |
| wandb.log({"completions": table}) | |
| if "comet_ml" in args.report_to: | |
| log_table_to_comet_experiment( | |
| name="completions.csv", | |
| table=table, | |
| ) | |
| # Save the last logged step, so we don't log the same completions multiple times | |
| self._last_logged_step = state.global_step | |
| class MergeModelCallback(TrainerCallback): | |
| r""" | |
| A [`~transformers.TrainerCallback`] that merges the policy model (the model being trained) with another model based | |
| on a merge configuration. | |
| Args: | |
| merge_config ([`MergeConfig`], *optional*, defaults to `None`): | |
| Configuration used for the merging process. If not provided, the default [`MergeConfig`] is used. | |
| merge_at_every_checkpoint (`bool`, *optional*, defaults to `False`): | |
| Whether to merge the model at every checkpoint. | |
| push_to_hub (`bool`, *optional*, defaults to `False`): | |
| Whether to push the merged model to the Hub after merging. | |
| Example: | |
| ```python | |
| # pip install mergekit | |
| from trl.mergekit_utils import MergeConfig | |
| from trl import MergeModelCallback | |
| config = MergeConfig() | |
| merge_callback = MergeModelCallback(config) | |
| trainer = DPOTrainer(..., callbacks=[merge_callback]) | |
| ``` | |
| """ | |
| def __init__( | |
| self, | |
| merge_config: Optional["MergeConfig"] = None, | |
| merge_at_every_checkpoint: bool = False, | |
| push_to_hub: bool = False, | |
| ): | |
| if not is_mergekit_available(): | |
| raise ImportError( | |
| "MergeModelCallback requires the `mergekit` extra. To install, run `pip install mergekit`." | |
| ) | |
| self.merge_config = merge_config or MergeConfig() | |
| self.merge_at_every_checkpoint = merge_at_every_checkpoint | |
| self.push_to_hub = push_to_hub | |
| def _merge_and_maybe_push(self, output_dir, global_step, model): | |
| checkpoint_path = os.path.join(output_dir, f"checkpoint-{global_step}") | |
| self.merge_config.policy_model_path = checkpoint_path | |
| if self.merge_config.target_model_path is None: | |
| self.merge_config.target_model_path = model.config._name_or_path | |
| merge_path = os.path.join(checkpoint_path, "merged") | |
| merge_models(self.merge_config.create(), merge_path) | |
| if self.push_to_hub: | |
| repo_name = f"{output_dir}_checkpoint-{global_step}_merged" | |
| upload_model_to_hf(merge_path, repo_name) | |
| def on_save(self, args, state, control, model=None, **kwargs): | |
| if self.merge_at_every_checkpoint: | |
| self._merge_and_maybe_push(args.output_dir, state.global_step, model) | |
| def on_train_end(self, args, state, control, model=None, **kwargs): | |
| if not self.merge_at_every_checkpoint: | |
| self._merge_and_maybe_push(args.output_dir, state.global_step, model) | |