""" 2025.3.15 2025.3.17 4.50.0.dev0 0.15.2 __UNSLOTH_VERSIONING__ """ from torch import Tensor import torch import torch.nn as nn from torch.nn import functional as F from trl.trainer.gkd_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, Callable, DataCollator, DataCollatorForChatML, Dataset, EvalPrediction, F, FeatureExtractionMixin, GKDConfig, GKDTrainer, GenerationConfig, Optional, PeftConfig, PreTrainedModel, PreTrainedModelWrapper, PreTrainedTokenizerBase, ProcessorMixin, SFTTrainer, TrainerCallback, Union, deepcopy, disable_dropout_in_model, empty_cache, generate_model_card, get_comet_experiment_url, is_wandb_available, nn, os, random, textwrap, torch, unwrap_model_for_generation) import os from typing import * from dataclasses import dataclass, field from packaging.version import Version import torch import numpy as np from contextlib import nullcontext from torch.nn import functional as F from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling torch_compile_options = { "epilogue_fusion" : True, "max_autotune" : False, "shape_padding" : True, "trace.enabled" : False, "triton.cudagraphs" : False, } @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,) def selective_log_softmax(logits, index): logits = logits.to(torch.float32) selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1) # loop to reduce peak mem consumption # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits]) logsumexp_values = torch.logsumexp(logits, dim = -1) per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x) return per_token_logps @dataclass class UnslothGKDConfig(GKDConfig): """ Configuration class for [`GKDTrainer`]. Args: temperature (`float`, *optional*, defaults to `0.9`): Temperature for sampling. The higher the temperature, the more random the completions. lmbda (`float`, *optional*, defaults to `0.5`): Lambda parameter that controls the student data fraction (i.e., the proportion of on-policy student-generated outputs). beta (`float`, *optional*, defaults to `0.5`): Interpolation coefficient between `0.0` and `1.0` of the Generalized Jensen-Shannon Divergence loss. When beta is `0.0`, the loss is the KL divergence. When beta is `1.0`, the loss is the Inverse KL Divergence. max_new_tokens (`int`, *optional*, defaults to `128`): Maximum number of tokens to generate per completion. teacher_model_name_or_path (`str` or `None`, *optional*, defaults to `None`): Model name or path of the teacher model. If `None`, the teacher model will be the same as the model being trained. teacher_model_init_kwargs (`dict[str, Any]]` or `None`, *optional*, defaults to `None`): Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the teacher model from a string. disable_dropout (`bool`, *optional*, defaults to `True`): Whether to disable dropout in the model. seq_kd (`bool`, *optional*, defaults to `False`): Seq_kd parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised FT on teacher-generated output). """ vllm_sampling_params: Optional[Any] = field( default = None, metadata = {'help': 'vLLM SamplingParams'}, ) unsloth_num_chunks : Optional[int] = field( default = -1, metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'}, ) def __init__( self, output_dir = None, overwrite_output_dir = None, do_train = False, do_eval = False, do_predict = False, eval_strategy = 'no', prediction_loss_only = False, per_device_train_batch_size = 4, per_device_eval_batch_size = 4, per_gpu_train_batch_size = None, per_gpu_eval_batch_size = None, gradient_accumulation_steps = 2, eval_accumulation_steps = 2, eval_delay = 0, torch_empty_cache_steps = 250, learning_rate = 5e-05, weight_decay = 0.01, adam_beta1 = 0.9, adam_beta2 = 0.999, adam_epsilon = 1e-08, max_grad_norm = 1.0, num_train_epochs = 3.0, max_steps = -1, lr_scheduler_type = 'linear', warmup_ratio = 0.1, warmup_steps = 0, log_level = 'passive', log_level_replica = 'warning', log_on_each_node = True, logging_dir = None, logging_strategy = 'steps', logging_first_step = False, logging_steps = 1, logging_nan_inf_filter = False, save_strategy = 'steps', save_steps = 500, save_total_limit = None, save_safetensors = True, save_on_each_node = False, save_only_model = False, restore_callback_states_from_checkpoint = False, no_cuda = False, use_cpu = False, use_mps_device = False, seed = 3407, data_seed = 3407, jit_mode_eval = False, use_ipex = False, bf16 = False, fp16 = False, fp16_opt_level = 'O1', half_precision_backend = 'auto', bf16_full_eval = False, fp16_full_eval = False, tf32 = None, local_rank = -1, ddp_backend = None, tpu_num_cores = None, tpu_metrics_debug = False, debug = '', dataloader_drop_last = False, eval_steps = None, dataloader_num_workers = 0, dataloader_prefetch_factor = None, past_index = -1, run_name = None, disable_tqdm = None, remove_unused_columns = True, label_names = None, load_best_model_at_end = False, metric_for_best_model = None, greater_is_better = None, ignore_data_skip = False, fsdp = '', fsdp_min_num_params = 0, fsdp_config = None, tp_size = 0, fsdp_transformer_layer_cls_to_wrap = None, accelerator_config = None, deepspeed = None, label_smoothing_factor = 0.0, optim = 'adamw_8bit', optim_args = None, adafactor = False, group_by_length = False, length_column_name = 'length', report_to = None, ddp_find_unused_parameters = None, ddp_bucket_cap_mb = None, ddp_broadcast_buffers = None, dataloader_pin_memory = True, dataloader_persistent_workers = False, skip_memory_metrics = True, use_legacy_prediction_loop = False, push_to_hub = False, resume_from_checkpoint = None, hub_model_id = None, hub_strategy = 'every_save', hub_token = None, hub_private_repo = None, hub_always_push = False, gradient_checkpointing = False, gradient_checkpointing_kwargs = None, include_inputs_for_metrics = False, eval_do_concat_batches = True, fp16_backend = 'auto', evaluation_strategy = None, push_to_hub_model_id = None, push_to_hub_organization = None, push_to_hub_token = None, mp_parameters = '', auto_find_batch_size = False, full_determinism = False, torchdynamo = None, ray_scope = 'last', ddp_timeout = 1800, torch_compile = False, torch_compile_backend = None, torch_compile_mode = None, dispatch_batches = None, split_batches = None, include_tokens_per_second = False, include_num_input_tokens_seen = False, neftune_noise_alpha = None, optim_target_modules = None, batch_eval_metrics = False, eval_on_start = False, use_liger_kernel = False, eval_use_gather_object = False, average_tokens_across_devices = False, model_init_kwargs = None, use_liger = False, dataset_text_field = 'text', dataset_kwargs = None, dataset_num_proc = None, max_seq_length = None, packing = False, eval_packing = None, dataset_batch_size = None, num_of_sequences = None, chars_per_token = None, temperature = 0.9, lmbda = 0.5, beta = 0.5, max_new_tokens = 128, teacher_model_name_or_path = None, teacher_model_init_kwargs = None, disable_dropout = True, seq_kd = False, vllm_sampling_params = None, unsloth_num_chunks = -1, **kwargs, ): if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!') if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!') if output_dir is None and save_strategy == 'steps' and save_steps == 500: output_dir = 'unsloth_training_checkpoints' save_strategy = 'no' if dataset_num_proc is None: from multiprocessing import cpu_count dataset_num_proc = cpu_count() super().__init__( output_dir = output_dir, overwrite_output_dir = overwrite_output_dir, do_train = do_train, do_eval = do_eval, do_predict = do_predict, eval_strategy = eval_strategy, prediction_loss_only = prediction_loss_only, per_device_train_batch_size = per_device_train_batch_size, per_device_eval_batch_size = per_device_eval_batch_size, per_gpu_train_batch_size = per_gpu_train_batch_size, per_gpu_eval_batch_size = per_gpu_eval_batch_size, gradient_accumulation_steps = gradient_accumulation_steps, eval_accumulation_steps = eval_accumulation_steps, eval_delay = eval_delay, torch_empty_cache_steps = torch_empty_cache_steps, learning_rate = learning_rate, weight_decay = weight_decay, adam_beta1 = adam_beta1, adam_beta2 = adam_beta2, adam_epsilon = adam_epsilon, max_grad_norm = max_grad_norm, num_train_epochs = num_train_epochs, max_steps = max_steps, lr_scheduler_type = lr_scheduler_type, warmup_ratio = warmup_ratio, warmup_steps = warmup_steps, log_level = log_level, log_level_replica = log_level_replica, log_on_each_node = log_on_each_node, logging_dir = logging_dir, logging_strategy = logging_strategy, logging_first_step = logging_first_step, logging_steps = logging_steps, logging_nan_inf_filter = logging_nan_inf_filter, save_strategy = save_strategy, save_steps = save_steps, save_total_limit = save_total_limit, save_safetensors = save_safetensors, save_on_each_node = save_on_each_node, save_only_model = save_only_model, restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint, no_cuda = no_cuda, use_cpu = use_cpu, use_mps_device = use_mps_device, seed = seed, data_seed = data_seed, jit_mode_eval = jit_mode_eval, use_ipex = use_ipex, bf16 = bf16, fp16 = fp16, fp16_opt_level = fp16_opt_level, half_precision_backend = half_precision_backend, bf16_full_eval = bf16_full_eval, fp16_full_eval = fp16_full_eval, tf32 = tf32, local_rank = local_rank, ddp_backend = ddp_backend, tpu_num_cores = tpu_num_cores, tpu_metrics_debug = tpu_metrics_debug, debug = debug, dataloader_drop_last = dataloader_drop_last, eval_steps = eval_steps, dataloader_num_workers = dataloader_num_workers, dataloader_prefetch_factor = dataloader_prefetch_factor, past_index = past_index, run_name = run_name, disable_tqdm = disable_tqdm, remove_unused_columns = remove_unused_columns, label_names = label_names, load_best_model_at_end = load_best_model_at_end, metric_for_best_model = metric_for_best_model, greater_is_better = greater_is_better, ignore_data_skip = ignore_data_skip, fsdp = fsdp, fsdp_min_num_params = fsdp_min_num_params, fsdp_config = fsdp_config, tp_size = tp_size, fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap, accelerator_config = accelerator_config, deepspeed = deepspeed, label_smoothing_factor = label_smoothing_factor, optim = optim, optim_args = optim_args, adafactor = adafactor, group_by_length = group_by_length, length_column_name = length_column_name, report_to = report_to, ddp_find_unused_parameters = ddp_find_unused_parameters, ddp_bucket_cap_mb = ddp_bucket_cap_mb, ddp_broadcast_buffers = ddp_broadcast_buffers, dataloader_pin_memory = dataloader_pin_memory, dataloader_persistent_workers = dataloader_persistent_workers, skip_memory_metrics = skip_memory_metrics, use_legacy_prediction_loop = use_legacy_prediction_loop, push_to_hub = push_to_hub, resume_from_checkpoint = resume_from_checkpoint, hub_model_id = hub_model_id, hub_strategy = hub_strategy, hub_token = hub_token, hub_private_repo = hub_private_repo, hub_always_push = hub_always_push, gradient_checkpointing = gradient_checkpointing, gradient_checkpointing_kwargs = gradient_checkpointing_kwargs, include_inputs_for_metrics = include_inputs_for_metrics, eval_do_concat_batches = eval_do_concat_batches, fp16_backend = fp16_backend, evaluation_strategy = evaluation_strategy, push_to_hub_model_id = push_to_hub_model_id, push_to_hub_organization = push_to_hub_organization, push_to_hub_token = push_to_hub_token, mp_parameters = mp_parameters, auto_find_batch_size = auto_find_batch_size, full_determinism = full_determinism, torchdynamo = torchdynamo, ray_scope = ray_scope, ddp_timeout = ddp_timeout, torch_compile = torch_compile, torch_compile_backend = torch_compile_backend, torch_compile_mode = torch_compile_mode, dispatch_batches = dispatch_batches, split_batches = split_batches, include_tokens_per_second = include_tokens_per_second, include_num_input_tokens_seen = include_num_input_tokens_seen, neftune_noise_alpha = neftune_noise_alpha, optim_target_modules = optim_target_modules, batch_eval_metrics = batch_eval_metrics, eval_on_start = eval_on_start, use_liger_kernel = use_liger_kernel, eval_use_gather_object = eval_use_gather_object, average_tokens_across_devices = average_tokens_across_devices, model_init_kwargs = model_init_kwargs, use_liger = use_liger, dataset_text_field = dataset_text_field, dataset_kwargs = dataset_kwargs, dataset_num_proc = dataset_num_proc, max_seq_length = max_seq_length, packing = packing, eval_packing = eval_packing, dataset_batch_size = dataset_batch_size, num_of_sequences = num_of_sequences, chars_per_token = chars_per_token, temperature = temperature, lmbda = lmbda, beta = beta, max_new_tokens = max_new_tokens, teacher_model_name_or_path = teacher_model_name_or_path, teacher_model_init_kwargs = teacher_model_init_kwargs, disable_dropout = disable_dropout, seq_kd = seq_kd,**kwargs) self.vllm_sampling_params = vllm_sampling_params self.unsloth_num_chunks = unsloth_num_chunks pass class _UnslothGKDTrainer(SFTTrainer): _tag_names = ["trl", "gkd"] def __init__( self, model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, teacher_model: Union[PreTrainedModel, nn.Module, str] = None, args: Optional[GKDConfig] = None, data_collator: Optional[DataCollator] = None, # type: ignore train_dataset: Optional[Dataset] = None, eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None, processing_class: Optional[ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin] ] = None, compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None, callbacks: Optional[list[TrainerCallback]] = None, optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, peft_config: Optional["PeftConfig"] = None, formatting_func: Optional[Callable] = None, ): # add remove_unused_columns=False to the dataclass args args.remove_unused_columns = False data_collator = DataCollatorForChatML(tokenizer=processing_class, max_length=args.max_seq_length) super().__init__( model, args=args, data_collator=data_collator, train_dataset=train_dataset, eval_dataset=eval_dataset, processing_class=processing_class, compute_metrics=compute_metrics, callbacks=callbacks, optimizers=optimizers, preprocess_logits_for_metrics=preprocess_logits_for_metrics, peft_config=peft_config, formatting_func=formatting_func, ) if args.teacher_model_init_kwargs is None: teacher_model_init_kwargs = {} elif not isinstance(teacher_model, str): raise ValueError( "You passed teacher_model_init_kwargs to the GKDConfig, but your teacher_model is already instantiated." ) else: teacher_model_init_kwargs = args.teacher_model_init_kwargs teacher_model_init_kwargs["torch_dtype"] = ( teacher_model_init_kwargs["torch_dtype"] if teacher_model_init_kwargs["torch_dtype"] in ["auto", None] else getattr(torch, teacher_model_init_kwargs["torch_dtype"]) ) if isinstance(teacher_model, str): if args.use_liger: teacher_model = AutoLigerKernelForCausalLM.from_pretrained(teacher_model, **teacher_model_init_kwargs) else: teacher_model = AutoModelForCausalLM.from_pretrained(teacher_model, **teacher_model_init_kwargs) # Disable dropout in the model if args.disable_dropout: disable_dropout_in_model(self.model) if self.is_deepspeed_enabled: self.teacher_model = self._prepare_deepspeed(teacher_model) else: self.teacher_model = self.accelerator.prepare_model(teacher_model, evaluation_mode=True) self.lmbda = args.lmbda self.beta = args.beta self.temperature = args.temperature self.seq_kd = args.seq_kd self.generation_config = GenerationConfig( max_new_tokens=args.max_new_tokens, temperature=args.temperature, do_sample=True, top_k=0, use_cache=False if args.gradient_checkpointing else True, pad_token_id=self.processing_class.pad_token_id, ) # Set custom EOS tokens if they are specified by the model's generation # config. This is important for models with the Llama 3 chat template, # which use special tokens <|eot_id|> and <|eom_id|> to mark the end of # turns or messages. if ( hasattr(self.model.generation_config, "eos_token_id") and self.model.generation_config.eos_token_id is not None ): self.generation_config.eos_token_id = self.model.generation_config.eos_token_id def _prepare_dataset(self, dataset, *args): # SFTTrainer._prepare_dataset() applies the chat template and rename the messages column to text. However, we # need to keep the messages column as it is. We use the following workaround to keep the messages column. dataset = dataset.add_column("_messages", dataset["messages"]) dataset = super()._prepare_dataset(dataset, *args) dataset = dataset.rename_column("_messages", "messages") return dataset @staticmethod def generalized_jsd_loss( student_logits, teacher_logits, labels=None, beta=0.5, temperature=1.0, reduction="batchmean" ): """ Compute the generalized Jensen-Shannon Divergence loss for knowledge distillation using F.kl_div. See Eq. (1) of https://huggingface.co/papers/2306.13649 for the definition. Args: student_logits: Tensor of shape (batch_size, sequence_length, vocab_size) teacher_logits: Tensor of shape (batch_size, sequence_length, vocab_size) labels: Tensor of shape (batch_size, sequence_length) with -100 for padding tokens to ignore when computing loss beta: Interpolation coefficient between 0 and 1 (default: 0.5) temperature: Softmax temperature (default: 1.0) reduction: Specifies the reduction to apply to the output (default: 'batchmean') Returns: loss: Scalar tensor with the generalized JSD loss """ # Apply temperature scaling student_logits = student_logits / temperature teacher_logits = teacher_logits / temperature # Compute log probabilities for student and probabilities for teacher student_log_probs = F.log_softmax(student_logits, dim=-1) teacher_log_probs = F.log_softmax(teacher_logits, dim=-1) # Compute the log of the mixture distribution # log(a + b) = log(exp(log(a)) + exp(log(b))) -> for mixture beta = torch.tensor(beta, dtype=student_log_probs.dtype) mixture_log_probs = torch.logsumexp( torch.stack([student_log_probs + torch.log(beta), teacher_log_probs + torch.log(1 - beta)]), dim=0, ) # Compute KL divergences using F.kl_div # PyTorch differs from the standard mathematical definition, so the order of the probability distributions is swapped compared to that defined in the paper. kl_teacher = F.kl_div(mixture_log_probs, teacher_log_probs, reduction="none", log_target=True) kl_student = F.kl_div(mixture_log_probs, student_log_probs, reduction="none", log_target=True) # Compute the Generalized Jensen-Shannon Divergence jsd = beta * kl_teacher + (1 - beta) * kl_student # Masking if labels is not None: mask = labels != -100 jsd = jsd[mask] # Apply reduction if reduction == "batchmean": return jsd.sum() / mask.sum() if labels is not None else jsd.sum() / (jsd.size(0) * jsd.size(1)) elif reduction == "sum": return jsd.sum() elif reduction == "mean": return jsd.mean() else: return jsd def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): # compute student output outputs_student = model( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], ) # compute teacher output in eval mode self.teacher_model.eval() with torch.no_grad(): outputs_teacher = self.teacher_model( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], ) # slice the logits for the generated tokens using the inputs["prompts"] lengths prompt_lengths = inputs["prompts"].shape[1] shifted_student_logits = outputs_student.logits[:, prompt_lengths - 1 : -1, :] shifted_teacher_logits = outputs_teacher.logits[:, prompt_lengths - 1 : -1, :] shifted_labels = inputs["labels"][:, prompt_lengths:] # compute loss loss = self.generalized_jsd_loss( student_logits=shifted_student_logits, teacher_logits=shifted_teacher_logits, labels=shifted_labels, beta=self.beta, ) # empty cache empty_cache() # Return loss return (loss, outputs_student) if return_outputs else loss @staticmethod def generate_on_policy_outputs(model, inputs, generation_config, pad_token_id=None): # Generate output with respect to the prompt only generated_outputs = model.generate( input_ids=inputs["prompts"], attention_mask=inputs.get("prompt_attention_mask", None), generation_config=generation_config, return_dict_in_generate=True, ) # Get the generated token IDs generated_tokens = generated_outputs.sequences # Calculate new attention mask new_attention_mask = torch.ones_like(generated_tokens) new_labels = generated_tokens.clone() # If there's pad_token_id, set attention mask to 0 for padding tokens if pad_token_id is not None: new_labels[new_labels == pad_token_id] = -100 new_attention_mask[generated_tokens == pad_token_id] = 0 return generated_tokens, new_attention_mask, new_labels def training_step( self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None ) -> torch.Tensor: """ Perform a training step for the Generalized Knowledge Distillation (GKD) model. This method implements the on-policy learning approach described in the GKD paper. With probability `self.lmbda`, it generates new responses using the student model, which are then used for training instead of the original inputs. """ if self.seq_kd: with unwrap_model_for_generation(self.teacher_model, self.accelerator) as unwrapped_model: new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs( unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id ) inputs["input_ids"] = new_input_ids inputs["attention_mask"] = new_attention_mask inputs["labels"] = new_labels if random.random() <= self.lmbda: with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model: new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs( unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id ) inputs["input_ids"] = new_input_ids inputs["attention_mask"] = new_attention_mask inputs["labels"] = new_labels loss = super().training_step(model, inputs, num_items_in_batch) return loss def _prepare_deepspeed(self, model: PreTrainedModelWrapper): # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473 deepspeed_plugin = self.accelerator.state.deepspeed_plugin config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config) if model is not None: if hasattr(model, "config"): hidden_size = ( max(model.config.hidden_sizes) if getattr(model.config, "hidden_sizes", None) else getattr(model.config, "hidden_size", None) ) if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3: # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0` # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081 config_kwargs.update( { "zero_optimization.reduce_bucket_size": hidden_size * hidden_size, "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size, "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size, } ) # If ZeRO-3 is used, we shard both the active and reference model. # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0) if config_kwargs["zero_optimization"]["stage"] != 3: config_kwargs["zero_optimization"]["stage"] = 0 model, *_ = deepspeed.initialize(model=model, config=config_kwargs) model.eval() return model def create_model_card( self, model_name: Optional[str] = None, dataset_name: Optional[str] = None, tags: Union[str, list[str], None] = None, ): """ Creates a draft of a model card using the information available to the `Trainer`. Args: model_name (`str` or `None`, *optional*, defaults to `None`): Name of the model. dataset_name (`str` or `None`, *optional*, defaults to `None`): Name of the dataset used for training. tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`): Tags to be associated with the model card. """ if not self.is_world_process_zero(): return if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path): base_model = self.model.config._name_or_path else: base_model = None tags = tags or [] if isinstance(tags, str): tags = [tags] if hasattr(self.model.config, "unsloth_version"): tags.append("unsloth") citation = textwrap.dedent("""\ @inproceedings{agarwal2024on-policy, title = {{On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes}}, author = {Rishabh Agarwal and Nino Vieillard and Yongchao Zhou and Piotr Stanczyk and Sabela Ramos Garea and Matthieu Geist and Olivier Bachem}, year = 2024, booktitle = {The Twelfth International Conference on Learning Representations, {ICLR} 2024, Vienna, Austria, May 7-11, 2024}, publisher = {OpenReview.net}, url = {https://openreview.net/forum?id=3zKtaqxLhW}, }""") model_card = generate_model_card( base_model=base_model, model_name=model_name, hub_model_id=self.hub_model_id, dataset_name=dataset_name, tags=tags, wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, comet_url=get_comet_experiment_url(), trainer_name="GKD", trainer_citation=citation, paper_title="On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes", paper_id="2306.13649", ) model_card.save(os.path.join(self.args.output_dir, "README.md")) class UnslothGKDTrainer(_UnslothGKDTrainer): """ """ def __init__( self, model = None, teacher_model = None, args = None, data_collator = None, train_dataset = None, eval_dataset = None, processing_class = None, compute_metrics = None, callbacks = None, preprocess_logits_for_metrics = None, peft_config = None, formatting_func = None, **kwargs ): if args is None: args = UnslothGKDConfig() use_bf16 = getattr(args, 'bf16', False) use_fp16 = getattr(args, 'fp16', False) force_float32 = False if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1': print('Unsloth: Switching to float32 training since model cannot work with float16') force_float32 = True mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') dtype = getattr(model.config, 'torch_dtype', None) if dtype is None: dtype = model.get_input_embeddings().dtype from unsloth_zoo.utils import _get_dtype dtype = _get_dtype(dtype) float16 = dtype == torch.float16 if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`') if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`') if force_float32: args.fp16 = False args.bf16 = False os.environ['ACCELERATE_MIXED_PRECISION'] = 'no' elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32': args.fp16 = float16 args.bf16 = not float16 os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16' if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no': args.eval_strategy = 'steps' if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1 ga_steps = getattr(args, 'gradient_accumulation_steps', None) if ga_steps is not None and ga_steps > 1: from transformers import __version__ as transformers_version if Version(transformers_version) <= Version('4.45.2'): print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n' '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`') if getattr(args, 'eval_strategy', 'no') != 'no': eval_bsz = getattr(args, 'per_device_eval_batch_size', 8) if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps fp16_full_eval = getattr(args, 'fp16_full_eval', False) bf16_full_eval = getattr(args, 'bf16_full_eval', False) if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False if force_float32: args.bf16_full_eval = False args.fp16_full_eval = False elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16': args.bf16_full_eval = True args.fp16_full_eval = False elif not bf16_full_eval and not fp16_full_eval: args.bf16_full_eval = args.bf16 args.fp16_full_eval = args.fp16 _output_logits = False if locals().get('compute_metrics', None) is not None: _output_logits = True if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True if _output_logits: os.environ['UNSLOTH_RETURN_LOGITS'] = '1' if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'): pass else: model_max_seq_length = getattr(model, 'max_seq_length', None) args_max_seq_length = getattr(args, 'max_seq_length', None) if args_max_seq_length is None and model_max_seq_length is not None: max_seq_length = model.max_seq_length if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length if model is not None and hasattr(model, 'for_training'): model.for_training() if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right' if 'processing_class' in locals(): if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right' if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right' __tokenizer = processing_class if 'processing_class' in locals() else tokenizer from unsloth_zoo.vision_utils import UnslothVisionDataCollator if not isinstance(data_collator, UnslothVisionDataCollator): if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names: data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False) elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names: data_collator = DataCollatorForSeq2Seq(__tokenizer) else: if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False if hasattr(args, 'dataset_text_field'): args.dataset_text_field = '' if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True} if not isinstance(data_collator, UnslothVisionDataCollator): if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'): if isinstance(data_collator, DataCollatorForSeq2Seq): data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer) else: data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False) other_metrics = [] from unsloth_zoo.logging_utils import PatchRLStatistics PatchRLStatistics('gkd_trainer', other_metrics) super().__init__( model = model, teacher_model = teacher_model, args = args, data_collator = data_collator, train_dataset = train_dataset, eval_dataset = eval_dataset, processing_class = processing_class, compute_metrics = compute_metrics, callbacks = callbacks, preprocess_logits_for_metrics = preprocess_logits_for_metrics, peft_config = peft_config, formatting_func = formatting_func,**kwargs) if hasattr(self, 'neftune_hook_handle'): self.neftune_hook_handle.remove() if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle if getattr(args, 'neftune_noise_alpha', None) is not None: model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha pass pass