Spaces:
Runtime error
Runtime error
Delete unsloth_compiled_cache
Browse files- unsloth_compiled_cache/UnslothAlignPropTrainer.py +0 -637
- unsloth_compiled_cache/UnslothBCOTrainer.py +0 -1822
- unsloth_compiled_cache/UnslothCPOTrainer.py +0 -1555
- unsloth_compiled_cache/UnslothDDPOTrainer.py +0 -872
- unsloth_compiled_cache/UnslothDPOTrainer.py +0 -0
- unsloth_compiled_cache/UnslothGKDTrainer.py +0 -861
- unsloth_compiled_cache/UnslothGRPOTrainer.py +0 -1436
- unsloth_compiled_cache/UnslothKTOTrainer.py +0 -1838
- unsloth_compiled_cache/UnslothNashMDTrainer.py +0 -953
- unsloth_compiled_cache/UnslothORPOTrainer.py +0 -1541
- unsloth_compiled_cache/UnslothOnlineDPOTrainer.py +0 -1267
- unsloth_compiled_cache/UnslothPPOTrainer.py +0 -1257
- unsloth_compiled_cache/UnslothPRMTrainer.py +0 -798
- unsloth_compiled_cache/UnslothRLOOTrainer.py +0 -1131
- unsloth_compiled_cache/UnslothRewardTrainer.py +0 -817
- unsloth_compiled_cache/UnslothSFTTrainer.py +0 -1025
- unsloth_compiled_cache/UnslothXPOTrainer.py +0 -1008
- unsloth_compiled_cache/__pycache__/UnslothAlignPropTrainer.cpython-311.pyc +0 -0
- unsloth_compiled_cache/__pycache__/UnslothBCOTrainer.cpython-311.pyc +0 -0
- unsloth_compiled_cache/__pycache__/UnslothCPOTrainer.cpython-311.pyc +0 -0
- unsloth_compiled_cache/__pycache__/UnslothDDPOTrainer.cpython-311.pyc +0 -0
- unsloth_compiled_cache/__pycache__/UnslothDPOTrainer.cpython-311.pyc +0 -3
- unsloth_compiled_cache/__pycache__/UnslothGKDTrainer.cpython-311.pyc +0 -0
- unsloth_compiled_cache/__pycache__/UnslothGRPOTrainer.cpython-311.pyc +0 -0
- unsloth_compiled_cache/__pycache__/UnslothKTOTrainer.cpython-311.pyc +0 -0
- unsloth_compiled_cache/__pycache__/UnslothNashMDTrainer.cpython-311.pyc +0 -0
- unsloth_compiled_cache/__pycache__/UnslothORPOTrainer.cpython-311.pyc +0 -0
- unsloth_compiled_cache/__pycache__/UnslothOnlineDPOTrainer.cpython-311.pyc +0 -0
- unsloth_compiled_cache/__pycache__/UnslothPPOTrainer.cpython-311.pyc +0 -0
- unsloth_compiled_cache/__pycache__/UnslothPRMTrainer.cpython-311.pyc +0 -0
- unsloth_compiled_cache/__pycache__/UnslothRLOOTrainer.cpython-311.pyc +0 -0
- unsloth_compiled_cache/__pycache__/UnslothRewardTrainer.cpython-311.pyc +0 -0
- unsloth_compiled_cache/__pycache__/UnslothSFTTrainer.cpython-311.pyc +0 -0
- unsloth_compiled_cache/__pycache__/UnslothXPOTrainer.cpython-311.pyc +0 -0
unsloth_compiled_cache/UnslothAlignPropTrainer.py
DELETED
@@ -1,637 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
2025.3.13
|
3 |
-
2025.3.15
|
4 |
-
4.48.3
|
5 |
-
0.15.2
|
6 |
-
__UNSLOTH_VERSIONING__
|
7 |
-
"""
|
8 |
-
from torch import Tensor
|
9 |
-
import torch
|
10 |
-
import torch.nn as nn
|
11 |
-
from torch.nn import functional as F
|
12 |
-
from trl.trainer.alignprop_trainer import (Accelerator, AlignPropConfig, AlignPropTrainer, Any, Callable, DDPOStableDiffusionPipeline, Optional, ProjectConfiguration, PyTorchModelHubMixin, Union, defaultdict, generate_model_card, get_comet_experiment_url, is_wandb_available, logger, os, set_seed, textwrap, torch, wandb, warn)
|
13 |
-
|
14 |
-
|
15 |
-
import os
|
16 |
-
from typing import *
|
17 |
-
from dataclasses import dataclass, field
|
18 |
-
from packaging.version import Version
|
19 |
-
import torch
|
20 |
-
import numpy as np
|
21 |
-
from contextlib import nullcontext
|
22 |
-
from torch.nn import functional as F
|
23 |
-
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
|
24 |
-
|
25 |
-
torch_compile_options = {
|
26 |
-
"epilogue_fusion" : True,
|
27 |
-
"max_autotune" : False,
|
28 |
-
"shape_padding" : True,
|
29 |
-
"trace.enabled" : False,
|
30 |
-
"triton.cudagraphs" : False,
|
31 |
-
}
|
32 |
-
|
33 |
-
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
34 |
-
def selective_log_softmax(logits, index):
|
35 |
-
logits = logits.to(torch.float32)
|
36 |
-
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
37 |
-
# loop to reduce peak mem consumption
|
38 |
-
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
39 |
-
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
40 |
-
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
41 |
-
return per_token_logps
|
42 |
-
@dataclass
|
43 |
-
class UnslothAlignPropConfig(AlignPropConfig):
|
44 |
-
"""
|
45 |
-
|
46 |
-
Configuration class for the [`AlignPropTrainer`].
|
47 |
-
|
48 |
-
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
49 |
-
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
50 |
-
command line.
|
51 |
-
|
52 |
-
Parameters:
|
53 |
-
exp_name (`str`, *optional*, defaults to `os.path.basename(sys.argv[0])[: -len(".py")]`):
|
54 |
-
Name of this experiment (defaults to the file name without the extension).
|
55 |
-
run_name (`str`, *optional*, defaults to `""`):
|
56 |
-
Name of this run.
|
57 |
-
seed (`int`, *optional*, defaults to `0`):
|
58 |
-
Random seed for reproducibility.
|
59 |
-
log_with (`str` or `None`, *optional*, defaults to `None`):
|
60 |
-
Log with either `"wandb"` or `"tensorboard"`. Check
|
61 |
-
[tracking](https://huggingface.co/docs/accelerate/usage_guides/tracking) for more details.
|
62 |
-
log_image_freq (`int`, *optional*, defaults to `1`):
|
63 |
-
Frequency for logging images.
|
64 |
-
tracker_kwargs (`dict[str, Any]`, *optional*, defaults to `{}`):
|
65 |
-
Keyword arguments for the tracker (e.g., `wandb_project`).
|
66 |
-
accelerator_kwargs (`dict[str, Any]`, *optional*, defaults to `{}`):
|
67 |
-
Keyword arguments for the accelerator.
|
68 |
-
project_kwargs (`dict[str, Any]`, *optional*, defaults to `{}`):
|
69 |
-
Keyword arguments for the accelerator project config (e.g., `logging_dir`).
|
70 |
-
tracker_project_name (`str`, *optional*, defaults to `"trl"`):
|
71 |
-
Name of project to use for tracking.
|
72 |
-
logdir (`str`, *optional*, defaults to `"logs"`):
|
73 |
-
Top-level logging directory for checkpoint saving.
|
74 |
-
num_epochs (`int`, *optional*, defaults to `100`):
|
75 |
-
Number of epochs to train.
|
76 |
-
save_freq (`int`, *optional*, defaults to `1`):
|
77 |
-
Number of epochs between saving model checkpoints.
|
78 |
-
num_checkpoint_limit (`int`, *optional*, defaults to `5`):
|
79 |
-
Number of checkpoints to keep before overwriting old ones.
|
80 |
-
mixed_precision (`str`, *optional*, defaults to `"fp16"`):
|
81 |
-
Mixed precision training.
|
82 |
-
allow_tf32 (`bool`, *optional*, defaults to `True`):
|
83 |
-
Allow `tf32` on Ampere GPUs.
|
84 |
-
resume_from (`str`, *optional*, defaults to `""`):
|
85 |
-
Path to resume training from a checkpoint.
|
86 |
-
sample_num_steps (`int`, *optional*, defaults to `50`):
|
87 |
-
Number of sampler inference steps.
|
88 |
-
sample_eta (`float`, *optional*, defaults to `1.0`):
|
89 |
-
Eta parameter for the DDIM sampler.
|
90 |
-
sample_guidance_scale (`float`, *optional*, defaults to `5.0`):
|
91 |
-
Classifier-free guidance weight.
|
92 |
-
train_batch_size (`int`, *optional*, defaults to `1`):
|
93 |
-
Batch size for training.
|
94 |
-
train_use_8bit_adam (`bool`, *optional*, defaults to `False`):
|
95 |
-
Whether to use the 8bit Adam optimizer from `bitsandbytes`.
|
96 |
-
train_learning_rate (`float`, *optional*, defaults to `1e-3`):
|
97 |
-
Learning rate.
|
98 |
-
train_adam_beta1 (`float`, *optional*, defaults to `0.9`):
|
99 |
-
Beta1 for Adam optimizer.
|
100 |
-
train_adam_beta2 (`float`, *optional*, defaults to `0.999`):
|
101 |
-
Beta2 for Adam optimizer.
|
102 |
-
train_adam_weight_decay (`float`, *optional*, defaults to `1e-4`):
|
103 |
-
Weight decay for Adam optimizer.
|
104 |
-
train_adam_epsilon (`float`, *optional*, defaults to `1e-8`):
|
105 |
-
Epsilon value for Adam optimizer.
|
106 |
-
train_gradient_accumulation_steps (`int`, *optional*, defaults to `1`):
|
107 |
-
Number of gradient accumulation steps.
|
108 |
-
train_max_grad_norm (`float`, *optional*, defaults to `1.0`):
|
109 |
-
Maximum gradient norm for gradient clipping.
|
110 |
-
negative_prompts (`str` or `None`, *optional*, defaults to `None`):
|
111 |
-
Comma-separated list of prompts to use as negative examples.
|
112 |
-
truncated_backprop_rand (`bool`, *optional*, defaults to `True`):
|
113 |
-
If `True`, randomized truncation to different diffusion timesteps is used.
|
114 |
-
truncated_backprop_timestep (`int`, *optional*, defaults to `49`):
|
115 |
-
Absolute timestep to which the gradients are backpropagated. Used only if `truncated_backprop_rand=False`.
|
116 |
-
truncated_rand_backprop_minmax (`tuple[int, int]`, *optional*, defaults to `(0, 50)`):
|
117 |
-
Range of diffusion timesteps for randomized truncated backpropagation.
|
118 |
-
push_to_hub (`bool`, *optional*, defaults to `False`):
|
119 |
-
Whether to push the final model to the Hub.
|
120 |
-
|
121 |
-
"""
|
122 |
-
vllm_sampling_params: Optional[Any] = field(
|
123 |
-
default = None,
|
124 |
-
metadata = {'help': 'vLLM SamplingParams'},
|
125 |
-
)
|
126 |
-
unsloth_num_chunks : Optional[int] = field(
|
127 |
-
default = -1,
|
128 |
-
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
129 |
-
)
|
130 |
-
def __init__(
|
131 |
-
self,
|
132 |
-
exp_name = 'main',
|
133 |
-
run_name = '',
|
134 |
-
seed = 3407,
|
135 |
-
log_with = None,
|
136 |
-
log_image_freq = 1,
|
137 |
-
tracker_project_name = 'trl',
|
138 |
-
logdir = 'logs',
|
139 |
-
num_epochs = 100,
|
140 |
-
save_freq = 1,
|
141 |
-
num_checkpoint_limit = 5,
|
142 |
-
mixed_precision = 'fp16',
|
143 |
-
allow_tf32 = True,
|
144 |
-
resume_from = '',
|
145 |
-
sample_num_steps = 50,
|
146 |
-
sample_eta = 1.0,
|
147 |
-
sample_guidance_scale = 5.0,
|
148 |
-
train_batch_size = 1,
|
149 |
-
train_use_8bit_adam = False,
|
150 |
-
train_learning_rate = 5e-05,
|
151 |
-
train_adam_beta1 = 0.9,
|
152 |
-
train_adam_beta2 = 0.999,
|
153 |
-
train_adam_weight_decay = 0.01,
|
154 |
-
train_adam_epsilon = 1e-08,
|
155 |
-
train_gradient_accumulation_steps = 2,
|
156 |
-
train_max_grad_norm = 1.0,
|
157 |
-
negative_prompts = None,
|
158 |
-
truncated_backprop_rand = True,
|
159 |
-
truncated_backprop_timestep = 49,
|
160 |
-
push_to_hub = False,
|
161 |
-
vllm_sampling_params = None,
|
162 |
-
unsloth_num_chunks = -1,
|
163 |
-
**kwargs,
|
164 |
-
):
|
165 |
-
|
166 |
-
super().__init__(
|
167 |
-
exp_name = exp_name,
|
168 |
-
run_name = run_name,
|
169 |
-
seed = seed,
|
170 |
-
log_with = log_with,
|
171 |
-
log_image_freq = log_image_freq,
|
172 |
-
tracker_project_name = tracker_project_name,
|
173 |
-
logdir = logdir,
|
174 |
-
num_epochs = num_epochs,
|
175 |
-
save_freq = save_freq,
|
176 |
-
num_checkpoint_limit = num_checkpoint_limit,
|
177 |
-
mixed_precision = mixed_precision,
|
178 |
-
allow_tf32 = allow_tf32,
|
179 |
-
resume_from = resume_from,
|
180 |
-
sample_num_steps = sample_num_steps,
|
181 |
-
sample_eta = sample_eta,
|
182 |
-
sample_guidance_scale = sample_guidance_scale,
|
183 |
-
train_batch_size = train_batch_size,
|
184 |
-
train_use_8bit_adam = train_use_8bit_adam,
|
185 |
-
train_learning_rate = train_learning_rate,
|
186 |
-
train_adam_beta1 = train_adam_beta1,
|
187 |
-
train_adam_beta2 = train_adam_beta2,
|
188 |
-
train_adam_weight_decay = train_adam_weight_decay,
|
189 |
-
train_adam_epsilon = train_adam_epsilon,
|
190 |
-
train_gradient_accumulation_steps = train_gradient_accumulation_steps,
|
191 |
-
train_max_grad_norm = train_max_grad_norm,
|
192 |
-
negative_prompts = negative_prompts,
|
193 |
-
truncated_backprop_rand = truncated_backprop_rand,
|
194 |
-
truncated_backprop_timestep = truncated_backprop_timestep,
|
195 |
-
push_to_hub = push_to_hub,**kwargs)
|
196 |
-
self.vllm_sampling_params = vllm_sampling_params
|
197 |
-
self.unsloth_num_chunks = unsloth_num_chunks
|
198 |
-
pass
|
199 |
-
|
200 |
-
class _UnslothAlignPropTrainer(PyTorchModelHubMixin):
|
201 |
-
""""""
|
202 |
-
|
203 |
-
_tag_names = ["trl", "alignprop"]
|
204 |
-
|
205 |
-
def __init__(
|
206 |
-
self,
|
207 |
-
config: AlignPropConfig,
|
208 |
-
reward_function: Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor],
|
209 |
-
prompt_function: Callable[[], tuple[str, Any]],
|
210 |
-
sd_pipeline: DDPOStableDiffusionPipeline,
|
211 |
-
image_samples_hook: Optional[Callable[[Any, Any, Any], Any]] = None,
|
212 |
-
):
|
213 |
-
if image_samples_hook is None:
|
214 |
-
warn("No image_samples_hook provided; no images will be logged")
|
215 |
-
|
216 |
-
self.prompt_fn = prompt_function
|
217 |
-
self.reward_fn = reward_function
|
218 |
-
self.config = config
|
219 |
-
self.image_samples_callback = image_samples_hook
|
220 |
-
|
221 |
-
accelerator_project_config = ProjectConfiguration(**self.config.project_kwargs)
|
222 |
-
|
223 |
-
if self.config.resume_from:
|
224 |
-
self.config.resume_from = os.path.normpath(os.path.expanduser(self.config.resume_from))
|
225 |
-
if "checkpoint_" not in os.path.basename(self.config.resume_from):
|
226 |
-
# get the most recent checkpoint in this directory
|
227 |
-
checkpoints = list(
|
228 |
-
filter(
|
229 |
-
lambda x: "checkpoint_" in x,
|
230 |
-
os.listdir(self.config.resume_from),
|
231 |
-
)
|
232 |
-
)
|
233 |
-
if len(checkpoints) == 0:
|
234 |
-
raise ValueError(f"No checkpoints found in {self.config.resume_from}")
|
235 |
-
checkpoint_numbers = sorted([int(x.split("_")[-1]) for x in checkpoints])
|
236 |
-
self.config.resume_from = os.path.join(
|
237 |
-
self.config.resume_from,
|
238 |
-
f"checkpoint_{checkpoint_numbers[-1]}",
|
239 |
-
)
|
240 |
-
|
241 |
-
accelerator_project_config.iteration = checkpoint_numbers[-1] + 1
|
242 |
-
|
243 |
-
self.accelerator = Accelerator(
|
244 |
-
log_with=self.config.log_with,
|
245 |
-
mixed_precision=self.config.mixed_precision,
|
246 |
-
project_config=accelerator_project_config,
|
247 |
-
# we always accumulate gradients across timesteps; we want config.train.gradient_accumulation_steps to be the
|
248 |
-
# number of *samples* we accumulate across, so we need to multiply by the number of training timesteps to get
|
249 |
-
# the total number of optimizer steps to accumulate across.
|
250 |
-
gradient_accumulation_steps=self.config.train_gradient_accumulation_steps,
|
251 |
-
**self.config.accelerator_kwargs,
|
252 |
-
)
|
253 |
-
|
254 |
-
is_using_tensorboard = config.log_with is not None and config.log_with == "tensorboard"
|
255 |
-
|
256 |
-
if self.accelerator.is_main_process:
|
257 |
-
self.accelerator.init_trackers(
|
258 |
-
self.config.tracker_project_name,
|
259 |
-
config=dict(alignprop_trainer_config=config.to_dict())
|
260 |
-
if not is_using_tensorboard
|
261 |
-
else config.to_dict(),
|
262 |
-
init_kwargs=self.config.tracker_kwargs,
|
263 |
-
)
|
264 |
-
|
265 |
-
logger.info(f"\n{config}")
|
266 |
-
|
267 |
-
set_seed(self.config.seed, device_specific=True)
|
268 |
-
|
269 |
-
self.sd_pipeline = sd_pipeline
|
270 |
-
|
271 |
-
self.sd_pipeline.set_progress_bar_config(
|
272 |
-
position=1,
|
273 |
-
disable=not self.accelerator.is_local_main_process,
|
274 |
-
leave=False,
|
275 |
-
desc="Timestep",
|
276 |
-
dynamic_ncols=True,
|
277 |
-
)
|
278 |
-
|
279 |
-
# For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
|
280 |
-
# as these weights are only used for inference, keeping weights in full precision is not required.
|
281 |
-
if self.accelerator.mixed_precision == "fp16":
|
282 |
-
inference_dtype = torch.float16
|
283 |
-
elif self.accelerator.mixed_precision == "bf16":
|
284 |
-
inference_dtype = torch.bfloat16
|
285 |
-
else:
|
286 |
-
inference_dtype = torch.float32
|
287 |
-
|
288 |
-
self.sd_pipeline.vae.to(self.accelerator.device, dtype=inference_dtype)
|
289 |
-
self.sd_pipeline.text_encoder.to(self.accelerator.device, dtype=inference_dtype)
|
290 |
-
self.sd_pipeline.unet.to(self.accelerator.device, dtype=inference_dtype)
|
291 |
-
|
292 |
-
trainable_layers = self.sd_pipeline.get_trainable_layers()
|
293 |
-
|
294 |
-
self.accelerator.register_save_state_pre_hook(self._save_model_hook)
|
295 |
-
self.accelerator.register_load_state_pre_hook(self._load_model_hook)
|
296 |
-
|
297 |
-
# Enable TF32 for faster training on Ampere GPUs,
|
298 |
-
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
299 |
-
if self.config.allow_tf32:
|
300 |
-
torch.backends.cuda.matmul.allow_tf32 = True
|
301 |
-
|
302 |
-
self.optimizer = self._setup_optimizer(
|
303 |
-
trainable_layers.parameters() if not isinstance(trainable_layers, list) else trainable_layers
|
304 |
-
)
|
305 |
-
|
306 |
-
self.neg_prompt_embed = self.sd_pipeline.text_encoder(
|
307 |
-
self.sd_pipeline.tokenizer(
|
308 |
-
[""] if self.config.negative_prompts is None else self.config.negative_prompts,
|
309 |
-
return_tensors="pt",
|
310 |
-
padding="max_length",
|
311 |
-
truncation=True,
|
312 |
-
max_length=self.sd_pipeline.tokenizer.model_max_length,
|
313 |
-
).input_ids.to(self.accelerator.device)
|
314 |
-
)[0]
|
315 |
-
|
316 |
-
# NOTE: for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses
|
317 |
-
# more memory
|
318 |
-
self.autocast = self.sd_pipeline.autocast or self.accelerator.autocast
|
319 |
-
|
320 |
-
if hasattr(self.sd_pipeline, "use_lora") and self.sd_pipeline.use_lora:
|
321 |
-
unet, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
|
322 |
-
self.trainable_layers = list(filter(lambda p: p.requires_grad, unet.parameters()))
|
323 |
-
else:
|
324 |
-
self.trainable_layers, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
|
325 |
-
|
326 |
-
if config.resume_from:
|
327 |
-
logger.info(f"Resuming from {config.resume_from}")
|
328 |
-
self.accelerator.load_state(config.resume_from)
|
329 |
-
self.first_epoch = int(config.resume_from.split("_")[-1]) + 1
|
330 |
-
else:
|
331 |
-
self.first_epoch = 0
|
332 |
-
|
333 |
-
def compute_rewards(self, prompt_image_pairs):
|
334 |
-
reward, reward_metadata = self.reward_fn(
|
335 |
-
prompt_image_pairs["images"], prompt_image_pairs["prompts"], prompt_image_pairs["prompt_metadata"]
|
336 |
-
)
|
337 |
-
return reward
|
338 |
-
|
339 |
-
def step(self, epoch: int, global_step: int):
|
340 |
-
"""
|
341 |
-
Perform a single step of training.
|
342 |
-
|
343 |
-
Args:
|
344 |
-
epoch (int): The current epoch.
|
345 |
-
global_step (int): The current global step.
|
346 |
-
|
347 |
-
Side Effects:
|
348 |
-
- Model weights are updated
|
349 |
-
- Logs the statistics to the accelerator trackers.
|
350 |
-
- If `self.image_samples_callback` is not None, it will be called with the prompt_image_pairs, global_step, and the accelerator tracker.
|
351 |
-
|
352 |
-
Returns:
|
353 |
-
global_step (int): The updated global step.
|
354 |
-
"""
|
355 |
-
info = defaultdict(list)
|
356 |
-
|
357 |
-
self.sd_pipeline.unet.train()
|
358 |
-
|
359 |
-
for _ in range(self.config.train_gradient_accumulation_steps):
|
360 |
-
with self.accelerator.accumulate(self.sd_pipeline.unet), self.autocast(), torch.enable_grad():
|
361 |
-
prompt_image_pairs = self._generate_samples(
|
362 |
-
batch_size=self.config.train_batch_size,
|
363 |
-
)
|
364 |
-
|
365 |
-
rewards = self.compute_rewards(prompt_image_pairs)
|
366 |
-
|
367 |
-
prompt_image_pairs["rewards"] = rewards
|
368 |
-
|
369 |
-
rewards_vis = self.accelerator.gather(rewards).detach().cpu().numpy()
|
370 |
-
|
371 |
-
loss = self.calculate_loss(rewards)
|
372 |
-
|
373 |
-
self.accelerator.backward(loss)
|
374 |
-
|
375 |
-
if self.accelerator.sync_gradients:
|
376 |
-
self.accelerator.clip_grad_norm_(
|
377 |
-
self.trainable_layers.parameters()
|
378 |
-
if not isinstance(self.trainable_layers, list)
|
379 |
-
else self.trainable_layers,
|
380 |
-
self.config.train_max_grad_norm,
|
381 |
-
)
|
382 |
-
|
383 |
-
self.optimizer.step()
|
384 |
-
self.optimizer.zero_grad()
|
385 |
-
|
386 |
-
info["reward_mean"].append(rewards_vis.mean())
|
387 |
-
info["reward_std"].append(rewards_vis.std())
|
388 |
-
info["loss"].append(loss.item())
|
389 |
-
|
390 |
-
# Checks if the accelerator has performed an optimization step behind the scenes
|
391 |
-
if self.accelerator.sync_gradients:
|
392 |
-
# log training-related stuff
|
393 |
-
info = {k: torch.mean(torch.tensor(v)) for k, v in info.items()}
|
394 |
-
info = self.accelerator.reduce(info, reduction="mean")
|
395 |
-
info.update({"epoch": epoch})
|
396 |
-
self.accelerator.log(info, step=global_step)
|
397 |
-
global_step += 1
|
398 |
-
info = defaultdict(list)
|
399 |
-
else:
|
400 |
-
raise ValueError(
|
401 |
-
"Optimization step should have been performed by this point. Please check calculated gradient accumulation settings."
|
402 |
-
)
|
403 |
-
# Logs generated images
|
404 |
-
if self.image_samples_callback is not None and global_step % self.config.log_image_freq == 0:
|
405 |
-
self.image_samples_callback(prompt_image_pairs, global_step, self.accelerator.trackers[0])
|
406 |
-
|
407 |
-
if epoch != 0 and epoch % self.config.save_freq == 0 and self.accelerator.is_main_process:
|
408 |
-
self.accelerator.save_state()
|
409 |
-
|
410 |
-
return global_step
|
411 |
-
|
412 |
-
def calculate_loss(self, rewards):
|
413 |
-
"""
|
414 |
-
Calculate the loss for a batch of an unpacked sample
|
415 |
-
|
416 |
-
Args:
|
417 |
-
rewards (torch.Tensor):
|
418 |
-
Differentiable reward scalars for each generated image, shape: [batch_size]
|
419 |
-
|
420 |
-
Returns:
|
421 |
-
loss (torch.Tensor)
|
422 |
-
(all of these are of shape (1,))
|
423 |
-
"""
|
424 |
-
# Loss is specific to Aesthetic Reward function used in AlignProp (https://huggingface.co/papers/2310.03739)
|
425 |
-
loss = 10.0 - (rewards).mean()
|
426 |
-
return loss
|
427 |
-
|
428 |
-
def loss(
|
429 |
-
self,
|
430 |
-
advantages: torch.Tensor,
|
431 |
-
clip_range: float,
|
432 |
-
ratio: torch.Tensor,
|
433 |
-
):
|
434 |
-
unclipped_loss = -advantages * ratio
|
435 |
-
clipped_loss = -advantages * torch.clamp(
|
436 |
-
ratio,
|
437 |
-
1.0 - clip_range,
|
438 |
-
1.0 + clip_range,
|
439 |
-
)
|
440 |
-
return torch.mean(torch.maximum(unclipped_loss, clipped_loss))
|
441 |
-
|
442 |
-
def _setup_optimizer(self, trainable_layers_parameters):
|
443 |
-
if self.config.train_use_8bit_adam:
|
444 |
-
import bitsandbytes
|
445 |
-
|
446 |
-
optimizer_cls = bitsandbytes.optim.AdamW8bit
|
447 |
-
else:
|
448 |
-
optimizer_cls = torch.optim.AdamW
|
449 |
-
|
450 |
-
return optimizer_cls(
|
451 |
-
trainable_layers_parameters,
|
452 |
-
lr=self.config.train_learning_rate,
|
453 |
-
betas=(self.config.train_adam_beta1, self.config.train_adam_beta2),
|
454 |
-
weight_decay=self.config.train_adam_weight_decay,
|
455 |
-
eps=self.config.train_adam_epsilon,
|
456 |
-
)
|
457 |
-
|
458 |
-
def _save_model_hook(self, models, weights, output_dir):
|
459 |
-
self.sd_pipeline.save_checkpoint(models, weights, output_dir)
|
460 |
-
weights.pop() # ensures that accelerate doesn't try to handle saving of the model
|
461 |
-
|
462 |
-
def _load_model_hook(self, models, input_dir):
|
463 |
-
self.sd_pipeline.load_checkpoint(models, input_dir)
|
464 |
-
models.pop() # ensures that accelerate doesn't try to handle loading of the model
|
465 |
-
|
466 |
-
def _generate_samples(self, batch_size, with_grad=True, prompts=None):
|
467 |
-
"""
|
468 |
-
Generate samples from the model
|
469 |
-
|
470 |
-
Args:
|
471 |
-
batch_size (int): Batch size to use for sampling
|
472 |
-
with_grad (bool): Whether the generated RGBs should have gradients attached to it.
|
473 |
-
|
474 |
-
Returns:
|
475 |
-
prompt_image_pairs (dict[Any])
|
476 |
-
"""
|
477 |
-
prompt_image_pairs = {}
|
478 |
-
|
479 |
-
sample_neg_prompt_embeds = self.neg_prompt_embed.repeat(batch_size, 1, 1)
|
480 |
-
|
481 |
-
if prompts is None:
|
482 |
-
prompts, prompt_metadata = zip(*[self.prompt_fn() for _ in range(batch_size)])
|
483 |
-
else:
|
484 |
-
prompt_metadata = [{} for _ in range(batch_size)]
|
485 |
-
|
486 |
-
prompt_ids = self.sd_pipeline.tokenizer(
|
487 |
-
prompts,
|
488 |
-
return_tensors="pt",
|
489 |
-
padding="max_length",
|
490 |
-
truncation=True,
|
491 |
-
max_length=self.sd_pipeline.tokenizer.model_max_length,
|
492 |
-
).input_ids.to(self.accelerator.device)
|
493 |
-
|
494 |
-
prompt_embeds = self.sd_pipeline.text_encoder(prompt_ids)[0]
|
495 |
-
|
496 |
-
if with_grad:
|
497 |
-
sd_output = self.sd_pipeline.rgb_with_grad(
|
498 |
-
prompt_embeds=prompt_embeds,
|
499 |
-
negative_prompt_embeds=sample_neg_prompt_embeds,
|
500 |
-
num_inference_steps=self.config.sample_num_steps,
|
501 |
-
guidance_scale=self.config.sample_guidance_scale,
|
502 |
-
eta=self.config.sample_eta,
|
503 |
-
truncated_backprop_rand=self.config.truncated_backprop_rand,
|
504 |
-
truncated_backprop_timestep=self.config.truncated_backprop_timestep,
|
505 |
-
truncated_rand_backprop_minmax=self.config.truncated_rand_backprop_minmax,
|
506 |
-
output_type="pt",
|
507 |
-
)
|
508 |
-
else:
|
509 |
-
sd_output = self.sd_pipeline(
|
510 |
-
prompt_embeds=prompt_embeds,
|
511 |
-
negative_prompt_embeds=sample_neg_prompt_embeds,
|
512 |
-
num_inference_steps=self.config.sample_num_steps,
|
513 |
-
guidance_scale=self.config.sample_guidance_scale,
|
514 |
-
eta=self.config.sample_eta,
|
515 |
-
output_type="pt",
|
516 |
-
)
|
517 |
-
|
518 |
-
images = sd_output.images
|
519 |
-
|
520 |
-
prompt_image_pairs["images"] = images
|
521 |
-
prompt_image_pairs["prompts"] = prompts
|
522 |
-
prompt_image_pairs["prompt_metadata"] = prompt_metadata
|
523 |
-
|
524 |
-
return prompt_image_pairs
|
525 |
-
|
526 |
-
def train(self, epochs: Optional[int] = None):
|
527 |
-
"""
|
528 |
-
Train the model for a given number of epochs
|
529 |
-
"""
|
530 |
-
global_step = 0
|
531 |
-
if epochs is None:
|
532 |
-
epochs = self.config.num_epochs
|
533 |
-
for epoch in range(self.first_epoch, epochs):
|
534 |
-
global_step = self.step(epoch, global_step)
|
535 |
-
|
536 |
-
def _save_pretrained(self, save_directory):
|
537 |
-
self.sd_pipeline.save_pretrained(save_directory)
|
538 |
-
self.create_model_card()
|
539 |
-
|
540 |
-
def create_model_card(
|
541 |
-
self,
|
542 |
-
model_name: Optional[str] = None,
|
543 |
-
dataset_name: Optional[str] = None,
|
544 |
-
tags: Union[str, list[str], None] = None,
|
545 |
-
):
|
546 |
-
"""
|
547 |
-
Creates a draft of a model card using the information available to the `Trainer`.
|
548 |
-
|
549 |
-
Args:
|
550 |
-
model_name (`str` or `None`, *optional*, defaults to `None`):
|
551 |
-
Name of the model.
|
552 |
-
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
553 |
-
Name of the dataset used for training.
|
554 |
-
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
555 |
-
Tags to be associated with the model card.
|
556 |
-
"""
|
557 |
-
if not self.is_world_process_zero():
|
558 |
-
return
|
559 |
-
|
560 |
-
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
561 |
-
base_model = self.model.config._name_or_path
|
562 |
-
else:
|
563 |
-
base_model = None
|
564 |
-
|
565 |
-
tags = tags or []
|
566 |
-
if isinstance(tags, str):
|
567 |
-
tags = [tags]
|
568 |
-
|
569 |
-
if hasattr(self.model.config, "unsloth_version"):
|
570 |
-
tags.append("unsloth")
|
571 |
-
|
572 |
-
citation = textwrap.dedent("""\
|
573 |
-
@article{prabhudesai2024aligning,
|
574 |
-
title = {{Aligning Text-to-Image Diffusion Models with Reward Backpropagation}},
|
575 |
-
author = {Mihir Prabhudesai and Anirudh Goyal and Deepak Pathak and Katerina Fragkiadaki},
|
576 |
-
year = 2024,
|
577 |
-
eprint = {arXiv:2310.03739}
|
578 |
-
}""")
|
579 |
-
|
580 |
-
model_card = generate_model_card(
|
581 |
-
base_model=base_model,
|
582 |
-
model_name=model_name,
|
583 |
-
hub_model_id=self.hub_model_id,
|
584 |
-
dataset_name=dataset_name,
|
585 |
-
tags=tags,
|
586 |
-
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
587 |
-
comet_url=get_comet_experiment_url(),
|
588 |
-
trainer_name="AlignProp",
|
589 |
-
trainer_citation=citation,
|
590 |
-
paper_title="Aligning Text-to-Image Diffusion Models with Reward Backpropagation",
|
591 |
-
paper_id="2310.03739",
|
592 |
-
)
|
593 |
-
|
594 |
-
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
595 |
-
class UnslothAlignPropTrainer(_UnslothAlignPropTrainer):
|
596 |
-
"""
|
597 |
-
|
598 |
-
The AlignPropTrainer uses Deep Diffusion Policy Optimization to optimise diffusion models.
|
599 |
-
Note, this trainer is heavily inspired by the work here: https://github.com/mihirp1998/AlignProp/
|
600 |
-
As of now only Stable Diffusion based pipelines are supported
|
601 |
-
|
602 |
-
Attributes:
|
603 |
-
config (`AlignPropConfig`):
|
604 |
-
Configuration object for AlignPropTrainer. Check the documentation of `PPOConfig` for more details.
|
605 |
-
reward_function (`Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor]`):
|
606 |
-
Reward function to be used
|
607 |
-
prompt_function (`Callable[[], tuple[str, Any]]`):
|
608 |
-
Function to generate prompts to guide model
|
609 |
-
sd_pipeline (`DDPOStableDiffusionPipeline`):
|
610 |
-
Stable Diffusion pipeline to be used for training.
|
611 |
-
image_samples_hook (`Optional[Callable[[Any, Any, Any], Any]]`):
|
612 |
-
Hook to be called to log images
|
613 |
-
|
614 |
-
"""
|
615 |
-
def __init__(
|
616 |
-
self,
|
617 |
-
config,
|
618 |
-
reward_function,
|
619 |
-
prompt_function,
|
620 |
-
sd_pipeline,
|
621 |
-
image_samples_hook = None,
|
622 |
-
**kwargs
|
623 |
-
):
|
624 |
-
if args is None: args = UnslothAlignPropConfig()
|
625 |
-
other_metrics = []
|
626 |
-
|
627 |
-
from unsloth_zoo.logging_utils import PatchRLStatistics
|
628 |
-
PatchRLStatistics('alignprop_trainer', other_metrics)
|
629 |
-
|
630 |
-
super().__init__(
|
631 |
-
config = config,
|
632 |
-
reward_function = reward_function,
|
633 |
-
prompt_function = prompt_function,
|
634 |
-
sd_pipeline = sd_pipeline,
|
635 |
-
image_samples_hook = image_samples_hook,**kwargs)
|
636 |
-
|
637 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unsloth_compiled_cache/UnslothBCOTrainer.py
DELETED
@@ -1,1822 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
2025.3.13
|
3 |
-
2025.3.15
|
4 |
-
4.48.3
|
5 |
-
0.15.2
|
6 |
-
__UNSLOTH_VERSIONING__
|
7 |
-
"""
|
8 |
-
from torch import Tensor
|
9 |
-
import torch
|
10 |
-
import torch.nn as nn
|
11 |
-
from torch.nn import functional as F
|
12 |
-
from trl.trainer.bco_trainer import (Any, AutoModelForCausalLM, BCOConfig, BCOTrainer, BaseImageProcessor, CLF_NAME, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, Literal, LogisticRegression, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedModelWrapper, PreTrainedTokenizerBase, ProcessorMixin, RUNNING_NAME, RunningMoments, SequentialSampler, Trainer, TrainerCallback, TrainingArguments, Union, _process_tokens, _tokenize, amp, contextmanager, create_reference_model, deepcopy, defaultdict, disable_dropout_in_model, generate_model_card, get_comet_experiment_url, has_length, inspect, is_comet_available, is_peft_available, is_sklearn_available, is_wandb_available, itemgetter, log_table_to_comet_experiment, maybe_apply_chat_template, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_model_for_kbit_training, random, textwrap, torch, tqdm, transformers, version, wandb, warnings, F, Optional, PeftModel, PreTrainedModel, Trainer, is_peft_available, os, torch)
|
13 |
-
|
14 |
-
|
15 |
-
import os
|
16 |
-
from typing import *
|
17 |
-
from dataclasses import dataclass, field
|
18 |
-
from packaging.version import Version
|
19 |
-
import torch
|
20 |
-
import numpy as np
|
21 |
-
from contextlib import nullcontext
|
22 |
-
from torch.nn import functional as F
|
23 |
-
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
|
24 |
-
|
25 |
-
torch_compile_options = {
|
26 |
-
"epilogue_fusion" : True,
|
27 |
-
"max_autotune" : False,
|
28 |
-
"shape_padding" : True,
|
29 |
-
"trace.enabled" : False,
|
30 |
-
"triton.cudagraphs" : False,
|
31 |
-
}
|
32 |
-
|
33 |
-
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
34 |
-
def selective_log_softmax(logits, index):
|
35 |
-
logits = logits.to(torch.float32)
|
36 |
-
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
37 |
-
# loop to reduce peak mem consumption
|
38 |
-
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
39 |
-
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
40 |
-
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
41 |
-
return per_token_logps
|
42 |
-
@dataclass
|
43 |
-
class UnslothBCOConfig(BCOConfig):
|
44 |
-
"""
|
45 |
-
|
46 |
-
Configuration class for the [`BCOTrainer`].
|
47 |
-
|
48 |
-
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
49 |
-
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
50 |
-
command line.
|
51 |
-
|
52 |
-
Parameters:
|
53 |
-
max_length (`int` or `None`, *optional*, defaults to `1024`):
|
54 |
-
Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
|
55 |
-
to use the default data collator.
|
56 |
-
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
|
57 |
-
Maximum length of the prompt. This argument is required if you want to use the default data collator.
|
58 |
-
max_completion_length (`int` or `None`, *optional*, defaults to `None`):
|
59 |
-
Maximum length of the completion. This argument is required if you want to use the default data collator
|
60 |
-
and your model is an encoder-decoder.
|
61 |
-
beta (`float`, *optional*, defaults to `0.1`):
|
62 |
-
Parameter controlling the deviation from the reference model. Higher β means less deviation from the
|
63 |
-
reference model.
|
64 |
-
label_pad_token_id (`int`, *optional*, defaults to `-100`):
|
65 |
-
Label pad token id. This argument is required if you want to use the default data collator.
|
66 |
-
padding_value (`int` or `None`, *optional*, defaults to `None`):
|
67 |
-
Padding value to use. If `None`, the padding value of the tokenizer is used.
|
68 |
-
truncation_mode (`str`, *optional*, defaults to `"keep_end"`):
|
69 |
-
Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`.
|
70 |
-
This argument is required if you want to use the default data collator.
|
71 |
-
disable_dropout (`bool`, *optional*, defaults to `True`):
|
72 |
-
Whether to disable dropout in the model and reference model.
|
73 |
-
generate_during_eval (`bool`, *optional*, defaults to `False`):
|
74 |
-
If `True`, generates and logs completions from both the model and the reference model to W&B or Comet during
|
75 |
-
evaluation.
|
76 |
-
is_encoder_decoder (`bool` or `None`, *optional*, defaults to `None`):
|
77 |
-
When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument,
|
78 |
-
you need to specify if the model returned by the callable is an encoder-decoder model.
|
79 |
-
precompute_ref_log_probs (`bool`, *optional*, defaults to `False`):
|
80 |
-
Whether to precompute reference model log probabilities for training and evaluation datasets. This is
|
81 |
-
useful when training without the reference model to reduce the total GPU memory needed.
|
82 |
-
model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
|
83 |
-
Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
|
84 |
-
string.
|
85 |
-
ref_model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
|
86 |
-
Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the reference model
|
87 |
-
from a string.
|
88 |
-
dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
|
89 |
-
Number of processes to use for processing the dataset.
|
90 |
-
prompt_sample_size (`int`, *optional*, defaults to `1024`):
|
91 |
-
Number of prompts that are fed to density ratio classifier.
|
92 |
-
min_density_ratio (`float`, *optional*, defaults to `0.5`):
|
93 |
-
Minimum value of the density ratio. The estimated density ratio is clamped to this value.
|
94 |
-
max_density_ratio (`float`, *optional*, defaults to `10.0`):
|
95 |
-
Maximum value of the density ratio. The estimated density ratio is clamped to this value.
|
96 |
-
|
97 |
-
"""
|
98 |
-
vllm_sampling_params: Optional[Any] = field(
|
99 |
-
default = None,
|
100 |
-
metadata = {'help': 'vLLM SamplingParams'},
|
101 |
-
)
|
102 |
-
unsloth_num_chunks : Optional[int] = field(
|
103 |
-
default = -1,
|
104 |
-
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
105 |
-
)
|
106 |
-
def __init__(
|
107 |
-
self,
|
108 |
-
output_dir = None,
|
109 |
-
overwrite_output_dir = None,
|
110 |
-
do_train = False,
|
111 |
-
do_eval = False,
|
112 |
-
do_predict = False,
|
113 |
-
eval_strategy = 'no',
|
114 |
-
prediction_loss_only = False,
|
115 |
-
per_device_train_batch_size = 4,
|
116 |
-
per_device_eval_batch_size = 4,
|
117 |
-
per_gpu_train_batch_size = None,
|
118 |
-
per_gpu_eval_batch_size = None,
|
119 |
-
gradient_accumulation_steps = 2,
|
120 |
-
eval_accumulation_steps = 2,
|
121 |
-
eval_delay = 0,
|
122 |
-
torch_empty_cache_steps = 250,
|
123 |
-
learning_rate = 5e-05,
|
124 |
-
weight_decay = 0.01,
|
125 |
-
adam_beta1 = 0.9,
|
126 |
-
adam_beta2 = 0.999,
|
127 |
-
adam_epsilon = 1e-08,
|
128 |
-
max_grad_norm = 1.0,
|
129 |
-
num_train_epochs = 3.0,
|
130 |
-
max_steps = -1,
|
131 |
-
lr_scheduler_type = 'linear',
|
132 |
-
warmup_ratio = 0.1,
|
133 |
-
warmup_steps = 0,
|
134 |
-
log_level = 'passive',
|
135 |
-
log_level_replica = 'warning',
|
136 |
-
log_on_each_node = True,
|
137 |
-
logging_dir = None,
|
138 |
-
logging_strategy = 'steps',
|
139 |
-
logging_first_step = False,
|
140 |
-
logging_steps = 1,
|
141 |
-
logging_nan_inf_filter = False,
|
142 |
-
save_strategy = 'steps',
|
143 |
-
save_steps = 500,
|
144 |
-
save_total_limit = None,
|
145 |
-
save_safetensors = True,
|
146 |
-
save_on_each_node = False,
|
147 |
-
save_only_model = False,
|
148 |
-
restore_callback_states_from_checkpoint = False,
|
149 |
-
no_cuda = False,
|
150 |
-
use_cpu = False,
|
151 |
-
use_mps_device = False,
|
152 |
-
seed = 3407,
|
153 |
-
data_seed = 3407,
|
154 |
-
jit_mode_eval = False,
|
155 |
-
use_ipex = False,
|
156 |
-
bf16 = False,
|
157 |
-
fp16 = False,
|
158 |
-
fp16_opt_level = 'O1',
|
159 |
-
half_precision_backend = 'auto',
|
160 |
-
bf16_full_eval = False,
|
161 |
-
fp16_full_eval = False,
|
162 |
-
tf32 = None,
|
163 |
-
local_rank = -1,
|
164 |
-
ddp_backend = None,
|
165 |
-
tpu_num_cores = None,
|
166 |
-
tpu_metrics_debug = False,
|
167 |
-
debug = '',
|
168 |
-
dataloader_drop_last = False,
|
169 |
-
eval_steps = None,
|
170 |
-
dataloader_num_workers = 0,
|
171 |
-
dataloader_prefetch_factor = None,
|
172 |
-
past_index = -1,
|
173 |
-
run_name = None,
|
174 |
-
disable_tqdm = None,
|
175 |
-
remove_unused_columns = True,
|
176 |
-
label_names = None,
|
177 |
-
load_best_model_at_end = False,
|
178 |
-
metric_for_best_model = None,
|
179 |
-
greater_is_better = None,
|
180 |
-
ignore_data_skip = False,
|
181 |
-
fsdp = '',
|
182 |
-
fsdp_min_num_params = 0,
|
183 |
-
fsdp_config = None,
|
184 |
-
fsdp_transformer_layer_cls_to_wrap = None,
|
185 |
-
accelerator_config = None,
|
186 |
-
deepspeed = None,
|
187 |
-
label_smoothing_factor = 0.0,
|
188 |
-
optim = 'adamw_8bit',
|
189 |
-
optim_args = None,
|
190 |
-
adafactor = False,
|
191 |
-
group_by_length = False,
|
192 |
-
length_column_name = 'length',
|
193 |
-
report_to = None,
|
194 |
-
ddp_find_unused_parameters = None,
|
195 |
-
ddp_bucket_cap_mb = None,
|
196 |
-
ddp_broadcast_buffers = None,
|
197 |
-
dataloader_pin_memory = True,
|
198 |
-
dataloader_persistent_workers = False,
|
199 |
-
skip_memory_metrics = True,
|
200 |
-
use_legacy_prediction_loop = False,
|
201 |
-
push_to_hub = False,
|
202 |
-
resume_from_checkpoint = None,
|
203 |
-
hub_model_id = None,
|
204 |
-
hub_strategy = 'every_save',
|
205 |
-
hub_token = None,
|
206 |
-
hub_private_repo = None,
|
207 |
-
hub_always_push = False,
|
208 |
-
gradient_checkpointing = False,
|
209 |
-
gradient_checkpointing_kwargs = None,
|
210 |
-
include_inputs_for_metrics = False,
|
211 |
-
eval_do_concat_batches = True,
|
212 |
-
fp16_backend = 'auto',
|
213 |
-
evaluation_strategy = None,
|
214 |
-
push_to_hub_model_id = None,
|
215 |
-
push_to_hub_organization = None,
|
216 |
-
push_to_hub_token = None,
|
217 |
-
mp_parameters = '',
|
218 |
-
auto_find_batch_size = False,
|
219 |
-
full_determinism = False,
|
220 |
-
torchdynamo = None,
|
221 |
-
ray_scope = 'last',
|
222 |
-
ddp_timeout = 1800,
|
223 |
-
torch_compile = False,
|
224 |
-
torch_compile_backend = None,
|
225 |
-
torch_compile_mode = None,
|
226 |
-
dispatch_batches = None,
|
227 |
-
split_batches = None,
|
228 |
-
include_tokens_per_second = False,
|
229 |
-
include_num_input_tokens_seen = False,
|
230 |
-
neftune_noise_alpha = None,
|
231 |
-
optim_target_modules = None,
|
232 |
-
batch_eval_metrics = False,
|
233 |
-
eval_on_start = False,
|
234 |
-
use_liger_kernel = False,
|
235 |
-
eval_use_gather_object = False,
|
236 |
-
average_tokens_across_devices = False,
|
237 |
-
max_length = 1024,
|
238 |
-
max_prompt_length = 512,
|
239 |
-
max_completion_length = None,
|
240 |
-
beta = 0.1,
|
241 |
-
label_pad_token_id = -100,
|
242 |
-
padding_value = None,
|
243 |
-
truncation_mode = 'keep_end',
|
244 |
-
disable_dropout = True,
|
245 |
-
generate_during_eval = False,
|
246 |
-
is_encoder_decoder = None,
|
247 |
-
precompute_ref_log_probs = False,
|
248 |
-
model_init_kwargs = None,
|
249 |
-
ref_model_init_kwargs = None,
|
250 |
-
dataset_num_proc = None,
|
251 |
-
prompt_sample_size = 1024,
|
252 |
-
min_density_ratio = 0.5,
|
253 |
-
max_density_ratio = 10.0,
|
254 |
-
vllm_sampling_params = None,
|
255 |
-
unsloth_num_chunks = -1,
|
256 |
-
**kwargs,
|
257 |
-
):
|
258 |
-
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
259 |
-
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
260 |
-
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
261 |
-
output_dir = 'unsloth_training_checkpoints'
|
262 |
-
save_strategy = 'no'
|
263 |
-
if dataset_num_proc is None:
|
264 |
-
from multiprocessing import cpu_count
|
265 |
-
dataset_num_proc = cpu_count()
|
266 |
-
|
267 |
-
super().__init__(
|
268 |
-
output_dir = output_dir,
|
269 |
-
overwrite_output_dir = overwrite_output_dir,
|
270 |
-
do_train = do_train,
|
271 |
-
do_eval = do_eval,
|
272 |
-
do_predict = do_predict,
|
273 |
-
eval_strategy = eval_strategy,
|
274 |
-
prediction_loss_only = prediction_loss_only,
|
275 |
-
per_device_train_batch_size = per_device_train_batch_size,
|
276 |
-
per_device_eval_batch_size = per_device_eval_batch_size,
|
277 |
-
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
278 |
-
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
279 |
-
gradient_accumulation_steps = gradient_accumulation_steps,
|
280 |
-
eval_accumulation_steps = eval_accumulation_steps,
|
281 |
-
eval_delay = eval_delay,
|
282 |
-
torch_empty_cache_steps = torch_empty_cache_steps,
|
283 |
-
learning_rate = learning_rate,
|
284 |
-
weight_decay = weight_decay,
|
285 |
-
adam_beta1 = adam_beta1,
|
286 |
-
adam_beta2 = adam_beta2,
|
287 |
-
adam_epsilon = adam_epsilon,
|
288 |
-
max_grad_norm = max_grad_norm,
|
289 |
-
num_train_epochs = num_train_epochs,
|
290 |
-
max_steps = max_steps,
|
291 |
-
lr_scheduler_type = lr_scheduler_type,
|
292 |
-
warmup_ratio = warmup_ratio,
|
293 |
-
warmup_steps = warmup_steps,
|
294 |
-
log_level = log_level,
|
295 |
-
log_level_replica = log_level_replica,
|
296 |
-
log_on_each_node = log_on_each_node,
|
297 |
-
logging_dir = logging_dir,
|
298 |
-
logging_strategy = logging_strategy,
|
299 |
-
logging_first_step = logging_first_step,
|
300 |
-
logging_steps = logging_steps,
|
301 |
-
logging_nan_inf_filter = logging_nan_inf_filter,
|
302 |
-
save_strategy = save_strategy,
|
303 |
-
save_steps = save_steps,
|
304 |
-
save_total_limit = save_total_limit,
|
305 |
-
save_safetensors = save_safetensors,
|
306 |
-
save_on_each_node = save_on_each_node,
|
307 |
-
save_only_model = save_only_model,
|
308 |
-
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
309 |
-
no_cuda = no_cuda,
|
310 |
-
use_cpu = use_cpu,
|
311 |
-
use_mps_device = use_mps_device,
|
312 |
-
seed = seed,
|
313 |
-
data_seed = data_seed,
|
314 |
-
jit_mode_eval = jit_mode_eval,
|
315 |
-
use_ipex = use_ipex,
|
316 |
-
bf16 = bf16,
|
317 |
-
fp16 = fp16,
|
318 |
-
fp16_opt_level = fp16_opt_level,
|
319 |
-
half_precision_backend = half_precision_backend,
|
320 |
-
bf16_full_eval = bf16_full_eval,
|
321 |
-
fp16_full_eval = fp16_full_eval,
|
322 |
-
tf32 = tf32,
|
323 |
-
local_rank = local_rank,
|
324 |
-
ddp_backend = ddp_backend,
|
325 |
-
tpu_num_cores = tpu_num_cores,
|
326 |
-
tpu_metrics_debug = tpu_metrics_debug,
|
327 |
-
debug = debug,
|
328 |
-
dataloader_drop_last = dataloader_drop_last,
|
329 |
-
eval_steps = eval_steps,
|
330 |
-
dataloader_num_workers = dataloader_num_workers,
|
331 |
-
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
332 |
-
past_index = past_index,
|
333 |
-
run_name = run_name,
|
334 |
-
disable_tqdm = disable_tqdm,
|
335 |
-
remove_unused_columns = remove_unused_columns,
|
336 |
-
label_names = label_names,
|
337 |
-
load_best_model_at_end = load_best_model_at_end,
|
338 |
-
metric_for_best_model = metric_for_best_model,
|
339 |
-
greater_is_better = greater_is_better,
|
340 |
-
ignore_data_skip = ignore_data_skip,
|
341 |
-
fsdp = fsdp,
|
342 |
-
fsdp_min_num_params = fsdp_min_num_params,
|
343 |
-
fsdp_config = fsdp_config,
|
344 |
-
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
345 |
-
accelerator_config = accelerator_config,
|
346 |
-
deepspeed = deepspeed,
|
347 |
-
label_smoothing_factor = label_smoothing_factor,
|
348 |
-
optim = optim,
|
349 |
-
optim_args = optim_args,
|
350 |
-
adafactor = adafactor,
|
351 |
-
group_by_length = group_by_length,
|
352 |
-
length_column_name = length_column_name,
|
353 |
-
report_to = report_to,
|
354 |
-
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
355 |
-
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
356 |
-
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
357 |
-
dataloader_pin_memory = dataloader_pin_memory,
|
358 |
-
dataloader_persistent_workers = dataloader_persistent_workers,
|
359 |
-
skip_memory_metrics = skip_memory_metrics,
|
360 |
-
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
361 |
-
push_to_hub = push_to_hub,
|
362 |
-
resume_from_checkpoint = resume_from_checkpoint,
|
363 |
-
hub_model_id = hub_model_id,
|
364 |
-
hub_strategy = hub_strategy,
|
365 |
-
hub_token = hub_token,
|
366 |
-
hub_private_repo = hub_private_repo,
|
367 |
-
hub_always_push = hub_always_push,
|
368 |
-
gradient_checkpointing = gradient_checkpointing,
|
369 |
-
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
370 |
-
include_inputs_for_metrics = include_inputs_for_metrics,
|
371 |
-
eval_do_concat_batches = eval_do_concat_batches,
|
372 |
-
fp16_backend = fp16_backend,
|
373 |
-
evaluation_strategy = evaluation_strategy,
|
374 |
-
push_to_hub_model_id = push_to_hub_model_id,
|
375 |
-
push_to_hub_organization = push_to_hub_organization,
|
376 |
-
push_to_hub_token = push_to_hub_token,
|
377 |
-
mp_parameters = mp_parameters,
|
378 |
-
auto_find_batch_size = auto_find_batch_size,
|
379 |
-
full_determinism = full_determinism,
|
380 |
-
torchdynamo = torchdynamo,
|
381 |
-
ray_scope = ray_scope,
|
382 |
-
ddp_timeout = ddp_timeout,
|
383 |
-
torch_compile = torch_compile,
|
384 |
-
torch_compile_backend = torch_compile_backend,
|
385 |
-
torch_compile_mode = torch_compile_mode,
|
386 |
-
dispatch_batches = dispatch_batches,
|
387 |
-
split_batches = split_batches,
|
388 |
-
include_tokens_per_second = include_tokens_per_second,
|
389 |
-
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
390 |
-
neftune_noise_alpha = neftune_noise_alpha,
|
391 |
-
optim_target_modules = optim_target_modules,
|
392 |
-
batch_eval_metrics = batch_eval_metrics,
|
393 |
-
eval_on_start = eval_on_start,
|
394 |
-
use_liger_kernel = use_liger_kernel,
|
395 |
-
eval_use_gather_object = eval_use_gather_object,
|
396 |
-
average_tokens_across_devices = average_tokens_across_devices,
|
397 |
-
max_length = max_length,
|
398 |
-
max_prompt_length = max_prompt_length,
|
399 |
-
max_completion_length = max_completion_length,
|
400 |
-
beta = beta,
|
401 |
-
label_pad_token_id = label_pad_token_id,
|
402 |
-
padding_value = padding_value,
|
403 |
-
truncation_mode = truncation_mode,
|
404 |
-
disable_dropout = disable_dropout,
|
405 |
-
generate_during_eval = generate_during_eval,
|
406 |
-
is_encoder_decoder = is_encoder_decoder,
|
407 |
-
precompute_ref_log_probs = precompute_ref_log_probs,
|
408 |
-
model_init_kwargs = model_init_kwargs,
|
409 |
-
ref_model_init_kwargs = ref_model_init_kwargs,
|
410 |
-
dataset_num_proc = dataset_num_proc,
|
411 |
-
prompt_sample_size = prompt_sample_size,
|
412 |
-
min_density_ratio = min_density_ratio,
|
413 |
-
max_density_ratio = max_density_ratio,**kwargs)
|
414 |
-
self.vllm_sampling_params = vllm_sampling_params
|
415 |
-
self.unsloth_num_chunks = unsloth_num_chunks
|
416 |
-
pass
|
417 |
-
|
418 |
-
class _UnslothBCOTrainer(Trainer):
|
419 |
-
r""""""
|
420 |
-
|
421 |
-
_tag_names = ["trl", "bco"]
|
422 |
-
|
423 |
-
def __init__(
|
424 |
-
self,
|
425 |
-
model: Union[PreTrainedModel, nn.Module, str] = None,
|
426 |
-
ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
|
427 |
-
args: BCOConfig = None,
|
428 |
-
train_dataset: Optional[Dataset] = None,
|
429 |
-
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
430 |
-
processing_class: Optional[
|
431 |
-
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
432 |
-
] = None,
|
433 |
-
data_collator: Optional[DataCollator] = None,
|
434 |
-
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
435 |
-
callbacks: Optional[list[TrainerCallback]] = None,
|
436 |
-
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
437 |
-
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
438 |
-
peft_config: Optional[dict] = None,
|
439 |
-
compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
|
440 |
-
model_adapter_name: Optional[str] = None,
|
441 |
-
ref_adapter_name: Optional[str] = None,
|
442 |
-
embedding_func: Optional[Callable] = None,
|
443 |
-
embedding_tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
444 |
-
):
|
445 |
-
if not is_sklearn_available():
|
446 |
-
raise ImportError(
|
447 |
-
"BCOTrainer requires the scikit-learn library. Please install it with `pip install scikit-learn`."
|
448 |
-
)
|
449 |
-
|
450 |
-
if type(args) is TrainingArguments:
|
451 |
-
raise ValueError("Please use `BCOConfig` instead `TrainingArguments`.")
|
452 |
-
|
453 |
-
if not isinstance(model, str) and ref_model is model:
|
454 |
-
raise ValueError(
|
455 |
-
"`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
|
456 |
-
"same as `model`, you must mass a copy of it, or `None` if you use peft."
|
457 |
-
)
|
458 |
-
|
459 |
-
if args.model_init_kwargs is None:
|
460 |
-
model_init_kwargs = {}
|
461 |
-
elif not isinstance(model, str):
|
462 |
-
raise ValueError("You passed model_kwargs to the BCOTrainer. But your model is already instantiated.")
|
463 |
-
else:
|
464 |
-
model_init_kwargs = args.model_init_kwargs
|
465 |
-
torch_dtype = model_init_kwargs.get("torch_dtype")
|
466 |
-
if torch_dtype is not None:
|
467 |
-
# Convert to `torch.dtype` if an str is passed
|
468 |
-
if isinstance(torch_dtype, str) and torch_dtype != "auto":
|
469 |
-
torch_dtype = getattr(torch, torch_dtype)
|
470 |
-
if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
|
471 |
-
raise ValueError(
|
472 |
-
f"Invalid `torch_dtype` passed to the BCOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
|
473 |
-
)
|
474 |
-
model_init_kwargs["torch_dtype"] = torch_dtype
|
475 |
-
|
476 |
-
if args.ref_model_init_kwargs is None:
|
477 |
-
ref_model_init_kwargs = {}
|
478 |
-
elif not isinstance(ref_model, str):
|
479 |
-
raise ValueError(
|
480 |
-
"You passed ref_model_kwargs to the BCOTrainer. But your ref_model is already instantiated."
|
481 |
-
)
|
482 |
-
else:
|
483 |
-
ref_model_init_kwargs = args.ref_model_init_kwargs
|
484 |
-
torch_dtype = ref_model_init_kwargs.get("torch_dtype")
|
485 |
-
if torch_dtype is not None:
|
486 |
-
# Convert to `torch.dtype` if an str is passed
|
487 |
-
if isinstance(torch_dtype, str) and torch_dtype != "auto":
|
488 |
-
torch_dtype = getattr(torch, torch_dtype)
|
489 |
-
if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
|
490 |
-
raise ValueError(
|
491 |
-
f"Invalid `torch_dtype` passed to the BCOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
|
492 |
-
)
|
493 |
-
ref_model_init_kwargs["torch_dtype"] = torch_dtype
|
494 |
-
|
495 |
-
if isinstance(model, str):
|
496 |
-
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
|
497 |
-
|
498 |
-
if isinstance(ref_model, str):
|
499 |
-
ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **ref_model_init_kwargs)
|
500 |
-
|
501 |
-
# Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
|
502 |
-
# has been called in order to properly call autocast if needed.
|
503 |
-
self._peft_has_been_casted_to_bf16 = False
|
504 |
-
|
505 |
-
if not is_peft_available() and peft_config is not None:
|
506 |
-
raise ValueError(
|
507 |
-
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it with `pip install peft` to use the PEFT models"
|
508 |
-
)
|
509 |
-
elif is_peft_available() and peft_config is not None:
|
510 |
-
# if model is a peft model and we have a peft_config, we merge and unload it first
|
511 |
-
if isinstance(model, PeftModel):
|
512 |
-
model = model.merge_and_unload()
|
513 |
-
|
514 |
-
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
|
515 |
-
_support_gc_kwargs = hasattr(
|
516 |
-
args, "gradient_checkpointing_kwargs"
|
517 |
-
) and "gradient_checkpointing_kwargs" in list(
|
518 |
-
inspect.signature(prepare_model_for_kbit_training).parameters
|
519 |
-
)
|
520 |
-
|
521 |
-
prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
|
522 |
-
|
523 |
-
if _support_gc_kwargs:
|
524 |
-
prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
|
525 |
-
|
526 |
-
model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
|
527 |
-
elif getattr(args, "gradient_checkpointing", False):
|
528 |
-
# For backward compatibility with older versions of transformers
|
529 |
-
if hasattr(model, "enable_input_require_grads"):
|
530 |
-
model.enable_input_require_grads()
|
531 |
-
else:
|
532 |
-
|
533 |
-
def make_inputs_require_grad(module, input, output):
|
534 |
-
output.requires_grad_(True)
|
535 |
-
|
536 |
-
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
537 |
-
|
538 |
-
# get peft model with the given config
|
539 |
-
model = model
|
540 |
-
if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
|
541 |
-
peft_module_casting_to_bf16(model)
|
542 |
-
# If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
|
543 |
-
self._peft_has_been_casted_to_bf16 = True
|
544 |
-
|
545 |
-
# For models that use gradient_checkpointing, we need to attach a hook that enables input
|
546 |
-
# to explicitly have `requires_grad=True`, otherwise training will either silently
|
547 |
-
# fail or completely fail.
|
548 |
-
elif getattr(args, "gradient_checkpointing", False):
|
549 |
-
# For backward compatibility with older versions of transformers
|
550 |
-
if hasattr(model, "enable_input_require_grads"):
|
551 |
-
model.enable_input_require_grads()
|
552 |
-
else:
|
553 |
-
|
554 |
-
def make_inputs_require_grad(module, input, output):
|
555 |
-
output.requires_grad_(True)
|
556 |
-
|
557 |
-
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
558 |
-
|
559 |
-
if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
|
560 |
-
raise ValueError(
|
561 |
-
"`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
|
562 |
-
" Please install `wandb` or `comet-ml` to resolve."
|
563 |
-
)
|
564 |
-
|
565 |
-
if model is not None:
|
566 |
-
self.is_encoder_decoder = model.config.is_encoder_decoder
|
567 |
-
elif args.is_encoder_decoder is None:
|
568 |
-
raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
|
569 |
-
else:
|
570 |
-
self.is_encoder_decoder = args.is_encoder_decoder
|
571 |
-
|
572 |
-
self.is_peft_model = is_peft_available() and isinstance(model, PeftModel)
|
573 |
-
self.model_adapter_name = model_adapter_name
|
574 |
-
self.ref_adapter_name = ref_adapter_name
|
575 |
-
|
576 |
-
if ref_model:
|
577 |
-
self.ref_model = ref_model
|
578 |
-
elif self.is_peft_model or args.precompute_ref_log_probs:
|
579 |
-
# The `model` with adapters turned off will be used as the reference model
|
580 |
-
self.ref_model = None
|
581 |
-
else:
|
582 |
-
self.ref_model = create_reference_model(model)
|
583 |
-
|
584 |
-
if processing_class is None:
|
585 |
-
raise ValueError(
|
586 |
-
"max_length or a processing_class must be specified when using the default DPODataCollatorWithPadding"
|
587 |
-
)
|
588 |
-
if args.max_length is None:
|
589 |
-
warnings.warn(
|
590 |
-
"When using DPODataCollatorWithPadding, you should set `max_length` in the `BCOConfig`. "
|
591 |
-
"It will be set to `512` by default, but you should do it yourself in the future.",
|
592 |
-
UserWarning,
|
593 |
-
)
|
594 |
-
max_length = 512
|
595 |
-
if args.max_length is not None:
|
596 |
-
max_length = args.max_length
|
597 |
-
|
598 |
-
if args.max_prompt_length is None:
|
599 |
-
warnings.warn(
|
600 |
-
"When using DPODataCollatorWithPadding, you should set `max_prompt_length` in the `BCOConfig`. "
|
601 |
-
"It will be set to `128` by default, but you should do it yourself in the future.",
|
602 |
-
UserWarning,
|
603 |
-
)
|
604 |
-
max_prompt_length = 128
|
605 |
-
if args.max_prompt_length is not None:
|
606 |
-
max_prompt_length = args.max_prompt_length
|
607 |
-
|
608 |
-
max_completion_length = None
|
609 |
-
if args.max_completion_length is None and self.is_encoder_decoder:
|
610 |
-
warnings.warn(
|
611 |
-
"When using DPODataCollatorWithPadding with an encoder decoder architecture, you should set `max_completion_length` in the BCOTrainer's init"
|
612 |
-
" it will be set to `128` by default, but you should do it yourself in the future.",
|
613 |
-
UserWarning,
|
614 |
-
)
|
615 |
-
max_completion_length = 128
|
616 |
-
if args.max_completion_length is not None and self.is_encoder_decoder:
|
617 |
-
max_completion_length = args.max_completion_length
|
618 |
-
|
619 |
-
if data_collator is None:
|
620 |
-
data_collator = DPODataCollatorWithPadding(
|
621 |
-
pad_token_id=processing_class.pad_token_id,
|
622 |
-
label_pad_token_id=args.label_pad_token_id,
|
623 |
-
is_encoder_decoder=self.is_encoder_decoder,
|
624 |
-
)
|
625 |
-
|
626 |
-
if args.remove_unused_columns:
|
627 |
-
args.remove_unused_columns = False
|
628 |
-
# warn users
|
629 |
-
warnings.warn(
|
630 |
-
"When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your BCOConfig"
|
631 |
-
" we have set it for you, but you should do it yourself in the future.",
|
632 |
-
UserWarning,
|
633 |
-
)
|
634 |
-
|
635 |
-
self.use_dpo_data_collator = True
|
636 |
-
else:
|
637 |
-
self.use_dpo_data_collator = False
|
638 |
-
|
639 |
-
# Disable dropout in the model and reference model
|
640 |
-
if args.disable_dropout:
|
641 |
-
disable_dropout_in_model(model)
|
642 |
-
if self.ref_model is not None:
|
643 |
-
disable_dropout_in_model(self.ref_model)
|
644 |
-
|
645 |
-
self.max_length = max_length
|
646 |
-
self.generate_during_eval = args.generate_during_eval
|
647 |
-
self.label_pad_token_id = args.label_pad_token_id
|
648 |
-
self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id
|
649 |
-
self.max_prompt_length = max_prompt_length
|
650 |
-
self.truncation_mode = args.truncation_mode
|
651 |
-
self.max_completion_length = max_completion_length
|
652 |
-
self.precompute_ref_log_probs = args.precompute_ref_log_probs
|
653 |
-
|
654 |
-
# Since ref_logs are precomputed on the first call to get_train/eval_dataloader
|
655 |
-
# keep track of first called to avoid computation of future calls
|
656 |
-
self._precomputed_train_ref_log_probs = False
|
657 |
-
self._precomputed_eval_ref_log_probs = False
|
658 |
-
|
659 |
-
# metric
|
660 |
-
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
661 |
-
|
662 |
-
# BCO parameter
|
663 |
-
self.beta = args.beta
|
664 |
-
self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
|
665 |
-
self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
|
666 |
-
if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
|
667 |
-
warnings.warn(
|
668 |
-
"You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
|
669 |
-
"`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
|
670 |
-
"greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
|
671 |
-
"loss.",
|
672 |
-
UserWarning,
|
673 |
-
)
|
674 |
-
|
675 |
-
# Underlying Distribution Matching argument
|
676 |
-
self.embedding_func = embedding_func
|
677 |
-
self.embedding_tokenizer = embedding_tokenizer
|
678 |
-
|
679 |
-
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
|
680 |
-
# input tensor associated with the key "input_ids". However, in BCO, the sampled data does not include the
|
681 |
-
# "input_ids" key. Instead, the available keys are "prompt_input_ids" and "completion_input_ids". As a result,
|
682 |
-
# the trainer issues the warning: "Could not estimate the number of tokens of the input, floating-point
|
683 |
-
# operations will not be computed." To suppress this warning, we set the "estimate_tokens" key in the model's
|
684 |
-
# "warnings_issued" dictionary to True. This acts as a flag to indicate that the warning has already been
|
685 |
-
# issued.
|
686 |
-
model.warnings_issued["estimate_tokens"] = True
|
687 |
-
|
688 |
-
with PartialState().local_main_process_first():
|
689 |
-
# Apply the chat template if needed
|
690 |
-
train_dataset = train_dataset.map(
|
691 |
-
maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, num_proc=args.dataset_num_proc
|
692 |
-
)
|
693 |
-
if eval_dataset is not None:
|
694 |
-
eval_dataset = eval_dataset.map(
|
695 |
-
maybe_apply_chat_template,
|
696 |
-
fn_kwargs={"tokenizer": processing_class},
|
697 |
-
num_proc=args.dataset_num_proc,
|
698 |
-
)
|
699 |
-
# Shuffle the datasets
|
700 |
-
train_dataset = train_dataset.shuffle(seed=args.data_seed)
|
701 |
-
if eval_dataset is not None:
|
702 |
-
eval_dataset = eval_dataset.shuffle(seed=args.data_seed)
|
703 |
-
# Tokenize and prepare the training datasets
|
704 |
-
train_dataset = train_dataset.map(
|
705 |
-
_tokenize,
|
706 |
-
batched=True,
|
707 |
-
fn_kwargs={"tokenizer": processing_class, "embedding_tokenizer": self.embedding_tokenizer},
|
708 |
-
num_proc=args.dataset_num_proc,
|
709 |
-
desc="Tokenizing train dataset",
|
710 |
-
)
|
711 |
-
|
712 |
-
# Prepare the datasets
|
713 |
-
fn_kwargs = {
|
714 |
-
"prefix": "",
|
715 |
-
"is_encoder_decoder": self.is_encoder_decoder,
|
716 |
-
"tokenizer": processing_class,
|
717 |
-
"max_length": self.max_length,
|
718 |
-
"truncation_mode": self.truncation_mode,
|
719 |
-
"label_pad_token_id": self.label_pad_token_id,
|
720 |
-
"max_prompt_length": self.max_prompt_length,
|
721 |
-
"max_completion_length": self.max_completion_length,
|
722 |
-
}
|
723 |
-
train_dataset = train_dataset.map(
|
724 |
-
_process_tokens,
|
725 |
-
fn_kwargs=fn_kwargs,
|
726 |
-
num_proc=args.dataset_num_proc,
|
727 |
-
desc="Processing tokenized train dataset",
|
728 |
-
)
|
729 |
-
|
730 |
-
if eval_dataset is not None:
|
731 |
-
# Tokenize
|
732 |
-
eval_dataset = eval_dataset.map(
|
733 |
-
_tokenize,
|
734 |
-
fn_kwargs={"tokenizer": processing_class, "embedding_tokenizer": self.embedding_tokenizer},
|
735 |
-
batched=True,
|
736 |
-
num_proc=args.dataset_num_proc,
|
737 |
-
desc="Tokenizing eval dataset",
|
738 |
-
)
|
739 |
-
|
740 |
-
# Process
|
741 |
-
fn_kwargs = {
|
742 |
-
"prefix": "",
|
743 |
-
"is_encoder_decoder": self.is_encoder_decoder,
|
744 |
-
"tokenizer": processing_class,
|
745 |
-
"max_length": self.max_length,
|
746 |
-
"truncation_mode": self.truncation_mode,
|
747 |
-
"label_pad_token_id": self.label_pad_token_id,
|
748 |
-
"max_prompt_length": self.max_prompt_length,
|
749 |
-
"max_completion_length": self.max_completion_length,
|
750 |
-
}
|
751 |
-
eval_dataset = eval_dataset.map(
|
752 |
-
_process_tokens,
|
753 |
-
fn_kwargs=fn_kwargs,
|
754 |
-
num_proc=args.dataset_num_proc,
|
755 |
-
desc="Processing tokenized eval dataset",
|
756 |
-
)
|
757 |
-
|
758 |
-
desirable = train_dataset.filter(
|
759 |
-
lambda x: x["label"], num_proc=args.dataset_num_proc, desc="Filtering desirable examples"
|
760 |
-
)
|
761 |
-
undesirable = train_dataset.filter(
|
762 |
-
lambda x: not x["label"], num_proc=args.dataset_num_proc, desc="Filtering undesirable examples"
|
763 |
-
)
|
764 |
-
|
765 |
-
desirable = desirable.shuffle(seed=args.data_seed)
|
766 |
-
undesirable = undesirable.shuffle(seed=args.data_seed)
|
767 |
-
|
768 |
-
super().__init__(
|
769 |
-
model=model,
|
770 |
-
args=args,
|
771 |
-
data_collator=data_collator,
|
772 |
-
train_dataset=train_dataset,
|
773 |
-
eval_dataset=eval_dataset,
|
774 |
-
processing_class=processing_class,
|
775 |
-
model_init=model_init,
|
776 |
-
compute_metrics=compute_metrics,
|
777 |
-
callbacks=callbacks,
|
778 |
-
optimizers=optimizers,
|
779 |
-
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
780 |
-
)
|
781 |
-
|
782 |
-
# Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
|
783 |
-
# model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
|
784 |
-
# self.model_accepts_loss_kwargs to False to enable scaling.
|
785 |
-
self.model_accepts_loss_kwargs = False
|
786 |
-
|
787 |
-
# Add tags for models that have been loaded with the correct transformers version
|
788 |
-
if hasattr(self.model, "add_model_tags"):
|
789 |
-
self.model.add_model_tags(self._tag_names)
|
790 |
-
|
791 |
-
if not hasattr(self, "accelerator"):
|
792 |
-
raise AttributeError(
|
793 |
-
"Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
|
794 |
-
)
|
795 |
-
|
796 |
-
# Deepspeed Zero-3 does not support precompute_ref_log_probs
|
797 |
-
if self.is_deepspeed_enabled:
|
798 |
-
if self.accelerator.state.deepspeed_plugin.zero_stage == 3 and self.precompute_ref_log_probs:
|
799 |
-
raise ValueError(
|
800 |
-
"You cannot use `precompute_ref_log_probs=True` with Deepspeed ZeRO-3. Please set `precompute_ref_log_probs=False`."
|
801 |
-
)
|
802 |
-
|
803 |
-
if self.ref_model is None:
|
804 |
-
if not (self.is_peft_model or self.precompute_ref_log_probs):
|
805 |
-
raise ValueError(
|
806 |
-
"No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`"
|
807 |
-
)
|
808 |
-
else:
|
809 |
-
if self.is_deepspeed_enabled:
|
810 |
-
self.ref_model = self._prepare_deepspeed(self.ref_model)
|
811 |
-
else:
|
812 |
-
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
813 |
-
|
814 |
-
self.running = RunningMoments(accelerator=self.accelerator)
|
815 |
-
|
816 |
-
if self.embedding_func is None:
|
817 |
-
return
|
818 |
-
|
819 |
-
chosen_embeddings = self._get_sample_prompt_embeddings(desirable, sample_size=self.args.prompt_sample_size)
|
820 |
-
rejected_embeddings = self._get_sample_prompt_embeddings(undesirable, sample_size=self.args.prompt_sample_size)
|
821 |
-
|
822 |
-
embeddings = torch.cat((chosen_embeddings, rejected_embeddings), dim=0)
|
823 |
-
labels = torch.cat(
|
824 |
-
(torch.ones_like(chosen_embeddings[:, 0]), torch.zeros_like(rejected_embeddings[:, 0])), dim=0
|
825 |
-
)
|
826 |
-
|
827 |
-
self.clf = LogisticRegression(class_weight="balanced").fit(
|
828 |
-
embeddings.cpu().float().numpy(), labels.cpu().numpy()
|
829 |
-
)
|
830 |
-
|
831 |
-
@property
|
832 |
-
def match_underlying_distribution(self):
|
833 |
-
return self.embedding_func is not None and self.embedding_tokenizer is not None
|
834 |
-
|
835 |
-
def _get_chosen_prob(self, prompt_embeddings: torch.FloatTensor) -> torch.FloatTensor:
|
836 |
-
"""
|
837 |
-
Calculates the probability if the given prompt embedding is from desirable dataset.
|
838 |
-
This function calculates the probability in the process and ensemble across processes.
|
839 |
-
"""
|
840 |
-
dtype = prompt_embeddings.dtype
|
841 |
-
device = prompt_embeddings.device
|
842 |
-
rank = self.accelerator.process_index
|
843 |
-
|
844 |
-
padded_prompt_embeddings = self.accelerator.pad_across_processes(
|
845 |
-
prompt_embeddings, pad_index=self.embedding_tokenizer.pad_token_id
|
846 |
-
)
|
847 |
-
sample_size = padded_prompt_embeddings.shape[0]
|
848 |
-
nonzero = padded_prompt_embeddings.mean(dim=1) != self.embedding_tokenizer.pad_token_id
|
849 |
-
prompt_embeddings = self.accelerator.gather(padded_prompt_embeddings)
|
850 |
-
|
851 |
-
# cannot predict for all empty values
|
852 |
-
if prompt_embeddings.shape[0] == 0:
|
853 |
-
return torch.tensor([], device=device, dtype=dtype)
|
854 |
-
|
855 |
-
prob = self.clf.predict_proba(prompt_embeddings.cpu().float().numpy())[:, 1]
|
856 |
-
prob = torch.as_tensor(prob, dtype=dtype, device=device)
|
857 |
-
prob = self.accelerator.reduce(prob, reduction="mean")
|
858 |
-
|
859 |
-
prob = prob[sample_size * rank : sample_size * (rank + 1)]
|
860 |
-
prob = prob[nonzero]
|
861 |
-
|
862 |
-
return prob
|
863 |
-
|
864 |
-
def _vectorize_prompt(self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor) -> torch.FloatTensor:
|
865 |
-
"""
|
866 |
-
Replaces processing_class.pad_token_id to embedding_tokenizer.pad_token_id
|
867 |
-
and applies self.embedding_func
|
868 |
-
"""
|
869 |
-
input_ids = torch.where(
|
870 |
-
input_ids == self.processing_class.pad_token_id,
|
871 |
-
self.embedding_tokenizer.pad_token_id,
|
872 |
-
input_ids,
|
873 |
-
)
|
874 |
-
|
875 |
-
with torch.no_grad():
|
876 |
-
embeddings = self.embedding_func(
|
877 |
-
input_ids=input_ids,
|
878 |
-
attention_mask=attention_mask,
|
879 |
-
)
|
880 |
-
|
881 |
-
return embeddings
|
882 |
-
|
883 |
-
def _get_prompt_embeddings(
|
884 |
-
self, batch: dict[str, Union[list, torch.LongTensor]]
|
885 |
-
) -> tuple[torch.FloatTensor, torch.FloatTensor]:
|
886 |
-
"""Extract embeddings from frozen embedding model"""
|
887 |
-
|
888 |
-
if not self.match_underlying_distribution:
|
889 |
-
return None, None
|
890 |
-
|
891 |
-
embeddings = self._vectorize_prompt(
|
892 |
-
input_ids=batch["embedding_input_ids"],
|
893 |
-
attention_mask=batch["embedding_attention_mask"],
|
894 |
-
)
|
895 |
-
|
896 |
-
chosen_idx = [i for i in range(len(batch["label"])) if batch["label"][i] is True]
|
897 |
-
rejected_idx = [i for i in range(len(batch["label"])) if batch["label"][i] is False]
|
898 |
-
|
899 |
-
chosen_embeddings = embeddings[chosen_idx, ...]
|
900 |
-
rejected_embeddings = embeddings[rejected_idx, ...]
|
901 |
-
|
902 |
-
return (chosen_embeddings, rejected_embeddings)
|
903 |
-
|
904 |
-
def _get_sample_prompt_embeddings(self, dataset: Dataset, sample_size: int = 512) -> torch.FloatTensor:
|
905 |
-
"""
|
906 |
-
Sample instances from dataset and get prompt embeddings.
|
907 |
-
Used for density ratio classifier training.
|
908 |
-
"""
|
909 |
-
n_samples = min(len(dataset), sample_size)
|
910 |
-
rand_indices = np.random.choice(len(dataset), size=(n_samples,))
|
911 |
-
|
912 |
-
embedding_dataset = dataset.select(rand_indices)
|
913 |
-
|
914 |
-
dataloader_params = {
|
915 |
-
"batch_size": self.args.per_device_train_batch_size,
|
916 |
-
"collate_fn": self.data_collator,
|
917 |
-
"num_workers": self.args.dataloader_num_workers,
|
918 |
-
"pin_memory": self.args.dataloader_pin_memory,
|
919 |
-
"shuffle": False,
|
920 |
-
}
|
921 |
-
|
922 |
-
# prepare dataloader
|
923 |
-
data_loader = self.accelerator.prepare(DataLoader(embedding_dataset, **dataloader_params))
|
924 |
-
|
925 |
-
with torch.no_grad():
|
926 |
-
all_embeddings = torch.empty(0)
|
927 |
-
for padded_batch in tqdm(iterable=data_loader, desc="Building sample prompt embeddings"):
|
928 |
-
embeddings = self._vectorize_prompt(
|
929 |
-
input_ids=padded_batch["embedding_input_ids"],
|
930 |
-
attention_mask=padded_batch["embedding_attention_mask"],
|
931 |
-
)
|
932 |
-
embeddings = self.accelerator.gather_for_metrics(embeddings)
|
933 |
-
all_embeddings = torch.cat((all_embeddings, embeddings.cpu()))
|
934 |
-
|
935 |
-
return all_embeddings
|
936 |
-
|
937 |
-
def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
|
938 |
-
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
|
939 |
-
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
|
940 |
-
config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
|
941 |
-
|
942 |
-
if model is not None:
|
943 |
-
if hasattr(model, "config"):
|
944 |
-
hidden_size = (
|
945 |
-
max(model.config.hidden_sizes)
|
946 |
-
if getattr(model.config, "hidden_sizes", None)
|
947 |
-
else getattr(model.config, "hidden_size", None)
|
948 |
-
)
|
949 |
-
if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
|
950 |
-
# Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
|
951 |
-
# This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
|
952 |
-
config_kwargs.update(
|
953 |
-
{
|
954 |
-
"zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
|
955 |
-
"zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
|
956 |
-
"zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
|
957 |
-
}
|
958 |
-
)
|
959 |
-
|
960 |
-
# If ZeRO-3 is used, we shard both the active and reference model.
|
961 |
-
# Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
|
962 |
-
if config_kwargs["zero_optimization"]["stage"] != 3:
|
963 |
-
config_kwargs["zero_optimization"]["stage"] = 0
|
964 |
-
model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
|
965 |
-
model.eval()
|
966 |
-
return model
|
967 |
-
|
968 |
-
def _save_optimizer_and_scheduler(self, output_dir):
|
969 |
-
super()._save_optimizer_and_scheduler(output_dir)
|
970 |
-
|
971 |
-
# When saving optimizer and scheduler to checkpoint, save also the running delta object.
|
972 |
-
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
973 |
-
|
974 |
-
self.running.save_to_json(os.path.join(output_dir, RUNNING_NAME))
|
975 |
-
|
976 |
-
if self.match_underlying_distribution:
|
977 |
-
torch.save(self.clf.get_params(), os.path.join(output_dir, CLF_NAME))
|
978 |
-
|
979 |
-
def _load_optimizer_and_scheduler(self, checkpoint):
|
980 |
-
super()._load_optimizer_and_scheduler(checkpoint)
|
981 |
-
|
982 |
-
if checkpoint is None:
|
983 |
-
return
|
984 |
-
# when loading optimizer and scheduler from checkpoint, also load the running delta object.
|
985 |
-
running_file = os.path.join(checkpoint, RUNNING_NAME)
|
986 |
-
if os.path.isfile(running_file):
|
987 |
-
self.running = RunningMoments.load_from_json(self.accelerator, running_file)
|
988 |
-
|
989 |
-
if self.match_underlying_distribution:
|
990 |
-
clf_file = os.path.join(checkpoint, CLF_NAME)
|
991 |
-
if os.path.isfile(running_file):
|
992 |
-
self.clf.set_params(**torch.load(clf_file, weights_only=True, map_location="cpu"))
|
993 |
-
|
994 |
-
@contextmanager
|
995 |
-
def null_ref_context(self):
|
996 |
-
"""Context manager for handling null reference model (that is, peft adapter manipulation)."""
|
997 |
-
with (
|
998 |
-
self.accelerator.unwrap_model(self.model).disable_adapter()
|
999 |
-
if self.is_peft_model and not self.ref_adapter_name
|
1000 |
-
else nullcontext()
|
1001 |
-
):
|
1002 |
-
if self.ref_adapter_name:
|
1003 |
-
self.model.set_adapter(self.ref_adapter_name)
|
1004 |
-
yield
|
1005 |
-
if self.ref_adapter_name:
|
1006 |
-
self.model.set_adapter(self.model_adapter_name or "default")
|
1007 |
-
|
1008 |
-
def get_train_dataloader(self) -> DataLoader:
|
1009 |
-
"""
|
1010 |
-
Returns the training [`~torch.utils.data.DataLoader`].
|
1011 |
-
|
1012 |
-
Subclass of transformers.src.transformers.trainer.get_train_dataloader to precompute `ref_log_probs`.
|
1013 |
-
"""
|
1014 |
-
|
1015 |
-
if self.precompute_ref_log_probs and not self._precomputed_train_ref_log_probs:
|
1016 |
-
dataloader_params = {
|
1017 |
-
"batch_size": self.args.per_device_train_batch_size,
|
1018 |
-
"collate_fn": self.data_collator,
|
1019 |
-
"num_workers": self.args.dataloader_num_workers,
|
1020 |
-
"pin_memory": self.args.dataloader_pin_memory,
|
1021 |
-
"shuffle": False,
|
1022 |
-
}
|
1023 |
-
|
1024 |
-
# prepare dataloader
|
1025 |
-
data_loader = self.accelerator.prepare(DataLoader(self.train_dataset, **dataloader_params))
|
1026 |
-
reference_completion_logps = []
|
1027 |
-
|
1028 |
-
for padded_batch in tqdm(iterable=data_loader, desc="Train dataset reference log probs"):
|
1029 |
-
reference_completion_logp = self.compute_reference_log_probs(padded_batch)
|
1030 |
-
|
1031 |
-
reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp)
|
1032 |
-
reference_completion_logps.append(reference_completion_logp.cpu())
|
1033 |
-
|
1034 |
-
self.train_dataset = self.train_dataset.add_column(
|
1035 |
-
name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy()
|
1036 |
-
)
|
1037 |
-
|
1038 |
-
self._precomputed_train_ref_log_probs = True
|
1039 |
-
|
1040 |
-
return super().get_train_dataloader()
|
1041 |
-
|
1042 |
-
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
|
1043 |
-
"""
|
1044 |
-
Returns the evaluation [`~torch.utils.data.DataLoader`].
|
1045 |
-
|
1046 |
-
Subclass of transformers.src.transformers.trainer.get_eval_dataloader to precompute `ref_log_probs`.
|
1047 |
-
|
1048 |
-
Args:
|
1049 |
-
eval_dataset (`torch.utils.data.Dataset`, *optional*):
|
1050 |
-
If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted
|
1051 |
-
by the `model.forward()` method are automatically removed. It must implement `__len__`.
|
1052 |
-
"""
|
1053 |
-
if eval_dataset is None and self.eval_dataset is None:
|
1054 |
-
raise ValueError("Trainer: evaluation requires an eval_dataset.")
|
1055 |
-
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
1056 |
-
|
1057 |
-
if self.precompute_ref_log_probs and not self._precomputed_eval_ref_log_probs:
|
1058 |
-
dataloader_params = {
|
1059 |
-
"batch_size": self.args.per_device_eval_batch_size,
|
1060 |
-
"collate_fn": self.data_collator,
|
1061 |
-
"num_workers": self.args.dataloader_num_workers,
|
1062 |
-
"pin_memory": self.args.dataloader_pin_memory,
|
1063 |
-
"shuffle": False,
|
1064 |
-
}
|
1065 |
-
|
1066 |
-
# prepare dataloader
|
1067 |
-
data_loader = self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params))
|
1068 |
-
|
1069 |
-
reference_completion_logps = []
|
1070 |
-
|
1071 |
-
for padded_batch in tqdm(iterable=data_loader, desc="Eval dataset reference log probs"):
|
1072 |
-
reference_completion_logp = self.compute_reference_log_probs(padded_batch)
|
1073 |
-
|
1074 |
-
reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp)
|
1075 |
-
reference_completion_logps.append(reference_completion_logp.cpu())
|
1076 |
-
|
1077 |
-
eval_dataset = eval_dataset.add_column(
|
1078 |
-
name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy()
|
1079 |
-
)
|
1080 |
-
|
1081 |
-
# Save calculated reference_chosen_logps and reference_rejected_logps to the eval_dataset for subsequent runs
|
1082 |
-
if self.eval_dataset is not None:
|
1083 |
-
self.eval_dataset = eval_dataset
|
1084 |
-
self._precomputed_eval_ref_log_probs = True
|
1085 |
-
|
1086 |
-
return super().get_eval_dataloader(eval_dataset=eval_dataset)
|
1087 |
-
|
1088 |
-
def compute_reference_log_probs(self, padded_batch: dict) -> dict:
|
1089 |
-
"""Computes log probabilities of the reference model for a single padded batch of a BCO specific dataset."""
|
1090 |
-
with torch.no_grad():
|
1091 |
-
if self.ref_model is None:
|
1092 |
-
with self.null_ref_context():
|
1093 |
-
if self.is_encoder_decoder:
|
1094 |
-
completion_logits = self.model(
|
1095 |
-
padded_batch["prompt_input_ids"],
|
1096 |
-
attention_mask=padded_batch["prompt_attention_mask"],
|
1097 |
-
decoder_input_ids=padded_batch.get("completion_decoder_input_ids"),
|
1098 |
-
labels=padded_batch["completion_labels"],
|
1099 |
-
).logits
|
1100 |
-
|
1101 |
-
else:
|
1102 |
-
completion_logits = self.model(
|
1103 |
-
padded_batch["completion_input_ids"],
|
1104 |
-
attention_mask=padded_batch["completion_attention_mask"],
|
1105 |
-
).logits
|
1106 |
-
|
1107 |
-
else:
|
1108 |
-
if self.is_encoder_decoder:
|
1109 |
-
completion_logits = self.ref_model(
|
1110 |
-
padded_batch["prompt_input_ids"],
|
1111 |
-
attention_mask=padded_batch["prompt_attention_mask"],
|
1112 |
-
decoder_input_ids=padded_batch.get("completion_decoder_input_ids"),
|
1113 |
-
labels=padded_batch["completion_labels"],
|
1114 |
-
).logits
|
1115 |
-
|
1116 |
-
else:
|
1117 |
-
completion_logits = self.ref_model(
|
1118 |
-
padded_batch["completion_input_ids"], attention_mask=padded_batch["completion_attention_mask"]
|
1119 |
-
).logits
|
1120 |
-
|
1121 |
-
completion_logps = self.get_batch_logps(
|
1122 |
-
completion_logits,
|
1123 |
-
padded_batch["completion_labels"],
|
1124 |
-
average_log_prob=False,
|
1125 |
-
is_encoder_decoder=self.is_encoder_decoder,
|
1126 |
-
label_pad_token_id=self.label_pad_token_id,
|
1127 |
-
)
|
1128 |
-
|
1129 |
-
return completion_logps
|
1130 |
-
|
1131 |
-
@staticmethod
|
1132 |
-
def get_batch_logps(
|
1133 |
-
logits: torch.FloatTensor,
|
1134 |
-
labels: torch.LongTensor,
|
1135 |
-
average_log_prob: bool = False,
|
1136 |
-
label_pad_token_id: int = -100,
|
1137 |
-
is_encoder_decoder: bool = False,
|
1138 |
-
) -> torch.FloatTensor:
|
1139 |
-
"""Compute the log probabilities of the given labels under the given logits.
|
1140 |
-
|
1141 |
-
Args:
|
1142 |
-
logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
|
1143 |
-
labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
|
1144 |
-
average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
|
1145 |
-
|
1146 |
-
Returns:
|
1147 |
-
A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
|
1148 |
-
"""
|
1149 |
-
if logits.shape[:-1] != labels.shape:
|
1150 |
-
raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
|
1151 |
-
|
1152 |
-
if not is_encoder_decoder:
|
1153 |
-
labels = labels[:, 1:].clone()
|
1154 |
-
logits = logits[:, :-1, :]
|
1155 |
-
else:
|
1156 |
-
# Fixes end-dec RuntimeError
|
1157 |
-
labels = labels.clone()
|
1158 |
-
|
1159 |
-
loss_mask = labels != label_pad_token_id
|
1160 |
-
|
1161 |
-
# dummy token; we'll ignore the losses on these tokens later
|
1162 |
-
labels[labels == label_pad_token_id] = 0
|
1163 |
-
|
1164 |
-
per_token_logps = selective_log_softmax(logits, labels)
|
1165 |
-
|
1166 |
-
if average_log_prob:
|
1167 |
-
return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
|
1168 |
-
else:
|
1169 |
-
return (per_token_logps * loss_mask).sum(-1)
|
1170 |
-
|
1171 |
-
def forward(
|
1172 |
-
self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
|
1173 |
-
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
1174 |
-
model_kwargs = (
|
1175 |
-
{
|
1176 |
-
"labels": batch["completion_labels"],
|
1177 |
-
"decoder_input_ids": batch.get("completion_decoder_input_ids"),
|
1178 |
-
}
|
1179 |
-
if self.is_encoder_decoder
|
1180 |
-
else {}
|
1181 |
-
)
|
1182 |
-
if self.aux_loss_enabled:
|
1183 |
-
model_kwargs["output_router_logits"] = True
|
1184 |
-
|
1185 |
-
outputs = model(
|
1186 |
-
batch["completion_input_ids"],
|
1187 |
-
attention_mask=batch["completion_attention_mask"],
|
1188 |
-
**model_kwargs,
|
1189 |
-
)
|
1190 |
-
completion_logits = outputs.logits
|
1191 |
-
|
1192 |
-
completion_logps = self.get_batch_logps(
|
1193 |
-
completion_logits,
|
1194 |
-
batch["completion_labels"],
|
1195 |
-
average_log_prob=False,
|
1196 |
-
is_encoder_decoder=self.is_encoder_decoder,
|
1197 |
-
label_pad_token_id=self.label_pad_token_id,
|
1198 |
-
)
|
1199 |
-
|
1200 |
-
if completion_logps.shape[0] != len(batch["label"]):
|
1201 |
-
raise ValueError(
|
1202 |
-
"There is a mismatch between the number of examples in this batch and the number of "
|
1203 |
-
"examples for which an output sequence was predicted."
|
1204 |
-
)
|
1205 |
-
|
1206 |
-
chosen_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is True]
|
1207 |
-
rejected_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is False]
|
1208 |
-
|
1209 |
-
chosen_logps = completion_logps[chosen_idx, ...]
|
1210 |
-
rejected_logps = completion_logps[rejected_idx, ...]
|
1211 |
-
|
1212 |
-
chosen_logits = completion_logits[chosen_idx, ...]
|
1213 |
-
rejected_logits = completion_logits[rejected_idx, ...]
|
1214 |
-
|
1215 |
-
if self.aux_loss_enabled:
|
1216 |
-
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, outputs.aux_loss)
|
1217 |
-
else:
|
1218 |
-
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits)
|
1219 |
-
|
1220 |
-
def _get_udm_weight(self, rejected_embeddings: torch.FloatTensor) -> torch.FloatTensor:
|
1221 |
-
prob_desirable = self._get_chosen_prob(rejected_embeddings)
|
1222 |
-
min_ratio = self.args.min_density_ratio
|
1223 |
-
max_ratio = self.args.max_density_ratio
|
1224 |
-
|
1225 |
-
weight = (prob_desirable / (1 - prob_desirable + 1e-8)).clamp(min=min_ratio, max=max_ratio)
|
1226 |
-
|
1227 |
-
return weight
|
1228 |
-
|
1229 |
-
def bco_loss(
|
1230 |
-
self,
|
1231 |
-
policy_chosen_logps: torch.FloatTensor,
|
1232 |
-
policy_rejected_logps: torch.FloatTensor,
|
1233 |
-
reference_chosen_logps: torch.FloatTensor,
|
1234 |
-
reference_rejected_logps: torch.FloatTensor,
|
1235 |
-
chosen_embeddings: Optional[torch.FloatTensor],
|
1236 |
-
rejected_embeddings: Optional[torch.FloatTensor],
|
1237 |
-
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
1238 |
-
"""Compute the BCO loss for a batch of policy and reference model log probabilities.
|
1239 |
-
|
1240 |
-
Args:
|
1241 |
-
policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (num(chosen) in batch_size,)
|
1242 |
-
policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (num(rejected) in batch_size,)
|
1243 |
-
reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (num(chosen) in batch_size,)
|
1244 |
-
reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (num(rejected) in batch_size,)
|
1245 |
-
chosen_embeddings: embeddings of desirable prompts
|
1246 |
-
rejected_embeddings: embeddings of undesirable prompts
|
1247 |
-
|
1248 |
-
Returns:
|
1249 |
-
A tuple of four tensors: (losses, chosen_rewards, rejected_rewards, delta).
|
1250 |
-
The losses tensor contains the BCO loss for each example in the batch.
|
1251 |
-
The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
|
1252 |
-
The delta value contains the moving average of all implicit rewards.
|
1253 |
-
"""
|
1254 |
-
|
1255 |
-
if policy_chosen_logps.shape[0] != 0 or reference_chosen_logps.shape[0] != 0:
|
1256 |
-
chosen_logratios = policy_chosen_logps - reference_chosen_logps
|
1257 |
-
chosen_rewards = self.beta * chosen_logratios
|
1258 |
-
else:
|
1259 |
-
# lists can't be empty -- if they are, then accelerate.gather will hang
|
1260 |
-
chosen_losses = torch.Tensor([]).to(self.accelerator.device)
|
1261 |
-
chosen_rewards = torch.Tensor([]).to(self.accelerator.device)
|
1262 |
-
|
1263 |
-
if policy_rejected_logps.shape[0] != 0 or reference_rejected_logps.shape[0] != 0:
|
1264 |
-
rejected_logratios = policy_rejected_logps - reference_rejected_logps
|
1265 |
-
rejected_rewards = self.beta * rejected_logratios
|
1266 |
-
else:
|
1267 |
-
# lists can't be empty -- if they are, then accelerate.gather will hang
|
1268 |
-
rejected_losses = torch.Tensor([]).to(self.accelerator.device)
|
1269 |
-
rejected_rewards = torch.Tensor([]).to(self.accelerator.device)
|
1270 |
-
|
1271 |
-
rewards = torch.cat((chosen_rewards, rejected_rewards), 0).mean().detach()
|
1272 |
-
self.running.update(rewards)
|
1273 |
-
delta = self.running.mean
|
1274 |
-
|
1275 |
-
if policy_chosen_logps.shape[0] != 0 or reference_chosen_logps.shape[0] != 0:
|
1276 |
-
chosen_losses = -F.logsigmoid(chosen_rewards - delta)
|
1277 |
-
|
1278 |
-
if policy_rejected_logps.shape[0] != 0 or reference_rejected_logps.shape[0] != 0:
|
1279 |
-
rejected_losses = -F.logsigmoid(-(rejected_rewards - delta))
|
1280 |
-
|
1281 |
-
if self.match_underlying_distribution:
|
1282 |
-
chosen_weight = torch.ones_like(chosen_losses)
|
1283 |
-
rejected_weight = self._get_udm_weight(rejected_embeddings)
|
1284 |
-
|
1285 |
-
losses = torch.cat((chosen_weight * chosen_losses, rejected_weight * rejected_losses), dim=0)
|
1286 |
-
else:
|
1287 |
-
losses = torch.cat((chosen_losses, rejected_losses), dim=0)
|
1288 |
-
|
1289 |
-
return losses, chosen_rewards, rejected_rewards, torch.as_tensor(delta)
|
1290 |
-
|
1291 |
-
def get_batch_loss_metrics(
|
1292 |
-
self,
|
1293 |
-
model,
|
1294 |
-
batch: dict[str, Union[list, torch.LongTensor]],
|
1295 |
-
):
|
1296 |
-
"""Compute the BCO loss and other metrics for the given batch of inputs for train or test."""
|
1297 |
-
metrics = {}
|
1298 |
-
batch = {k: (v.to(self.accelerator.device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
|
1299 |
-
|
1300 |
-
forward_output = self.forward(model, batch)
|
1301 |
-
(
|
1302 |
-
policy_chosen_logps,
|
1303 |
-
policy_rejected_logps,
|
1304 |
-
policy_chosen_logits,
|
1305 |
-
policy_rejected_logits,
|
1306 |
-
) = forward_output[:4]
|
1307 |
-
if self.aux_loss_enabled:
|
1308 |
-
aux_loss = forward_output[4]
|
1309 |
-
|
1310 |
-
# if reference_logps in batch use them, otherwise use the reference model
|
1311 |
-
if "reference_logps" in batch:
|
1312 |
-
chosen_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is True]
|
1313 |
-
rejected_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is False]
|
1314 |
-
|
1315 |
-
reference_chosen_logps = batch["reference_logps"][chosen_idx, ...]
|
1316 |
-
reference_rejected_logps = batch["reference_logps"][rejected_idx, ...]
|
1317 |
-
else:
|
1318 |
-
with torch.no_grad():
|
1319 |
-
if self.ref_model is None:
|
1320 |
-
with self.null_ref_context():
|
1321 |
-
(
|
1322 |
-
reference_chosen_logps,
|
1323 |
-
reference_rejected_logps,
|
1324 |
-
_,
|
1325 |
-
_,
|
1326 |
-
) = self.forward(self.model, batch)[:4]
|
1327 |
-
else:
|
1328 |
-
(
|
1329 |
-
reference_chosen_logps,
|
1330 |
-
reference_rejected_logps,
|
1331 |
-
_,
|
1332 |
-
_,
|
1333 |
-
) = self.forward(self.ref_model, batch)[:4]
|
1334 |
-
|
1335 |
-
chosen_embeddings, rejected_embeddings = self._get_prompt_embeddings(batch)
|
1336 |
-
|
1337 |
-
losses, chosen_rewards, rejected_rewards, delta = self.bco_loss(
|
1338 |
-
policy_chosen_logps,
|
1339 |
-
policy_rejected_logps,
|
1340 |
-
reference_chosen_logps,
|
1341 |
-
reference_rejected_logps,
|
1342 |
-
chosen_embeddings,
|
1343 |
-
rejected_embeddings,
|
1344 |
-
)
|
1345 |
-
metrics["delta"] = self.accelerator.gather_for_metrics(delta).mean().item()
|
1346 |
-
|
1347 |
-
num_chosen = torch.Tensor([len(chosen_rewards)]).to(self.accelerator.device)
|
1348 |
-
num_rejected = torch.Tensor([len(rejected_rewards)]).to(self.accelerator.device)
|
1349 |
-
|
1350 |
-
all_num_chosen = self.accelerator.gather_for_metrics(num_chosen).sum().item()
|
1351 |
-
all_num_rejected = self.accelerator.gather_for_metrics(num_rejected).sum().item()
|
1352 |
-
|
1353 |
-
if all_num_chosen > 0:
|
1354 |
-
metrics["rewards/chosen_sum"] = (
|
1355 |
-
self.accelerator.gather_for_metrics(chosen_rewards.nansum()).nansum().item()
|
1356 |
-
)
|
1357 |
-
metrics["logps/chosen_sum"] = (
|
1358 |
-
self.accelerator.gather_for_metrics(policy_chosen_logps.nansum()).nansum().item()
|
1359 |
-
)
|
1360 |
-
metrics["logits/chosen_sum"] = (
|
1361 |
-
self.accelerator.gather_for_metrics(policy_chosen_logits.nansum()).nansum().item()
|
1362 |
-
)
|
1363 |
-
metrics["count/chosen"] = all_num_chosen
|
1364 |
-
|
1365 |
-
if all_num_rejected > 0:
|
1366 |
-
metrics["rewards/rejected_sum"] = (
|
1367 |
-
self.accelerator.gather_for_metrics(rejected_rewards.nansum()).nansum().item()
|
1368 |
-
)
|
1369 |
-
metrics["logps/rejected_sum"] = (
|
1370 |
-
self.accelerator.gather_for_metrics(policy_rejected_logps.nansum()).nansum().item()
|
1371 |
-
)
|
1372 |
-
metrics["logits/rejected_sum"] = (
|
1373 |
-
self.accelerator.gather_for_metrics(policy_rejected_logits.nansum()).nansum().item()
|
1374 |
-
)
|
1375 |
-
metrics["count/rejected"] = all_num_rejected
|
1376 |
-
|
1377 |
-
loss = losses.nanmean()
|
1378 |
-
if self.aux_loss_enabled:
|
1379 |
-
loss += self.aux_loss_coef * aux_loss
|
1380 |
-
|
1381 |
-
return loss, metrics
|
1382 |
-
|
1383 |
-
def compute_loss(
|
1384 |
-
self,
|
1385 |
-
model: Union[PreTrainedModel, nn.Module],
|
1386 |
-
inputs: dict[str, Union[torch.Tensor, Any]],
|
1387 |
-
return_outputs=False,
|
1388 |
-
num_items_in_batch=None,
|
1389 |
-
) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
|
1390 |
-
compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
|
1391 |
-
|
1392 |
-
with compute_loss_context_manager:
|
1393 |
-
loss, metrics = self.get_batch_loss_metrics(model, inputs)
|
1394 |
-
|
1395 |
-
# Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class:
|
1396 |
-
loss = loss.to(self.args.device)
|
1397 |
-
# force log the metrics
|
1398 |
-
if self.accelerator.is_main_process:
|
1399 |
-
self.store_metrics(metrics, train_eval="train")
|
1400 |
-
|
1401 |
-
if return_outputs:
|
1402 |
-
return (loss, metrics)
|
1403 |
-
return loss
|
1404 |
-
|
1405 |
-
def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
|
1406 |
-
for key, value in metrics.items():
|
1407 |
-
self._stored_metrics[train_eval][key].append(value)
|
1408 |
-
|
1409 |
-
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
1410 |
-
if self.train_dataset is None or not has_length(self.train_dataset):
|
1411 |
-
return None
|
1412 |
-
return SequentialSampler(self.train_dataset)
|
1413 |
-
|
1414 |
-
def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) -> tuple[str, str]:
|
1415 |
-
"""Generate samples from the model and reference model for the given batch of inputs."""
|
1416 |
-
|
1417 |
-
# If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
|
1418 |
-
# the torch cuda amp context manager as some hidden states are silently casted to full precision.
|
1419 |
-
generate_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
|
1420 |
-
with generate_context_manager:
|
1421 |
-
policy_output = model.generate(
|
1422 |
-
input_ids=batch["prompt_input_ids"],
|
1423 |
-
attention_mask=batch["prompt_attention_mask"],
|
1424 |
-
max_length=self.max_length,
|
1425 |
-
do_sample=True,
|
1426 |
-
pad_token_id=self.processing_class.pad_token_id,
|
1427 |
-
)
|
1428 |
-
|
1429 |
-
# if reference_output in batch use that otherwise use the reference model
|
1430 |
-
if "reference_output" in batch:
|
1431 |
-
reference_output = batch["reference_output"]
|
1432 |
-
else:
|
1433 |
-
if self.ref_model is None:
|
1434 |
-
with self.null_ref_context():
|
1435 |
-
reference_output = self.model.generate(
|
1436 |
-
input_ids=batch["prompt_input_ids"],
|
1437 |
-
attention_mask=batch["prompt_attention_mask"],
|
1438 |
-
max_length=self.max_length,
|
1439 |
-
do_sample=True,
|
1440 |
-
pad_token_id=self.processing_class.pad_token_id,
|
1441 |
-
)
|
1442 |
-
else:
|
1443 |
-
reference_output = self.ref_model.generate(
|
1444 |
-
input_ids=batch["prompt_input_ids"],
|
1445 |
-
attention_mask=batch["prompt_attention_mask"],
|
1446 |
-
max_length=self.max_length,
|
1447 |
-
do_sample=True,
|
1448 |
-
pad_token_id=self.processing_class.pad_token_id,
|
1449 |
-
)
|
1450 |
-
|
1451 |
-
policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id)
|
1452 |
-
policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True)
|
1453 |
-
|
1454 |
-
reference_output = pad_to_length(reference_output, self.max_length, self.processing_class.pad_token_id)
|
1455 |
-
reference_output_decoded = self.processing_class.batch_decode(reference_output, skip_special_tokens=True)
|
1456 |
-
|
1457 |
-
return policy_output_decoded, reference_output_decoded
|
1458 |
-
|
1459 |
-
def prediction_step(
|
1460 |
-
self,
|
1461 |
-
model: Union[PreTrainedModel, nn.Module],
|
1462 |
-
inputs: dict[str, Union[torch.Tensor, Any]],
|
1463 |
-
prediction_loss_only: bool,
|
1464 |
-
ignore_keys: Optional[list[str]] = None,
|
1465 |
-
):
|
1466 |
-
if ignore_keys is None:
|
1467 |
-
if hasattr(model, "config"):
|
1468 |
-
ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
|
1469 |
-
else:
|
1470 |
-
ignore_keys = []
|
1471 |
-
|
1472 |
-
prediction_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
|
1473 |
-
with torch.no_grad(), prediction_context_manager:
|
1474 |
-
loss, metrics = self.get_batch_loss_metrics(model, inputs)
|
1475 |
-
|
1476 |
-
# force log the metrics
|
1477 |
-
if self.accelerator.is_main_process:
|
1478 |
-
self.store_metrics(metrics, train_eval="eval")
|
1479 |
-
|
1480 |
-
if prediction_loss_only:
|
1481 |
-
return (loss.detach(), None, None)
|
1482 |
-
|
1483 |
-
# logits for the chosen and rejected samples from model
|
1484 |
-
logits_dict = {
|
1485 |
-
"eval_logits/chosen": metrics["logits/chosen"],
|
1486 |
-
"eval_logits/rejected": metrics["logits/rejected"],
|
1487 |
-
}
|
1488 |
-
logits = tuple(v.unsqueeze(dim=0) for k, v in logits_dict.items() if k not in ignore_keys)
|
1489 |
-
logits = torch.stack(logits).mean(axis=1).to(self.accelerator.device)
|
1490 |
-
labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
|
1491 |
-
|
1492 |
-
return (loss.detach(), logits, labels)
|
1493 |
-
|
1494 |
-
def evaluation_loop(
|
1495 |
-
self,
|
1496 |
-
dataloader: DataLoader,
|
1497 |
-
description: str,
|
1498 |
-
prediction_loss_only: Optional[bool] = None,
|
1499 |
-
ignore_keys: Optional[list[str]] = None,
|
1500 |
-
metric_key_prefix: str = "eval",
|
1501 |
-
) -> EvalLoopOutput:
|
1502 |
-
"""
|
1503 |
-
Overriding built-in evaluation loop to store metrics for each batch.
|
1504 |
-
Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
|
1505 |
-
|
1506 |
-
Works both with or without labels.
|
1507 |
-
"""
|
1508 |
-
|
1509 |
-
# Sample and save to game log if requested (for one batch to save time)
|
1510 |
-
if self.generate_during_eval:
|
1511 |
-
# Generate random indices within the range of the total number of samples
|
1512 |
-
num_samples = len(dataloader.dataset)
|
1513 |
-
random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
|
1514 |
-
|
1515 |
-
# Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
|
1516 |
-
random_batch_dataset = dataloader.dataset.select(random_indices)
|
1517 |
-
random_batch = self.data_collator(random_batch_dataset)
|
1518 |
-
random_batch = self._prepare_inputs(random_batch)
|
1519 |
-
|
1520 |
-
target_indicies = [i for i in range(len(random_batch["label"])) if random_batch["label"][i] is False]
|
1521 |
-
target_batch = {
|
1522 |
-
"prompt_input_ids": random_batch["prompt_input_ids"][target_indicies],
|
1523 |
-
"prompt_attention_mask": random_batch["prompt_attention_mask"][target_indicies],
|
1524 |
-
"prompt": itemgetter(*target_indicies)(random_batch["prompt"]),
|
1525 |
-
}
|
1526 |
-
policy_output_decoded, ref_output_decoded = self.generate_from_model_and_ref(self.model, target_batch)
|
1527 |
-
|
1528 |
-
table = pd.DataFrame(
|
1529 |
-
columns=["Prompt", "Policy", "Ref Model"],
|
1530 |
-
data=[
|
1531 |
-
[prompt, pol[len(prompt) :], ref[len(prompt) :]]
|
1532 |
-
for prompt, pol, ref in zip(target_batch["prompt"], policy_output_decoded, ref_output_decoded)
|
1533 |
-
],
|
1534 |
-
)
|
1535 |
-
if "wandb" in self.args.report_to:
|
1536 |
-
wandb.log({"game_log": wandb.Table(data=table)})
|
1537 |
-
|
1538 |
-
if "comet_ml" in self.args.report_to:
|
1539 |
-
log_table_to_comet_experiment(
|
1540 |
-
name="game_log.csv",
|
1541 |
-
table=table,
|
1542 |
-
)
|
1543 |
-
|
1544 |
-
# Base evaluation
|
1545 |
-
initial_output = super().evaluation_loop(
|
1546 |
-
dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
|
1547 |
-
)
|
1548 |
-
|
1549 |
-
return initial_output
|
1550 |
-
|
1551 |
-
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
|
1552 |
-
"""
|
1553 |
-
Log `logs` on the various objects watching training, including stored metrics.
|
1554 |
-
|
1555 |
-
Args:
|
1556 |
-
logs (`dict[str, float]`):
|
1557 |
-
The values to log.
|
1558 |
-
start_time (`float` or `None`, *optional*, defaults to `None`):
|
1559 |
-
Start time of the training.
|
1560 |
-
"""
|
1561 |
-
# logs either has 'loss' or 'eval_loss'
|
1562 |
-
train_eval = "train" if "loss" in logs else "eval"
|
1563 |
-
# train metrics should have no prefix, eval should have 'eval_'
|
1564 |
-
prefix = "eval_" if train_eval == "eval" else ""
|
1565 |
-
# accumulate average metrics from sums and lengths
|
1566 |
-
for split in ["chosen", "rejected"]:
|
1567 |
-
if f"count/{split}" in self._stored_metrics[train_eval]:
|
1568 |
-
count_sum = torch.Tensor(self._stored_metrics[train_eval][f"count/{split}"]).sum().item()
|
1569 |
-
for metric in ["rewards", "logps", "logits"]:
|
1570 |
-
logs[f"{prefix}{metric}/{split}"] = (
|
1571 |
-
torch.Tensor(self._stored_metrics[train_eval][f"{metric}/{split}_sum"]).sum().item()
|
1572 |
-
/ count_sum
|
1573 |
-
)
|
1574 |
-
# delete obsolete metric
|
1575 |
-
del self._stored_metrics[train_eval][f"{metric}/{split}_sum"]
|
1576 |
-
del self._stored_metrics[train_eval][f"count/{split}"]
|
1577 |
-
# calculate reward margin
|
1578 |
-
if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs:
|
1579 |
-
logs[f"{prefix}rewards/margins"] = logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"]
|
1580 |
-
# Add averaged stored metrics to logs
|
1581 |
-
for key, metrics in self._stored_metrics[train_eval].items():
|
1582 |
-
logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item()
|
1583 |
-
del self._stored_metrics[train_eval]
|
1584 |
-
|
1585 |
-
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
1586 |
-
return super().log(logs, start_time)
|
1587 |
-
else: # transformers<=4.46
|
1588 |
-
return super().log(logs)
|
1589 |
-
|
1590 |
-
def create_model_card(
|
1591 |
-
self,
|
1592 |
-
model_name: Optional[str] = None,
|
1593 |
-
dataset_name: Optional[str] = None,
|
1594 |
-
tags: Union[str, list[str], None] = None,
|
1595 |
-
):
|
1596 |
-
"""
|
1597 |
-
Creates a draft of a model card using the information available to the `Trainer`.
|
1598 |
-
|
1599 |
-
Args:
|
1600 |
-
model_name (`str` or `None`, *optional*, defaults to `None`):
|
1601 |
-
Name of the model.
|
1602 |
-
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
1603 |
-
Name of the dataset used for training.
|
1604 |
-
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
1605 |
-
Tags to be associated with the model card.
|
1606 |
-
"""
|
1607 |
-
if not self.is_world_process_zero():
|
1608 |
-
return
|
1609 |
-
|
1610 |
-
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
1611 |
-
base_model = self.model.config._name_or_path
|
1612 |
-
else:
|
1613 |
-
base_model = None
|
1614 |
-
|
1615 |
-
tags = tags or []
|
1616 |
-
if isinstance(tags, str):
|
1617 |
-
tags = [tags]
|
1618 |
-
|
1619 |
-
if hasattr(self.model.config, "unsloth_version"):
|
1620 |
-
tags.append("unsloth")
|
1621 |
-
|
1622 |
-
citation = textwrap.dedent("""\
|
1623 |
-
@article{jung2024binary,
|
1624 |
-
title = {{Binary Classifier Optimization for Large Language Model Alignment}},
|
1625 |
-
author = {Seungjae Jung and Gunsoo Han and Daniel Wontae Nam and Kyoung{-}Woon On},
|
1626 |
-
year = 2024,
|
1627 |
-
eprint = {arXiv:2404.04656}
|
1628 |
-
}""")
|
1629 |
-
|
1630 |
-
model_card = generate_model_card(
|
1631 |
-
base_model=base_model,
|
1632 |
-
model_name=model_name,
|
1633 |
-
hub_model_id=self.hub_model_id,
|
1634 |
-
dataset_name=dataset_name,
|
1635 |
-
tags=tags,
|
1636 |
-
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
1637 |
-
comet_url=get_comet_experiment_url(),
|
1638 |
-
trainer_name="BCO",
|
1639 |
-
trainer_citation=citation,
|
1640 |
-
paper_title="Binary Classifier Optimization for Large Language Model Alignment",
|
1641 |
-
paper_id="2404.04656",
|
1642 |
-
)
|
1643 |
-
|
1644 |
-
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
1645 |
-
class UnslothBCOTrainer(_UnslothBCOTrainer):
|
1646 |
-
"""
|
1647 |
-
|
1648 |
-
Initialize BCOTrainer from [BCO](https://huggingface.co/papers/2404.04656) paper.
|
1649 |
-
|
1650 |
-
Args:
|
1651 |
-
model (`transformers.PreTrainedModel`):
|
1652 |
-
The model to train, preferably an `AutoModelForSequenceClassification`.
|
1653 |
-
ref_model (`PreTrainedModelWrapper`):
|
1654 |
-
Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no
|
1655 |
-
reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized.
|
1656 |
-
args (`BCOConfig`):
|
1657 |
-
The arguments to use for training.
|
1658 |
-
train_dataset (`datasets.Dataset`):
|
1659 |
-
The dataset to use for training.
|
1660 |
-
eval_dataset (`datasets.Dataset`):
|
1661 |
-
The dataset to use for evaluation.
|
1662 |
-
processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
|
1663 |
-
Processing class used to process the data. If provided, will be used to automatically process the inputs
|
1664 |
-
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
|
1665 |
-
reuse the fine-tuned model.
|
1666 |
-
data_collator (`transformers.DataCollator`, *optional*, defaults to `None`):
|
1667 |
-
The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
|
1668 |
-
which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
|
1669 |
-
model_init (`Callable[[], transformers.PreTrainedModel]`):
|
1670 |
-
The model initializer to use for training. If None is specified, the default model initializer will be used.
|
1671 |
-
callbacks (`list[transformers.TrainerCallback]`):
|
1672 |
-
The callbacks to use for training.
|
1673 |
-
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
|
1674 |
-
The optimizer and scheduler to use for training.
|
1675 |
-
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
|
1676 |
-
The function to use to preprocess the logits before computing the metrics.
|
1677 |
-
peft_config (`dict`, defaults to `None`):
|
1678 |
-
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
|
1679 |
-
compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
|
1680 |
-
The function to use to compute the metrics. Must take a `EvalPrediction` and return
|
1681 |
-
a dictionary string to metric values.
|
1682 |
-
model_adapter_name (`str`, defaults to `None`):
|
1683 |
-
Name of the train target PEFT adapter, when using LoRA with multiple adapters.
|
1684 |
-
ref_adapter_name (`str`, defaults to `None`):
|
1685 |
-
Name of the reference PEFT adapter, when using LoRA with multiple adapters.
|
1686 |
-
|
1687 |
-
"""
|
1688 |
-
def __init__(
|
1689 |
-
self,
|
1690 |
-
model = None,
|
1691 |
-
ref_model = None,
|
1692 |
-
args = None,
|
1693 |
-
train_dataset = None,
|
1694 |
-
eval_dataset = None,
|
1695 |
-
processing_class = None,
|
1696 |
-
data_collator = None,
|
1697 |
-
model_init = None,
|
1698 |
-
callbacks = None,
|
1699 |
-
preprocess_logits_for_metrics = None,
|
1700 |
-
peft_config = None,
|
1701 |
-
compute_metrics = None,
|
1702 |
-
model_adapter_name = None,
|
1703 |
-
ref_adapter_name = None,
|
1704 |
-
embedding_func = None,
|
1705 |
-
embedding_tokenizer = None,
|
1706 |
-
**kwargs
|
1707 |
-
):
|
1708 |
-
if args is None: args = UnslothBCOConfig()
|
1709 |
-
use_bf16 = getattr(args, 'bf16', False)
|
1710 |
-
use_fp16 = getattr(args, 'fp16', False)
|
1711 |
-
force_float32 = False
|
1712 |
-
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
1713 |
-
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
1714 |
-
force_float32 = True
|
1715 |
-
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
1716 |
-
dtype = getattr(model.config, 'torch_dtype', None)
|
1717 |
-
if dtype is None: dtype = model.get_input_embeddings().dtype
|
1718 |
-
from unsloth_zoo.utils import _get_dtype
|
1719 |
-
dtype = _get_dtype(dtype)
|
1720 |
-
float16 = dtype == torch.float16
|
1721 |
-
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
1722 |
-
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
1723 |
-
if force_float32:
|
1724 |
-
args.fp16 = False
|
1725 |
-
args.bf16 = False
|
1726 |
-
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
1727 |
-
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
1728 |
-
args.fp16 = float16
|
1729 |
-
args.bf16 = not float16
|
1730 |
-
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
1731 |
-
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
1732 |
-
args.eval_strategy = 'steps'
|
1733 |
-
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
1734 |
-
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
1735 |
-
if ga_steps is not None and ga_steps > 1:
|
1736 |
-
from transformers import __version__ as transformers_version
|
1737 |
-
if Version(transformers_version) <= Version('4.45.2'):
|
1738 |
-
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
1739 |
-
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
1740 |
-
if getattr(args, 'eval_strategy', 'no') != 'no':
|
1741 |
-
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
1742 |
-
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
1743 |
-
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
1744 |
-
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
1745 |
-
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
1746 |
-
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
1747 |
-
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
1748 |
-
if force_float32:
|
1749 |
-
args.bf16_full_eval = False
|
1750 |
-
args.fp16_full_eval = False
|
1751 |
-
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
1752 |
-
args.bf16_full_eval = True
|
1753 |
-
args.fp16_full_eval = False
|
1754 |
-
elif not bf16_full_eval and not fp16_full_eval:
|
1755 |
-
args.bf16_full_eval = args.bf16
|
1756 |
-
args.fp16_full_eval = args.fp16
|
1757 |
-
_output_logits = False
|
1758 |
-
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
1759 |
-
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
1760 |
-
if _output_logits:
|
1761 |
-
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
1762 |
-
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
1763 |
-
pass
|
1764 |
-
else:
|
1765 |
-
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
1766 |
-
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
1767 |
-
if args_max_seq_length is None and model_max_seq_length is not None:
|
1768 |
-
max_seq_length = model.max_seq_length
|
1769 |
-
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
1770 |
-
if model is not None and hasattr(model, 'for_training'):
|
1771 |
-
model.for_training()
|
1772 |
-
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
1773 |
-
if 'processing_class' in locals():
|
1774 |
-
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
1775 |
-
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
1776 |
-
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
1777 |
-
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
1778 |
-
if not isinstance(data_collator, UnslothVisionDataCollator):
|
1779 |
-
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
1780 |
-
data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
|
1781 |
-
elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
1782 |
-
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
1783 |
-
else:
|
1784 |
-
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
1785 |
-
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
1786 |
-
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
1787 |
-
if not isinstance(data_collator, UnslothVisionDataCollator):
|
1788 |
-
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
1789 |
-
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
1790 |
-
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
1791 |
-
else:
|
1792 |
-
data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
|
1793 |
-
other_metrics = []
|
1794 |
-
|
1795 |
-
from unsloth_zoo.logging_utils import PatchRLStatistics
|
1796 |
-
PatchRLStatistics('bco_trainer', other_metrics)
|
1797 |
-
|
1798 |
-
super().__init__(
|
1799 |
-
model = model,
|
1800 |
-
ref_model = ref_model,
|
1801 |
-
args = args,
|
1802 |
-
train_dataset = train_dataset,
|
1803 |
-
eval_dataset = eval_dataset,
|
1804 |
-
processing_class = processing_class,
|
1805 |
-
data_collator = data_collator,
|
1806 |
-
model_init = model_init,
|
1807 |
-
callbacks = callbacks,
|
1808 |
-
preprocess_logits_for_metrics = preprocess_logits_for_metrics,
|
1809 |
-
peft_config = peft_config,
|
1810 |
-
compute_metrics = compute_metrics,
|
1811 |
-
model_adapter_name = model_adapter_name,
|
1812 |
-
ref_adapter_name = ref_adapter_name,
|
1813 |
-
embedding_func = embedding_func,
|
1814 |
-
embedding_tokenizer = embedding_tokenizer,**kwargs)
|
1815 |
-
if hasattr(self, 'neftune_hook_handle'):
|
1816 |
-
self.neftune_hook_handle.remove()
|
1817 |
-
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
1818 |
-
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
1819 |
-
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
1820 |
-
pass
|
1821 |
-
|
1822 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unsloth_compiled_cache/UnslothCPOTrainer.py
DELETED
@@ -1,1555 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
2025.3.13
|
3 |
-
2025.3.15
|
4 |
-
4.48.3
|
5 |
-
0.15.2
|
6 |
-
__UNSLOTH_VERSIONING__
|
7 |
-
"""
|
8 |
-
from torch import Tensor
|
9 |
-
import torch
|
10 |
-
import torch.nn as nn
|
11 |
-
from torch.nn import functional as F
|
12 |
-
from trl.trainer.cpo_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, CPOConfig, CPOTrainer, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, Literal, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, Trainer, TrainerCallback, Union, add_bos_token_if_needed, add_eos_token_if_needed, amp, defaultdict, disable_dropout_in_model, generate_model_card, get_comet_experiment_url, inspect, is_comet_available, is_peft_available, is_torch_fx_proxy, is_wandb_available, log_table_to_comet_experiment, maybe_apply_chat_template, maybe_extract_prompt, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_model_for_kbit_training, random, textwrap, torch, transformers, version, wandb, warnings)
|
13 |
-
|
14 |
-
|
15 |
-
import os
|
16 |
-
from typing import *
|
17 |
-
from dataclasses import dataclass, field
|
18 |
-
from packaging.version import Version
|
19 |
-
import torch
|
20 |
-
import numpy as np
|
21 |
-
from contextlib import nullcontext
|
22 |
-
from torch.nn import functional as F
|
23 |
-
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
|
24 |
-
|
25 |
-
torch_compile_options = {
|
26 |
-
"epilogue_fusion" : True,
|
27 |
-
"max_autotune" : False,
|
28 |
-
"shape_padding" : True,
|
29 |
-
"trace.enabled" : False,
|
30 |
-
"triton.cudagraphs" : False,
|
31 |
-
}
|
32 |
-
|
33 |
-
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
34 |
-
def selective_log_softmax(logits, index):
|
35 |
-
logits = logits.to(torch.float32)
|
36 |
-
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
37 |
-
# loop to reduce peak mem consumption
|
38 |
-
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
39 |
-
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
40 |
-
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
41 |
-
return per_token_logps
|
42 |
-
@dataclass
|
43 |
-
class UnslothCPOConfig(CPOConfig):
|
44 |
-
"""
|
45 |
-
|
46 |
-
Configuration class for the [`CPOTrainer`].
|
47 |
-
|
48 |
-
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
49 |
-
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
50 |
-
command line.
|
51 |
-
|
52 |
-
Parameters:
|
53 |
-
learning_rate (`float`, *optional*, defaults to `1e-6`):
|
54 |
-
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
|
55 |
-
[`~transformers.TrainingArguments`].
|
56 |
-
max_length (`int` or `None`, *optional*, defaults to `1024`):
|
57 |
-
Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
|
58 |
-
to use the default data collator.
|
59 |
-
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
|
60 |
-
Maximum length of the prompt. This argument is required if you want to use the default data collator.
|
61 |
-
max_completion_length (`int` or `None`, *optional*, defaults to `None`):
|
62 |
-
Maximum length of the completion. This argument is required if you want to use the default data collator
|
63 |
-
and your model is an encoder-decoder.
|
64 |
-
beta (`float`, *optional*, defaults to `0.1`):
|
65 |
-
Parameter controlling the deviation from the reference model. Higher β means less deviation from the
|
66 |
-
reference model. For the IPO loss (`loss_type="ipo"`), β is the regularization parameter denoted by τ in
|
67 |
-
the [paper](https://huggingface.co/papers/2310.12036).
|
68 |
-
label_smoothing (`float`, *optional*, defaults to `0.0`):
|
69 |
-
Label smoothing factor. This argument is required if you want to use the default data collator.
|
70 |
-
loss_type (`str`, *optional*, defaults to `"sigmoid"`):
|
71 |
-
Type of loss to use. Possible values are:
|
72 |
-
|
73 |
-
- `"sigmoid"`: sigmoid loss from the original [DPO](https://huggingface.co/papers/2305.18290) paper.
|
74 |
-
- `"hinge"`: hinge loss on the normalized likelihood from the [SLiC](https://huggingface.co/papers/2305.10425) paper.
|
75 |
-
- `"ipo"`: IPO loss from the [IPO](https://huggingface.co/papers/2310.12036) paper.
|
76 |
-
- `"simpo"`: SimPO loss from the [SimPO](https://huggingface.co/papers/2405.14734) paper.
|
77 |
-
|
78 |
-
disable_dropout (`bool`, *optional*, defaults to `True`):
|
79 |
-
Whether to disable dropout in the model.
|
80 |
-
cpo_alpha (`float`, *optional*, defaults to `1.0`):
|
81 |
-
Weight of the BC regularizer in CPO training.
|
82 |
-
simpo_gamma (`float`, *optional*, defaults to `0.5`):
|
83 |
-
Target reward margin for the SimPO loss, used only when the `loss_type="simpo"`.
|
84 |
-
label_pad_token_id (`int`, *optional*, defaults to `-100`):
|
85 |
-
Label pad token id. This argument is required if you want to use the default data collator.
|
86 |
-
padding_value (`int` or `None`, *optional*, defaults to `None`):
|
87 |
-
Padding value to use. If `None`, the padding value of the tokenizer is used.
|
88 |
-
truncation_mode (`str`,*optional*, defaults to `"keep_end"`):
|
89 |
-
Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`.
|
90 |
-
This argument is required if you want to use the default data collator.
|
91 |
-
generate_during_eval (`bool`, *optional*, defaults to `False`):
|
92 |
-
If `True`, generates and logs completions from the model to W&B or Comet during evaluation.
|
93 |
-
is_encoder_decoder (`bool` or `None`, *optional*, defaults to `None`):
|
94 |
-
When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument,
|
95 |
-
you need to specify if the model returned by the callable is an encoder-decoder model.
|
96 |
-
model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
|
97 |
-
Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
|
98 |
-
string.
|
99 |
-
dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
|
100 |
-
Number of processes to use for processing the dataset.
|
101 |
-
|
102 |
-
"""
|
103 |
-
vllm_sampling_params: Optional[Any] = field(
|
104 |
-
default = None,
|
105 |
-
metadata = {'help': 'vLLM SamplingParams'},
|
106 |
-
)
|
107 |
-
unsloth_num_chunks : Optional[int] = field(
|
108 |
-
default = -1,
|
109 |
-
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
110 |
-
)
|
111 |
-
def __init__(
|
112 |
-
self,
|
113 |
-
output_dir = None,
|
114 |
-
overwrite_output_dir = None,
|
115 |
-
do_train = False,
|
116 |
-
do_eval = False,
|
117 |
-
do_predict = False,
|
118 |
-
eval_strategy = 'no',
|
119 |
-
prediction_loss_only = False,
|
120 |
-
per_device_train_batch_size = 4,
|
121 |
-
per_device_eval_batch_size = 4,
|
122 |
-
per_gpu_train_batch_size = None,
|
123 |
-
per_gpu_eval_batch_size = None,
|
124 |
-
gradient_accumulation_steps = 2,
|
125 |
-
eval_accumulation_steps = 2,
|
126 |
-
eval_delay = 0,
|
127 |
-
torch_empty_cache_steps = 250,
|
128 |
-
learning_rate = 5e-05,
|
129 |
-
weight_decay = 0.01,
|
130 |
-
adam_beta1 = 0.9,
|
131 |
-
adam_beta2 = 0.999,
|
132 |
-
adam_epsilon = 1e-08,
|
133 |
-
max_grad_norm = 1.0,
|
134 |
-
num_train_epochs = 3.0,
|
135 |
-
max_steps = -1,
|
136 |
-
lr_scheduler_type = 'linear',
|
137 |
-
warmup_ratio = 0.1,
|
138 |
-
warmup_steps = 0,
|
139 |
-
log_level = 'passive',
|
140 |
-
log_level_replica = 'warning',
|
141 |
-
log_on_each_node = True,
|
142 |
-
logging_dir = None,
|
143 |
-
logging_strategy = 'steps',
|
144 |
-
logging_first_step = False,
|
145 |
-
logging_steps = 1,
|
146 |
-
logging_nan_inf_filter = False,
|
147 |
-
save_strategy = 'steps',
|
148 |
-
save_steps = 500,
|
149 |
-
save_total_limit = None,
|
150 |
-
save_safetensors = True,
|
151 |
-
save_on_each_node = False,
|
152 |
-
save_only_model = False,
|
153 |
-
restore_callback_states_from_checkpoint = False,
|
154 |
-
no_cuda = False,
|
155 |
-
use_cpu = False,
|
156 |
-
use_mps_device = False,
|
157 |
-
seed = 3407,
|
158 |
-
data_seed = 3407,
|
159 |
-
jit_mode_eval = False,
|
160 |
-
use_ipex = False,
|
161 |
-
bf16 = False,
|
162 |
-
fp16 = False,
|
163 |
-
fp16_opt_level = 'O1',
|
164 |
-
half_precision_backend = 'auto',
|
165 |
-
bf16_full_eval = False,
|
166 |
-
fp16_full_eval = False,
|
167 |
-
tf32 = None,
|
168 |
-
local_rank = -1,
|
169 |
-
ddp_backend = None,
|
170 |
-
tpu_num_cores = None,
|
171 |
-
tpu_metrics_debug = False,
|
172 |
-
debug = '',
|
173 |
-
dataloader_drop_last = False,
|
174 |
-
eval_steps = None,
|
175 |
-
dataloader_num_workers = 0,
|
176 |
-
dataloader_prefetch_factor = None,
|
177 |
-
past_index = -1,
|
178 |
-
run_name = None,
|
179 |
-
disable_tqdm = None,
|
180 |
-
remove_unused_columns = True,
|
181 |
-
label_names = None,
|
182 |
-
load_best_model_at_end = False,
|
183 |
-
metric_for_best_model = None,
|
184 |
-
greater_is_better = None,
|
185 |
-
ignore_data_skip = False,
|
186 |
-
fsdp = '',
|
187 |
-
fsdp_min_num_params = 0,
|
188 |
-
fsdp_config = None,
|
189 |
-
fsdp_transformer_layer_cls_to_wrap = None,
|
190 |
-
accelerator_config = None,
|
191 |
-
deepspeed = None,
|
192 |
-
label_smoothing_factor = 0.0,
|
193 |
-
optim = 'adamw_8bit',
|
194 |
-
optim_args = None,
|
195 |
-
adafactor = False,
|
196 |
-
group_by_length = False,
|
197 |
-
length_column_name = 'length',
|
198 |
-
report_to = None,
|
199 |
-
ddp_find_unused_parameters = None,
|
200 |
-
ddp_bucket_cap_mb = None,
|
201 |
-
ddp_broadcast_buffers = None,
|
202 |
-
dataloader_pin_memory = True,
|
203 |
-
dataloader_persistent_workers = False,
|
204 |
-
skip_memory_metrics = True,
|
205 |
-
use_legacy_prediction_loop = False,
|
206 |
-
push_to_hub = False,
|
207 |
-
resume_from_checkpoint = None,
|
208 |
-
hub_model_id = None,
|
209 |
-
hub_strategy = 'every_save',
|
210 |
-
hub_token = None,
|
211 |
-
hub_private_repo = None,
|
212 |
-
hub_always_push = False,
|
213 |
-
gradient_checkpointing = False,
|
214 |
-
gradient_checkpointing_kwargs = None,
|
215 |
-
include_inputs_for_metrics = False,
|
216 |
-
eval_do_concat_batches = True,
|
217 |
-
fp16_backend = 'auto',
|
218 |
-
evaluation_strategy = None,
|
219 |
-
push_to_hub_model_id = None,
|
220 |
-
push_to_hub_organization = None,
|
221 |
-
push_to_hub_token = None,
|
222 |
-
mp_parameters = '',
|
223 |
-
auto_find_batch_size = False,
|
224 |
-
full_determinism = False,
|
225 |
-
torchdynamo = None,
|
226 |
-
ray_scope = 'last',
|
227 |
-
ddp_timeout = 1800,
|
228 |
-
torch_compile = False,
|
229 |
-
torch_compile_backend = None,
|
230 |
-
torch_compile_mode = None,
|
231 |
-
dispatch_batches = None,
|
232 |
-
split_batches = None,
|
233 |
-
include_tokens_per_second = False,
|
234 |
-
include_num_input_tokens_seen = False,
|
235 |
-
neftune_noise_alpha = None,
|
236 |
-
optim_target_modules = None,
|
237 |
-
batch_eval_metrics = False,
|
238 |
-
eval_on_start = False,
|
239 |
-
use_liger_kernel = False,
|
240 |
-
eval_use_gather_object = False,
|
241 |
-
average_tokens_across_devices = False,
|
242 |
-
max_length = 1024,
|
243 |
-
max_prompt_length = 512,
|
244 |
-
max_completion_length = None,
|
245 |
-
beta = 0.1,
|
246 |
-
label_smoothing = 0.0,
|
247 |
-
loss_type = 'sigmoid',
|
248 |
-
disable_dropout = True,
|
249 |
-
cpo_alpha = 1.0,
|
250 |
-
simpo_gamma = 0.5,
|
251 |
-
label_pad_token_id = -100,
|
252 |
-
padding_value = None,
|
253 |
-
truncation_mode = 'keep_end',
|
254 |
-
generate_during_eval = False,
|
255 |
-
is_encoder_decoder = None,
|
256 |
-
model_init_kwargs = None,
|
257 |
-
dataset_num_proc = None,
|
258 |
-
vllm_sampling_params = None,
|
259 |
-
unsloth_num_chunks = -1,
|
260 |
-
**kwargs,
|
261 |
-
):
|
262 |
-
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
263 |
-
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
264 |
-
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
265 |
-
output_dir = 'unsloth_training_checkpoints'
|
266 |
-
save_strategy = 'no'
|
267 |
-
if dataset_num_proc is None:
|
268 |
-
from multiprocessing import cpu_count
|
269 |
-
dataset_num_proc = cpu_count()
|
270 |
-
|
271 |
-
super().__init__(
|
272 |
-
output_dir = output_dir,
|
273 |
-
overwrite_output_dir = overwrite_output_dir,
|
274 |
-
do_train = do_train,
|
275 |
-
do_eval = do_eval,
|
276 |
-
do_predict = do_predict,
|
277 |
-
eval_strategy = eval_strategy,
|
278 |
-
prediction_loss_only = prediction_loss_only,
|
279 |
-
per_device_train_batch_size = per_device_train_batch_size,
|
280 |
-
per_device_eval_batch_size = per_device_eval_batch_size,
|
281 |
-
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
282 |
-
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
283 |
-
gradient_accumulation_steps = gradient_accumulation_steps,
|
284 |
-
eval_accumulation_steps = eval_accumulation_steps,
|
285 |
-
eval_delay = eval_delay,
|
286 |
-
torch_empty_cache_steps = torch_empty_cache_steps,
|
287 |
-
learning_rate = learning_rate,
|
288 |
-
weight_decay = weight_decay,
|
289 |
-
adam_beta1 = adam_beta1,
|
290 |
-
adam_beta2 = adam_beta2,
|
291 |
-
adam_epsilon = adam_epsilon,
|
292 |
-
max_grad_norm = max_grad_norm,
|
293 |
-
num_train_epochs = num_train_epochs,
|
294 |
-
max_steps = max_steps,
|
295 |
-
lr_scheduler_type = lr_scheduler_type,
|
296 |
-
warmup_ratio = warmup_ratio,
|
297 |
-
warmup_steps = warmup_steps,
|
298 |
-
log_level = log_level,
|
299 |
-
log_level_replica = log_level_replica,
|
300 |
-
log_on_each_node = log_on_each_node,
|
301 |
-
logging_dir = logging_dir,
|
302 |
-
logging_strategy = logging_strategy,
|
303 |
-
logging_first_step = logging_first_step,
|
304 |
-
logging_steps = logging_steps,
|
305 |
-
logging_nan_inf_filter = logging_nan_inf_filter,
|
306 |
-
save_strategy = save_strategy,
|
307 |
-
save_steps = save_steps,
|
308 |
-
save_total_limit = save_total_limit,
|
309 |
-
save_safetensors = save_safetensors,
|
310 |
-
save_on_each_node = save_on_each_node,
|
311 |
-
save_only_model = save_only_model,
|
312 |
-
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
313 |
-
no_cuda = no_cuda,
|
314 |
-
use_cpu = use_cpu,
|
315 |
-
use_mps_device = use_mps_device,
|
316 |
-
seed = seed,
|
317 |
-
data_seed = data_seed,
|
318 |
-
jit_mode_eval = jit_mode_eval,
|
319 |
-
use_ipex = use_ipex,
|
320 |
-
bf16 = bf16,
|
321 |
-
fp16 = fp16,
|
322 |
-
fp16_opt_level = fp16_opt_level,
|
323 |
-
half_precision_backend = half_precision_backend,
|
324 |
-
bf16_full_eval = bf16_full_eval,
|
325 |
-
fp16_full_eval = fp16_full_eval,
|
326 |
-
tf32 = tf32,
|
327 |
-
local_rank = local_rank,
|
328 |
-
ddp_backend = ddp_backend,
|
329 |
-
tpu_num_cores = tpu_num_cores,
|
330 |
-
tpu_metrics_debug = tpu_metrics_debug,
|
331 |
-
debug = debug,
|
332 |
-
dataloader_drop_last = dataloader_drop_last,
|
333 |
-
eval_steps = eval_steps,
|
334 |
-
dataloader_num_workers = dataloader_num_workers,
|
335 |
-
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
336 |
-
past_index = past_index,
|
337 |
-
run_name = run_name,
|
338 |
-
disable_tqdm = disable_tqdm,
|
339 |
-
remove_unused_columns = remove_unused_columns,
|
340 |
-
label_names = label_names,
|
341 |
-
load_best_model_at_end = load_best_model_at_end,
|
342 |
-
metric_for_best_model = metric_for_best_model,
|
343 |
-
greater_is_better = greater_is_better,
|
344 |
-
ignore_data_skip = ignore_data_skip,
|
345 |
-
fsdp = fsdp,
|
346 |
-
fsdp_min_num_params = fsdp_min_num_params,
|
347 |
-
fsdp_config = fsdp_config,
|
348 |
-
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
349 |
-
accelerator_config = accelerator_config,
|
350 |
-
deepspeed = deepspeed,
|
351 |
-
label_smoothing_factor = label_smoothing_factor,
|
352 |
-
optim = optim,
|
353 |
-
optim_args = optim_args,
|
354 |
-
adafactor = adafactor,
|
355 |
-
group_by_length = group_by_length,
|
356 |
-
length_column_name = length_column_name,
|
357 |
-
report_to = report_to,
|
358 |
-
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
359 |
-
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
360 |
-
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
361 |
-
dataloader_pin_memory = dataloader_pin_memory,
|
362 |
-
dataloader_persistent_workers = dataloader_persistent_workers,
|
363 |
-
skip_memory_metrics = skip_memory_metrics,
|
364 |
-
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
365 |
-
push_to_hub = push_to_hub,
|
366 |
-
resume_from_checkpoint = resume_from_checkpoint,
|
367 |
-
hub_model_id = hub_model_id,
|
368 |
-
hub_strategy = hub_strategy,
|
369 |
-
hub_token = hub_token,
|
370 |
-
hub_private_repo = hub_private_repo,
|
371 |
-
hub_always_push = hub_always_push,
|
372 |
-
gradient_checkpointing = gradient_checkpointing,
|
373 |
-
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
374 |
-
include_inputs_for_metrics = include_inputs_for_metrics,
|
375 |
-
eval_do_concat_batches = eval_do_concat_batches,
|
376 |
-
fp16_backend = fp16_backend,
|
377 |
-
evaluation_strategy = evaluation_strategy,
|
378 |
-
push_to_hub_model_id = push_to_hub_model_id,
|
379 |
-
push_to_hub_organization = push_to_hub_organization,
|
380 |
-
push_to_hub_token = push_to_hub_token,
|
381 |
-
mp_parameters = mp_parameters,
|
382 |
-
auto_find_batch_size = auto_find_batch_size,
|
383 |
-
full_determinism = full_determinism,
|
384 |
-
torchdynamo = torchdynamo,
|
385 |
-
ray_scope = ray_scope,
|
386 |
-
ddp_timeout = ddp_timeout,
|
387 |
-
torch_compile = torch_compile,
|
388 |
-
torch_compile_backend = torch_compile_backend,
|
389 |
-
torch_compile_mode = torch_compile_mode,
|
390 |
-
dispatch_batches = dispatch_batches,
|
391 |
-
split_batches = split_batches,
|
392 |
-
include_tokens_per_second = include_tokens_per_second,
|
393 |
-
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
394 |
-
neftune_noise_alpha = neftune_noise_alpha,
|
395 |
-
optim_target_modules = optim_target_modules,
|
396 |
-
batch_eval_metrics = batch_eval_metrics,
|
397 |
-
eval_on_start = eval_on_start,
|
398 |
-
use_liger_kernel = use_liger_kernel,
|
399 |
-
eval_use_gather_object = eval_use_gather_object,
|
400 |
-
average_tokens_across_devices = average_tokens_across_devices,
|
401 |
-
max_length = max_length,
|
402 |
-
max_prompt_length = max_prompt_length,
|
403 |
-
max_completion_length = max_completion_length,
|
404 |
-
beta = beta,
|
405 |
-
label_smoothing = label_smoothing,
|
406 |
-
loss_type = loss_type,
|
407 |
-
disable_dropout = disable_dropout,
|
408 |
-
cpo_alpha = cpo_alpha,
|
409 |
-
simpo_gamma = simpo_gamma,
|
410 |
-
label_pad_token_id = label_pad_token_id,
|
411 |
-
padding_value = padding_value,
|
412 |
-
truncation_mode = truncation_mode,
|
413 |
-
generate_during_eval = generate_during_eval,
|
414 |
-
is_encoder_decoder = is_encoder_decoder,
|
415 |
-
model_init_kwargs = model_init_kwargs,
|
416 |
-
dataset_num_proc = dataset_num_proc,**kwargs)
|
417 |
-
self.vllm_sampling_params = vllm_sampling_params
|
418 |
-
self.unsloth_num_chunks = unsloth_num_chunks
|
419 |
-
pass
|
420 |
-
|
421 |
-
class _UnslothCPOTrainer(Trainer):
|
422 |
-
r""""""
|
423 |
-
|
424 |
-
_tag_names = ["trl", "cpo"]
|
425 |
-
|
426 |
-
def __init__(
|
427 |
-
self,
|
428 |
-
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
|
429 |
-
args: Optional[CPOConfig] = None,
|
430 |
-
data_collator: Optional[DataCollator] = None,
|
431 |
-
train_dataset: Optional[Dataset] = None,
|
432 |
-
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
433 |
-
processing_class: Optional[
|
434 |
-
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
435 |
-
] = None,
|
436 |
-
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
437 |
-
callbacks: Optional[list[TrainerCallback]] = None,
|
438 |
-
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
439 |
-
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
440 |
-
peft_config: Optional[dict] = None,
|
441 |
-
compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
|
442 |
-
):
|
443 |
-
if args.model_init_kwargs is None:
|
444 |
-
model_init_kwargs = {}
|
445 |
-
elif not isinstance(model, str):
|
446 |
-
raise ValueError("You passed model_kwargs to the CPOTrainer. But your model is already instantiated.")
|
447 |
-
else:
|
448 |
-
model_init_kwargs = args.model_init_kwargs
|
449 |
-
torch_dtype = model_init_kwargs.get("torch_dtype")
|
450 |
-
if torch_dtype is not None:
|
451 |
-
# Convert to `torch.dtype` if an str is passed
|
452 |
-
if isinstance(torch_dtype, str) and torch_dtype != "auto":
|
453 |
-
torch_dtype = getattr(torch, torch_dtype)
|
454 |
-
if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
|
455 |
-
raise ValueError(
|
456 |
-
f"Invalid `torch_dtype` passed to the CPOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
|
457 |
-
)
|
458 |
-
model_init_kwargs["torch_dtype"] = torch_dtype
|
459 |
-
|
460 |
-
if isinstance(model, str):
|
461 |
-
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
|
462 |
-
|
463 |
-
# Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
|
464 |
-
# has been called in order to properly call autocast if needed.
|
465 |
-
self._peft_has_been_casted_to_bf16 = False
|
466 |
-
|
467 |
-
if not is_peft_available() and peft_config is not None:
|
468 |
-
raise ValueError(
|
469 |
-
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
|
470 |
-
)
|
471 |
-
elif is_peft_available() and peft_config is not None:
|
472 |
-
# if model is a peft model and we have a peft_config, we merge and unload it first
|
473 |
-
if isinstance(model, PeftModel):
|
474 |
-
model = model.merge_and_unload()
|
475 |
-
|
476 |
-
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
|
477 |
-
_support_gc_kwargs = hasattr(
|
478 |
-
args, "gradient_checkpointing_kwargs"
|
479 |
-
) and "gradient_checkpointing_kwargs" in list(
|
480 |
-
inspect.signature(prepare_model_for_kbit_training).parameters
|
481 |
-
)
|
482 |
-
|
483 |
-
prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
|
484 |
-
|
485 |
-
if _support_gc_kwargs:
|
486 |
-
prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
|
487 |
-
|
488 |
-
model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
|
489 |
-
elif getattr(args, "gradient_checkpointing", False):
|
490 |
-
# For backward compatibility with older versions of transformers
|
491 |
-
if hasattr(model, "enable_input_require_grads"):
|
492 |
-
model.enable_input_require_grads()
|
493 |
-
else:
|
494 |
-
|
495 |
-
def make_inputs_require_grad(module, input, output):
|
496 |
-
output.requires_grad_(True)
|
497 |
-
|
498 |
-
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
499 |
-
|
500 |
-
# get peft model with the given config
|
501 |
-
model = model
|
502 |
-
if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
|
503 |
-
peft_module_casting_to_bf16(model)
|
504 |
-
# If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
|
505 |
-
self._peft_has_been_casted_to_bf16 = True
|
506 |
-
|
507 |
-
# For models that use gradient_checkpointing, we need to attach a hook that enables input
|
508 |
-
# to explicitly have `requires_grad=True`, otherwise training will either silently
|
509 |
-
# fail or completely fail.
|
510 |
-
elif getattr(args, "gradient_checkpointing", False):
|
511 |
-
# For backward compatibility with older versions of transformers
|
512 |
-
if hasattr(model, "enable_input_require_grads"):
|
513 |
-
model.enable_input_require_grads()
|
514 |
-
else:
|
515 |
-
|
516 |
-
def make_inputs_require_grad(module, input, output):
|
517 |
-
output.requires_grad_(True)
|
518 |
-
|
519 |
-
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
520 |
-
|
521 |
-
if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
|
522 |
-
raise ValueError(
|
523 |
-
"`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
|
524 |
-
" Please install `wandb` or `comet-ml` to resolve."
|
525 |
-
)
|
526 |
-
|
527 |
-
if model is not None:
|
528 |
-
self.is_encoder_decoder = model.config.is_encoder_decoder
|
529 |
-
elif args.is_encoder_decoder is None:
|
530 |
-
raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
|
531 |
-
else:
|
532 |
-
self.is_encoder_decoder = args.is_encoder_decoder
|
533 |
-
|
534 |
-
if self.is_encoder_decoder:
|
535 |
-
self.decoder_start_token_id = model.config.decoder_start_token_id
|
536 |
-
self.pad_token_id = model.config.pad_token_id
|
537 |
-
|
538 |
-
if processing_class is None:
|
539 |
-
raise ValueError("processing_class must be specified to tokenize a CPO dataset.")
|
540 |
-
if args.max_length is None:
|
541 |
-
warnings.warn(
|
542 |
-
"`max_length` is not set in the CPOConfig's init"
|
543 |
-
" it will default to `512` by default, but you should do it yourself in the future.",
|
544 |
-
UserWarning,
|
545 |
-
)
|
546 |
-
max_length = 512
|
547 |
-
else:
|
548 |
-
max_length = args.max_length
|
549 |
-
if args.max_prompt_length is None:
|
550 |
-
warnings.warn(
|
551 |
-
"`max_prompt_length` is not set in the CPOConfig's init"
|
552 |
-
" it will default to `128` by default, but you should do it yourself in the future.",
|
553 |
-
UserWarning,
|
554 |
-
)
|
555 |
-
max_prompt_length = 128
|
556 |
-
else:
|
557 |
-
max_prompt_length = args.max_prompt_length
|
558 |
-
|
559 |
-
if args.max_completion_length is None and self.is_encoder_decoder:
|
560 |
-
warnings.warn(
|
561 |
-
"When using an encoder decoder architecture, you should set `max_completion_length` in the CPOConfig's init"
|
562 |
-
" it will default to `128` by default, but you should do it yourself in the future.",
|
563 |
-
UserWarning,
|
564 |
-
)
|
565 |
-
max_completion_length = 128
|
566 |
-
else:
|
567 |
-
max_completion_length = args.max_completion_length
|
568 |
-
|
569 |
-
if data_collator is None:
|
570 |
-
data_collator = DPODataCollatorWithPadding(
|
571 |
-
pad_token_id=processing_class.pad_token_id,
|
572 |
-
label_pad_token_id=args.label_pad_token_id,
|
573 |
-
is_encoder_decoder=self.is_encoder_decoder,
|
574 |
-
)
|
575 |
-
|
576 |
-
if args.remove_unused_columns:
|
577 |
-
args.remove_unused_columns = False
|
578 |
-
# warn users
|
579 |
-
warnings.warn(
|
580 |
-
"When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments"
|
581 |
-
" we have set it for you, but you should do it yourself in the future.",
|
582 |
-
UserWarning,
|
583 |
-
)
|
584 |
-
|
585 |
-
self.use_dpo_data_collator = True
|
586 |
-
else:
|
587 |
-
self.use_dpo_data_collator = False
|
588 |
-
|
589 |
-
# Disable dropout in the model
|
590 |
-
if args.disable_dropout:
|
591 |
-
disable_dropout_in_model(model)
|
592 |
-
|
593 |
-
self.max_length = max_length
|
594 |
-
self.generate_during_eval = args.generate_during_eval
|
595 |
-
self.label_pad_token_id = args.label_pad_token_id
|
596 |
-
self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id
|
597 |
-
self.max_prompt_length = max_prompt_length
|
598 |
-
self.truncation_mode = args.truncation_mode
|
599 |
-
self.max_completion_length = max_completion_length
|
600 |
-
self.processing_class = processing_class
|
601 |
-
|
602 |
-
if args.loss_type in ["hinge", "ipo"] and args.label_smoothing > 0:
|
603 |
-
warnings.warn(
|
604 |
-
f"You are using the {args.loss_type} loss type that does not support label smoothing. The "
|
605 |
-
"`label_smoothing` parameter will be ignored. Set `label_smoothing` to `0.0` to remove this warning.",
|
606 |
-
UserWarning,
|
607 |
-
)
|
608 |
-
if args.loss_type == "kto_pair":
|
609 |
-
raise ValueError("Support for kto_pair has been removed in CPOTrainer. Please use KTOTrainer.")
|
610 |
-
|
611 |
-
self.beta = args.beta
|
612 |
-
self.label_smoothing = args.label_smoothing
|
613 |
-
self.loss_type = args.loss_type
|
614 |
-
self.cpo_alpha = args.cpo_alpha
|
615 |
-
self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
|
616 |
-
self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
|
617 |
-
if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
|
618 |
-
warnings.warn(
|
619 |
-
"You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
|
620 |
-
"`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
|
621 |
-
"greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
|
622 |
-
"loss.",
|
623 |
-
UserWarning,
|
624 |
-
)
|
625 |
-
|
626 |
-
if args.loss_type == "simpo":
|
627 |
-
self.simpo_gamma = args.simpo_gamma
|
628 |
-
|
629 |
-
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
630 |
-
|
631 |
-
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
|
632 |
-
# input tensor associated with the key "input_ids". However, in CPO, the sampled data does not include the
|
633 |
-
# "input_ids" key. Instead, the available keys are "prompt_input_ids", "chosen_input_ids", and
|
634 |
-
# "rejected_input_ids". As a result, the trainer issues the warning: "Could not estimate the number of tokens
|
635 |
-
# of the input, floating-point operations will not be computed." To suppress this warning, we set the
|
636 |
-
# "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate
|
637 |
-
# that the warning has already been issued.
|
638 |
-
model.warnings_issued["estimate_tokens"] = True
|
639 |
-
|
640 |
-
# Compute that only on the main process for faster data processing.
|
641 |
-
# see: https://github.com/huggingface/trl/pull/1255
|
642 |
-
with PartialState().local_main_process_first():
|
643 |
-
# Extract the prompt if needed, and apply the chat template if needed
|
644 |
-
train_dataset = train_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
|
645 |
-
train_dataset = train_dataset.map(
|
646 |
-
maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, num_proc=args.dataset_num_proc
|
647 |
-
)
|
648 |
-
if eval_dataset is not None:
|
649 |
-
eval_dataset = eval_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
|
650 |
-
eval_dataset = eval_dataset.map(
|
651 |
-
maybe_apply_chat_template,
|
652 |
-
fn_kwargs={"tokenizer": processing_class},
|
653 |
-
num_proc=args.dataset_num_proc,
|
654 |
-
)
|
655 |
-
|
656 |
-
# tokenize the dataset
|
657 |
-
train_dataset = train_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
|
658 |
-
if eval_dataset is not None:
|
659 |
-
eval_dataset = eval_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
|
660 |
-
|
661 |
-
super().__init__(
|
662 |
-
model=model,
|
663 |
-
args=args,
|
664 |
-
data_collator=data_collator,
|
665 |
-
train_dataset=train_dataset,
|
666 |
-
eval_dataset=eval_dataset,
|
667 |
-
processing_class=processing_class,
|
668 |
-
model_init=model_init,
|
669 |
-
compute_metrics=compute_metrics,
|
670 |
-
callbacks=callbacks,
|
671 |
-
optimizers=optimizers,
|
672 |
-
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
673 |
-
)
|
674 |
-
|
675 |
-
# Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
|
676 |
-
# model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
|
677 |
-
# self.model_accepts_loss_kwargs to False to enable scaling.
|
678 |
-
self.model_accepts_loss_kwargs = False
|
679 |
-
|
680 |
-
# Add tags for models that have been loaded with the correct transformers version
|
681 |
-
if hasattr(self.model, "add_model_tags"):
|
682 |
-
self.model.add_model_tags(self._tag_names)
|
683 |
-
|
684 |
-
if not hasattr(self, "accelerator"):
|
685 |
-
raise AttributeError(
|
686 |
-
"Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
|
687 |
-
)
|
688 |
-
|
689 |
-
def build_tokenized_answer(self, prompt, answer):
|
690 |
-
"""
|
691 |
-
Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`.
|
692 |
-
It does ensure `enc(a + b) = enc(a) + enc(a + b)[len(enc(a)):]`.
|
693 |
-
Reference:
|
694 |
-
https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
|
695 |
-
"""
|
696 |
-
|
697 |
-
full_tokenized = self.processing_class(prompt + answer, add_special_tokens=False)
|
698 |
-
prompt_input_ids = self.processing_class(prompt, add_special_tokens=False)["input_ids"]
|
699 |
-
|
700 |
-
answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids) :]
|
701 |
-
answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids) :]
|
702 |
-
|
703 |
-
# Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]`
|
704 |
-
full_concat_input_ids = np.concatenate([prompt_input_ids, answer_input_ids])
|
705 |
-
|
706 |
-
# Prepare input tokens for token by token comparison
|
707 |
-
full_input_ids = np.array(full_tokenized["input_ids"])
|
708 |
-
|
709 |
-
if len(full_input_ids) != len(full_concat_input_ids):
|
710 |
-
raise ValueError("Prompt input ids and answer input ids should have the same length.")
|
711 |
-
|
712 |
-
# On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens
|
713 |
-
# can be merged together when tokenizing prompt+answer. This could result
|
714 |
-
# on the last token from the prompt being different when tokenized on its own
|
715 |
-
# vs when done as prompt+answer.
|
716 |
-
response_token_ids_start_idx = len(prompt_input_ids)
|
717 |
-
|
718 |
-
# If tokenized prompt is different than both prompt+answer, then it means the
|
719 |
-
# last token has changed due to merging.
|
720 |
-
if prompt_input_ids != full_tokenized["input_ids"][:response_token_ids_start_idx]:
|
721 |
-
response_token_ids_start_idx -= 1
|
722 |
-
|
723 |
-
prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx]
|
724 |
-
prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx]
|
725 |
-
|
726 |
-
if len(prompt_input_ids) != len(prompt_attention_mask):
|
727 |
-
raise ValueError("Prompt input ids and attention mask should have the same length.")
|
728 |
-
|
729 |
-
answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:]
|
730 |
-
answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:]
|
731 |
-
|
732 |
-
return dict(
|
733 |
-
prompt_input_ids=prompt_input_ids,
|
734 |
-
prompt_attention_mask=prompt_attention_mask,
|
735 |
-
input_ids=answer_input_ids,
|
736 |
-
attention_mask=answer_attention_mask,
|
737 |
-
)
|
738 |
-
|
739 |
-
def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module]] = None) -> dict:
|
740 |
-
"""Tokenize a single row from a CPO specific dataset.
|
741 |
-
|
742 |
-
At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation
|
743 |
-
in case the prompt + chosen or prompt + rejected responses is/are too long. First
|
744 |
-
we truncate the prompt; if we're still too long, we truncate the chosen/rejected.
|
745 |
-
|
746 |
-
We also create the labels for the chosen/rejected responses, which are of length equal to
|
747 |
-
the sum of the length of the prompt and the chosen/rejected response, with
|
748 |
-
label_pad_token_id for the prompt tokens.
|
749 |
-
"""
|
750 |
-
batch = {}
|
751 |
-
prompt = feature["prompt"]
|
752 |
-
chosen = feature["chosen"]
|
753 |
-
rejected = feature["rejected"]
|
754 |
-
|
755 |
-
if not self.is_encoder_decoder:
|
756 |
-
# Check issues below for more details
|
757 |
-
# 1. https://github.com/huggingface/trl/issues/907
|
758 |
-
# 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
|
759 |
-
# 3. https://github.com/LianjiaTech/BELLE/issues/337
|
760 |
-
|
761 |
-
if not isinstance(prompt, str):
|
762 |
-
raise ValueError(f"prompt should be an str but got {type(prompt)}")
|
763 |
-
prompt_tokens = self.processing_class(prompt, add_special_tokens=False)
|
764 |
-
prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()}
|
765 |
-
|
766 |
-
if not isinstance(chosen, str):
|
767 |
-
raise ValueError(f"chosen should be an str but got {type(chosen)}")
|
768 |
-
chosen_tokens = self.build_tokenized_answer(prompt, chosen)
|
769 |
-
|
770 |
-
if not isinstance(rejected, str):
|
771 |
-
raise ValueError(f"rejected should be an str but got {type(rejected)}")
|
772 |
-
rejected_tokens = self.build_tokenized_answer(prompt, rejected)
|
773 |
-
|
774 |
-
# Last prompt token might get merged by tokenizer and
|
775 |
-
# it should not be included for generation if that happens
|
776 |
-
prompt_len_input_ids = len(prompt_tokens["prompt_input_ids"])
|
777 |
-
|
778 |
-
chosen_prompt_len_input_ids = len(chosen_tokens["prompt_input_ids"])
|
779 |
-
rejected_prompt_len_input_ids = len(rejected_tokens["prompt_input_ids"])
|
780 |
-
prompt_len_input_ids = min(chosen_prompt_len_input_ids, rejected_prompt_len_input_ids)
|
781 |
-
|
782 |
-
for k, v in prompt_tokens.items():
|
783 |
-
prompt_tokens[k] = v[:prompt_len_input_ids]
|
784 |
-
|
785 |
-
# Make sure prompts only have one different token at most an
|
786 |
-
# and length only differs by 1 at most
|
787 |
-
num_diff_tokens = sum(
|
788 |
-
[a != b for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"])]
|
789 |
-
)
|
790 |
-
num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids)
|
791 |
-
if num_diff_tokens > 1 or num_diff_len > 1:
|
792 |
-
raise ValueError(
|
793 |
-
"Chosen and rejected prompt_input_ids might only differ on the "
|
794 |
-
"last token due to tokenizer merge ops."
|
795 |
-
)
|
796 |
-
|
797 |
-
# add BOS token to head of prompt. Avoid adding if it's already there
|
798 |
-
prompt_tokens, chosen_tokens, rejected_tokens = add_bos_token_if_needed(
|
799 |
-
self.processing_class.bos_token_id,
|
800 |
-
prompt_len_input_ids,
|
801 |
-
prompt_tokens,
|
802 |
-
chosen_prompt_len_input_ids,
|
803 |
-
chosen_tokens,
|
804 |
-
rejected_prompt_len_input_ids,
|
805 |
-
rejected_tokens,
|
806 |
-
)
|
807 |
-
|
808 |
-
# add EOS token to end of answer. Avoid adding if it's already there
|
809 |
-
chosen_tokens, rejected_tokens = add_eos_token_if_needed(
|
810 |
-
self.processing_class.eos_token_id, chosen_tokens, rejected_tokens
|
811 |
-
)
|
812 |
-
|
813 |
-
longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))
|
814 |
-
|
815 |
-
# if combined sequence is too long, truncate the prompt
|
816 |
-
for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]:
|
817 |
-
if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
|
818 |
-
if self.truncation_mode == "keep_start":
|
819 |
-
for k in ["prompt_input_ids", "prompt_attention_mask"]:
|
820 |
-
answer_tokens[k] = answer_tokens[k][: self.max_prompt_length]
|
821 |
-
elif self.truncation_mode == "keep_end":
|
822 |
-
for k in ["prompt_input_ids", "prompt_attention_mask"]:
|
823 |
-
answer_tokens[k] = answer_tokens[k][-self.max_prompt_length :]
|
824 |
-
else:
|
825 |
-
raise ValueError(f"Unknown truncation mode: {self.truncation_mode}")
|
826 |
-
|
827 |
-
# if that's still too long, truncate the response
|
828 |
-
for answer_tokens in [chosen_tokens, rejected_tokens]:
|
829 |
-
if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
|
830 |
-
for k in ["input_ids", "attention_mask"]:
|
831 |
-
answer_tokens[k] = answer_tokens[k][: self.max_length - self.max_prompt_length]
|
832 |
-
|
833 |
-
# Create labels
|
834 |
-
chosen_sequence_tokens = {
|
835 |
-
k: chosen_tokens[f"prompt_{k}"] + chosen_tokens[k] for k in ["input_ids", "attention_mask"]
|
836 |
-
}
|
837 |
-
rejected_sequence_tokens = {
|
838 |
-
k: rejected_tokens[f"prompt_{k}"] + rejected_tokens[k] for k in ["input_ids", "attention_mask"]
|
839 |
-
}
|
840 |
-
chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:]
|
841 |
-
chosen_sequence_tokens["labels"][: len(chosen_tokens["prompt_input_ids"])] = [
|
842 |
-
self.label_pad_token_id
|
843 |
-
] * len(chosen_tokens["prompt_input_ids"])
|
844 |
-
rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:]
|
845 |
-
rejected_sequence_tokens["labels"][: len(rejected_tokens["prompt_input_ids"])] = [
|
846 |
-
self.label_pad_token_id
|
847 |
-
] * len(rejected_tokens["prompt_input_ids"])
|
848 |
-
|
849 |
-
for k, toks in {
|
850 |
-
"chosen_": chosen_sequence_tokens,
|
851 |
-
"rejected_": rejected_sequence_tokens,
|
852 |
-
"": prompt_tokens,
|
853 |
-
}.items():
|
854 |
-
for type_key, tokens in toks.items():
|
855 |
-
if type_key == "token_type_ids":
|
856 |
-
continue
|
857 |
-
batch[f"{k}{type_key}"] = tokens
|
858 |
-
|
859 |
-
else:
|
860 |
-
chosen_tokens = self.processing_class(
|
861 |
-
chosen, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
|
862 |
-
)
|
863 |
-
rejected_tokens = self.processing_class(
|
864 |
-
rejected, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
|
865 |
-
)
|
866 |
-
prompt_tokens = self.processing_class(
|
867 |
-
prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True
|
868 |
-
)
|
869 |
-
|
870 |
-
batch["chosen_labels"] = chosen_tokens["input_ids"]
|
871 |
-
batch["rejected_labels"] = rejected_tokens["input_ids"]
|
872 |
-
batch["prompt_input_ids"] = prompt_tokens["input_ids"]
|
873 |
-
batch["prompt_attention_mask"] = prompt_tokens["attention_mask"]
|
874 |
-
|
875 |
-
if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"):
|
876 |
-
batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
|
877 |
-
labels=torch.tensor(batch["rejected_labels"])
|
878 |
-
)
|
879 |
-
batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
|
880 |
-
labels=torch.tensor(batch["chosen_labels"])
|
881 |
-
)
|
882 |
-
|
883 |
-
return batch
|
884 |
-
|
885 |
-
@staticmethod
|
886 |
-
def concatenated_inputs(
|
887 |
-
batch: dict[str, Union[list, torch.LongTensor]],
|
888 |
-
is_encoder_decoder: bool = False,
|
889 |
-
label_pad_token_id: int = -100,
|
890 |
-
padding_value: int = 0,
|
891 |
-
device: Optional[torch.device] = None,
|
892 |
-
) -> dict[str, torch.LongTensor]:
|
893 |
-
"""Concatenate the chosen and rejected inputs into a single tensor.
|
894 |
-
|
895 |
-
Args:
|
896 |
-
batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length).
|
897 |
-
is_encoder_decoder: Whether the model is an encoder-decoder model.
|
898 |
-
label_pad_token_id: The label pad token id.
|
899 |
-
padding_value: The padding value to use for the concatenated inputs_ids.
|
900 |
-
device: The device for the concatenated inputs.
|
901 |
-
|
902 |
-
Returns:
|
903 |
-
A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'.
|
904 |
-
"""
|
905 |
-
concatenated_batch = {}
|
906 |
-
|
907 |
-
if is_encoder_decoder:
|
908 |
-
max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1])
|
909 |
-
else:
|
910 |
-
max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1])
|
911 |
-
|
912 |
-
for k in batch:
|
913 |
-
if k.startswith("chosen") and isinstance(batch[k], torch.Tensor):
|
914 |
-
if "labels" in k or is_encoder_decoder:
|
915 |
-
pad_value = label_pad_token_id
|
916 |
-
elif k.endswith("_input_ids"):
|
917 |
-
pad_value = padding_value
|
918 |
-
elif k.endswith("_attention_mask"):
|
919 |
-
pad_value = 0
|
920 |
-
concatenated_key = k.replace("chosen", "concatenated")
|
921 |
-
concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value)
|
922 |
-
for k in batch:
|
923 |
-
if k.startswith("rejected") and isinstance(batch[k], torch.Tensor):
|
924 |
-
if "labels" in k or is_encoder_decoder:
|
925 |
-
pad_value = label_pad_token_id
|
926 |
-
elif k.endswith("_input_ids"):
|
927 |
-
pad_value = padding_value
|
928 |
-
elif k.endswith("_attention_mask"):
|
929 |
-
pad_value = 0
|
930 |
-
concatenated_key = k.replace("rejected", "concatenated")
|
931 |
-
concatenated_batch[concatenated_key] = torch.cat(
|
932 |
-
(
|
933 |
-
concatenated_batch[concatenated_key],
|
934 |
-
pad_to_length(batch[k], max_length, pad_value=pad_value),
|
935 |
-
),
|
936 |
-
dim=0,
|
937 |
-
).to(device=device)
|
938 |
-
|
939 |
-
if is_encoder_decoder:
|
940 |
-
concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1).to(device=device)
|
941 |
-
concatenated_batch["concatenated_attention_mask"] = (
|
942 |
-
batch["prompt_attention_mask"].repeat(2, 1).to(device=device)
|
943 |
-
)
|
944 |
-
|
945 |
-
return concatenated_batch
|
946 |
-
|
947 |
-
def cpo_loss(
|
948 |
-
self,
|
949 |
-
policy_chosen_logps: torch.FloatTensor,
|
950 |
-
policy_rejected_logps: torch.FloatTensor,
|
951 |
-
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
952 |
-
"""Compute the CPO loss for a batch of policy and reference model log probabilities.
|
953 |
-
|
954 |
-
Args:
|
955 |
-
policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
|
956 |
-
policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
|
957 |
-
|
958 |
-
Returns:
|
959 |
-
A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
|
960 |
-
The losses tensor contains the CPO loss for each example in the batch.
|
961 |
-
The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
|
962 |
-
"""
|
963 |
-
logits = (policy_chosen_logps - policy_rejected_logps).to(self.accelerator.device)
|
964 |
-
|
965 |
-
# The beta is a temperature parameter for the CPO loss, typically something in the range of 0.1 to 0.5.
|
966 |
-
# We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the labels and
|
967 |
-
# calculates a conservative CPO loss.
|
968 |
-
|
969 |
-
if self.loss_type == "simpo":
|
970 |
-
gamma_logratios = self.simpo_gamma / self.beta
|
971 |
-
logits = logits - gamma_logratios
|
972 |
-
# This reduces to Equation 3 from the CPO paper when label_smoothing -> 0.
|
973 |
-
losses = (
|
974 |
-
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
|
975 |
-
- F.logsigmoid(-self.beta * logits) * self.label_smoothing
|
976 |
-
)
|
977 |
-
elif self.loss_type == "sigmoid":
|
978 |
-
# This reduces to Equation 3 from the CPO paper when label_smoothing -> 0.
|
979 |
-
losses = (
|
980 |
-
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
|
981 |
-
- F.logsigmoid(-self.beta * logits) * self.label_smoothing
|
982 |
-
)
|
983 |
-
elif self.loss_type == "hinge":
|
984 |
-
losses = torch.relu(1 - self.beta * logits)
|
985 |
-
elif self.loss_type == "ipo":
|
986 |
-
# eqn (17) of the paper where beta is the regularization parameter for the IPO loss, denoted by tau in the paper.
|
987 |
-
losses = (logits - 1 / (2 * self.beta)) ** 2
|
988 |
-
else:
|
989 |
-
raise ValueError(
|
990 |
-
f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'simpo']"
|
991 |
-
)
|
992 |
-
|
993 |
-
chosen_rewards = self.beta * (policy_chosen_logps.to(self.accelerator.device)).detach()
|
994 |
-
rejected_rewards = self.beta * (policy_rejected_logps.to(self.accelerator.device)).detach()
|
995 |
-
|
996 |
-
return losses, chosen_rewards, rejected_rewards
|
997 |
-
|
998 |
-
@staticmethod
|
999 |
-
def get_batch_logps(
|
1000 |
-
logits: torch.FloatTensor,
|
1001 |
-
labels: torch.LongTensor,
|
1002 |
-
average_log_prob: bool = False,
|
1003 |
-
label_pad_token_id: int = -100,
|
1004 |
-
is_encoder_decoder: bool = False,
|
1005 |
-
) -> torch.FloatTensor:
|
1006 |
-
"""Compute the log probabilities of the given labels under the given logits.
|
1007 |
-
|
1008 |
-
Args:
|
1009 |
-
logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
|
1010 |
-
labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
|
1011 |
-
average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
|
1012 |
-
label_pad_token_id: The label pad token id.
|
1013 |
-
is_encoder_decoder: Whether the model is an encoder-decoder model.
|
1014 |
-
|
1015 |
-
Returns:
|
1016 |
-
A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
|
1017 |
-
"""
|
1018 |
-
if logits.shape[:-1] != labels.shape:
|
1019 |
-
raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
|
1020 |
-
|
1021 |
-
if not is_encoder_decoder:
|
1022 |
-
labels = labels[:, 1:].clone()
|
1023 |
-
logits = logits[:, :-1, :]
|
1024 |
-
loss_mask = labels != label_pad_token_id
|
1025 |
-
|
1026 |
-
# dummy token; we'll ignore the losses on these tokens later
|
1027 |
-
labels[labels == label_pad_token_id] = 0
|
1028 |
-
|
1029 |
-
per_token_logps = selective_log_softmax(logits, labels)
|
1030 |
-
|
1031 |
-
if average_log_prob:
|
1032 |
-
return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
|
1033 |
-
else:
|
1034 |
-
return (per_token_logps * loss_mask).sum(-1)
|
1035 |
-
|
1036 |
-
def concatenated_forward(
|
1037 |
-
self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
|
1038 |
-
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
1039 |
-
"""Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
|
1040 |
-
|
1041 |
-
We do this to avoid doing two forward passes, because it's faster for FSDP.
|
1042 |
-
"""
|
1043 |
-
concatenated_batch = self.concatenated_inputs(
|
1044 |
-
batch,
|
1045 |
-
is_encoder_decoder=self.is_encoder_decoder,
|
1046 |
-
label_pad_token_id=self.label_pad_token_id,
|
1047 |
-
padding_value=self.padding_value,
|
1048 |
-
device=self.accelerator.device,
|
1049 |
-
)
|
1050 |
-
len_chosen = batch["chosen_labels"].shape[0]
|
1051 |
-
|
1052 |
-
model_kwargs = (
|
1053 |
-
{
|
1054 |
-
"decoder_input_ids": self._shift_right(concatenated_batch["concatenated_labels"]),
|
1055 |
-
}
|
1056 |
-
if self.is_encoder_decoder
|
1057 |
-
else {}
|
1058 |
-
)
|
1059 |
-
|
1060 |
-
if self.aux_loss_enabled:
|
1061 |
-
model_kwargs["output_router_logits"] = True
|
1062 |
-
|
1063 |
-
outputs = model(
|
1064 |
-
concatenated_batch["concatenated_input_ids"],
|
1065 |
-
attention_mask=concatenated_batch["concatenated_attention_mask"],
|
1066 |
-
use_cache=False,
|
1067 |
-
**model_kwargs,
|
1068 |
-
)
|
1069 |
-
all_logits = outputs.logits
|
1070 |
-
|
1071 |
-
def cross_entropy_loss(logits, labels):
|
1072 |
-
if not self.is_encoder_decoder:
|
1073 |
-
# Shift so that tokens < n predict n
|
1074 |
-
logits = logits[..., :-1, :].contiguous()
|
1075 |
-
labels = labels[..., 1:].contiguous()
|
1076 |
-
# Flatten the tokens
|
1077 |
-
loss_fct = nn.CrossEntropyLoss()
|
1078 |
-
logits = logits.view(-1, logits.shape[-1])
|
1079 |
-
labels = labels.view(-1)
|
1080 |
-
# Enable model parallelism
|
1081 |
-
labels = labels.to(logits.device)
|
1082 |
-
loss = loss_fct(logits, labels)
|
1083 |
-
return loss
|
1084 |
-
|
1085 |
-
labels = concatenated_batch["concatenated_labels"].clone()
|
1086 |
-
|
1087 |
-
if self.cpo_alpha == 0:
|
1088 |
-
nll_loss = torch.tensor(0.0).to(self.accelerator.device)
|
1089 |
-
else:
|
1090 |
-
nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])
|
1091 |
-
|
1092 |
-
all_logps = self.get_batch_logps(
|
1093 |
-
all_logits,
|
1094 |
-
concatenated_batch["concatenated_labels"],
|
1095 |
-
average_log_prob=self.loss_type in ["ipo", "simpo"],
|
1096 |
-
is_encoder_decoder=self.is_encoder_decoder,
|
1097 |
-
label_pad_token_id=self.label_pad_token_id,
|
1098 |
-
)
|
1099 |
-
|
1100 |
-
chosen_logps = all_logps[:len_chosen]
|
1101 |
-
rejected_logps = all_logps[len_chosen:]
|
1102 |
-
|
1103 |
-
chosen_logits = all_logits[:len_chosen]
|
1104 |
-
rejected_logits = all_logits[len_chosen:]
|
1105 |
-
|
1106 |
-
if self.aux_loss_enabled:
|
1107 |
-
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss, outputs.aux_loss)
|
1108 |
-
|
1109 |
-
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss)
|
1110 |
-
|
1111 |
-
def get_batch_loss_metrics(
|
1112 |
-
self,
|
1113 |
-
model,
|
1114 |
-
batch: dict[str, Union[list, torch.LongTensor]],
|
1115 |
-
train_eval: Literal["train", "eval"] = "train",
|
1116 |
-
):
|
1117 |
-
"""Compute the CPO loss and other metrics for the given batch of inputs for train or test."""
|
1118 |
-
metrics = {}
|
1119 |
-
|
1120 |
-
forward_output = self.concatenated_forward(model, batch)
|
1121 |
-
(
|
1122 |
-
policy_chosen_logps,
|
1123 |
-
policy_rejected_logps,
|
1124 |
-
policy_chosen_logits,
|
1125 |
-
policy_rejected_logits,
|
1126 |
-
policy_nll_loss,
|
1127 |
-
) = forward_output[:5]
|
1128 |
-
if self.aux_loss_enabled:
|
1129 |
-
aux_loss = forward_output[5]
|
1130 |
-
|
1131 |
-
losses, chosen_rewards, rejected_rewards = self.cpo_loss(
|
1132 |
-
policy_chosen_logps,
|
1133 |
-
policy_rejected_logps,
|
1134 |
-
)
|
1135 |
-
|
1136 |
-
loss = losses.mean() + self.cpo_alpha * policy_nll_loss
|
1137 |
-
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
1138 |
-
|
1139 |
-
prefix = "eval_" if train_eval == "eval" else ""
|
1140 |
-
metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean().item()
|
1141 |
-
metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean().item()
|
1142 |
-
metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean().item()
|
1143 |
-
metrics[f"{prefix}rewards/margins"] = (
|
1144 |
-
self.accelerator.gather_for_metrics(chosen_rewards - rejected_rewards).mean().item()
|
1145 |
-
)
|
1146 |
-
metrics[f"{prefix}logps/rejected"] = (
|
1147 |
-
self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean().item()
|
1148 |
-
)
|
1149 |
-
metrics[f"{prefix}logps/chosen"] = (
|
1150 |
-
self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean().item()
|
1151 |
-
)
|
1152 |
-
metrics[f"{prefix}logits/rejected"] = (
|
1153 |
-
self.accelerator.gather_for_metrics(policy_rejected_logits).detach().mean().item()
|
1154 |
-
)
|
1155 |
-
metrics[f"{prefix}logits/chosen"] = (
|
1156 |
-
self.accelerator.gather_for_metrics(policy_chosen_logits).detach().mean().item()
|
1157 |
-
)
|
1158 |
-
metrics[f"{prefix}nll_loss"] = self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean().item()
|
1159 |
-
|
1160 |
-
if self.aux_loss_enabled:
|
1161 |
-
loss += self.aux_loss_coef * aux_loss
|
1162 |
-
|
1163 |
-
return loss, metrics
|
1164 |
-
|
1165 |
-
def compute_loss(
|
1166 |
-
self,
|
1167 |
-
model: Union[PreTrainedModel, nn.Module],
|
1168 |
-
inputs: dict[str, Union[torch.Tensor, Any]],
|
1169 |
-
return_outputs=False,
|
1170 |
-
num_items_in_batch=None,
|
1171 |
-
) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
|
1172 |
-
compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
|
1173 |
-
|
1174 |
-
with compute_loss_context_manager:
|
1175 |
-
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
|
1176 |
-
|
1177 |
-
# force log the metrics
|
1178 |
-
self.store_metrics(metrics, train_eval="train")
|
1179 |
-
|
1180 |
-
if return_outputs:
|
1181 |
-
return (loss, metrics)
|
1182 |
-
return loss
|
1183 |
-
|
1184 |
-
def generate_from_model(self, model, batch: dict[str, torch.LongTensor]) -> str:
|
1185 |
-
"""Generate samples from the model and reference model for the given batch of inputs."""
|
1186 |
-
|
1187 |
-
# If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
|
1188 |
-
# the torch cuda amp context manager as some hidden states are silently casted to full precision.
|
1189 |
-
generate_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
|
1190 |
-
|
1191 |
-
with generate_context_manager:
|
1192 |
-
policy_output = model.generate(
|
1193 |
-
input_ids=batch["prompt_input_ids"],
|
1194 |
-
attention_mask=batch["prompt_attention_mask"],
|
1195 |
-
max_length=self.max_length,
|
1196 |
-
do_sample=True,
|
1197 |
-
pad_token_id=self.processing_class.pad_token_id,
|
1198 |
-
)
|
1199 |
-
|
1200 |
-
policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id)
|
1201 |
-
policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True)
|
1202 |
-
|
1203 |
-
return policy_output_decoded
|
1204 |
-
|
1205 |
-
def prediction_step(
|
1206 |
-
self,
|
1207 |
-
model: Union[PreTrainedModel, nn.Module],
|
1208 |
-
inputs: dict[str, Union[torch.Tensor, Any]],
|
1209 |
-
prediction_loss_only: bool,
|
1210 |
-
ignore_keys: Optional[list[str]] = None,
|
1211 |
-
):
|
1212 |
-
if ignore_keys is None:
|
1213 |
-
if hasattr(model, "config"):
|
1214 |
-
ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
|
1215 |
-
else:
|
1216 |
-
ignore_keys = []
|
1217 |
-
|
1218 |
-
prediction_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
|
1219 |
-
|
1220 |
-
with torch.no_grad(), prediction_context_manager:
|
1221 |
-
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval")
|
1222 |
-
|
1223 |
-
# force log the metrics
|
1224 |
-
self.store_metrics(metrics, train_eval="eval")
|
1225 |
-
|
1226 |
-
if prediction_loss_only:
|
1227 |
-
return (loss.detach(), None, None)
|
1228 |
-
|
1229 |
-
# logits for the chosen and rejected samples from model
|
1230 |
-
logits_dict = {
|
1231 |
-
"eval_logits/chosen": metrics["eval_logits/chosen"],
|
1232 |
-
"eval_logits/rejected": metrics["eval_logits/rejected"],
|
1233 |
-
}
|
1234 |
-
logits = tuple(v.unsqueeze(dim=0) for k, v in logits_dict.items() if k not in ignore_keys)
|
1235 |
-
logits = torch.stack(logits).mean(axis=1).to(self.accelerator.device)
|
1236 |
-
labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
|
1237 |
-
|
1238 |
-
return (loss.detach(), logits, labels)
|
1239 |
-
|
1240 |
-
def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
|
1241 |
-
for key, value in metrics.items():
|
1242 |
-
self._stored_metrics[train_eval][key].append(value)
|
1243 |
-
|
1244 |
-
def evaluation_loop(
|
1245 |
-
self,
|
1246 |
-
dataloader: DataLoader,
|
1247 |
-
description: str,
|
1248 |
-
prediction_loss_only: Optional[bool] = None,
|
1249 |
-
ignore_keys: Optional[list[str]] = None,
|
1250 |
-
metric_key_prefix: str = "eval",
|
1251 |
-
) -> EvalLoopOutput:
|
1252 |
-
"""
|
1253 |
-
Overriding built-in evaluation loop to store metrics for each batch.
|
1254 |
-
Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
|
1255 |
-
|
1256 |
-
Works both with or without labels.
|
1257 |
-
"""
|
1258 |
-
|
1259 |
-
# Sample and save to game log if requested (for one batch to save time)
|
1260 |
-
if self.generate_during_eval:
|
1261 |
-
# Generate random indices within the range of the total number of samples
|
1262 |
-
num_samples = len(dataloader.dataset)
|
1263 |
-
random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
|
1264 |
-
|
1265 |
-
# Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
|
1266 |
-
random_batch_dataset = dataloader.dataset.select(random_indices)
|
1267 |
-
random_batch = self.data_collator(random_batch_dataset)
|
1268 |
-
random_batch = self._prepare_inputs(random_batch)
|
1269 |
-
|
1270 |
-
policy_output_decoded = self.generate_from_model(self.model, random_batch)
|
1271 |
-
|
1272 |
-
table = pd.DataFrame(
|
1273 |
-
columns=["Prompt", "Policy"],
|
1274 |
-
data=[
|
1275 |
-
[prompt, pol[len(prompt) :]] for prompt, pol in zip(random_batch["prompt"], policy_output_decoded)
|
1276 |
-
],
|
1277 |
-
)
|
1278 |
-
if "wandb" in self.args.report_to:
|
1279 |
-
wandb.log({"game_log": wandb.Table(data=table)})
|
1280 |
-
|
1281 |
-
if "comet_ml" in self.args.report_to:
|
1282 |
-
log_table_to_comet_experiment(
|
1283 |
-
name="game_log.csv",
|
1284 |
-
table=table,
|
1285 |
-
)
|
1286 |
-
|
1287 |
-
# Base evaluation
|
1288 |
-
initial_output = super().evaluation_loop(
|
1289 |
-
dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
|
1290 |
-
)
|
1291 |
-
|
1292 |
-
return initial_output
|
1293 |
-
|
1294 |
-
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
|
1295 |
-
"""
|
1296 |
-
Log `logs` on the various objects watching training, including stored metrics.
|
1297 |
-
|
1298 |
-
Args:
|
1299 |
-
logs (`dict[str, float]`):
|
1300 |
-
The values to log.
|
1301 |
-
start_time (`float` or `None`, *optional*, defaults to `None`):
|
1302 |
-
Start time of the training.
|
1303 |
-
"""
|
1304 |
-
# logs either has 'loss' or 'eval_loss'
|
1305 |
-
train_eval = "train" if "loss" in logs else "eval"
|
1306 |
-
# Add averaged stored metrics to logs
|
1307 |
-
for key, metrics in self._stored_metrics[train_eval].items():
|
1308 |
-
logs[key] = torch.tensor(metrics).mean().item()
|
1309 |
-
del self._stored_metrics[train_eval]
|
1310 |
-
|
1311 |
-
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
1312 |
-
return super().log(logs, start_time)
|
1313 |
-
else: # transformers<=4.46
|
1314 |
-
return super().log(logs)
|
1315 |
-
|
1316 |
-
def _shift_right(self, input_ids):
|
1317 |
-
if self.decoder_start_token_id is None:
|
1318 |
-
raise ValueError(
|
1319 |
-
"model.config.decoder_start_token_id has to be defined. It is usually set to the pad_token_id."
|
1320 |
-
)
|
1321 |
-
|
1322 |
-
# shift inputs to the right
|
1323 |
-
if is_torch_fx_proxy(input_ids):
|
1324 |
-
# Item assignment is not supported natively for proxies.
|
1325 |
-
shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), self.decoder_start_token_id)
|
1326 |
-
shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
|
1327 |
-
else:
|
1328 |
-
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
1329 |
-
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
|
1330 |
-
shifted_input_ids[..., 0] = self.decoder_start_token_id
|
1331 |
-
|
1332 |
-
if self.pad_token_id is None:
|
1333 |
-
raise ValueError("model.config.pad_token_id has to be defined.")
|
1334 |
-
# replace possible -100 values in labels by `pad_token_id`
|
1335 |
-
shifted_input_ids.masked_fill_(shifted_input_ids == -100, self.pad_token_id)
|
1336 |
-
|
1337 |
-
return shifted_input_ids
|
1338 |
-
|
1339 |
-
def create_model_card(
|
1340 |
-
self,
|
1341 |
-
model_name: Optional[str] = None,
|
1342 |
-
dataset_name: Optional[str] = None,
|
1343 |
-
tags: Union[str, list[str], None] = None,
|
1344 |
-
):
|
1345 |
-
"""
|
1346 |
-
Creates a draft of a model card using the information available to the `Trainer`.
|
1347 |
-
|
1348 |
-
Args:
|
1349 |
-
model_name (`str` or `None`, *optional*, defaults to `None`):
|
1350 |
-
Name of the model.
|
1351 |
-
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
1352 |
-
Name of the dataset used for training.
|
1353 |
-
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
1354 |
-
Tags to be associated with the model card.
|
1355 |
-
"""
|
1356 |
-
if not self.is_world_process_zero():
|
1357 |
-
return
|
1358 |
-
|
1359 |
-
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
1360 |
-
base_model = self.model.config._name_or_path
|
1361 |
-
else:
|
1362 |
-
base_model = None
|
1363 |
-
|
1364 |
-
tags = tags or []
|
1365 |
-
if isinstance(tags, str):
|
1366 |
-
tags = [tags]
|
1367 |
-
|
1368 |
-
if hasattr(self.model.config, "unsloth_version"):
|
1369 |
-
tags.append("unsloth")
|
1370 |
-
|
1371 |
-
citation = textwrap.dedent("""\
|
1372 |
-
@inproceedings{xu2024contrastive,
|
1373 |
-
title = {{Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation}},
|
1374 |
-
author = {Haoran Xu and Amr Sharaf and Yunmo Chen and Weiting Tan and Lingfeng Shen and Benjamin Van Durme and Kenton Murray and Young Jin Kim},
|
1375 |
-
year = 2024,
|
1376 |
-
booktitle = {Forty-first International Conference on Machine Learning, {ICML} 2024, Vienna, Austria, July 21-27, 2024},
|
1377 |
-
publisher = {OpenReview.net},
|
1378 |
-
url = {https://openreview.net/forum?id=51iwkioZpn}
|
1379 |
-
}""")
|
1380 |
-
|
1381 |
-
model_card = generate_model_card(
|
1382 |
-
base_model=base_model,
|
1383 |
-
model_name=model_name,
|
1384 |
-
hub_model_id=self.hub_model_id,
|
1385 |
-
dataset_name=dataset_name,
|
1386 |
-
tags=tags,
|
1387 |
-
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
1388 |
-
comet_url=get_comet_experiment_url(),
|
1389 |
-
trainer_name="CPO",
|
1390 |
-
trainer_citation=citation,
|
1391 |
-
paper_title="Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation",
|
1392 |
-
paper_id="2401.08417",
|
1393 |
-
)
|
1394 |
-
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
1395 |
-
class UnslothCPOTrainer(_UnslothCPOTrainer):
|
1396 |
-
"""
|
1397 |
-
|
1398 |
-
Initialize CPOTrainer.
|
1399 |
-
|
1400 |
-
Args:
|
1401 |
-
model (`transformers.PreTrainedModel`):
|
1402 |
-
The model to train, preferably an `AutoModelForSequenceClassification`.
|
1403 |
-
args (`CPOConfig`):
|
1404 |
-
The CPO config arguments to use for training.
|
1405 |
-
data_collator (`transformers.DataCollator`):
|
1406 |
-
The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
|
1407 |
-
which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
|
1408 |
-
train_dataset (`datasets.Dataset`):
|
1409 |
-
The dataset to use for training.
|
1410 |
-
eval_dataset (`datasets.Dataset`):
|
1411 |
-
The dataset to use for evaluation.
|
1412 |
-
processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
|
1413 |
-
Processing class used to process the data. If provided, will be used to automatically process the inputs
|
1414 |
-
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
|
1415 |
-
reuse the fine-tuned model.
|
1416 |
-
model_init (`Callable[[], transformers.PreTrainedModel]`):
|
1417 |
-
The model initializer to use for training. If None is specified, the default model initializer will be used.
|
1418 |
-
callbacks (`list[transformers.TrainerCallback]`):
|
1419 |
-
The callbacks to use for training.
|
1420 |
-
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
|
1421 |
-
The optimizer and scheduler to use for training.
|
1422 |
-
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
|
1423 |
-
The function to use to preprocess the logits before computing the metrics.
|
1424 |
-
peft_config (`dict`, defaults to `None`):
|
1425 |
-
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
|
1426 |
-
compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
|
1427 |
-
The function to use to compute the metrics. Must take a `EvalPrediction` and return
|
1428 |
-
a dictionary string to metric values.
|
1429 |
-
|
1430 |
-
"""
|
1431 |
-
def __init__(
|
1432 |
-
self,
|
1433 |
-
model = None,
|
1434 |
-
args = None,
|
1435 |
-
data_collator = None,
|
1436 |
-
train_dataset = None,
|
1437 |
-
eval_dataset = None,
|
1438 |
-
processing_class = None,
|
1439 |
-
model_init = None,
|
1440 |
-
callbacks = None,
|
1441 |
-
preprocess_logits_for_metrics = None,
|
1442 |
-
peft_config = None,
|
1443 |
-
compute_metrics = None,
|
1444 |
-
**kwargs
|
1445 |
-
):
|
1446 |
-
if args is None: args = UnslothCPOConfig()
|
1447 |
-
use_bf16 = getattr(args, 'bf16', False)
|
1448 |
-
use_fp16 = getattr(args, 'fp16', False)
|
1449 |
-
force_float32 = False
|
1450 |
-
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
1451 |
-
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
1452 |
-
force_float32 = True
|
1453 |
-
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
1454 |
-
dtype = getattr(model.config, 'torch_dtype', None)
|
1455 |
-
if dtype is None: dtype = model.get_input_embeddings().dtype
|
1456 |
-
from unsloth_zoo.utils import _get_dtype
|
1457 |
-
dtype = _get_dtype(dtype)
|
1458 |
-
float16 = dtype == torch.float16
|
1459 |
-
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
1460 |
-
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
1461 |
-
if force_float32:
|
1462 |
-
args.fp16 = False
|
1463 |
-
args.bf16 = False
|
1464 |
-
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
1465 |
-
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
1466 |
-
args.fp16 = float16
|
1467 |
-
args.bf16 = not float16
|
1468 |
-
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
1469 |
-
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
1470 |
-
args.eval_strategy = 'steps'
|
1471 |
-
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
1472 |
-
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
1473 |
-
if ga_steps is not None and ga_steps > 1:
|
1474 |
-
from transformers import __version__ as transformers_version
|
1475 |
-
if Version(transformers_version) <= Version('4.45.2'):
|
1476 |
-
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
1477 |
-
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
1478 |
-
if getattr(args, 'eval_strategy', 'no') != 'no':
|
1479 |
-
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
1480 |
-
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
1481 |
-
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
1482 |
-
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
1483 |
-
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
1484 |
-
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
1485 |
-
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
1486 |
-
if force_float32:
|
1487 |
-
args.bf16_full_eval = False
|
1488 |
-
args.fp16_full_eval = False
|
1489 |
-
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
1490 |
-
args.bf16_full_eval = True
|
1491 |
-
args.fp16_full_eval = False
|
1492 |
-
elif not bf16_full_eval and not fp16_full_eval:
|
1493 |
-
args.bf16_full_eval = args.bf16
|
1494 |
-
args.fp16_full_eval = args.fp16
|
1495 |
-
_output_logits = False
|
1496 |
-
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
1497 |
-
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
1498 |
-
if _output_logits:
|
1499 |
-
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
1500 |
-
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
1501 |
-
pass
|
1502 |
-
else:
|
1503 |
-
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
1504 |
-
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
1505 |
-
if args_max_seq_length is None and model_max_seq_length is not None:
|
1506 |
-
max_seq_length = model.max_seq_length
|
1507 |
-
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
1508 |
-
if model is not None and hasattr(model, 'for_training'):
|
1509 |
-
model.for_training()
|
1510 |
-
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
1511 |
-
if 'processing_class' in locals():
|
1512 |
-
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
1513 |
-
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
1514 |
-
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
1515 |
-
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
1516 |
-
if not isinstance(data_collator, UnslothVisionDataCollator):
|
1517 |
-
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
1518 |
-
data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
|
1519 |
-
elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
1520 |
-
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
1521 |
-
else:
|
1522 |
-
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
1523 |
-
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
1524 |
-
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
1525 |
-
if not isinstance(data_collator, UnslothVisionDataCollator):
|
1526 |
-
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
1527 |
-
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
1528 |
-
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
1529 |
-
else:
|
1530 |
-
data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
|
1531 |
-
other_metrics = []
|
1532 |
-
|
1533 |
-
from unsloth_zoo.logging_utils import PatchRLStatistics
|
1534 |
-
PatchRLStatistics('cpo_trainer', other_metrics)
|
1535 |
-
|
1536 |
-
super().__init__(
|
1537 |
-
model = model,
|
1538 |
-
args = args,
|
1539 |
-
data_collator = data_collator,
|
1540 |
-
train_dataset = train_dataset,
|
1541 |
-
eval_dataset = eval_dataset,
|
1542 |
-
processing_class = processing_class,
|
1543 |
-
model_init = model_init,
|
1544 |
-
callbacks = callbacks,
|
1545 |
-
preprocess_logits_for_metrics = preprocess_logits_for_metrics,
|
1546 |
-
peft_config = peft_config,
|
1547 |
-
compute_metrics = compute_metrics,**kwargs)
|
1548 |
-
if hasattr(self, 'neftune_hook_handle'):
|
1549 |
-
self.neftune_hook_handle.remove()
|
1550 |
-
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
1551 |
-
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
1552 |
-
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
1553 |
-
pass
|
1554 |
-
|
1555 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unsloth_compiled_cache/UnslothDDPOTrainer.py
DELETED
@@ -1,872 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
2025.3.13
|
3 |
-
2025.3.15
|
4 |
-
4.48.3
|
5 |
-
0.15.2
|
6 |
-
__UNSLOTH_VERSIONING__
|
7 |
-
"""
|
8 |
-
from torch import Tensor
|
9 |
-
import torch
|
10 |
-
import torch.nn as nn
|
11 |
-
from torch.nn import functional as F
|
12 |
-
from trl.trainer.ddpo_trainer import (Accelerator, Any, Callable, DDPOConfig, DDPOStableDiffusionPipeline, DDPOTrainer, Optional, PerPromptStatTracker, ProjectConfiguration, PyTorchModelHubMixin, Union, defaultdict, futures, generate_model_card, get_comet_experiment_url, is_wandb_available, logger, os, set_seed, textwrap, torch, wandb, warn)
|
13 |
-
|
14 |
-
|
15 |
-
import os
|
16 |
-
from typing import *
|
17 |
-
from dataclasses import dataclass, field
|
18 |
-
from packaging.version import Version
|
19 |
-
import torch
|
20 |
-
import numpy as np
|
21 |
-
from contextlib import nullcontext
|
22 |
-
from torch.nn import functional as F
|
23 |
-
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
|
24 |
-
|
25 |
-
torch_compile_options = {
|
26 |
-
"epilogue_fusion" : True,
|
27 |
-
"max_autotune" : False,
|
28 |
-
"shape_padding" : True,
|
29 |
-
"trace.enabled" : False,
|
30 |
-
"triton.cudagraphs" : False,
|
31 |
-
}
|
32 |
-
|
33 |
-
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
34 |
-
def selective_log_softmax(logits, index):
|
35 |
-
logits = logits.to(torch.float32)
|
36 |
-
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
37 |
-
# loop to reduce peak mem consumption
|
38 |
-
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
39 |
-
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
40 |
-
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
41 |
-
return per_token_logps
|
42 |
-
@dataclass
|
43 |
-
class UnslothDDPOConfig(DDPOConfig):
|
44 |
-
"""
|
45 |
-
|
46 |
-
Configuration class for the [`DDPOTrainer`].
|
47 |
-
|
48 |
-
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
49 |
-
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
50 |
-
command line.
|
51 |
-
|
52 |
-
Parameters:
|
53 |
-
exp_name (`str`, *optional*, defaults to `os.path.basename(sys.argv[0])[: -len(".py")]`):
|
54 |
-
Name of this experiment (by default is the file name without the extension name).
|
55 |
-
run_name (`str`, *optional*, defaults to `""`):
|
56 |
-
Name of this run.
|
57 |
-
seed (`int`, *optional*, defaults to `0`):
|
58 |
-
Random seed.
|
59 |
-
log_with (`Literal["wandb", "tensorboard"]]` or `None`, *optional*, defaults to `None`):
|
60 |
-
Log with either 'wandb' or 'tensorboard', check
|
61 |
-
https://huggingface.co/docs/accelerate/usage_guides/tracking for more details.
|
62 |
-
tracker_kwargs (`Dict`, *optional*, defaults to `{}`):
|
63 |
-
Keyword arguments for the tracker (e.g. wandb_project).
|
64 |
-
accelerator_kwargs (`Dict`, *optional*, defaults to `{}`):
|
65 |
-
Keyword arguments for the accelerator.
|
66 |
-
project_kwargs (`Dict`, *optional*, defaults to `{}`):
|
67 |
-
Keyword arguments for the accelerator project config (e.g. `logging_dir`).
|
68 |
-
tracker_project_name (`str`, *optional*, defaults to `"trl"`):
|
69 |
-
Name of project to use for tracking.
|
70 |
-
logdir (`str`, *optional*, defaults to `"logs"`):
|
71 |
-
Top-level logging directory for checkpoint saving.
|
72 |
-
num_epochs (`int`, *optional*, defaults to `100`):
|
73 |
-
Number of epochs to train.
|
74 |
-
save_freq (`int`, *optional*, defaults to `1`):
|
75 |
-
Number of epochs between saving model checkpoints.
|
76 |
-
num_checkpoint_limit (`int`, *optional*, defaults to `5`):
|
77 |
-
Number of checkpoints to keep before overwriting old ones.
|
78 |
-
mixed_precision (`str`, *optional*, defaults to `"fp16"`):
|
79 |
-
Mixed precision training.
|
80 |
-
allow_tf32 (`bool`, *optional*, defaults to `True`):
|
81 |
-
Allow `tf32` on Ampere GPUs.
|
82 |
-
resume_from (`str`, *optional*, defaults to `""`):
|
83 |
-
Resume training from a checkpoint.
|
84 |
-
sample_num_steps (`int`, *optional*, defaults to `50`):
|
85 |
-
Number of sampler inference steps.
|
86 |
-
sample_eta (`float`, *optional*, defaults to `1.0`):
|
87 |
-
Eta parameter for the DDIM sampler.
|
88 |
-
sample_guidance_scale (`float`, *optional*, defaults to `5.0`):
|
89 |
-
Classifier-free guidance weight.
|
90 |
-
sample_batch_size (`int`, *optional*, defaults to `1`):
|
91 |
-
Batch size (per GPU) to use for sampling.
|
92 |
-
sample_num_batches_per_epoch (`int`, *optional*, defaults to `2`):
|
93 |
-
Number of batches to sample per epoch.
|
94 |
-
train_batch_size (`int`, *optional*, defaults to `1`):
|
95 |
-
Batch size (per GPU) to use for training.
|
96 |
-
train_use_8bit_adam (`bool`, *optional*, defaults to `False`):
|
97 |
-
Use 8bit Adam optimizer from bitsandbytes.
|
98 |
-
train_learning_rate (`float`, *optional*, defaults to `3e-4`):
|
99 |
-
Learning rate.
|
100 |
-
train_adam_beta1 (`float`, *optional*, defaults to `0.9`):
|
101 |
-
Adam beta1.
|
102 |
-
train_adam_beta2 (`float`, *optional*, defaults to `0.999`):
|
103 |
-
Adam beta2.
|
104 |
-
train_adam_weight_decay (`float`, *optional*, defaults to `1e-4`):
|
105 |
-
Adam weight decay.
|
106 |
-
train_adam_epsilon (`float`, *optional*, defaults to `1e-8`):
|
107 |
-
Adam epsilon.
|
108 |
-
train_gradient_accumulation_steps (`int`, *optional*, defaults to `1`):
|
109 |
-
Number of gradient accumulation steps.
|
110 |
-
train_max_grad_norm (`float`, *optional*, defaults to `1.0`):
|
111 |
-
Maximum gradient norm for gradient clipping.
|
112 |
-
train_num_inner_epochs (`int`, *optional*, defaults to `1`):
|
113 |
-
Number of inner epochs per outer epoch.
|
114 |
-
train_cfg (`bool`, *optional*, defaults to `True`):
|
115 |
-
Whether to use classifier-free guidance during training.
|
116 |
-
train_adv_clip_max (`float`, *optional*, defaults to `5.0`):
|
117 |
-
Clip advantages to the range.
|
118 |
-
train_clip_range (`float`, *optional*, defaults to `1e-4`):
|
119 |
-
PPO clip range.
|
120 |
-
train_timestep_fraction (`float`, *optional*, defaults to `1.0`):
|
121 |
-
Fraction of timesteps to train on.
|
122 |
-
per_prompt_stat_tracking (`bool`, *optional*, defaults to `False`):
|
123 |
-
Whether to track statistics for each prompt separately.
|
124 |
-
per_prompt_stat_tracking_buffer_size (`int`, *optional*, defaults to `16`):
|
125 |
-
Number of reward values to store in the buffer for each prompt.
|
126 |
-
per_prompt_stat_tracking_min_count (`int`, *optional*, defaults to `16`):
|
127 |
-
Minimum number of reward values to store in the buffer.
|
128 |
-
async_reward_computation (`bool`, *optional*, defaults to `False`):
|
129 |
-
Whether to compute rewards asynchronously.
|
130 |
-
max_workers (`int`, *optional*, defaults to `2`):
|
131 |
-
Maximum number of workers to use for async reward computation.
|
132 |
-
negative_prompts (`str`, *optional*, defaults to `""`):
|
133 |
-
Comma-separated list of prompts to use as negative examples.
|
134 |
-
push_to_hub (`bool`, *optional*, defaults to `False`):
|
135 |
-
Whether to push the final model checkpoint to the Hub.
|
136 |
-
|
137 |
-
"""
|
138 |
-
vllm_sampling_params: Optional[Any] = field(
|
139 |
-
default = None,
|
140 |
-
metadata = {'help': 'vLLM SamplingParams'},
|
141 |
-
)
|
142 |
-
unsloth_num_chunks : Optional[int] = field(
|
143 |
-
default = -1,
|
144 |
-
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
145 |
-
)
|
146 |
-
def __init__(
|
147 |
-
self,
|
148 |
-
exp_name = 'main',
|
149 |
-
run_name = '',
|
150 |
-
seed = 3407,
|
151 |
-
log_with = None,
|
152 |
-
tracker_project_name = 'trl',
|
153 |
-
logdir = 'logs',
|
154 |
-
num_epochs = 100,
|
155 |
-
save_freq = 1,
|
156 |
-
num_checkpoint_limit = 5,
|
157 |
-
mixed_precision = 'fp16',
|
158 |
-
allow_tf32 = True,
|
159 |
-
resume_from = '',
|
160 |
-
sample_num_steps = 50,
|
161 |
-
sample_eta = 1.0,
|
162 |
-
sample_guidance_scale = 5.0,
|
163 |
-
sample_batch_size = 1,
|
164 |
-
sample_num_batches_per_epoch = 2,
|
165 |
-
train_batch_size = 1,
|
166 |
-
train_use_8bit_adam = False,
|
167 |
-
train_learning_rate = 5e-05,
|
168 |
-
train_adam_beta1 = 0.9,
|
169 |
-
train_adam_beta2 = 0.999,
|
170 |
-
train_adam_weight_decay = 0.01,
|
171 |
-
train_adam_epsilon = 1e-08,
|
172 |
-
train_gradient_accumulation_steps = 2,
|
173 |
-
train_max_grad_norm = 1.0,
|
174 |
-
train_num_inner_epochs = 1,
|
175 |
-
train_cfg = True,
|
176 |
-
train_adv_clip_max = 5.0,
|
177 |
-
train_clip_range = 0.0001,
|
178 |
-
train_timestep_fraction = 1.0,
|
179 |
-
per_prompt_stat_tracking = False,
|
180 |
-
per_prompt_stat_tracking_buffer_size = 16,
|
181 |
-
per_prompt_stat_tracking_min_count = 16,
|
182 |
-
async_reward_computation = False,
|
183 |
-
max_workers = 2,
|
184 |
-
negative_prompts = '',
|
185 |
-
push_to_hub = False,
|
186 |
-
vllm_sampling_params = None,
|
187 |
-
unsloth_num_chunks = -1,
|
188 |
-
**kwargs,
|
189 |
-
):
|
190 |
-
|
191 |
-
super().__init__(
|
192 |
-
exp_name = exp_name,
|
193 |
-
run_name = run_name,
|
194 |
-
seed = seed,
|
195 |
-
log_with = log_with,
|
196 |
-
tracker_project_name = tracker_project_name,
|
197 |
-
logdir = logdir,
|
198 |
-
num_epochs = num_epochs,
|
199 |
-
save_freq = save_freq,
|
200 |
-
num_checkpoint_limit = num_checkpoint_limit,
|
201 |
-
mixed_precision = mixed_precision,
|
202 |
-
allow_tf32 = allow_tf32,
|
203 |
-
resume_from = resume_from,
|
204 |
-
sample_num_steps = sample_num_steps,
|
205 |
-
sample_eta = sample_eta,
|
206 |
-
sample_guidance_scale = sample_guidance_scale,
|
207 |
-
sample_batch_size = sample_batch_size,
|
208 |
-
sample_num_batches_per_epoch = sample_num_batches_per_epoch,
|
209 |
-
train_batch_size = train_batch_size,
|
210 |
-
train_use_8bit_adam = train_use_8bit_adam,
|
211 |
-
train_learning_rate = train_learning_rate,
|
212 |
-
train_adam_beta1 = train_adam_beta1,
|
213 |
-
train_adam_beta2 = train_adam_beta2,
|
214 |
-
train_adam_weight_decay = train_adam_weight_decay,
|
215 |
-
train_adam_epsilon = train_adam_epsilon,
|
216 |
-
train_gradient_accumulation_steps = train_gradient_accumulation_steps,
|
217 |
-
train_max_grad_norm = train_max_grad_norm,
|
218 |
-
train_num_inner_epochs = train_num_inner_epochs,
|
219 |
-
train_cfg = train_cfg,
|
220 |
-
train_adv_clip_max = train_adv_clip_max,
|
221 |
-
train_clip_range = train_clip_range,
|
222 |
-
train_timestep_fraction = train_timestep_fraction,
|
223 |
-
per_prompt_stat_tracking = per_prompt_stat_tracking,
|
224 |
-
per_prompt_stat_tracking_buffer_size = per_prompt_stat_tracking_buffer_size,
|
225 |
-
per_prompt_stat_tracking_min_count = per_prompt_stat_tracking_min_count,
|
226 |
-
async_reward_computation = async_reward_computation,
|
227 |
-
max_workers = max_workers,
|
228 |
-
negative_prompts = negative_prompts,
|
229 |
-
push_to_hub = push_to_hub,**kwargs)
|
230 |
-
self.vllm_sampling_params = vllm_sampling_params
|
231 |
-
self.unsloth_num_chunks = unsloth_num_chunks
|
232 |
-
pass
|
233 |
-
|
234 |
-
class _UnslothDDPOTrainer(PyTorchModelHubMixin):
|
235 |
-
""""""
|
236 |
-
|
237 |
-
_tag_names = ["trl", "ddpo"]
|
238 |
-
|
239 |
-
def __init__(
|
240 |
-
self,
|
241 |
-
config: DDPOConfig,
|
242 |
-
reward_function: Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor],
|
243 |
-
prompt_function: Callable[[], tuple[str, Any]],
|
244 |
-
sd_pipeline: DDPOStableDiffusionPipeline,
|
245 |
-
image_samples_hook: Optional[Callable[[Any, Any, Any], Any]] = None,
|
246 |
-
):
|
247 |
-
if image_samples_hook is None:
|
248 |
-
warn("No image_samples_hook provided; no images will be logged")
|
249 |
-
|
250 |
-
self.prompt_fn = prompt_function
|
251 |
-
self.reward_fn = reward_function
|
252 |
-
self.config = config
|
253 |
-
self.image_samples_callback = image_samples_hook
|
254 |
-
|
255 |
-
accelerator_project_config = ProjectConfiguration(**self.config.project_kwargs)
|
256 |
-
|
257 |
-
if self.config.resume_from:
|
258 |
-
self.config.resume_from = os.path.normpath(os.path.expanduser(self.config.resume_from))
|
259 |
-
if "checkpoint_" not in os.path.basename(self.config.resume_from):
|
260 |
-
# get the most recent checkpoint in this directory
|
261 |
-
checkpoints = list(
|
262 |
-
filter(
|
263 |
-
lambda x: "checkpoint_" in x,
|
264 |
-
os.listdir(self.config.resume_from),
|
265 |
-
)
|
266 |
-
)
|
267 |
-
if len(checkpoints) == 0:
|
268 |
-
raise ValueError(f"No checkpoints found in {self.config.resume_from}")
|
269 |
-
checkpoint_numbers = sorted([int(x.split("_")[-1]) for x in checkpoints])
|
270 |
-
self.config.resume_from = os.path.join(
|
271 |
-
self.config.resume_from,
|
272 |
-
f"checkpoint_{checkpoint_numbers[-1]}",
|
273 |
-
)
|
274 |
-
|
275 |
-
accelerator_project_config.iteration = checkpoint_numbers[-1] + 1
|
276 |
-
|
277 |
-
# number of timesteps within each trajectory to train on
|
278 |
-
self.num_train_timesteps = int(self.config.sample_num_steps * self.config.train_timestep_fraction)
|
279 |
-
|
280 |
-
self.accelerator = Accelerator(
|
281 |
-
log_with=self.config.log_with,
|
282 |
-
mixed_precision=self.config.mixed_precision,
|
283 |
-
project_config=accelerator_project_config,
|
284 |
-
# we always accumulate gradients across timesteps; we want config.train.gradient_accumulation_steps to be the
|
285 |
-
# number of *samples* we accumulate across, so we need to multiply by the number of training timesteps to get
|
286 |
-
# the total number of optimizer steps to accumulate across.
|
287 |
-
gradient_accumulation_steps=self.config.train_gradient_accumulation_steps * self.num_train_timesteps,
|
288 |
-
**self.config.accelerator_kwargs,
|
289 |
-
)
|
290 |
-
|
291 |
-
is_okay, message = self._config_check()
|
292 |
-
if not is_okay:
|
293 |
-
raise ValueError(message)
|
294 |
-
|
295 |
-
is_using_tensorboard = config.log_with is not None and config.log_with == "tensorboard"
|
296 |
-
|
297 |
-
if self.accelerator.is_main_process:
|
298 |
-
self.accelerator.init_trackers(
|
299 |
-
self.config.tracker_project_name,
|
300 |
-
config=dict(ddpo_trainer_config=config.to_dict()) if not is_using_tensorboard else config.to_dict(),
|
301 |
-
init_kwargs=self.config.tracker_kwargs,
|
302 |
-
)
|
303 |
-
|
304 |
-
logger.info(f"\n{config}")
|
305 |
-
|
306 |
-
set_seed(self.config.seed, device_specific=True)
|
307 |
-
|
308 |
-
self.sd_pipeline = sd_pipeline
|
309 |
-
|
310 |
-
self.sd_pipeline.set_progress_bar_config(
|
311 |
-
position=1,
|
312 |
-
disable=not self.accelerator.is_local_main_process,
|
313 |
-
leave=False,
|
314 |
-
desc="Timestep",
|
315 |
-
dynamic_ncols=True,
|
316 |
-
)
|
317 |
-
|
318 |
-
# For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
|
319 |
-
# as these weights are only used for inference, keeping weights in full precision is not required.
|
320 |
-
if self.accelerator.mixed_precision == "fp16":
|
321 |
-
inference_dtype = torch.float16
|
322 |
-
elif self.accelerator.mixed_precision == "bf16":
|
323 |
-
inference_dtype = torch.bfloat16
|
324 |
-
else:
|
325 |
-
inference_dtype = torch.float32
|
326 |
-
|
327 |
-
self.sd_pipeline.vae.to(self.accelerator.device, dtype=inference_dtype)
|
328 |
-
self.sd_pipeline.text_encoder.to(self.accelerator.device, dtype=inference_dtype)
|
329 |
-
self.sd_pipeline.unet.to(self.accelerator.device, dtype=inference_dtype)
|
330 |
-
|
331 |
-
trainable_layers = self.sd_pipeline.get_trainable_layers()
|
332 |
-
|
333 |
-
self.accelerator.register_save_state_pre_hook(self._save_model_hook)
|
334 |
-
self.accelerator.register_load_state_pre_hook(self._load_model_hook)
|
335 |
-
|
336 |
-
# Enable TF32 for faster training on Ampere GPUs,
|
337 |
-
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
338 |
-
if self.config.allow_tf32:
|
339 |
-
torch.backends.cuda.matmul.allow_tf32 = True
|
340 |
-
|
341 |
-
self.optimizer = self._setup_optimizer(
|
342 |
-
trainable_layers.parameters() if not isinstance(trainable_layers, list) else trainable_layers
|
343 |
-
)
|
344 |
-
|
345 |
-
self.neg_prompt_embed = self.sd_pipeline.text_encoder(
|
346 |
-
self.sd_pipeline.tokenizer(
|
347 |
-
[""] if self.config.negative_prompts is None else self.config.negative_prompts,
|
348 |
-
return_tensors="pt",
|
349 |
-
padding="max_length",
|
350 |
-
truncation=True,
|
351 |
-
max_length=self.sd_pipeline.tokenizer.model_max_length,
|
352 |
-
).input_ids.to(self.accelerator.device)
|
353 |
-
)[0]
|
354 |
-
|
355 |
-
if config.per_prompt_stat_tracking:
|
356 |
-
self.stat_tracker = PerPromptStatTracker(
|
357 |
-
config.per_prompt_stat_tracking_buffer_size,
|
358 |
-
config.per_prompt_stat_tracking_min_count,
|
359 |
-
)
|
360 |
-
|
361 |
-
# NOTE: for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses
|
362 |
-
# more memory
|
363 |
-
self.autocast = self.sd_pipeline.autocast or self.accelerator.autocast
|
364 |
-
|
365 |
-
if hasattr(self.sd_pipeline, "use_lora") and self.sd_pipeline.use_lora:
|
366 |
-
unet, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
|
367 |
-
self.trainable_layers = list(filter(lambda p: p.requires_grad, unet.parameters()))
|
368 |
-
else:
|
369 |
-
self.trainable_layers, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
|
370 |
-
|
371 |
-
if self.config.async_reward_computation:
|
372 |
-
self.executor = futures.ThreadPoolExecutor(max_workers=config.max_workers)
|
373 |
-
|
374 |
-
if config.resume_from:
|
375 |
-
logger.info(f"Resuming from {config.resume_from}")
|
376 |
-
self.accelerator.load_state(config.resume_from)
|
377 |
-
self.first_epoch = int(config.resume_from.split("_")[-1]) + 1
|
378 |
-
else:
|
379 |
-
self.first_epoch = 0
|
380 |
-
|
381 |
-
def compute_rewards(self, prompt_image_pairs, is_async=False):
|
382 |
-
if not is_async:
|
383 |
-
rewards = []
|
384 |
-
for images, prompts, prompt_metadata in prompt_image_pairs:
|
385 |
-
reward, reward_metadata = self.reward_fn(images, prompts, prompt_metadata)
|
386 |
-
rewards.append(
|
387 |
-
(
|
388 |
-
torch.as_tensor(reward, device=self.accelerator.device),
|
389 |
-
reward_metadata,
|
390 |
-
)
|
391 |
-
)
|
392 |
-
else:
|
393 |
-
rewards = self.executor.map(lambda x: self.reward_fn(*x), prompt_image_pairs)
|
394 |
-
rewards = [
|
395 |
-
(torch.as_tensor(reward.result(), device=self.accelerator.device), reward_metadata.result())
|
396 |
-
for reward, reward_metadata in rewards
|
397 |
-
]
|
398 |
-
|
399 |
-
return zip(*rewards)
|
400 |
-
|
401 |
-
def step(self, epoch: int, global_step: int):
|
402 |
-
"""
|
403 |
-
Perform a single step of training.
|
404 |
-
|
405 |
-
Args:
|
406 |
-
epoch (int): The current epoch.
|
407 |
-
global_step (int): The current global step.
|
408 |
-
|
409 |
-
Side Effects:
|
410 |
-
- Model weights are updated
|
411 |
-
- Logs the statistics to the accelerator trackers.
|
412 |
-
- If `self.image_samples_callback` is not None, it will be called with the prompt_image_pairs, global_step, and the accelerator tracker.
|
413 |
-
|
414 |
-
Returns:
|
415 |
-
global_step (int): The updated global step.
|
416 |
-
|
417 |
-
"""
|
418 |
-
samples, prompt_image_data = self._generate_samples(
|
419 |
-
iterations=self.config.sample_num_batches_per_epoch,
|
420 |
-
batch_size=self.config.sample_batch_size,
|
421 |
-
)
|
422 |
-
|
423 |
-
# collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...)
|
424 |
-
samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()}
|
425 |
-
rewards, rewards_metadata = self.compute_rewards(
|
426 |
-
prompt_image_data, is_async=self.config.async_reward_computation
|
427 |
-
)
|
428 |
-
|
429 |
-
for i, image_data in enumerate(prompt_image_data):
|
430 |
-
image_data.extend([rewards[i], rewards_metadata[i]])
|
431 |
-
|
432 |
-
if self.image_samples_callback is not None:
|
433 |
-
self.image_samples_callback(prompt_image_data, global_step, self.accelerator.trackers[0])
|
434 |
-
|
435 |
-
rewards = torch.cat(rewards)
|
436 |
-
rewards = self.accelerator.gather(rewards).cpu().numpy()
|
437 |
-
|
438 |
-
self.accelerator.log(
|
439 |
-
{
|
440 |
-
"reward": rewards,
|
441 |
-
"epoch": epoch,
|
442 |
-
"reward_mean": rewards.mean(),
|
443 |
-
"reward_std": rewards.std(),
|
444 |
-
},
|
445 |
-
step=global_step,
|
446 |
-
)
|
447 |
-
|
448 |
-
if self.config.per_prompt_stat_tracking:
|
449 |
-
# gather the prompts across processes
|
450 |
-
prompt_ids = self.accelerator.gather(samples["prompt_ids"]).cpu().numpy()
|
451 |
-
prompts = self.sd_pipeline.tokenizer.batch_decode(prompt_ids, skip_special_tokens=True)
|
452 |
-
advantages = self.stat_tracker.update(prompts, rewards)
|
453 |
-
else:
|
454 |
-
advantages = (rewards - rewards.mean()) / (rewards.std() + 1e-8)
|
455 |
-
|
456 |
-
# ungather advantages; keep the entries corresponding to the samples on this process
|
457 |
-
samples["advantages"] = (
|
458 |
-
torch.as_tensor(advantages)
|
459 |
-
.reshape(self.accelerator.num_processes, -1)[self.accelerator.process_index]
|
460 |
-
.to(self.accelerator.device)
|
461 |
-
)
|
462 |
-
|
463 |
-
del samples["prompt_ids"]
|
464 |
-
|
465 |
-
total_batch_size, num_timesteps = samples["timesteps"].shape
|
466 |
-
|
467 |
-
for inner_epoch in range(self.config.train_num_inner_epochs):
|
468 |
-
# shuffle samples along batch dimension
|
469 |
-
perm = torch.randperm(total_batch_size, device=self.accelerator.device)
|
470 |
-
samples = {k: v[perm] for k, v in samples.items()}
|
471 |
-
|
472 |
-
# shuffle along time dimension independently for each sample
|
473 |
-
# still trying to understand the code below
|
474 |
-
perms = torch.stack(
|
475 |
-
[torch.randperm(num_timesteps, device=self.accelerator.device) for _ in range(total_batch_size)]
|
476 |
-
)
|
477 |
-
|
478 |
-
for key in ["timesteps", "latents", "next_latents", "log_probs"]:
|
479 |
-
samples[key] = samples[key][
|
480 |
-
torch.arange(total_batch_size, device=self.accelerator.device)[:, None],
|
481 |
-
perms,
|
482 |
-
]
|
483 |
-
|
484 |
-
original_keys = samples.keys()
|
485 |
-
original_values = samples.values()
|
486 |
-
# rebatch them as user defined train_batch_size is different from sample_batch_size
|
487 |
-
reshaped_values = [v.reshape(-1, self.config.train_batch_size, *v.shape[1:]) for v in original_values]
|
488 |
-
|
489 |
-
# Transpose the list of original values
|
490 |
-
transposed_values = zip(*reshaped_values)
|
491 |
-
# Create new dictionaries for each row of transposed values
|
492 |
-
samples_batched = [dict(zip(original_keys, row_values)) for row_values in transposed_values]
|
493 |
-
|
494 |
-
self.sd_pipeline.unet.train()
|
495 |
-
global_step = self._train_batched_samples(inner_epoch, epoch, global_step, samples_batched)
|
496 |
-
# ensure optimization step at the end of the inner epoch
|
497 |
-
if not self.accelerator.sync_gradients:
|
498 |
-
raise ValueError(
|
499 |
-
"Optimization step should have been performed by this point. Please check calculated gradient accumulation settings."
|
500 |
-
)
|
501 |
-
|
502 |
-
if epoch != 0 and epoch % self.config.save_freq == 0 and self.accelerator.is_main_process:
|
503 |
-
self.accelerator.save_state()
|
504 |
-
|
505 |
-
return global_step
|
506 |
-
|
507 |
-
def calculate_loss(self, latents, timesteps, next_latents, log_probs, advantages, embeds):
|
508 |
-
"""
|
509 |
-
Calculate the loss for a batch of an unpacked sample
|
510 |
-
|
511 |
-
Args:
|
512 |
-
latents (torch.Tensor):
|
513 |
-
The latents sampled from the diffusion model, shape: [batch_size, num_channels_latents, height, width]
|
514 |
-
timesteps (torch.Tensor):
|
515 |
-
The timesteps sampled from the diffusion model, shape: [batch_size]
|
516 |
-
next_latents (torch.Tensor):
|
517 |
-
The next latents sampled from the diffusion model, shape: [batch_size, num_channels_latents, height, width]
|
518 |
-
log_probs (torch.Tensor):
|
519 |
-
The log probabilities of the latents, shape: [batch_size]
|
520 |
-
advantages (torch.Tensor):
|
521 |
-
The advantages of the latents, shape: [batch_size]
|
522 |
-
embeds (torch.Tensor):
|
523 |
-
The embeddings of the prompts, shape: [2*batch_size or batch_size, ...]
|
524 |
-
Note: the "or" is because if train_cfg is True, the expectation is that negative prompts are concatenated to the embeds
|
525 |
-
|
526 |
-
Returns:
|
527 |
-
loss (torch.Tensor), approx_kl (torch.Tensor), clipfrac (torch.Tensor)
|
528 |
-
(all of these are of shape (1,))
|
529 |
-
"""
|
530 |
-
with self.autocast():
|
531 |
-
if self.config.train_cfg:
|
532 |
-
noise_pred = self.sd_pipeline.unet(
|
533 |
-
torch.cat([latents] * 2),
|
534 |
-
torch.cat([timesteps] * 2),
|
535 |
-
embeds,
|
536 |
-
).sample
|
537 |
-
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
538 |
-
noise_pred = noise_pred_uncond + self.config.sample_guidance_scale * (
|
539 |
-
noise_pred_text - noise_pred_uncond
|
540 |
-
)
|
541 |
-
else:
|
542 |
-
noise_pred = self.sd_pipeline.unet(
|
543 |
-
latents,
|
544 |
-
timesteps,
|
545 |
-
embeds,
|
546 |
-
).sample
|
547 |
-
# compute the log prob of next_latents given latents under the current model
|
548 |
-
|
549 |
-
scheduler_step_output = self.sd_pipeline.scheduler_step(
|
550 |
-
noise_pred,
|
551 |
-
timesteps,
|
552 |
-
latents,
|
553 |
-
eta=self.config.sample_eta,
|
554 |
-
prev_sample=next_latents,
|
555 |
-
)
|
556 |
-
|
557 |
-
log_prob = scheduler_step_output.log_probs
|
558 |
-
|
559 |
-
advantages = torch.clamp(
|
560 |
-
advantages,
|
561 |
-
-self.config.train_adv_clip_max,
|
562 |
-
self.config.train_adv_clip_max,
|
563 |
-
)
|
564 |
-
|
565 |
-
ratio = torch.exp(log_prob - log_probs)
|
566 |
-
|
567 |
-
loss = self.loss(advantages, self.config.train_clip_range, ratio)
|
568 |
-
|
569 |
-
approx_kl = 0.5 * torch.mean((log_prob - log_probs) ** 2)
|
570 |
-
|
571 |
-
clipfrac = torch.mean((torch.abs(ratio - 1.0) > self.config.train_clip_range).float())
|
572 |
-
|
573 |
-
return loss, approx_kl, clipfrac
|
574 |
-
|
575 |
-
def loss(
|
576 |
-
self,
|
577 |
-
advantages: torch.Tensor,
|
578 |
-
clip_range: float,
|
579 |
-
ratio: torch.Tensor,
|
580 |
-
):
|
581 |
-
unclipped_loss = -advantages * ratio
|
582 |
-
clipped_loss = -advantages * torch.clamp(
|
583 |
-
ratio,
|
584 |
-
1.0 - clip_range,
|
585 |
-
1.0 + clip_range,
|
586 |
-
)
|
587 |
-
return torch.mean(torch.maximum(unclipped_loss, clipped_loss))
|
588 |
-
|
589 |
-
def _setup_optimizer(self, trainable_layers_parameters):
|
590 |
-
if self.config.train_use_8bit_adam:
|
591 |
-
import bitsandbytes
|
592 |
-
|
593 |
-
optimizer_cls = bitsandbytes.optim.AdamW8bit
|
594 |
-
else:
|
595 |
-
optimizer_cls = torch.optim.AdamW
|
596 |
-
|
597 |
-
return optimizer_cls(
|
598 |
-
trainable_layers_parameters,
|
599 |
-
lr=self.config.train_learning_rate,
|
600 |
-
betas=(self.config.train_adam_beta1, self.config.train_adam_beta2),
|
601 |
-
weight_decay=self.config.train_adam_weight_decay,
|
602 |
-
eps=self.config.train_adam_epsilon,
|
603 |
-
)
|
604 |
-
|
605 |
-
def _save_model_hook(self, models, weights, output_dir):
|
606 |
-
self.sd_pipeline.save_checkpoint(models, weights, output_dir)
|
607 |
-
weights.pop() # ensures that accelerate doesn't try to handle saving of the model
|
608 |
-
|
609 |
-
def _load_model_hook(self, models, input_dir):
|
610 |
-
self.sd_pipeline.load_checkpoint(models, input_dir)
|
611 |
-
models.pop() # ensures that accelerate doesn't try to handle loading of the model
|
612 |
-
|
613 |
-
def _generate_samples(self, iterations, batch_size):
|
614 |
-
"""
|
615 |
-
Generate samples from the model
|
616 |
-
|
617 |
-
Args:
|
618 |
-
iterations (int): Number of iterations to generate samples for
|
619 |
-
batch_size (int): Batch size to use for sampling
|
620 |
-
|
621 |
-
Returns:
|
622 |
-
samples (list[dict[str, torch.Tensor]]), prompt_image_pairs (list[list[Any]])
|
623 |
-
"""
|
624 |
-
samples = []
|
625 |
-
prompt_image_pairs = []
|
626 |
-
self.sd_pipeline.unet.eval()
|
627 |
-
|
628 |
-
sample_neg_prompt_embeds = self.neg_prompt_embed.repeat(batch_size, 1, 1)
|
629 |
-
|
630 |
-
for _ in range(iterations):
|
631 |
-
prompts, prompt_metadata = zip(*[self.prompt_fn() for _ in range(batch_size)])
|
632 |
-
|
633 |
-
prompt_ids = self.sd_pipeline.tokenizer(
|
634 |
-
prompts,
|
635 |
-
return_tensors="pt",
|
636 |
-
padding="max_length",
|
637 |
-
truncation=True,
|
638 |
-
max_length=self.sd_pipeline.tokenizer.model_max_length,
|
639 |
-
).input_ids.to(self.accelerator.device)
|
640 |
-
prompt_embeds = self.sd_pipeline.text_encoder(prompt_ids)[0]
|
641 |
-
|
642 |
-
with self.autocast():
|
643 |
-
sd_output = self.sd_pipeline(
|
644 |
-
prompt_embeds=prompt_embeds,
|
645 |
-
negative_prompt_embeds=sample_neg_prompt_embeds,
|
646 |
-
num_inference_steps=self.config.sample_num_steps,
|
647 |
-
guidance_scale=self.config.sample_guidance_scale,
|
648 |
-
eta=self.config.sample_eta,
|
649 |
-
output_type="pt",
|
650 |
-
)
|
651 |
-
|
652 |
-
images = sd_output.images
|
653 |
-
latents = sd_output.latents
|
654 |
-
log_probs = sd_output.log_probs
|
655 |
-
|
656 |
-
latents = torch.stack(latents, dim=1) # (batch_size, num_steps + 1, ...)
|
657 |
-
log_probs = torch.stack(log_probs, dim=1) # (batch_size, num_steps, 1)
|
658 |
-
timesteps = self.sd_pipeline.scheduler.timesteps.repeat(batch_size, 1) # (batch_size, num_steps)
|
659 |
-
|
660 |
-
samples.append(
|
661 |
-
{
|
662 |
-
"prompt_ids": prompt_ids,
|
663 |
-
"prompt_embeds": prompt_embeds,
|
664 |
-
"timesteps": timesteps,
|
665 |
-
"latents": latents[:, :-1], # each entry is the latent before timestep t
|
666 |
-
"next_latents": latents[:, 1:], # each entry is the latent after timestep t
|
667 |
-
"log_probs": log_probs,
|
668 |
-
"negative_prompt_embeds": sample_neg_prompt_embeds,
|
669 |
-
}
|
670 |
-
)
|
671 |
-
prompt_image_pairs.append([images, prompts, prompt_metadata])
|
672 |
-
|
673 |
-
return samples, prompt_image_pairs
|
674 |
-
|
675 |
-
def _train_batched_samples(self, inner_epoch, epoch, global_step, batched_samples):
|
676 |
-
"""
|
677 |
-
Train on a batch of samples. Main training segment
|
678 |
-
|
679 |
-
Args:
|
680 |
-
inner_epoch (int): The current inner epoch
|
681 |
-
epoch (int): The current epoch
|
682 |
-
global_step (int): The current global step
|
683 |
-
batched_samples (list[dict[str, torch.Tensor]]): The batched samples to train on
|
684 |
-
|
685 |
-
Side Effects:
|
686 |
-
- Model weights are updated
|
687 |
-
- Logs the statistics to the accelerator trackers.
|
688 |
-
|
689 |
-
Returns:
|
690 |
-
global_step (int): The updated global step
|
691 |
-
"""
|
692 |
-
info = defaultdict(list)
|
693 |
-
for _i, sample in enumerate(batched_samples):
|
694 |
-
if self.config.train_cfg:
|
695 |
-
# concat negative prompts to sample prompts to avoid two forward passes
|
696 |
-
embeds = torch.cat([sample["negative_prompt_embeds"], sample["prompt_embeds"]])
|
697 |
-
else:
|
698 |
-
embeds = sample["prompt_embeds"]
|
699 |
-
|
700 |
-
for j in range(self.num_train_timesteps):
|
701 |
-
with self.accelerator.accumulate(self.sd_pipeline.unet):
|
702 |
-
loss, approx_kl, clipfrac = self.calculate_loss(
|
703 |
-
sample["latents"][:, j],
|
704 |
-
sample["timesteps"][:, j],
|
705 |
-
sample["next_latents"][:, j],
|
706 |
-
sample["log_probs"][:, j],
|
707 |
-
sample["advantages"],
|
708 |
-
embeds,
|
709 |
-
)
|
710 |
-
info["approx_kl"].append(approx_kl)
|
711 |
-
info["clipfrac"].append(clipfrac)
|
712 |
-
info["loss"].append(loss)
|
713 |
-
|
714 |
-
self.accelerator.backward(loss)
|
715 |
-
if self.accelerator.sync_gradients:
|
716 |
-
self.accelerator.clip_grad_norm_(
|
717 |
-
self.trainable_layers.parameters()
|
718 |
-
if not isinstance(self.trainable_layers, list)
|
719 |
-
else self.trainable_layers,
|
720 |
-
self.config.train_max_grad_norm,
|
721 |
-
)
|
722 |
-
self.optimizer.step()
|
723 |
-
self.optimizer.zero_grad()
|
724 |
-
|
725 |
-
# Checks if the accelerator has performed an optimization step behind the scenes
|
726 |
-
if self.accelerator.sync_gradients:
|
727 |
-
# log training-related stuff
|
728 |
-
info = {k: torch.mean(torch.stack(v)) for k, v in info.items()}
|
729 |
-
info = self.accelerator.reduce(info, reduction="mean")
|
730 |
-
info.update({"epoch": epoch, "inner_epoch": inner_epoch})
|
731 |
-
self.accelerator.log(info, step=global_step)
|
732 |
-
global_step += 1
|
733 |
-
info = defaultdict(list)
|
734 |
-
return global_step
|
735 |
-
|
736 |
-
def _config_check(self) -> tuple[bool, str]:
|
737 |
-
samples_per_epoch = (
|
738 |
-
self.config.sample_batch_size * self.accelerator.num_processes * self.config.sample_num_batches_per_epoch
|
739 |
-
)
|
740 |
-
total_train_batch_size = (
|
741 |
-
self.config.train_batch_size
|
742 |
-
* self.accelerator.num_processes
|
743 |
-
* self.config.train_gradient_accumulation_steps
|
744 |
-
)
|
745 |
-
|
746 |
-
if not self.config.sample_batch_size >= self.config.train_batch_size:
|
747 |
-
return (
|
748 |
-
False,
|
749 |
-
f"Sample batch size ({self.config.sample_batch_size}) must be greater than or equal to the train batch size ({self.config.train_batch_size})",
|
750 |
-
)
|
751 |
-
if not self.config.sample_batch_size % self.config.train_batch_size == 0:
|
752 |
-
return (
|
753 |
-
False,
|
754 |
-
f"Sample batch size ({self.config.sample_batch_size}) must be divisible by the train batch size ({self.config.train_batch_size})",
|
755 |
-
)
|
756 |
-
if not samples_per_epoch % total_train_batch_size == 0:
|
757 |
-
return (
|
758 |
-
False,
|
759 |
-
f"Number of samples per epoch ({samples_per_epoch}) must be divisible by the total train batch size ({total_train_batch_size})",
|
760 |
-
)
|
761 |
-
return True, ""
|
762 |
-
|
763 |
-
def train(self, epochs: Optional[int] = None):
|
764 |
-
"""
|
765 |
-
Train the model for a given number of epochs
|
766 |
-
"""
|
767 |
-
global_step = 0
|
768 |
-
if epochs is None:
|
769 |
-
epochs = self.config.num_epochs
|
770 |
-
for epoch in range(self.first_epoch, epochs):
|
771 |
-
global_step = self.step(epoch, global_step)
|
772 |
-
|
773 |
-
def _save_pretrained(self, save_directory):
|
774 |
-
self.sd_pipeline.save_pretrained(save_directory)
|
775 |
-
self.create_model_card()
|
776 |
-
|
777 |
-
def create_model_card(
|
778 |
-
self,
|
779 |
-
model_name: Optional[str] = None,
|
780 |
-
dataset_name: Optional[str] = None,
|
781 |
-
tags: Union[str, list[str], None] = None,
|
782 |
-
):
|
783 |
-
"""
|
784 |
-
Creates a draft of a model card using the information available to the `Trainer`.
|
785 |
-
|
786 |
-
Args:
|
787 |
-
model_name (`str` or `None`, *optional*, defaults to `None`):
|
788 |
-
Name of the model.
|
789 |
-
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
790 |
-
Name of the dataset used for training.
|
791 |
-
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
792 |
-
Tags to be associated with the model card.
|
793 |
-
"""
|
794 |
-
if not self.is_world_process_zero():
|
795 |
-
return
|
796 |
-
|
797 |
-
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
798 |
-
base_model = self.model.config._name_or_path
|
799 |
-
else:
|
800 |
-
base_model = None
|
801 |
-
|
802 |
-
tags = tags or []
|
803 |
-
if isinstance(tags, str):
|
804 |
-
tags = [tags]
|
805 |
-
|
806 |
-
if hasattr(self.model.config, "unsloth_version"):
|
807 |
-
tags.append("unsloth")
|
808 |
-
|
809 |
-
citation = textwrap.dedent("""\
|
810 |
-
@inproceedings{black2024training,
|
811 |
-
title = {{Training Diffusion Models with Reinforcement Learning}},
|
812 |
-
author = {Kevin Black and Michael Janner and Yilun Du and Ilya Kostrikov and Sergey Levine},
|
813 |
-
year = 2024,
|
814 |
-
booktitle = {The Twelfth International Conference on Learning Representations, {ICLR} 2024, Vienna, Austria, May 7-11, 2024},
|
815 |
-
publisher = {OpenReview.net},
|
816 |
-
url = {https://openreview.net/forum?id=YCWjhGrJFD},
|
817 |
-
}""")
|
818 |
-
|
819 |
-
model_card = generate_model_card(
|
820 |
-
base_model=base_model,
|
821 |
-
model_name=model_name,
|
822 |
-
hub_model_id=self.hub_model_id,
|
823 |
-
dataset_name=dataset_name,
|
824 |
-
tags=tags,
|
825 |
-
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
826 |
-
comet_url=get_comet_experiment_url(),
|
827 |
-
trainer_name="DDPO",
|
828 |
-
trainer_citation=citation,
|
829 |
-
paper_title="Training Diffusion Models with Reinforcement Learning",
|
830 |
-
paper_id="2305.13301",
|
831 |
-
)
|
832 |
-
|
833 |
-
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
834 |
-
class UnslothDDPOTrainer(_UnslothDDPOTrainer):
|
835 |
-
"""
|
836 |
-
|
837 |
-
The DDPOTrainer uses Deep Diffusion Policy Optimization to optimise diffusion models.
|
838 |
-
Note, this trainer is heavily inspired by the work here: https://github.com/kvablack/ddpo-pytorch
|
839 |
-
As of now only Stable Diffusion based pipelines are supported
|
840 |
-
|
841 |
-
Attributes:
|
842 |
-
**config** (`DDPOConfig`) -- Configuration object for DDPOTrainer. Check the documentation of `PPOConfig` for more
|
843 |
-
details.
|
844 |
-
**reward_function** (Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor]) -- Reward function to be used
|
845 |
-
**prompt_function** (Callable[[], tuple[str, Any]]) -- Function to generate prompts to guide model
|
846 |
-
**sd_pipeline** (`DDPOStableDiffusionPipeline`) -- Stable Diffusion pipeline to be used for training.
|
847 |
-
**image_samples_hook** (Optional[Callable[[Any, Any, Any], Any]]) -- Hook to be called to log images
|
848 |
-
|
849 |
-
"""
|
850 |
-
def __init__(
|
851 |
-
self,
|
852 |
-
config,
|
853 |
-
reward_function,
|
854 |
-
prompt_function,
|
855 |
-
sd_pipeline,
|
856 |
-
image_samples_hook = None,
|
857 |
-
**kwargs
|
858 |
-
):
|
859 |
-
if args is None: args = UnslothDDPOConfig()
|
860 |
-
other_metrics = []
|
861 |
-
|
862 |
-
from unsloth_zoo.logging_utils import PatchRLStatistics
|
863 |
-
PatchRLStatistics('ddpo_trainer', other_metrics)
|
864 |
-
|
865 |
-
super().__init__(
|
866 |
-
config = config,
|
867 |
-
reward_function = reward_function,
|
868 |
-
prompt_function = prompt_function,
|
869 |
-
sd_pipeline = sd_pipeline,
|
870 |
-
image_samples_hook = image_samples_hook,**kwargs)
|
871 |
-
|
872 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unsloth_compiled_cache/UnslothDPOTrainer.py
DELETED
The diff for this file is too large to render.
See raw diff
|
|
unsloth_compiled_cache/UnslothGKDTrainer.py
DELETED
@@ -1,861 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
2025.3.13
|
3 |
-
2025.3.15
|
4 |
-
4.48.3
|
5 |
-
0.15.2
|
6 |
-
__UNSLOTH_VERSIONING__
|
7 |
-
"""
|
8 |
-
from torch import Tensor
|
9 |
-
import torch
|
10 |
-
import torch.nn as nn
|
11 |
-
from torch.nn import functional as F
|
12 |
-
from trl.trainer.gkd_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, Callable, DataCollator, DataCollatorForChatML, Dataset, EvalPrediction, F, FeatureExtractionMixin, GKDConfig, GKDTrainer, GenerationConfig, Optional, PeftConfig, PreTrainedModel, PreTrainedModelWrapper, PreTrainedTokenizerBase, ProcessorMixin, SFTTrainer, TrainerCallback, Union, deepcopy, disable_dropout_in_model, empty_cache, generate_model_card, get_comet_experiment_url, is_wandb_available, nn, os, random, textwrap, torch, unwrap_model_for_generation, wandb)
|
13 |
-
|
14 |
-
|
15 |
-
import os
|
16 |
-
from typing import *
|
17 |
-
from dataclasses import dataclass, field
|
18 |
-
from packaging.version import Version
|
19 |
-
import torch
|
20 |
-
import numpy as np
|
21 |
-
from contextlib import nullcontext
|
22 |
-
from torch.nn import functional as F
|
23 |
-
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
|
24 |
-
|
25 |
-
torch_compile_options = {
|
26 |
-
"epilogue_fusion" : True,
|
27 |
-
"max_autotune" : False,
|
28 |
-
"shape_padding" : True,
|
29 |
-
"trace.enabled" : False,
|
30 |
-
"triton.cudagraphs" : False,
|
31 |
-
}
|
32 |
-
|
33 |
-
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
34 |
-
def selective_log_softmax(logits, index):
|
35 |
-
logits = logits.to(torch.float32)
|
36 |
-
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
37 |
-
# loop to reduce peak mem consumption
|
38 |
-
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
39 |
-
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
40 |
-
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
41 |
-
return per_token_logps
|
42 |
-
@dataclass
|
43 |
-
class UnslothGKDConfig(GKDConfig):
|
44 |
-
"""
|
45 |
-
|
46 |
-
Configuration class for [`GKDTrainer`].
|
47 |
-
|
48 |
-
Args:
|
49 |
-
temperature (`float`, *optional*, defaults to `0.9`):
|
50 |
-
Temperature for sampling. The higher the temperature, the more random the completions.
|
51 |
-
lmbda (`float`, *optional*, defaults to `0.5`):
|
52 |
-
Lambda parameter that controls the student data fraction (i.e., the proportion of on-policy
|
53 |
-
student-generated outputs).
|
54 |
-
beta (`float`, *optional*, defaults to `0.5`):
|
55 |
-
Interpolation coefficient between `0.0` and `1.0` of the Generalized Jensen-Shannon Divergence loss. When
|
56 |
-
beta is `0.0`, the loss is the KL divergence. When beta is `1.0`, the loss is the Inverse KL Divergence.
|
57 |
-
max_new_tokens (`int`, *optional*, defaults to `128`):
|
58 |
-
Maximum number of tokens to generate per completion.
|
59 |
-
teacher_model_name_or_path (`str` or `None`, *optional*, defaults to `None`):
|
60 |
-
Model name or path of the teacher model. If `None`, the teacher model will be the same as the model
|
61 |
-
being trained.
|
62 |
-
teacher_model_init_kwargs (`dict[str, Any]]` or `None`, *optional*, defaults to `None`):
|
63 |
-
Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the teacher model
|
64 |
-
from a string.
|
65 |
-
disable_dropout (`bool`, *optional*, defaults to `True`):
|
66 |
-
Whether to disable dropout in the model.
|
67 |
-
seq_kd (`bool`, *optional*, defaults to `False`):
|
68 |
-
Seq_kd parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised FT
|
69 |
-
on teacher-generated output).
|
70 |
-
|
71 |
-
"""
|
72 |
-
vllm_sampling_params: Optional[Any] = field(
|
73 |
-
default = None,
|
74 |
-
metadata = {'help': 'vLLM SamplingParams'},
|
75 |
-
)
|
76 |
-
unsloth_num_chunks : Optional[int] = field(
|
77 |
-
default = -1,
|
78 |
-
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
79 |
-
)
|
80 |
-
def __init__(
|
81 |
-
self,
|
82 |
-
output_dir = None,
|
83 |
-
overwrite_output_dir = None,
|
84 |
-
do_train = False,
|
85 |
-
do_eval = False,
|
86 |
-
do_predict = False,
|
87 |
-
eval_strategy = 'no',
|
88 |
-
prediction_loss_only = False,
|
89 |
-
per_device_train_batch_size = 4,
|
90 |
-
per_device_eval_batch_size = 4,
|
91 |
-
per_gpu_train_batch_size = None,
|
92 |
-
per_gpu_eval_batch_size = None,
|
93 |
-
gradient_accumulation_steps = 2,
|
94 |
-
eval_accumulation_steps = 2,
|
95 |
-
eval_delay = 0,
|
96 |
-
torch_empty_cache_steps = 250,
|
97 |
-
learning_rate = 5e-05,
|
98 |
-
weight_decay = 0.01,
|
99 |
-
adam_beta1 = 0.9,
|
100 |
-
adam_beta2 = 0.999,
|
101 |
-
adam_epsilon = 1e-08,
|
102 |
-
max_grad_norm = 1.0,
|
103 |
-
num_train_epochs = 3.0,
|
104 |
-
max_steps = -1,
|
105 |
-
lr_scheduler_type = 'linear',
|
106 |
-
warmup_ratio = 0.1,
|
107 |
-
warmup_steps = 0,
|
108 |
-
log_level = 'passive',
|
109 |
-
log_level_replica = 'warning',
|
110 |
-
log_on_each_node = True,
|
111 |
-
logging_dir = None,
|
112 |
-
logging_strategy = 'steps',
|
113 |
-
logging_first_step = False,
|
114 |
-
logging_steps = 1,
|
115 |
-
logging_nan_inf_filter = False,
|
116 |
-
save_strategy = 'steps',
|
117 |
-
save_steps = 500,
|
118 |
-
save_total_limit = None,
|
119 |
-
save_safetensors = True,
|
120 |
-
save_on_each_node = False,
|
121 |
-
save_only_model = False,
|
122 |
-
restore_callback_states_from_checkpoint = False,
|
123 |
-
no_cuda = False,
|
124 |
-
use_cpu = False,
|
125 |
-
use_mps_device = False,
|
126 |
-
seed = 3407,
|
127 |
-
data_seed = 3407,
|
128 |
-
jit_mode_eval = False,
|
129 |
-
use_ipex = False,
|
130 |
-
bf16 = False,
|
131 |
-
fp16 = False,
|
132 |
-
fp16_opt_level = 'O1',
|
133 |
-
half_precision_backend = 'auto',
|
134 |
-
bf16_full_eval = False,
|
135 |
-
fp16_full_eval = False,
|
136 |
-
tf32 = None,
|
137 |
-
local_rank = -1,
|
138 |
-
ddp_backend = None,
|
139 |
-
tpu_num_cores = None,
|
140 |
-
tpu_metrics_debug = False,
|
141 |
-
debug = '',
|
142 |
-
dataloader_drop_last = False,
|
143 |
-
eval_steps = None,
|
144 |
-
dataloader_num_workers = 0,
|
145 |
-
dataloader_prefetch_factor = None,
|
146 |
-
past_index = -1,
|
147 |
-
run_name = None,
|
148 |
-
disable_tqdm = None,
|
149 |
-
remove_unused_columns = True,
|
150 |
-
label_names = None,
|
151 |
-
load_best_model_at_end = False,
|
152 |
-
metric_for_best_model = None,
|
153 |
-
greater_is_better = None,
|
154 |
-
ignore_data_skip = False,
|
155 |
-
fsdp = '',
|
156 |
-
fsdp_min_num_params = 0,
|
157 |
-
fsdp_config = None,
|
158 |
-
fsdp_transformer_layer_cls_to_wrap = None,
|
159 |
-
accelerator_config = None,
|
160 |
-
deepspeed = None,
|
161 |
-
label_smoothing_factor = 0.0,
|
162 |
-
optim = 'adamw_8bit',
|
163 |
-
optim_args = None,
|
164 |
-
adafactor = False,
|
165 |
-
group_by_length = False,
|
166 |
-
length_column_name = 'length',
|
167 |
-
report_to = None,
|
168 |
-
ddp_find_unused_parameters = None,
|
169 |
-
ddp_bucket_cap_mb = None,
|
170 |
-
ddp_broadcast_buffers = None,
|
171 |
-
dataloader_pin_memory = True,
|
172 |
-
dataloader_persistent_workers = False,
|
173 |
-
skip_memory_metrics = True,
|
174 |
-
use_legacy_prediction_loop = False,
|
175 |
-
push_to_hub = False,
|
176 |
-
resume_from_checkpoint = None,
|
177 |
-
hub_model_id = None,
|
178 |
-
hub_strategy = 'every_save',
|
179 |
-
hub_token = None,
|
180 |
-
hub_private_repo = None,
|
181 |
-
hub_always_push = False,
|
182 |
-
gradient_checkpointing = False,
|
183 |
-
gradient_checkpointing_kwargs = None,
|
184 |
-
include_inputs_for_metrics = False,
|
185 |
-
eval_do_concat_batches = True,
|
186 |
-
fp16_backend = 'auto',
|
187 |
-
evaluation_strategy = None,
|
188 |
-
push_to_hub_model_id = None,
|
189 |
-
push_to_hub_organization = None,
|
190 |
-
push_to_hub_token = None,
|
191 |
-
mp_parameters = '',
|
192 |
-
auto_find_batch_size = False,
|
193 |
-
full_determinism = False,
|
194 |
-
torchdynamo = None,
|
195 |
-
ray_scope = 'last',
|
196 |
-
ddp_timeout = 1800,
|
197 |
-
torch_compile = False,
|
198 |
-
torch_compile_backend = None,
|
199 |
-
torch_compile_mode = None,
|
200 |
-
dispatch_batches = None,
|
201 |
-
split_batches = None,
|
202 |
-
include_tokens_per_second = False,
|
203 |
-
include_num_input_tokens_seen = False,
|
204 |
-
neftune_noise_alpha = None,
|
205 |
-
optim_target_modules = None,
|
206 |
-
batch_eval_metrics = False,
|
207 |
-
eval_on_start = False,
|
208 |
-
use_liger_kernel = False,
|
209 |
-
eval_use_gather_object = False,
|
210 |
-
average_tokens_across_devices = False,
|
211 |
-
model_init_kwargs = None,
|
212 |
-
use_liger = False,
|
213 |
-
dataset_text_field = 'text',
|
214 |
-
dataset_kwargs = None,
|
215 |
-
dataset_num_proc = None,
|
216 |
-
max_seq_length = None,
|
217 |
-
packing = False,
|
218 |
-
eval_packing = None,
|
219 |
-
dataset_batch_size = None,
|
220 |
-
num_of_sequences = None,
|
221 |
-
chars_per_token = None,
|
222 |
-
temperature = 0.9,
|
223 |
-
lmbda = 0.5,
|
224 |
-
beta = 0.5,
|
225 |
-
max_new_tokens = 128,
|
226 |
-
teacher_model_name_or_path = None,
|
227 |
-
teacher_model_init_kwargs = None,
|
228 |
-
disable_dropout = True,
|
229 |
-
seq_kd = False,
|
230 |
-
vllm_sampling_params = None,
|
231 |
-
unsloth_num_chunks = -1,
|
232 |
-
**kwargs,
|
233 |
-
):
|
234 |
-
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
235 |
-
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
236 |
-
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
237 |
-
output_dir = 'unsloth_training_checkpoints'
|
238 |
-
save_strategy = 'no'
|
239 |
-
if dataset_num_proc is None:
|
240 |
-
from multiprocessing import cpu_count
|
241 |
-
dataset_num_proc = cpu_count()
|
242 |
-
|
243 |
-
super().__init__(
|
244 |
-
output_dir = output_dir,
|
245 |
-
overwrite_output_dir = overwrite_output_dir,
|
246 |
-
do_train = do_train,
|
247 |
-
do_eval = do_eval,
|
248 |
-
do_predict = do_predict,
|
249 |
-
eval_strategy = eval_strategy,
|
250 |
-
prediction_loss_only = prediction_loss_only,
|
251 |
-
per_device_train_batch_size = per_device_train_batch_size,
|
252 |
-
per_device_eval_batch_size = per_device_eval_batch_size,
|
253 |
-
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
254 |
-
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
255 |
-
gradient_accumulation_steps = gradient_accumulation_steps,
|
256 |
-
eval_accumulation_steps = eval_accumulation_steps,
|
257 |
-
eval_delay = eval_delay,
|
258 |
-
torch_empty_cache_steps = torch_empty_cache_steps,
|
259 |
-
learning_rate = learning_rate,
|
260 |
-
weight_decay = weight_decay,
|
261 |
-
adam_beta1 = adam_beta1,
|
262 |
-
adam_beta2 = adam_beta2,
|
263 |
-
adam_epsilon = adam_epsilon,
|
264 |
-
max_grad_norm = max_grad_norm,
|
265 |
-
num_train_epochs = num_train_epochs,
|
266 |
-
max_steps = max_steps,
|
267 |
-
lr_scheduler_type = lr_scheduler_type,
|
268 |
-
warmup_ratio = warmup_ratio,
|
269 |
-
warmup_steps = warmup_steps,
|
270 |
-
log_level = log_level,
|
271 |
-
log_level_replica = log_level_replica,
|
272 |
-
log_on_each_node = log_on_each_node,
|
273 |
-
logging_dir = logging_dir,
|
274 |
-
logging_strategy = logging_strategy,
|
275 |
-
logging_first_step = logging_first_step,
|
276 |
-
logging_steps = logging_steps,
|
277 |
-
logging_nan_inf_filter = logging_nan_inf_filter,
|
278 |
-
save_strategy = save_strategy,
|
279 |
-
save_steps = save_steps,
|
280 |
-
save_total_limit = save_total_limit,
|
281 |
-
save_safetensors = save_safetensors,
|
282 |
-
save_on_each_node = save_on_each_node,
|
283 |
-
save_only_model = save_only_model,
|
284 |
-
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
285 |
-
no_cuda = no_cuda,
|
286 |
-
use_cpu = use_cpu,
|
287 |
-
use_mps_device = use_mps_device,
|
288 |
-
seed = seed,
|
289 |
-
data_seed = data_seed,
|
290 |
-
jit_mode_eval = jit_mode_eval,
|
291 |
-
use_ipex = use_ipex,
|
292 |
-
bf16 = bf16,
|
293 |
-
fp16 = fp16,
|
294 |
-
fp16_opt_level = fp16_opt_level,
|
295 |
-
half_precision_backend = half_precision_backend,
|
296 |
-
bf16_full_eval = bf16_full_eval,
|
297 |
-
fp16_full_eval = fp16_full_eval,
|
298 |
-
tf32 = tf32,
|
299 |
-
local_rank = local_rank,
|
300 |
-
ddp_backend = ddp_backend,
|
301 |
-
tpu_num_cores = tpu_num_cores,
|
302 |
-
tpu_metrics_debug = tpu_metrics_debug,
|
303 |
-
debug = debug,
|
304 |
-
dataloader_drop_last = dataloader_drop_last,
|
305 |
-
eval_steps = eval_steps,
|
306 |
-
dataloader_num_workers = dataloader_num_workers,
|
307 |
-
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
308 |
-
past_index = past_index,
|
309 |
-
run_name = run_name,
|
310 |
-
disable_tqdm = disable_tqdm,
|
311 |
-
remove_unused_columns = remove_unused_columns,
|
312 |
-
label_names = label_names,
|
313 |
-
load_best_model_at_end = load_best_model_at_end,
|
314 |
-
metric_for_best_model = metric_for_best_model,
|
315 |
-
greater_is_better = greater_is_better,
|
316 |
-
ignore_data_skip = ignore_data_skip,
|
317 |
-
fsdp = fsdp,
|
318 |
-
fsdp_min_num_params = fsdp_min_num_params,
|
319 |
-
fsdp_config = fsdp_config,
|
320 |
-
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
321 |
-
accelerator_config = accelerator_config,
|
322 |
-
deepspeed = deepspeed,
|
323 |
-
label_smoothing_factor = label_smoothing_factor,
|
324 |
-
optim = optim,
|
325 |
-
optim_args = optim_args,
|
326 |
-
adafactor = adafactor,
|
327 |
-
group_by_length = group_by_length,
|
328 |
-
length_column_name = length_column_name,
|
329 |
-
report_to = report_to,
|
330 |
-
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
331 |
-
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
332 |
-
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
333 |
-
dataloader_pin_memory = dataloader_pin_memory,
|
334 |
-
dataloader_persistent_workers = dataloader_persistent_workers,
|
335 |
-
skip_memory_metrics = skip_memory_metrics,
|
336 |
-
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
337 |
-
push_to_hub = push_to_hub,
|
338 |
-
resume_from_checkpoint = resume_from_checkpoint,
|
339 |
-
hub_model_id = hub_model_id,
|
340 |
-
hub_strategy = hub_strategy,
|
341 |
-
hub_token = hub_token,
|
342 |
-
hub_private_repo = hub_private_repo,
|
343 |
-
hub_always_push = hub_always_push,
|
344 |
-
gradient_checkpointing = gradient_checkpointing,
|
345 |
-
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
346 |
-
include_inputs_for_metrics = include_inputs_for_metrics,
|
347 |
-
eval_do_concat_batches = eval_do_concat_batches,
|
348 |
-
fp16_backend = fp16_backend,
|
349 |
-
evaluation_strategy = evaluation_strategy,
|
350 |
-
push_to_hub_model_id = push_to_hub_model_id,
|
351 |
-
push_to_hub_organization = push_to_hub_organization,
|
352 |
-
push_to_hub_token = push_to_hub_token,
|
353 |
-
mp_parameters = mp_parameters,
|
354 |
-
auto_find_batch_size = auto_find_batch_size,
|
355 |
-
full_determinism = full_determinism,
|
356 |
-
torchdynamo = torchdynamo,
|
357 |
-
ray_scope = ray_scope,
|
358 |
-
ddp_timeout = ddp_timeout,
|
359 |
-
torch_compile = torch_compile,
|
360 |
-
torch_compile_backend = torch_compile_backend,
|
361 |
-
torch_compile_mode = torch_compile_mode,
|
362 |
-
dispatch_batches = dispatch_batches,
|
363 |
-
split_batches = split_batches,
|
364 |
-
include_tokens_per_second = include_tokens_per_second,
|
365 |
-
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
366 |
-
neftune_noise_alpha = neftune_noise_alpha,
|
367 |
-
optim_target_modules = optim_target_modules,
|
368 |
-
batch_eval_metrics = batch_eval_metrics,
|
369 |
-
eval_on_start = eval_on_start,
|
370 |
-
use_liger_kernel = use_liger_kernel,
|
371 |
-
eval_use_gather_object = eval_use_gather_object,
|
372 |
-
average_tokens_across_devices = average_tokens_across_devices,
|
373 |
-
model_init_kwargs = model_init_kwargs,
|
374 |
-
use_liger = use_liger,
|
375 |
-
dataset_text_field = dataset_text_field,
|
376 |
-
dataset_kwargs = dataset_kwargs,
|
377 |
-
dataset_num_proc = dataset_num_proc,
|
378 |
-
max_seq_length = max_seq_length,
|
379 |
-
packing = packing,
|
380 |
-
eval_packing = eval_packing,
|
381 |
-
dataset_batch_size = dataset_batch_size,
|
382 |
-
num_of_sequences = num_of_sequences,
|
383 |
-
chars_per_token = chars_per_token,
|
384 |
-
temperature = temperature,
|
385 |
-
lmbda = lmbda,
|
386 |
-
beta = beta,
|
387 |
-
max_new_tokens = max_new_tokens,
|
388 |
-
teacher_model_name_or_path = teacher_model_name_or_path,
|
389 |
-
teacher_model_init_kwargs = teacher_model_init_kwargs,
|
390 |
-
disable_dropout = disable_dropout,
|
391 |
-
seq_kd = seq_kd,**kwargs)
|
392 |
-
self.vllm_sampling_params = vllm_sampling_params
|
393 |
-
self.unsloth_num_chunks = unsloth_num_chunks
|
394 |
-
pass
|
395 |
-
|
396 |
-
class _UnslothGKDTrainer(SFTTrainer):
|
397 |
-
_tag_names = ["trl", "gkd"]
|
398 |
-
|
399 |
-
def __init__(
|
400 |
-
self,
|
401 |
-
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
|
402 |
-
teacher_model: Union[PreTrainedModel, nn.Module, str] = None,
|
403 |
-
args: Optional[GKDConfig] = None,
|
404 |
-
data_collator: Optional[DataCollator] = None, # type: ignore
|
405 |
-
train_dataset: Optional[Dataset] = None,
|
406 |
-
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
407 |
-
processing_class: Optional[
|
408 |
-
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
409 |
-
] = None,
|
410 |
-
compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
|
411 |
-
callbacks: Optional[list[TrainerCallback]] = None,
|
412 |
-
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
413 |
-
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
414 |
-
peft_config: Optional["PeftConfig"] = None,
|
415 |
-
formatting_func: Optional[Callable] = None,
|
416 |
-
):
|
417 |
-
# add remove_unused_columns=False to the dataclass args
|
418 |
-
args.remove_unused_columns = False
|
419 |
-
data_collator = DataCollatorForChatML(tokenizer=processing_class, max_length=args.max_seq_length)
|
420 |
-
|
421 |
-
super().__init__(
|
422 |
-
model,
|
423 |
-
args=args,
|
424 |
-
data_collator=data_collator,
|
425 |
-
train_dataset=train_dataset,
|
426 |
-
eval_dataset=eval_dataset,
|
427 |
-
processing_class=processing_class,
|
428 |
-
compute_metrics=compute_metrics,
|
429 |
-
callbacks=callbacks,
|
430 |
-
optimizers=optimizers,
|
431 |
-
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
432 |
-
peft_config=peft_config,
|
433 |
-
formatting_func=formatting_func,
|
434 |
-
)
|
435 |
-
|
436 |
-
if args.teacher_model_init_kwargs is None:
|
437 |
-
teacher_model_init_kwargs = {}
|
438 |
-
elif not isinstance(teacher_model, str):
|
439 |
-
raise ValueError(
|
440 |
-
"You passed teacher_model_init_kwargs to the GKDConfig, but your teacher_model is already instantiated."
|
441 |
-
)
|
442 |
-
else:
|
443 |
-
teacher_model_init_kwargs = args.teacher_model_init_kwargs
|
444 |
-
teacher_model_init_kwargs["torch_dtype"] = (
|
445 |
-
teacher_model_init_kwargs["torch_dtype"]
|
446 |
-
if teacher_model_init_kwargs["torch_dtype"] in ["auto", None]
|
447 |
-
else getattr(torch, teacher_model_init_kwargs["torch_dtype"])
|
448 |
-
)
|
449 |
-
|
450 |
-
if isinstance(teacher_model, str):
|
451 |
-
if args.use_liger:
|
452 |
-
teacher_model = AutoLigerKernelForCausalLM.from_pretrained(teacher_model, **teacher_model_init_kwargs)
|
453 |
-
else:
|
454 |
-
teacher_model = AutoModelForCausalLM.from_pretrained(teacher_model, **teacher_model_init_kwargs)
|
455 |
-
|
456 |
-
# Disable dropout in the model
|
457 |
-
if args.disable_dropout:
|
458 |
-
disable_dropout_in_model(self.model)
|
459 |
-
|
460 |
-
if self.is_deepspeed_enabled:
|
461 |
-
self.teacher_model = self._prepare_deepspeed(teacher_model)
|
462 |
-
else:
|
463 |
-
self.teacher_model = self.accelerator.prepare_model(teacher_model, evaluation_mode=True)
|
464 |
-
|
465 |
-
self.lmbda = args.lmbda
|
466 |
-
self.beta = args.beta
|
467 |
-
self.temperature = args.temperature
|
468 |
-
self.seq_kd = args.seq_kd
|
469 |
-
|
470 |
-
self.generation_config = GenerationConfig(
|
471 |
-
max_new_tokens=args.max_new_tokens,
|
472 |
-
temperature=args.temperature,
|
473 |
-
do_sample=True,
|
474 |
-
top_k=0,
|
475 |
-
use_cache=False if args.gradient_checkpointing else True,
|
476 |
-
pad_token_id=self.processing_class.pad_token_id,
|
477 |
-
)
|
478 |
-
# Set custom EOS tokens if they are specified by the model's generation
|
479 |
-
# config. This is important for models with the Llama 3 chat template,
|
480 |
-
# which use special tokens <|eot_id|> and <|eom_id|> to mark the end of
|
481 |
-
# turns or messages.
|
482 |
-
if (
|
483 |
-
hasattr(self.model.generation_config, "eos_token_id")
|
484 |
-
and self.model.generation_config.eos_token_id is not None
|
485 |
-
):
|
486 |
-
self.generation_config.eos_token_id = self.model.generation_config.eos_token_id
|
487 |
-
|
488 |
-
def _prepare_dataset(self, dataset, *args):
|
489 |
-
# SFTTrainer._prepare_dataset() applies the chat template and rename the messages column to text. However, we
|
490 |
-
# need to keep the messages column as it is. We use the following workaround to keep the messages column.
|
491 |
-
dataset = dataset.add_column("_messages", dataset["messages"])
|
492 |
-
dataset = super()._prepare_dataset(dataset, *args)
|
493 |
-
dataset = dataset.rename_column("_messages", "messages")
|
494 |
-
return dataset
|
495 |
-
|
496 |
-
@staticmethod
|
497 |
-
def generalized_jsd_loss(
|
498 |
-
student_logits, teacher_logits, labels=None, beta=0.5, temperature=1.0, reduction="batchmean"
|
499 |
-
):
|
500 |
-
"""
|
501 |
-
Compute the generalized Jensen-Shannon Divergence loss for knowledge distillation using F.kl_div. See Eq. (1)
|
502 |
-
of https://huggingface.co/papers/2306.13649 for the definition.
|
503 |
-
|
504 |
-
Args:
|
505 |
-
student_logits: Tensor of shape (batch_size, sequence_length, vocab_size)
|
506 |
-
teacher_logits: Tensor of shape (batch_size, sequence_length, vocab_size)
|
507 |
-
labels: Tensor of shape (batch_size, sequence_length) with -100 for padding tokens to ignore when computing loss
|
508 |
-
beta: Interpolation coefficient between 0 and 1 (default: 0.5)
|
509 |
-
temperature: Softmax temperature (default: 1.0)
|
510 |
-
reduction: Specifies the reduction to apply to the output (default: 'batchmean')
|
511 |
-
|
512 |
-
Returns:
|
513 |
-
loss: Scalar tensor with the generalized JSD loss
|
514 |
-
"""
|
515 |
-
|
516 |
-
# Apply temperature scaling
|
517 |
-
student_logits = student_logits / temperature
|
518 |
-
teacher_logits = teacher_logits / temperature
|
519 |
-
|
520 |
-
# Compute log probabilities for student and probabilities for teacher
|
521 |
-
student_log_probs = F.log_softmax(student_logits, dim=-1)
|
522 |
-
teacher_log_probs = F.log_softmax(teacher_logits, dim=-1)
|
523 |
-
|
524 |
-
# Compute the log of the mixture distribution
|
525 |
-
# log(a + b) = log(exp(log(a)) + exp(log(b))) -> for mixture
|
526 |
-
beta = torch.tensor(beta, dtype=student_log_probs.dtype)
|
527 |
-
mixture_log_probs = torch.logsumexp(
|
528 |
-
torch.stack([student_log_probs + torch.log(beta), teacher_log_probs + torch.log(1 - beta)]),
|
529 |
-
dim=0,
|
530 |
-
)
|
531 |
-
|
532 |
-
# Compute KL divergences using F.kl_div
|
533 |
-
# PyTorch differs from the standard mathematical definition, so the order of the probability distributions is swapped compared to that defined in the paper.
|
534 |
-
kl_teacher = F.kl_div(mixture_log_probs, teacher_log_probs, reduction="none", log_target=True)
|
535 |
-
kl_student = F.kl_div(mixture_log_probs, student_log_probs, reduction="none", log_target=True)
|
536 |
-
|
537 |
-
# Compute the Generalized Jensen-Shannon Divergence
|
538 |
-
jsd = beta * kl_teacher + (1 - beta) * kl_student
|
539 |
-
|
540 |
-
# Masking
|
541 |
-
if labels is not None:
|
542 |
-
mask = labels != -100
|
543 |
-
jsd = jsd[mask]
|
544 |
-
|
545 |
-
# Apply reduction
|
546 |
-
if reduction == "batchmean":
|
547 |
-
return jsd.sum() / mask.sum() if labels is not None else jsd.sum() / (jsd.size(0) * jsd.size(1))
|
548 |
-
elif reduction == "sum":
|
549 |
-
return jsd.sum()
|
550 |
-
elif reduction == "mean":
|
551 |
-
return jsd.mean()
|
552 |
-
else:
|
553 |
-
return jsd
|
554 |
-
|
555 |
-
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
|
556 |
-
# compute student output
|
557 |
-
outputs_student = model(
|
558 |
-
input_ids=inputs["input_ids"],
|
559 |
-
attention_mask=inputs["attention_mask"],
|
560 |
-
)
|
561 |
-
|
562 |
-
# compute teacher output in eval mode
|
563 |
-
self.teacher_model.eval()
|
564 |
-
with torch.no_grad():
|
565 |
-
outputs_teacher = self.teacher_model(
|
566 |
-
input_ids=inputs["input_ids"],
|
567 |
-
attention_mask=inputs["attention_mask"],
|
568 |
-
)
|
569 |
-
|
570 |
-
# slice the logits for the generated tokens using the inputs["prompts"] lengths
|
571 |
-
prompt_lengths = inputs["prompts"].shape[1]
|
572 |
-
shifted_student_logits = outputs_student.logits[:, prompt_lengths - 1 : -1, :]
|
573 |
-
shifted_teacher_logits = outputs_teacher.logits[:, prompt_lengths - 1 : -1, :]
|
574 |
-
shifted_labels = inputs["labels"][:, prompt_lengths:]
|
575 |
-
|
576 |
-
# compute loss
|
577 |
-
loss = self.generalized_jsd_loss(
|
578 |
-
student_logits=shifted_student_logits,
|
579 |
-
teacher_logits=shifted_teacher_logits,
|
580 |
-
labels=shifted_labels,
|
581 |
-
beta=self.beta,
|
582 |
-
)
|
583 |
-
|
584 |
-
# empty cache
|
585 |
-
empty_cache()
|
586 |
-
|
587 |
-
# Return loss
|
588 |
-
return (loss, outputs_student) if return_outputs else loss
|
589 |
-
|
590 |
-
@staticmethod
|
591 |
-
def generate_on_policy_outputs(model, inputs, generation_config, pad_token_id=None):
|
592 |
-
# Generate output with respect to the prompt only
|
593 |
-
generated_outputs = model.generate(
|
594 |
-
input_ids=inputs["prompts"],
|
595 |
-
attention_mask=inputs.get("prompt_attention_mask", None),
|
596 |
-
generation_config=generation_config,
|
597 |
-
return_dict_in_generate=True,
|
598 |
-
)
|
599 |
-
|
600 |
-
# Get the generated token IDs
|
601 |
-
generated_tokens = generated_outputs.sequences
|
602 |
-
# Calculate new attention mask
|
603 |
-
new_attention_mask = torch.ones_like(generated_tokens)
|
604 |
-
new_labels = generated_tokens.clone()
|
605 |
-
|
606 |
-
# If there's pad_token_id, set attention mask to 0 for padding tokens
|
607 |
-
if pad_token_id is not None:
|
608 |
-
new_labels[new_labels == pad_token_id] = -100
|
609 |
-
new_attention_mask[generated_tokens == pad_token_id] = 0
|
610 |
-
|
611 |
-
return generated_tokens, new_attention_mask, new_labels
|
612 |
-
|
613 |
-
def training_step(
|
614 |
-
self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
|
615 |
-
) -> torch.Tensor:
|
616 |
-
"""
|
617 |
-
Perform a training step for the Generalized Knowledge Distillation (GKD) model.
|
618 |
-
|
619 |
-
This method implements the on-policy learning approach described in the GKD paper.
|
620 |
-
With probability `self.lmbda`, it generates new responses using the student model,
|
621 |
-
which are then used for training instead of the original inputs.
|
622 |
-
"""
|
623 |
-
if self.seq_kd:
|
624 |
-
with unwrap_model_for_generation(self.teacher_model, self.accelerator) as unwrapped_model:
|
625 |
-
new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs(
|
626 |
-
unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id
|
627 |
-
)
|
628 |
-
inputs["input_ids"] = new_input_ids
|
629 |
-
inputs["attention_mask"] = new_attention_mask
|
630 |
-
inputs["labels"] = new_labels
|
631 |
-
if random.random() <= self.lmbda:
|
632 |
-
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
|
633 |
-
new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs(
|
634 |
-
unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id
|
635 |
-
)
|
636 |
-
inputs["input_ids"] = new_input_ids
|
637 |
-
inputs["attention_mask"] = new_attention_mask
|
638 |
-
inputs["labels"] = new_labels
|
639 |
-
|
640 |
-
loss = super().training_step(model, inputs, num_items_in_batch)
|
641 |
-
return loss
|
642 |
-
|
643 |
-
def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
|
644 |
-
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
|
645 |
-
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
|
646 |
-
config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
|
647 |
-
|
648 |
-
if model is not None:
|
649 |
-
if hasattr(model, "config"):
|
650 |
-
hidden_size = (
|
651 |
-
max(model.config.hidden_sizes)
|
652 |
-
if getattr(model.config, "hidden_sizes", None)
|
653 |
-
else getattr(model.config, "hidden_size", None)
|
654 |
-
)
|
655 |
-
if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
|
656 |
-
# Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
|
657 |
-
# This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
|
658 |
-
config_kwargs.update(
|
659 |
-
{
|
660 |
-
"zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
|
661 |
-
"zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
|
662 |
-
"zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
|
663 |
-
}
|
664 |
-
)
|
665 |
-
|
666 |
-
# If ZeRO-3 is used, we shard both the active and reference model.
|
667 |
-
# Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
|
668 |
-
if config_kwargs["zero_optimization"]["stage"] != 3:
|
669 |
-
config_kwargs["zero_optimization"]["stage"] = 0
|
670 |
-
model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
|
671 |
-
model.eval()
|
672 |
-
return model
|
673 |
-
|
674 |
-
def create_model_card(
|
675 |
-
self,
|
676 |
-
model_name: Optional[str] = None,
|
677 |
-
dataset_name: Optional[str] = None,
|
678 |
-
tags: Union[str, list[str], None] = None,
|
679 |
-
):
|
680 |
-
"""
|
681 |
-
Creates a draft of a model card using the information available to the `Trainer`.
|
682 |
-
|
683 |
-
Args:
|
684 |
-
model_name (`str` or `None`, *optional*, defaults to `None`):
|
685 |
-
Name of the model.
|
686 |
-
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
687 |
-
Name of the dataset used for training.
|
688 |
-
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
689 |
-
Tags to be associated with the model card.
|
690 |
-
"""
|
691 |
-
if not self.is_world_process_zero():
|
692 |
-
return
|
693 |
-
|
694 |
-
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
695 |
-
base_model = self.model.config._name_or_path
|
696 |
-
else:
|
697 |
-
base_model = None
|
698 |
-
|
699 |
-
tags = tags or []
|
700 |
-
if isinstance(tags, str):
|
701 |
-
tags = [tags]
|
702 |
-
|
703 |
-
if hasattr(self.model.config, "unsloth_version"):
|
704 |
-
tags.append("unsloth")
|
705 |
-
|
706 |
-
citation = textwrap.dedent("""\
|
707 |
-
@inproceedings{agarwal2024on-policy,
|
708 |
-
title = {{On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes}},
|
709 |
-
author = {Rishabh Agarwal and Nino Vieillard and Yongchao Zhou and Piotr Stanczyk and Sabela Ramos Garea and Matthieu Geist and Olivier Bachem},
|
710 |
-
year = 2024,
|
711 |
-
booktitle = {The Twelfth International Conference on Learning Representations, {ICLR} 2024, Vienna, Austria, May 7-11, 2024},
|
712 |
-
publisher = {OpenReview.net},
|
713 |
-
url = {https://openreview.net/forum?id=3zKtaqxLhW},
|
714 |
-
}""")
|
715 |
-
|
716 |
-
model_card = generate_model_card(
|
717 |
-
base_model=base_model,
|
718 |
-
model_name=model_name,
|
719 |
-
hub_model_id=self.hub_model_id,
|
720 |
-
dataset_name=dataset_name,
|
721 |
-
tags=tags,
|
722 |
-
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
723 |
-
comet_url=get_comet_experiment_url(),
|
724 |
-
trainer_name="GKD",
|
725 |
-
trainer_citation=citation,
|
726 |
-
paper_title="On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes",
|
727 |
-
paper_id="2306.13649",
|
728 |
-
)
|
729 |
-
|
730 |
-
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
731 |
-
class UnslothGKDTrainer(_UnslothGKDTrainer):
|
732 |
-
"""
|
733 |
-
|
734 |
-
"""
|
735 |
-
def __init__(
|
736 |
-
self,
|
737 |
-
model = None,
|
738 |
-
teacher_model = None,
|
739 |
-
args = None,
|
740 |
-
data_collator = None,
|
741 |
-
train_dataset = None,
|
742 |
-
eval_dataset = None,
|
743 |
-
processing_class = None,
|
744 |
-
compute_metrics = None,
|
745 |
-
callbacks = None,
|
746 |
-
preprocess_logits_for_metrics = None,
|
747 |
-
peft_config = None,
|
748 |
-
formatting_func = None,
|
749 |
-
**kwargs
|
750 |
-
):
|
751 |
-
if args is None: args = UnslothGKDConfig()
|
752 |
-
use_bf16 = getattr(args, 'bf16', False)
|
753 |
-
use_fp16 = getattr(args, 'fp16', False)
|
754 |
-
force_float32 = False
|
755 |
-
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
756 |
-
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
757 |
-
force_float32 = True
|
758 |
-
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
759 |
-
dtype = getattr(model.config, 'torch_dtype', None)
|
760 |
-
if dtype is None: dtype = model.get_input_embeddings().dtype
|
761 |
-
from unsloth_zoo.utils import _get_dtype
|
762 |
-
dtype = _get_dtype(dtype)
|
763 |
-
float16 = dtype == torch.float16
|
764 |
-
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
765 |
-
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
766 |
-
if force_float32:
|
767 |
-
args.fp16 = False
|
768 |
-
args.bf16 = False
|
769 |
-
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
770 |
-
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
771 |
-
args.fp16 = float16
|
772 |
-
args.bf16 = not float16
|
773 |
-
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
774 |
-
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
775 |
-
args.eval_strategy = 'steps'
|
776 |
-
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
777 |
-
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
778 |
-
if ga_steps is not None and ga_steps > 1:
|
779 |
-
from transformers import __version__ as transformers_version
|
780 |
-
if Version(transformers_version) <= Version('4.45.2'):
|
781 |
-
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
782 |
-
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
783 |
-
if getattr(args, 'eval_strategy', 'no') != 'no':
|
784 |
-
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
785 |
-
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
786 |
-
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
787 |
-
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
788 |
-
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
789 |
-
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
790 |
-
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
791 |
-
if force_float32:
|
792 |
-
args.bf16_full_eval = False
|
793 |
-
args.fp16_full_eval = False
|
794 |
-
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
795 |
-
args.bf16_full_eval = True
|
796 |
-
args.fp16_full_eval = False
|
797 |
-
elif not bf16_full_eval and not fp16_full_eval:
|
798 |
-
args.bf16_full_eval = args.bf16
|
799 |
-
args.fp16_full_eval = args.fp16
|
800 |
-
_output_logits = False
|
801 |
-
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
802 |
-
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
803 |
-
if _output_logits:
|
804 |
-
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
805 |
-
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
806 |
-
pass
|
807 |
-
else:
|
808 |
-
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
809 |
-
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
810 |
-
if args_max_seq_length is None and model_max_seq_length is not None:
|
811 |
-
max_seq_length = model.max_seq_length
|
812 |
-
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
813 |
-
if model is not None and hasattr(model, 'for_training'):
|
814 |
-
model.for_training()
|
815 |
-
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
816 |
-
if 'processing_class' in locals():
|
817 |
-
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
818 |
-
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
819 |
-
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
820 |
-
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
821 |
-
if not isinstance(data_collator, UnslothVisionDataCollator):
|
822 |
-
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
823 |
-
data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
|
824 |
-
elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
825 |
-
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
826 |
-
else:
|
827 |
-
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
828 |
-
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
829 |
-
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
830 |
-
if not isinstance(data_collator, UnslothVisionDataCollator):
|
831 |
-
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
832 |
-
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
833 |
-
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
834 |
-
else:
|
835 |
-
data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
|
836 |
-
other_metrics = []
|
837 |
-
|
838 |
-
from unsloth_zoo.logging_utils import PatchRLStatistics
|
839 |
-
PatchRLStatistics('gkd_trainer', other_metrics)
|
840 |
-
|
841 |
-
super().__init__(
|
842 |
-
model = model,
|
843 |
-
teacher_model = teacher_model,
|
844 |
-
args = args,
|
845 |
-
data_collator = data_collator,
|
846 |
-
train_dataset = train_dataset,
|
847 |
-
eval_dataset = eval_dataset,
|
848 |
-
processing_class = processing_class,
|
849 |
-
compute_metrics = compute_metrics,
|
850 |
-
callbacks = callbacks,
|
851 |
-
preprocess_logits_for_metrics = preprocess_logits_for_metrics,
|
852 |
-
peft_config = peft_config,
|
853 |
-
formatting_func = formatting_func,**kwargs)
|
854 |
-
if hasattr(self, 'neftune_hook_handle'):
|
855 |
-
self.neftune_hook_handle.remove()
|
856 |
-
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
857 |
-
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
858 |
-
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
859 |
-
pass
|
860 |
-
|
861 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unsloth_compiled_cache/UnslothGRPOTrainer.py
DELETED
@@ -1,1436 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
2025.3.13
|
3 |
-
2025.3.15
|
4 |
-
4.48.3
|
5 |
-
0.15.2
|
6 |
-
__UNSLOTH_VERSIONING__
|
7 |
-
"""
|
8 |
-
from torch import Tensor
|
9 |
-
import torch
|
10 |
-
import torch.nn as nn
|
11 |
-
from torch.nn import functional as F
|
12 |
-
from trl.trainer.grpo_trainer import (Any, AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, Dataset, GRPOConfig, GRPOTrainer, GenerationConfig, IterableDataset, Optional, PeftConfig, PreTrainedModel, PreTrainedTokenizerBase, RepeatRandomSampler, RewardFunc, Sampler, SyncRefModelCallback, Trainer, TrainerCallback, Union, apply_chat_template, broadcast_object_list, create_reference_model, defaultdict, gather, gather_object, generate_model_card, get_comet_experiment_url, is_conversational, is_deepspeed_zero3_enabled, is_peft_model, is_wandb_available, maybe_apply_chat_template, nn, os, pad, patch, prepare_deepspeed, set_seed, textwrap, torch, transformers, unwrap_model_for_generation, version, wandb, warnings, os, torch, transformers, Any, Union, apply_chat_template, broadcast_object_list, gather, gather_object, is_conversational, maybe_apply_chat_template, nn, os, pad, torch, unwrap_model_for_generation, wandb, GRPOTrainer, Trainer, gather, os, torch)
|
13 |
-
|
14 |
-
|
15 |
-
import os
|
16 |
-
from typing import *
|
17 |
-
from dataclasses import dataclass, field
|
18 |
-
from packaging.version import Version
|
19 |
-
import torch
|
20 |
-
import numpy as np
|
21 |
-
from contextlib import nullcontext
|
22 |
-
from torch.nn import functional as F
|
23 |
-
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
|
24 |
-
|
25 |
-
torch_compile_options = {
|
26 |
-
"epilogue_fusion" : True,
|
27 |
-
"max_autotune" : False,
|
28 |
-
"shape_padding" : True,
|
29 |
-
"trace.enabled" : False,
|
30 |
-
"triton.cudagraphs" : False,
|
31 |
-
}
|
32 |
-
|
33 |
-
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
34 |
-
def selective_log_softmax(logits, index):
|
35 |
-
logits = logits.to(torch.float32)
|
36 |
-
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
37 |
-
# loop to reduce peak mem consumption
|
38 |
-
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
39 |
-
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
40 |
-
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
41 |
-
return per_token_logps
|
42 |
-
|
43 |
-
def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta, advantages):
|
44 |
-
# All Unsloth Zoo code licensed under LGPLv3
|
45 |
-
old_logits = old_logits.to(torch.float32)
|
46 |
-
new_logits = new_logits.to(torch.float32)
|
47 |
-
input_ids = input_ids.unsqueeze(-1)
|
48 |
-
|
49 |
-
# x_i - logsumexp(x_i)
|
50 |
-
old_x = torch.gather(old_logits, dim = -1, index = input_ids).squeeze(-1)
|
51 |
-
new_x = torch.gather(new_logits, dim = -1, index = input_ids).squeeze(-1)
|
52 |
-
old = old_x - torch.logsumexp(old_logits, dim = -1)
|
53 |
-
new = new_x - torch.logsumexp(new_logits, dim = -1)
|
54 |
-
|
55 |
-
# Reverse KL
|
56 |
-
kl_i = torch.exp(old - new) - (old - new) - 1.0
|
57 |
-
# Full correct reverse KL divergence?? Missing term maybe?
|
58 |
-
# kl_i = torch.exp(new) * kl_i
|
59 |
-
|
60 |
-
# Below is forward KL (normal KL)
|
61 |
-
# kl_i = torch.exp(old) * (old - new)
|
62 |
-
|
63 |
-
# Must detach - otherwise gradients are not propagated correctly!
|
64 |
-
# exp(x - x) == 1
|
65 |
-
loss_i = torch.exp(new - new.detach()) * advantages.unsqueeze(1)
|
66 |
-
loss_i = -(loss_i - beta * kl_i)
|
67 |
-
|
68 |
-
mask = mask.to(torch.float32)
|
69 |
-
n_mask_per_reward = mask.sum(1)
|
70 |
-
|
71 |
-
# See https://github.com/huggingface/trl/pull/2881
|
72 |
-
loss_per_reward = (loss_i * mask).sum(1) / n_mask_per_reward
|
73 |
-
loss = loss_per_reward.mean()
|
74 |
-
# loss = (loss_i * mask).sum() / mask.sum()
|
75 |
-
|
76 |
-
# Get metrics as well which are folded
|
77 |
-
with torch.inference_mode():
|
78 |
-
completion_length = n_mask_per_reward.mean()
|
79 |
-
mean_kl_per_reward = (kl_i * mask).sum(1) / n_mask_per_reward
|
80 |
-
mean_kl = mean_kl_per_reward.mean()
|
81 |
-
pass
|
82 |
-
return loss, completion_length, mean_kl
|
83 |
-
|
84 |
-
class UnslothEfficientGRPO(torch.autograd.Function):
|
85 |
-
# All Unsloth Zoo code licensed under LGPLv3
|
86 |
-
@staticmethod
|
87 |
-
def forward(ctx, _new_hidden_states, _old_hidden_states, lm_head, _input_ids, _mask, _advantages, beta, scaler = None, n_chunks = 1):
|
88 |
-
def compute_loss(new_hidden_states, old_hidden_states, input_ids, mask, advantages, scaling):
|
89 |
-
new_logits = torch.matmul(new_hidden_states, lm_head.t())
|
90 |
-
new_logits = new_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred
|
91 |
-
old_logits = torch.matmul(old_hidden_states, lm_head.t())
|
92 |
-
old_logits = old_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred
|
93 |
-
loss, completion_length, mean_kl = grpo_compute_loss(
|
94 |
-
old_logits, new_logits, input_ids, mask, beta, advantages,
|
95 |
-
)
|
96 |
-
# Scale loss if needed for mixed precision training
|
97 |
-
scaled_loss = loss * scaling
|
98 |
-
# Must add .loss.detach otherwise autograd uses 2x VRAM
|
99 |
-
return scaled_loss, (loss.detach(), completion_length, mean_kl,)
|
100 |
-
pass
|
101 |
-
|
102 |
-
device =_new_hidden_states.device
|
103 |
-
grad_inputs = torch.empty_like(_new_hidden_states)
|
104 |
-
accumulated_loss = torch.zeros(1, device = device)
|
105 |
-
accumulated_completion_length = torch.zeros(1, device = device)
|
106 |
-
accumulated_mean_kl = torch.zeros(1, device = device)
|
107 |
-
|
108 |
-
def accumulate_chunk(new_hidden_states_j, old_hidden_states_j, input_ids_j, mask_j, advantages_j, scaling):
|
109 |
-
(chunk_grad_input,), (chunk_loss, (unscaled_loss, chunk_completion_length, chunk_mean_kl,)) = torch.func.grad_and_value(
|
110 |
-
compute_loss,
|
111 |
-
argnums = (0,),
|
112 |
-
has_aux = True,
|
113 |
-
)(new_hidden_states_j, old_hidden_states_j, input_ids_j, mask_j, advantages_j, scaling)
|
114 |
-
accumulated_loss .add_(unscaled_loss)
|
115 |
-
accumulated_completion_length.add_(chunk_completion_length)
|
116 |
-
accumulated_mean_kl .add_(chunk_mean_kl)
|
117 |
-
return chunk_grad_input
|
118 |
-
pass
|
119 |
-
|
120 |
-
accumulate_chunk = torch.compile(
|
121 |
-
accumulate_chunk,
|
122 |
-
fullgraph = True,
|
123 |
-
options = torch_compile_options,
|
124 |
-
)
|
125 |
-
|
126 |
-
grad_inputs_chunks = torch.chunk(grad_inputs, chunks = n_chunks, dim = 0)
|
127 |
-
new_hidden_states = torch.chunk(_new_hidden_states, chunks = n_chunks, dim = 0)
|
128 |
-
old_hidden_states = torch.chunk(_old_hidden_states, chunks = n_chunks, dim = 0)
|
129 |
-
input_ids = torch.chunk(_input_ids, chunks = n_chunks, dim = 0)
|
130 |
-
mask = torch.chunk(_mask, chunks = n_chunks, dim = 0)
|
131 |
-
advantages = torch.chunk(_advantages, chunks = n_chunks, dim = 0)
|
132 |
-
|
133 |
-
# Get mixed precision scaling if seen
|
134 |
-
scaling = scaler.get_scale() if scaler is not None else 1.0
|
135 |
-
|
136 |
-
# Force torch.compile to use dynamic shapes for seqlen dim
|
137 |
-
mark_dynamic = lambda x: torch._dynamo.mark_dynamic(x, 1)
|
138 |
-
|
139 |
-
for (grad_inputs_j, new_hidden_states_j, old_hidden_states_j, input_ids_j, mask_j, advantages_j,) in \
|
140 |
-
zip(grad_inputs_chunks, new_hidden_states, old_hidden_states, input_ids, mask, advantages):
|
141 |
-
|
142 |
-
mark_dynamic(new_hidden_states_j)
|
143 |
-
mark_dynamic(old_hidden_states_j)
|
144 |
-
mark_dynamic(input_ids_j)
|
145 |
-
mark_dynamic(mask_j)
|
146 |
-
|
147 |
-
grad_inputs_j.copy_(
|
148 |
-
accumulate_chunk(new_hidden_states_j, old_hidden_states_j, input_ids_j, mask_j, advantages_j, scaling)
|
149 |
-
)
|
150 |
-
pass
|
151 |
-
|
152 |
-
grad_inputs .div_(n_chunks)
|
153 |
-
accumulated_loss .div_(n_chunks)
|
154 |
-
accumulated_completion_length.div_(n_chunks)
|
155 |
-
accumulated_mean_kl .div_(n_chunks)
|
156 |
-
ctx.save_for_backward(grad_inputs)
|
157 |
-
|
158 |
-
return (
|
159 |
-
accumulated_loss,
|
160 |
-
accumulated_completion_length,
|
161 |
-
accumulated_mean_kl,
|
162 |
-
)
|
163 |
-
pass
|
164 |
-
|
165 |
-
@staticmethod
|
166 |
-
def backward(ctx, grad_output, dcompletion_length, dmean_kl):
|
167 |
-
(grad_input,) = ctx.saved_tensors
|
168 |
-
return (grad_input, None, None, None, None, None, None, None, None,)
|
169 |
-
pass
|
170 |
-
|
171 |
-
def grpo_accumulated_loss(
|
172 |
-
trainer,
|
173 |
-
input_ids,
|
174 |
-
logits_to_keep,
|
175 |
-
completion_mask,
|
176 |
-
advantages,
|
177 |
-
n_chunks = -1,
|
178 |
-
):
|
179 |
-
# All Unsloth Zoo code licensed under LGPLv3
|
180 |
-
bsz, qlen = input_ids.shape
|
181 |
-
# Find closest multiple
|
182 |
-
factors = [i for i in range(1, bsz + 1) if bsz % i == 0]
|
183 |
-
if n_chunks == -1: n_chunks = bsz
|
184 |
-
n_chunks = factors[min(np.searchsorted(factors, n_chunks), len(factors)-1)]
|
185 |
-
|
186 |
-
mixed_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16
|
187 |
-
os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1"
|
188 |
-
|
189 |
-
completion_input_ids = input_ids[:, -logits_to_keep:]
|
190 |
-
lm_head = trainer.model.get_output_embeddings().weight
|
191 |
-
|
192 |
-
with torch.amp.autocast(device_type = "cuda", dtype = mixed_dtype):
|
193 |
-
with torch.inference_mode(), trainer.accelerator.unwrap_model(trainer.model, keep_fp32_wrapper = False).disable_adapter():
|
194 |
-
old_hidden_states = trainer.model(input_ids = input_ids, logits_to_keep = logits_to_keep + 1).logits
|
195 |
-
pass
|
196 |
-
|
197 |
-
new_hidden_states = trainer.model(input_ids = input_ids, logits_to_keep = logits_to_keep + 1).logits
|
198 |
-
|
199 |
-
loss, completion_length, mean_kl = UnslothEfficientGRPO.apply(
|
200 |
-
new_hidden_states, old_hidden_states, lm_head,
|
201 |
-
completion_input_ids, completion_mask, advantages, trainer.beta,
|
202 |
-
trainer.accelerator.scaler,
|
203 |
-
n_chunks,
|
204 |
-
)
|
205 |
-
return loss, completion_length, mean_kl
|
206 |
-
|
207 |
-
# Old non efficient code path
|
208 |
-
new_logits = torch.matmul(new_hidden_states, lm_head.t())
|
209 |
-
new_logits = new_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred
|
210 |
-
old_logits = torch.matmul(old_hidden_states, lm_head.t())
|
211 |
-
old_logits = old_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred
|
212 |
-
loss, completion_length, mean_kl = grpo_compute_loss(
|
213 |
-
old_logits, new_logits, completion_input_ids, completion_mask, trainer.beta, advantages,
|
214 |
-
)
|
215 |
-
return loss, completion_length, mean_kl
|
216 |
-
pass
|
217 |
-
|
218 |
-
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options)
|
219 |
-
def grpo_compute_loss_slow(old_logits, new_logits, input_ids, mask, beta, advantages):
|
220 |
-
# All Unsloth Zoo code licensed under LGPLv3
|
221 |
-
old_logits = old_logits.to(torch.float32)
|
222 |
-
new_logits = new_logits.to(torch.float32)
|
223 |
-
input_ids = input_ids.unsqueeze(-1)
|
224 |
-
|
225 |
-
# x_i - logsumexp(x_i)
|
226 |
-
old_x = torch.gather(old_logits, dim = -1, index = input_ids).squeeze(-1)
|
227 |
-
new_x = torch.gather(new_logits, dim = -1, index = input_ids).squeeze(-1)
|
228 |
-
old = old_x - torch.logsumexp(old_logits, dim = -1)
|
229 |
-
new = new_x - torch.logsumexp(new_logits, dim = -1)
|
230 |
-
|
231 |
-
# Reverse KL
|
232 |
-
kl_i = torch.exp(old - new) - (old - new) - 1.0
|
233 |
-
# Full correct reverse KL divergence?? Missing term maybe?
|
234 |
-
# kl_i = torch.exp(new) * kl_i
|
235 |
-
|
236 |
-
# Below is forward KL (normal KL)
|
237 |
-
# kl_i = torch.exp(old) * (old - new)
|
238 |
-
|
239 |
-
# Must detach - otherwise gradients are not propagated correctly!
|
240 |
-
# exp(x - x) == 1
|
241 |
-
loss_i = torch.exp(new - new.detach()) * advantages.unsqueeze(1)
|
242 |
-
loss_i = -(loss_i - beta * kl_i)
|
243 |
-
|
244 |
-
mask = mask.to(torch.float32)
|
245 |
-
n_mask_per_reward = mask.sum(1)
|
246 |
-
|
247 |
-
# See https://github.com/huggingface/trl/pull/2881
|
248 |
-
loss_per_reward = (loss_i * mask).sum(1) / n_mask_per_reward
|
249 |
-
loss = loss_per_reward.mean()
|
250 |
-
# loss = (loss_i * mask).sum() / mask.sum()
|
251 |
-
|
252 |
-
# Get metrics as well which are folded
|
253 |
-
with torch.inference_mode():
|
254 |
-
completion_length = n_mask_per_reward.mean()
|
255 |
-
mean_kl_per_reward = (kl_i * mask).sum(1) / n_mask_per_reward
|
256 |
-
mean_kl = mean_kl_per_reward.mean()
|
257 |
-
pass
|
258 |
-
return loss, completion_length, mean_kl
|
259 |
-
|
260 |
-
def vLLMSamplingParams(**kwargs):
|
261 |
-
from vllm import SamplingParams
|
262 |
-
sampling_params = SamplingParams(**kwargs)
|
263 |
-
sampling_params._set_kwargs = kwargs
|
264 |
-
return sampling_params
|
265 |
-
@dataclass
|
266 |
-
class UnslothGRPOConfig(GRPOConfig):
|
267 |
-
"""
|
268 |
-
|
269 |
-
Configuration class for the [`GRPOTrainer`].
|
270 |
-
|
271 |
-
Only the parameters specific to GRPO training are listed here. For details on other parameters, refer to the
|
272 |
-
[`~transformers.TrainingArguments`] documentation.
|
273 |
-
|
274 |
-
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
275 |
-
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
276 |
-
command line.
|
277 |
-
|
278 |
-
Parameters:
|
279 |
-
> Parameters that control the model and reference model
|
280 |
-
|
281 |
-
model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
|
282 |
-
Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model`
|
283 |
-
argument of the [`GRPOTrainer`] is provided as a string.
|
284 |
-
|
285 |
-
> Parameters that control the data preprocessing
|
286 |
-
|
287 |
-
remove_unused_columns (`bool`, *optional*, defaults to `False`):
|
288 |
-
Whether to only keep the column `"prompt"` in the dataset. If you use a custom reward function that
|
289 |
-
requires any column other than `"prompts"` and `"completions"`, you should keep this to `False`.
|
290 |
-
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
|
291 |
-
Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left.
|
292 |
-
num_generations (`int` or `None`, *optional*, defaults to `8`):
|
293 |
-
Number of generations per prompt to sample. The global batch size (num_processes * per_device_batch_size)
|
294 |
-
must be divisible by this value.
|
295 |
-
temperature (`float`, *optional*, defaults to `0.9`):
|
296 |
-
Temperature for sampling. The higher the temperature, the more random the completions.
|
297 |
-
max_completion_length (`int` or `None`, *optional*, defaults to `256`):
|
298 |
-
Maximum length of the generated completion.
|
299 |
-
ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
|
300 |
-
This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
|
301 |
-
improving generation speed. However, disabling this option allows training models that exceed the VRAM
|
302 |
-
capacity of a single GPU, albeit at the cost of slower generation. Disabling this option is not compatible
|
303 |
-
with vLLM generation.
|
304 |
-
|
305 |
-
> Parameters that control generation acceleration powered by vLLM
|
306 |
-
|
307 |
-
use_vllm (`bool`, *optional*, defaults to `False`):
|
308 |
-
Whether to use vLLM for generating completions. If set to `True`, ensure that a GPU is kept unused for
|
309 |
-
training, as vLLM will require one for generation. vLLM must be installed (`pip install vllm`).
|
310 |
-
vllm_device (`str`, *optional*, defaults to `"auto"`):
|
311 |
-
Device where vLLM generation will run, e.g. `"cuda:1"`. If set to `"auto"` (default), the system will
|
312 |
-
automatically select the next available GPU after the last one used for training. This assumes that
|
313 |
-
training has not already occupied all available GPUs. If only one device is available, the device will be
|
314 |
-
shared between both training and vLLM.
|
315 |
-
vllm_gpu_memory_utilization (`float`, *optional*, defaults to `0.9`):
|
316 |
-
Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache on the
|
317 |
-
device dedicated to generation powered by vLLM. Higher values will increase the KV cache size and thus
|
318 |
-
improve the model's throughput. However, if the value is too high, it may cause out-of-memory (OOM) errors
|
319 |
-
during initialization.
|
320 |
-
vllm_dtype (`str`, *optional*, defaults to `"auto"`):
|
321 |
-
Data type to use for vLLM generation. If set to `"auto"`, the data type will be automatically determined
|
322 |
-
based on the model configuration. Find the supported values in the vLLM documentation.
|
323 |
-
vllm_max_model_len (`int` or `None`, *optional*, defaults to `None`):
|
324 |
-
If set, the `max_model_len` to use for vLLM. This could be useful when running with reduced
|
325 |
-
`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model
|
326 |
-
context size, which might be much larger than the KV cache, leading to inefficiencies.
|
327 |
-
|
328 |
-
> Parameters that control the training
|
329 |
-
|
330 |
-
learning_rate (`float`, *optional*, defaults to `1e-6`):
|
331 |
-
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
|
332 |
-
[`~transformers.TrainingArguments`].
|
333 |
-
beta (`float`, *optional*, defaults to `0.04`):
|
334 |
-
KL coefficient.
|
335 |
-
reward_weights (`list[float]` or `None`, *optional*, defaults to `None`):
|
336 |
-
Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are
|
337 |
-
weighted equally with weight `1.0`.
|
338 |
-
sync_ref_model (`bool`, *optional*, defaults to `False`):
|
339 |
-
Whether to synchronize the reference model with the active model every `ref_model_sync_steps` steps, using
|
340 |
-
the `ref_model_mixup_alpha` parameter. This synchronization originites from the
|
341 |
-
[TR-DPO](https://huggingface.co/papers/2404.09656) paper.
|
342 |
-
ref_model_mixup_alpha (`float`, *optional*, defaults to `0.9`):
|
343 |
-
α parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which controls the mix
|
344 |
-
between the current policy and the previous reference policy during updates. The reference policy is
|
345 |
-
updated according to the equation: `π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you
|
346 |
-
must set `sync_ref_model=True`.
|
347 |
-
ref_model_sync_steps (`int`, *optional*, defaults to `64`):
|
348 |
-
τ parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which determines how
|
349 |
-
frequently the current policy is synchronized with the reference policy. To use this parameter, you must
|
350 |
-
set `sync_ref_model=True`.
|
351 |
-
|
352 |
-
> Parameters that control the logging
|
353 |
-
|
354 |
-
log_completions (`bool`, *optional*, defaults to `False`):
|
355 |
-
Whether to log the completions during training.
|
356 |
-
|
357 |
-
"""
|
358 |
-
vllm_sampling_params: Optional[Any] = field(
|
359 |
-
default = None,
|
360 |
-
metadata = {'help': 'vLLM SamplingParams'},
|
361 |
-
)
|
362 |
-
unsloth_num_chunks : Optional[int] = field(
|
363 |
-
default = -1,
|
364 |
-
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
365 |
-
)
|
366 |
-
def __init__(
|
367 |
-
self,
|
368 |
-
output_dir = None,
|
369 |
-
overwrite_output_dir = None,
|
370 |
-
do_train = False,
|
371 |
-
do_eval = False,
|
372 |
-
do_predict = False,
|
373 |
-
eval_strategy = 'no',
|
374 |
-
prediction_loss_only = False,
|
375 |
-
per_device_train_batch_size = 4,
|
376 |
-
per_device_eval_batch_size = 4,
|
377 |
-
per_gpu_train_batch_size = None,
|
378 |
-
per_gpu_eval_batch_size = None,
|
379 |
-
gradient_accumulation_steps = 2,
|
380 |
-
eval_accumulation_steps = 2,
|
381 |
-
eval_delay = 0,
|
382 |
-
torch_empty_cache_steps = 250,
|
383 |
-
learning_rate = 5e-05,
|
384 |
-
weight_decay = 0.01,
|
385 |
-
adam_beta1 = 0.9,
|
386 |
-
adam_beta2 = 0.999,
|
387 |
-
adam_epsilon = 1e-08,
|
388 |
-
max_grad_norm = 1.0,
|
389 |
-
num_train_epochs = 3.0,
|
390 |
-
max_steps = -1,
|
391 |
-
lr_scheduler_type = 'linear',
|
392 |
-
warmup_ratio = 0.1,
|
393 |
-
warmup_steps = 0,
|
394 |
-
log_level = 'passive',
|
395 |
-
log_level_replica = 'warning',
|
396 |
-
log_on_each_node = True,
|
397 |
-
logging_dir = None,
|
398 |
-
logging_strategy = 'steps',
|
399 |
-
logging_first_step = False,
|
400 |
-
logging_steps = 1,
|
401 |
-
logging_nan_inf_filter = False,
|
402 |
-
save_strategy = 'steps',
|
403 |
-
save_steps = 500,
|
404 |
-
save_total_limit = None,
|
405 |
-
save_safetensors = True,
|
406 |
-
save_on_each_node = False,
|
407 |
-
save_only_model = False,
|
408 |
-
restore_callback_states_from_checkpoint = False,
|
409 |
-
no_cuda = False,
|
410 |
-
use_cpu = False,
|
411 |
-
use_mps_device = False,
|
412 |
-
seed = 3407,
|
413 |
-
data_seed = 3407,
|
414 |
-
jit_mode_eval = False,
|
415 |
-
use_ipex = False,
|
416 |
-
bf16 = False,
|
417 |
-
fp16 = False,
|
418 |
-
fp16_opt_level = 'O1',
|
419 |
-
half_precision_backend = 'auto',
|
420 |
-
bf16_full_eval = False,
|
421 |
-
fp16_full_eval = False,
|
422 |
-
tf32 = None,
|
423 |
-
local_rank = -1,
|
424 |
-
ddp_backend = None,
|
425 |
-
tpu_num_cores = None,
|
426 |
-
tpu_metrics_debug = False,
|
427 |
-
debug = '',
|
428 |
-
dataloader_drop_last = False,
|
429 |
-
eval_steps = None,
|
430 |
-
dataloader_num_workers = 0,
|
431 |
-
dataloader_prefetch_factor = None,
|
432 |
-
past_index = -1,
|
433 |
-
run_name = None,
|
434 |
-
disable_tqdm = None,
|
435 |
-
remove_unused_columns = False,
|
436 |
-
label_names = None,
|
437 |
-
load_best_model_at_end = False,
|
438 |
-
metric_for_best_model = None,
|
439 |
-
greater_is_better = None,
|
440 |
-
ignore_data_skip = False,
|
441 |
-
fsdp = '',
|
442 |
-
fsdp_min_num_params = 0,
|
443 |
-
fsdp_config = None,
|
444 |
-
fsdp_transformer_layer_cls_to_wrap = None,
|
445 |
-
accelerator_config = None,
|
446 |
-
deepspeed = None,
|
447 |
-
label_smoothing_factor = 0.0,
|
448 |
-
optim = 'adamw_8bit',
|
449 |
-
optim_args = None,
|
450 |
-
adafactor = False,
|
451 |
-
group_by_length = False,
|
452 |
-
length_column_name = 'length',
|
453 |
-
report_to = None,
|
454 |
-
ddp_find_unused_parameters = None,
|
455 |
-
ddp_bucket_cap_mb = None,
|
456 |
-
ddp_broadcast_buffers = None,
|
457 |
-
dataloader_pin_memory = True,
|
458 |
-
dataloader_persistent_workers = False,
|
459 |
-
skip_memory_metrics = True,
|
460 |
-
use_legacy_prediction_loop = False,
|
461 |
-
push_to_hub = False,
|
462 |
-
resume_from_checkpoint = None,
|
463 |
-
hub_model_id = None,
|
464 |
-
hub_strategy = 'every_save',
|
465 |
-
hub_token = None,
|
466 |
-
hub_private_repo = None,
|
467 |
-
hub_always_push = False,
|
468 |
-
gradient_checkpointing = False,
|
469 |
-
gradient_checkpointing_kwargs = None,
|
470 |
-
include_inputs_for_metrics = False,
|
471 |
-
eval_do_concat_batches = True,
|
472 |
-
fp16_backend = 'auto',
|
473 |
-
evaluation_strategy = None,
|
474 |
-
push_to_hub_model_id = None,
|
475 |
-
push_to_hub_organization = None,
|
476 |
-
push_to_hub_token = None,
|
477 |
-
mp_parameters = '',
|
478 |
-
auto_find_batch_size = False,
|
479 |
-
full_determinism = False,
|
480 |
-
torchdynamo = None,
|
481 |
-
ray_scope = 'last',
|
482 |
-
ddp_timeout = 1800,
|
483 |
-
torch_compile = False,
|
484 |
-
torch_compile_backend = None,
|
485 |
-
torch_compile_mode = None,
|
486 |
-
dispatch_batches = None,
|
487 |
-
split_batches = None,
|
488 |
-
include_tokens_per_second = False,
|
489 |
-
include_num_input_tokens_seen = False,
|
490 |
-
neftune_noise_alpha = None,
|
491 |
-
optim_target_modules = None,
|
492 |
-
batch_eval_metrics = False,
|
493 |
-
eval_on_start = False,
|
494 |
-
use_liger_kernel = False,
|
495 |
-
eval_use_gather_object = False,
|
496 |
-
average_tokens_across_devices = False,
|
497 |
-
model_init_kwargs = None,
|
498 |
-
max_prompt_length = 512,
|
499 |
-
num_generations = 8,
|
500 |
-
temperature = 0.9,
|
501 |
-
max_completion_length = 256,
|
502 |
-
ds3_gather_for_generation = True,
|
503 |
-
use_vllm = False,
|
504 |
-
vllm_device = 'auto',
|
505 |
-
vllm_gpu_memory_utilization = 0.9,
|
506 |
-
vllm_dtype = 'auto',
|
507 |
-
vllm_max_model_len = None,
|
508 |
-
beta = 0.04,
|
509 |
-
reward_weights = None,
|
510 |
-
sync_ref_model = False,
|
511 |
-
ref_model_mixup_alpha = 0.9,
|
512 |
-
ref_model_sync_steps = 64,
|
513 |
-
log_completions = False,
|
514 |
-
vllm_sampling_params = None,
|
515 |
-
unsloth_num_chunks = -1,
|
516 |
-
**kwargs,
|
517 |
-
):
|
518 |
-
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
519 |
-
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
520 |
-
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
521 |
-
output_dir = 'unsloth_training_checkpoints'
|
522 |
-
save_strategy = 'no'
|
523 |
-
div = per_device_train_batch_size // num_generations
|
524 |
-
if div * num_generations != per_device_train_batch_size:
|
525 |
-
print('Unsloth: We now expect `per_device_train_batch_size` to be a multiple of `num_generations`.\nWe will change the batch size of ' + str(per_device_train_batch_size) + ' to the `num_generations` of ' + str(num_generations))
|
526 |
-
per_device_train_batch_size = num_generations
|
527 |
-
|
528 |
-
super().__init__(
|
529 |
-
output_dir = output_dir,
|
530 |
-
overwrite_output_dir = overwrite_output_dir,
|
531 |
-
do_train = do_train,
|
532 |
-
do_eval = do_eval,
|
533 |
-
do_predict = do_predict,
|
534 |
-
eval_strategy = eval_strategy,
|
535 |
-
prediction_loss_only = prediction_loss_only,
|
536 |
-
per_device_train_batch_size = per_device_train_batch_size,
|
537 |
-
per_device_eval_batch_size = per_device_eval_batch_size,
|
538 |
-
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
539 |
-
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
540 |
-
gradient_accumulation_steps = gradient_accumulation_steps,
|
541 |
-
eval_accumulation_steps = eval_accumulation_steps,
|
542 |
-
eval_delay = eval_delay,
|
543 |
-
torch_empty_cache_steps = torch_empty_cache_steps,
|
544 |
-
learning_rate = learning_rate,
|
545 |
-
weight_decay = weight_decay,
|
546 |
-
adam_beta1 = adam_beta1,
|
547 |
-
adam_beta2 = adam_beta2,
|
548 |
-
adam_epsilon = adam_epsilon,
|
549 |
-
max_grad_norm = max_grad_norm,
|
550 |
-
num_train_epochs = num_train_epochs,
|
551 |
-
max_steps = max_steps,
|
552 |
-
lr_scheduler_type = lr_scheduler_type,
|
553 |
-
warmup_ratio = warmup_ratio,
|
554 |
-
warmup_steps = warmup_steps,
|
555 |
-
log_level = log_level,
|
556 |
-
log_level_replica = log_level_replica,
|
557 |
-
log_on_each_node = log_on_each_node,
|
558 |
-
logging_dir = logging_dir,
|
559 |
-
logging_strategy = logging_strategy,
|
560 |
-
logging_first_step = logging_first_step,
|
561 |
-
logging_steps = logging_steps,
|
562 |
-
logging_nan_inf_filter = logging_nan_inf_filter,
|
563 |
-
save_strategy = save_strategy,
|
564 |
-
save_steps = save_steps,
|
565 |
-
save_total_limit = save_total_limit,
|
566 |
-
save_safetensors = save_safetensors,
|
567 |
-
save_on_each_node = save_on_each_node,
|
568 |
-
save_only_model = save_only_model,
|
569 |
-
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
570 |
-
no_cuda = no_cuda,
|
571 |
-
use_cpu = use_cpu,
|
572 |
-
use_mps_device = use_mps_device,
|
573 |
-
seed = seed,
|
574 |
-
data_seed = data_seed,
|
575 |
-
jit_mode_eval = jit_mode_eval,
|
576 |
-
use_ipex = use_ipex,
|
577 |
-
bf16 = bf16,
|
578 |
-
fp16 = fp16,
|
579 |
-
fp16_opt_level = fp16_opt_level,
|
580 |
-
half_precision_backend = half_precision_backend,
|
581 |
-
bf16_full_eval = bf16_full_eval,
|
582 |
-
fp16_full_eval = fp16_full_eval,
|
583 |
-
tf32 = tf32,
|
584 |
-
local_rank = local_rank,
|
585 |
-
ddp_backend = ddp_backend,
|
586 |
-
tpu_num_cores = tpu_num_cores,
|
587 |
-
tpu_metrics_debug = tpu_metrics_debug,
|
588 |
-
debug = debug,
|
589 |
-
dataloader_drop_last = dataloader_drop_last,
|
590 |
-
eval_steps = eval_steps,
|
591 |
-
dataloader_num_workers = dataloader_num_workers,
|
592 |
-
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
593 |
-
past_index = past_index,
|
594 |
-
run_name = run_name,
|
595 |
-
disable_tqdm = disable_tqdm,
|
596 |
-
remove_unused_columns = remove_unused_columns,
|
597 |
-
label_names = label_names,
|
598 |
-
load_best_model_at_end = load_best_model_at_end,
|
599 |
-
metric_for_best_model = metric_for_best_model,
|
600 |
-
greater_is_better = greater_is_better,
|
601 |
-
ignore_data_skip = ignore_data_skip,
|
602 |
-
fsdp = fsdp,
|
603 |
-
fsdp_min_num_params = fsdp_min_num_params,
|
604 |
-
fsdp_config = fsdp_config,
|
605 |
-
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
606 |
-
accelerator_config = accelerator_config,
|
607 |
-
deepspeed = deepspeed,
|
608 |
-
label_smoothing_factor = label_smoothing_factor,
|
609 |
-
optim = optim,
|
610 |
-
optim_args = optim_args,
|
611 |
-
adafactor = adafactor,
|
612 |
-
group_by_length = group_by_length,
|
613 |
-
length_column_name = length_column_name,
|
614 |
-
report_to = report_to,
|
615 |
-
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
616 |
-
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
617 |
-
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
618 |
-
dataloader_pin_memory = dataloader_pin_memory,
|
619 |
-
dataloader_persistent_workers = dataloader_persistent_workers,
|
620 |
-
skip_memory_metrics = skip_memory_metrics,
|
621 |
-
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
622 |
-
push_to_hub = push_to_hub,
|
623 |
-
resume_from_checkpoint = resume_from_checkpoint,
|
624 |
-
hub_model_id = hub_model_id,
|
625 |
-
hub_strategy = hub_strategy,
|
626 |
-
hub_token = hub_token,
|
627 |
-
hub_private_repo = hub_private_repo,
|
628 |
-
hub_always_push = hub_always_push,
|
629 |
-
gradient_checkpointing = gradient_checkpointing,
|
630 |
-
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
631 |
-
include_inputs_for_metrics = include_inputs_for_metrics,
|
632 |
-
eval_do_concat_batches = eval_do_concat_batches,
|
633 |
-
fp16_backend = fp16_backend,
|
634 |
-
evaluation_strategy = evaluation_strategy,
|
635 |
-
push_to_hub_model_id = push_to_hub_model_id,
|
636 |
-
push_to_hub_organization = push_to_hub_organization,
|
637 |
-
push_to_hub_token = push_to_hub_token,
|
638 |
-
mp_parameters = mp_parameters,
|
639 |
-
auto_find_batch_size = auto_find_batch_size,
|
640 |
-
full_determinism = full_determinism,
|
641 |
-
torchdynamo = torchdynamo,
|
642 |
-
ray_scope = ray_scope,
|
643 |
-
ddp_timeout = ddp_timeout,
|
644 |
-
torch_compile = torch_compile,
|
645 |
-
torch_compile_backend = torch_compile_backend,
|
646 |
-
torch_compile_mode = torch_compile_mode,
|
647 |
-
dispatch_batches = dispatch_batches,
|
648 |
-
split_batches = split_batches,
|
649 |
-
include_tokens_per_second = include_tokens_per_second,
|
650 |
-
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
651 |
-
neftune_noise_alpha = neftune_noise_alpha,
|
652 |
-
optim_target_modules = optim_target_modules,
|
653 |
-
batch_eval_metrics = batch_eval_metrics,
|
654 |
-
eval_on_start = eval_on_start,
|
655 |
-
use_liger_kernel = use_liger_kernel,
|
656 |
-
eval_use_gather_object = eval_use_gather_object,
|
657 |
-
average_tokens_across_devices = average_tokens_across_devices,
|
658 |
-
model_init_kwargs = model_init_kwargs,
|
659 |
-
max_prompt_length = max_prompt_length,
|
660 |
-
num_generations = num_generations,
|
661 |
-
temperature = temperature,
|
662 |
-
max_completion_length = max_completion_length,
|
663 |
-
ds3_gather_for_generation = ds3_gather_for_generation,
|
664 |
-
use_vllm = use_vllm,
|
665 |
-
vllm_device = vllm_device,
|
666 |
-
vllm_gpu_memory_utilization = vllm_gpu_memory_utilization,
|
667 |
-
vllm_dtype = vllm_dtype,
|
668 |
-
vllm_max_model_len = vllm_max_model_len,
|
669 |
-
beta = beta,
|
670 |
-
reward_weights = reward_weights,
|
671 |
-
sync_ref_model = sync_ref_model,
|
672 |
-
ref_model_mixup_alpha = ref_model_mixup_alpha,
|
673 |
-
ref_model_sync_steps = ref_model_sync_steps,
|
674 |
-
log_completions = log_completions,**kwargs)
|
675 |
-
self.vllm_sampling_params = vllm_sampling_params
|
676 |
-
self.unsloth_num_chunks = unsloth_num_chunks
|
677 |
-
pass
|
678 |
-
|
679 |
-
class _UnslothGRPOTrainer(Trainer):
|
680 |
-
""""""
|
681 |
-
|
682 |
-
_tag_names = ["trl", "grpo"]
|
683 |
-
|
684 |
-
def __init__(
|
685 |
-
self,
|
686 |
-
model: Union[str, PreTrainedModel],
|
687 |
-
reward_funcs: Union[RewardFunc, list[RewardFunc]],
|
688 |
-
args: GRPOConfig = None,
|
689 |
-
train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
|
690 |
-
eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None,
|
691 |
-
processing_class: Optional[PreTrainedTokenizerBase] = None,
|
692 |
-
reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None,
|
693 |
-
callbacks: Optional[list[TrainerCallback]] = None,
|
694 |
-
optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
|
695 |
-
peft_config: Optional["PeftConfig"] = None,
|
696 |
-
):
|
697 |
-
|
698 |
-
if hasattr(model, 'vllm_engine') and hasattr(args, 'use_vllm') and (getattr(args, 'use_vllm', False) == False): args.use_vllm = True
|
699 |
-
# Args
|
700 |
-
if args is None:
|
701 |
-
model_name = model if isinstance(model, str) else model.config._name_or_path
|
702 |
-
model_name = model_name.split("/")[-1]
|
703 |
-
args = GRPOConfig(f"{model_name}-GRPO")
|
704 |
-
|
705 |
-
# Models
|
706 |
-
# Trained model
|
707 |
-
model_init_kwargs = args.model_init_kwargs or {}
|
708 |
-
if isinstance(model, str):
|
709 |
-
model_id = model
|
710 |
-
torch_dtype = model_init_kwargs.get("torch_dtype")
|
711 |
-
if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None:
|
712 |
-
pass # torch_dtype is already a torch.dtype or "auto" or None
|
713 |
-
elif isinstance(torch_dtype, str): # it's a str, but not "auto"
|
714 |
-
torch_dtype = getattr(torch, torch_dtype)
|
715 |
-
model_init_kwargs["torch_dtype"] = torch_dtype
|
716 |
-
else:
|
717 |
-
raise ValueError(
|
718 |
-
"Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing "
|
719 |
-
f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
|
720 |
-
)
|
721 |
-
# Disable caching if gradient checkpointing is enabled (not supported)
|
722 |
-
model_init_kwargs["use_cache"] = (
|
723 |
-
False if args.gradient_checkpointing else model_init_kwargs.get("use_cache")
|
724 |
-
)
|
725 |
-
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
|
726 |
-
else:
|
727 |
-
model_id = model.config._name_or_path
|
728 |
-
if args.model_init_kwargs is not None:
|
729 |
-
raise ValueError(
|
730 |
-
"You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. "
|
731 |
-
"This argument can only be used when the `model` argument is a string."
|
732 |
-
)
|
733 |
-
|
734 |
-
if False:
|
735 |
-
model = model
|
736 |
-
|
737 |
-
# Reference model
|
738 |
-
if is_deepspeed_zero3_enabled():
|
739 |
-
self.ref_model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs)
|
740 |
-
elif not is_peft_model(model):
|
741 |
-
# If PEFT configuration is not provided, create a reference model based on the initial model.
|
742 |
-
self.ref_model = create_reference_model(model)
|
743 |
-
else:
|
744 |
-
# If PEFT is used, the reference model is not needed since the adapter can be disabled
|
745 |
-
# to revert to the initial model.
|
746 |
-
self.ref_model = None
|
747 |
-
|
748 |
-
# Processing class
|
749 |
-
if processing_class is None:
|
750 |
-
processing_class = AutoTokenizer.from_pretrained(model.config._name_or_path, padding_side="left")
|
751 |
-
|
752 |
-
# Reward functions
|
753 |
-
if not isinstance(reward_funcs, list):
|
754 |
-
reward_funcs = [reward_funcs]
|
755 |
-
for i, reward_func in enumerate(reward_funcs):
|
756 |
-
if isinstance(reward_func, str):
|
757 |
-
reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
|
758 |
-
reward_func, num_labels=1, **model_init_kwargs
|
759 |
-
)
|
760 |
-
self.reward_funcs = reward_funcs
|
761 |
-
|
762 |
-
# Reward weights
|
763 |
-
if args.reward_weights is not None:
|
764 |
-
if len(args.reward_weights) != len(reward_funcs):
|
765 |
-
raise ValueError(
|
766 |
-
f"Number of reward weights ({len(args.reward_weights)}) must match number of reward "
|
767 |
-
f"functions ({len(reward_funcs)})"
|
768 |
-
)
|
769 |
-
self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32)
|
770 |
-
else:
|
771 |
-
self.reward_weights = torch.ones(len(reward_funcs), dtype=torch.float32)
|
772 |
-
|
773 |
-
# Reward processing class
|
774 |
-
if reward_processing_classes is None:
|
775 |
-
reward_processing_classes = [None] * len(reward_funcs)
|
776 |
-
elif not isinstance(reward_processing_classes, list):
|
777 |
-
reward_processing_classes = [reward_processing_classes]
|
778 |
-
else:
|
779 |
-
if len(reward_processing_classes) != len(reward_funcs):
|
780 |
-
raise ValueError("The number of reward processing classes must match the number of reward functions.")
|
781 |
-
|
782 |
-
for i, (reward_processing_class, reward_func) in enumerate(zip(reward_processing_classes, reward_funcs)):
|
783 |
-
if isinstance(reward_func, PreTrainedModel):
|
784 |
-
if reward_processing_class is None:
|
785 |
-
reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path)
|
786 |
-
if reward_processing_class.pad_token_id is None:
|
787 |
-
reward_processing_class.pad_token = reward_processing_class.eos_token
|
788 |
-
# The reward model computes the reward for the latest non-padded token in the input sequence.
|
789 |
-
# So it's important to set the pad token ID to the padding token ID of the processing class.
|
790 |
-
reward_func.config.pad_token_id = reward_processing_class.pad_token_id
|
791 |
-
reward_processing_classes[i] = reward_processing_class
|
792 |
-
self.reward_processing_classes = reward_processing_classes
|
793 |
-
|
794 |
-
# Data collator
|
795 |
-
def data_collator(features): # No data collation is needed in GRPO
|
796 |
-
return features
|
797 |
-
|
798 |
-
# Training arguments
|
799 |
-
self.max_prompt_length = args.max_prompt_length
|
800 |
-
self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper
|
801 |
-
self.num_generations = args.num_generations # = G in the GRPO paper
|
802 |
-
self.use_vllm = args.use_vllm
|
803 |
-
|
804 |
-
self.beta = args.beta
|
805 |
-
|
806 |
-
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
|
807 |
-
# input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the
|
808 |
-
# "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning:
|
809 |
-
# "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To
|
810 |
-
# suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True.
|
811 |
-
# This acts as a flag to indicate that the warning has already been issued.
|
812 |
-
model.warnings_issued["estimate_tokens"] = True
|
813 |
-
|
814 |
-
# Initialize the metrics
|
815 |
-
self._metrics = defaultdict(list)
|
816 |
-
self.log_completions = args.log_completions
|
817 |
-
|
818 |
-
super().__init__(
|
819 |
-
model=model,
|
820 |
-
args=args,
|
821 |
-
data_collator=data_collator,
|
822 |
-
train_dataset=train_dataset,
|
823 |
-
eval_dataset=eval_dataset,
|
824 |
-
processing_class=processing_class,
|
825 |
-
callbacks=callbacks,
|
826 |
-
optimizers=optimizers,
|
827 |
-
)
|
828 |
-
|
829 |
-
# Check if the per_device_train/eval_batch_size * num processes can be divided by the number of generations
|
830 |
-
num_processes = self.accelerator.num_processes
|
831 |
-
global_batch_size = args.per_device_train_batch_size * num_processes
|
832 |
-
possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0]
|
833 |
-
if self.num_generations not in possible_values:
|
834 |
-
raise ValueError(
|
835 |
-
f"The global train batch size ({num_processes} x {args.per_device_train_batch_size}) must be evenly "
|
836 |
-
f"divisible by the number of generations per prompt ({self.num_generations}). Given the current train "
|
837 |
-
f"batch size, the valid values for the number of generations are: {possible_values}."
|
838 |
-
)
|
839 |
-
if self.args.eval_strategy != "no":
|
840 |
-
global_batch_size = args.per_device_eval_batch_size * num_processes
|
841 |
-
possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0]
|
842 |
-
if self.num_generations not in possible_values:
|
843 |
-
raise ValueError(
|
844 |
-
f"The global eval batch size ({num_processes} x {args.per_device_eval_batch_size}) must be evenly "
|
845 |
-
f"divisible by the number of generations per prompt ({self.num_generations}). Given the current "
|
846 |
-
f"eval batch size, the valid values for the number of generations are: {possible_values}."
|
847 |
-
)
|
848 |
-
|
849 |
-
# Ensure each process receives a unique seed to prevent duplicate completions when generating with
|
850 |
-
# transformers if num_generations exceeds per_device_train_batch_size. We could skip it if we use vLLM, but
|
851 |
-
# it's safer to set it in all cases.
|
852 |
-
set_seed(args.seed, device_specific=True)
|
853 |
-
|
854 |
-
if self.use_vllm:
|
855 |
-
self.llm = model.vllm_engine; self._last_loaded_step = 0; self.sampling_params = SamplingParams(
|
856 |
-
temperature=args.temperature,
|
857 |
-
max_tokens=self.max_completion_length,**getattr(getattr(args, 'vllm_sampling_params', vLLMSamplingParams()), '_set_kwargs', {}),)
|
858 |
-
else:
|
859 |
-
self.generation_config = GenerationConfig(
|
860 |
-
max_new_tokens=self.max_completion_length,
|
861 |
-
do_sample=True,
|
862 |
-
temperature=args.temperature,
|
863 |
-
pad_token_id=processing_class.pad_token_id,
|
864 |
-
)
|
865 |
-
|
866 |
-
# Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
|
867 |
-
# model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
|
868 |
-
# self.model_accepts_loss_kwargs to False to enable scaling.
|
869 |
-
self.model_accepts_loss_kwargs = False
|
870 |
-
|
871 |
-
# Add tags to the model
|
872 |
-
self.model.add_model_tags(self._tag_names)
|
873 |
-
|
874 |
-
if self.ref_model is not None:
|
875 |
-
if self.is_deepspeed_enabled:
|
876 |
-
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
|
877 |
-
else:
|
878 |
-
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
879 |
-
|
880 |
-
if args.sync_ref_model:
|
881 |
-
self.add_callback(SyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator))
|
882 |
-
|
883 |
-
for i, reward_func in enumerate(self.reward_funcs):
|
884 |
-
if isinstance(reward_func, PreTrainedModel):
|
885 |
-
self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True)
|
886 |
-
|
887 |
-
def _set_signature_columns_if_needed(self):
|
888 |
-
# If `self.args.remove_unused_columns` is True, non-signature columns are removed.
|
889 |
-
# By default, this method sets `self._signature_columns` to the model's expected inputs.
|
890 |
-
# In GRPOTrainer, we preprocess data, so using the model's signature columns doesn't work.
|
891 |
-
# Instead, we set them to the columns expected by the `training_step` method, hence the override.
|
892 |
-
if self._signature_columns is None:
|
893 |
-
self._signature_columns = ["prompt"]
|
894 |
-
|
895 |
-
def _get_train_sampler(self) -> Sampler:
|
896 |
-
# Returns a sampler that ensures each prompt is repeated across multiple processes. This guarantees that
|
897 |
-
# identical prompts are distributed to different GPUs, allowing rewards to be computed and normalized correctly
|
898 |
-
# within each prompt group. Using the same seed across processes ensures consistent prompt assignment,
|
899 |
-
# preventing discrepancies in group formation.
|
900 |
-
return RepeatRandomSampler(self.train_dataset, self.num_generations, seed=self.args.seed)
|
901 |
-
|
902 |
-
def _get_eval_sampler(self, eval_dataset) -> Sampler:
|
903 |
-
# Returns a sampler that ensures each prompt is repeated across multiple processes. This guarantees that
|
904 |
-
# identical prompts are distributed to different GPUs, allowing rewards to be computed and normalized correctly
|
905 |
-
# within each prompt group. Using the same seed across processes ensures consistent prompt assignment,
|
906 |
-
# preventing discrepancies in group formation.
|
907 |
-
return RepeatRandomSampler(eval_dataset, self.num_generations, seed=self.args.seed)
|
908 |
-
|
909 |
-
# Get the per-token log probabilities for the completions for the model and the reference model
|
910 |
-
def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
|
911 |
-
if os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0':
|
912 |
-
return None # Unsloth efficient GRPO
|
913 |
-
# Otherwise, calculate normally:
|
914 |
-
if not hasattr(self, '_autocast_dtype'):
|
915 |
-
self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16
|
916 |
-
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1': self._autocast_dtype = torch.float16
|
917 |
-
with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype):
|
918 |
-
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
|
919 |
-
logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
|
920 |
-
logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
|
921 |
-
|
922 |
-
input_ids = input_ids[:, -logits_to_keep:]
|
923 |
-
# For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
|
924 |
-
# See https://github.com/huggingface/trl/issues/2770
|
925 |
-
logits = logits[:, -logits_to_keep:]
|
926 |
-
return logits
|
927 |
-
# return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens
|
928 |
-
pass
|
929 |
-
|
930 |
-
def _move_model_to_vllm(self, *args, **kwargs): return None
|
931 |
-
|
932 |
-
def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
|
933 |
-
device = self.accelerator.device
|
934 |
-
prompts = [x["prompt"] for x in inputs]
|
935 |
-
prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
|
936 |
-
prompt_inputs = self.processing_class(
|
937 |
-
prompts_text, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False
|
938 |
-
)
|
939 |
-
prompt_inputs = super()._prepare_inputs(prompt_inputs)
|
940 |
-
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
|
941 |
-
|
942 |
-
if self.max_prompt_length is not None:
|
943 |
-
prompt_ids = prompt_ids[:, -self.max_prompt_length :]
|
944 |
-
prompt_mask = prompt_mask[:, -self.max_prompt_length :]
|
945 |
-
|
946 |
-
# Generate completions using either vLLM or regular generation
|
947 |
-
if self.args.use_vllm:
|
948 |
-
# First, have main process load weights if needed
|
949 |
-
if self.state.global_step != self._last_loaded_step:
|
950 |
-
self._move_model_to_vllm()
|
951 |
-
self._last_loaded_step = self.state.global_step
|
952 |
-
|
953 |
-
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
|
954 |
-
all_prompts_text = gather_object(prompts_text)
|
955 |
-
if self.accelerator.is_main_process:
|
956 |
-
outputs = self.llm.generate(all_prompts_text, sampling_params=self.sampling_params, use_tqdm=False, lora_request = self.model.load_lora('grpo_trainer_lora_model', load_tensors = True))
|
957 |
-
completion_ids = [out.token_ids for completions in outputs for out in completions.outputs]
|
958 |
-
else:
|
959 |
-
completion_ids = [None] * len(all_prompts_text)
|
960 |
-
# Broadcast the completions from the main process to all processes, ensuring each process receives its
|
961 |
-
# corresponding slice.
|
962 |
-
completion_ids = broadcast_object_list(completion_ids, from_process=0)
|
963 |
-
process_slice = slice(
|
964 |
-
self.accelerator.process_index * len(prompts),
|
965 |
-
(self.accelerator.process_index + 1) * len(prompts),
|
966 |
-
)
|
967 |
-
completion_ids = completion_ids[process_slice]
|
968 |
-
|
969 |
-
# Pad the completions, and concatenate them with the prompts
|
970 |
-
completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
|
971 |
-
completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id)
|
972 |
-
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
|
973 |
-
else:
|
974 |
-
# Regular generation path
|
975 |
-
with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
|
976 |
-
prompt_completion_ids = unwrapped_model.generate(
|
977 |
-
prompt_ids, attention_mask=prompt_mask, generation_config=self.generation_config
|
978 |
-
)
|
979 |
-
|
980 |
-
# Compute prompt length and extract completion ids
|
981 |
-
prompt_length = prompt_ids.size(1)
|
982 |
-
prompt_ids = prompt_completion_ids[:, :prompt_length]
|
983 |
-
completion_ids = prompt_completion_ids[:, prompt_length:]
|
984 |
-
|
985 |
-
# Mask everything after the first EOS token
|
986 |
-
is_eos = completion_ids == self.processing_class.eos_token_id
|
987 |
-
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
|
988 |
-
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
|
989 |
-
sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
|
990 |
-
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
|
991 |
-
|
992 |
-
# Concatenate prompt_mask with completion_mask for logit computation
|
993 |
-
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B*G, P+C)
|
994 |
-
|
995 |
-
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
|
996 |
-
|
997 |
-
with torch.inference_mode(), torch.amp.autocast(device_type = 'cuda', dtype = ((torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16) if not torch.is_autocast_enabled('cuda') else nullcontext())if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '0' else torch.float16):
|
998 |
-
if self.ref_model is not None:
|
999 |
-
ref_per_token_logps = self._get_per_token_logps(
|
1000 |
-
self.ref_model, prompt_completion_ids, attention_mask, logits_to_keep
|
1001 |
-
)
|
1002 |
-
else:
|
1003 |
-
with self.accelerator.unwrap_model(self.model, keep_fp32_wrapper = False).disable_adapter():
|
1004 |
-
ref_per_token_logps = self._get_per_token_logps(
|
1005 |
-
self.model, prompt_completion_ids, attention_mask, logits_to_keep
|
1006 |
-
)
|
1007 |
-
|
1008 |
-
# Decode the generated completions
|
1009 |
-
completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
|
1010 |
-
if is_conversational(inputs[0]):
|
1011 |
-
completions = []
|
1012 |
-
for prompt, completion in zip(prompts, completions_text):
|
1013 |
-
bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else ""
|
1014 |
-
completions.append([{"role": "assistant", "content": bootstrap + completion}])
|
1015 |
-
else:
|
1016 |
-
completions = completions_text
|
1017 |
-
|
1018 |
-
rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
|
1019 |
-
for i, (reward_func, reward_processing_class) in enumerate(
|
1020 |
-
zip(self.reward_funcs, self.reward_processing_classes)
|
1021 |
-
):
|
1022 |
-
if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models
|
1023 |
-
if is_conversational(inputs[0]):
|
1024 |
-
messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
|
1025 |
-
texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
|
1026 |
-
else:
|
1027 |
-
texts = [p + c for p, c in zip(prompts, completions)]
|
1028 |
-
reward_inputs = reward_processing_class(
|
1029 |
-
texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
|
1030 |
-
)
|
1031 |
-
reward_inputs = super()._prepare_inputs(reward_inputs)
|
1032 |
-
with torch.inference_mode(), torch.amp.autocast(device_type = 'cuda', dtype = ((torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16) if not torch.is_autocast_enabled('cuda') else nullcontext())if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '0' else torch.float16):
|
1033 |
-
rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,)
|
1034 |
-
else:
|
1035 |
-
# Repeat all input columns (but "prompt" and "completion") to match the number of generations
|
1036 |
-
keys = [key for key in inputs[0] if key not in ["prompt", "completion"]]
|
1037 |
-
reward_kwargs = {key: [example[key] for example in inputs] for key in keys}
|
1038 |
-
output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs)
|
1039 |
-
rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
|
1040 |
-
|
1041 |
-
# Gather the reward per function: this part is crucial, because the rewards are normalized per group and the
|
1042 |
-
# completions may be distributed across processes
|
1043 |
-
rewards_per_func = gather(rewards_per_func)
|
1044 |
-
|
1045 |
-
# Apply weights to each reward function's output and sum
|
1046 |
-
rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).sum(dim=1)
|
1047 |
-
|
1048 |
-
# Compute grouped-wise rewards
|
1049 |
-
mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
|
1050 |
-
std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
|
1051 |
-
|
1052 |
-
# Normalize the rewards to compute the advantages
|
1053 |
-
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
|
1054 |
-
std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
|
1055 |
-
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
|
1056 |
-
|
1057 |
-
# Slice to keep only the local part of the data
|
1058 |
-
process_slice = slice(
|
1059 |
-
self.accelerator.process_index * len(prompts),
|
1060 |
-
(self.accelerator.process_index + 1) * len(prompts),
|
1061 |
-
)
|
1062 |
-
advantages = advantages[process_slice]
|
1063 |
-
|
1064 |
-
# Log the metrics
|
1065 |
-
reward_per_func = rewards_per_func.mean(0)
|
1066 |
-
for i, reward_func in enumerate(self.reward_funcs):
|
1067 |
-
if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models
|
1068 |
-
reward_func_name = reward_func.config._name_or_path.split("/")[-1]
|
1069 |
-
else:
|
1070 |
-
reward_func_name = reward_func.__name__
|
1071 |
-
self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item())
|
1072 |
-
|
1073 |
-
self._metrics["reward"].append(rewards.mean().item())
|
1074 |
-
self._metrics["reward_std"].append(std_grouped_rewards.mean().item())
|
1075 |
-
|
1076 |
-
if (
|
1077 |
-
self.log_completions
|
1078 |
-
and self.state.global_step % self.args.logging_steps == 0
|
1079 |
-
and "wandb" in self.args.report_to
|
1080 |
-
):
|
1081 |
-
import pandas as pd
|
1082 |
-
|
1083 |
-
# For logging
|
1084 |
-
table = {
|
1085 |
-
"step": [str(self.state.global_step)] * len(rewards),
|
1086 |
-
"prompt": gather_object(prompts_text),
|
1087 |
-
"completion": gather_object(completions_text),
|
1088 |
-
"reward": rewards.tolist(),
|
1089 |
-
}
|
1090 |
-
df = pd.DataFrame(table)
|
1091 |
-
|
1092 |
-
if wandb.run is not None and self.accelerator.is_main_process:
|
1093 |
-
wandb.log({"completions": wandb.Table(dataframe=df)})
|
1094 |
-
|
1095 |
-
return {
|
1096 |
-
"prompt_ids": prompt_ids,
|
1097 |
-
"prompt_mask": prompt_mask,
|
1098 |
-
"completion_ids": completion_ids,
|
1099 |
-
"completion_mask": completion_mask,
|
1100 |
-
"ref_per_token_logps": ref_per_token_logps,
|
1101 |
-
"advantages": advantages,
|
1102 |
-
}
|
1103 |
-
|
1104 |
-
def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch = None):
|
1105 |
-
if return_outputs:
|
1106 |
-
raise ValueError("The GRPOTrainer does not support returning outputs")
|
1107 |
-
# Compute the per-token log probabilities for the model
|
1108 |
-
|
1109 |
-
prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
|
1110 |
-
completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
|
1111 |
-
input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
|
1112 |
-
bsz, qlen = input_ids.shape
|
1113 |
-
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
|
1114 |
-
# attention_mask = None
|
1115 |
-
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
|
1116 |
-
_input_ids = input_ids
|
1117 |
-
_logits_to_keep = logits_to_keep
|
1118 |
-
|
1119 |
-
per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)
|
1120 |
-
|
1121 |
-
# Compute the KL divergence between the model and the reference model
|
1122 |
-
ref_per_token_logps = inputs["ref_per_token_logps"]
|
1123 |
-
# per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
|
1124 |
-
|
1125 |
-
# x - x.detach() allows for preserving gradients from x
|
1126 |
-
advantages = inputs["advantages"]
|
1127 |
-
# per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
|
1128 |
-
# per_token_loss = -(per_token_loss - self.beta * per_token_kl)
|
1129 |
-
# loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
|
1130 |
-
input_ids = input_ids[:, -logits_to_keep:]
|
1131 |
-
if per_token_logps is not None:
|
1132 |
-
loss, completion_length, mean_kl = grpo_compute_loss_slow(
|
1133 |
-
ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages,
|
1134 |
-
)
|
1135 |
-
else:
|
1136 |
-
loss, completion_length, mean_kl = grpo_accumulated_loss(
|
1137 |
-
self, _input_ids, logits_to_keep, completion_mask, advantages,
|
1138 |
-
n_chunks = self.args.unsloth_num_chunks,
|
1139 |
-
)
|
1140 |
-
|
1141 |
-
# Log the metrics
|
1142 |
-
# completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
|
1143 |
-
|
1144 |
-
# mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
|
1145 |
-
# self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
|
1146 |
-
|
1147 |
-
if "train" in self._metrics:
|
1148 |
-
mode = "eval" if self.control.should_evaluate else "train"
|
1149 |
-
self._metrics[mode]["completion_length"].append(completion_length.item())
|
1150 |
-
self._metrics[mode]["kl"].append(mean_kl.item())
|
1151 |
-
else:
|
1152 |
-
self._metrics["completion_length"].append(completion_length.item())
|
1153 |
-
self._metrics["kl"].append(mean_kl.item())
|
1154 |
-
return loss
|
1155 |
-
|
1156 |
-
def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys: Optional[list[str]] = None):
|
1157 |
-
inputs = self._prepare_inputs(inputs)
|
1158 |
-
with torch.no_grad():
|
1159 |
-
with self.compute_loss_context_manager():
|
1160 |
-
loss = self.compute_loss(model, inputs)
|
1161 |
-
loss = loss.mean().detach()
|
1162 |
-
return loss, None, None
|
1163 |
-
|
1164 |
-
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
|
1165 |
-
metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics
|
1166 |
-
|
1167 |
-
# This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
|
1168 |
-
# start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
|
1169 |
-
if next(iter(logs.keys())).startswith("eval_"):
|
1170 |
-
metrics = {f"eval_{key}": val for key, val in metrics.items()}
|
1171 |
-
|
1172 |
-
logs = {**logs, **metrics}
|
1173 |
-
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
1174 |
-
super().log(logs, start_time)
|
1175 |
-
else: # transformers<=4.46
|
1176 |
-
super().log(logs)
|
1177 |
-
self._metrics.clear()
|
1178 |
-
|
1179 |
-
def create_model_card(
|
1180 |
-
self,
|
1181 |
-
model_name: Optional[str] = None,
|
1182 |
-
dataset_name: Optional[str] = None,
|
1183 |
-
tags: Union[str, list[str], None] = None,
|
1184 |
-
):
|
1185 |
-
"""
|
1186 |
-
Creates a draft of a model card using the information available to the `Trainer`.
|
1187 |
-
|
1188 |
-
Args:
|
1189 |
-
model_name (`str` or `None`, *optional*, defaults to `None`):
|
1190 |
-
Name of the model.
|
1191 |
-
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
1192 |
-
Name of the dataset used for training.
|
1193 |
-
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
1194 |
-
Tags to be associated with the model card.
|
1195 |
-
"""
|
1196 |
-
if not self.is_world_process_zero():
|
1197 |
-
return
|
1198 |
-
|
1199 |
-
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
1200 |
-
base_model = self.model.config._name_or_path
|
1201 |
-
else:
|
1202 |
-
base_model = None
|
1203 |
-
|
1204 |
-
tags = tags or []
|
1205 |
-
if isinstance(tags, str):
|
1206 |
-
tags = [tags]
|
1207 |
-
|
1208 |
-
if hasattr(self.model.config, "unsloth_version"):
|
1209 |
-
tags.append("unsloth")
|
1210 |
-
|
1211 |
-
citation = textwrap.dedent(
|
1212 |
-
"""\
|
1213 |
-
@article{zhihong2024deepseekmath,
|
1214 |
-
title = {{DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models}},
|
1215 |
-
author = {Zhihong Shao and Peiyi Wang and Qihao Zhu and Runxin Xu and Junxiao Song and Mingchuan Zhang and Y. K. Li and Y. Wu and Daya Guo},
|
1216 |
-
year = 2024,
|
1217 |
-
eprint = {arXiv:2402.03300},
|
1218 |
-
}
|
1219 |
-
"""
|
1220 |
-
)
|
1221 |
-
|
1222 |
-
model_card = generate_model_card(
|
1223 |
-
base_model=base_model,
|
1224 |
-
model_name=model_name,
|
1225 |
-
hub_model_id=self.hub_model_id,
|
1226 |
-
dataset_name=dataset_name,
|
1227 |
-
tags=tags,
|
1228 |
-
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
1229 |
-
comet_url=get_comet_experiment_url(),
|
1230 |
-
trainer_name="GRPO",
|
1231 |
-
trainer_citation=citation,
|
1232 |
-
paper_title="DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models",
|
1233 |
-
paper_id="2402.03300",
|
1234 |
-
)
|
1235 |
-
|
1236 |
-
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
1237 |
-
class UnslothGRPOTrainer(_UnslothGRPOTrainer):
|
1238 |
-
"""
|
1239 |
-
|
1240 |
-
Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the
|
1241 |
-
paper [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300).
|
1242 |
-
|
1243 |
-
Example:
|
1244 |
-
|
1245 |
-
```python
|
1246 |
-
from datasets import load_dataset
|
1247 |
-
from trl import GRPOTrainer
|
1248 |
-
|
1249 |
-
dataset = load_dataset("trl-lib/tldr", split="train")
|
1250 |
-
|
1251 |
-
def reward_func(completions, **kwargs):
|
1252 |
-
# Dummy reward function that rewards completions with more unique letters.
|
1253 |
-
return [float(len(set(completion))) for completion in completions]
|
1254 |
-
|
1255 |
-
trainer = GRPOTrainer(
|
1256 |
-
model="Qwen/Qwen2-0.5B-Instruct",
|
1257 |
-
reward_funcs=reward_func,
|
1258 |
-
train_dataset=dataset,
|
1259 |
-
)
|
1260 |
-
|
1261 |
-
trainer.train()
|
1262 |
-
```
|
1263 |
-
|
1264 |
-
Args:
|
1265 |
-
model (`Union[str, PreTrainedModel]`):
|
1266 |
-
Model to be trained. Can be either:
|
1267 |
-
|
1268 |
-
- A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or
|
1269 |
-
a path to a *directory* containing model weights saved using
|
1270 |
-
[`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is
|
1271 |
-
loaded using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keywork arguments
|
1272 |
-
in `args.model_init_kwargs`.
|
1273 |
-
- A [`~transformers.PreTrainedModel`] object. Only causal language models are supported.
|
1274 |
-
reward_funcs (`Union[RewardFunc, list[RewardFunc]]`):
|
1275 |
-
Reward functions to be used for computing the rewards. To compute the rewards, we call all the reward
|
1276 |
-
functions with the prompts and completions and sum the rewards. Can be either:
|
1277 |
-
|
1278 |
-
- A single reward function, such as:
|
1279 |
-
- A string: The *model ID* of a pretrained model hosted inside a model repo on huggingface.co, or a
|
1280 |
-
path to a *directory* containing model weights saved using
|
1281 |
-
[`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded
|
1282 |
-
using [`~transformers.AutoModelForSequenceClassification.from_pretrained`] with `num_labels=1` and the
|
1283 |
-
keyword arguments in `args.model_init_kwargs`.
|
1284 |
-
- A [`~transformers.PreTrainedModel`] object: Only sequence classification models are supported.
|
1285 |
-
- A custom reward function: The function is provided with the prompts and the generated completions,
|
1286 |
-
plus any additional columns in the dataset. It should return a list of rewards. For more details, see
|
1287 |
-
[Using a custom reward function](#using-a-custom-reward-function).
|
1288 |
-
- A list of reward functions, where each item can independently be any of the above types. Mixing different
|
1289 |
-
types within the list (e.g., a string model ID and a custom reward function) is allowed.
|
1290 |
-
args ([`GRPOConfig`], *optional*, defaults to `None`):
|
1291 |
-
Configuration for this trainer. If `None`, a default configuration is used.
|
1292 |
-
train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
|
1293 |
-
Dataset to use for training. It must include a column `"prompt"`. Any additional columns in the dataset is
|
1294 |
-
ignored. The format of the samples can be either:
|
1295 |
-
|
1296 |
-
- [Standard](dataset_formats#standard): Each sample contains plain text.
|
1297 |
-
- [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
|
1298 |
-
and content).
|
1299 |
-
eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
|
1300 |
-
Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
|
1301 |
-
processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*, defaults to `None`):
|
1302 |
-
Processing class used to process the data. The padding side must be set to "left". If `None`, the
|
1303 |
-
processing class is loaded from the model's name with [`~transformers.AutoTokenizer.from_pretrained`].
|
1304 |
-
reward_processing_classes (`Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]`, *optional*, defaults to `None`):
|
1305 |
-
Processing classes corresponding to the reward functions specified in `reward_funcs`. Can be either:
|
1306 |
-
|
1307 |
-
- A single processing class: Used when `reward_funcs` contains only one reward function.
|
1308 |
-
- A list of processing classes: Must match the order and length of the reward functions in `reward_funcs`.
|
1309 |
-
If set to `None`, or if an element of the list corresponding to a [`~transformers.PreTrainedModel`] is
|
1310 |
-
`None`, the tokenizer for the model is automatically loaded using [`~transformers.AutoTokenizer.from_pretrained`].
|
1311 |
-
For elements in `reward_funcs` that are custom reward functions (not [`~transformers.PreTrainedModel`]),
|
1312 |
-
the corresponding entries in `reward_processing_classes` are ignored.
|
1313 |
-
callbacks (list of [`~transformers.TrainerCallback`], *optional*, defaults to `None`):
|
1314 |
-
List of callbacks to customize the training loop. Will add those to the list of default callbacks
|
1315 |
-
detailed in [here](https://huggingface.co/docs/transformers/main_classes/callback).
|
1316 |
-
|
1317 |
-
If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`]
|
1318 |
-
method.
|
1319 |
-
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`):
|
1320 |
-
A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your
|
1321 |
-
model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
|
1322 |
-
peft_config ([`~peft.PeftConfig`], *optional*, defaults to `None`):
|
1323 |
-
PEFT configuration used to wrap the model. If `None`, the model is not wrapped.
|
1324 |
-
|
1325 |
-
"""
|
1326 |
-
def __init__(
|
1327 |
-
self,
|
1328 |
-
model,
|
1329 |
-
reward_funcs,
|
1330 |
-
args = None,
|
1331 |
-
train_dataset = None,
|
1332 |
-
eval_dataset = None,
|
1333 |
-
processing_class = None,
|
1334 |
-
reward_processing_classes = None,
|
1335 |
-
callbacks = None,
|
1336 |
-
peft_config = None,
|
1337 |
-
**kwargs
|
1338 |
-
):
|
1339 |
-
if args is None: args = UnslothGRPOConfig()
|
1340 |
-
use_bf16 = getattr(args, 'bf16', False)
|
1341 |
-
use_fp16 = getattr(args, 'fp16', False)
|
1342 |
-
force_float32 = False
|
1343 |
-
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
1344 |
-
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
1345 |
-
force_float32 = True
|
1346 |
-
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
1347 |
-
dtype = getattr(model.config, 'torch_dtype', None)
|
1348 |
-
if dtype is None: dtype = model.get_input_embeddings().dtype
|
1349 |
-
from unsloth_zoo.utils import _get_dtype
|
1350 |
-
dtype = _get_dtype(dtype)
|
1351 |
-
float16 = dtype == torch.float16
|
1352 |
-
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
1353 |
-
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
1354 |
-
if force_float32:
|
1355 |
-
args.fp16 = False
|
1356 |
-
args.bf16 = False
|
1357 |
-
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
1358 |
-
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
1359 |
-
args.fp16 = float16
|
1360 |
-
args.bf16 = not float16
|
1361 |
-
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
1362 |
-
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
1363 |
-
args.eval_strategy = 'steps'
|
1364 |
-
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
1365 |
-
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
1366 |
-
if ga_steps is not None and ga_steps > 1:
|
1367 |
-
from transformers import __version__ as transformers_version
|
1368 |
-
if Version(transformers_version) <= Version('4.45.2'):
|
1369 |
-
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
1370 |
-
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
1371 |
-
if getattr(args, 'eval_strategy', 'no') != 'no':
|
1372 |
-
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
1373 |
-
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
1374 |
-
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
1375 |
-
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
1376 |
-
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
1377 |
-
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
1378 |
-
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
1379 |
-
if force_float32:
|
1380 |
-
args.bf16_full_eval = False
|
1381 |
-
args.fp16_full_eval = False
|
1382 |
-
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
1383 |
-
args.bf16_full_eval = True
|
1384 |
-
args.fp16_full_eval = False
|
1385 |
-
elif not bf16_full_eval and not fp16_full_eval:
|
1386 |
-
args.bf16_full_eval = args.bf16
|
1387 |
-
args.fp16_full_eval = args.fp16
|
1388 |
-
_output_logits = False
|
1389 |
-
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
1390 |
-
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
1391 |
-
if _output_logits:
|
1392 |
-
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
1393 |
-
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
1394 |
-
pass
|
1395 |
-
else:
|
1396 |
-
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
1397 |
-
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
1398 |
-
if args_max_seq_length is None and model_max_seq_length is not None:
|
1399 |
-
max_seq_length = model.max_seq_length
|
1400 |
-
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
1401 |
-
if model is not None and hasattr(model, 'for_training'):
|
1402 |
-
model.for_training()
|
1403 |
-
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
1404 |
-
if 'processing_class' in locals():
|
1405 |
-
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
1406 |
-
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
1407 |
-
other_metrics = []
|
1408 |
-
if not isinstance(reward_funcs, list): _reward_funcs = [reward_funcs]
|
1409 |
-
else: _reward_funcs = reward_funcs
|
1410 |
-
for reward_func in _reward_funcs:
|
1411 |
-
try:
|
1412 |
-
reward_func_name = reward_func.__name__
|
1413 |
-
other_metrics.append(f'rewards/{reward_func_name}')
|
1414 |
-
except: pass
|
1415 |
-
|
1416 |
-
from unsloth_zoo.logging_utils import PatchRLStatistics
|
1417 |
-
PatchRLStatistics('grpo_trainer', other_metrics)
|
1418 |
-
|
1419 |
-
super().__init__(
|
1420 |
-
model = model,
|
1421 |
-
reward_funcs = reward_funcs,
|
1422 |
-
args = args,
|
1423 |
-
train_dataset = train_dataset,
|
1424 |
-
eval_dataset = eval_dataset,
|
1425 |
-
processing_class = processing_class,
|
1426 |
-
reward_processing_classes = reward_processing_classes,
|
1427 |
-
callbacks = callbacks,
|
1428 |
-
peft_config = peft_config,**kwargs)
|
1429 |
-
if hasattr(self, 'neftune_hook_handle'):
|
1430 |
-
self.neftune_hook_handle.remove()
|
1431 |
-
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
1432 |
-
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
1433 |
-
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
1434 |
-
pass
|
1435 |
-
|
1436 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unsloth_compiled_cache/UnslothKTOTrainer.py
DELETED
@@ -1,1838 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
2025.3.13
|
3 |
-
2025.3.15
|
4 |
-
4.48.3
|
5 |
-
0.15.2
|
6 |
-
__UNSLOTH_VERSIONING__
|
7 |
-
"""
|
8 |
-
from torch import Tensor
|
9 |
-
import torch
|
10 |
-
import torch.nn as nn
|
11 |
-
from torch.nn import functional as F
|
12 |
-
from trl.trainer.kto_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, KTOConfig, KTOTrainer, Literal, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedModelWrapper, PreTrainedTokenizerBase, ProcessorMixin, SequentialSampler, Trainer, TrainerCallback, TrainingArguments, Union, _get_kl_dataset, _process_tokens, _tokenize, amp, concatenate_datasets, contextmanager, create_reference_model, deepcopy, defaultdict, disable_dropout_in_model, generate_model_card, get_comet_experiment_url, has_length, inspect, is_comet_available, is_peft_available, is_wandb_available, itemgetter, log_table_to_comet_experiment, maybe_apply_chat_template, maybe_extract_prompt, maybe_unpair_preference_dataset, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_model_for_kbit_training, random, textwrap, torch, tqdm, transformers, version, wandb, warnings)
|
13 |
-
|
14 |
-
|
15 |
-
import os
|
16 |
-
from typing import *
|
17 |
-
from dataclasses import dataclass, field
|
18 |
-
from packaging.version import Version
|
19 |
-
import torch
|
20 |
-
import numpy as np
|
21 |
-
from contextlib import nullcontext
|
22 |
-
from torch.nn import functional as F
|
23 |
-
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
|
24 |
-
|
25 |
-
torch_compile_options = {
|
26 |
-
"epilogue_fusion" : True,
|
27 |
-
"max_autotune" : False,
|
28 |
-
"shape_padding" : True,
|
29 |
-
"trace.enabled" : False,
|
30 |
-
"triton.cudagraphs" : False,
|
31 |
-
}
|
32 |
-
|
33 |
-
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
34 |
-
def selective_log_softmax(logits, index):
|
35 |
-
logits = logits.to(torch.float32)
|
36 |
-
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
37 |
-
# loop to reduce peak mem consumption
|
38 |
-
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
39 |
-
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
40 |
-
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
41 |
-
return per_token_logps
|
42 |
-
@dataclass
|
43 |
-
class UnslothKTOConfig(KTOConfig):
|
44 |
-
"""
|
45 |
-
|
46 |
-
Configuration class for the [`KTOTrainer`].
|
47 |
-
|
48 |
-
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
49 |
-
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
50 |
-
command line.
|
51 |
-
|
52 |
-
Parameters:
|
53 |
-
learning_rate (`float`, *optional*, defaults to `5e-7`):
|
54 |
-
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
|
55 |
-
[`~transformers.TrainingArguments`].
|
56 |
-
max_length (`int` or `None`, *optional*, defaults to `1024`):
|
57 |
-
Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
|
58 |
-
to use the default data collator.
|
59 |
-
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
|
60 |
-
Maximum length of the prompt. This argument is required if you want to use the default data collator.
|
61 |
-
max_completion_length (`int` or `None`, *optional*, defaults to `None`):
|
62 |
-
Maximum length of the completion. This argument is required if you want to use the default data collator
|
63 |
-
and your model is an encoder-decoder.
|
64 |
-
beta (`float`, *optional*, defaults to `0.1`):
|
65 |
-
Parameter controlling the deviation from the reference model. Higher β means less deviation from the
|
66 |
-
reference model.
|
67 |
-
loss_type (`str`, *optional*, defaults to `"kto"`):
|
68 |
-
Type of loss to use. Possible values are:
|
69 |
-
|
70 |
-
- `"kto"`: KTO loss from the [KTO](https://huggingface.co/papers/2402.01306) paper.
|
71 |
-
- `"apo_zero_unpaired"`: Unpaired variant of APO-zero loss from the [APO](https://huggingface.co/papers/2408.06266) paper.
|
72 |
-
|
73 |
-
desirable_weight (`float`, *optional*, defaults to `1.0`):
|
74 |
-
Desirable losses are weighed by this factor to counter unequal number of desirable and undesirable paris.
|
75 |
-
undesirable_weight (`float`, *optional*, defaults to `1.0`):
|
76 |
-
Undesirable losses are weighed by this factor to counter unequal number of desirable and undesirable pairs.
|
77 |
-
label_pad_token_id (`int`, *optional*, defaults to `-100`):
|
78 |
-
Label pad token id. This argument is required if you want to use the default data collator.
|
79 |
-
padding_value (`int` or `None`, *optional*, defaults to `None`):
|
80 |
-
Padding value to use. If `None`, the padding value of the tokenizer is used.
|
81 |
-
truncation_mode (`str`, *optional*, defaults to `"keep_end"`):
|
82 |
-
Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`.
|
83 |
-
This argument is required if you want to use the default data collator.
|
84 |
-
generate_during_eval (`bool`, *optional*, defaults to `False`):
|
85 |
-
If `True`, generates and logs completions from both the model and the reference model to W&B or Comet during
|
86 |
-
evaluation.
|
87 |
-
is_encoder_decoder (`bool` or `None`, *optional*, defaults to `None`):
|
88 |
-
When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument,
|
89 |
-
you need to specify if the model returned by the callable is an encoder-decoder model.
|
90 |
-
precompute_ref_log_probs (`bool`, *optional*, defaults to `False`):
|
91 |
-
Whether to precompute reference model log probabilities for training and evaluation datasets. This is
|
92 |
-
useful when training without the reference model to reduce the total GPU memory needed.
|
93 |
-
model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
|
94 |
-
Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
|
95 |
-
string.
|
96 |
-
ref_model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
|
97 |
-
Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the reference model
|
98 |
-
from a string.
|
99 |
-
dataset_num_proc: (`int` or `None`, *optional*, defaults to `None`):
|
100 |
-
Number of processes to use for processing the dataset.
|
101 |
-
disable_dropout (`bool`, *optional*, defaults to `True`):
|
102 |
-
Whether to disable dropout in the model and reference model.
|
103 |
-
|
104 |
-
"""
|
105 |
-
vllm_sampling_params: Optional[Any] = field(
|
106 |
-
default = None,
|
107 |
-
metadata = {'help': 'vLLM SamplingParams'},
|
108 |
-
)
|
109 |
-
unsloth_num_chunks : Optional[int] = field(
|
110 |
-
default = -1,
|
111 |
-
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
112 |
-
)
|
113 |
-
def __init__(
|
114 |
-
self,
|
115 |
-
output_dir = None,
|
116 |
-
overwrite_output_dir = None,
|
117 |
-
do_train = False,
|
118 |
-
do_eval = False,
|
119 |
-
do_predict = False,
|
120 |
-
eval_strategy = 'no',
|
121 |
-
prediction_loss_only = False,
|
122 |
-
per_device_train_batch_size = 4,
|
123 |
-
per_device_eval_batch_size = 4,
|
124 |
-
per_gpu_train_batch_size = None,
|
125 |
-
per_gpu_eval_batch_size = None,
|
126 |
-
gradient_accumulation_steps = 2,
|
127 |
-
eval_accumulation_steps = 2,
|
128 |
-
eval_delay = 0,
|
129 |
-
torch_empty_cache_steps = 250,
|
130 |
-
learning_rate = 5e-05,
|
131 |
-
weight_decay = 0.01,
|
132 |
-
adam_beta1 = 0.9,
|
133 |
-
adam_beta2 = 0.999,
|
134 |
-
adam_epsilon = 1e-08,
|
135 |
-
max_grad_norm = 1.0,
|
136 |
-
num_train_epochs = 3.0,
|
137 |
-
max_steps = -1,
|
138 |
-
lr_scheduler_type = 'linear',
|
139 |
-
warmup_ratio = 0.1,
|
140 |
-
warmup_steps = 0,
|
141 |
-
log_level = 'passive',
|
142 |
-
log_level_replica = 'warning',
|
143 |
-
log_on_each_node = True,
|
144 |
-
logging_dir = None,
|
145 |
-
logging_strategy = 'steps',
|
146 |
-
logging_first_step = False,
|
147 |
-
logging_steps = 1,
|
148 |
-
logging_nan_inf_filter = False,
|
149 |
-
save_strategy = 'steps',
|
150 |
-
save_steps = 500,
|
151 |
-
save_total_limit = None,
|
152 |
-
save_safetensors = True,
|
153 |
-
save_on_each_node = False,
|
154 |
-
save_only_model = False,
|
155 |
-
restore_callback_states_from_checkpoint = False,
|
156 |
-
no_cuda = False,
|
157 |
-
use_cpu = False,
|
158 |
-
use_mps_device = False,
|
159 |
-
seed = 3407,
|
160 |
-
data_seed = 3407,
|
161 |
-
jit_mode_eval = False,
|
162 |
-
use_ipex = False,
|
163 |
-
bf16 = False,
|
164 |
-
fp16 = False,
|
165 |
-
fp16_opt_level = 'O1',
|
166 |
-
half_precision_backend = 'auto',
|
167 |
-
bf16_full_eval = False,
|
168 |
-
fp16_full_eval = False,
|
169 |
-
tf32 = None,
|
170 |
-
local_rank = -1,
|
171 |
-
ddp_backend = None,
|
172 |
-
tpu_num_cores = None,
|
173 |
-
tpu_metrics_debug = False,
|
174 |
-
debug = '',
|
175 |
-
dataloader_drop_last = False,
|
176 |
-
eval_steps = None,
|
177 |
-
dataloader_num_workers = 0,
|
178 |
-
dataloader_prefetch_factor = None,
|
179 |
-
past_index = -1,
|
180 |
-
run_name = None,
|
181 |
-
disable_tqdm = None,
|
182 |
-
remove_unused_columns = True,
|
183 |
-
label_names = None,
|
184 |
-
load_best_model_at_end = False,
|
185 |
-
metric_for_best_model = None,
|
186 |
-
greater_is_better = None,
|
187 |
-
ignore_data_skip = False,
|
188 |
-
fsdp = '',
|
189 |
-
fsdp_min_num_params = 0,
|
190 |
-
fsdp_config = None,
|
191 |
-
fsdp_transformer_layer_cls_to_wrap = None,
|
192 |
-
accelerator_config = None,
|
193 |
-
deepspeed = None,
|
194 |
-
label_smoothing_factor = 0.0,
|
195 |
-
optim = 'adamw_8bit',
|
196 |
-
optim_args = None,
|
197 |
-
adafactor = False,
|
198 |
-
group_by_length = False,
|
199 |
-
length_column_name = 'length',
|
200 |
-
report_to = None,
|
201 |
-
ddp_find_unused_parameters = None,
|
202 |
-
ddp_bucket_cap_mb = None,
|
203 |
-
ddp_broadcast_buffers = None,
|
204 |
-
dataloader_pin_memory = True,
|
205 |
-
dataloader_persistent_workers = False,
|
206 |
-
skip_memory_metrics = True,
|
207 |
-
use_legacy_prediction_loop = False,
|
208 |
-
push_to_hub = False,
|
209 |
-
resume_from_checkpoint = None,
|
210 |
-
hub_model_id = None,
|
211 |
-
hub_strategy = 'every_save',
|
212 |
-
hub_token = None,
|
213 |
-
hub_private_repo = None,
|
214 |
-
hub_always_push = False,
|
215 |
-
gradient_checkpointing = False,
|
216 |
-
gradient_checkpointing_kwargs = None,
|
217 |
-
include_inputs_for_metrics = False,
|
218 |
-
eval_do_concat_batches = True,
|
219 |
-
fp16_backend = 'auto',
|
220 |
-
evaluation_strategy = None,
|
221 |
-
push_to_hub_model_id = None,
|
222 |
-
push_to_hub_organization = None,
|
223 |
-
push_to_hub_token = None,
|
224 |
-
mp_parameters = '',
|
225 |
-
auto_find_batch_size = False,
|
226 |
-
full_determinism = False,
|
227 |
-
torchdynamo = None,
|
228 |
-
ray_scope = 'last',
|
229 |
-
ddp_timeout = 1800,
|
230 |
-
torch_compile = False,
|
231 |
-
torch_compile_backend = None,
|
232 |
-
torch_compile_mode = None,
|
233 |
-
dispatch_batches = None,
|
234 |
-
split_batches = None,
|
235 |
-
include_tokens_per_second = False,
|
236 |
-
include_num_input_tokens_seen = False,
|
237 |
-
neftune_noise_alpha = None,
|
238 |
-
optim_target_modules = None,
|
239 |
-
batch_eval_metrics = False,
|
240 |
-
eval_on_start = False,
|
241 |
-
use_liger_kernel = False,
|
242 |
-
eval_use_gather_object = False,
|
243 |
-
average_tokens_across_devices = False,
|
244 |
-
max_length = 1024,
|
245 |
-
max_prompt_length = 512,
|
246 |
-
max_completion_length = None,
|
247 |
-
beta = 0.1,
|
248 |
-
loss_type = 'kto',
|
249 |
-
desirable_weight = 1.0,
|
250 |
-
undesirable_weight = 1.0,
|
251 |
-
label_pad_token_id = -100,
|
252 |
-
padding_value = None,
|
253 |
-
truncation_mode = 'keep_end',
|
254 |
-
generate_during_eval = False,
|
255 |
-
is_encoder_decoder = None,
|
256 |
-
disable_dropout = True,
|
257 |
-
precompute_ref_log_probs = False,
|
258 |
-
model_init_kwargs = None,
|
259 |
-
ref_model_init_kwargs = None,
|
260 |
-
dataset_num_proc = None,
|
261 |
-
vllm_sampling_params = None,
|
262 |
-
unsloth_num_chunks = -1,
|
263 |
-
**kwargs,
|
264 |
-
):
|
265 |
-
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
266 |
-
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
267 |
-
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
268 |
-
output_dir = 'unsloth_training_checkpoints'
|
269 |
-
save_strategy = 'no'
|
270 |
-
if dataset_num_proc is None:
|
271 |
-
from multiprocessing import cpu_count
|
272 |
-
dataset_num_proc = cpu_count()
|
273 |
-
|
274 |
-
super().__init__(
|
275 |
-
output_dir = output_dir,
|
276 |
-
overwrite_output_dir = overwrite_output_dir,
|
277 |
-
do_train = do_train,
|
278 |
-
do_eval = do_eval,
|
279 |
-
do_predict = do_predict,
|
280 |
-
eval_strategy = eval_strategy,
|
281 |
-
prediction_loss_only = prediction_loss_only,
|
282 |
-
per_device_train_batch_size = per_device_train_batch_size,
|
283 |
-
per_device_eval_batch_size = per_device_eval_batch_size,
|
284 |
-
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
285 |
-
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
286 |
-
gradient_accumulation_steps = gradient_accumulation_steps,
|
287 |
-
eval_accumulation_steps = eval_accumulation_steps,
|
288 |
-
eval_delay = eval_delay,
|
289 |
-
torch_empty_cache_steps = torch_empty_cache_steps,
|
290 |
-
learning_rate = learning_rate,
|
291 |
-
weight_decay = weight_decay,
|
292 |
-
adam_beta1 = adam_beta1,
|
293 |
-
adam_beta2 = adam_beta2,
|
294 |
-
adam_epsilon = adam_epsilon,
|
295 |
-
max_grad_norm = max_grad_norm,
|
296 |
-
num_train_epochs = num_train_epochs,
|
297 |
-
max_steps = max_steps,
|
298 |
-
lr_scheduler_type = lr_scheduler_type,
|
299 |
-
warmup_ratio = warmup_ratio,
|
300 |
-
warmup_steps = warmup_steps,
|
301 |
-
log_level = log_level,
|
302 |
-
log_level_replica = log_level_replica,
|
303 |
-
log_on_each_node = log_on_each_node,
|
304 |
-
logging_dir = logging_dir,
|
305 |
-
logging_strategy = logging_strategy,
|
306 |
-
logging_first_step = logging_first_step,
|
307 |
-
logging_steps = logging_steps,
|
308 |
-
logging_nan_inf_filter = logging_nan_inf_filter,
|
309 |
-
save_strategy = save_strategy,
|
310 |
-
save_steps = save_steps,
|
311 |
-
save_total_limit = save_total_limit,
|
312 |
-
save_safetensors = save_safetensors,
|
313 |
-
save_on_each_node = save_on_each_node,
|
314 |
-
save_only_model = save_only_model,
|
315 |
-
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
316 |
-
no_cuda = no_cuda,
|
317 |
-
use_cpu = use_cpu,
|
318 |
-
use_mps_device = use_mps_device,
|
319 |
-
seed = seed,
|
320 |
-
data_seed = data_seed,
|
321 |
-
jit_mode_eval = jit_mode_eval,
|
322 |
-
use_ipex = use_ipex,
|
323 |
-
bf16 = bf16,
|
324 |
-
fp16 = fp16,
|
325 |
-
fp16_opt_level = fp16_opt_level,
|
326 |
-
half_precision_backend = half_precision_backend,
|
327 |
-
bf16_full_eval = bf16_full_eval,
|
328 |
-
fp16_full_eval = fp16_full_eval,
|
329 |
-
tf32 = tf32,
|
330 |
-
local_rank = local_rank,
|
331 |
-
ddp_backend = ddp_backend,
|
332 |
-
tpu_num_cores = tpu_num_cores,
|
333 |
-
tpu_metrics_debug = tpu_metrics_debug,
|
334 |
-
debug = debug,
|
335 |
-
dataloader_drop_last = dataloader_drop_last,
|
336 |
-
eval_steps = eval_steps,
|
337 |
-
dataloader_num_workers = dataloader_num_workers,
|
338 |
-
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
339 |
-
past_index = past_index,
|
340 |
-
run_name = run_name,
|
341 |
-
disable_tqdm = disable_tqdm,
|
342 |
-
remove_unused_columns = remove_unused_columns,
|
343 |
-
label_names = label_names,
|
344 |
-
load_best_model_at_end = load_best_model_at_end,
|
345 |
-
metric_for_best_model = metric_for_best_model,
|
346 |
-
greater_is_better = greater_is_better,
|
347 |
-
ignore_data_skip = ignore_data_skip,
|
348 |
-
fsdp = fsdp,
|
349 |
-
fsdp_min_num_params = fsdp_min_num_params,
|
350 |
-
fsdp_config = fsdp_config,
|
351 |
-
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
352 |
-
accelerator_config = accelerator_config,
|
353 |
-
deepspeed = deepspeed,
|
354 |
-
label_smoothing_factor = label_smoothing_factor,
|
355 |
-
optim = optim,
|
356 |
-
optim_args = optim_args,
|
357 |
-
adafactor = adafactor,
|
358 |
-
group_by_length = group_by_length,
|
359 |
-
length_column_name = length_column_name,
|
360 |
-
report_to = report_to,
|
361 |
-
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
362 |
-
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
363 |
-
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
364 |
-
dataloader_pin_memory = dataloader_pin_memory,
|
365 |
-
dataloader_persistent_workers = dataloader_persistent_workers,
|
366 |
-
skip_memory_metrics = skip_memory_metrics,
|
367 |
-
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
368 |
-
push_to_hub = push_to_hub,
|
369 |
-
resume_from_checkpoint = resume_from_checkpoint,
|
370 |
-
hub_model_id = hub_model_id,
|
371 |
-
hub_strategy = hub_strategy,
|
372 |
-
hub_token = hub_token,
|
373 |
-
hub_private_repo = hub_private_repo,
|
374 |
-
hub_always_push = hub_always_push,
|
375 |
-
gradient_checkpointing = gradient_checkpointing,
|
376 |
-
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
377 |
-
include_inputs_for_metrics = include_inputs_for_metrics,
|
378 |
-
eval_do_concat_batches = eval_do_concat_batches,
|
379 |
-
fp16_backend = fp16_backend,
|
380 |
-
evaluation_strategy = evaluation_strategy,
|
381 |
-
push_to_hub_model_id = push_to_hub_model_id,
|
382 |
-
push_to_hub_organization = push_to_hub_organization,
|
383 |
-
push_to_hub_token = push_to_hub_token,
|
384 |
-
mp_parameters = mp_parameters,
|
385 |
-
auto_find_batch_size = auto_find_batch_size,
|
386 |
-
full_determinism = full_determinism,
|
387 |
-
torchdynamo = torchdynamo,
|
388 |
-
ray_scope = ray_scope,
|
389 |
-
ddp_timeout = ddp_timeout,
|
390 |
-
torch_compile = torch_compile,
|
391 |
-
torch_compile_backend = torch_compile_backend,
|
392 |
-
torch_compile_mode = torch_compile_mode,
|
393 |
-
dispatch_batches = dispatch_batches,
|
394 |
-
split_batches = split_batches,
|
395 |
-
include_tokens_per_second = include_tokens_per_second,
|
396 |
-
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
397 |
-
neftune_noise_alpha = neftune_noise_alpha,
|
398 |
-
optim_target_modules = optim_target_modules,
|
399 |
-
batch_eval_metrics = batch_eval_metrics,
|
400 |
-
eval_on_start = eval_on_start,
|
401 |
-
use_liger_kernel = use_liger_kernel,
|
402 |
-
eval_use_gather_object = eval_use_gather_object,
|
403 |
-
average_tokens_across_devices = average_tokens_across_devices,
|
404 |
-
max_length = max_length,
|
405 |
-
max_prompt_length = max_prompt_length,
|
406 |
-
max_completion_length = max_completion_length,
|
407 |
-
beta = beta,
|
408 |
-
loss_type = loss_type,
|
409 |
-
desirable_weight = desirable_weight,
|
410 |
-
undesirable_weight = undesirable_weight,
|
411 |
-
label_pad_token_id = label_pad_token_id,
|
412 |
-
padding_value = padding_value,
|
413 |
-
truncation_mode = truncation_mode,
|
414 |
-
generate_during_eval = generate_during_eval,
|
415 |
-
is_encoder_decoder = is_encoder_decoder,
|
416 |
-
disable_dropout = disable_dropout,
|
417 |
-
precompute_ref_log_probs = precompute_ref_log_probs,
|
418 |
-
model_init_kwargs = model_init_kwargs,
|
419 |
-
ref_model_init_kwargs = ref_model_init_kwargs,
|
420 |
-
dataset_num_proc = dataset_num_proc,**kwargs)
|
421 |
-
self.vllm_sampling_params = vllm_sampling_params
|
422 |
-
self.unsloth_num_chunks = unsloth_num_chunks
|
423 |
-
pass
|
424 |
-
|
425 |
-
class _UnslothKTOTrainer(Trainer):
|
426 |
-
r""""""
|
427 |
-
|
428 |
-
_tag_names = ["trl", "kto"]
|
429 |
-
|
430 |
-
def __init__(
|
431 |
-
self,
|
432 |
-
model: Union[PreTrainedModel, nn.Module, str] = None,
|
433 |
-
ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
|
434 |
-
args: KTOConfig = None,
|
435 |
-
train_dataset: Optional[Dataset] = None,
|
436 |
-
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
437 |
-
processing_class: Optional[
|
438 |
-
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
439 |
-
] = None,
|
440 |
-
data_collator: Optional[DataCollator] = None,
|
441 |
-
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
442 |
-
callbacks: Optional[list[TrainerCallback]] = None,
|
443 |
-
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
444 |
-
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
445 |
-
peft_config: Optional[dict] = None,
|
446 |
-
compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
|
447 |
-
model_adapter_name: Optional[str] = None,
|
448 |
-
ref_adapter_name: Optional[str] = None,
|
449 |
-
):
|
450 |
-
if type(args) is TrainingArguments:
|
451 |
-
raise ValueError("Please use `KTOConfig` instead TrainingArguments.")
|
452 |
-
|
453 |
-
if not isinstance(model, str) and ref_model is model:
|
454 |
-
raise ValueError(
|
455 |
-
"`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
|
456 |
-
"same as `model`, you must mass a copy of it, or `None` if you use peft."
|
457 |
-
)
|
458 |
-
|
459 |
-
if args.model_init_kwargs is None:
|
460 |
-
model_init_kwargs = {}
|
461 |
-
elif not isinstance(model, str):
|
462 |
-
raise ValueError("You passed model_kwargs to the KTOTrainer. But your model is already instantiated.")
|
463 |
-
else:
|
464 |
-
model_init_kwargs = args.model_init_kwargs
|
465 |
-
torch_dtype = model_init_kwargs.get("torch_dtype")
|
466 |
-
if torch_dtype is not None:
|
467 |
-
# Convert to `torch.dtype` if an str is passed
|
468 |
-
if isinstance(torch_dtype, str) and torch_dtype != "auto":
|
469 |
-
torch_dtype = getattr(torch, torch_dtype)
|
470 |
-
if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
|
471 |
-
raise ValueError(
|
472 |
-
f"Invalid `torch_dtype` passed to the KTOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
|
473 |
-
)
|
474 |
-
model_init_kwargs["torch_dtype"] = torch_dtype
|
475 |
-
|
476 |
-
if args.ref_model_init_kwargs is None:
|
477 |
-
ref_model_init_kwargs = {}
|
478 |
-
elif not isinstance(ref_model, str):
|
479 |
-
raise ValueError(
|
480 |
-
"You passed ref_model_kwargs to the KTOTrainer. But your ref_model is already instantiated."
|
481 |
-
)
|
482 |
-
else:
|
483 |
-
ref_model_init_kwargs = args.ref_model_init_kwargs
|
484 |
-
torch_dtype = ref_model_init_kwargs.get("torch_dtype")
|
485 |
-
if torch_dtype is not None:
|
486 |
-
# Convert to `torch.dtype` if an str is passed
|
487 |
-
if isinstance(torch_dtype, str) and torch_dtype != "auto":
|
488 |
-
torch_dtype = getattr(torch, torch_dtype)
|
489 |
-
if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
|
490 |
-
raise ValueError(
|
491 |
-
f"Invalid `torch_dtype` passed to the KTOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
|
492 |
-
)
|
493 |
-
ref_model_init_kwargs["torch_dtype"] = torch_dtype
|
494 |
-
|
495 |
-
if isinstance(model, str):
|
496 |
-
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
|
497 |
-
|
498 |
-
if isinstance(ref_model, str):
|
499 |
-
ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **ref_model_init_kwargs)
|
500 |
-
|
501 |
-
# Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
|
502 |
-
# has been called in order to properly call autocast if needed.
|
503 |
-
self._peft_has_been_casted_to_bf16 = False
|
504 |
-
|
505 |
-
if not is_peft_available() and peft_config is not None:
|
506 |
-
raise ValueError(
|
507 |
-
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it with `pip install peft` to use the PEFT models"
|
508 |
-
)
|
509 |
-
elif is_peft_available() and peft_config is not None:
|
510 |
-
# if model is a peft model and we have a peft_config, we merge and unload it first
|
511 |
-
if isinstance(model, PeftModel):
|
512 |
-
model = model.merge_and_unload()
|
513 |
-
|
514 |
-
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
|
515 |
-
_support_gc_kwargs = hasattr(
|
516 |
-
args, "gradient_checkpointing_kwargs"
|
517 |
-
) and "gradient_checkpointing_kwargs" in list(
|
518 |
-
inspect.signature(prepare_model_for_kbit_training).parameters
|
519 |
-
)
|
520 |
-
|
521 |
-
prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
|
522 |
-
|
523 |
-
if _support_gc_kwargs:
|
524 |
-
prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
|
525 |
-
|
526 |
-
model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
|
527 |
-
elif getattr(args, "gradient_checkpointing", False):
|
528 |
-
# For backward compatibility with older versions of transformers
|
529 |
-
if hasattr(model, "enable_input_require_grads"):
|
530 |
-
model.enable_input_require_grads()
|
531 |
-
else:
|
532 |
-
|
533 |
-
def make_inputs_require_grad(module, input, output):
|
534 |
-
output.requires_grad_(True)
|
535 |
-
|
536 |
-
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
537 |
-
|
538 |
-
# get peft model with the given config
|
539 |
-
model = model
|
540 |
-
if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
|
541 |
-
peft_module_casting_to_bf16(model)
|
542 |
-
# If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
|
543 |
-
self._peft_has_been_casted_to_bf16 = True
|
544 |
-
|
545 |
-
# For models that use gradient_checkpointing, we need to attach a hook that enables input
|
546 |
-
# to explicitly have `requires_grad=True`, otherwise training will either silently
|
547 |
-
# fail or completely fail.
|
548 |
-
elif getattr(args, "gradient_checkpointing", False):
|
549 |
-
# For backward compatibility with older versions of transformers
|
550 |
-
if hasattr(model, "enable_input_require_grads"):
|
551 |
-
model.enable_input_require_grads()
|
552 |
-
else:
|
553 |
-
|
554 |
-
def make_inputs_require_grad(module, input, output):
|
555 |
-
output.requires_grad_(True)
|
556 |
-
|
557 |
-
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
558 |
-
|
559 |
-
if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
|
560 |
-
raise ValueError(
|
561 |
-
"`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
|
562 |
-
" Please install `wandb` or `comet-ml` to resolve."
|
563 |
-
)
|
564 |
-
|
565 |
-
if model is not None:
|
566 |
-
self.is_encoder_decoder = model.config.is_encoder_decoder
|
567 |
-
elif args.is_encoder_decoder is None:
|
568 |
-
raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
|
569 |
-
else:
|
570 |
-
self.is_encoder_decoder = args.is_encoder_decoder
|
571 |
-
|
572 |
-
self.is_peft_model = is_peft_available() and isinstance(model, PeftModel)
|
573 |
-
self.model_adapter_name = model_adapter_name
|
574 |
-
self.ref_adapter_name = ref_adapter_name
|
575 |
-
|
576 |
-
if ref_model:
|
577 |
-
self.ref_model = ref_model
|
578 |
-
elif self.is_peft_model or args.precompute_ref_log_probs:
|
579 |
-
# The `model` with adapters turned off will be used as the reference model
|
580 |
-
self.ref_model = None
|
581 |
-
else:
|
582 |
-
self.ref_model = create_reference_model(model)
|
583 |
-
|
584 |
-
if processing_class is None:
|
585 |
-
raise ValueError(
|
586 |
-
"max_length or a processing_class must be specified when using the default DPODataCollatorWithPadding"
|
587 |
-
)
|
588 |
-
if args.max_length is None:
|
589 |
-
warnings.warn(
|
590 |
-
"When using DPODataCollatorWithPadding, you should set `max_length` in the KTOTrainer's init"
|
591 |
-
" it will be set to `512` by default, but you should do it yourself in the future.",
|
592 |
-
UserWarning,
|
593 |
-
)
|
594 |
-
max_length = 512
|
595 |
-
if args.max_length is not None:
|
596 |
-
max_length = args.max_length
|
597 |
-
|
598 |
-
if args.max_prompt_length is None:
|
599 |
-
warnings.warn(
|
600 |
-
"When using DPODataCollatorWithPadding, you should set `max_prompt_length` in the KTOTrainer's init"
|
601 |
-
" it will be set to `128` by default, but you should do it yourself in the future.",
|
602 |
-
UserWarning,
|
603 |
-
)
|
604 |
-
max_prompt_length = 128
|
605 |
-
if args.max_prompt_length is not None:
|
606 |
-
max_prompt_length = args.max_prompt_length
|
607 |
-
|
608 |
-
max_completion_length = None
|
609 |
-
if args.max_completion_length is None and self.is_encoder_decoder:
|
610 |
-
warnings.warn(
|
611 |
-
"When using DPODataCollatorWithPadding with an encoder decoder architecture, you should set `max_completion_length` in the KTOTrainer's init"
|
612 |
-
" it will be set to `128` by default, but you should do it yourself in the future.",
|
613 |
-
UserWarning,
|
614 |
-
)
|
615 |
-
max_completion_length = 128
|
616 |
-
if args.max_completion_length is not None and self.is_encoder_decoder:
|
617 |
-
max_completion_length = args.max_completion_length
|
618 |
-
|
619 |
-
if data_collator is None:
|
620 |
-
data_collator = DPODataCollatorWithPadding(
|
621 |
-
pad_token_id=processing_class.pad_token_id,
|
622 |
-
label_pad_token_id=args.label_pad_token_id,
|
623 |
-
is_encoder_decoder=self.is_encoder_decoder,
|
624 |
-
)
|
625 |
-
|
626 |
-
if args.remove_unused_columns:
|
627 |
-
args.remove_unused_columns = False
|
628 |
-
# warn users
|
629 |
-
warnings.warn(
|
630 |
-
"When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your KTOConfig"
|
631 |
-
" we have set it for you, but you should do it yourself in the future.",
|
632 |
-
UserWarning,
|
633 |
-
)
|
634 |
-
|
635 |
-
self.use_dpo_data_collator = True
|
636 |
-
else:
|
637 |
-
self.use_dpo_data_collator = False
|
638 |
-
|
639 |
-
# Disable dropout in the model and reference model
|
640 |
-
if args.disable_dropout:
|
641 |
-
disable_dropout_in_model(model)
|
642 |
-
if self.ref_model is not None:
|
643 |
-
disable_dropout_in_model(self.ref_model)
|
644 |
-
|
645 |
-
self.loss_type = args.loss_type
|
646 |
-
self.max_length = max_length
|
647 |
-
self.generate_during_eval = args.generate_during_eval
|
648 |
-
self.label_pad_token_id = args.label_pad_token_id
|
649 |
-
self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id
|
650 |
-
self.max_prompt_length = max_prompt_length
|
651 |
-
self.truncation_mode = args.truncation_mode
|
652 |
-
self.max_completion_length = max_completion_length
|
653 |
-
self.processing_class = processing_class
|
654 |
-
self.precompute_ref_log_probs = args.precompute_ref_log_probs
|
655 |
-
|
656 |
-
# Not all losses require a KL calculation
|
657 |
-
self.calculate_KL = True
|
658 |
-
if self.loss_type in ["apo_zero_unpaired"]:
|
659 |
-
self.calculate_KL = False
|
660 |
-
|
661 |
-
# Since ref_logs are precomputed on the first call to get_train/eval_dataloader
|
662 |
-
# keep track of first called to avoid computation of future calls
|
663 |
-
self._precomputed_train_ref_log_probs = False
|
664 |
-
self._precomputed_eval_ref_log_probs = False
|
665 |
-
|
666 |
-
# metric
|
667 |
-
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
668 |
-
|
669 |
-
# KTO parameter
|
670 |
-
self.beta = args.beta
|
671 |
-
self.desirable_weight = args.desirable_weight
|
672 |
-
self.undesirable_weight = args.undesirable_weight
|
673 |
-
self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
|
674 |
-
self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
|
675 |
-
if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
|
676 |
-
warnings.warn(
|
677 |
-
"You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
|
678 |
-
"`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
|
679 |
-
"greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
|
680 |
-
"loss.",
|
681 |
-
UserWarning,
|
682 |
-
)
|
683 |
-
|
684 |
-
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
|
685 |
-
# input tensor associated with the key "input_ids". However, in KTO, the sampled data does not include the
|
686 |
-
# "input_ids" key. Instead, the available keys are "prompt_input_ids" and "completion_input_ids". As a result,
|
687 |
-
# the trainer issues the warning: "Could not estimate the number of tokens of the input, floating-point
|
688 |
-
# operations will not be computed." To suppress this warning, we set the "estimate_tokens" key in the model's
|
689 |
-
# "warnings_issued" dictionary to True. This acts as a flag to indicate that the warning has already been
|
690 |
-
# issued.
|
691 |
-
model.warnings_issued["estimate_tokens"] = True
|
692 |
-
|
693 |
-
# Compute that only on the main process for faster data processing.
|
694 |
-
# see: https://github.com/huggingface/trl/pull/1255
|
695 |
-
with PartialState().local_main_process_first():
|
696 |
-
# Extract the prompt if needed
|
697 |
-
train_dataset = train_dataset.map(
|
698 |
-
maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from train dataset"
|
699 |
-
)
|
700 |
-
# Unpair the dataset if needed
|
701 |
-
train_dataset = maybe_unpair_preference_dataset(
|
702 |
-
train_dataset, args.dataset_num_proc, desc="Unpairing train dataset"
|
703 |
-
)
|
704 |
-
# Apply the chat template if needed
|
705 |
-
train_dataset = train_dataset.map(
|
706 |
-
maybe_apply_chat_template,
|
707 |
-
fn_kwargs={"tokenizer": processing_class},
|
708 |
-
num_proc=args.dataset_num_proc,
|
709 |
-
desc="Applying chat template to train dataset",
|
710 |
-
)
|
711 |
-
if eval_dataset is not None:
|
712 |
-
eval_dataset = eval_dataset.map(
|
713 |
-
maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from eval dataset"
|
714 |
-
)
|
715 |
-
eval_dataset = maybe_unpair_preference_dataset(
|
716 |
-
eval_dataset, args.dataset_num_proc, desc="Unpairing eval dataset"
|
717 |
-
)
|
718 |
-
eval_dataset = eval_dataset.map(
|
719 |
-
maybe_apply_chat_template,
|
720 |
-
fn_kwargs={"tokenizer": processing_class},
|
721 |
-
num_proc=args.dataset_num_proc,
|
722 |
-
desc="Applying chat template to eval dataset",
|
723 |
-
)
|
724 |
-
|
725 |
-
# Tokenize and prepare the training datasets
|
726 |
-
train_dataset = train_dataset.map(
|
727 |
-
_tokenize,
|
728 |
-
batched=True,
|
729 |
-
fn_kwargs={"tokenizer": self.processing_class},
|
730 |
-
num_proc=args.dataset_num_proc,
|
731 |
-
desc="Tokenizing train dataset",
|
732 |
-
)
|
733 |
-
|
734 |
-
fn_kwargs = {
|
735 |
-
"prefix": "",
|
736 |
-
"is_encoder_decoder": self.is_encoder_decoder,
|
737 |
-
"tokenizer": self.processing_class,
|
738 |
-
"max_length": self.max_length,
|
739 |
-
"truncation_mode": self.truncation_mode,
|
740 |
-
"label_pad_token_id": self.label_pad_token_id,
|
741 |
-
"max_prompt_length": self.max_prompt_length,
|
742 |
-
"max_completion_length": self.max_completion_length,
|
743 |
-
}
|
744 |
-
|
745 |
-
train_dataset = train_dataset.map(
|
746 |
-
_process_tokens,
|
747 |
-
fn_kwargs=fn_kwargs,
|
748 |
-
num_proc=args.dataset_num_proc,
|
749 |
-
desc="Processing tokenized train dataset",
|
750 |
-
)
|
751 |
-
|
752 |
-
# Tokenize and prepare the eval datasets
|
753 |
-
if eval_dataset is not None:
|
754 |
-
eval_dataset = eval_dataset.map(
|
755 |
-
_tokenize,
|
756 |
-
fn_kwargs={"tokenizer": self.processing_class},
|
757 |
-
batched=True,
|
758 |
-
num_proc=args.dataset_num_proc,
|
759 |
-
desc="Tokenizing eval dataset",
|
760 |
-
)
|
761 |
-
|
762 |
-
eval_dataset = eval_dataset.map(
|
763 |
-
_process_tokens,
|
764 |
-
fn_kwargs=fn_kwargs,
|
765 |
-
num_proc=args.dataset_num_proc,
|
766 |
-
desc="Processing tokenized eval dataset",
|
767 |
-
)
|
768 |
-
|
769 |
-
# Get KL datasets if needed
|
770 |
-
if self.calculate_KL:
|
771 |
-
if args.per_device_train_batch_size <= 1:
|
772 |
-
raise ValueError(
|
773 |
-
"Actual (not effective) batch size must be > 1. KTO will not work properly because the KL term will be equivalent to the implied reward."
|
774 |
-
)
|
775 |
-
|
776 |
-
# create pairs for estimating the KL term by flipping the matched pairs in each batch of size total_batch_size
|
777 |
-
# i.e., (x_1, y_1), ..., (x_n, y_n) --> (x_1, y_n), ..., (x_n, y_1) = (x'_1, y'_1), ..., (x'_n, y'_n)
|
778 |
-
train_kl_dataset = train_dataset.map(
|
779 |
-
_get_kl_dataset,
|
780 |
-
batched=True,
|
781 |
-
batch_size=args.per_device_train_batch_size,
|
782 |
-
num_proc=args.dataset_num_proc,
|
783 |
-
desc="Extracting KL train dataset",
|
784 |
-
)
|
785 |
-
|
786 |
-
fn_kwargs["prefix"] = "KL_"
|
787 |
-
train_kl_dataset = train_kl_dataset.map(
|
788 |
-
_process_tokens,
|
789 |
-
fn_kwargs=fn_kwargs,
|
790 |
-
num_proc=args.dataset_num_proc,
|
791 |
-
remove_columns=[c for c in train_kl_dataset.column_names if c in train_dataset.column_names],
|
792 |
-
desc="Processing tokenized train KL dataset",
|
793 |
-
)
|
794 |
-
|
795 |
-
# merge the datasets
|
796 |
-
train_dataset = concatenate_datasets([train_dataset, train_kl_dataset], axis=1)
|
797 |
-
|
798 |
-
if eval_dataset is not None:
|
799 |
-
# Get KL dataset
|
800 |
-
eval_kl_dataset = eval_dataset.map(
|
801 |
-
_get_kl_dataset,
|
802 |
-
batched=True,
|
803 |
-
batch_size=args.per_device_train_batch_size,
|
804 |
-
num_proc=args.dataset_num_proc,
|
805 |
-
desc="Extracting eval KL dataset",
|
806 |
-
)
|
807 |
-
|
808 |
-
eval_kl_dataset = eval_kl_dataset.map(
|
809 |
-
_process_tokens,
|
810 |
-
fn_kwargs=fn_kwargs,
|
811 |
-
num_proc=args.dataset_num_proc,
|
812 |
-
remove_columns=[c for c in eval_kl_dataset.column_names if c in eval_dataset.column_names],
|
813 |
-
desc="Processing tokenized eval KL dataset",
|
814 |
-
)
|
815 |
-
|
816 |
-
# merge the datasets
|
817 |
-
eval_dataset = concatenate_datasets([eval_dataset, eval_kl_dataset], axis=1)
|
818 |
-
|
819 |
-
# calculate dataset desirability balance
|
820 |
-
num_desirable = max(sum(train_dataset["label"]), 1)
|
821 |
-
num_undesirable = max(len(train_dataset["label"]) - num_desirable, 1) # "label" is binary
|
822 |
-
|
823 |
-
if num_desirable != num_undesirable:
|
824 |
-
# The lower and upper bounds come from Eq. (8) of https://huggingface.co/papers/2402.01306
|
825 |
-
des_weight_lower_bound = round((num_undesirable * self.undesirable_weight / num_desirable) * 1, 2)
|
826 |
-
des_weight_upper_bound = round((num_undesirable * self.undesirable_weight / num_desirable) * 1.33, 2)
|
827 |
-
und_weight_lower_bound = round((num_desirable * self.desirable_weight / num_undesirable) / 1.33, 2)
|
828 |
-
und_weight_upper_bound = round((num_desirable * self.desirable_weight / num_undesirable) / 1, 2)
|
829 |
-
|
830 |
-
des_weight_in_range = des_weight_lower_bound <= self.desirable_weight <= des_weight_upper_bound
|
831 |
-
und_weight_in_range = und_weight_lower_bound <= self.undesirable_weight <= und_weight_upper_bound
|
832 |
-
|
833 |
-
if not (des_weight_in_range or und_weight_in_range):
|
834 |
-
warnings.warn(
|
835 |
-
"You have different amounts of desirable/positive and undesirable/negative examples but the "
|
836 |
-
"weights on the desirable and undesirable losses don't seem to be in an ideal range. Based "
|
837 |
-
f"on your data, we recommend EITHER "
|
838 |
-
f"desirable_weight in [{des_weight_lower_bound}, {des_weight_upper_bound}] or "
|
839 |
-
f"undesirable_weight in [{und_weight_lower_bound}, {und_weight_upper_bound}] (but NOT BOTH). "
|
840 |
-
"See the documentation on how to optimally set these weights.",
|
841 |
-
UserWarning,
|
842 |
-
)
|
843 |
-
|
844 |
-
super().__init__(
|
845 |
-
model=model,
|
846 |
-
args=args,
|
847 |
-
data_collator=data_collator,
|
848 |
-
train_dataset=train_dataset,
|
849 |
-
eval_dataset=eval_dataset,
|
850 |
-
processing_class=processing_class,
|
851 |
-
model_init=model_init,
|
852 |
-
compute_metrics=compute_metrics,
|
853 |
-
callbacks=callbacks,
|
854 |
-
optimizers=optimizers,
|
855 |
-
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
856 |
-
)
|
857 |
-
|
858 |
-
# Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
|
859 |
-
# model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
|
860 |
-
# self.model_accepts_loss_kwargs to False to enable scaling.
|
861 |
-
self.model_accepts_loss_kwargs = False
|
862 |
-
|
863 |
-
# Add tags for models that have been loaded with the correct transformers version
|
864 |
-
if hasattr(self.model, "add_model_tags"):
|
865 |
-
self.model.add_model_tags(self._tag_names)
|
866 |
-
|
867 |
-
if not hasattr(self, "accelerator"):
|
868 |
-
raise AttributeError(
|
869 |
-
"Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
|
870 |
-
)
|
871 |
-
|
872 |
-
# Deepspeed Zero-3 does not support precompute_ref_log_probs
|
873 |
-
if self.is_deepspeed_enabled:
|
874 |
-
if self.accelerator.state.deepspeed_plugin.zero_stage == 3 and self.precompute_ref_log_probs:
|
875 |
-
raise ValueError(
|
876 |
-
"You cannot use `precompute_ref_log_probs=True` with Deepspeed ZeRO-3. Please set `precompute_ref_log_probs=False`."
|
877 |
-
)
|
878 |
-
|
879 |
-
if self.ref_model is None:
|
880 |
-
if not (self.is_peft_model or self.precompute_ref_log_probs):
|
881 |
-
raise ValueError(
|
882 |
-
"No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`"
|
883 |
-
)
|
884 |
-
else:
|
885 |
-
if self.is_deepspeed_enabled:
|
886 |
-
self.ref_model = self._prepare_deepspeed(self.ref_model)
|
887 |
-
else:
|
888 |
-
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
889 |
-
|
890 |
-
def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
|
891 |
-
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
|
892 |
-
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
|
893 |
-
config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
|
894 |
-
|
895 |
-
if model is not None:
|
896 |
-
if hasattr(model, "config"):
|
897 |
-
hidden_size = (
|
898 |
-
max(model.config.hidden_sizes)
|
899 |
-
if getattr(model.config, "hidden_sizes", None)
|
900 |
-
else getattr(model.config, "hidden_size", None)
|
901 |
-
)
|
902 |
-
if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
|
903 |
-
# Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
|
904 |
-
# This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
|
905 |
-
config_kwargs.update(
|
906 |
-
{
|
907 |
-
"zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
|
908 |
-
"zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
|
909 |
-
"zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
|
910 |
-
}
|
911 |
-
)
|
912 |
-
|
913 |
-
# If ZeRO-3 is used, we shard both the active and reference model.
|
914 |
-
# Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
|
915 |
-
if config_kwargs["zero_optimization"]["stage"] != 3:
|
916 |
-
config_kwargs["zero_optimization"]["stage"] = 0
|
917 |
-
model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
|
918 |
-
model.eval()
|
919 |
-
return model
|
920 |
-
|
921 |
-
@contextmanager
|
922 |
-
def null_ref_context(self):
|
923 |
-
"""Context manager for handling null reference model (that is, peft adapter manipulation)."""
|
924 |
-
with (
|
925 |
-
self.accelerator.unwrap_model(self.model).disable_adapter()
|
926 |
-
if self.is_peft_model and not self.ref_adapter_name
|
927 |
-
else nullcontext()
|
928 |
-
):
|
929 |
-
if self.ref_adapter_name:
|
930 |
-
self.model.set_adapter(self.ref_adapter_name)
|
931 |
-
yield
|
932 |
-
if self.ref_adapter_name:
|
933 |
-
self.model.set_adapter(self.model_adapter_name or "default")
|
934 |
-
|
935 |
-
def get_train_dataloader(self) -> DataLoader:
|
936 |
-
"""
|
937 |
-
Returns the training [`~torch.utils.data.DataLoader`].
|
938 |
-
|
939 |
-
Subclass of transformers.src.transformers.trainer.get_train_dataloader to precompute `ref_log_probs`.
|
940 |
-
"""
|
941 |
-
|
942 |
-
if self.precompute_ref_log_probs and not self._precomputed_train_ref_log_probs:
|
943 |
-
dataloader_params = {
|
944 |
-
"batch_size": self.args.per_device_train_batch_size,
|
945 |
-
"collate_fn": self.data_collator,
|
946 |
-
"num_workers": self.args.dataloader_num_workers,
|
947 |
-
"pin_memory": self.args.dataloader_pin_memory,
|
948 |
-
"shuffle": False,
|
949 |
-
}
|
950 |
-
|
951 |
-
# prepare dataloader
|
952 |
-
data_loader = self.accelerator.prepare(DataLoader(self.train_dataset, **dataloader_params))
|
953 |
-
reference_completion_logps = []
|
954 |
-
reference_KL_logps = []
|
955 |
-
|
956 |
-
for padded_batch in tqdm(iterable=data_loader, desc="Train dataset reference log probs"):
|
957 |
-
reference_completion_logp, reference_KL_logp = self.compute_reference_log_probs(padded_batch)
|
958 |
-
|
959 |
-
reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp)
|
960 |
-
reference_completion_logps.append(reference_completion_logp.cpu())
|
961 |
-
|
962 |
-
if self.calculate_KL:
|
963 |
-
reference_KL_logp = self.accelerator.gather_for_metrics(reference_KL_logp)
|
964 |
-
reference_KL_logps.append(reference_KL_logp.cpu())
|
965 |
-
|
966 |
-
self.train_dataset = self.train_dataset.add_column(
|
967 |
-
name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy()
|
968 |
-
)
|
969 |
-
|
970 |
-
if self.calculate_KL:
|
971 |
-
self.train_dataset = self.train_dataset.add_column(
|
972 |
-
name="reference_KL_logps", column=torch.cat(reference_KL_logps).float().numpy()
|
973 |
-
)
|
974 |
-
|
975 |
-
self._precomputed_train_ref_log_probs = True
|
976 |
-
|
977 |
-
return super().get_train_dataloader()
|
978 |
-
|
979 |
-
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
|
980 |
-
"""
|
981 |
-
Returns the evaluation [`~torch.utils.data.DataLoader`].
|
982 |
-
|
983 |
-
Subclass of transformers.src.transformers.trainer.get_eval_dataloader to precompute `ref_log_probs`.
|
984 |
-
|
985 |
-
Args:
|
986 |
-
eval_dataset (`torch.utils.data.Dataset`, *optional*):
|
987 |
-
If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted
|
988 |
-
by the `model.forward()` method are automatically removed. It must implement `__len__`.
|
989 |
-
"""
|
990 |
-
if eval_dataset is None and self.eval_dataset is None:
|
991 |
-
raise ValueError("Trainer: evaluation requires an eval_dataset.")
|
992 |
-
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
993 |
-
|
994 |
-
if self.precompute_ref_log_probs and not self._precomputed_eval_ref_log_probs:
|
995 |
-
dataloader_params = {
|
996 |
-
"batch_size": self.args.per_device_eval_batch_size,
|
997 |
-
"collate_fn": self.data_collator,
|
998 |
-
"num_workers": self.args.dataloader_num_workers,
|
999 |
-
"pin_memory": self.args.dataloader_pin_memory,
|
1000 |
-
"shuffle": False,
|
1001 |
-
}
|
1002 |
-
|
1003 |
-
# prepare dataloader
|
1004 |
-
data_loader = self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params))
|
1005 |
-
|
1006 |
-
reference_completion_logps = []
|
1007 |
-
reference_KL_logps = []
|
1008 |
-
|
1009 |
-
for padded_batch in tqdm(iterable=data_loader, desc="Eval dataset reference log probs"):
|
1010 |
-
reference_completion_logp, reference_KL_logp = self.compute_reference_log_probs(padded_batch)
|
1011 |
-
|
1012 |
-
reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp)
|
1013 |
-
reference_completion_logps.append(reference_completion_logp.cpu())
|
1014 |
-
|
1015 |
-
if self.calculate_KL:
|
1016 |
-
reference_KL_logp = self.accelerator.gather_for_metrics(reference_KL_logp)
|
1017 |
-
reference_KL_logps.append(reference_KL_logp.cpu())
|
1018 |
-
|
1019 |
-
eval_dataset = eval_dataset.add_column(
|
1020 |
-
name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy()
|
1021 |
-
)
|
1022 |
-
if self.calculate_KL:
|
1023 |
-
eval_dataset = eval_dataset.add_column(
|
1024 |
-
name="reference_KL_logps", column=torch.cat(reference_KL_logps).float().numpy()
|
1025 |
-
)
|
1026 |
-
|
1027 |
-
# Save calculated reference_chosen_logps and reference_rejected_logps to the eval_dataset for subsequent runs
|
1028 |
-
if self.eval_dataset is not None:
|
1029 |
-
self.eval_dataset = eval_dataset
|
1030 |
-
self._precomputed_eval_ref_log_probs = True
|
1031 |
-
|
1032 |
-
return super().get_eval_dataloader(eval_dataset=eval_dataset)
|
1033 |
-
|
1034 |
-
def compute_reference_log_probs(self, padded_batch: dict) -> dict:
|
1035 |
-
"""Computes log probabilities of the reference model for a single padded batch of a KTO specific dataset."""
|
1036 |
-
with torch.no_grad():
|
1037 |
-
if self.ref_model is None:
|
1038 |
-
with self.null_ref_context():
|
1039 |
-
if self.is_encoder_decoder:
|
1040 |
-
completion_logits = self.model(
|
1041 |
-
padded_batch["prompt_input_ids"],
|
1042 |
-
attention_mask=padded_batch["prompt_attention_mask"],
|
1043 |
-
decoder_input_ids=padded_batch.get("completion_decoder_input_ids"),
|
1044 |
-
labels=padded_batch["completion_labels"],
|
1045 |
-
).logits
|
1046 |
-
|
1047 |
-
if self.calculate_KL:
|
1048 |
-
KL_logits = self.model(
|
1049 |
-
padded_batch["KL_prompt_input_ids"],
|
1050 |
-
attention_mask=padded_batch["KL_prompt_attention_mask"],
|
1051 |
-
decoder_input_ids=padded_batch.get("KL_completion_decoder_input_ids"),
|
1052 |
-
labels=padded_batch["KL_completion_labels"],
|
1053 |
-
).logits
|
1054 |
-
else:
|
1055 |
-
completion_logits = self.model(
|
1056 |
-
padded_batch["completion_input_ids"],
|
1057 |
-
attention_mask=padded_batch["completion_attention_mask"],
|
1058 |
-
).logits
|
1059 |
-
|
1060 |
-
if self.calculate_KL:
|
1061 |
-
KL_logits = self.model(
|
1062 |
-
padded_batch["KL_completion_input_ids"],
|
1063 |
-
attention_mask=padded_batch["KL_completion_attention_mask"],
|
1064 |
-
).logits
|
1065 |
-
else:
|
1066 |
-
if self.is_encoder_decoder:
|
1067 |
-
completion_logits = self.ref_model(
|
1068 |
-
padded_batch["prompt_input_ids"],
|
1069 |
-
attention_mask=padded_batch["prompt_attention_mask"],
|
1070 |
-
decoder_input_ids=padded_batch.get("completion_decoder_input_ids"),
|
1071 |
-
labels=padded_batch["completion_labels"],
|
1072 |
-
).logits
|
1073 |
-
|
1074 |
-
if self.calculate_KL:
|
1075 |
-
KL_logits = self.ref_model(
|
1076 |
-
padded_batch["KL_prompt_input_ids"],
|
1077 |
-
attention_mask=padded_batch["KL_prompt_attention_mask"],
|
1078 |
-
decoder_input_ids=padded_batch.get("KL_completion_decoder_input_ids"),
|
1079 |
-
labels=padded_batch["KL_completion_labels"],
|
1080 |
-
).logits
|
1081 |
-
else:
|
1082 |
-
completion_logits = self.ref_model(
|
1083 |
-
padded_batch["completion_input_ids"], attention_mask=padded_batch["completion_attention_mask"]
|
1084 |
-
).logits
|
1085 |
-
|
1086 |
-
if self.calculate_KL:
|
1087 |
-
KL_logits = self.ref_model(
|
1088 |
-
padded_batch["KL_completion_input_ids"],
|
1089 |
-
attention_mask=padded_batch["KL_completion_attention_mask"],
|
1090 |
-
).logits
|
1091 |
-
|
1092 |
-
completion_logps = self.get_batch_logps(
|
1093 |
-
completion_logits,
|
1094 |
-
padded_batch["completion_labels"],
|
1095 |
-
average_log_prob=False,
|
1096 |
-
is_encoder_decoder=self.is_encoder_decoder,
|
1097 |
-
label_pad_token_id=self.label_pad_token_id,
|
1098 |
-
)
|
1099 |
-
|
1100 |
-
if self.calculate_KL:
|
1101 |
-
KL_logps = self.get_batch_logps(
|
1102 |
-
KL_logits,
|
1103 |
-
padded_batch["KL_completion_labels"],
|
1104 |
-
average_log_prob=False,
|
1105 |
-
is_encoder_decoder=self.is_encoder_decoder,
|
1106 |
-
label_pad_token_id=self.label_pad_token_id,
|
1107 |
-
)
|
1108 |
-
else:
|
1109 |
-
KL_logps = None
|
1110 |
-
|
1111 |
-
return completion_logps, KL_logps
|
1112 |
-
|
1113 |
-
@staticmethod
|
1114 |
-
def get_batch_logps(
|
1115 |
-
logits: torch.FloatTensor,
|
1116 |
-
labels: torch.LongTensor,
|
1117 |
-
average_log_prob: bool = False,
|
1118 |
-
label_pad_token_id: int = -100,
|
1119 |
-
is_encoder_decoder: bool = False,
|
1120 |
-
) -> torch.FloatTensor:
|
1121 |
-
"""Compute the log probabilities of the given labels under the given logits.
|
1122 |
-
|
1123 |
-
Args:
|
1124 |
-
logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
|
1125 |
-
labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
|
1126 |
-
average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
|
1127 |
-
|
1128 |
-
Returns:
|
1129 |
-
A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
|
1130 |
-
"""
|
1131 |
-
if logits.shape[:-1] != labels.shape:
|
1132 |
-
raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
|
1133 |
-
|
1134 |
-
if not is_encoder_decoder:
|
1135 |
-
labels = labels[:, 1:].clone()
|
1136 |
-
logits = logits[:, :-1, :]
|
1137 |
-
else:
|
1138 |
-
# Fixes end-dec RuntimeError
|
1139 |
-
labels = labels.clone()
|
1140 |
-
|
1141 |
-
loss_mask = labels != label_pad_token_id
|
1142 |
-
|
1143 |
-
# dummy token; we'll ignore the losses on these tokens later
|
1144 |
-
labels[labels == label_pad_token_id] = 0
|
1145 |
-
|
1146 |
-
per_token_logps = selective_log_softmax(logits, labels)
|
1147 |
-
|
1148 |
-
if average_log_prob:
|
1149 |
-
return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
|
1150 |
-
else:
|
1151 |
-
return (per_token_logps * loss_mask).sum(-1)
|
1152 |
-
|
1153 |
-
def forward(
|
1154 |
-
self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
|
1155 |
-
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
1156 |
-
if self.calculate_KL:
|
1157 |
-
KL_logps = None
|
1158 |
-
KL_model_kwargs = (
|
1159 |
-
{
|
1160 |
-
"input_ids": batch["KL_prompt_input_ids"],
|
1161 |
-
"attention_mask": batch["KL_prompt_attention_mask"],
|
1162 |
-
"labels": batch["KL_completion_labels"],
|
1163 |
-
"decoder_input_ids": batch.get("KL_completion_decoder_input_ids"),
|
1164 |
-
}
|
1165 |
-
if self.is_encoder_decoder
|
1166 |
-
else {
|
1167 |
-
"input_ids": batch["KL_completion_input_ids"],
|
1168 |
-
"attention_mask": batch["KL_completion_attention_mask"],
|
1169 |
-
}
|
1170 |
-
)
|
1171 |
-
with torch.no_grad():
|
1172 |
-
KL_logits = model(
|
1173 |
-
**KL_model_kwargs,
|
1174 |
-
).logits
|
1175 |
-
|
1176 |
-
KL_logps = self.get_batch_logps(
|
1177 |
-
KL_logits,
|
1178 |
-
batch["KL_completion_labels"],
|
1179 |
-
average_log_prob=False,
|
1180 |
-
is_encoder_decoder=self.is_encoder_decoder,
|
1181 |
-
label_pad_token_id=self.label_pad_token_id,
|
1182 |
-
)
|
1183 |
-
else:
|
1184 |
-
KL_logps = None
|
1185 |
-
|
1186 |
-
model_kwargs = (
|
1187 |
-
{
|
1188 |
-
"labels": batch["completion_labels"],
|
1189 |
-
"decoder_input_ids": batch.get("completion_decoder_input_ids"),
|
1190 |
-
}
|
1191 |
-
if self.is_encoder_decoder
|
1192 |
-
else {}
|
1193 |
-
)
|
1194 |
-
if self.aux_loss_enabled:
|
1195 |
-
model_kwargs["output_router_logits"] = True
|
1196 |
-
|
1197 |
-
outputs = model(
|
1198 |
-
batch["completion_input_ids"],
|
1199 |
-
attention_mask=batch["completion_attention_mask"],
|
1200 |
-
**model_kwargs,
|
1201 |
-
)
|
1202 |
-
completion_logits = outputs.logits
|
1203 |
-
|
1204 |
-
completion_logps = self.get_batch_logps(
|
1205 |
-
completion_logits,
|
1206 |
-
batch["completion_labels"],
|
1207 |
-
average_log_prob=False,
|
1208 |
-
is_encoder_decoder=self.is_encoder_decoder,
|
1209 |
-
label_pad_token_id=self.label_pad_token_id,
|
1210 |
-
)
|
1211 |
-
|
1212 |
-
if completion_logps.shape[0] != len(batch["label"]):
|
1213 |
-
raise ValueError(
|
1214 |
-
"There is a mismatch between the number of examples in this batch and the number of "
|
1215 |
-
"examples for which an output sequence was predicted."
|
1216 |
-
)
|
1217 |
-
|
1218 |
-
chosen_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is True]
|
1219 |
-
rejected_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is False]
|
1220 |
-
|
1221 |
-
chosen_logps = completion_logps[chosen_idx, ...]
|
1222 |
-
rejected_logps = completion_logps[rejected_idx, ...]
|
1223 |
-
|
1224 |
-
chosen_logits = completion_logits[chosen_idx, ...]
|
1225 |
-
rejected_logits = completion_logits[rejected_idx, ...]
|
1226 |
-
|
1227 |
-
if self.aux_loss_enabled:
|
1228 |
-
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, KL_logps, outputs.aux_loss)
|
1229 |
-
else:
|
1230 |
-
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, KL_logps)
|
1231 |
-
|
1232 |
-
def kto_loss(
|
1233 |
-
self,
|
1234 |
-
policy_chosen_logps: torch.FloatTensor,
|
1235 |
-
policy_rejected_logps: torch.FloatTensor,
|
1236 |
-
policy_KL_logps: torch.FloatTensor,
|
1237 |
-
reference_chosen_logps: torch.FloatTensor,
|
1238 |
-
reference_rejected_logps: torch.FloatTensor,
|
1239 |
-
reference_KL_logps: torch.FloatTensor,
|
1240 |
-
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
1241 |
-
"""Compute the KTO loss for a batch of policy and reference model log probabilities.
|
1242 |
-
|
1243 |
-
Args:
|
1244 |
-
policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (num(chosen) in batch_size,)
|
1245 |
-
policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (num(rejected) in batch_size,)
|
1246 |
-
policy_KL_logps: Log probabilities of the policy model for the KL responses. Shape: (batch_size,)
|
1247 |
-
reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (num(chosen) in batch_size,)
|
1248 |
-
reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (num(rejected) in batch_size,)
|
1249 |
-
reference_KL_logps: Log probabilities of the reference model for the KL responses. Shape: (batch_size,)
|
1250 |
-
|
1251 |
-
Returns:
|
1252 |
-
A tuple of four tensors: (losses, chosen_rewards, rejected_rewards, KL).
|
1253 |
-
The losses tensor contains the KTO loss for each example in the batch.
|
1254 |
-
The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
|
1255 |
-
The KL tensor contains the detached KL divergence estimate between the policy and reference models.
|
1256 |
-
"""
|
1257 |
-
if self.calculate_KL:
|
1258 |
-
kl = (policy_KL_logps - reference_KL_logps).mean().detach()
|
1259 |
-
kl = self.accelerator.gather_for_metrics(kl).mean().clamp(min=0)
|
1260 |
-
else:
|
1261 |
-
kl = torch.zeros(1).to(policy_chosen_logps.device)
|
1262 |
-
|
1263 |
-
# Chosen losses
|
1264 |
-
if policy_chosen_logps.shape[0] != 0 or reference_chosen_logps.shape[0] != 0:
|
1265 |
-
chosen_logratios = policy_chosen_logps - reference_chosen_logps
|
1266 |
-
|
1267 |
-
if self.loss_type == "kto":
|
1268 |
-
# Eqn (7) of the KTO paper (https://huggingface.co/papers/2402.01306)
|
1269 |
-
chosen_losses = 1 - F.sigmoid(self.beta * (chosen_logratios - kl))
|
1270 |
-
elif self.loss_type == "apo_zero_unpaired":
|
1271 |
-
# Unpaired variant of Eqn (7) of the APO paper (https://huggingface.co/papers/2408.06266)
|
1272 |
-
# Use this loss when you believe the chosen outputs are better than your model's default output
|
1273 |
-
chosen_losses = 1 - F.sigmoid(self.beta * chosen_logratios)
|
1274 |
-
|
1275 |
-
chosen_rewards = self.beta * chosen_logratios.detach()
|
1276 |
-
|
1277 |
-
else:
|
1278 |
-
# lists can't be empty -- if they are, then accelerate.gather will hang
|
1279 |
-
chosen_losses = torch.Tensor([]).to(self.accelerator.device)
|
1280 |
-
chosen_rewards = torch.Tensor([]).to(self.accelerator.device)
|
1281 |
-
|
1282 |
-
# Rejected losses
|
1283 |
-
if policy_rejected_logps.shape[0] != 0 or reference_rejected_logps.shape[0] != 0:
|
1284 |
-
rejected_logratios = policy_rejected_logps - reference_rejected_logps
|
1285 |
-
|
1286 |
-
if self.loss_type == "kto":
|
1287 |
-
rejected_losses = 1 - F.sigmoid(self.beta * (kl - rejected_logratios))
|
1288 |
-
elif self.loss_type == "apo_zero_unpaired":
|
1289 |
-
rejected_losses = F.sigmoid(self.beta * rejected_logratios)
|
1290 |
-
|
1291 |
-
rejected_rewards = self.beta * rejected_logratios.detach()
|
1292 |
-
else:
|
1293 |
-
# lists can't be empty -- if they are, then accelerate.gather will hang
|
1294 |
-
rejected_losses = torch.Tensor([]).to(self.accelerator.device)
|
1295 |
-
rejected_rewards = torch.Tensor([]).to(self.accelerator.device)
|
1296 |
-
|
1297 |
-
losses = torch.cat(
|
1298 |
-
(self.desirable_weight * chosen_losses, self.undesirable_weight * rejected_losses),
|
1299 |
-
0,
|
1300 |
-
)
|
1301 |
-
|
1302 |
-
return losses, chosen_rewards, rejected_rewards, kl
|
1303 |
-
|
1304 |
-
def get_batch_loss_metrics(
|
1305 |
-
self,
|
1306 |
-
model,
|
1307 |
-
batch: dict[str, Union[list, torch.LongTensor]],
|
1308 |
-
):
|
1309 |
-
"""Compute the KTO loss and other metrics for the given batch of inputs for train or test."""
|
1310 |
-
metrics = {}
|
1311 |
-
batch = {k: (v.to(self.accelerator.device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
|
1312 |
-
|
1313 |
-
forward_output = self.forward(model, batch)
|
1314 |
-
(
|
1315 |
-
policy_chosen_logps,
|
1316 |
-
policy_rejected_logps,
|
1317 |
-
policy_chosen_logits,
|
1318 |
-
policy_rejected_logits,
|
1319 |
-
policy_KL_logps,
|
1320 |
-
) = forward_output[:5]
|
1321 |
-
if self.aux_loss_enabled:
|
1322 |
-
aux_loss = forward_output[5]
|
1323 |
-
|
1324 |
-
# if reference_logps in batch use them, otherwise use the reference model
|
1325 |
-
if "reference_logps" in batch:
|
1326 |
-
chosen_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is True]
|
1327 |
-
rejected_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is False]
|
1328 |
-
|
1329 |
-
reference_chosen_logps = batch["reference_logps"][chosen_idx, ...]
|
1330 |
-
reference_rejected_logps = batch["reference_logps"][rejected_idx, ...]
|
1331 |
-
if self.calculate_KL:
|
1332 |
-
reference_KL_logps = batch["reference_KL_logps"]
|
1333 |
-
else:
|
1334 |
-
reference_KL_logps = None
|
1335 |
-
else:
|
1336 |
-
with torch.no_grad():
|
1337 |
-
if self.ref_model is None:
|
1338 |
-
with self.null_ref_context():
|
1339 |
-
(
|
1340 |
-
reference_chosen_logps,
|
1341 |
-
reference_rejected_logps,
|
1342 |
-
_,
|
1343 |
-
_,
|
1344 |
-
reference_KL_logps,
|
1345 |
-
) = self.forward(self.model, batch)[:5]
|
1346 |
-
else:
|
1347 |
-
(
|
1348 |
-
reference_chosen_logps,
|
1349 |
-
reference_rejected_logps,
|
1350 |
-
_,
|
1351 |
-
_,
|
1352 |
-
reference_KL_logps,
|
1353 |
-
) = self.forward(self.ref_model, batch)[:5]
|
1354 |
-
|
1355 |
-
losses, chosen_rewards, rejected_rewards, kl = self.kto_loss(
|
1356 |
-
policy_chosen_logps,
|
1357 |
-
policy_rejected_logps,
|
1358 |
-
policy_KL_logps,
|
1359 |
-
reference_chosen_logps,
|
1360 |
-
reference_rejected_logps,
|
1361 |
-
reference_KL_logps,
|
1362 |
-
)
|
1363 |
-
metrics["kl"] = kl.item()
|
1364 |
-
|
1365 |
-
num_chosen = torch.Tensor([len(chosen_rewards)]).to(self.accelerator.device)
|
1366 |
-
num_rejected = torch.Tensor([len(rejected_rewards)]).to(self.accelerator.device)
|
1367 |
-
|
1368 |
-
all_num_chosen = self.accelerator.gather_for_metrics(num_chosen).sum().item()
|
1369 |
-
all_num_rejected = self.accelerator.gather_for_metrics(num_rejected).sum().item()
|
1370 |
-
|
1371 |
-
if all_num_chosen > 0:
|
1372 |
-
metrics["rewards/chosen_sum"] = (
|
1373 |
-
self.accelerator.gather_for_metrics(chosen_rewards.nansum()).nansum().item()
|
1374 |
-
)
|
1375 |
-
metrics["logps/chosen_sum"] = (
|
1376 |
-
self.accelerator.gather_for_metrics(policy_chosen_logps.nansum()).nansum().item()
|
1377 |
-
)
|
1378 |
-
metrics["logits/chosen_sum"] = (
|
1379 |
-
self.accelerator.gather_for_metrics(policy_chosen_logits.nansum()).nansum().item()
|
1380 |
-
)
|
1381 |
-
metrics["count/chosen"] = all_num_chosen
|
1382 |
-
|
1383 |
-
if all_num_rejected > 0:
|
1384 |
-
metrics["rewards/rejected_sum"] = (
|
1385 |
-
self.accelerator.gather_for_metrics(rejected_rewards.nansum()).nansum().item()
|
1386 |
-
)
|
1387 |
-
metrics["logps/rejected_sum"] = (
|
1388 |
-
self.accelerator.gather_for_metrics(policy_rejected_logps.nansum()).nansum().item()
|
1389 |
-
)
|
1390 |
-
metrics["logits/rejected_sum"] = (
|
1391 |
-
self.accelerator.gather_for_metrics(policy_rejected_logits.nansum()).nansum().item()
|
1392 |
-
)
|
1393 |
-
metrics["count/rejected"] = all_num_rejected
|
1394 |
-
|
1395 |
-
loss = losses.nanmean()
|
1396 |
-
if self.aux_loss_enabled:
|
1397 |
-
loss += self.aux_loss_coef * aux_loss
|
1398 |
-
|
1399 |
-
return loss, metrics
|
1400 |
-
|
1401 |
-
def compute_loss(
|
1402 |
-
self,
|
1403 |
-
model: Union[PreTrainedModel, nn.Module],
|
1404 |
-
inputs: dict[str, Union[torch.Tensor, Any]],
|
1405 |
-
return_outputs=False,
|
1406 |
-
num_items_in_batch=None,
|
1407 |
-
) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
|
1408 |
-
compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
|
1409 |
-
|
1410 |
-
with compute_loss_context_manager:
|
1411 |
-
loss, metrics = self.get_batch_loss_metrics(model, inputs)
|
1412 |
-
|
1413 |
-
# Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class:
|
1414 |
-
loss = loss.to(self.args.device)
|
1415 |
-
# force log the metrics
|
1416 |
-
if self.accelerator.is_main_process:
|
1417 |
-
self.store_metrics(metrics, train_eval="train")
|
1418 |
-
|
1419 |
-
if return_outputs:
|
1420 |
-
return (loss, metrics)
|
1421 |
-
return loss
|
1422 |
-
|
1423 |
-
def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
|
1424 |
-
for key, value in metrics.items():
|
1425 |
-
self._stored_metrics[train_eval][key].append(value)
|
1426 |
-
|
1427 |
-
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
1428 |
-
if self.train_dataset is None or not has_length(self.train_dataset):
|
1429 |
-
return None
|
1430 |
-
return SequentialSampler(self.train_dataset)
|
1431 |
-
|
1432 |
-
def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) -> tuple[str, str]:
|
1433 |
-
"""Generate samples from the model and reference model for the given batch of inputs."""
|
1434 |
-
|
1435 |
-
# If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
|
1436 |
-
# the torch cuda amp context manager as some hidden states are silently casted to full precision.
|
1437 |
-
generate_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
|
1438 |
-
|
1439 |
-
with generate_context_manager:
|
1440 |
-
policy_output = model.generate(
|
1441 |
-
input_ids=batch["prompt_input_ids"],
|
1442 |
-
attention_mask=batch["prompt_attention_mask"],
|
1443 |
-
max_length=self.max_length,
|
1444 |
-
do_sample=True,
|
1445 |
-
pad_token_id=self.processing_class.pad_token_id,
|
1446 |
-
)
|
1447 |
-
|
1448 |
-
# if reference_output in batch use that otherwise use the reference model
|
1449 |
-
if "reference_output" in batch:
|
1450 |
-
reference_output = batch["reference_output"]
|
1451 |
-
else:
|
1452 |
-
if self.ref_model is None:
|
1453 |
-
with self.null_ref_context():
|
1454 |
-
reference_output = self.model.generate(
|
1455 |
-
input_ids=batch["prompt_input_ids"],
|
1456 |
-
attention_mask=batch["prompt_attention_mask"],
|
1457 |
-
max_length=self.max_length,
|
1458 |
-
do_sample=True,
|
1459 |
-
pad_token_id=self.processing_class.pad_token_id,
|
1460 |
-
)
|
1461 |
-
else:
|
1462 |
-
reference_output = self.ref_model.generate(
|
1463 |
-
input_ids=batch["prompt_input_ids"],
|
1464 |
-
attention_mask=batch["prompt_attention_mask"],
|
1465 |
-
max_length=self.max_length,
|
1466 |
-
do_sample=True,
|
1467 |
-
pad_token_id=self.processing_class.pad_token_id,
|
1468 |
-
)
|
1469 |
-
|
1470 |
-
policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id)
|
1471 |
-
policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True)
|
1472 |
-
|
1473 |
-
reference_output = pad_to_length(reference_output, self.max_length, self.processing_class.pad_token_id)
|
1474 |
-
reference_output_decoded = self.processing_class.batch_decode(reference_output, skip_special_tokens=True)
|
1475 |
-
|
1476 |
-
return policy_output_decoded, reference_output_decoded
|
1477 |
-
|
1478 |
-
def prediction_step(
|
1479 |
-
self,
|
1480 |
-
model: Union[PreTrainedModel, nn.Module],
|
1481 |
-
inputs: dict[str, Union[torch.Tensor, Any]],
|
1482 |
-
prediction_loss_only: bool,
|
1483 |
-
ignore_keys: Optional[list[str]] = None,
|
1484 |
-
):
|
1485 |
-
if ignore_keys is None:
|
1486 |
-
if hasattr(model, "config"):
|
1487 |
-
ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
|
1488 |
-
else:
|
1489 |
-
ignore_keys = []
|
1490 |
-
|
1491 |
-
prediction_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
|
1492 |
-
with torch.no_grad(), prediction_context_manager:
|
1493 |
-
loss, metrics = self.get_batch_loss_metrics(model, inputs)
|
1494 |
-
|
1495 |
-
# force log the metrics
|
1496 |
-
if self.accelerator.is_main_process:
|
1497 |
-
self.store_metrics(metrics, train_eval="eval")
|
1498 |
-
|
1499 |
-
if prediction_loss_only:
|
1500 |
-
return (loss.detach(), None, None)
|
1501 |
-
|
1502 |
-
# logits for the chosen and rejected samples from model
|
1503 |
-
logits_dict = {
|
1504 |
-
"eval_logits/chosen": metrics["logits/chosen"],
|
1505 |
-
"eval_logits/rejected": metrics["logits/rejected"],
|
1506 |
-
}
|
1507 |
-
logits = torch.tensor(
|
1508 |
-
[v for k, v in logits_dict.items() if k not in ignore_keys], device=self.accelerator.device
|
1509 |
-
)
|
1510 |
-
labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
|
1511 |
-
|
1512 |
-
return (loss.detach(), logits, labels)
|
1513 |
-
|
1514 |
-
def evaluation_loop(
|
1515 |
-
self,
|
1516 |
-
dataloader: DataLoader,
|
1517 |
-
description: str,
|
1518 |
-
prediction_loss_only: Optional[bool] = None,
|
1519 |
-
ignore_keys: Optional[list[str]] = None,
|
1520 |
-
metric_key_prefix: str = "eval",
|
1521 |
-
) -> EvalLoopOutput:
|
1522 |
-
"""
|
1523 |
-
Overriding built-in evaluation loop to store metrics for each batch.
|
1524 |
-
Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
|
1525 |
-
|
1526 |
-
Works both with or without labels.
|
1527 |
-
"""
|
1528 |
-
|
1529 |
-
# Sample and save to game log if requested (for one batch to save time)
|
1530 |
-
if self.generate_during_eval:
|
1531 |
-
# Generate random indices within the range of the total number of samples
|
1532 |
-
num_samples = len(dataloader.dataset)
|
1533 |
-
random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
|
1534 |
-
|
1535 |
-
# Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
|
1536 |
-
random_batch_dataset = dataloader.dataset.select(random_indices)
|
1537 |
-
random_batch = self.data_collator(random_batch_dataset)
|
1538 |
-
random_batch = self._prepare_inputs(random_batch)
|
1539 |
-
|
1540 |
-
target_indicies = [i for i in range(len(random_batch["label"])) if random_batch["label"][i] is False]
|
1541 |
-
target_batch = {
|
1542 |
-
"prompt_input_ids": random_batch["prompt_input_ids"][target_indicies],
|
1543 |
-
"prompt_attention_mask": random_batch["prompt_attention_mask"][target_indicies],
|
1544 |
-
"prompt": itemgetter(*target_indicies)(random_batch["prompt"]),
|
1545 |
-
}
|
1546 |
-
policy_output_decoded, ref_output_decoded = self.generate_from_model_and_ref(self.model, target_batch)
|
1547 |
-
|
1548 |
-
table = pd.DataFrame(
|
1549 |
-
columns=["Prompt", "Policy", "Ref Model"],
|
1550 |
-
data=[
|
1551 |
-
[prompt, pol[len(prompt) :], ref[len(prompt) :]]
|
1552 |
-
for prompt, pol, ref in zip(target_batch["prompt"], policy_output_decoded, ref_output_decoded)
|
1553 |
-
],
|
1554 |
-
)
|
1555 |
-
if "wandb" in self.args.report_to:
|
1556 |
-
wandb.log({"game_log": wandb.Table(data=table)})
|
1557 |
-
|
1558 |
-
if "comet_ml" in self.args.report_to:
|
1559 |
-
log_table_to_comet_experiment(
|
1560 |
-
name="game_log.csv",
|
1561 |
-
table=table,
|
1562 |
-
)
|
1563 |
-
|
1564 |
-
# Base evaluation
|
1565 |
-
initial_output = super().evaluation_loop(
|
1566 |
-
dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
|
1567 |
-
)
|
1568 |
-
|
1569 |
-
return initial_output
|
1570 |
-
|
1571 |
-
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
|
1572 |
-
"""
|
1573 |
-
Log `logs` on the various objects watching training, including stored metrics.
|
1574 |
-
|
1575 |
-
Args:
|
1576 |
-
logs (`dict[str, float]`):
|
1577 |
-
The values to log.
|
1578 |
-
start_time (`float` or `None`, *optional*, defaults to `None`):
|
1579 |
-
Start time of the training.
|
1580 |
-
"""
|
1581 |
-
# logs either has 'loss' or 'eval_loss'
|
1582 |
-
train_eval = "train" if "loss" in logs else "eval"
|
1583 |
-
# train metrics should have no prefix, eval should have 'eval_'
|
1584 |
-
prefix = "eval_" if train_eval == "eval" else ""
|
1585 |
-
# accumulate average metrics from sums and lengths
|
1586 |
-
for split in ["chosen", "rejected"]:
|
1587 |
-
if f"count/{split}" in self._stored_metrics[train_eval]:
|
1588 |
-
count_sum = torch.Tensor(self._stored_metrics[train_eval][f"count/{split}"]).sum().item()
|
1589 |
-
for metric in ["rewards", "logps", "logits"]:
|
1590 |
-
logs[f"{prefix}{metric}/{split}"] = (
|
1591 |
-
torch.Tensor(self._stored_metrics[train_eval][f"{metric}/{split}_sum"]).sum().item()
|
1592 |
-
/ count_sum
|
1593 |
-
)
|
1594 |
-
# delete obsolete metric
|
1595 |
-
del self._stored_metrics[train_eval][f"{metric}/{split}_sum"]
|
1596 |
-
del self._stored_metrics[train_eval][f"count/{split}"]
|
1597 |
-
# calculate reward margin
|
1598 |
-
if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs:
|
1599 |
-
logs[f"{prefix}rewards/margins"] = logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"]
|
1600 |
-
# Add averaged stored metrics to logs
|
1601 |
-
for key, metrics in self._stored_metrics[train_eval].items():
|
1602 |
-
logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item()
|
1603 |
-
del self._stored_metrics[train_eval]
|
1604 |
-
|
1605 |
-
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
1606 |
-
return super().log(logs, start_time)
|
1607 |
-
else: # transformers<=4.46
|
1608 |
-
return super().log(logs)
|
1609 |
-
|
1610 |
-
def create_model_card(
|
1611 |
-
self,
|
1612 |
-
model_name: Optional[str] = None,
|
1613 |
-
dataset_name: Optional[str] = None,
|
1614 |
-
tags: Union[str, list[str], None] = None,
|
1615 |
-
):
|
1616 |
-
"""
|
1617 |
-
Creates a draft of a model card using the information available to the `Trainer`.
|
1618 |
-
|
1619 |
-
Args:
|
1620 |
-
model_name (`str` or `None`, *optional*, defaults to `None`):
|
1621 |
-
Name of the model.
|
1622 |
-
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
1623 |
-
Name of the dataset used for training.
|
1624 |
-
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
1625 |
-
Tags to be associated with the model card.
|
1626 |
-
"""
|
1627 |
-
if not self.is_world_process_zero():
|
1628 |
-
return
|
1629 |
-
|
1630 |
-
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
1631 |
-
base_model = self.model.config._name_or_path
|
1632 |
-
else:
|
1633 |
-
base_model = None
|
1634 |
-
|
1635 |
-
tags = tags or []
|
1636 |
-
if isinstance(tags, str):
|
1637 |
-
tags = [tags]
|
1638 |
-
|
1639 |
-
if hasattr(self.model.config, "unsloth_version"):
|
1640 |
-
tags.append("unsloth")
|
1641 |
-
|
1642 |
-
citation = textwrap.dedent("""\
|
1643 |
-
@article{ethayarajh2024kto,
|
1644 |
-
title = {{KTO: Model Alignment as Prospect Theoretic Optimization}},
|
1645 |
-
author = {Kawin Ethayarajh and Winnie Xu and Niklas Muennighoff and Dan Jurafsky and Douwe Kiela},
|
1646 |
-
year = 2024,
|
1647 |
-
eprint = {arXiv:2402.01306},
|
1648 |
-
}""")
|
1649 |
-
|
1650 |
-
model_card = generate_model_card(
|
1651 |
-
base_model=base_model,
|
1652 |
-
model_name=model_name,
|
1653 |
-
hub_model_id=self.hub_model_id,
|
1654 |
-
dataset_name=dataset_name,
|
1655 |
-
tags=tags,
|
1656 |
-
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
1657 |
-
comet_url=get_comet_experiment_url(),
|
1658 |
-
trainer_name="KTO",
|
1659 |
-
trainer_citation=citation,
|
1660 |
-
paper_title="KTO: Model Alignment as Prospect Theoretic Optimization",
|
1661 |
-
paper_id="2402.01306",
|
1662 |
-
)
|
1663 |
-
|
1664 |
-
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
1665 |
-
class UnslothKTOTrainer(_UnslothKTOTrainer):
|
1666 |
-
"""
|
1667 |
-
|
1668 |
-
Initialize KTOTrainer.
|
1669 |
-
|
1670 |
-
Args:
|
1671 |
-
model (`transformers.PreTrainedModel`):
|
1672 |
-
The model to train, preferably an `AutoModelForSequenceClassification`.
|
1673 |
-
ref_model (`PreTrainedModelWrapper`):
|
1674 |
-
Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no
|
1675 |
-
reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized.
|
1676 |
-
args (`KTOConfig`):
|
1677 |
-
The arguments to use for training.
|
1678 |
-
train_dataset (`datasets.Dataset`):
|
1679 |
-
The dataset to use for training.
|
1680 |
-
eval_dataset (`datasets.Dataset`):
|
1681 |
-
The dataset to use for evaluation.
|
1682 |
-
processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
|
1683 |
-
Processing class used to process the data. If provided, will be used to automatically process the inputs
|
1684 |
-
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
|
1685 |
-
reuse the fine-tuned model.
|
1686 |
-
data_collator (`transformers.DataCollator`, *optional*, defaults to `None`):
|
1687 |
-
The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
|
1688 |
-
which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
|
1689 |
-
model_init (`Callable[[], transformers.PreTrainedModel]`):
|
1690 |
-
The model initializer to use for training. If None is specified, the default model initializer will be used.
|
1691 |
-
callbacks (`list[transformers.TrainerCallback]`):
|
1692 |
-
The callbacks to use for training.
|
1693 |
-
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
|
1694 |
-
The optimizer and scheduler to use for training.
|
1695 |
-
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
|
1696 |
-
The function to use to preprocess the logits before computing the metrics.
|
1697 |
-
peft_config (`dict`, defaults to `None`):
|
1698 |
-
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
|
1699 |
-
compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
|
1700 |
-
The function to use to compute the metrics. Must take a `EvalPrediction` and return
|
1701 |
-
a dictionary string to metric values.
|
1702 |
-
model_adapter_name (`str`, defaults to `None`):
|
1703 |
-
Name of the train target PEFT adapter, when using LoRA with multiple adapters.
|
1704 |
-
ref_adapter_name (`str`, defaults to `None`):
|
1705 |
-
Name of the reference PEFT adapter, when using LoRA with multiple adapters.
|
1706 |
-
|
1707 |
-
"""
|
1708 |
-
def __init__(
|
1709 |
-
self,
|
1710 |
-
model = None,
|
1711 |
-
ref_model = None,
|
1712 |
-
args = None,
|
1713 |
-
train_dataset = None,
|
1714 |
-
eval_dataset = None,
|
1715 |
-
processing_class = None,
|
1716 |
-
data_collator = None,
|
1717 |
-
model_init = None,
|
1718 |
-
callbacks = None,
|
1719 |
-
preprocess_logits_for_metrics = None,
|
1720 |
-
peft_config = None,
|
1721 |
-
compute_metrics = None,
|
1722 |
-
model_adapter_name = None,
|
1723 |
-
ref_adapter_name = None,
|
1724 |
-
**kwargs
|
1725 |
-
):
|
1726 |
-
if args is None: args = UnslothKTOConfig()
|
1727 |
-
use_bf16 = getattr(args, 'bf16', False)
|
1728 |
-
use_fp16 = getattr(args, 'fp16', False)
|
1729 |
-
force_float32 = False
|
1730 |
-
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
1731 |
-
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
1732 |
-
force_float32 = True
|
1733 |
-
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
1734 |
-
dtype = getattr(model.config, 'torch_dtype', None)
|
1735 |
-
if dtype is None: dtype = model.get_input_embeddings().dtype
|
1736 |
-
from unsloth_zoo.utils import _get_dtype
|
1737 |
-
dtype = _get_dtype(dtype)
|
1738 |
-
float16 = dtype == torch.float16
|
1739 |
-
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
1740 |
-
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
1741 |
-
if force_float32:
|
1742 |
-
args.fp16 = False
|
1743 |
-
args.bf16 = False
|
1744 |
-
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
1745 |
-
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
1746 |
-
args.fp16 = float16
|
1747 |
-
args.bf16 = not float16
|
1748 |
-
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
1749 |
-
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
1750 |
-
args.eval_strategy = 'steps'
|
1751 |
-
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
1752 |
-
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
1753 |
-
if ga_steps is not None and ga_steps > 1:
|
1754 |
-
from transformers import __version__ as transformers_version
|
1755 |
-
if Version(transformers_version) <= Version('4.45.2'):
|
1756 |
-
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
1757 |
-
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
1758 |
-
if getattr(args, 'eval_strategy', 'no') != 'no':
|
1759 |
-
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
1760 |
-
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
1761 |
-
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
1762 |
-
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
1763 |
-
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
1764 |
-
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
1765 |
-
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
1766 |
-
if force_float32:
|
1767 |
-
args.bf16_full_eval = False
|
1768 |
-
args.fp16_full_eval = False
|
1769 |
-
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
1770 |
-
args.bf16_full_eval = True
|
1771 |
-
args.fp16_full_eval = False
|
1772 |
-
elif not bf16_full_eval and not fp16_full_eval:
|
1773 |
-
args.bf16_full_eval = args.bf16
|
1774 |
-
args.fp16_full_eval = args.fp16
|
1775 |
-
_output_logits = False
|
1776 |
-
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
1777 |
-
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
1778 |
-
if _output_logits:
|
1779 |
-
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
1780 |
-
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
1781 |
-
pass
|
1782 |
-
else:
|
1783 |
-
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
1784 |
-
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
1785 |
-
if args_max_seq_length is None and model_max_seq_length is not None:
|
1786 |
-
max_seq_length = model.max_seq_length
|
1787 |
-
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
1788 |
-
if model is not None and hasattr(model, 'for_training'):
|
1789 |
-
model.for_training()
|
1790 |
-
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
1791 |
-
if 'processing_class' in locals():
|
1792 |
-
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
1793 |
-
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
1794 |
-
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
1795 |
-
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
1796 |
-
if not isinstance(data_collator, UnslothVisionDataCollator):
|
1797 |
-
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
1798 |
-
data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
|
1799 |
-
elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
1800 |
-
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
1801 |
-
else:
|
1802 |
-
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
1803 |
-
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
1804 |
-
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
1805 |
-
if not isinstance(data_collator, UnslothVisionDataCollator):
|
1806 |
-
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
1807 |
-
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
1808 |
-
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
1809 |
-
else:
|
1810 |
-
data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
|
1811 |
-
other_metrics = []
|
1812 |
-
|
1813 |
-
from unsloth_zoo.logging_utils import PatchRLStatistics
|
1814 |
-
PatchRLStatistics('kto_trainer', other_metrics)
|
1815 |
-
|
1816 |
-
super().__init__(
|
1817 |
-
model = model,
|
1818 |
-
ref_model = ref_model,
|
1819 |
-
args = args,
|
1820 |
-
train_dataset = train_dataset,
|
1821 |
-
eval_dataset = eval_dataset,
|
1822 |
-
processing_class = processing_class,
|
1823 |
-
data_collator = data_collator,
|
1824 |
-
model_init = model_init,
|
1825 |
-
callbacks = callbacks,
|
1826 |
-
preprocess_logits_for_metrics = preprocess_logits_for_metrics,
|
1827 |
-
peft_config = peft_config,
|
1828 |
-
compute_metrics = compute_metrics,
|
1829 |
-
model_adapter_name = model_adapter_name,
|
1830 |
-
ref_adapter_name = ref_adapter_name,**kwargs)
|
1831 |
-
if hasattr(self, 'neftune_hook_handle'):
|
1832 |
-
self.neftune_hook_handle.remove()
|
1833 |
-
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
1834 |
-
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
1835 |
-
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
1836 |
-
pass
|
1837 |
-
|
1838 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unsloth_compiled_cache/UnslothNashMDTrainer.py
DELETED
@@ -1,953 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
2025.3.13
|
3 |
-
2025.3.15
|
4 |
-
4.48.3
|
5 |
-
0.15.2
|
6 |
-
__UNSLOTH_VERSIONING__
|
7 |
-
"""
|
8 |
-
from torch import Tensor
|
9 |
-
import torch
|
10 |
-
import torch.nn as nn
|
11 |
-
from torch.nn import functional as F
|
12 |
-
from trl.trainer.nash_md_trainer import (Any, BaseImageProcessor, BasePairwiseJudge, Callable, Dataset, EvalPrediction, F, FeatureExtractionMixin, GeometricMixtureWrapper, IterableDataset, NashMDConfig, NashMDTrainer, OnlineDPOTrainer, OptimizerNames, Optional, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SIMPLE_CHAT_TEMPLATE, TrainerCallback, Union, empty_cache, generate_model_card, get_comet_experiment_url, get_reward, is_conversational, is_wandb_available, jinja2, maybe_apply_chat_template, nn, os, textwrap, torch, truncate_right, unwrap_model_for_generation, wandb)
|
13 |
-
|
14 |
-
|
15 |
-
import os
|
16 |
-
from typing import *
|
17 |
-
from dataclasses import dataclass, field
|
18 |
-
from packaging.version import Version
|
19 |
-
import torch
|
20 |
-
import numpy as np
|
21 |
-
from contextlib import nullcontext
|
22 |
-
from torch.nn import functional as F
|
23 |
-
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
|
24 |
-
|
25 |
-
torch_compile_options = {
|
26 |
-
"epilogue_fusion" : True,
|
27 |
-
"max_autotune" : False,
|
28 |
-
"shape_padding" : True,
|
29 |
-
"trace.enabled" : False,
|
30 |
-
"triton.cudagraphs" : False,
|
31 |
-
}
|
32 |
-
|
33 |
-
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
34 |
-
def selective_log_softmax(logits, index):
|
35 |
-
logits = logits.to(torch.float32)
|
36 |
-
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
37 |
-
# loop to reduce peak mem consumption
|
38 |
-
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
39 |
-
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
40 |
-
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
41 |
-
return per_token_logps
|
42 |
-
@dataclass
|
43 |
-
class UnslothNashMDConfig(NashMDConfig):
|
44 |
-
"""
|
45 |
-
|
46 |
-
Configuration class for the [`NashMDTrainer`].
|
47 |
-
|
48 |
-
Subclass of [`OnlineDPOConfig`] we can use all its arguments and add the following:
|
49 |
-
|
50 |
-
Parameters:
|
51 |
-
mixture_coef (`float` or `list[float]`, *optional*, defaults to `0.5`):
|
52 |
-
Logit mixture coefficient for the model and reference model. If a list of floats is provided then the
|
53 |
-
mixture coefficient is selected for each new epoch and the last coefficient is used for the rest of the
|
54 |
-
epochs.
|
55 |
-
|
56 |
-
"""
|
57 |
-
vllm_sampling_params: Optional[Any] = field(
|
58 |
-
default = None,
|
59 |
-
metadata = {'help': 'vLLM SamplingParams'},
|
60 |
-
)
|
61 |
-
unsloth_num_chunks : Optional[int] = field(
|
62 |
-
default = -1,
|
63 |
-
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
64 |
-
)
|
65 |
-
def __init__(
|
66 |
-
self,
|
67 |
-
output_dir = None,
|
68 |
-
overwrite_output_dir = None,
|
69 |
-
do_train = False,
|
70 |
-
do_eval = False,
|
71 |
-
do_predict = False,
|
72 |
-
eval_strategy = 'no',
|
73 |
-
prediction_loss_only = False,
|
74 |
-
per_device_train_batch_size = 4,
|
75 |
-
per_device_eval_batch_size = 4,
|
76 |
-
per_gpu_train_batch_size = None,
|
77 |
-
per_gpu_eval_batch_size = None,
|
78 |
-
gradient_accumulation_steps = 2,
|
79 |
-
eval_accumulation_steps = 2,
|
80 |
-
eval_delay = 0,
|
81 |
-
torch_empty_cache_steps = 250,
|
82 |
-
learning_rate = 5e-05,
|
83 |
-
weight_decay = 0.01,
|
84 |
-
adam_beta1 = 0.9,
|
85 |
-
adam_beta2 = 0.999,
|
86 |
-
adam_epsilon = 1e-08,
|
87 |
-
max_grad_norm = 1.0,
|
88 |
-
num_train_epochs = 3.0,
|
89 |
-
max_steps = -1,
|
90 |
-
lr_scheduler_type = 'linear',
|
91 |
-
warmup_ratio = 0.1,
|
92 |
-
warmup_steps = 0,
|
93 |
-
log_level = 'passive',
|
94 |
-
log_level_replica = 'warning',
|
95 |
-
log_on_each_node = True,
|
96 |
-
logging_dir = None,
|
97 |
-
logging_strategy = 'steps',
|
98 |
-
logging_first_step = False,
|
99 |
-
logging_steps = 1,
|
100 |
-
logging_nan_inf_filter = False,
|
101 |
-
save_strategy = 'steps',
|
102 |
-
save_steps = 500,
|
103 |
-
save_total_limit = None,
|
104 |
-
save_safetensors = True,
|
105 |
-
save_on_each_node = False,
|
106 |
-
save_only_model = False,
|
107 |
-
restore_callback_states_from_checkpoint = False,
|
108 |
-
no_cuda = False,
|
109 |
-
use_cpu = False,
|
110 |
-
use_mps_device = False,
|
111 |
-
seed = 3407,
|
112 |
-
data_seed = 3407,
|
113 |
-
jit_mode_eval = False,
|
114 |
-
use_ipex = False,
|
115 |
-
bf16 = False,
|
116 |
-
fp16 = False,
|
117 |
-
fp16_opt_level = 'O1',
|
118 |
-
half_precision_backend = 'auto',
|
119 |
-
bf16_full_eval = False,
|
120 |
-
fp16_full_eval = False,
|
121 |
-
tf32 = None,
|
122 |
-
local_rank = -1,
|
123 |
-
ddp_backend = None,
|
124 |
-
tpu_num_cores = None,
|
125 |
-
tpu_metrics_debug = False,
|
126 |
-
debug = '',
|
127 |
-
dataloader_drop_last = False,
|
128 |
-
eval_steps = None,
|
129 |
-
dataloader_num_workers = 0,
|
130 |
-
dataloader_prefetch_factor = None,
|
131 |
-
past_index = -1,
|
132 |
-
run_name = None,
|
133 |
-
disable_tqdm = None,
|
134 |
-
remove_unused_columns = True,
|
135 |
-
label_names = None,
|
136 |
-
load_best_model_at_end = False,
|
137 |
-
metric_for_best_model = None,
|
138 |
-
greater_is_better = None,
|
139 |
-
ignore_data_skip = False,
|
140 |
-
fsdp = '',
|
141 |
-
fsdp_min_num_params = 0,
|
142 |
-
fsdp_config = None,
|
143 |
-
fsdp_transformer_layer_cls_to_wrap = None,
|
144 |
-
accelerator_config = None,
|
145 |
-
deepspeed = None,
|
146 |
-
label_smoothing_factor = 0.0,
|
147 |
-
optim = 'adamw_8bit',
|
148 |
-
optim_args = None,
|
149 |
-
adafactor = False,
|
150 |
-
group_by_length = False,
|
151 |
-
length_column_name = 'length',
|
152 |
-
report_to = None,
|
153 |
-
ddp_find_unused_parameters = None,
|
154 |
-
ddp_bucket_cap_mb = None,
|
155 |
-
ddp_broadcast_buffers = None,
|
156 |
-
dataloader_pin_memory = True,
|
157 |
-
dataloader_persistent_workers = False,
|
158 |
-
skip_memory_metrics = True,
|
159 |
-
use_legacy_prediction_loop = False,
|
160 |
-
push_to_hub = False,
|
161 |
-
resume_from_checkpoint = None,
|
162 |
-
hub_model_id = None,
|
163 |
-
hub_strategy = 'every_save',
|
164 |
-
hub_token = None,
|
165 |
-
hub_private_repo = None,
|
166 |
-
hub_always_push = False,
|
167 |
-
gradient_checkpointing = False,
|
168 |
-
gradient_checkpointing_kwargs = None,
|
169 |
-
include_inputs_for_metrics = False,
|
170 |
-
eval_do_concat_batches = True,
|
171 |
-
fp16_backend = 'auto',
|
172 |
-
evaluation_strategy = None,
|
173 |
-
push_to_hub_model_id = None,
|
174 |
-
push_to_hub_organization = None,
|
175 |
-
push_to_hub_token = None,
|
176 |
-
mp_parameters = '',
|
177 |
-
auto_find_batch_size = False,
|
178 |
-
full_determinism = False,
|
179 |
-
torchdynamo = None,
|
180 |
-
ray_scope = 'last',
|
181 |
-
ddp_timeout = 1800,
|
182 |
-
torch_compile = False,
|
183 |
-
torch_compile_backend = None,
|
184 |
-
torch_compile_mode = None,
|
185 |
-
dispatch_batches = None,
|
186 |
-
split_batches = None,
|
187 |
-
include_tokens_per_second = False,
|
188 |
-
include_num_input_tokens_seen = False,
|
189 |
-
neftune_noise_alpha = None,
|
190 |
-
optim_target_modules = None,
|
191 |
-
batch_eval_metrics = False,
|
192 |
-
eval_on_start = False,
|
193 |
-
use_liger_kernel = False,
|
194 |
-
eval_use_gather_object = False,
|
195 |
-
average_tokens_across_devices = False,
|
196 |
-
reward_model_path = None,
|
197 |
-
judge = None,
|
198 |
-
max_new_tokens = 64,
|
199 |
-
max_length = 512,
|
200 |
-
temperature = 0.9,
|
201 |
-
missing_eos_penalty = None,
|
202 |
-
loss_type = 'sigmoid',
|
203 |
-
dataset_num_proc = None,
|
204 |
-
disable_dropout = True,
|
205 |
-
use_vllm = False,
|
206 |
-
ds3_gather_for_generation = True,
|
207 |
-
vllm_sampling_params = None,
|
208 |
-
unsloth_num_chunks = -1,
|
209 |
-
**kwargs,
|
210 |
-
):
|
211 |
-
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
212 |
-
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
213 |
-
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
214 |
-
output_dir = 'unsloth_training_checkpoints'
|
215 |
-
save_strategy = 'no'
|
216 |
-
if dataset_num_proc is None:
|
217 |
-
from multiprocessing import cpu_count
|
218 |
-
dataset_num_proc = cpu_count()
|
219 |
-
|
220 |
-
super().__init__(
|
221 |
-
output_dir = output_dir,
|
222 |
-
overwrite_output_dir = overwrite_output_dir,
|
223 |
-
do_train = do_train,
|
224 |
-
do_eval = do_eval,
|
225 |
-
do_predict = do_predict,
|
226 |
-
eval_strategy = eval_strategy,
|
227 |
-
prediction_loss_only = prediction_loss_only,
|
228 |
-
per_device_train_batch_size = per_device_train_batch_size,
|
229 |
-
per_device_eval_batch_size = per_device_eval_batch_size,
|
230 |
-
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
231 |
-
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
232 |
-
gradient_accumulation_steps = gradient_accumulation_steps,
|
233 |
-
eval_accumulation_steps = eval_accumulation_steps,
|
234 |
-
eval_delay = eval_delay,
|
235 |
-
torch_empty_cache_steps = torch_empty_cache_steps,
|
236 |
-
learning_rate = learning_rate,
|
237 |
-
weight_decay = weight_decay,
|
238 |
-
adam_beta1 = adam_beta1,
|
239 |
-
adam_beta2 = adam_beta2,
|
240 |
-
adam_epsilon = adam_epsilon,
|
241 |
-
max_grad_norm = max_grad_norm,
|
242 |
-
num_train_epochs = num_train_epochs,
|
243 |
-
max_steps = max_steps,
|
244 |
-
lr_scheduler_type = lr_scheduler_type,
|
245 |
-
warmup_ratio = warmup_ratio,
|
246 |
-
warmup_steps = warmup_steps,
|
247 |
-
log_level = log_level,
|
248 |
-
log_level_replica = log_level_replica,
|
249 |
-
log_on_each_node = log_on_each_node,
|
250 |
-
logging_dir = logging_dir,
|
251 |
-
logging_strategy = logging_strategy,
|
252 |
-
logging_first_step = logging_first_step,
|
253 |
-
logging_steps = logging_steps,
|
254 |
-
logging_nan_inf_filter = logging_nan_inf_filter,
|
255 |
-
save_strategy = save_strategy,
|
256 |
-
save_steps = save_steps,
|
257 |
-
save_total_limit = save_total_limit,
|
258 |
-
save_safetensors = save_safetensors,
|
259 |
-
save_on_each_node = save_on_each_node,
|
260 |
-
save_only_model = save_only_model,
|
261 |
-
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
262 |
-
no_cuda = no_cuda,
|
263 |
-
use_cpu = use_cpu,
|
264 |
-
use_mps_device = use_mps_device,
|
265 |
-
seed = seed,
|
266 |
-
data_seed = data_seed,
|
267 |
-
jit_mode_eval = jit_mode_eval,
|
268 |
-
use_ipex = use_ipex,
|
269 |
-
bf16 = bf16,
|
270 |
-
fp16 = fp16,
|
271 |
-
fp16_opt_level = fp16_opt_level,
|
272 |
-
half_precision_backend = half_precision_backend,
|
273 |
-
bf16_full_eval = bf16_full_eval,
|
274 |
-
fp16_full_eval = fp16_full_eval,
|
275 |
-
tf32 = tf32,
|
276 |
-
local_rank = local_rank,
|
277 |
-
ddp_backend = ddp_backend,
|
278 |
-
tpu_num_cores = tpu_num_cores,
|
279 |
-
tpu_metrics_debug = tpu_metrics_debug,
|
280 |
-
debug = debug,
|
281 |
-
dataloader_drop_last = dataloader_drop_last,
|
282 |
-
eval_steps = eval_steps,
|
283 |
-
dataloader_num_workers = dataloader_num_workers,
|
284 |
-
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
285 |
-
past_index = past_index,
|
286 |
-
run_name = run_name,
|
287 |
-
disable_tqdm = disable_tqdm,
|
288 |
-
remove_unused_columns = remove_unused_columns,
|
289 |
-
label_names = label_names,
|
290 |
-
load_best_model_at_end = load_best_model_at_end,
|
291 |
-
metric_for_best_model = metric_for_best_model,
|
292 |
-
greater_is_better = greater_is_better,
|
293 |
-
ignore_data_skip = ignore_data_skip,
|
294 |
-
fsdp = fsdp,
|
295 |
-
fsdp_min_num_params = fsdp_min_num_params,
|
296 |
-
fsdp_config = fsdp_config,
|
297 |
-
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
298 |
-
accelerator_config = accelerator_config,
|
299 |
-
deepspeed = deepspeed,
|
300 |
-
label_smoothing_factor = label_smoothing_factor,
|
301 |
-
optim = optim,
|
302 |
-
optim_args = optim_args,
|
303 |
-
adafactor = adafactor,
|
304 |
-
group_by_length = group_by_length,
|
305 |
-
length_column_name = length_column_name,
|
306 |
-
report_to = report_to,
|
307 |
-
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
308 |
-
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
309 |
-
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
310 |
-
dataloader_pin_memory = dataloader_pin_memory,
|
311 |
-
dataloader_persistent_workers = dataloader_persistent_workers,
|
312 |
-
skip_memory_metrics = skip_memory_metrics,
|
313 |
-
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
314 |
-
push_to_hub = push_to_hub,
|
315 |
-
resume_from_checkpoint = resume_from_checkpoint,
|
316 |
-
hub_model_id = hub_model_id,
|
317 |
-
hub_strategy = hub_strategy,
|
318 |
-
hub_token = hub_token,
|
319 |
-
hub_private_repo = hub_private_repo,
|
320 |
-
hub_always_push = hub_always_push,
|
321 |
-
gradient_checkpointing = gradient_checkpointing,
|
322 |
-
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
323 |
-
include_inputs_for_metrics = include_inputs_for_metrics,
|
324 |
-
eval_do_concat_batches = eval_do_concat_batches,
|
325 |
-
fp16_backend = fp16_backend,
|
326 |
-
evaluation_strategy = evaluation_strategy,
|
327 |
-
push_to_hub_model_id = push_to_hub_model_id,
|
328 |
-
push_to_hub_organization = push_to_hub_organization,
|
329 |
-
push_to_hub_token = push_to_hub_token,
|
330 |
-
mp_parameters = mp_parameters,
|
331 |
-
auto_find_batch_size = auto_find_batch_size,
|
332 |
-
full_determinism = full_determinism,
|
333 |
-
torchdynamo = torchdynamo,
|
334 |
-
ray_scope = ray_scope,
|
335 |
-
ddp_timeout = ddp_timeout,
|
336 |
-
torch_compile = torch_compile,
|
337 |
-
torch_compile_backend = torch_compile_backend,
|
338 |
-
torch_compile_mode = torch_compile_mode,
|
339 |
-
dispatch_batches = dispatch_batches,
|
340 |
-
split_batches = split_batches,
|
341 |
-
include_tokens_per_second = include_tokens_per_second,
|
342 |
-
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
343 |
-
neftune_noise_alpha = neftune_noise_alpha,
|
344 |
-
optim_target_modules = optim_target_modules,
|
345 |
-
batch_eval_metrics = batch_eval_metrics,
|
346 |
-
eval_on_start = eval_on_start,
|
347 |
-
use_liger_kernel = use_liger_kernel,
|
348 |
-
eval_use_gather_object = eval_use_gather_object,
|
349 |
-
average_tokens_across_devices = average_tokens_across_devices,
|
350 |
-
reward_model_path = reward_model_path,
|
351 |
-
judge = judge,
|
352 |
-
max_new_tokens = max_new_tokens,
|
353 |
-
max_length = max_length,
|
354 |
-
temperature = temperature,
|
355 |
-
missing_eos_penalty = missing_eos_penalty,
|
356 |
-
loss_type = loss_type,
|
357 |
-
dataset_num_proc = dataset_num_proc,
|
358 |
-
disable_dropout = disable_dropout,
|
359 |
-
use_vllm = use_vllm,
|
360 |
-
ds3_gather_for_generation = ds3_gather_for_generation,**kwargs)
|
361 |
-
self.vllm_sampling_params = vllm_sampling_params
|
362 |
-
self.unsloth_num_chunks = unsloth_num_chunks
|
363 |
-
pass
|
364 |
-
|
365 |
-
class _UnslothNashMDTrainer(OnlineDPOTrainer):
|
366 |
-
r""""""
|
367 |
-
|
368 |
-
_tag_names = ["trl", "nash-md"]
|
369 |
-
|
370 |
-
def __init__(
|
371 |
-
self,
|
372 |
-
model: Union[PreTrainedModel, nn.Module] = None,
|
373 |
-
ref_model: Union[PreTrainedModel, nn.Module] = None,
|
374 |
-
reward_model: Union[PreTrainedModel, nn.Module, None] = None,
|
375 |
-
judge: Optional[BasePairwiseJudge] = None,
|
376 |
-
args: Optional[NashMDConfig] = None,
|
377 |
-
data_collator: Optional[Callable] = None,
|
378 |
-
train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
|
379 |
-
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
380 |
-
processing_class: Optional[
|
381 |
-
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
382 |
-
] = None,
|
383 |
-
peft_config: Optional[dict] = None,
|
384 |
-
compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
|
385 |
-
callbacks: Optional[list[TrainerCallback]] = None,
|
386 |
-
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
387 |
-
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
388 |
-
) -> None:
|
389 |
-
super().__init__(
|
390 |
-
model=model,
|
391 |
-
ref_model=ref_model,
|
392 |
-
reward_model=reward_model,
|
393 |
-
judge=judge,
|
394 |
-
args=args,
|
395 |
-
data_collator=data_collator,
|
396 |
-
train_dataset=train_dataset,
|
397 |
-
eval_dataset=eval_dataset,
|
398 |
-
processing_class=processing_class,
|
399 |
-
reward_processing_class=processing_class, # for now, NashMDTrainer can't use any reward model
|
400 |
-
peft_config=peft_config,
|
401 |
-
compute_metrics=compute_metrics,
|
402 |
-
callbacks=callbacks,
|
403 |
-
optimizers=optimizers,
|
404 |
-
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
405 |
-
)
|
406 |
-
|
407 |
-
self._mixture_coef = self.args.mixture_coef
|
408 |
-
|
409 |
-
# Overwrite the stats dictionary to include NashMD specific statistics
|
410 |
-
self.stats = {
|
411 |
-
# Remove "non_score_reward", "rlhf_reward", "scores_margin"
|
412 |
-
# Add "mixture_coef"
|
413 |
-
"loss/kl": [],
|
414 |
-
"objective/entropy": [],
|
415 |
-
"loss/score": [],
|
416 |
-
"rewards/probabilities": [],
|
417 |
-
"rewards/accuracies": [],
|
418 |
-
"rewards/margins": [],
|
419 |
-
"logps/chosen": [],
|
420 |
-
"logps/rejected": [],
|
421 |
-
"val/model_contain_eos_token": [],
|
422 |
-
"val/ref_contain_eos_token": [],
|
423 |
-
"beta": [],
|
424 |
-
"mixture_coef": [],
|
425 |
-
}
|
426 |
-
if self.reward_model is not None:
|
427 |
-
self.stats["rewards/chosen"] = []
|
428 |
-
self.stats["rewards/rejected"] = []
|
429 |
-
|
430 |
-
@property
|
431 |
-
def mixture_coef(self):
|
432 |
-
if isinstance(self._mixture_coef, list):
|
433 |
-
epoch = self.state.epoch
|
434 |
-
return self._mixture_coef[epoch] if epoch < len(self._mixture_coef) else self._mixture_coef[-1]
|
435 |
-
else:
|
436 |
-
return self._mixture_coef
|
437 |
-
|
438 |
-
def _generate_completions(self, model, prompts):
|
439 |
-
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
|
440 |
-
model_output = unwrapped_model.generate(
|
441 |
-
input_ids=prompts["input_ids"],
|
442 |
-
attention_mask=prompts["attention_mask"],
|
443 |
-
generation_config=self.generation_config,
|
444 |
-
)
|
445 |
-
|
446 |
-
ref_model = model if self.ref_model is None else self.ref_model
|
447 |
-
with torch.no_grad(), unwrap_model_for_generation(ref_model, self.accelerator) as unwrapped_ref_model:
|
448 |
-
mixture_model = GeometricMixtureWrapper(
|
449 |
-
model=unwrapped_model,
|
450 |
-
ref_model=unwrapped_ref_model,
|
451 |
-
generation_config=self.generation_config,
|
452 |
-
mixture_coef=self.mixture_coef,
|
453 |
-
device=self.accelerator.device,
|
454 |
-
)
|
455 |
-
|
456 |
-
mixture_output = mixture_model.generate(
|
457 |
-
input_ids=prompts["input_ids"],
|
458 |
-
attention_mask=prompts["attention_mask"],
|
459 |
-
generation_config=self.generation_config,
|
460 |
-
)
|
461 |
-
|
462 |
-
return model_output, mixture_output
|
463 |
-
|
464 |
-
def _process_completions(self, model_output, mixture_output, prompts):
|
465 |
-
context_length = prompts["input_ids"].shape[1]
|
466 |
-
|
467 |
-
# Process model completions
|
468 |
-
model_completion_ids = model_output[:, context_length:]
|
469 |
-
model_completion_ids, model_completion_mask = truncate_right(
|
470 |
-
model_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id
|
471 |
-
)
|
472 |
-
model_data = {
|
473 |
-
"input_ids": torch.cat((prompts["input_ids"], model_completion_ids), dim=1),
|
474 |
-
"attention_mask": torch.cat((prompts["attention_mask"], model_completion_mask), dim=1),
|
475 |
-
"raw": prompts["raw"],
|
476 |
-
}
|
477 |
-
|
478 |
-
# Process reference model completions
|
479 |
-
mixture_completion_ids = mixture_output[:, context_length:]
|
480 |
-
mixture_completion_ids, mixture_completion_mask = truncate_right(
|
481 |
-
mixture_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id
|
482 |
-
)
|
483 |
-
mixture_data = {
|
484 |
-
"input_ids": torch.cat((prompts["input_ids"], mixture_completion_ids), dim=1),
|
485 |
-
"attention_mask": torch.cat((prompts["attention_mask"], mixture_completion_mask), dim=1),
|
486 |
-
"raw": prompts["raw"],
|
487 |
-
}
|
488 |
-
|
489 |
-
return model_data, mixture_data
|
490 |
-
|
491 |
-
def _compute_rewards(self, model_data, mixture_data, context_length):
|
492 |
-
with torch.no_grad():
|
493 |
-
_, model_scores, _ = get_reward(
|
494 |
-
self.reward_model, model_data["input_ids"], self.processing_class.pad_token_id, context_length
|
495 |
-
)
|
496 |
-
_, mixture_scores, _ = get_reward(
|
497 |
-
self.reward_model, mixture_data["input_ids"], self.processing_class.pad_token_id, context_length
|
498 |
-
)
|
499 |
-
|
500 |
-
# Apply EOS penalty if needed
|
501 |
-
if self.args.missing_eos_penalty is not None:
|
502 |
-
model_contain_eos = torch.any(model_data["input_ids"] == self.processing_class.eos_token_id, dim=-1)
|
503 |
-
mixture_contain_eos = torch.any(mixture_data["input_ids"] == self.processing_class.eos_token_id, dim=-1)
|
504 |
-
model_scores[~model_contain_eos] -= self.args.missing_eos_penalty
|
505 |
-
mixture_scores[~mixture_contain_eos] -= self.args.missing_eos_penalty
|
506 |
-
|
507 |
-
return model_scores, mixture_scores
|
508 |
-
|
509 |
-
def _compute_judge(self, model_data, mixture_data, context_length):
|
510 |
-
prompts = model_data["raw"]
|
511 |
-
model_data_completions = self.processing_class.batch_decode(
|
512 |
-
model_data["input_ids"][:, context_length:], skip_special_tokens=True
|
513 |
-
)
|
514 |
-
model_data_completions = [completion.strip() for completion in model_data_completions]
|
515 |
-
|
516 |
-
mixture_data_completions = self.processing_class.batch_decode(
|
517 |
-
mixture_data["input_ids"][:, context_length:], skip_special_tokens=True
|
518 |
-
)
|
519 |
-
mixture_data_completions = [completion.strip() for completion in mixture_data_completions]
|
520 |
-
if is_conversational({"prompt": prompts[0]}):
|
521 |
-
model_data_completions = [
|
522 |
-
[{"role": "assistant", "content": completion}] for completion in model_data_completions
|
523 |
-
]
|
524 |
-
environment = jinja2.Environment()
|
525 |
-
template = environment.from_string(SIMPLE_CHAT_TEMPLATE)
|
526 |
-
prompts = [template.render(messages=message) for message in prompts]
|
527 |
-
model_data_completions = [template.render(messages=completion) for completion in model_data_completions]
|
528 |
-
|
529 |
-
mixture_data_completions = [
|
530 |
-
[{"role": "assistant", "content": completion}] for completion in mixture_data_completions
|
531 |
-
]
|
532 |
-
mixture_data_completions = [
|
533 |
-
template.render(messages=completion) for completion in mixture_data_completions
|
534 |
-
]
|
535 |
-
|
536 |
-
probability = self.judge.judge(
|
537 |
-
prompts,
|
538 |
-
list(zip(model_data_completions, mixture_data_completions)),
|
539 |
-
return_scores=True,
|
540 |
-
)
|
541 |
-
return torch.tensor(probability, device=model_data["input_ids"].device)
|
542 |
-
|
543 |
-
def _compute_logprobs(self, model, model_data, context_length):
|
544 |
-
def compute_logprobs_for_data(m, data):
|
545 |
-
output = m(data["input_ids"], attention_mask=data["attention_mask"])
|
546 |
-
logits = output.logits[:, context_length - 1 : -1]
|
547 |
-
token_logprobs = selective_log_softmax(logits, data["input_ids"][:, context_length:])
|
548 |
-
return token_logprobs
|
549 |
-
|
550 |
-
# Compute logprobs for model completions under the model
|
551 |
-
model_logprobs_model_data = compute_logprobs_for_data(model, model_data)
|
552 |
-
|
553 |
-
# Compute logprobs of model completions under the reference model
|
554 |
-
with torch.no_grad():
|
555 |
-
if self.ref_model is None:
|
556 |
-
with model.disable_adapter():
|
557 |
-
ref_logprobs_model_data = compute_logprobs_for_data(model, model_data)
|
558 |
-
else:
|
559 |
-
ref_logprobs_model_data = compute_logprobs_for_data(self.ref_model, model_data)
|
560 |
-
|
561 |
-
# Mask padding tokens
|
562 |
-
model_padding_mask = model_data["attention_mask"][:, context_length:] == 0
|
563 |
-
model_logprobs_model_data = model_logprobs_model_data.masked_fill(model_padding_mask, 0.0)
|
564 |
-
ref_logprobs_model_data = ref_logprobs_model_data.masked_fill(model_padding_mask, 0.0)
|
565 |
-
|
566 |
-
return (model_logprobs_model_data, ref_logprobs_model_data)
|
567 |
-
|
568 |
-
def _compute_losses(
|
569 |
-
self,
|
570 |
-
model_logprobs_model_data,
|
571 |
-
ref_logprobs_model_data,
|
572 |
-
probability,
|
573 |
-
):
|
574 |
-
# reinforce score where 0.5 is a control variate
|
575 |
-
score = (probability - 0.5) * model_logprobs_model_data.sum(1)
|
576 |
-
|
577 |
-
# kl divergence via reinforce
|
578 |
-
with torch.no_grad():
|
579 |
-
log_ratio = model_logprobs_model_data - ref_logprobs_model_data
|
580 |
-
kl_div_log = log_ratio.sum(1)
|
581 |
-
kl_div_loss = (log_ratio * model_logprobs_model_data).sum(1)
|
582 |
-
|
583 |
-
# final loss
|
584 |
-
loss = self.beta * kl_div_loss - score
|
585 |
-
|
586 |
-
return loss.mean(), score, kl_div_log
|
587 |
-
|
588 |
-
def _log_statistics(
|
589 |
-
self,
|
590 |
-
model_data,
|
591 |
-
mixture_data,
|
592 |
-
model_logprobs_model_data,
|
593 |
-
ref_logprobs_model_data,
|
594 |
-
probability,
|
595 |
-
score,
|
596 |
-
kl_div,
|
597 |
-
context_length,
|
598 |
-
model_scores=None,
|
599 |
-
mixture_scores=None,
|
600 |
-
):
|
601 |
-
# Helper function to gather and compute mean
|
602 |
-
def gather_mean(tensor):
|
603 |
-
return self.accelerator.gather_for_metrics(tensor).mean().item()
|
604 |
-
|
605 |
-
# Log score
|
606 |
-
self.stats["loss/score"].append(gather_mean(score))
|
607 |
-
# Log KL divergence
|
608 |
-
self.stats["loss/kl"].append(gather_mean(kl_div))
|
609 |
-
|
610 |
-
# Log logprobs
|
611 |
-
model_logprobs_model_data_sum = model_logprobs_model_data.sum(1)
|
612 |
-
ref_logprobs_model_data_sum = ref_logprobs_model_data.sum(1)
|
613 |
-
|
614 |
-
self.stats["logps/chosen"].append(gather_mean(model_logprobs_model_data_sum))
|
615 |
-
self.stats["logps/rejected"].append(gather_mean(ref_logprobs_model_data_sum))
|
616 |
-
|
617 |
-
# Log rewards
|
618 |
-
if self.reward_model is not None:
|
619 |
-
self.stats["rewards/chosen"].append(gather_mean(model_scores))
|
620 |
-
self.stats["rewards/rejected"].append(gather_mean(mixture_scores))
|
621 |
-
|
622 |
-
# Log probabilities
|
623 |
-
self.stats["rewards/probabilities"].append(gather_mean(probability))
|
624 |
-
|
625 |
-
# Calculate entropy for model data
|
626 |
-
entropy_model_data = -model_logprobs_model_data.sum(1)
|
627 |
-
self.stats["objective/entropy"].append(gather_mean(entropy_model_data))
|
628 |
-
|
629 |
-
# Calculate margins
|
630 |
-
margin = model_logprobs_model_data_sum - ref_logprobs_model_data_sum
|
631 |
-
self.stats["rewards/margins"].append(gather_mean(margin))
|
632 |
-
|
633 |
-
# Calculate accuracy
|
634 |
-
accuracy = (margin > 0).float()
|
635 |
-
self.stats["rewards/accuracies"].append(gather_mean(accuracy))
|
636 |
-
|
637 |
-
# Log EOS token statistics
|
638 |
-
model_eos = (model_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1)
|
639 |
-
mixture_eos = (mixture_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1)
|
640 |
-
self.stats["val/model_contain_eos_token"].append(gather_mean(model_eos.float()))
|
641 |
-
self.stats["val/ref_contain_eos_token"].append(gather_mean(mixture_eos.float()))
|
642 |
-
|
643 |
-
# Log beta and mixture coef
|
644 |
-
self.stats["beta"].append(self.beta)
|
645 |
-
self.stats["mixture_coef"].append(self.mixture_coef)
|
646 |
-
|
647 |
-
def training_step(
|
648 |
-
self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
|
649 |
-
) -> torch.Tensor:
|
650 |
-
model.train()
|
651 |
-
|
652 |
-
# Apply chat template and tokenize the input
|
653 |
-
batch_size = len(next(iter(inputs.values())))
|
654 |
-
prompts = inputs["prompt"]
|
655 |
-
inputs = [{k: v[i] for k, v in inputs.items()} for i in range(batch_size)]
|
656 |
-
inputs = [maybe_apply_chat_template(x, self.processing_class) for x in inputs]
|
657 |
-
inputs = [self.tokenize_row(x, self.model.config.is_encoder_decoder, self.processing_class) for x in inputs]
|
658 |
-
inputs = self.data_collator(inputs)
|
659 |
-
|
660 |
-
# need the prompt_ only
|
661 |
-
inputs = self._prepare_inputs(inputs)
|
662 |
-
context_length = inputs["prompt_input_ids"].shape[1]
|
663 |
-
prompts = {
|
664 |
-
"input_ids": inputs["prompt_input_ids"],
|
665 |
-
"attention_mask": inputs["prompt_attention_mask"],
|
666 |
-
"raw": prompts,
|
667 |
-
}
|
668 |
-
del inputs
|
669 |
-
|
670 |
-
# Sample completions from both the model and the reference model
|
671 |
-
model_output, mixture_output = self._generate_completions(model, prompts)
|
672 |
-
|
673 |
-
# Process model completions
|
674 |
-
model_data, mixture_data = self._process_completions(model_output, mixture_output, prompts)
|
675 |
-
|
676 |
-
# Compute rewards
|
677 |
-
if self.reward_model is not None:
|
678 |
-
model_scores, mixture_scores = self._compute_rewards(model_data, mixture_data, context_length)
|
679 |
-
# probability of the model data vs the mixture data
|
680 |
-
probability = F.sigmoid(model_scores - mixture_scores)
|
681 |
-
else:
|
682 |
-
model_scores, mixture_scores = None, None
|
683 |
-
probability = self._compute_judge(model_data, mixture_data, context_length)
|
684 |
-
|
685 |
-
# Compute logprobs
|
686 |
-
model_logprobs_model_data, ref_logprobs_model_data = self._compute_logprobs(model, model_data, context_length)
|
687 |
-
|
688 |
-
# Compute loss
|
689 |
-
loss, score, kl_div = self._compute_losses(model_logprobs_model_data, ref_logprobs_model_data, probability)
|
690 |
-
|
691 |
-
# Log everything
|
692 |
-
self._log_statistics(
|
693 |
-
model_data,
|
694 |
-
mixture_data,
|
695 |
-
model_logprobs_model_data.detach(),
|
696 |
-
ref_logprobs_model_data,
|
697 |
-
probability,
|
698 |
-
score.detach(),
|
699 |
-
kl_div.detach(),
|
700 |
-
context_length,
|
701 |
-
model_scores,
|
702 |
-
mixture_scores,
|
703 |
-
)
|
704 |
-
|
705 |
-
if (
|
706 |
-
self.args.torch_empty_cache_steps is not None
|
707 |
-
and self.state.global_step % self.args.torch_empty_cache_steps == 0
|
708 |
-
):
|
709 |
-
empty_cache()
|
710 |
-
|
711 |
-
kwargs = {}
|
712 |
-
# For LOMO optimizers you need to explicitly use the learning rate
|
713 |
-
if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
|
714 |
-
kwargs["learning_rate"] = self._get_learning_rate()
|
715 |
-
|
716 |
-
if self.args.n_gpu > 1:
|
717 |
-
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
718 |
-
|
719 |
-
if self.use_apex:
|
720 |
-
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
721 |
-
scaled_loss.backward()
|
722 |
-
else:
|
723 |
-
self.accelerator.backward(loss, **kwargs)
|
724 |
-
|
725 |
-
return loss.detach() / self.args.gradient_accumulation_steps
|
726 |
-
|
727 |
-
def create_model_card(
|
728 |
-
self,
|
729 |
-
model_name: Optional[str] = None,
|
730 |
-
dataset_name: Optional[str] = None,
|
731 |
-
tags: Union[str, list[str], None] = None,
|
732 |
-
):
|
733 |
-
"""
|
734 |
-
Creates a draft of a model card using the information available to the `Trainer`.
|
735 |
-
|
736 |
-
Args:
|
737 |
-
model_name (`str` or `None`, *optional*, defaults to `None`):
|
738 |
-
Name of the model.
|
739 |
-
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
740 |
-
Name of the dataset used for training.
|
741 |
-
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
742 |
-
Tags to be associated with the model card.
|
743 |
-
"""
|
744 |
-
if not self.is_world_process_zero():
|
745 |
-
return
|
746 |
-
|
747 |
-
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
748 |
-
base_model = self.model.config._name_or_path
|
749 |
-
else:
|
750 |
-
base_model = None
|
751 |
-
|
752 |
-
tags = tags or []
|
753 |
-
if isinstance(tags, str):
|
754 |
-
tags = [tags]
|
755 |
-
|
756 |
-
if hasattr(self.model.config, "unsloth_version"):
|
757 |
-
tags.append("unsloth")
|
758 |
-
|
759 |
-
citation = textwrap.dedent("""\
|
760 |
-
@inproceedings{munos2024nash,
|
761 |
-
title = {{Nash Learning from Human Feedback}},
|
762 |
-
author = {R{\'{e}}mi Munos and Michal Valko and Daniele Calandriello and Mohammad Gheshlaghi Azar and Mark Rowland and Zhaohan Daniel Guo and Yunhao Tang and Matthieu Geist and Thomas Mesnard and C{\\^{o}}me Fiegel and Andrea Michi and Marco Selvi and Sertan Girgin and Nikola Momchev and Olivier Bachem and Daniel J. Mankowitz and Doina Precup and Bilal Piot},
|
763 |
-
year = 2024,
|
764 |
-
booktitle = {Forty-first International Conference on Machine Learning, {ICML} 2024, Vienna, Austria, July 21-27, 2024},
|
765 |
-
publisher = {OpenReview.net},
|
766 |
-
url = {https://openreview.net/forum?id=Y5AmNYiyCQ}
|
767 |
-
}""")
|
768 |
-
|
769 |
-
model_card = generate_model_card(
|
770 |
-
base_model=base_model,
|
771 |
-
model_name=model_name,
|
772 |
-
hub_model_id=self.hub_model_id,
|
773 |
-
dataset_name=dataset_name,
|
774 |
-
tags=tags,
|
775 |
-
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
776 |
-
comet_url=get_comet_experiment_url(),
|
777 |
-
trainer_name="Nash-MD",
|
778 |
-
trainer_citation=citation,
|
779 |
-
paper_title="Nash Learning from Human Feedback",
|
780 |
-
paper_id="2312.00886",
|
781 |
-
)
|
782 |
-
|
783 |
-
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
784 |
-
class UnslothNashMDTrainer(_UnslothNashMDTrainer):
|
785 |
-
"""
|
786 |
-
|
787 |
-
Initialize NashMDTrainer as a subclass of [`OnlineDPOConfig`].
|
788 |
-
|
789 |
-
Args:
|
790 |
-
model (`transformers.PreTrainedModel`):
|
791 |
-
The model to train, preferably an `AutoModelForCausalLM`.
|
792 |
-
ref_model (`PreTrainedModelWrapper`):
|
793 |
-
Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no
|
794 |
-
reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized.
|
795 |
-
reward_model (`transformers.PreTrainedModel`):
|
796 |
-
The reward model to score completions with, preferably an `AutoModelForSequenceClassification`.
|
797 |
-
judge (`BasePairwiseJudge`):
|
798 |
-
The judge to use for pairwise comparison of model completions.
|
799 |
-
args (`NashMDConfig`):
|
800 |
-
The NashMD config arguments to use for training.
|
801 |
-
data_collator (`transformers.DataCollator`):
|
802 |
-
The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
|
803 |
-
which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
|
804 |
-
train_dataset (`datasets.Dataset`):
|
805 |
-
The dataset to use for training.
|
806 |
-
eval_dataset (`datasets.Dataset`):
|
807 |
-
The dataset to use for evaluation.
|
808 |
-
processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
|
809 |
-
Processing class used to process the data. If provided, will be used to automatically process the inputs
|
810 |
-
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
|
811 |
-
reuse the fine-tuned model.
|
812 |
-
peft_config (`dict`):
|
813 |
-
The peft config to use for training.
|
814 |
-
compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
|
815 |
-
The function to use to compute the metrics. Must take a `EvalPrediction` and return
|
816 |
-
a dictionary string to metric values.
|
817 |
-
callbacks (`list[transformers.TrainerCallback]`):
|
818 |
-
The callbacks to use for training.
|
819 |
-
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
|
820 |
-
The optimizer and scheduler to use for training.
|
821 |
-
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
|
822 |
-
The function to use to preprocess the logits before computing the metrics.
|
823 |
-
|
824 |
-
"""
|
825 |
-
def __init__(
|
826 |
-
self,
|
827 |
-
model = None,
|
828 |
-
ref_model = None,
|
829 |
-
reward_model = None,
|
830 |
-
judge = None,
|
831 |
-
args = None,
|
832 |
-
data_collator = None,
|
833 |
-
train_dataset = None,
|
834 |
-
eval_dataset = None,
|
835 |
-
processing_class = None,
|
836 |
-
peft_config = None,
|
837 |
-
compute_metrics = None,
|
838 |
-
callbacks = None,
|
839 |
-
preprocess_logits_for_metrics = None,
|
840 |
-
**kwargs
|
841 |
-
):
|
842 |
-
if args is None: args = UnslothNashMDConfig()
|
843 |
-
use_bf16 = getattr(args, 'bf16', False)
|
844 |
-
use_fp16 = getattr(args, 'fp16', False)
|
845 |
-
force_float32 = False
|
846 |
-
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
847 |
-
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
848 |
-
force_float32 = True
|
849 |
-
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
850 |
-
dtype = getattr(model.config, 'torch_dtype', None)
|
851 |
-
if dtype is None: dtype = model.get_input_embeddings().dtype
|
852 |
-
from unsloth_zoo.utils import _get_dtype
|
853 |
-
dtype = _get_dtype(dtype)
|
854 |
-
float16 = dtype == torch.float16
|
855 |
-
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
856 |
-
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
857 |
-
if force_float32:
|
858 |
-
args.fp16 = False
|
859 |
-
args.bf16 = False
|
860 |
-
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
861 |
-
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
862 |
-
args.fp16 = float16
|
863 |
-
args.bf16 = not float16
|
864 |
-
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
865 |
-
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
866 |
-
args.eval_strategy = 'steps'
|
867 |
-
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
868 |
-
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
869 |
-
if ga_steps is not None and ga_steps > 1:
|
870 |
-
from transformers import __version__ as transformers_version
|
871 |
-
if Version(transformers_version) <= Version('4.45.2'):
|
872 |
-
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
873 |
-
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
874 |
-
if getattr(args, 'eval_strategy', 'no') != 'no':
|
875 |
-
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
876 |
-
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
877 |
-
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
878 |
-
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
879 |
-
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
880 |
-
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
881 |
-
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
882 |
-
if force_float32:
|
883 |
-
args.bf16_full_eval = False
|
884 |
-
args.fp16_full_eval = False
|
885 |
-
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
886 |
-
args.bf16_full_eval = True
|
887 |
-
args.fp16_full_eval = False
|
888 |
-
elif not bf16_full_eval and not fp16_full_eval:
|
889 |
-
args.bf16_full_eval = args.bf16
|
890 |
-
args.fp16_full_eval = args.fp16
|
891 |
-
_output_logits = False
|
892 |
-
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
893 |
-
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
894 |
-
if _output_logits:
|
895 |
-
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
896 |
-
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
897 |
-
pass
|
898 |
-
else:
|
899 |
-
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
900 |
-
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
901 |
-
if args_max_seq_length is None and model_max_seq_length is not None:
|
902 |
-
max_seq_length = model.max_seq_length
|
903 |
-
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
904 |
-
if model is not None and hasattr(model, 'for_training'):
|
905 |
-
model.for_training()
|
906 |
-
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
907 |
-
if 'processing_class' in locals():
|
908 |
-
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
909 |
-
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
910 |
-
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
911 |
-
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
912 |
-
if not isinstance(data_collator, UnslothVisionDataCollator):
|
913 |
-
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
914 |
-
data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
|
915 |
-
elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
916 |
-
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
917 |
-
else:
|
918 |
-
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
919 |
-
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
920 |
-
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
921 |
-
if not isinstance(data_collator, UnslothVisionDataCollator):
|
922 |
-
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
923 |
-
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
924 |
-
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
925 |
-
else:
|
926 |
-
data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
|
927 |
-
other_metrics = []
|
928 |
-
|
929 |
-
from unsloth_zoo.logging_utils import PatchRLStatistics
|
930 |
-
PatchRLStatistics('nash_md_trainer', other_metrics)
|
931 |
-
|
932 |
-
super().__init__(
|
933 |
-
model = model,
|
934 |
-
ref_model = ref_model,
|
935 |
-
reward_model = reward_model,
|
936 |
-
judge = judge,
|
937 |
-
args = args,
|
938 |
-
data_collator = data_collator,
|
939 |
-
train_dataset = train_dataset,
|
940 |
-
eval_dataset = eval_dataset,
|
941 |
-
processing_class = processing_class,
|
942 |
-
peft_config = peft_config,
|
943 |
-
compute_metrics = compute_metrics,
|
944 |
-
callbacks = callbacks,
|
945 |
-
preprocess_logits_for_metrics = preprocess_logits_for_metrics,**kwargs)
|
946 |
-
if hasattr(self, 'neftune_hook_handle'):
|
947 |
-
self.neftune_hook_handle.remove()
|
948 |
-
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
949 |
-
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
950 |
-
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
951 |
-
pass
|
952 |
-
|
953 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unsloth_compiled_cache/UnslothORPOTrainer.py
DELETED
@@ -1,1541 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
2025.3.13
|
3 |
-
2025.3.15
|
4 |
-
4.48.3
|
5 |
-
0.15.2
|
6 |
-
__UNSLOTH_VERSIONING__
|
7 |
-
"""
|
8 |
-
from torch import Tensor
|
9 |
-
import torch
|
10 |
-
import torch.nn as nn
|
11 |
-
from torch.nn import functional as F
|
12 |
-
from trl.trainer.orpo_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, Literal, ORPOConfig, ORPOTrainer, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedModelWrapper, PreTrainedTokenizerBase, ProcessorMixin, Trainer, TrainerCallback, Union, add_bos_token_if_needed, add_eos_token_if_needed, amp, deepcopy, defaultdict, disable_dropout_in_model, generate_model_card, get_comet_experiment_url, inspect, is_comet_available, is_peft_available, is_torch_fx_proxy, is_torch_xla_available, is_wandb_available, log_table_to_comet_experiment, maybe_apply_chat_template, maybe_extract_prompt, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_model_for_kbit_training, random, textwrap, torch, transformers, version, wandb, warnings)
|
13 |
-
|
14 |
-
|
15 |
-
import os
|
16 |
-
from typing import *
|
17 |
-
from dataclasses import dataclass, field
|
18 |
-
from packaging.version import Version
|
19 |
-
import torch
|
20 |
-
import numpy as np
|
21 |
-
from contextlib import nullcontext
|
22 |
-
from torch.nn import functional as F
|
23 |
-
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
|
24 |
-
|
25 |
-
torch_compile_options = {
|
26 |
-
"epilogue_fusion" : True,
|
27 |
-
"max_autotune" : False,
|
28 |
-
"shape_padding" : True,
|
29 |
-
"trace.enabled" : False,
|
30 |
-
"triton.cudagraphs" : False,
|
31 |
-
}
|
32 |
-
|
33 |
-
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
34 |
-
def selective_log_softmax(logits, index):
|
35 |
-
logits = logits.to(torch.float32)
|
36 |
-
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
37 |
-
# loop to reduce peak mem consumption
|
38 |
-
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
39 |
-
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
40 |
-
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
41 |
-
return per_token_logps
|
42 |
-
@dataclass
|
43 |
-
class UnslothORPOConfig(ORPOConfig):
|
44 |
-
"""
|
45 |
-
|
46 |
-
Configuration class for the [`ORPOTrainer`].
|
47 |
-
|
48 |
-
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
49 |
-
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
50 |
-
command line.
|
51 |
-
|
52 |
-
Parameters:
|
53 |
-
learning_rate (`float`, *optional*, defaults to `1e-6`):
|
54 |
-
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
|
55 |
-
[`~transformers.TrainingArguments`].
|
56 |
-
max_length (`int` or `None`, *optional*, defaults to `1024`):
|
57 |
-
Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
|
58 |
-
to use the default data collator.
|
59 |
-
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
|
60 |
-
Maximum length of the prompt. This argument is required if you want to use the default data collator.
|
61 |
-
max_completion_length (`int` or `None`, *optional*, defaults to `None`):
|
62 |
-
Maximum length of the completion. This argument is required if you want to use the default data collator
|
63 |
-
and your model is an encoder-decoder.
|
64 |
-
beta (`float`, *optional*, defaults to `0.1`):
|
65 |
-
Parameter controlling the relative ratio loss weight in the ORPO loss. In the [paper](https://huggingface.co/papers/2403.07691),
|
66 |
-
it is denoted by λ. In the [code](https://github.com/xfactlab/orpo), it is denoted by `alpha`.
|
67 |
-
disable_dropout (`bool`, *optional*, defaults to `True`):
|
68 |
-
Whether to disable dropout in the model.
|
69 |
-
label_pad_token_id (`int`, *optional*, defaults to `-100`):
|
70 |
-
Label pad token id. This argument is required if you want to use the default data collator.
|
71 |
-
padding_value (`int` or `None`, *optional*, defaults to `None`):
|
72 |
-
Padding value to use. If `None`, the padding value of the tokenizer is used.
|
73 |
-
truncation_mode (`str`, *optional*, defaults to `"keep_end"`):
|
74 |
-
Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`.
|
75 |
-
This argument is required if you want to use the default data collator.
|
76 |
-
generate_during_eval (`bool`, *optional*, defaults to `False`):
|
77 |
-
If `True`, generates and logs completions from the model to W&B or Comet during evaluation.
|
78 |
-
is_encoder_decoder (`bool` or `None`, *optional*, defaults to `None`):
|
79 |
-
When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument,
|
80 |
-
you need to specify if the model returned by the callable is an encoder-decoder model.
|
81 |
-
model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
|
82 |
-
Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
|
83 |
-
string.
|
84 |
-
dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
|
85 |
-
Number of processes to use for processing the dataset.
|
86 |
-
|
87 |
-
"""
|
88 |
-
vllm_sampling_params: Optional[Any] = field(
|
89 |
-
default = None,
|
90 |
-
metadata = {'help': 'vLLM SamplingParams'},
|
91 |
-
)
|
92 |
-
unsloth_num_chunks : Optional[int] = field(
|
93 |
-
default = -1,
|
94 |
-
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
95 |
-
)
|
96 |
-
def __init__(
|
97 |
-
self,
|
98 |
-
output_dir = None,
|
99 |
-
overwrite_output_dir = None,
|
100 |
-
do_train = False,
|
101 |
-
do_eval = False,
|
102 |
-
do_predict = False,
|
103 |
-
eval_strategy = 'no',
|
104 |
-
prediction_loss_only = False,
|
105 |
-
per_device_train_batch_size = 4,
|
106 |
-
per_device_eval_batch_size = 4,
|
107 |
-
per_gpu_train_batch_size = None,
|
108 |
-
per_gpu_eval_batch_size = None,
|
109 |
-
gradient_accumulation_steps = 2,
|
110 |
-
eval_accumulation_steps = 2,
|
111 |
-
eval_delay = 0,
|
112 |
-
torch_empty_cache_steps = 250,
|
113 |
-
learning_rate = 5e-05,
|
114 |
-
weight_decay = 0.01,
|
115 |
-
adam_beta1 = 0.9,
|
116 |
-
adam_beta2 = 0.999,
|
117 |
-
adam_epsilon = 1e-08,
|
118 |
-
max_grad_norm = 1.0,
|
119 |
-
num_train_epochs = 3.0,
|
120 |
-
max_steps = -1,
|
121 |
-
lr_scheduler_type = 'linear',
|
122 |
-
warmup_ratio = 0.1,
|
123 |
-
warmup_steps = 0,
|
124 |
-
log_level = 'passive',
|
125 |
-
log_level_replica = 'warning',
|
126 |
-
log_on_each_node = True,
|
127 |
-
logging_dir = None,
|
128 |
-
logging_strategy = 'steps',
|
129 |
-
logging_first_step = False,
|
130 |
-
logging_steps = 1,
|
131 |
-
logging_nan_inf_filter = False,
|
132 |
-
save_strategy = 'steps',
|
133 |
-
save_steps = 500,
|
134 |
-
save_total_limit = None,
|
135 |
-
save_safetensors = True,
|
136 |
-
save_on_each_node = False,
|
137 |
-
save_only_model = False,
|
138 |
-
restore_callback_states_from_checkpoint = False,
|
139 |
-
no_cuda = False,
|
140 |
-
use_cpu = False,
|
141 |
-
use_mps_device = False,
|
142 |
-
seed = 3407,
|
143 |
-
data_seed = 3407,
|
144 |
-
jit_mode_eval = False,
|
145 |
-
use_ipex = False,
|
146 |
-
bf16 = False,
|
147 |
-
fp16 = False,
|
148 |
-
fp16_opt_level = 'O1',
|
149 |
-
half_precision_backend = 'auto',
|
150 |
-
bf16_full_eval = False,
|
151 |
-
fp16_full_eval = False,
|
152 |
-
tf32 = None,
|
153 |
-
local_rank = -1,
|
154 |
-
ddp_backend = None,
|
155 |
-
tpu_num_cores = None,
|
156 |
-
tpu_metrics_debug = False,
|
157 |
-
debug = '',
|
158 |
-
dataloader_drop_last = False,
|
159 |
-
eval_steps = None,
|
160 |
-
dataloader_num_workers = 0,
|
161 |
-
dataloader_prefetch_factor = None,
|
162 |
-
past_index = -1,
|
163 |
-
run_name = None,
|
164 |
-
disable_tqdm = None,
|
165 |
-
remove_unused_columns = True,
|
166 |
-
label_names = None,
|
167 |
-
load_best_model_at_end = False,
|
168 |
-
metric_for_best_model = None,
|
169 |
-
greater_is_better = None,
|
170 |
-
ignore_data_skip = False,
|
171 |
-
fsdp = '',
|
172 |
-
fsdp_min_num_params = 0,
|
173 |
-
fsdp_config = None,
|
174 |
-
fsdp_transformer_layer_cls_to_wrap = None,
|
175 |
-
accelerator_config = None,
|
176 |
-
deepspeed = None,
|
177 |
-
label_smoothing_factor = 0.0,
|
178 |
-
optim = 'adamw_8bit',
|
179 |
-
optim_args = None,
|
180 |
-
adafactor = False,
|
181 |
-
group_by_length = False,
|
182 |
-
length_column_name = 'length',
|
183 |
-
report_to = None,
|
184 |
-
ddp_find_unused_parameters = None,
|
185 |
-
ddp_bucket_cap_mb = None,
|
186 |
-
ddp_broadcast_buffers = None,
|
187 |
-
dataloader_pin_memory = True,
|
188 |
-
dataloader_persistent_workers = False,
|
189 |
-
skip_memory_metrics = True,
|
190 |
-
use_legacy_prediction_loop = False,
|
191 |
-
push_to_hub = False,
|
192 |
-
resume_from_checkpoint = None,
|
193 |
-
hub_model_id = None,
|
194 |
-
hub_strategy = 'every_save',
|
195 |
-
hub_token = None,
|
196 |
-
hub_private_repo = None,
|
197 |
-
hub_always_push = False,
|
198 |
-
gradient_checkpointing = False,
|
199 |
-
gradient_checkpointing_kwargs = None,
|
200 |
-
include_inputs_for_metrics = False,
|
201 |
-
eval_do_concat_batches = True,
|
202 |
-
fp16_backend = 'auto',
|
203 |
-
evaluation_strategy = None,
|
204 |
-
push_to_hub_model_id = None,
|
205 |
-
push_to_hub_organization = None,
|
206 |
-
push_to_hub_token = None,
|
207 |
-
mp_parameters = '',
|
208 |
-
auto_find_batch_size = False,
|
209 |
-
full_determinism = False,
|
210 |
-
torchdynamo = None,
|
211 |
-
ray_scope = 'last',
|
212 |
-
ddp_timeout = 1800,
|
213 |
-
torch_compile = False,
|
214 |
-
torch_compile_backend = None,
|
215 |
-
torch_compile_mode = None,
|
216 |
-
dispatch_batches = None,
|
217 |
-
split_batches = None,
|
218 |
-
include_tokens_per_second = False,
|
219 |
-
include_num_input_tokens_seen = False,
|
220 |
-
neftune_noise_alpha = None,
|
221 |
-
optim_target_modules = None,
|
222 |
-
batch_eval_metrics = False,
|
223 |
-
eval_on_start = False,
|
224 |
-
use_liger_kernel = False,
|
225 |
-
eval_use_gather_object = False,
|
226 |
-
average_tokens_across_devices = False,
|
227 |
-
max_length = 1024,
|
228 |
-
max_prompt_length = 512,
|
229 |
-
max_completion_length = None,
|
230 |
-
beta = 0.1,
|
231 |
-
disable_dropout = True,
|
232 |
-
label_pad_token_id = -100,
|
233 |
-
padding_value = None,
|
234 |
-
truncation_mode = 'keep_end',
|
235 |
-
generate_during_eval = False,
|
236 |
-
is_encoder_decoder = None,
|
237 |
-
model_init_kwargs = None,
|
238 |
-
dataset_num_proc = None,
|
239 |
-
vllm_sampling_params = None,
|
240 |
-
unsloth_num_chunks = -1,
|
241 |
-
**kwargs,
|
242 |
-
):
|
243 |
-
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
244 |
-
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
245 |
-
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
246 |
-
output_dir = 'unsloth_training_checkpoints'
|
247 |
-
save_strategy = 'no'
|
248 |
-
if dataset_num_proc is None:
|
249 |
-
from multiprocessing import cpu_count
|
250 |
-
dataset_num_proc = cpu_count()
|
251 |
-
|
252 |
-
super().__init__(
|
253 |
-
output_dir = output_dir,
|
254 |
-
overwrite_output_dir = overwrite_output_dir,
|
255 |
-
do_train = do_train,
|
256 |
-
do_eval = do_eval,
|
257 |
-
do_predict = do_predict,
|
258 |
-
eval_strategy = eval_strategy,
|
259 |
-
prediction_loss_only = prediction_loss_only,
|
260 |
-
per_device_train_batch_size = per_device_train_batch_size,
|
261 |
-
per_device_eval_batch_size = per_device_eval_batch_size,
|
262 |
-
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
263 |
-
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
264 |
-
gradient_accumulation_steps = gradient_accumulation_steps,
|
265 |
-
eval_accumulation_steps = eval_accumulation_steps,
|
266 |
-
eval_delay = eval_delay,
|
267 |
-
torch_empty_cache_steps = torch_empty_cache_steps,
|
268 |
-
learning_rate = learning_rate,
|
269 |
-
weight_decay = weight_decay,
|
270 |
-
adam_beta1 = adam_beta1,
|
271 |
-
adam_beta2 = adam_beta2,
|
272 |
-
adam_epsilon = adam_epsilon,
|
273 |
-
max_grad_norm = max_grad_norm,
|
274 |
-
num_train_epochs = num_train_epochs,
|
275 |
-
max_steps = max_steps,
|
276 |
-
lr_scheduler_type = lr_scheduler_type,
|
277 |
-
warmup_ratio = warmup_ratio,
|
278 |
-
warmup_steps = warmup_steps,
|
279 |
-
log_level = log_level,
|
280 |
-
log_level_replica = log_level_replica,
|
281 |
-
log_on_each_node = log_on_each_node,
|
282 |
-
logging_dir = logging_dir,
|
283 |
-
logging_strategy = logging_strategy,
|
284 |
-
logging_first_step = logging_first_step,
|
285 |
-
logging_steps = logging_steps,
|
286 |
-
logging_nan_inf_filter = logging_nan_inf_filter,
|
287 |
-
save_strategy = save_strategy,
|
288 |
-
save_steps = save_steps,
|
289 |
-
save_total_limit = save_total_limit,
|
290 |
-
save_safetensors = save_safetensors,
|
291 |
-
save_on_each_node = save_on_each_node,
|
292 |
-
save_only_model = save_only_model,
|
293 |
-
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
294 |
-
no_cuda = no_cuda,
|
295 |
-
use_cpu = use_cpu,
|
296 |
-
use_mps_device = use_mps_device,
|
297 |
-
seed = seed,
|
298 |
-
data_seed = data_seed,
|
299 |
-
jit_mode_eval = jit_mode_eval,
|
300 |
-
use_ipex = use_ipex,
|
301 |
-
bf16 = bf16,
|
302 |
-
fp16 = fp16,
|
303 |
-
fp16_opt_level = fp16_opt_level,
|
304 |
-
half_precision_backend = half_precision_backend,
|
305 |
-
bf16_full_eval = bf16_full_eval,
|
306 |
-
fp16_full_eval = fp16_full_eval,
|
307 |
-
tf32 = tf32,
|
308 |
-
local_rank = local_rank,
|
309 |
-
ddp_backend = ddp_backend,
|
310 |
-
tpu_num_cores = tpu_num_cores,
|
311 |
-
tpu_metrics_debug = tpu_metrics_debug,
|
312 |
-
debug = debug,
|
313 |
-
dataloader_drop_last = dataloader_drop_last,
|
314 |
-
eval_steps = eval_steps,
|
315 |
-
dataloader_num_workers = dataloader_num_workers,
|
316 |
-
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
317 |
-
past_index = past_index,
|
318 |
-
run_name = run_name,
|
319 |
-
disable_tqdm = disable_tqdm,
|
320 |
-
remove_unused_columns = remove_unused_columns,
|
321 |
-
label_names = label_names,
|
322 |
-
load_best_model_at_end = load_best_model_at_end,
|
323 |
-
metric_for_best_model = metric_for_best_model,
|
324 |
-
greater_is_better = greater_is_better,
|
325 |
-
ignore_data_skip = ignore_data_skip,
|
326 |
-
fsdp = fsdp,
|
327 |
-
fsdp_min_num_params = fsdp_min_num_params,
|
328 |
-
fsdp_config = fsdp_config,
|
329 |
-
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
330 |
-
accelerator_config = accelerator_config,
|
331 |
-
deepspeed = deepspeed,
|
332 |
-
label_smoothing_factor = label_smoothing_factor,
|
333 |
-
optim = optim,
|
334 |
-
optim_args = optim_args,
|
335 |
-
adafactor = adafactor,
|
336 |
-
group_by_length = group_by_length,
|
337 |
-
length_column_name = length_column_name,
|
338 |
-
report_to = report_to,
|
339 |
-
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
340 |
-
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
341 |
-
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
342 |
-
dataloader_pin_memory = dataloader_pin_memory,
|
343 |
-
dataloader_persistent_workers = dataloader_persistent_workers,
|
344 |
-
skip_memory_metrics = skip_memory_metrics,
|
345 |
-
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
346 |
-
push_to_hub = push_to_hub,
|
347 |
-
resume_from_checkpoint = resume_from_checkpoint,
|
348 |
-
hub_model_id = hub_model_id,
|
349 |
-
hub_strategy = hub_strategy,
|
350 |
-
hub_token = hub_token,
|
351 |
-
hub_private_repo = hub_private_repo,
|
352 |
-
hub_always_push = hub_always_push,
|
353 |
-
gradient_checkpointing = gradient_checkpointing,
|
354 |
-
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
355 |
-
include_inputs_for_metrics = include_inputs_for_metrics,
|
356 |
-
eval_do_concat_batches = eval_do_concat_batches,
|
357 |
-
fp16_backend = fp16_backend,
|
358 |
-
evaluation_strategy = evaluation_strategy,
|
359 |
-
push_to_hub_model_id = push_to_hub_model_id,
|
360 |
-
push_to_hub_organization = push_to_hub_organization,
|
361 |
-
push_to_hub_token = push_to_hub_token,
|
362 |
-
mp_parameters = mp_parameters,
|
363 |
-
auto_find_batch_size = auto_find_batch_size,
|
364 |
-
full_determinism = full_determinism,
|
365 |
-
torchdynamo = torchdynamo,
|
366 |
-
ray_scope = ray_scope,
|
367 |
-
ddp_timeout = ddp_timeout,
|
368 |
-
torch_compile = torch_compile,
|
369 |
-
torch_compile_backend = torch_compile_backend,
|
370 |
-
torch_compile_mode = torch_compile_mode,
|
371 |
-
dispatch_batches = dispatch_batches,
|
372 |
-
split_batches = split_batches,
|
373 |
-
include_tokens_per_second = include_tokens_per_second,
|
374 |
-
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
375 |
-
neftune_noise_alpha = neftune_noise_alpha,
|
376 |
-
optim_target_modules = optim_target_modules,
|
377 |
-
batch_eval_metrics = batch_eval_metrics,
|
378 |
-
eval_on_start = eval_on_start,
|
379 |
-
use_liger_kernel = use_liger_kernel,
|
380 |
-
eval_use_gather_object = eval_use_gather_object,
|
381 |
-
average_tokens_across_devices = average_tokens_across_devices,
|
382 |
-
max_length = max_length,
|
383 |
-
max_prompt_length = max_prompt_length,
|
384 |
-
max_completion_length = max_completion_length,
|
385 |
-
beta = beta,
|
386 |
-
disable_dropout = disable_dropout,
|
387 |
-
label_pad_token_id = label_pad_token_id,
|
388 |
-
padding_value = padding_value,
|
389 |
-
truncation_mode = truncation_mode,
|
390 |
-
generate_during_eval = generate_during_eval,
|
391 |
-
is_encoder_decoder = is_encoder_decoder,
|
392 |
-
model_init_kwargs = model_init_kwargs,
|
393 |
-
dataset_num_proc = dataset_num_proc,**kwargs)
|
394 |
-
self.vllm_sampling_params = vllm_sampling_params
|
395 |
-
self.unsloth_num_chunks = unsloth_num_chunks
|
396 |
-
pass
|
397 |
-
|
398 |
-
class _UnslothORPOTrainer(Trainer):
|
399 |
-
r""""""
|
400 |
-
|
401 |
-
_tag_names = ["trl", "orpo"]
|
402 |
-
|
403 |
-
def __init__(
|
404 |
-
self,
|
405 |
-
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
|
406 |
-
args: Optional[ORPOConfig] = None,
|
407 |
-
data_collator: Optional[DataCollator] = None,
|
408 |
-
train_dataset: Optional[Dataset] = None,
|
409 |
-
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
410 |
-
processing_class: Optional[
|
411 |
-
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
412 |
-
] = None,
|
413 |
-
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
414 |
-
callbacks: Optional[list[TrainerCallback]] = None,
|
415 |
-
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
416 |
-
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
417 |
-
peft_config: Optional[dict] = None,
|
418 |
-
compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
|
419 |
-
):
|
420 |
-
if args.model_init_kwargs is None:
|
421 |
-
model_init_kwargs = {}
|
422 |
-
elif not isinstance(model, str):
|
423 |
-
raise ValueError("You passed model_kwargs to the ORPOTrainer. But your model is already instantiated.")
|
424 |
-
else:
|
425 |
-
model_init_kwargs = args.model_init_kwargs
|
426 |
-
torch_dtype = model_init_kwargs.get("torch_dtype")
|
427 |
-
if torch_dtype is not None:
|
428 |
-
# Convert to `torch.dtype` if an str is passed
|
429 |
-
if isinstance(torch_dtype, str) and torch_dtype != "auto":
|
430 |
-
torch_dtype = getattr(torch, torch_dtype)
|
431 |
-
if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
|
432 |
-
raise ValueError(
|
433 |
-
f"Invalid `torch_dtype` passed to the ORPOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
|
434 |
-
)
|
435 |
-
model_init_kwargs["torch_dtype"] = torch_dtype
|
436 |
-
|
437 |
-
if isinstance(model, str):
|
438 |
-
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
|
439 |
-
|
440 |
-
# Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
|
441 |
-
# has been called in order to properly call autocast if needed.
|
442 |
-
self._peft_has_been_casted_to_bf16 = False
|
443 |
-
|
444 |
-
if not is_peft_available() and peft_config is not None:
|
445 |
-
raise ValueError(
|
446 |
-
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
|
447 |
-
)
|
448 |
-
elif is_peft_available() and peft_config is not None:
|
449 |
-
# if model is a peft model and we have a peft_config, we merge and unload it first
|
450 |
-
if isinstance(model, PeftModel):
|
451 |
-
model = model.merge_and_unload()
|
452 |
-
|
453 |
-
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
|
454 |
-
_support_gc_kwargs = hasattr(
|
455 |
-
args, "gradient_checkpointing_kwargs"
|
456 |
-
) and "gradient_checkpointing_kwargs" in list(
|
457 |
-
inspect.signature(prepare_model_for_kbit_training).parameters
|
458 |
-
)
|
459 |
-
|
460 |
-
prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
|
461 |
-
|
462 |
-
if _support_gc_kwargs:
|
463 |
-
prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
|
464 |
-
|
465 |
-
model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
|
466 |
-
elif getattr(args, "gradient_checkpointing", False):
|
467 |
-
# For backward compatibility with older versions of transformers
|
468 |
-
if hasattr(model, "enable_input_require_grads"):
|
469 |
-
model.enable_input_require_grads()
|
470 |
-
else:
|
471 |
-
|
472 |
-
def make_inputs_require_grad(module, input, output):
|
473 |
-
output.requires_grad_(True)
|
474 |
-
|
475 |
-
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
476 |
-
|
477 |
-
# get peft model with the given config
|
478 |
-
model = model
|
479 |
-
if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
|
480 |
-
peft_module_casting_to_bf16(model)
|
481 |
-
# If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
|
482 |
-
self._peft_has_been_casted_to_bf16 = True
|
483 |
-
|
484 |
-
# For models that use gradient_checkpointing, we need to attach a hook that enables input
|
485 |
-
# to explicitly have `requires_grad=True`, otherwise training will either silently
|
486 |
-
# fail or completely fail.
|
487 |
-
elif getattr(args, "gradient_checkpointing", False):
|
488 |
-
# For backward compatibility with older versions of transformers
|
489 |
-
if hasattr(model, "enable_input_require_grads"):
|
490 |
-
model.enable_input_require_grads()
|
491 |
-
else:
|
492 |
-
|
493 |
-
def make_inputs_require_grad(module, input, output):
|
494 |
-
output.requires_grad_(True)
|
495 |
-
|
496 |
-
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
497 |
-
|
498 |
-
if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
|
499 |
-
raise ValueError(
|
500 |
-
"`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
|
501 |
-
" Please install `wandb` or `comet-ml` to resolve."
|
502 |
-
)
|
503 |
-
|
504 |
-
if model is not None:
|
505 |
-
self.is_encoder_decoder = model.config.is_encoder_decoder
|
506 |
-
elif args.is_encoder_decoder is None:
|
507 |
-
raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
|
508 |
-
else:
|
509 |
-
self.is_encoder_decoder = args.is_encoder_decoder
|
510 |
-
|
511 |
-
if self.is_encoder_decoder:
|
512 |
-
self.decoder_start_token_id = model.config.decoder_start_token_id
|
513 |
-
self.pad_token_id = model.config.pad_token_id
|
514 |
-
|
515 |
-
if processing_class is None:
|
516 |
-
raise ValueError("processing_class must be specified to tokenize a ORPO dataset.")
|
517 |
-
if args.max_length is None:
|
518 |
-
warnings.warn(
|
519 |
-
"`max_length` is not set in the ORPOConfig's init"
|
520 |
-
" it will default to `512` by default, but you should do it yourself in the future.",
|
521 |
-
UserWarning,
|
522 |
-
)
|
523 |
-
max_length = 512
|
524 |
-
else:
|
525 |
-
max_length = args.max_length
|
526 |
-
if args.max_prompt_length is None:
|
527 |
-
warnings.warn(
|
528 |
-
"`max_prompt_length` is not set in the ORPOConfig's init"
|
529 |
-
" it will default to `128` by default, but you should do it yourself in the future.",
|
530 |
-
UserWarning,
|
531 |
-
)
|
532 |
-
max_prompt_length = 128
|
533 |
-
else:
|
534 |
-
max_prompt_length = args.max_prompt_length
|
535 |
-
|
536 |
-
if args.max_completion_length is None and self.is_encoder_decoder:
|
537 |
-
warnings.warn(
|
538 |
-
"When using an encoder decoder architecture, you should set `max_completion_length` in the ORPOConfig's init"
|
539 |
-
" it will default to `128` by default, but you should do it yourself in the future.",
|
540 |
-
UserWarning,
|
541 |
-
)
|
542 |
-
self.max_completion_length = 128
|
543 |
-
else:
|
544 |
-
self.max_completion_length = args.max_completion_length
|
545 |
-
|
546 |
-
if data_collator is None:
|
547 |
-
data_collator = DPODataCollatorWithPadding(
|
548 |
-
pad_token_id=processing_class.pad_token_id,
|
549 |
-
label_pad_token_id=args.label_pad_token_id,
|
550 |
-
is_encoder_decoder=self.is_encoder_decoder,
|
551 |
-
)
|
552 |
-
|
553 |
-
if args.remove_unused_columns:
|
554 |
-
args.remove_unused_columns = False
|
555 |
-
# warn users
|
556 |
-
warnings.warn(
|
557 |
-
"When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments"
|
558 |
-
" we have set it for you, but you should do it yourself in the future.",
|
559 |
-
UserWarning,
|
560 |
-
)
|
561 |
-
|
562 |
-
self.use_dpo_data_collator = True
|
563 |
-
else:
|
564 |
-
self.use_dpo_data_collator = False
|
565 |
-
|
566 |
-
# Disable dropout in the model and reference model
|
567 |
-
if args.disable_dropout:
|
568 |
-
disable_dropout_in_model(model)
|
569 |
-
|
570 |
-
self.max_length = max_length
|
571 |
-
self.generate_during_eval = args.generate_during_eval
|
572 |
-
self.label_pad_token_id = args.label_pad_token_id
|
573 |
-
self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id
|
574 |
-
self.max_prompt_length = max_prompt_length
|
575 |
-
self.truncation_mode = args.truncation_mode
|
576 |
-
self.processing_class = processing_class
|
577 |
-
|
578 |
-
self.beta = args.beta
|
579 |
-
self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
|
580 |
-
self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
|
581 |
-
if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
|
582 |
-
warnings.warn(
|
583 |
-
"You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
|
584 |
-
"`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
|
585 |
-
"greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
|
586 |
-
"loss.",
|
587 |
-
UserWarning,
|
588 |
-
)
|
589 |
-
|
590 |
-
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
591 |
-
|
592 |
-
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
|
593 |
-
# input tensor associated with the key "input_ids". However, in ORPO, the sampled data does not include the
|
594 |
-
# "input_ids" key. Instead, the available keys are "prompt_input_ids", "chosen_input_ids", and
|
595 |
-
# "rejected_input_ids". As a result, the trainer issues the warning: "Could not estimate the number of tokens
|
596 |
-
# of the input, floating-point operations will not be computed." To suppress this warning, we set the
|
597 |
-
# "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate
|
598 |
-
# that the warning has already been issued.
|
599 |
-
model.warnings_issued["estimate_tokens"] = True
|
600 |
-
|
601 |
-
# Compute that only on the main process for faster data processing.
|
602 |
-
# see: https://github.com/huggingface/trl/pull/1255
|
603 |
-
with PartialState().local_main_process_first():
|
604 |
-
# Extract the prompt if needed, and apply the chat template if needed
|
605 |
-
train_dataset = train_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
|
606 |
-
train_dataset = train_dataset.map(
|
607 |
-
maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, num_proc=args.dataset_num_proc
|
608 |
-
)
|
609 |
-
train_dataset = train_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
|
610 |
-
if eval_dataset is not None:
|
611 |
-
eval_dataset = eval_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
|
612 |
-
eval_dataset = eval_dataset.map(
|
613 |
-
maybe_apply_chat_template,
|
614 |
-
fn_kwargs={"tokenizer": processing_class},
|
615 |
-
num_proc=args.dataset_num_proc,
|
616 |
-
)
|
617 |
-
eval_dataset = eval_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
|
618 |
-
|
619 |
-
super().__init__(
|
620 |
-
model=model,
|
621 |
-
args=args,
|
622 |
-
data_collator=data_collator,
|
623 |
-
train_dataset=train_dataset,
|
624 |
-
eval_dataset=eval_dataset,
|
625 |
-
processing_class=processing_class,
|
626 |
-
model_init=model_init,
|
627 |
-
compute_metrics=compute_metrics,
|
628 |
-
callbacks=callbacks,
|
629 |
-
optimizers=optimizers,
|
630 |
-
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
631 |
-
)
|
632 |
-
|
633 |
-
# Add tags for models that have been loaded with the correct transformers version
|
634 |
-
if hasattr(self.model, "add_model_tags"):
|
635 |
-
self.model.add_model_tags(self._tag_names)
|
636 |
-
|
637 |
-
if not hasattr(self, "accelerator"):
|
638 |
-
raise AttributeError(
|
639 |
-
"Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
|
640 |
-
)
|
641 |
-
|
642 |
-
def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
|
643 |
-
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
|
644 |
-
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
|
645 |
-
config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
|
646 |
-
|
647 |
-
if model is not None:
|
648 |
-
if hasattr(model, "config"):
|
649 |
-
hidden_size = (
|
650 |
-
max(model.config.hidden_sizes)
|
651 |
-
if getattr(model.config, "hidden_sizes", None)
|
652 |
-
else getattr(model.config, "hidden_size", None)
|
653 |
-
)
|
654 |
-
if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
|
655 |
-
# Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
|
656 |
-
# This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
|
657 |
-
config_kwargs.update(
|
658 |
-
{
|
659 |
-
"zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
|
660 |
-
"zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
|
661 |
-
"zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
|
662 |
-
}
|
663 |
-
)
|
664 |
-
|
665 |
-
# If ZeRO-3 is used, we shard both the active and reference model.
|
666 |
-
# Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
|
667 |
-
if config_kwargs["zero_optimization"]["stage"] != 3:
|
668 |
-
config_kwargs["zero_optimization"]["stage"] = 0
|
669 |
-
model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
|
670 |
-
model.eval()
|
671 |
-
return model
|
672 |
-
|
673 |
-
def build_tokenized_answer(self, prompt, answer):
|
674 |
-
"""
|
675 |
-
Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`.
|
676 |
-
It does ensure `enc(a + b) = enc(a) + enc(a + b)[len(enc(a)):]`.
|
677 |
-
Reference:
|
678 |
-
https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
|
679 |
-
"""
|
680 |
-
|
681 |
-
full_tokenized = self.processing_class(prompt + answer, add_special_tokens=False)
|
682 |
-
prompt_input_ids = self.processing_class(prompt, add_special_tokens=False)["input_ids"]
|
683 |
-
|
684 |
-
answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids) :]
|
685 |
-
answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids) :]
|
686 |
-
|
687 |
-
# Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]`
|
688 |
-
full_concat_input_ids = np.concatenate([prompt_input_ids, answer_input_ids])
|
689 |
-
|
690 |
-
# Prepare input tokens for token by token comparison
|
691 |
-
full_input_ids = np.array(full_tokenized["input_ids"])
|
692 |
-
|
693 |
-
if len(full_input_ids) != len(full_concat_input_ids):
|
694 |
-
raise ValueError("Prompt input ids and answer input ids should have the same length.")
|
695 |
-
|
696 |
-
# On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens
|
697 |
-
# can be merged together when tokenizing prompt+answer. This could result
|
698 |
-
# on the last token from the prompt being different when tokenized on its own
|
699 |
-
# vs when done as prompt+answer.
|
700 |
-
response_token_ids_start_idx = len(prompt_input_ids)
|
701 |
-
|
702 |
-
# If tokenized prompt is different than both prompt+answer, then it means the
|
703 |
-
# last token has changed due to merging.
|
704 |
-
if prompt_input_ids != full_tokenized["input_ids"][:response_token_ids_start_idx]:
|
705 |
-
response_token_ids_start_idx -= 1
|
706 |
-
|
707 |
-
prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx]
|
708 |
-
prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx]
|
709 |
-
|
710 |
-
if len(prompt_input_ids) != len(prompt_attention_mask):
|
711 |
-
raise ValueError("Prompt input ids and attention mask should have the same length.")
|
712 |
-
|
713 |
-
answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:]
|
714 |
-
answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:]
|
715 |
-
|
716 |
-
return dict(
|
717 |
-
prompt_input_ids=prompt_input_ids,
|
718 |
-
prompt_attention_mask=prompt_attention_mask,
|
719 |
-
input_ids=answer_input_ids,
|
720 |
-
attention_mask=answer_attention_mask,
|
721 |
-
)
|
722 |
-
|
723 |
-
def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module]] = None) -> dict:
|
724 |
-
"""Tokenize a single row from a ORPO specific dataset.
|
725 |
-
|
726 |
-
At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation
|
727 |
-
in case the prompt + chosen or prompt + rejected responses is/are too long. First
|
728 |
-
we truncate the prompt; if we're still too long, we truncate the chosen/rejected.
|
729 |
-
|
730 |
-
We also create the labels for the chosen/rejected responses, which are of length equal to
|
731 |
-
the sum of the length of the prompt and the chosen/rejected response, with
|
732 |
-
label_pad_token_id for the prompt tokens.
|
733 |
-
"""
|
734 |
-
batch = {}
|
735 |
-
prompt = feature["prompt"]
|
736 |
-
chosen = feature["chosen"]
|
737 |
-
rejected = feature["rejected"]
|
738 |
-
|
739 |
-
if not self.is_encoder_decoder:
|
740 |
-
# Check issues below for more details
|
741 |
-
# 1. https://github.com/huggingface/trl/issues/907
|
742 |
-
# 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
|
743 |
-
# 3. https://github.com/LianjiaTech/BELLE/issues/337
|
744 |
-
|
745 |
-
if not isinstance(prompt, str):
|
746 |
-
raise ValueError(f"prompt should be an str but got {type(prompt)}")
|
747 |
-
prompt_tokens = self.processing_class(prompt, add_special_tokens=False)
|
748 |
-
prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()}
|
749 |
-
|
750 |
-
if not isinstance(chosen, str):
|
751 |
-
raise ValueError(f"chosen should be an str but got {type(chosen)}")
|
752 |
-
chosen_tokens = self.build_tokenized_answer(prompt, chosen)
|
753 |
-
|
754 |
-
if not isinstance(rejected, str):
|
755 |
-
raise ValueError(f"rejected should be an str but got {type(rejected)}")
|
756 |
-
rejected_tokens = self.build_tokenized_answer(prompt, rejected)
|
757 |
-
|
758 |
-
# Last prompt token might get merged by tokenizer and
|
759 |
-
# it should not be included for generation if that happens
|
760 |
-
prompt_len_input_ids = len(prompt_tokens["prompt_input_ids"])
|
761 |
-
|
762 |
-
chosen_prompt_len_input_ids = len(chosen_tokens["prompt_input_ids"])
|
763 |
-
rejected_prompt_len_input_ids = len(rejected_tokens["prompt_input_ids"])
|
764 |
-
prompt_len_input_ids = min(chosen_prompt_len_input_ids, rejected_prompt_len_input_ids)
|
765 |
-
|
766 |
-
for k, v in prompt_tokens.items():
|
767 |
-
prompt_tokens[k] = v[:prompt_len_input_ids]
|
768 |
-
|
769 |
-
# Make sure prompts only have one different token at most an
|
770 |
-
# and length only differs by 1 at most
|
771 |
-
num_diff_tokens = sum(
|
772 |
-
[a != b for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"])]
|
773 |
-
)
|
774 |
-
num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids)
|
775 |
-
if num_diff_tokens > 1 or num_diff_len > 1:
|
776 |
-
raise ValueError(
|
777 |
-
"Chosen and rejected prompt_input_ids might only differ on the "
|
778 |
-
"last token due to tokenizer merge ops."
|
779 |
-
)
|
780 |
-
|
781 |
-
# add BOS token to head of prompt. Avoid adding if it's already there
|
782 |
-
prompt_tokens, chosen_tokens, rejected_tokens = add_bos_token_if_needed(
|
783 |
-
self.processing_class.bos_token_id,
|
784 |
-
prompt_len_input_ids,
|
785 |
-
prompt_tokens,
|
786 |
-
chosen_prompt_len_input_ids,
|
787 |
-
chosen_tokens,
|
788 |
-
rejected_prompt_len_input_ids,
|
789 |
-
rejected_tokens,
|
790 |
-
)
|
791 |
-
|
792 |
-
# add EOS token to end of answer. Avoid adding if it's already there
|
793 |
-
chosen_tokens, rejected_tokens = add_eos_token_if_needed(
|
794 |
-
self.processing_class.eos_token_id, chosen_tokens, rejected_tokens
|
795 |
-
)
|
796 |
-
|
797 |
-
longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))
|
798 |
-
|
799 |
-
# if combined sequence is too long, truncate the prompt
|
800 |
-
for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]:
|
801 |
-
if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
|
802 |
-
if self.truncation_mode == "keep_start":
|
803 |
-
for k in ["prompt_input_ids", "prompt_attention_mask"]:
|
804 |
-
answer_tokens[k] = answer_tokens[k][: self.max_prompt_length]
|
805 |
-
elif self.truncation_mode == "keep_end":
|
806 |
-
for k in ["prompt_input_ids", "prompt_attention_mask"]:
|
807 |
-
answer_tokens[k] = answer_tokens[k][-self.max_prompt_length :]
|
808 |
-
else:
|
809 |
-
raise ValueError(f"Unknown truncation mode: {self.truncation_mode}")
|
810 |
-
|
811 |
-
# if that's still too long, truncate the response
|
812 |
-
for answer_tokens in [chosen_tokens, rejected_tokens]:
|
813 |
-
if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
|
814 |
-
for k in ["input_ids", "attention_mask"]:
|
815 |
-
answer_tokens[k] = answer_tokens[k][: self.max_length - self.max_prompt_length]
|
816 |
-
|
817 |
-
# Create labels
|
818 |
-
chosen_sequence_tokens = {
|
819 |
-
k: chosen_tokens[f"prompt_{k}"] + chosen_tokens[k] for k in ["input_ids", "attention_mask"]
|
820 |
-
}
|
821 |
-
rejected_sequence_tokens = {
|
822 |
-
k: rejected_tokens[f"prompt_{k}"] + rejected_tokens[k] for k in ["input_ids", "attention_mask"]
|
823 |
-
}
|
824 |
-
chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:]
|
825 |
-
chosen_sequence_tokens["labels"][: len(chosen_tokens["prompt_input_ids"])] = [
|
826 |
-
self.label_pad_token_id
|
827 |
-
] * len(chosen_tokens["prompt_input_ids"])
|
828 |
-
rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:]
|
829 |
-
rejected_sequence_tokens["labels"][: len(rejected_tokens["prompt_input_ids"])] = [
|
830 |
-
self.label_pad_token_id
|
831 |
-
] * len(rejected_tokens["prompt_input_ids"])
|
832 |
-
|
833 |
-
for k, toks in {
|
834 |
-
"chosen_": chosen_sequence_tokens,
|
835 |
-
"rejected_": rejected_sequence_tokens,
|
836 |
-
"": prompt_tokens,
|
837 |
-
}.items():
|
838 |
-
for type_key, tokens in toks.items():
|
839 |
-
if type_key == "token_type_ids":
|
840 |
-
continue
|
841 |
-
batch[f"{k}{type_key}"] = tokens
|
842 |
-
|
843 |
-
else:
|
844 |
-
chosen_tokens = self.processing_class(
|
845 |
-
chosen, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
|
846 |
-
)
|
847 |
-
rejected_tokens = self.processing_class(
|
848 |
-
rejected, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
|
849 |
-
)
|
850 |
-
prompt_tokens = self.processing_class(
|
851 |
-
prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True
|
852 |
-
)
|
853 |
-
|
854 |
-
batch["chosen_labels"] = chosen_tokens["input_ids"]
|
855 |
-
batch["rejected_labels"] = rejected_tokens["input_ids"]
|
856 |
-
batch["prompt_input_ids"] = prompt_tokens["input_ids"]
|
857 |
-
batch["prompt_attention_mask"] = prompt_tokens["attention_mask"]
|
858 |
-
|
859 |
-
if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"):
|
860 |
-
batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
|
861 |
-
labels=torch.tensor(batch["rejected_labels"])
|
862 |
-
)
|
863 |
-
batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
|
864 |
-
labels=torch.tensor(batch["chosen_labels"])
|
865 |
-
)
|
866 |
-
|
867 |
-
if is_torch_xla_available():
|
868 |
-
# Pad the sequences to global max_length to avoid TorchXLA recompilation
|
869 |
-
for k in batch:
|
870 |
-
if "labels" in k or self.is_encoder_decoder:
|
871 |
-
pad_value = self.label_pad_token_id
|
872 |
-
elif k.endswith("_input_ids"):
|
873 |
-
pad_value = self.padding_value
|
874 |
-
elif k.endswith("_attention_mask"):
|
875 |
-
pad_value = 0
|
876 |
-
batch[k] = batch[k] + [pad_value] * (self.max_length - len(batch[k]))
|
877 |
-
return batch
|
878 |
-
|
879 |
-
@staticmethod
|
880 |
-
def concatenated_inputs(
|
881 |
-
batch: dict[str, Union[list, torch.LongTensor]],
|
882 |
-
is_encoder_decoder: bool = False,
|
883 |
-
label_pad_token_id: int = -100,
|
884 |
-
padding_value: int = 0,
|
885 |
-
device: Optional[torch.device] = None,
|
886 |
-
) -> dict[str, torch.LongTensor]:
|
887 |
-
"""Concatenate the chosen and rejected inputs into a single tensor.
|
888 |
-
|
889 |
-
Args:
|
890 |
-
batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length).
|
891 |
-
is_encoder_decoder: Whether the model is an encoder-decoder model.
|
892 |
-
label_pad_token_id: The label pad token id.
|
893 |
-
padding_value: The padding value to use for the concatenated inputs_ids.
|
894 |
-
device: The device for the concatenated inputs.
|
895 |
-
|
896 |
-
Returns:
|
897 |
-
A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'.
|
898 |
-
"""
|
899 |
-
concatenated_batch = {}
|
900 |
-
|
901 |
-
if is_encoder_decoder:
|
902 |
-
max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1])
|
903 |
-
else:
|
904 |
-
max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1])
|
905 |
-
|
906 |
-
for k in batch:
|
907 |
-
if k.startswith("chosen") and isinstance(batch[k], torch.Tensor):
|
908 |
-
if "labels" in k or is_encoder_decoder:
|
909 |
-
pad_value = label_pad_token_id
|
910 |
-
elif k.endswith("_input_ids"):
|
911 |
-
pad_value = padding_value
|
912 |
-
elif k.endswith("_attention_mask"):
|
913 |
-
pad_value = 0
|
914 |
-
concatenated_key = k.replace("chosen", "concatenated")
|
915 |
-
concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value)
|
916 |
-
for k in batch:
|
917 |
-
if k.startswith("rejected") and isinstance(batch[k], torch.Tensor):
|
918 |
-
if "labels" in k or is_encoder_decoder:
|
919 |
-
pad_value = label_pad_token_id
|
920 |
-
elif k.endswith("_input_ids"):
|
921 |
-
pad_value = padding_value
|
922 |
-
elif k.endswith("_attention_mask"):
|
923 |
-
pad_value = 0
|
924 |
-
concatenated_key = k.replace("rejected", "concatenated")
|
925 |
-
concatenated_batch[concatenated_key] = torch.cat(
|
926 |
-
(
|
927 |
-
concatenated_batch[concatenated_key],
|
928 |
-
pad_to_length(batch[k], max_length, pad_value=pad_value),
|
929 |
-
),
|
930 |
-
dim=0,
|
931 |
-
).to(device=device)
|
932 |
-
|
933 |
-
if is_encoder_decoder:
|
934 |
-
concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1).to(device=device)
|
935 |
-
concatenated_batch["concatenated_attention_mask"] = (
|
936 |
-
batch["prompt_attention_mask"].repeat(2, 1).to(device=device)
|
937 |
-
)
|
938 |
-
|
939 |
-
return concatenated_batch
|
940 |
-
|
941 |
-
def odds_ratio_loss(
|
942 |
-
self,
|
943 |
-
policy_chosen_logps: torch.FloatTensor,
|
944 |
-
policy_rejected_logps: torch.FloatTensor,
|
945 |
-
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
946 |
-
"""Compute ORPO's odds ratio (OR) loss for a batch of policy and reference model log probabilities.
|
947 |
-
|
948 |
-
Args:
|
949 |
-
policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
|
950 |
-
policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
|
951 |
-
|
952 |
-
Returns:
|
953 |
-
A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
|
954 |
-
The losses tensor contains the ORPO loss for each example in the batch.
|
955 |
-
The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
|
956 |
-
The log odds ratio of the chosen responses over the rejected responses ratio for logging purposes.
|
957 |
-
The `log(sigmoid(log_odds_chosen))` for logging purposes.
|
958 |
-
"""
|
959 |
-
|
960 |
-
# Derived from Eqs. (4) and (7) from https://huggingface.co/papers/2403.07691 by using log identities and exp(log(P(y|x)) = P(y|x)
|
961 |
-
log_odds = (policy_chosen_logps - policy_rejected_logps) - (
|
962 |
-
torch.log1p(-torch.exp(policy_chosen_logps)) - torch.log1p(-torch.exp(policy_rejected_logps))
|
963 |
-
)
|
964 |
-
ratio = F.logsigmoid(log_odds)
|
965 |
-
losses = self.beta * ratio
|
966 |
-
|
967 |
-
chosen_rewards = self.beta * (policy_chosen_logps.to(self.accelerator.device)).detach()
|
968 |
-
rejected_rewards = self.beta * (policy_rejected_logps.to(self.accelerator.device)).detach()
|
969 |
-
|
970 |
-
return losses, chosen_rewards, rejected_rewards, torch.mean(ratio), torch.mean(log_odds)
|
971 |
-
|
972 |
-
@staticmethod
|
973 |
-
def get_batch_logps(
|
974 |
-
logits: torch.FloatTensor,
|
975 |
-
labels: torch.LongTensor,
|
976 |
-
average_log_prob: bool = False,
|
977 |
-
label_pad_token_id: int = -100,
|
978 |
-
is_encoder_decoder: bool = False,
|
979 |
-
) -> torch.FloatTensor:
|
980 |
-
"""Compute the log probabilities of the given labels under the given logits.
|
981 |
-
|
982 |
-
Args:
|
983 |
-
logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
|
984 |
-
labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
|
985 |
-
average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
|
986 |
-
label_pad_token_id: The label pad token id.
|
987 |
-
is_encoder_decoder: Whether the model is an encoder-decoder model.
|
988 |
-
|
989 |
-
Returns:
|
990 |
-
A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
|
991 |
-
"""
|
992 |
-
if logits.shape[:-1] != labels.shape:
|
993 |
-
raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
|
994 |
-
|
995 |
-
if not is_encoder_decoder:
|
996 |
-
labels = labels[:, 1:].clone()
|
997 |
-
logits = logits[:, :-1, :]
|
998 |
-
loss_mask = labels != label_pad_token_id
|
999 |
-
|
1000 |
-
# dummy token; we'll ignore the losses on these tokens later
|
1001 |
-
labels = torch.where(labels == label_pad_token_id, 0, labels)
|
1002 |
-
|
1003 |
-
per_token_logps = selective_log_softmax(logits, labels)
|
1004 |
-
|
1005 |
-
if average_log_prob:
|
1006 |
-
return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
|
1007 |
-
else:
|
1008 |
-
return (per_token_logps * loss_mask).sum(-1)
|
1009 |
-
|
1010 |
-
def concatenated_forward(
|
1011 |
-
self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
|
1012 |
-
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
1013 |
-
"""Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
|
1014 |
-
|
1015 |
-
We do this to avoid doing two forward passes, because it's faster for FSDP.
|
1016 |
-
"""
|
1017 |
-
concatenated_batch = self.concatenated_inputs(
|
1018 |
-
batch,
|
1019 |
-
is_encoder_decoder=self.is_encoder_decoder,
|
1020 |
-
label_pad_token_id=self.label_pad_token_id,
|
1021 |
-
padding_value=self.padding_value,
|
1022 |
-
device=self.accelerator.device,
|
1023 |
-
)
|
1024 |
-
len_chosen = batch["chosen_labels"].shape[0]
|
1025 |
-
|
1026 |
-
model_kwargs = (
|
1027 |
-
{
|
1028 |
-
"decoder_input_ids": self._shift_right(concatenated_batch["concatenated_labels"]),
|
1029 |
-
}
|
1030 |
-
if self.is_encoder_decoder
|
1031 |
-
else {}
|
1032 |
-
)
|
1033 |
-
|
1034 |
-
if self.aux_loss_enabled:
|
1035 |
-
model_kwargs["output_router_logits"] = True
|
1036 |
-
|
1037 |
-
outputs = model(
|
1038 |
-
concatenated_batch["concatenated_input_ids"],
|
1039 |
-
attention_mask=concatenated_batch["concatenated_attention_mask"],
|
1040 |
-
use_cache=False,
|
1041 |
-
**model_kwargs,
|
1042 |
-
)
|
1043 |
-
all_logits = outputs.logits
|
1044 |
-
|
1045 |
-
def cross_entropy_loss(logits, labels):
|
1046 |
-
if not self.is_encoder_decoder:
|
1047 |
-
# Shift so that tokens < n predict n
|
1048 |
-
logits = logits[..., :-1, :].contiguous()
|
1049 |
-
labels = labels[..., 1:].contiguous()
|
1050 |
-
# Flatten the tokens
|
1051 |
-
loss_fct = nn.CrossEntropyLoss()
|
1052 |
-
logits = logits.view(-1, logits.shape[-1])
|
1053 |
-
labels = labels.view(-1)
|
1054 |
-
# Enable model parallelism
|
1055 |
-
labels = labels.to(logits.device)
|
1056 |
-
loss = loss_fct(logits, labels)
|
1057 |
-
return loss
|
1058 |
-
|
1059 |
-
if self.is_encoder_decoder:
|
1060 |
-
labels = concatenated_batch["concatenated_labels"].clone()
|
1061 |
-
else:
|
1062 |
-
labels = concatenated_batch["concatenated_input_ids"].clone()
|
1063 |
-
attention_mask = concatenated_batch["concatenated_attention_mask"]
|
1064 |
-
labels = torch.where(attention_mask == 1, labels, self.label_pad_token_id)
|
1065 |
-
# orpo chosen nll loss is computed over the full prompt and response
|
1066 |
-
chosen_nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])
|
1067 |
-
|
1068 |
-
all_logps = self.get_batch_logps(
|
1069 |
-
all_logits,
|
1070 |
-
concatenated_batch["concatenated_labels"],
|
1071 |
-
average_log_prob=True,
|
1072 |
-
is_encoder_decoder=self.is_encoder_decoder,
|
1073 |
-
label_pad_token_id=self.label_pad_token_id,
|
1074 |
-
)
|
1075 |
-
|
1076 |
-
chosen_logps = all_logps[:len_chosen]
|
1077 |
-
rejected_logps = all_logps[len_chosen:]
|
1078 |
-
|
1079 |
-
if not self.is_encoder_decoder:
|
1080 |
-
chosen_logits = all_logits[:len_chosen, :-1, :]
|
1081 |
-
rejected_logits = all_logits[len_chosen:, :-1, :]
|
1082 |
-
else:
|
1083 |
-
chosen_logits = all_logits[:len_chosen]
|
1084 |
-
rejected_logits = all_logits[len_chosen:]
|
1085 |
-
|
1086 |
-
if self.aux_loss_enabled:
|
1087 |
-
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss, outputs.aux_loss)
|
1088 |
-
|
1089 |
-
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss)
|
1090 |
-
|
1091 |
-
def get_batch_loss_metrics(
|
1092 |
-
self,
|
1093 |
-
model,
|
1094 |
-
batch: dict[str, Union[list, torch.LongTensor]],
|
1095 |
-
train_eval: Literal["train", "eval"] = "train",
|
1096 |
-
):
|
1097 |
-
"""Compute the ORPO loss and other metrics for the given batch of inputs for train or test."""
|
1098 |
-
metrics = {}
|
1099 |
-
|
1100 |
-
forward_output = self.concatenated_forward(model, batch)
|
1101 |
-
(
|
1102 |
-
policy_chosen_logps,
|
1103 |
-
policy_rejected_logps,
|
1104 |
-
policy_chosen_logits,
|
1105 |
-
policy_rejected_logits,
|
1106 |
-
policy_nll_loss,
|
1107 |
-
) = forward_output[:5]
|
1108 |
-
if self.aux_loss_enabled:
|
1109 |
-
aux_loss = forward_output[5]
|
1110 |
-
|
1111 |
-
losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = self.odds_ratio_loss(
|
1112 |
-
policy_chosen_logps, policy_rejected_logps
|
1113 |
-
)
|
1114 |
-
# full ORPO loss
|
1115 |
-
loss = policy_nll_loss - losses.mean()
|
1116 |
-
|
1117 |
-
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
1118 |
-
|
1119 |
-
prefix = "eval_" if train_eval == "eval" else ""
|
1120 |
-
metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean()
|
1121 |
-
metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean()
|
1122 |
-
metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean()
|
1123 |
-
metrics[f"{prefix}rewards/margins"] = self.accelerator.gather_for_metrics(
|
1124 |
-
chosen_rewards - rejected_rewards
|
1125 |
-
).mean()
|
1126 |
-
metrics[f"{prefix}logps/rejected"] = self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean()
|
1127 |
-
metrics[f"{prefix}logps/chosen"] = self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean()
|
1128 |
-
metrics[f"{prefix}logits/rejected"] = (
|
1129 |
-
self.accelerator.gather_for_metrics(policy_rejected_logits).detach().mean()
|
1130 |
-
)
|
1131 |
-
metrics[f"{prefix}logits/chosen"] = self.accelerator.gather_for_metrics(policy_chosen_logits).detach().mean()
|
1132 |
-
metrics[f"{prefix}nll_loss"] = self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean()
|
1133 |
-
metrics[f"{prefix}log_odds_ratio"] = self.accelerator.gather_for_metrics(log_odds_ratio).mean()
|
1134 |
-
metrics[f"{prefix}log_odds_chosen"] = self.accelerator.gather_for_metrics(log_odds_chosen).mean()
|
1135 |
-
if is_torch_xla_available():
|
1136 |
-
xm.mark_step() # needed because .item() calls
|
1137 |
-
for k, v in metrics.items():
|
1138 |
-
metrics[k] = v.item()
|
1139 |
-
if self.aux_loss_enabled:
|
1140 |
-
loss += self.aux_loss_coef * aux_loss
|
1141 |
-
|
1142 |
-
return loss, metrics
|
1143 |
-
|
1144 |
-
def compute_loss(
|
1145 |
-
self,
|
1146 |
-
model: Union[PreTrainedModel, nn.Module],
|
1147 |
-
inputs: dict[str, Union[torch.Tensor, Any]],
|
1148 |
-
return_outputs=False,
|
1149 |
-
num_items_in_batch=None,
|
1150 |
-
) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
|
1151 |
-
compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
|
1152 |
-
|
1153 |
-
with compute_loss_context_manager:
|
1154 |
-
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
|
1155 |
-
|
1156 |
-
# Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class:
|
1157 |
-
loss = loss.to(self.args.device)
|
1158 |
-
|
1159 |
-
# force log the metrics
|
1160 |
-
self.store_metrics(metrics, train_eval="train")
|
1161 |
-
|
1162 |
-
if return_outputs:
|
1163 |
-
return (loss, metrics)
|
1164 |
-
return loss
|
1165 |
-
|
1166 |
-
def generate_from_model(self, model, batch: dict[str, torch.LongTensor]) -> str:
|
1167 |
-
"""Generate samples from the model and reference model for the given batch of inputs."""
|
1168 |
-
|
1169 |
-
# If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
|
1170 |
-
# the torch cuda amp context manager as some hidden states are silently casted to full precision.
|
1171 |
-
generate_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
|
1172 |
-
|
1173 |
-
with generate_context_manager:
|
1174 |
-
policy_output = model.generate(
|
1175 |
-
input_ids=batch["prompt_input_ids"],
|
1176 |
-
attention_mask=batch["prompt_attention_mask"],
|
1177 |
-
max_length=self.max_length,
|
1178 |
-
do_sample=True,
|
1179 |
-
pad_token_id=self.processing_class.pad_token_id,
|
1180 |
-
)
|
1181 |
-
|
1182 |
-
policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id)
|
1183 |
-
policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True)
|
1184 |
-
|
1185 |
-
return policy_output_decoded
|
1186 |
-
|
1187 |
-
def prediction_step(
|
1188 |
-
self,
|
1189 |
-
model: Union[PreTrainedModel, nn.Module],
|
1190 |
-
inputs: dict[str, Union[torch.Tensor, Any]],
|
1191 |
-
prediction_loss_only: bool,
|
1192 |
-
ignore_keys: Optional[list[str]] = None,
|
1193 |
-
):
|
1194 |
-
if not self.use_dpo_data_collator:
|
1195 |
-
warnings.warn(
|
1196 |
-
"prediction_step is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
|
1197 |
-
"DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
|
1198 |
-
)
|
1199 |
-
if ignore_keys is None:
|
1200 |
-
if hasattr(model, "config"):
|
1201 |
-
ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
|
1202 |
-
else:
|
1203 |
-
ignore_keys = []
|
1204 |
-
|
1205 |
-
prediction_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
|
1206 |
-
|
1207 |
-
with torch.no_grad(), prediction_context_manager:
|
1208 |
-
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval")
|
1209 |
-
|
1210 |
-
# force log the metrics
|
1211 |
-
self.store_metrics(metrics, train_eval="eval")
|
1212 |
-
|
1213 |
-
if prediction_loss_only:
|
1214 |
-
return (loss.detach(), None, None)
|
1215 |
-
|
1216 |
-
# logits for the chosen and rejected samples from model
|
1217 |
-
logits_dict = {
|
1218 |
-
"eval_logits/chosen": metrics["eval_logits/chosen"],
|
1219 |
-
"eval_logits/rejected": metrics["eval_logits/rejected"],
|
1220 |
-
}
|
1221 |
-
logits = tuple(v.unsqueeze(dim=0) for k, v in logits_dict.items() if k not in ignore_keys)
|
1222 |
-
logits = torch.stack(logits).mean(axis=1).to(self.accelerator.device)
|
1223 |
-
labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
|
1224 |
-
|
1225 |
-
return (loss.detach(), logits, labels)
|
1226 |
-
|
1227 |
-
def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
|
1228 |
-
for key, value in metrics.items():
|
1229 |
-
self._stored_metrics[train_eval][key].append(value)
|
1230 |
-
|
1231 |
-
def evaluation_loop(
|
1232 |
-
self,
|
1233 |
-
dataloader: DataLoader,
|
1234 |
-
description: str,
|
1235 |
-
prediction_loss_only: Optional[bool] = None,
|
1236 |
-
ignore_keys: Optional[list[str]] = None,
|
1237 |
-
metric_key_prefix: str = "eval",
|
1238 |
-
) -> EvalLoopOutput:
|
1239 |
-
"""
|
1240 |
-
Overriding built-in evaluation loop to store metrics for each batch.
|
1241 |
-
Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
|
1242 |
-
|
1243 |
-
Works both with or without labels.
|
1244 |
-
"""
|
1245 |
-
|
1246 |
-
# Sample and save to game log if requested (for one batch to save time)
|
1247 |
-
if self.generate_during_eval:
|
1248 |
-
# Generate random indices within the range of the total number of samples
|
1249 |
-
num_samples = len(dataloader.dataset)
|
1250 |
-
random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
|
1251 |
-
|
1252 |
-
# Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
|
1253 |
-
random_batch_dataset = dataloader.dataset.select(random_indices)
|
1254 |
-
random_batch = self.data_collator(random_batch_dataset)
|
1255 |
-
random_batch = self._prepare_inputs(random_batch)
|
1256 |
-
|
1257 |
-
policy_output_decoded = self.generate_from_model(self.model, random_batch)
|
1258 |
-
|
1259 |
-
table = pd.DataFrame(
|
1260 |
-
columns=["Prompt", "Policy"],
|
1261 |
-
data=[
|
1262 |
-
[prompt, pol[len(prompt) :]] for prompt, pol in zip(random_batch["prompt"], policy_output_decoded)
|
1263 |
-
],
|
1264 |
-
)
|
1265 |
-
if "wandb" in self.args.report_to:
|
1266 |
-
wandb.log({"game_log": wandb.Table(data=table)})
|
1267 |
-
|
1268 |
-
if "comet_ml" in self.args.report_to:
|
1269 |
-
log_table_to_comet_experiment(
|
1270 |
-
name="game_log.csv",
|
1271 |
-
table=table,
|
1272 |
-
)
|
1273 |
-
|
1274 |
-
# Base evaluation
|
1275 |
-
initial_output = super().evaluation_loop(
|
1276 |
-
dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
|
1277 |
-
)
|
1278 |
-
|
1279 |
-
return initial_output
|
1280 |
-
|
1281 |
-
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
|
1282 |
-
"""
|
1283 |
-
Log `logs` on the various objects watching training, including stored metrics.
|
1284 |
-
|
1285 |
-
Args:
|
1286 |
-
logs (`dict[str, float]`):
|
1287 |
-
The values to log.
|
1288 |
-
start_time (`float` or `None`, *optional*, defaults to `None`):
|
1289 |
-
Start time of the training.
|
1290 |
-
"""
|
1291 |
-
# logs either has 'loss' or 'eval_loss'
|
1292 |
-
train_eval = "train" if "loss" in logs else "eval"
|
1293 |
-
# Add averaged stored metrics to logs
|
1294 |
-
for key, metrics in self._stored_metrics[train_eval].items():
|
1295 |
-
logs[key] = torch.tensor(metrics).mean().item()
|
1296 |
-
del self._stored_metrics[train_eval]
|
1297 |
-
|
1298 |
-
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
1299 |
-
return super().log(logs, start_time)
|
1300 |
-
else: # transformers<=4.46
|
1301 |
-
return super().log(logs)
|
1302 |
-
|
1303 |
-
def _shift_right(self, input_ids):
|
1304 |
-
if self.decoder_start_token_id is None:
|
1305 |
-
raise ValueError(
|
1306 |
-
"model.config.decoder_start_token_id has to be defined. It is usually set to the pad_token_id."
|
1307 |
-
)
|
1308 |
-
|
1309 |
-
# shift inputs to the right
|
1310 |
-
if is_torch_fx_proxy(input_ids):
|
1311 |
-
# Item assignment is not supported natively for proxies.
|
1312 |
-
shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), self.decoder_start_token_id)
|
1313 |
-
shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
|
1314 |
-
else:
|
1315 |
-
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
1316 |
-
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
|
1317 |
-
shifted_input_ids[..., 0] = self.decoder_start_token_id
|
1318 |
-
|
1319 |
-
if self.pad_token_id is None:
|
1320 |
-
raise ValueError("model.config.pad_token_id has to be defined.")
|
1321 |
-
# replace possible -100 values in labels by `pad_token_id`
|
1322 |
-
shifted_input_ids.masked_fill_(shifted_input_ids == -100, self.pad_token_id)
|
1323 |
-
|
1324 |
-
return shifted_input_ids
|
1325 |
-
|
1326 |
-
def create_model_card(
|
1327 |
-
self,
|
1328 |
-
model_name: Optional[str] = None,
|
1329 |
-
dataset_name: Optional[str] = None,
|
1330 |
-
tags: Union[str, list[str], None] = None,
|
1331 |
-
):
|
1332 |
-
"""
|
1333 |
-
Creates a draft of a model card using the information available to the `Trainer`.
|
1334 |
-
|
1335 |
-
Args:
|
1336 |
-
model_name (`str` or `None`, *optional*, defaults to `None`):
|
1337 |
-
Name of the model.
|
1338 |
-
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
1339 |
-
Name of the dataset used for training.
|
1340 |
-
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
1341 |
-
Tags to be associated with the model card.
|
1342 |
-
"""
|
1343 |
-
if not self.is_world_process_zero():
|
1344 |
-
return
|
1345 |
-
|
1346 |
-
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
1347 |
-
base_model = self.model.config._name_or_path
|
1348 |
-
else:
|
1349 |
-
base_model = None
|
1350 |
-
|
1351 |
-
tags = tags or []
|
1352 |
-
if isinstance(tags, str):
|
1353 |
-
tags = [tags]
|
1354 |
-
|
1355 |
-
if hasattr(self.model.config, "unsloth_version"):
|
1356 |
-
tags.append("unsloth")
|
1357 |
-
|
1358 |
-
citation = textwrap.dedent("""\
|
1359 |
-
@article{hong2024orpo,
|
1360 |
-
title = {{ORPO: Monolithic Preference Optimization without Reference Model}},
|
1361 |
-
author = {Jiwoo Hong and Noah Lee and James Thorne},
|
1362 |
-
year = 2024,
|
1363 |
-
eprint = {arXiv:2403.07691}
|
1364 |
-
}""")
|
1365 |
-
|
1366 |
-
model_card = generate_model_card(
|
1367 |
-
base_model=base_model,
|
1368 |
-
model_name=model_name,
|
1369 |
-
hub_model_id=self.hub_model_id,
|
1370 |
-
dataset_name=dataset_name,
|
1371 |
-
tags=tags,
|
1372 |
-
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
1373 |
-
comet_url=get_comet_experiment_url(),
|
1374 |
-
trainer_name="ORPO",
|
1375 |
-
trainer_citation=citation,
|
1376 |
-
paper_title="ORPO: Monolithic Preference Optimization without Reference Model",
|
1377 |
-
paper_id="2403.07691",
|
1378 |
-
)
|
1379 |
-
|
1380 |
-
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
1381 |
-
class UnslothORPOTrainer(_UnslothORPOTrainer):
|
1382 |
-
"""
|
1383 |
-
|
1384 |
-
Initialize ORPOTrainer.
|
1385 |
-
|
1386 |
-
Args:
|
1387 |
-
model (`transformers.PreTrainedModel`):
|
1388 |
-
The model to train, preferably an `AutoModelForSequenceClassification`.
|
1389 |
-
args (`ORPOConfig`):
|
1390 |
-
The ORPO config arguments to use for training.
|
1391 |
-
data_collator (`transformers.DataCollator`):
|
1392 |
-
The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
|
1393 |
-
which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
|
1394 |
-
train_dataset (`datasets.Dataset`):
|
1395 |
-
The dataset to use for training.
|
1396 |
-
eval_dataset (`datasets.Dataset`):
|
1397 |
-
The dataset to use for evaluation.
|
1398 |
-
processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
|
1399 |
-
Processing class used to process the data. If provided, will be used to automatically process the inputs
|
1400 |
-
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
|
1401 |
-
reuse the fine-tuned model.
|
1402 |
-
model_init (`Callable[[], transformers.PreTrainedModel]`):
|
1403 |
-
The model initializer to use for training. If None is specified, the default model initializer will be used.
|
1404 |
-
callbacks (`list[transformers.TrainerCallback]`):
|
1405 |
-
The callbacks to use for training.
|
1406 |
-
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
|
1407 |
-
The optimizer and scheduler to use for training.
|
1408 |
-
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
|
1409 |
-
The function to use to preprocess the logits before computing the metrics.
|
1410 |
-
peft_config (`dict`, defaults to `None`):
|
1411 |
-
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
|
1412 |
-
compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
|
1413 |
-
The function to use to compute the metrics. Must take a `EvalPrediction` and return
|
1414 |
-
a dictionary string to metric values.
|
1415 |
-
|
1416 |
-
"""
|
1417 |
-
def __init__(
|
1418 |
-
self,
|
1419 |
-
model = None,
|
1420 |
-
args = None,
|
1421 |
-
data_collator = None,
|
1422 |
-
train_dataset = None,
|
1423 |
-
eval_dataset = None,
|
1424 |
-
processing_class = None,
|
1425 |
-
model_init = None,
|
1426 |
-
callbacks = None,
|
1427 |
-
preprocess_logits_for_metrics = None,
|
1428 |
-
peft_config = None,
|
1429 |
-
compute_metrics = None,
|
1430 |
-
**kwargs
|
1431 |
-
):
|
1432 |
-
if args is None: args = UnslothORPOConfig()
|
1433 |
-
use_bf16 = getattr(args, 'bf16', False)
|
1434 |
-
use_fp16 = getattr(args, 'fp16', False)
|
1435 |
-
force_float32 = False
|
1436 |
-
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
1437 |
-
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
1438 |
-
force_float32 = True
|
1439 |
-
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
1440 |
-
dtype = getattr(model.config, 'torch_dtype', None)
|
1441 |
-
if dtype is None: dtype = model.get_input_embeddings().dtype
|
1442 |
-
from unsloth_zoo.utils import _get_dtype
|
1443 |
-
dtype = _get_dtype(dtype)
|
1444 |
-
float16 = dtype == torch.float16
|
1445 |
-
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
1446 |
-
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
1447 |
-
if force_float32:
|
1448 |
-
args.fp16 = False
|
1449 |
-
args.bf16 = False
|
1450 |
-
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
1451 |
-
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
1452 |
-
args.fp16 = float16
|
1453 |
-
args.bf16 = not float16
|
1454 |
-
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
1455 |
-
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
1456 |
-
args.eval_strategy = 'steps'
|
1457 |
-
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
1458 |
-
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
1459 |
-
if ga_steps is not None and ga_steps > 1:
|
1460 |
-
from transformers import __version__ as transformers_version
|
1461 |
-
if Version(transformers_version) <= Version('4.45.2'):
|
1462 |
-
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
1463 |
-
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
1464 |
-
if getattr(args, 'eval_strategy', 'no') != 'no':
|
1465 |
-
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
1466 |
-
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
1467 |
-
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
1468 |
-
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
1469 |
-
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
1470 |
-
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
1471 |
-
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
1472 |
-
if force_float32:
|
1473 |
-
args.bf16_full_eval = False
|
1474 |
-
args.fp16_full_eval = False
|
1475 |
-
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
1476 |
-
args.bf16_full_eval = True
|
1477 |
-
args.fp16_full_eval = False
|
1478 |
-
elif not bf16_full_eval and not fp16_full_eval:
|
1479 |
-
args.bf16_full_eval = args.bf16
|
1480 |
-
args.fp16_full_eval = args.fp16
|
1481 |
-
_output_logits = False
|
1482 |
-
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
1483 |
-
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
1484 |
-
if _output_logits:
|
1485 |
-
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
1486 |
-
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
1487 |
-
pass
|
1488 |
-
else:
|
1489 |
-
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
1490 |
-
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
1491 |
-
if args_max_seq_length is None and model_max_seq_length is not None:
|
1492 |
-
max_seq_length = model.max_seq_length
|
1493 |
-
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
1494 |
-
if model is not None and hasattr(model, 'for_training'):
|
1495 |
-
model.for_training()
|
1496 |
-
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
1497 |
-
if 'processing_class' in locals():
|
1498 |
-
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
1499 |
-
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
1500 |
-
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
1501 |
-
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
1502 |
-
if not isinstance(data_collator, UnslothVisionDataCollator):
|
1503 |
-
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
1504 |
-
data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
|
1505 |
-
elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
1506 |
-
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
1507 |
-
else:
|
1508 |
-
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
1509 |
-
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
1510 |
-
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
1511 |
-
if not isinstance(data_collator, UnslothVisionDataCollator):
|
1512 |
-
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
1513 |
-
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
1514 |
-
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
1515 |
-
else:
|
1516 |
-
data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
|
1517 |
-
other_metrics = []
|
1518 |
-
|
1519 |
-
from unsloth_zoo.logging_utils import PatchRLStatistics
|
1520 |
-
PatchRLStatistics('orpo_trainer', other_metrics)
|
1521 |
-
|
1522 |
-
super().__init__(
|
1523 |
-
model = model,
|
1524 |
-
args = args,
|
1525 |
-
data_collator = data_collator,
|
1526 |
-
train_dataset = train_dataset,
|
1527 |
-
eval_dataset = eval_dataset,
|
1528 |
-
processing_class = processing_class,
|
1529 |
-
model_init = model_init,
|
1530 |
-
callbacks = callbacks,
|
1531 |
-
preprocess_logits_for_metrics = preprocess_logits_for_metrics,
|
1532 |
-
peft_config = peft_config,
|
1533 |
-
compute_metrics = compute_metrics,**kwargs)
|
1534 |
-
if hasattr(self, 'neftune_hook_handle'):
|
1535 |
-
self.neftune_hook_handle.remove()
|
1536 |
-
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
1537 |
-
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
1538 |
-
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
1539 |
-
pass
|
1540 |
-
|
1541 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unsloth_compiled_cache/UnslothOnlineDPOTrainer.py
DELETED
@@ -1,1267 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
2025.3.13
|
3 |
-
2025.3.15
|
4 |
-
4.48.3
|
5 |
-
0.15.2
|
6 |
-
__UNSLOTH_VERSIONING__
|
7 |
-
"""
|
8 |
-
from torch import Tensor
|
9 |
-
import torch
|
10 |
-
import torch.nn as nn
|
11 |
-
from torch.nn import functional as F
|
12 |
-
from trl.trainer.online_dpo_trainer import (Any, BaseImageProcessor, BasePairwiseJudge, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalPrediction, F, FeatureExtractionMixin, GenerationConfig, IterableDataset, OnlineDPOConfig, OnlineDPOTrainer, OptimizerNames, Optional, PREFIX_CHECKPOINT_DIR, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SIMPLE_CHAT_TEMPLATE, Trainer, TrainerCallback, Union, apply_chat_template, create_reference_model, datasets, disable_dropout_in_model, empty_cache, generate_model_card, get_comet_experiment_url, get_reward, is_conversational, is_peft_available, is_wandb_available, jinja2, logging, maybe_apply_chat_template, nn, np, os, prepare_deepspeed, seed_worker, textwrap, torch, transformers, truncate_right, unwrap_model_for_generation, version, wandb, warnings, wraps, F, is_conversational, os, torch)
|
13 |
-
|
14 |
-
|
15 |
-
import os
|
16 |
-
from typing import *
|
17 |
-
from dataclasses import dataclass, field
|
18 |
-
from packaging.version import Version
|
19 |
-
import torch
|
20 |
-
import numpy as np
|
21 |
-
from contextlib import nullcontext
|
22 |
-
from torch.nn import functional as F
|
23 |
-
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
|
24 |
-
|
25 |
-
torch_compile_options = {
|
26 |
-
"epilogue_fusion" : True,
|
27 |
-
"max_autotune" : False,
|
28 |
-
"shape_padding" : True,
|
29 |
-
"trace.enabled" : False,
|
30 |
-
"triton.cudagraphs" : False,
|
31 |
-
}
|
32 |
-
|
33 |
-
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
34 |
-
def selective_log_softmax(logits, index):
|
35 |
-
logits = logits.to(torch.float32)
|
36 |
-
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
37 |
-
# loop to reduce peak mem consumption
|
38 |
-
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
39 |
-
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
40 |
-
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
41 |
-
return per_token_logps
|
42 |
-
def vLLMSamplingParams(**kwargs):
|
43 |
-
from vllm import SamplingParams
|
44 |
-
sampling_params = SamplingParams(**kwargs)
|
45 |
-
sampling_params._set_kwargs = kwargs
|
46 |
-
return sampling_params
|
47 |
-
@dataclass
|
48 |
-
class UnslothOnlineDPOConfig(OnlineDPOConfig):
|
49 |
-
"""
|
50 |
-
|
51 |
-
Configuration class for the [`OnlineDPOTrainer`].
|
52 |
-
|
53 |
-
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
54 |
-
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
55 |
-
command line.
|
56 |
-
|
57 |
-
Parameters:
|
58 |
-
learning_rate (`float`, *optional*, defaults to `5e-7`):
|
59 |
-
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
|
60 |
-
[`~transformers.TrainingArguments`].
|
61 |
-
reward_model_path (`str` or `None`, *optional*, defaults to `None`):
|
62 |
-
Path to the reward model. Either `judge` or `reward_model_path` must be set, but not both.
|
63 |
-
judge (`str` or `None`, *optional*, defaults to `None`):
|
64 |
-
Name of the judge to use. Either `judge` or `reward_model_path` must be set, but not both.
|
65 |
-
max_new_tokens (`int`, *optional*, defaults to `64`):
|
66 |
-
Maximum number of tokens to generate per completion.
|
67 |
-
max_length (`int`, *optional*, defaults to `256`):
|
68 |
-
Maximum total length of the sequence (prompt + completion) used to compute log probabilities. If the
|
69 |
-
sequence exceeds this limit, the leftmost tokens will be truncated to preserve as much of the completion as
|
70 |
-
possible.
|
71 |
-
temperature (`float`, *optional*, defaults to `0.9`):
|
72 |
-
Temperature for sampling. The higher the temperature, the more random the completions.
|
73 |
-
missing_eos_penalty (`float` or `None`, *optional*, defaults to `None`):
|
74 |
-
Penalty applied to the score when the model fails to generate an EOS token. This is useful to encourage
|
75 |
-
to generate completions shorter than the maximum length (`max_new_tokens`). The penalty must be a positive
|
76 |
-
value.
|
77 |
-
beta (`float` or `list[float]`, *optional*, defaults to `0.1`):
|
78 |
-
Parameter controlling the deviation from the reference model. Higher β means less deviation from the
|
79 |
-
reference model. For the IPO loss (`loss_type="ipo"`), β is the regularization parameter denoted by τ in
|
80 |
-
the [paper](https://huggingface.co/papers/2310.12036). If a list of floats is provided then the β is
|
81 |
-
selected for each new epoch and the last β is used for the rest of the epochs.
|
82 |
-
loss_type (`str`, *optional*, defaults to `"sigmoid"`):
|
83 |
-
Type of loss to use. Possible values are:
|
84 |
-
|
85 |
-
- `"sigmoid"`: sigmoid loss from the original [DPO](https://huggingface.co/papers/2305.18290) paper.
|
86 |
-
- `"ipo"`: IPO loss from the [IPO](https://huggingface.co/papers/2310.12036) paper.
|
87 |
-
|
88 |
-
dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
|
89 |
-
Number of processes to use for processing the dataset.
|
90 |
-
disable_dropout (`bool`, *optional*, defaults to `True`):
|
91 |
-
Whether to disable dropout in the model and reference model.
|
92 |
-
use_vllm (`bool`, *optional*, defaults to `False`):
|
93 |
-
Whether to use vLLM for generating completions. Requires vLLM to be installed (`pip install vllm`).
|
94 |
-
ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
|
95 |
-
This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
|
96 |
-
improving generation speed. However, disabling this option allows training models that exceed the VRAM
|
97 |
-
capacity of a single GPU, albeit at the cost of slower generation.
|
98 |
-
|
99 |
-
"""
|
100 |
-
vllm_sampling_params: Optional[Any] = field(
|
101 |
-
default = None,
|
102 |
-
metadata = {'help': 'vLLM SamplingParams'},
|
103 |
-
)
|
104 |
-
unsloth_num_chunks : Optional[int] = field(
|
105 |
-
default = -1,
|
106 |
-
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
107 |
-
)
|
108 |
-
def __init__(
|
109 |
-
self,
|
110 |
-
output_dir = None,
|
111 |
-
overwrite_output_dir = None,
|
112 |
-
do_train = False,
|
113 |
-
do_eval = False,
|
114 |
-
do_predict = False,
|
115 |
-
eval_strategy = 'no',
|
116 |
-
prediction_loss_only = False,
|
117 |
-
per_device_train_batch_size = 4,
|
118 |
-
per_device_eval_batch_size = 4,
|
119 |
-
per_gpu_train_batch_size = None,
|
120 |
-
per_gpu_eval_batch_size = None,
|
121 |
-
gradient_accumulation_steps = 2,
|
122 |
-
eval_accumulation_steps = 2,
|
123 |
-
eval_delay = 0,
|
124 |
-
torch_empty_cache_steps = 250,
|
125 |
-
learning_rate = 5e-05,
|
126 |
-
weight_decay = 0.01,
|
127 |
-
adam_beta1 = 0.9,
|
128 |
-
adam_beta2 = 0.999,
|
129 |
-
adam_epsilon = 1e-08,
|
130 |
-
max_grad_norm = 1.0,
|
131 |
-
num_train_epochs = 3.0,
|
132 |
-
max_steps = -1,
|
133 |
-
lr_scheduler_type = 'linear',
|
134 |
-
warmup_ratio = 0.1,
|
135 |
-
warmup_steps = 0,
|
136 |
-
log_level = 'passive',
|
137 |
-
log_level_replica = 'warning',
|
138 |
-
log_on_each_node = True,
|
139 |
-
logging_dir = None,
|
140 |
-
logging_strategy = 'steps',
|
141 |
-
logging_first_step = False,
|
142 |
-
logging_steps = 1,
|
143 |
-
logging_nan_inf_filter = False,
|
144 |
-
save_strategy = 'steps',
|
145 |
-
save_steps = 500,
|
146 |
-
save_total_limit = None,
|
147 |
-
save_safetensors = True,
|
148 |
-
save_on_each_node = False,
|
149 |
-
save_only_model = False,
|
150 |
-
restore_callback_states_from_checkpoint = False,
|
151 |
-
no_cuda = False,
|
152 |
-
use_cpu = False,
|
153 |
-
use_mps_device = False,
|
154 |
-
seed = 3407,
|
155 |
-
data_seed = 3407,
|
156 |
-
jit_mode_eval = False,
|
157 |
-
use_ipex = False,
|
158 |
-
bf16 = False,
|
159 |
-
fp16 = False,
|
160 |
-
fp16_opt_level = 'O1',
|
161 |
-
half_precision_backend = 'auto',
|
162 |
-
bf16_full_eval = False,
|
163 |
-
fp16_full_eval = False,
|
164 |
-
tf32 = None,
|
165 |
-
local_rank = -1,
|
166 |
-
ddp_backend = None,
|
167 |
-
tpu_num_cores = None,
|
168 |
-
tpu_metrics_debug = False,
|
169 |
-
debug = '',
|
170 |
-
dataloader_drop_last = False,
|
171 |
-
eval_steps = None,
|
172 |
-
dataloader_num_workers = 0,
|
173 |
-
dataloader_prefetch_factor = None,
|
174 |
-
past_index = -1,
|
175 |
-
run_name = None,
|
176 |
-
disable_tqdm = None,
|
177 |
-
remove_unused_columns = True,
|
178 |
-
label_names = None,
|
179 |
-
load_best_model_at_end = False,
|
180 |
-
metric_for_best_model = None,
|
181 |
-
greater_is_better = None,
|
182 |
-
ignore_data_skip = False,
|
183 |
-
fsdp = '',
|
184 |
-
fsdp_min_num_params = 0,
|
185 |
-
fsdp_config = None,
|
186 |
-
fsdp_transformer_layer_cls_to_wrap = None,
|
187 |
-
accelerator_config = None,
|
188 |
-
deepspeed = None,
|
189 |
-
label_smoothing_factor = 0.0,
|
190 |
-
optim = 'adamw_8bit',
|
191 |
-
optim_args = None,
|
192 |
-
adafactor = False,
|
193 |
-
group_by_length = False,
|
194 |
-
length_column_name = 'length',
|
195 |
-
report_to = None,
|
196 |
-
ddp_find_unused_parameters = None,
|
197 |
-
ddp_bucket_cap_mb = None,
|
198 |
-
ddp_broadcast_buffers = None,
|
199 |
-
dataloader_pin_memory = True,
|
200 |
-
dataloader_persistent_workers = False,
|
201 |
-
skip_memory_metrics = True,
|
202 |
-
use_legacy_prediction_loop = False,
|
203 |
-
push_to_hub = False,
|
204 |
-
resume_from_checkpoint = None,
|
205 |
-
hub_model_id = None,
|
206 |
-
hub_strategy = 'every_save',
|
207 |
-
hub_token = None,
|
208 |
-
hub_private_repo = None,
|
209 |
-
hub_always_push = False,
|
210 |
-
gradient_checkpointing = False,
|
211 |
-
gradient_checkpointing_kwargs = None,
|
212 |
-
include_inputs_for_metrics = False,
|
213 |
-
eval_do_concat_batches = True,
|
214 |
-
fp16_backend = 'auto',
|
215 |
-
evaluation_strategy = None,
|
216 |
-
push_to_hub_model_id = None,
|
217 |
-
push_to_hub_organization = None,
|
218 |
-
push_to_hub_token = None,
|
219 |
-
mp_parameters = '',
|
220 |
-
auto_find_batch_size = False,
|
221 |
-
full_determinism = False,
|
222 |
-
torchdynamo = None,
|
223 |
-
ray_scope = 'last',
|
224 |
-
ddp_timeout = 1800,
|
225 |
-
torch_compile = False,
|
226 |
-
torch_compile_backend = None,
|
227 |
-
torch_compile_mode = None,
|
228 |
-
dispatch_batches = None,
|
229 |
-
split_batches = None,
|
230 |
-
include_tokens_per_second = False,
|
231 |
-
include_num_input_tokens_seen = False,
|
232 |
-
neftune_noise_alpha = None,
|
233 |
-
optim_target_modules = None,
|
234 |
-
batch_eval_metrics = False,
|
235 |
-
eval_on_start = False,
|
236 |
-
use_liger_kernel = False,
|
237 |
-
eval_use_gather_object = False,
|
238 |
-
average_tokens_across_devices = False,
|
239 |
-
reward_model_path = None,
|
240 |
-
judge = None,
|
241 |
-
max_new_tokens = 64,
|
242 |
-
max_length = 512,
|
243 |
-
temperature = 0.9,
|
244 |
-
missing_eos_penalty = None,
|
245 |
-
loss_type = 'sigmoid',
|
246 |
-
dataset_num_proc = None,
|
247 |
-
disable_dropout = True,
|
248 |
-
use_vllm = False,
|
249 |
-
ds3_gather_for_generation = True,
|
250 |
-
vllm_sampling_params = None,
|
251 |
-
unsloth_num_chunks = -1,
|
252 |
-
**kwargs,
|
253 |
-
):
|
254 |
-
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
255 |
-
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
256 |
-
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
257 |
-
output_dir = 'unsloth_training_checkpoints'
|
258 |
-
save_strategy = 'no'
|
259 |
-
if dataset_num_proc is None:
|
260 |
-
from multiprocessing import cpu_count
|
261 |
-
dataset_num_proc = cpu_count()
|
262 |
-
|
263 |
-
super().__init__(
|
264 |
-
output_dir = output_dir,
|
265 |
-
overwrite_output_dir = overwrite_output_dir,
|
266 |
-
do_train = do_train,
|
267 |
-
do_eval = do_eval,
|
268 |
-
do_predict = do_predict,
|
269 |
-
eval_strategy = eval_strategy,
|
270 |
-
prediction_loss_only = prediction_loss_only,
|
271 |
-
per_device_train_batch_size = per_device_train_batch_size,
|
272 |
-
per_device_eval_batch_size = per_device_eval_batch_size,
|
273 |
-
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
274 |
-
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
275 |
-
gradient_accumulation_steps = gradient_accumulation_steps,
|
276 |
-
eval_accumulation_steps = eval_accumulation_steps,
|
277 |
-
eval_delay = eval_delay,
|
278 |
-
torch_empty_cache_steps = torch_empty_cache_steps,
|
279 |
-
learning_rate = learning_rate,
|
280 |
-
weight_decay = weight_decay,
|
281 |
-
adam_beta1 = adam_beta1,
|
282 |
-
adam_beta2 = adam_beta2,
|
283 |
-
adam_epsilon = adam_epsilon,
|
284 |
-
max_grad_norm = max_grad_norm,
|
285 |
-
num_train_epochs = num_train_epochs,
|
286 |
-
max_steps = max_steps,
|
287 |
-
lr_scheduler_type = lr_scheduler_type,
|
288 |
-
warmup_ratio = warmup_ratio,
|
289 |
-
warmup_steps = warmup_steps,
|
290 |
-
log_level = log_level,
|
291 |
-
log_level_replica = log_level_replica,
|
292 |
-
log_on_each_node = log_on_each_node,
|
293 |
-
logging_dir = logging_dir,
|
294 |
-
logging_strategy = logging_strategy,
|
295 |
-
logging_first_step = logging_first_step,
|
296 |
-
logging_steps = logging_steps,
|
297 |
-
logging_nan_inf_filter = logging_nan_inf_filter,
|
298 |
-
save_strategy = save_strategy,
|
299 |
-
save_steps = save_steps,
|
300 |
-
save_total_limit = save_total_limit,
|
301 |
-
save_safetensors = save_safetensors,
|
302 |
-
save_on_each_node = save_on_each_node,
|
303 |
-
save_only_model = save_only_model,
|
304 |
-
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
305 |
-
no_cuda = no_cuda,
|
306 |
-
use_cpu = use_cpu,
|
307 |
-
use_mps_device = use_mps_device,
|
308 |
-
seed = seed,
|
309 |
-
data_seed = data_seed,
|
310 |
-
jit_mode_eval = jit_mode_eval,
|
311 |
-
use_ipex = use_ipex,
|
312 |
-
bf16 = bf16,
|
313 |
-
fp16 = fp16,
|
314 |
-
fp16_opt_level = fp16_opt_level,
|
315 |
-
half_precision_backend = half_precision_backend,
|
316 |
-
bf16_full_eval = bf16_full_eval,
|
317 |
-
fp16_full_eval = fp16_full_eval,
|
318 |
-
tf32 = tf32,
|
319 |
-
local_rank = local_rank,
|
320 |
-
ddp_backend = ddp_backend,
|
321 |
-
tpu_num_cores = tpu_num_cores,
|
322 |
-
tpu_metrics_debug = tpu_metrics_debug,
|
323 |
-
debug = debug,
|
324 |
-
dataloader_drop_last = dataloader_drop_last,
|
325 |
-
eval_steps = eval_steps,
|
326 |
-
dataloader_num_workers = dataloader_num_workers,
|
327 |
-
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
328 |
-
past_index = past_index,
|
329 |
-
run_name = run_name,
|
330 |
-
disable_tqdm = disable_tqdm,
|
331 |
-
remove_unused_columns = remove_unused_columns,
|
332 |
-
label_names = label_names,
|
333 |
-
load_best_model_at_end = load_best_model_at_end,
|
334 |
-
metric_for_best_model = metric_for_best_model,
|
335 |
-
greater_is_better = greater_is_better,
|
336 |
-
ignore_data_skip = ignore_data_skip,
|
337 |
-
fsdp = fsdp,
|
338 |
-
fsdp_min_num_params = fsdp_min_num_params,
|
339 |
-
fsdp_config = fsdp_config,
|
340 |
-
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
341 |
-
accelerator_config = accelerator_config,
|
342 |
-
deepspeed = deepspeed,
|
343 |
-
label_smoothing_factor = label_smoothing_factor,
|
344 |
-
optim = optim,
|
345 |
-
optim_args = optim_args,
|
346 |
-
adafactor = adafactor,
|
347 |
-
group_by_length = group_by_length,
|
348 |
-
length_column_name = length_column_name,
|
349 |
-
report_to = report_to,
|
350 |
-
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
351 |
-
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
352 |
-
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
353 |
-
dataloader_pin_memory = dataloader_pin_memory,
|
354 |
-
dataloader_persistent_workers = dataloader_persistent_workers,
|
355 |
-
skip_memory_metrics = skip_memory_metrics,
|
356 |
-
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
357 |
-
push_to_hub = push_to_hub,
|
358 |
-
resume_from_checkpoint = resume_from_checkpoint,
|
359 |
-
hub_model_id = hub_model_id,
|
360 |
-
hub_strategy = hub_strategy,
|
361 |
-
hub_token = hub_token,
|
362 |
-
hub_private_repo = hub_private_repo,
|
363 |
-
hub_always_push = hub_always_push,
|
364 |
-
gradient_checkpointing = gradient_checkpointing,
|
365 |
-
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
366 |
-
include_inputs_for_metrics = include_inputs_for_metrics,
|
367 |
-
eval_do_concat_batches = eval_do_concat_batches,
|
368 |
-
fp16_backend = fp16_backend,
|
369 |
-
evaluation_strategy = evaluation_strategy,
|
370 |
-
push_to_hub_model_id = push_to_hub_model_id,
|
371 |
-
push_to_hub_organization = push_to_hub_organization,
|
372 |
-
push_to_hub_token = push_to_hub_token,
|
373 |
-
mp_parameters = mp_parameters,
|
374 |
-
auto_find_batch_size = auto_find_batch_size,
|
375 |
-
full_determinism = full_determinism,
|
376 |
-
torchdynamo = torchdynamo,
|
377 |
-
ray_scope = ray_scope,
|
378 |
-
ddp_timeout = ddp_timeout,
|
379 |
-
torch_compile = torch_compile,
|
380 |
-
torch_compile_backend = torch_compile_backend,
|
381 |
-
torch_compile_mode = torch_compile_mode,
|
382 |
-
dispatch_batches = dispatch_batches,
|
383 |
-
split_batches = split_batches,
|
384 |
-
include_tokens_per_second = include_tokens_per_second,
|
385 |
-
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
386 |
-
neftune_noise_alpha = neftune_noise_alpha,
|
387 |
-
optim_target_modules = optim_target_modules,
|
388 |
-
batch_eval_metrics = batch_eval_metrics,
|
389 |
-
eval_on_start = eval_on_start,
|
390 |
-
use_liger_kernel = use_liger_kernel,
|
391 |
-
eval_use_gather_object = eval_use_gather_object,
|
392 |
-
average_tokens_across_devices = average_tokens_across_devices,
|
393 |
-
reward_model_path = reward_model_path,
|
394 |
-
judge = judge,
|
395 |
-
max_new_tokens = max_new_tokens,
|
396 |
-
max_length = max_length,
|
397 |
-
temperature = temperature,
|
398 |
-
missing_eos_penalty = missing_eos_penalty,
|
399 |
-
loss_type = loss_type,
|
400 |
-
dataset_num_proc = dataset_num_proc,
|
401 |
-
disable_dropout = disable_dropout,
|
402 |
-
use_vllm = use_vllm,
|
403 |
-
ds3_gather_for_generation = ds3_gather_for_generation,**kwargs)
|
404 |
-
self.vllm_sampling_params = vllm_sampling_params
|
405 |
-
self.unsloth_num_chunks = unsloth_num_chunks
|
406 |
-
pass
|
407 |
-
|
408 |
-
class _UnslothOnlineDPOTrainer(Trainer):
|
409 |
-
r""""""
|
410 |
-
|
411 |
-
_tag_names = ["trl", "online-dpo"]
|
412 |
-
|
413 |
-
def __init__(
|
414 |
-
self,
|
415 |
-
model: Union[PreTrainedModel, nn.Module],
|
416 |
-
ref_model: Union[PreTrainedModel, nn.Module, None] = None,
|
417 |
-
reward_model: Union[PreTrainedModel, nn.Module, None] = None,
|
418 |
-
judge: Optional[BasePairwiseJudge] = None,
|
419 |
-
args: Optional[OnlineDPOConfig] = None,
|
420 |
-
data_collator: Optional[DataCollator] = None,
|
421 |
-
train_dataset: Optional[Union[Dataset, IterableDataset, "datasets.Dataset"]] = None,
|
422 |
-
eval_dataset: Optional[Union[Dataset, dict[str, Dataset], "datasets.Dataset"]] = None,
|
423 |
-
processing_class: Optional[
|
424 |
-
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
425 |
-
] = None,
|
426 |
-
reward_processing_class: Optional[PreTrainedTokenizerBase] = None,
|
427 |
-
peft_config: Optional[dict] = None,
|
428 |
-
compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
|
429 |
-
callbacks: Optional[list[TrainerCallback]] = None,
|
430 |
-
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
431 |
-
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
432 |
-
) -> None:
|
433 |
-
|
434 |
-
if hasattr(model, 'vllm_engine') and hasattr(args, 'use_vllm') and (getattr(args, 'use_vllm', False) == False): args.use_vllm = True
|
435 |
-
if ref_model is model:
|
436 |
-
raise ValueError(
|
437 |
-
"`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
|
438 |
-
"same as `model`, either omit the `ref_model` argument or pass `None`."
|
439 |
-
)
|
440 |
-
|
441 |
-
self.ref_model = ref_model
|
442 |
-
|
443 |
-
if reward_model is not None and judge is not None:
|
444 |
-
warnings.warn(
|
445 |
-
"Both `reward_model` and `judge` are provided. Please choose provide only one of them. "
|
446 |
-
"Ignoring `judge` and using `reward_model`.",
|
447 |
-
UserWarning,
|
448 |
-
)
|
449 |
-
judge = None
|
450 |
-
elif reward_model is None and judge is None:
|
451 |
-
raise ValueError("Either `reward_model` or `judge` must be provided.")
|
452 |
-
|
453 |
-
self.reward_model = reward_model
|
454 |
-
self.reward_processing_class = reward_processing_class
|
455 |
-
self.judge = judge
|
456 |
-
|
457 |
-
if args.missing_eos_penalty is not None and judge is not None:
|
458 |
-
raise ValueError("`missing_eos_penalty` is not supported when `judge` is provided.")
|
459 |
-
|
460 |
-
if args is None:
|
461 |
-
raise ValueError("`args` must be provided.")
|
462 |
-
|
463 |
-
# Check that the processing_class is provided
|
464 |
-
if processing_class is None:
|
465 |
-
raise ValueError("`processing_class` must be provided.")
|
466 |
-
|
467 |
-
# Convert to PEFT model if peft_config is provided
|
468 |
-
if False:
|
469 |
-
# Check if PEFT is available
|
470 |
-
if not is_peft_available():
|
471 |
-
raise ImportError(
|
472 |
-
"PEFT is not available and passed `peft_config`. Please install PEFT with "
|
473 |
-
"`pip install peft` to use it."
|
474 |
-
)
|
475 |
-
|
476 |
-
# If the model is already a PeftModel, we need to merge and unload it.
|
477 |
-
# Further information here: https://huggingface.co/docs/trl/dpo_trainer#reference-model-considerations-with-peft
|
478 |
-
if isinstance(model, PeftModel):
|
479 |
-
model = model.merge_and_unload()
|
480 |
-
|
481 |
-
# Get peft model with the given config
|
482 |
-
model = model
|
483 |
-
|
484 |
-
# Disable dropout in the model and reference model
|
485 |
-
if args.disable_dropout:
|
486 |
-
disable_dropout_in_model(model)
|
487 |
-
if self.ref_model is not None:
|
488 |
-
disable_dropout_in_model(self.ref_model)
|
489 |
-
|
490 |
-
# Handle the ref_model
|
491 |
-
# Usually, the user wants the ref model to be the initial version of the model. When using PEFT, it's easy to
|
492 |
-
# get the ref model, as it's just the model with a disabled adapter. When not using PEFT, we need to create
|
493 |
-
# the ref model from the model by copying it and disable the gradients and set it in evaluation mode.
|
494 |
-
if ref_model is None: # No ref model provided, the most common case
|
495 |
-
if False:
|
496 |
-
self.ref_model = create_reference_model(model) # copy, disable gradients, set eval mode
|
497 |
-
else:
|
498 |
-
self.ref_model = None # we don't need a ref model here, we can just disable the adapter.
|
499 |
-
else: # rare case, the user provided a ref model
|
500 |
-
self.ref_model = ref_model
|
501 |
-
self.ref_model.eval()
|
502 |
-
|
503 |
-
# Disable the gradient and set the reward model in eval mode
|
504 |
-
if self.reward_model is not None:
|
505 |
-
self.reward_model.eval()
|
506 |
-
|
507 |
-
# Define the collator is not provided
|
508 |
-
if data_collator is None:
|
509 |
-
data_collator = DPODataCollatorWithPadding(pad_token_id=processing_class.pad_token_id)
|
510 |
-
|
511 |
-
self.max_length = args.max_length
|
512 |
-
|
513 |
-
self.stats = {
|
514 |
-
"objective/kl": [],
|
515 |
-
"objective/entropy": [],
|
516 |
-
"objective/non_score_reward": [],
|
517 |
-
"rewards/chosen": [],
|
518 |
-
"rewards/rejected": [],
|
519 |
-
"rewards/accuracies": [],
|
520 |
-
"rewards/margins": [],
|
521 |
-
"logps/chosen": [],
|
522 |
-
"logps/rejected": [],
|
523 |
-
"val/contain_eos_token": [],
|
524 |
-
"beta": [],
|
525 |
-
}
|
526 |
-
if self.reward_model is not None:
|
527 |
-
self.stats["objective/rlhf_reward"] = []
|
528 |
-
self.stats["objective/scores_margin"] = []
|
529 |
-
self.stats["objective/scores"] = []
|
530 |
-
|
531 |
-
if args.use_vllm:
|
532 |
-
self.llm = model.vllm_engine; self._last_loaded_step = 0; self.generation_config = SamplingParams(
|
533 |
-
n=2, max_tokens=args.max_new_tokens,
|
534 |
-
temperature=args.temperature,
|
535 |
-
top_k=50,
|
536 |
-
top_p=1.0,
|
537 |
-
detokenize=False,**getattr(getattr(args, 'vllm_sampling_params', vLLMSamplingParams()), '_set_kwargs', {}),)
|
538 |
-
else:
|
539 |
-
self.generation_config = GenerationConfig(
|
540 |
-
max_new_tokens=args.max_new_tokens,
|
541 |
-
temperature=args.temperature,
|
542 |
-
top_k=50,
|
543 |
-
top_p=1.0,
|
544 |
-
do_sample=True,
|
545 |
-
use_cache=False if args.gradient_checkpointing else True,
|
546 |
-
)
|
547 |
-
|
548 |
-
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
|
549 |
-
# input tensor associated with the key "input_ids". However, in Online DPO, the sampled data does not include
|
550 |
-
# the "input_ids" key. As a result, the trainer issues the warning: "Could not estimate the number of tokens
|
551 |
-
# of the input, floating-point operations will not be computed." To suppress this warning, we set the
|
552 |
-
# "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate
|
553 |
-
# that the warning has already been issued.
|
554 |
-
model.warnings_issued["estimate_tokens"] = True
|
555 |
-
|
556 |
-
super().__init__(
|
557 |
-
model=model,
|
558 |
-
args=args,
|
559 |
-
data_collator=data_collator,
|
560 |
-
train_dataset=train_dataset,
|
561 |
-
eval_dataset=eval_dataset,
|
562 |
-
processing_class=processing_class,
|
563 |
-
compute_metrics=compute_metrics,
|
564 |
-
callbacks=callbacks,
|
565 |
-
optimizers=optimizers,
|
566 |
-
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
567 |
-
)
|
568 |
-
|
569 |
-
# Add tags for models that have been loaded with the correct transformers version
|
570 |
-
if hasattr(self.model, "add_model_tags"):
|
571 |
-
self.model.add_model_tags(self._tag_names)
|
572 |
-
|
573 |
-
self._beta = args.beta
|
574 |
-
|
575 |
-
# Placed after the super().__init__ because we need self.is_deepspeed_enabled and self.accelerator
|
576 |
-
if self.is_deepspeed_enabled:
|
577 |
-
if self.reward_model is not None:
|
578 |
-
self.reward_model = prepare_deepspeed(
|
579 |
-
self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16
|
580 |
-
)
|
581 |
-
if self.ref_model is not None:
|
582 |
-
self.ref_model = prepare_deepspeed(
|
583 |
-
self.ref_model, args.per_device_train_batch_size, args.fp16, args.bf16
|
584 |
-
)
|
585 |
-
else:
|
586 |
-
if self.ref_model is not None:
|
587 |
-
self.ref_model = self.ref_model.to(self.accelerator.device)
|
588 |
-
if self.reward_model is not None:
|
589 |
-
self.reward_model = self.reward_model.to(self.accelerator.device)
|
590 |
-
|
591 |
-
@property
|
592 |
-
def beta(self):
|
593 |
-
if isinstance(self._beta, list):
|
594 |
-
epoch = self.state.epoch
|
595 |
-
return self._beta[epoch] if epoch < len(self._beta) else self._beta[-1]
|
596 |
-
else:
|
597 |
-
return self._beta
|
598 |
-
|
599 |
-
@staticmethod
|
600 |
-
def tokenize_row(feature, is_encoder_decoder: bool, tokenizer: PreTrainedTokenizerBase) -> dict[str, Any]:
|
601 |
-
"""Tokenize a single row from a DPO specific dataset."""
|
602 |
-
if not is_encoder_decoder:
|
603 |
-
batch = tokenizer(feature["prompt"], add_special_tokens=False)
|
604 |
-
# Add BOS token to head of prompt. Avoid adding if it's already there
|
605 |
-
if tokenizer.bos_token_id is not None:
|
606 |
-
prompt_len_input_ids = len(batch["input_ids"])
|
607 |
-
if prompt_len_input_ids == 0 or tokenizer.bos_token_id != batch["input_ids"][0]:
|
608 |
-
batch["input_ids"] = [tokenizer.bos_token_id] + batch["input_ids"]
|
609 |
-
batch["attention_mask"] = [1] + batch["attention_mask"]
|
610 |
-
else:
|
611 |
-
batch = tokenizer(feature["prompt"], add_special_tokens=True)
|
612 |
-
batch = {f"prompt_{key}": value for key, value in batch.items()}
|
613 |
-
return batch
|
614 |
-
|
615 |
-
# Same as Trainer.get_train_dataloader but skip the "remove_unused_columns".
|
616 |
-
@wraps(Trainer.get_train_dataloader)
|
617 |
-
def get_train_dataloader(self) -> DataLoader:
|
618 |
-
if self.train_dataset is None:
|
619 |
-
raise ValueError("Trainer: training requires a train_dataset.")
|
620 |
-
|
621 |
-
train_dataset = self.train_dataset
|
622 |
-
data_collator = self.data_collator
|
623 |
-
dataloader_params = {
|
624 |
-
"batch_size": self._train_batch_size,
|
625 |
-
"collate_fn": data_collator,
|
626 |
-
"num_workers": self.args.dataloader_num_workers,
|
627 |
-
"pin_memory": self.args.dataloader_pin_memory,
|
628 |
-
"persistent_workers": self.args.dataloader_persistent_workers,
|
629 |
-
}
|
630 |
-
|
631 |
-
if not isinstance(train_dataset, torch.utils.data.IterableDataset):
|
632 |
-
dataloader_params["sampler"] = self._get_train_sampler()
|
633 |
-
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
634 |
-
dataloader_params["worker_init_fn"] = seed_worker
|
635 |
-
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
|
636 |
-
|
637 |
-
return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
|
638 |
-
|
639 |
-
# Same as Trainer.get_eval_dataloader but skip the "remove_unused_columns".
|
640 |
-
@wraps(Trainer.get_eval_dataloader)
|
641 |
-
def get_eval_dataloader(self, eval_dataset: Optional[Union[str, Dataset]] = None) -> DataLoader:
|
642 |
-
if eval_dataset is None and self.eval_dataset is None:
|
643 |
-
raise ValueError("Trainer: evaluation requires an eval_dataset.")
|
644 |
-
|
645 |
-
# If we have persistent workers, don't do a fork bomb especially as eval datasets
|
646 |
-
# don't change during training
|
647 |
-
dataloader_key = eval_dataset if isinstance(eval_dataset, str) else "eval"
|
648 |
-
if (
|
649 |
-
hasattr(self, "_eval_dataloaders")
|
650 |
-
and dataloader_key in self._eval_dataloaders
|
651 |
-
and self.args.dataloader_persistent_workers
|
652 |
-
):
|
653 |
-
return self.accelerator.prepare(self._eval_dataloaders[dataloader_key])
|
654 |
-
|
655 |
-
eval_dataset = (
|
656 |
-
self.eval_dataset[eval_dataset]
|
657 |
-
if isinstance(eval_dataset, str)
|
658 |
-
else eval_dataset
|
659 |
-
if eval_dataset is not None
|
660 |
-
else self.eval_dataset
|
661 |
-
)
|
662 |
-
data_collator = self.data_collator
|
663 |
-
|
664 |
-
dataloader_params = {
|
665 |
-
"batch_size": self.args.eval_batch_size,
|
666 |
-
"collate_fn": data_collator,
|
667 |
-
"num_workers": self.args.dataloader_num_workers,
|
668 |
-
"pin_memory": self.args.dataloader_pin_memory,
|
669 |
-
"persistent_workers": self.args.dataloader_persistent_workers,
|
670 |
-
}
|
671 |
-
|
672 |
-
if not isinstance(eval_dataset, torch.utils.data.IterableDataset):
|
673 |
-
dataloader_params["sampler"] = self._get_eval_sampler(eval_dataset)
|
674 |
-
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
675 |
-
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
|
676 |
-
|
677 |
-
# accelerator.free_memory() will destroy the references, so
|
678 |
-
# we need to store the non-prepared version
|
679 |
-
eval_dataloader = DataLoader(eval_dataset, **dataloader_params)
|
680 |
-
if self.args.dataloader_persistent_workers:
|
681 |
-
if hasattr(self, "_eval_dataloaders"):
|
682 |
-
self._eval_dataloaders[dataloader_key] = eval_dataloader
|
683 |
-
else:
|
684 |
-
self._eval_dataloaders = {dataloader_key: eval_dataloader}
|
685 |
-
|
686 |
-
return self.accelerator.prepare(eval_dataloader)
|
687 |
-
|
688 |
-
def _generate_vllm(self, model, prompts):
|
689 |
-
eos_token_id = self.processing_class.eos_token_id
|
690 |
-
pad_token_id = self.processing_class.pad_token_id
|
691 |
-
|
692 |
-
# Load the latest weights
|
693 |
-
|
694 |
-
pass
|
695 |
-
|
696 |
-
pass
|
697 |
-
|
698 |
-
if is_conversational({"prompt": prompts[0]}):
|
699 |
-
outputs = self.llm.chat(prompts, self.generation_config, use_tqdm=False, lora_request = self.model.load_lora('online_dpo_trainer_lora_model', load_tensors = True))
|
700 |
-
else:
|
701 |
-
outputs = self.llm.generate(prompts, self.generation_config, use_tqdm=False, lora_request = self.model.load_lora('online_dpo_trainer_lora_model', load_tensors = True))
|
702 |
-
|
703 |
-
completion_ids = [list(output.outputs[i].token_ids) for i in range(2) for output in outputs]
|
704 |
-
prompt_ids = [list(output.prompt_token_ids) for _ in range(2) for output in outputs]
|
705 |
-
|
706 |
-
# Create mask and pad the prompt and completion
|
707 |
-
max_prompt_length = max(len(ids) for ids in prompt_ids)
|
708 |
-
prompt_mask = [[0] * (max_prompt_length - len(ids)) + [1] * len(ids) for ids in prompt_ids]
|
709 |
-
prompt_ids = [[pad_token_id] * (max_prompt_length - len(ids)) + ids for ids in prompt_ids]
|
710 |
-
max_tokens = self.generation_config.max_tokens
|
711 |
-
completion_mask = [[1] * len(ids) + [0] * (max_tokens - len(ids)) for ids in completion_ids]
|
712 |
-
completion_ids = [
|
713 |
-
ids + [eos_token_id] if ids[-1] != eos_token_id and len(ids) < max_tokens else ids
|
714 |
-
for ids in completion_ids
|
715 |
-
]
|
716 |
-
completion_ids = [ids + [pad_token_id] * (max_tokens - len(ids)) for ids in completion_ids]
|
717 |
-
|
718 |
-
# Convert to tensors
|
719 |
-
prompt_ids = torch.tensor(prompt_ids, device=self.accelerator.device)
|
720 |
-
prompt_mask = torch.tensor(prompt_mask, device=self.accelerator.device)
|
721 |
-
completion_ids = torch.tensor(completion_ids, device=self.accelerator.device)
|
722 |
-
completion_mask = torch.tensor(completion_mask, device=self.accelerator.device)
|
723 |
-
|
724 |
-
return prompt_ids, prompt_mask, completion_ids, completion_mask
|
725 |
-
|
726 |
-
def _generate(self, model, prompts):
|
727 |
-
eos_token_id = self.processing_class.eos_token_id
|
728 |
-
pad_token_id = self.processing_class.pad_token_id
|
729 |
-
|
730 |
-
# Apply chat template and tokenize the input. We do this on-the-fly to enable the use of reward models and
|
731 |
-
# policies with different tokenizers / chat templates.
|
732 |
-
inputs = [{"prompt": prompt} for prompt in prompts]
|
733 |
-
inputs = [maybe_apply_chat_template(x, self.processing_class) for x in inputs]
|
734 |
-
inputs = [self.tokenize_row(x, model.config.is_encoder_decoder, self.processing_class) for x in inputs]
|
735 |
-
inputs = self.data_collator(inputs)
|
736 |
-
|
737 |
-
# Sample 2 completions per prompt of size `max_new_tokens` from the model
|
738 |
-
inputs = self._prepare_inputs(inputs)
|
739 |
-
prompt_ids = inputs["prompt_input_ids"].repeat(2, 1)
|
740 |
-
prompt_mask = inputs["prompt_attention_mask"].repeat(2, 1)
|
741 |
-
with unwrap_model_for_generation(
|
742 |
-
model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
|
743 |
-
) as unwrapped_model:
|
744 |
-
output = unwrapped_model.generate(
|
745 |
-
input_ids=prompt_ids,
|
746 |
-
attention_mask=prompt_mask,
|
747 |
-
generation_config=self.generation_config,
|
748 |
-
)
|
749 |
-
|
750 |
-
completion_ids = output[:, prompt_ids.size(1) :]
|
751 |
-
completion_ids, completion_mask = truncate_right(completion_ids, eos_token_id, pad_token_id)
|
752 |
-
|
753 |
-
return prompt_ids, prompt_mask, completion_ids, completion_mask
|
754 |
-
|
755 |
-
def _forward(self, model, prompt_ids, prompt_mask, completion_ids, completion_mask):
|
756 |
-
# Get the number of tokens to truncate from prompt
|
757 |
-
num_tokens_to_truncate = max(prompt_ids.size(1) + completion_ids.size(1) - self.max_length, 0)
|
758 |
-
|
759 |
-
# Truncate left to avoid oom
|
760 |
-
prompt_ids = prompt_ids[:, num_tokens_to_truncate:]
|
761 |
-
prompt_mask = prompt_mask[:, num_tokens_to_truncate:]
|
762 |
-
|
763 |
-
# Concat the prompt and completion
|
764 |
-
prompt_completion_ids = torch.cat((prompt_ids, completion_ids), dim=1)
|
765 |
-
prompt_completion_mask = torch.cat((prompt_mask, completion_mask), dim=1)
|
766 |
-
|
767 |
-
# Get the logprobs of the completions from the model
|
768 |
-
output = model(prompt_completion_ids, attention_mask=prompt_completion_mask)
|
769 |
-
|
770 |
-
# There is 1 offset, because the model predict the next token
|
771 |
-
logits = output.logits[:, prompt_ids.size(1) - 1 : -1]
|
772 |
-
|
773 |
-
# Take the completion tokens logprob
|
774 |
-
logprobs = torch.take_along_dim(logits.log_softmax(dim=-1), completion_ids.unsqueeze(-1), dim=2).squeeze(-1)
|
775 |
-
return logprobs
|
776 |
-
|
777 |
-
def training_step(
|
778 |
-
self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
|
779 |
-
) -> torch.Tensor:
|
780 |
-
model.train()
|
781 |
-
|
782 |
-
prompts = inputs["prompt"]
|
783 |
-
batch_size = len(prompts)
|
784 |
-
|
785 |
-
if self.args.use_vllm:
|
786 |
-
prompt_ids, prompt_mask, completion_ids, completion_mask = self._generate_vllm(model, prompts)
|
787 |
-
else:
|
788 |
-
prompt_ids, prompt_mask, completion_ids, completion_mask = self._generate(model, prompts)
|
789 |
-
|
790 |
-
contain_eos_token = torch.any(completion_ids == self.processing_class.eos_token_id, dim=-1)
|
791 |
-
|
792 |
-
logprobs = self._forward(model, prompt_ids, prompt_mask, completion_ids, completion_mask)
|
793 |
-
with torch.no_grad():
|
794 |
-
if self.ref_model is not None:
|
795 |
-
ref_logprobs = self._forward(self.ref_model, prompt_ids, prompt_mask, completion_ids, completion_mask)
|
796 |
-
else: # peft case: we just need to disable the adapter
|
797 |
-
with self.model.disable_adapter():
|
798 |
-
ref_logprobs = self._forward(self.model, prompt_ids, prompt_mask, completion_ids, completion_mask)
|
799 |
-
|
800 |
-
# Decode the completions, and format them if the input is conversational
|
801 |
-
device = logprobs.device
|
802 |
-
completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
|
803 |
-
if is_conversational({"prompt": prompts[0]}):
|
804 |
-
completions = [[{"role": "assistant", "content": completion}] for completion in completions]
|
805 |
-
|
806 |
-
# Get the reward from the reward model or judge
|
807 |
-
if self.judge is not None:
|
808 |
-
# Once formatted, conversational data may contain special tokens (such as <|im_start|>) that are not
|
809 |
-
# directly understandable by the judge and could alter its judgment. To avoid this and make the judge
|
810 |
-
# independent of the model's chat template, we use the raw conversation data, and apply our own chat
|
811 |
-
# template to it.
|
812 |
-
if is_conversational({"prompt": prompts[0]}):
|
813 |
-
environment = jinja2.Environment()
|
814 |
-
template = environment.from_string(SIMPLE_CHAT_TEMPLATE)
|
815 |
-
prompts = [template.render(messages=prompt) for prompt in prompts]
|
816 |
-
completions = [template.render(messages=completion) for completion in completions]
|
817 |
-
|
818 |
-
ranks_of_first_completion = self.judge.judge(
|
819 |
-
prompts, list(zip(completions[:batch_size], completions[batch_size:]))
|
820 |
-
)
|
821 |
-
|
822 |
-
# convert ranks to a True/False mask:
|
823 |
-
# when rank == 0, it means the first completion is the best
|
824 |
-
# when rank == 1, it means the second completion is the best
|
825 |
-
mask = torch.tensor([rank == 0 for rank in ranks_of_first_completion], device=device)
|
826 |
-
else:
|
827 |
-
# The reward model may not have the same chat template or tokenizer as the model, so we need to use the
|
828 |
-
# raw data (string), apply the chat template (if needed), and tokenize it with the reward processing class.
|
829 |
-
prompts = 2 * prompts # repeat the prompt: [prompt0, prompt1] -> [prompt0, prompt1, prompt0, prompt1]
|
830 |
-
if is_conversational({"prompt": prompts[0]}):
|
831 |
-
examples = [{"prompt": p, "completion": c} for p, c in zip(prompts, completions)]
|
832 |
-
examples = [apply_chat_template(example, self.reward_processing_class) for example in examples]
|
833 |
-
prompts = [example["prompt"] for example in examples]
|
834 |
-
completions = [example["completion"] for example in examples]
|
835 |
-
|
836 |
-
# Tokenize the prompts
|
837 |
-
prompts_ids = self.reward_processing_class(
|
838 |
-
prompts, padding=True, return_tensors="pt", padding_side="left"
|
839 |
-
)["input_ids"].to(device)
|
840 |
-
context_length = prompts_ids.shape[1]
|
841 |
-
|
842 |
-
# Tokenize the completions
|
843 |
-
completions_ids = self.reward_processing_class(
|
844 |
-
completions, padding=True, return_tensors="pt", padding_side="right"
|
845 |
-
)["input_ids"].to(device)
|
846 |
-
|
847 |
-
# Concatenate the prompts and completions and get the reward
|
848 |
-
prompt_completion_ids = torch.cat((prompts_ids, completions_ids), dim=1)
|
849 |
-
with torch.inference_mode():
|
850 |
-
_, scores, _ = get_reward(
|
851 |
-
self.reward_model, prompt_completion_ids, self.reward_processing_class.pad_token_id, context_length
|
852 |
-
)
|
853 |
-
|
854 |
-
# Filter completion. Ensure that the sample contains stop_token_id
|
855 |
-
# Completions not passing that filter will receive a lower score.
|
856 |
-
if self.args.missing_eos_penalty is not None:
|
857 |
-
scores[~contain_eos_token] -= self.args.missing_eos_penalty
|
858 |
-
|
859 |
-
# Split the scores in 2 (the prompts of the first half are the same as the second half)
|
860 |
-
first_half, second_half = scores.split(batch_size)
|
861 |
-
|
862 |
-
# Get the indices of the chosen and rejected examples
|
863 |
-
mask = first_half >= second_half
|
864 |
-
|
865 |
-
batch_range = torch.arange(batch_size, device=device)
|
866 |
-
chosen_indices = batch_range + (~mask * batch_size)
|
867 |
-
rejected_indices = batch_range + (mask * batch_size)
|
868 |
-
|
869 |
-
# Build tensor so that the first half is the chosen examples and the second half the rejected examples
|
870 |
-
cr_indices = torch.cat((chosen_indices, rejected_indices), dim=0) # cr = chosen and rejected
|
871 |
-
cr_logprobs = logprobs[cr_indices]
|
872 |
-
cr_ref_logprobs = ref_logprobs[cr_indices]
|
873 |
-
|
874 |
-
# mask out the padding tokens
|
875 |
-
padding_mask = ~completion_mask.bool()
|
876 |
-
cr_padding_mask = padding_mask[cr_indices]
|
877 |
-
|
878 |
-
cr_logprobs_sum = (cr_logprobs * ~cr_padding_mask).sum(1)
|
879 |
-
cr_ref_logprobs_sum = (cr_ref_logprobs * ~cr_padding_mask).sum(1)
|
880 |
-
|
881 |
-
# Split the chosen and rejected examples
|
882 |
-
chosen_logprobs_sum, rejected_logprobs_sum = torch.split(cr_logprobs_sum, batch_size)
|
883 |
-
chosen_ref_logprobs_sum, rejected_ref_logprobs_sum = torch.split(cr_ref_logprobs_sum, batch_size)
|
884 |
-
pi_logratios = chosen_logprobs_sum - rejected_logprobs_sum
|
885 |
-
ref_logratios = chosen_ref_logprobs_sum - rejected_ref_logprobs_sum
|
886 |
-
|
887 |
-
logits = pi_logratios - ref_logratios
|
888 |
-
|
889 |
-
if self.args.loss_type == "sigmoid":
|
890 |
-
losses = -F.logsigmoid(self.beta * logits)
|
891 |
-
elif self.args.loss_type == "ipo":
|
892 |
-
losses = (logits - 1 / (2 * self.beta)) ** 2
|
893 |
-
else:
|
894 |
-
raise NotImplementedError(f"invalid loss type {self.loss_type}")
|
895 |
-
|
896 |
-
loss = losses.mean()
|
897 |
-
|
898 |
-
# Log everything
|
899 |
-
if self.reward_model is not None:
|
900 |
-
scores_margin = scores[chosen_indices] - scores[rejected_indices]
|
901 |
-
self.stats["objective/scores_margin"].append(
|
902 |
-
self.accelerator.gather_for_metrics(scores_margin.mean()).mean().item()
|
903 |
-
)
|
904 |
-
self.stats["objective/scores"].append(self.accelerator.gather_for_metrics(scores.mean()).mean().item())
|
905 |
-
self.stats["val/contain_eos_token"].append(contain_eos_token.float().mean().item())
|
906 |
-
self.stats["logps/chosen"].append(self.accelerator.gather_for_metrics(chosen_logprobs_sum).mean().item())
|
907 |
-
self.stats["logps/rejected"].append(self.accelerator.gather_for_metrics(rejected_logprobs_sum).mean().item())
|
908 |
-
|
909 |
-
kl = logprobs - ref_logprobs
|
910 |
-
mean_kl = kl.sum(1).mean()
|
911 |
-
self.stats["objective/kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
|
912 |
-
non_score_reward = (-self.beta * kl).sum(1)
|
913 |
-
mean_non_score_reward = non_score_reward.mean()
|
914 |
-
self.stats["objective/non_score_reward"].append(
|
915 |
-
self.accelerator.gather_for_metrics(mean_non_score_reward).mean().item()
|
916 |
-
)
|
917 |
-
if self.reward_model is not None:
|
918 |
-
rlhf_reward = scores + non_score_reward
|
919 |
-
self.stats["objective/rlhf_reward"].append(self.accelerator.gather_for_metrics(rlhf_reward).mean().item())
|
920 |
-
mean_entropy = -logprobs.sum(1).mean()
|
921 |
-
self.stats["objective/entropy"].append(self.accelerator.gather_for_metrics(mean_entropy).mean().item())
|
922 |
-
chosen_rewards = self.beta * (chosen_logprobs_sum - chosen_ref_logprobs_sum)
|
923 |
-
gathered_chosen_rewards = self.accelerator.gather_for_metrics(chosen_rewards)
|
924 |
-
self.stats["rewards/chosen"].append(gathered_chosen_rewards.mean().item())
|
925 |
-
rejected_rewards = self.beta * (rejected_logprobs_sum - rejected_ref_logprobs_sum)
|
926 |
-
gathered_rejected_rewards = self.accelerator.gather_for_metrics(rejected_rewards)
|
927 |
-
self.stats["rewards/rejected"].append(gathered_rejected_rewards.mean().item())
|
928 |
-
margin = gathered_chosen_rewards - gathered_rejected_rewards
|
929 |
-
self.stats["rewards/margins"].append(margin.mean().item())
|
930 |
-
accuracy = margin > 0
|
931 |
-
self.stats["rewards/accuracies"].append(accuracy.float().mean().item())
|
932 |
-
self.stats["beta"].append(self.beta)
|
933 |
-
|
934 |
-
if (
|
935 |
-
self.args.torch_empty_cache_steps is not None
|
936 |
-
and self.state.global_step % self.args.torch_empty_cache_steps == 0
|
937 |
-
):
|
938 |
-
empty_cache()
|
939 |
-
|
940 |
-
kwargs = {}
|
941 |
-
|
942 |
-
# For LOMO optimizers you need to explicitly use the learnign rate
|
943 |
-
if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
|
944 |
-
kwargs["learning_rate"] = self._get_learning_rate()
|
945 |
-
|
946 |
-
if self.args.n_gpu > 1:
|
947 |
-
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
948 |
-
|
949 |
-
if self.use_apex:
|
950 |
-
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
951 |
-
scaled_loss.backward()
|
952 |
-
else:
|
953 |
-
self.accelerator.backward(loss, **kwargs)
|
954 |
-
|
955 |
-
return loss.detach() / self.args.gradient_accumulation_steps
|
956 |
-
|
957 |
-
# Same as Trainer._maybe_log_save_evaluate but log our metrics
|
958 |
-
# start_time defaults to None to allow compatibility with transformers<=4.46
|
959 |
-
def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time=None):
|
960 |
-
if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
|
961 |
-
logs: dict[str, float] = {}
|
962 |
-
|
963 |
-
# all_gather + mean() to get average loss over all processes
|
964 |
-
tr_loss_scalar = self._nested_gather(tr_loss).mean().item()
|
965 |
-
|
966 |
-
# reset tr_loss to zero
|
967 |
-
tr_loss -= tr_loss
|
968 |
-
|
969 |
-
logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
|
970 |
-
if grad_norm is not None:
|
971 |
-
logs["grad_norm"] = grad_norm.detach().item() if isinstance(grad_norm, torch.Tensor) else grad_norm
|
972 |
-
logs["learning_rate"] = self._get_learning_rate()
|
973 |
-
|
974 |
-
# Add our metrics
|
975 |
-
for key, val in self.stats.items():
|
976 |
-
logs[key] = sum(val) / len(val)
|
977 |
-
self.stats = {key: [] for key in self.stats} # reset stats
|
978 |
-
|
979 |
-
self._total_loss_scalar += tr_loss_scalar
|
980 |
-
self._globalstep_last_logged = self.state.global_step
|
981 |
-
self.store_flos()
|
982 |
-
|
983 |
-
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
984 |
-
self.log(logs, start_time)
|
985 |
-
else: # transformers<=4.46
|
986 |
-
self.log(logs)
|
987 |
-
|
988 |
-
metrics = None
|
989 |
-
if self.control.should_evaluate:
|
990 |
-
metrics = self._evaluate(trial, ignore_keys_for_eval)
|
991 |
-
is_new_best_metric = self._determine_best_metric(metrics=metrics, trial=trial)
|
992 |
-
|
993 |
-
if self.args.save_strategy == "best":
|
994 |
-
self.control.should_save = is_new_best_metric
|
995 |
-
|
996 |
-
if self.control.should_save:
|
997 |
-
self._save_checkpoint(model, trial)
|
998 |
-
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
|
999 |
-
|
1000 |
-
# Copy-pasted from transformers.Trainer to maintain compatibility with earlier versions.
|
1001 |
-
# This can be removed once the minimum transformers version is updated to 4.47.
|
1002 |
-
# Refer to https://github.com/huggingface/trl/pull/2288 for more details.
|
1003 |
-
def _determine_best_metric(self, metrics, trial):
|
1004 |
-
"""
|
1005 |
-
Determine if the model should be saved based on the evaluation metrics.
|
1006 |
-
If args.metric_for_best_model is not set, the loss is used.
|
1007 |
-
Returns:
|
1008 |
-
bool: True if a new best metric was found, else False
|
1009 |
-
"""
|
1010 |
-
is_new_best_metric = False
|
1011 |
-
|
1012 |
-
if self.args.metric_for_best_model is not None:
|
1013 |
-
metric_to_check = self.args.metric_for_best_model
|
1014 |
-
|
1015 |
-
if not metric_to_check.startswith("eval_"):
|
1016 |
-
metric_to_check = f"eval_{metric_to_check}"
|
1017 |
-
|
1018 |
-
try:
|
1019 |
-
metric_value = metrics[metric_to_check]
|
1020 |
-
except KeyError as exc:
|
1021 |
-
raise KeyError(
|
1022 |
-
f"The `metric_for_best_model` training argument is set to '{metric_to_check}', which is not found in the evaluation metrics. "
|
1023 |
-
f"The available evaluation metrics are: {list(metrics.keys())}. Consider changing the `metric_for_best_model` via the TrainingArguments."
|
1024 |
-
) from exc
|
1025 |
-
|
1026 |
-
operator = np.greater if self.args.greater_is_better else np.less
|
1027 |
-
|
1028 |
-
if self.state.best_metric is None:
|
1029 |
-
self.state.best_metric = float("-inf") if self.args.greater_is_better else float("inf")
|
1030 |
-
|
1031 |
-
if operator(metric_value, self.state.best_metric):
|
1032 |
-
run_dir = self._get_output_dir(trial=trial)
|
1033 |
-
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
|
1034 |
-
output_dir = os.path.join(run_dir, checkpoint_folder)
|
1035 |
-
self.state.best_metric = metric_value
|
1036 |
-
self.state.best_model_checkpoint = output_dir
|
1037 |
-
|
1038 |
-
is_new_best_metric = True
|
1039 |
-
|
1040 |
-
return is_new_best_metric
|
1041 |
-
|
1042 |
-
def create_model_card(
|
1043 |
-
self,
|
1044 |
-
model_name: Optional[str] = None,
|
1045 |
-
dataset_name: Optional[str] = None,
|
1046 |
-
tags: Union[str, list[str], None] = None,
|
1047 |
-
):
|
1048 |
-
"""
|
1049 |
-
Creates a draft of a model card using the information available to the `Trainer`.
|
1050 |
-
|
1051 |
-
Args:
|
1052 |
-
model_name (`str` or `None`, *optional*, defaults to `None`):
|
1053 |
-
Name of the model.
|
1054 |
-
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
1055 |
-
Name of the dataset used for training.
|
1056 |
-
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
1057 |
-
Tags to be associated with the model card.
|
1058 |
-
"""
|
1059 |
-
if not self.is_world_process_zero():
|
1060 |
-
return
|
1061 |
-
|
1062 |
-
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
1063 |
-
base_model = self.model.config._name_or_path
|
1064 |
-
else:
|
1065 |
-
base_model = None
|
1066 |
-
|
1067 |
-
tags = tags or []
|
1068 |
-
if isinstance(tags, str):
|
1069 |
-
tags = [tags]
|
1070 |
-
|
1071 |
-
if hasattr(self.model.config, "unsloth_version"):
|
1072 |
-
tags.append("unsloth")
|
1073 |
-
|
1074 |
-
citation = textwrap.dedent("""\
|
1075 |
-
@article{guo2024direct,
|
1076 |
-
title = {{Direct Language Model Alignment from Online AI Feedback}},
|
1077 |
-
author = {Shangmin Guo and Biao Zhang and Tianlin Liu and Tianqi Liu and Misha Khalman and Felipe Llinares and Alexandre Ram{\'{e}} and Thomas Mesnard and Yao Zhao and Bilal Piot and Johan Ferret and Mathieu Blondel},
|
1078 |
-
year = 2024,
|
1079 |
-
eprint = {arXiv:2402.04792}
|
1080 |
-
}""")
|
1081 |
-
|
1082 |
-
model_card = generate_model_card(
|
1083 |
-
base_model=base_model,
|
1084 |
-
model_name=model_name,
|
1085 |
-
hub_model_id=self.hub_model_id,
|
1086 |
-
dataset_name=dataset_name,
|
1087 |
-
tags=tags,
|
1088 |
-
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
1089 |
-
comet_url=get_comet_experiment_url(),
|
1090 |
-
trainer_name="Online DPO",
|
1091 |
-
trainer_citation=citation,
|
1092 |
-
paper_title="Direct Language Model Alignment from Online AI Feedback",
|
1093 |
-
paper_id="2402.04792",
|
1094 |
-
)
|
1095 |
-
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
1096 |
-
class UnslothOnlineDPOTrainer(_UnslothOnlineDPOTrainer):
|
1097 |
-
"""
|
1098 |
-
|
1099 |
-
Initialize OnlineDPOTrainer.
|
1100 |
-
|
1101 |
-
Args:
|
1102 |
-
model (`transformers.PreTrainedModel` or `torch.nn.Module`):
|
1103 |
-
The model to train, preferably an `AutoModelForCausalLM`.
|
1104 |
-
ref_model (`transformers.PreTrainedModel` or `torch.nn.Module` or `None`):
|
1105 |
-
The reference model to use for training. If None is specified, the reference model will be created from
|
1106 |
-
the model.
|
1107 |
-
reward_model (`transformers.PreTrainedModel` or `torch.nn.Module` or `None`):
|
1108 |
-
The reward model to score completions with, preferably an `AutoModelForSequenceClassification`.
|
1109 |
-
judge (`BasePairwiseJudge`):
|
1110 |
-
The judge to use for pairwise comparison of model completions.
|
1111 |
-
args (`OnlineDPOConfig`):
|
1112 |
-
The online DPO config arguments to use for training.
|
1113 |
-
data_collator (`transformers.DataCollator`):
|
1114 |
-
The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
|
1115 |
-
which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
|
1116 |
-
train_dataset (`datasets.Dataset`):
|
1117 |
-
The dataset to use for training.
|
1118 |
-
eval_dataset (`datasets.Dataset`):
|
1119 |
-
The dataset to use for evaluation.
|
1120 |
-
processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
|
1121 |
-
Processing class used to process the data. If provided, will be used to automatically process the inputs
|
1122 |
-
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
|
1123 |
-
reuse the fine-tuned model.
|
1124 |
-
peft_config (`dict`):
|
1125 |
-
The peft config to use for training.
|
1126 |
-
compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
|
1127 |
-
The function to use to compute the metrics. Must take a `EvalPrediction` and return
|
1128 |
-
a dictionary string to metric values.
|
1129 |
-
callbacks (`list[transformers.TrainerCallback]`):
|
1130 |
-
The callbacks to use for training.
|
1131 |
-
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
|
1132 |
-
The optimizer and scheduler to use for training.
|
1133 |
-
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
|
1134 |
-
The function to use to preprocess the logits before computing the metrics.
|
1135 |
-
|
1136 |
-
"""
|
1137 |
-
def __init__(
|
1138 |
-
self,
|
1139 |
-
model,
|
1140 |
-
ref_model = None,
|
1141 |
-
reward_model = None,
|
1142 |
-
judge = None,
|
1143 |
-
args = None,
|
1144 |
-
data_collator = None,
|
1145 |
-
train_dataset = None,
|
1146 |
-
eval_dataset = None,
|
1147 |
-
processing_class = None,
|
1148 |
-
reward_processing_class = None,
|
1149 |
-
peft_config = None,
|
1150 |
-
compute_metrics = None,
|
1151 |
-
callbacks = None,
|
1152 |
-
preprocess_logits_for_metrics = None,
|
1153 |
-
**kwargs
|
1154 |
-
):
|
1155 |
-
if args is None: args = UnslothOnlineDPOConfig()
|
1156 |
-
use_bf16 = getattr(args, 'bf16', False)
|
1157 |
-
use_fp16 = getattr(args, 'fp16', False)
|
1158 |
-
force_float32 = False
|
1159 |
-
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
1160 |
-
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
1161 |
-
force_float32 = True
|
1162 |
-
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
1163 |
-
dtype = getattr(model.config, 'torch_dtype', None)
|
1164 |
-
if dtype is None: dtype = model.get_input_embeddings().dtype
|
1165 |
-
from unsloth_zoo.utils import _get_dtype
|
1166 |
-
dtype = _get_dtype(dtype)
|
1167 |
-
float16 = dtype == torch.float16
|
1168 |
-
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
1169 |
-
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
1170 |
-
if force_float32:
|
1171 |
-
args.fp16 = False
|
1172 |
-
args.bf16 = False
|
1173 |
-
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
1174 |
-
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
1175 |
-
args.fp16 = float16
|
1176 |
-
args.bf16 = not float16
|
1177 |
-
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
1178 |
-
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
1179 |
-
args.eval_strategy = 'steps'
|
1180 |
-
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
1181 |
-
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
1182 |
-
if ga_steps is not None and ga_steps > 1:
|
1183 |
-
from transformers import __version__ as transformers_version
|
1184 |
-
if Version(transformers_version) <= Version('4.45.2'):
|
1185 |
-
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
1186 |
-
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
1187 |
-
if getattr(args, 'eval_strategy', 'no') != 'no':
|
1188 |
-
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
1189 |
-
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
1190 |
-
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
1191 |
-
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
1192 |
-
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
1193 |
-
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
1194 |
-
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
1195 |
-
if force_float32:
|
1196 |
-
args.bf16_full_eval = False
|
1197 |
-
args.fp16_full_eval = False
|
1198 |
-
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
1199 |
-
args.bf16_full_eval = True
|
1200 |
-
args.fp16_full_eval = False
|
1201 |
-
elif not bf16_full_eval and not fp16_full_eval:
|
1202 |
-
args.bf16_full_eval = args.bf16
|
1203 |
-
args.fp16_full_eval = args.fp16
|
1204 |
-
_output_logits = False
|
1205 |
-
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
1206 |
-
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
1207 |
-
if _output_logits:
|
1208 |
-
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
1209 |
-
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
1210 |
-
pass
|
1211 |
-
else:
|
1212 |
-
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
1213 |
-
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
1214 |
-
if args_max_seq_length is None and model_max_seq_length is not None:
|
1215 |
-
max_seq_length = model.max_seq_length
|
1216 |
-
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
1217 |
-
if model is not None and hasattr(model, 'for_training'):
|
1218 |
-
model.for_training()
|
1219 |
-
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
1220 |
-
if 'processing_class' in locals():
|
1221 |
-
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
1222 |
-
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
1223 |
-
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
1224 |
-
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
1225 |
-
if not isinstance(data_collator, UnslothVisionDataCollator):
|
1226 |
-
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
1227 |
-
data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
|
1228 |
-
elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
1229 |
-
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
1230 |
-
else:
|
1231 |
-
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
1232 |
-
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
1233 |
-
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
1234 |
-
if not isinstance(data_collator, UnslothVisionDataCollator):
|
1235 |
-
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
1236 |
-
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
1237 |
-
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
1238 |
-
else:
|
1239 |
-
data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
|
1240 |
-
other_metrics = []
|
1241 |
-
|
1242 |
-
from unsloth_zoo.logging_utils import PatchRLStatistics
|
1243 |
-
PatchRLStatistics('online_dpo_trainer', other_metrics)
|
1244 |
-
|
1245 |
-
super().__init__(
|
1246 |
-
model = model,
|
1247 |
-
ref_model = ref_model,
|
1248 |
-
reward_model = reward_model,
|
1249 |
-
judge = judge,
|
1250 |
-
args = args,
|
1251 |
-
data_collator = data_collator,
|
1252 |
-
train_dataset = train_dataset,
|
1253 |
-
eval_dataset = eval_dataset,
|
1254 |
-
processing_class = processing_class,
|
1255 |
-
reward_processing_class = reward_processing_class,
|
1256 |
-
peft_config = peft_config,
|
1257 |
-
compute_metrics = compute_metrics,
|
1258 |
-
callbacks = callbacks,
|
1259 |
-
preprocess_logits_for_metrics = preprocess_logits_for_metrics,**kwargs)
|
1260 |
-
if hasattr(self, 'neftune_hook_handle'):
|
1261 |
-
self.neftune_hook_handle.remove()
|
1262 |
-
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
1263 |
-
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
1264 |
-
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
1265 |
-
pass
|
1266 |
-
|
1267 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unsloth_compiled_cache/UnslothPPOTrainer.py
DELETED
@@ -1,1257 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
2025.3.13
|
3 |
-
2025.3.15
|
4 |
-
4.48.3
|
5 |
-
0.15.2
|
6 |
-
__UNSLOTH_VERSIONING__
|
7 |
-
"""
|
8 |
-
from torch import Tensor
|
9 |
-
import torch
|
10 |
-
import torch.nn as nn
|
11 |
-
from torch.nn import functional as F
|
12 |
-
from trl.trainer.ppo_trainer import (Accelerator, BaseImageProcessor, CallbackHandler, DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK, DataCollatorWithPadding, DataLoader, Dataset, ExportableState, FeatureExtractionMixin, GenerationConfig, INVALID_LOGPROB, OnlineTrainerState, Optional, PPOConfig, PPOTrainer, PeftConfig, PeftModel, PolicyAndValueWrapper, PreTrainedTokenizerBase, PrinterCallback, ProcessorMixin, Trainer, TrainerCallback, TrainerControl, Union, batch_generation, broadcast, contextmanager, create_reference_model, defaultdict, disable_dropout_in_model, exact_div, first_true_indices, forward, gather_object, gc, generate_model_card, get_comet_experiment_url, get_peft_model, get_reporting_integration_callbacks, get_reward, is_peft_available, is_wandb_available, log_table_to_comet_experiment, masked_mean, masked_whiten, math, nn, np, nullcontext, os, pd, peft_module_casting_to_bf16, prepare_deepspeed, print_rich_table, textwrap, time, torch, truncate_response, unwrap_model_for_generation, wandb)
|
13 |
-
|
14 |
-
|
15 |
-
import os
|
16 |
-
from typing import *
|
17 |
-
from dataclasses import dataclass, field
|
18 |
-
from packaging.version import Version
|
19 |
-
import torch
|
20 |
-
import numpy as np
|
21 |
-
from contextlib import nullcontext
|
22 |
-
from torch.nn import functional as F
|
23 |
-
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
|
24 |
-
|
25 |
-
torch_compile_options = {
|
26 |
-
"epilogue_fusion" : True,
|
27 |
-
"max_autotune" : False,
|
28 |
-
"shape_padding" : True,
|
29 |
-
"trace.enabled" : False,
|
30 |
-
"triton.cudagraphs" : False,
|
31 |
-
}
|
32 |
-
|
33 |
-
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
34 |
-
def selective_log_softmax(logits, index):
|
35 |
-
logits = logits.to(torch.float32)
|
36 |
-
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
37 |
-
# loop to reduce peak mem consumption
|
38 |
-
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
39 |
-
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
40 |
-
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
41 |
-
return per_token_logps
|
42 |
-
@dataclass
|
43 |
-
class UnslothPPOConfig(PPOConfig):
|
44 |
-
"""
|
45 |
-
|
46 |
-
Configuration class for the [`PPOTrainer`].
|
47 |
-
|
48 |
-
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
49 |
-
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
50 |
-
command line.
|
51 |
-
|
52 |
-
Parameters:
|
53 |
-
exp_name (`str`, *optional*, defaults to `os.path.basename(__file__)[:-3]`):
|
54 |
-
Name of this experiment.
|
55 |
-
reward_model_path (`str`, *optional*, defaults to `"EleutherAI/pythia-160m"`):
|
56 |
-
Path to the reward model.
|
57 |
-
model_adapter_name (`str` or `None`, *optional*, defaults to `None`):
|
58 |
-
Name of the train target PEFT adapter, when using LoRA with multiple adapters.
|
59 |
-
ref_adapter_name (`str` or `None`, *optional*, defaults to `None`):
|
60 |
-
Name of the reference PEFT adapter, when using LoRA with multiple adapters.
|
61 |
-
num_ppo_epochs (`int`, *optional*, defaults to `4`):
|
62 |
-
Number of epochs to train.
|
63 |
-
whiten_rewards (`bool`, *optional*, defaults to `False`):
|
64 |
-
Whether to whiten the rewards.
|
65 |
-
kl_coef (`float`, *optional*, defaults to `0.05`):
|
66 |
-
KL coefficient.
|
67 |
-
cliprange (`float`, *optional*, defaults to `0.2`):
|
68 |
-
Clip range.
|
69 |
-
vf_coef (`float`, *optional*, defaults to `0.1`):
|
70 |
-
Value function coefficient.
|
71 |
-
cliprange_value (`float`, *optional*, defaults to `0.2`):
|
72 |
-
Clip range for the value function.
|
73 |
-
gamma (`float`, *optional*, defaults to `1.0`):
|
74 |
-
Discount factor.
|
75 |
-
lam (`float`, *optional*, defaults to `0.95`):
|
76 |
-
Lambda value for GAE.
|
77 |
-
ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
|
78 |
-
This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
|
79 |
-
improving generation speed. However, disabling this option allows training models that exceed the VRAM
|
80 |
-
capacity of a single GPU, albeit at the cost of slower generation.
|
81 |
-
|
82 |
-
"""
|
83 |
-
vllm_sampling_params: Optional[Any] = field(
|
84 |
-
default = None,
|
85 |
-
metadata = {'help': 'vLLM SamplingParams'},
|
86 |
-
)
|
87 |
-
unsloth_num_chunks : Optional[int] = field(
|
88 |
-
default = -1,
|
89 |
-
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
90 |
-
)
|
91 |
-
def __init__(
|
92 |
-
self,
|
93 |
-
output_dir = None,
|
94 |
-
overwrite_output_dir = None,
|
95 |
-
do_train = False,
|
96 |
-
do_eval = False,
|
97 |
-
do_predict = False,
|
98 |
-
eval_strategy = 'no',
|
99 |
-
prediction_loss_only = False,
|
100 |
-
per_device_train_batch_size = 4,
|
101 |
-
per_device_eval_batch_size = 4,
|
102 |
-
per_gpu_train_batch_size = None,
|
103 |
-
per_gpu_eval_batch_size = None,
|
104 |
-
gradient_accumulation_steps = 2,
|
105 |
-
eval_accumulation_steps = 2,
|
106 |
-
eval_delay = 0,
|
107 |
-
torch_empty_cache_steps = 250,
|
108 |
-
learning_rate = 5e-05,
|
109 |
-
weight_decay = 0.01,
|
110 |
-
adam_beta1 = 0.9,
|
111 |
-
adam_beta2 = 0.999,
|
112 |
-
adam_epsilon = 1e-08,
|
113 |
-
max_grad_norm = 1.0,
|
114 |
-
num_train_epochs = 3.0,
|
115 |
-
max_steps = -1,
|
116 |
-
lr_scheduler_type = 'linear',
|
117 |
-
warmup_ratio = 0.1,
|
118 |
-
warmup_steps = 0,
|
119 |
-
log_level = 'passive',
|
120 |
-
log_level_replica = 'warning',
|
121 |
-
log_on_each_node = True,
|
122 |
-
logging_dir = None,
|
123 |
-
logging_strategy = 'steps',
|
124 |
-
logging_first_step = False,
|
125 |
-
logging_steps = 1,
|
126 |
-
logging_nan_inf_filter = False,
|
127 |
-
save_strategy = 'steps',
|
128 |
-
save_steps = 500,
|
129 |
-
save_total_limit = None,
|
130 |
-
save_safetensors = True,
|
131 |
-
save_on_each_node = False,
|
132 |
-
save_only_model = False,
|
133 |
-
restore_callback_states_from_checkpoint = False,
|
134 |
-
no_cuda = False,
|
135 |
-
use_cpu = False,
|
136 |
-
use_mps_device = False,
|
137 |
-
seed = 3407,
|
138 |
-
data_seed = 3407,
|
139 |
-
jit_mode_eval = False,
|
140 |
-
use_ipex = False,
|
141 |
-
bf16 = False,
|
142 |
-
fp16 = False,
|
143 |
-
fp16_opt_level = 'O1',
|
144 |
-
half_precision_backend = 'auto',
|
145 |
-
bf16_full_eval = False,
|
146 |
-
fp16_full_eval = False,
|
147 |
-
tf32 = None,
|
148 |
-
local_rank = -1,
|
149 |
-
ddp_backend = None,
|
150 |
-
tpu_num_cores = None,
|
151 |
-
tpu_metrics_debug = False,
|
152 |
-
debug = '',
|
153 |
-
dataloader_drop_last = False,
|
154 |
-
eval_steps = None,
|
155 |
-
dataloader_num_workers = 0,
|
156 |
-
dataloader_prefetch_factor = None,
|
157 |
-
past_index = -1,
|
158 |
-
run_name = None,
|
159 |
-
disable_tqdm = None,
|
160 |
-
remove_unused_columns = True,
|
161 |
-
label_names = None,
|
162 |
-
load_best_model_at_end = False,
|
163 |
-
metric_for_best_model = None,
|
164 |
-
greater_is_better = None,
|
165 |
-
ignore_data_skip = False,
|
166 |
-
fsdp = '',
|
167 |
-
fsdp_min_num_params = 0,
|
168 |
-
fsdp_config = None,
|
169 |
-
fsdp_transformer_layer_cls_to_wrap = None,
|
170 |
-
accelerator_config = None,
|
171 |
-
deepspeed = None,
|
172 |
-
label_smoothing_factor = 0.0,
|
173 |
-
optim = 'adamw_8bit',
|
174 |
-
optim_args = None,
|
175 |
-
adafactor = False,
|
176 |
-
group_by_length = False,
|
177 |
-
length_column_name = 'length',
|
178 |
-
report_to = None,
|
179 |
-
ddp_find_unused_parameters = None,
|
180 |
-
ddp_bucket_cap_mb = None,
|
181 |
-
ddp_broadcast_buffers = None,
|
182 |
-
dataloader_pin_memory = True,
|
183 |
-
dataloader_persistent_workers = False,
|
184 |
-
skip_memory_metrics = True,
|
185 |
-
use_legacy_prediction_loop = False,
|
186 |
-
push_to_hub = False,
|
187 |
-
resume_from_checkpoint = None,
|
188 |
-
hub_model_id = None,
|
189 |
-
hub_strategy = 'every_save',
|
190 |
-
hub_token = None,
|
191 |
-
hub_private_repo = None,
|
192 |
-
hub_always_push = False,
|
193 |
-
gradient_checkpointing = False,
|
194 |
-
gradient_checkpointing_kwargs = None,
|
195 |
-
include_inputs_for_metrics = False,
|
196 |
-
eval_do_concat_batches = True,
|
197 |
-
fp16_backend = 'auto',
|
198 |
-
evaluation_strategy = None,
|
199 |
-
push_to_hub_model_id = None,
|
200 |
-
push_to_hub_organization = None,
|
201 |
-
push_to_hub_token = None,
|
202 |
-
mp_parameters = '',
|
203 |
-
auto_find_batch_size = False,
|
204 |
-
full_determinism = False,
|
205 |
-
torchdynamo = None,
|
206 |
-
ray_scope = 'last',
|
207 |
-
ddp_timeout = 1800,
|
208 |
-
torch_compile = False,
|
209 |
-
torch_compile_backend = None,
|
210 |
-
torch_compile_mode = None,
|
211 |
-
dispatch_batches = None,
|
212 |
-
split_batches = None,
|
213 |
-
include_tokens_per_second = False,
|
214 |
-
include_num_input_tokens_seen = False,
|
215 |
-
neftune_noise_alpha = None,
|
216 |
-
optim_target_modules = None,
|
217 |
-
batch_eval_metrics = False,
|
218 |
-
eval_on_start = False,
|
219 |
-
use_liger_kernel = False,
|
220 |
-
eval_use_gather_object = False,
|
221 |
-
average_tokens_across_devices = False,
|
222 |
-
dataset_num_proc = None,
|
223 |
-
num_mini_batches = 1,
|
224 |
-
total_episodes = None,
|
225 |
-
local_rollout_forward_batch_size = 64,
|
226 |
-
num_sample_generations = 10,
|
227 |
-
response_length = 53,
|
228 |
-
stop_token = None,
|
229 |
-
stop_token_id = None,
|
230 |
-
temperature = 0.7,
|
231 |
-
missing_eos_penalty = None,
|
232 |
-
sft_model_path = 'EleutherAI/pythia-160m',
|
233 |
-
world_size = None,
|
234 |
-
num_total_batches = None,
|
235 |
-
micro_batch_size = None,
|
236 |
-
local_batch_size = None,
|
237 |
-
batch_size = None,
|
238 |
-
local_mini_batch_size = None,
|
239 |
-
mini_batch_size = None,
|
240 |
-
exp_name = 'ppo_config',
|
241 |
-
reward_model_path = 'EleutherAI/pythia-160m',
|
242 |
-
model_adapter_name = None,
|
243 |
-
ref_adapter_name = None,
|
244 |
-
num_ppo_epochs = 4,
|
245 |
-
whiten_rewards = False,
|
246 |
-
kl_coef = 0.05,
|
247 |
-
cliprange = 0.2,
|
248 |
-
vf_coef = 0.1,
|
249 |
-
cliprange_value = 0.2,
|
250 |
-
gamma = 1.0,
|
251 |
-
lam = 0.95,
|
252 |
-
ds3_gather_for_generation = True,
|
253 |
-
vllm_sampling_params = None,
|
254 |
-
unsloth_num_chunks = -1,
|
255 |
-
**kwargs,
|
256 |
-
):
|
257 |
-
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
258 |
-
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
259 |
-
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
260 |
-
output_dir = 'unsloth_training_checkpoints'
|
261 |
-
save_strategy = 'no'
|
262 |
-
if dataset_num_proc is None:
|
263 |
-
from multiprocessing import cpu_count
|
264 |
-
dataset_num_proc = cpu_count()
|
265 |
-
|
266 |
-
super().__init__(
|
267 |
-
output_dir = output_dir,
|
268 |
-
overwrite_output_dir = overwrite_output_dir,
|
269 |
-
do_train = do_train,
|
270 |
-
do_eval = do_eval,
|
271 |
-
do_predict = do_predict,
|
272 |
-
eval_strategy = eval_strategy,
|
273 |
-
prediction_loss_only = prediction_loss_only,
|
274 |
-
per_device_train_batch_size = per_device_train_batch_size,
|
275 |
-
per_device_eval_batch_size = per_device_eval_batch_size,
|
276 |
-
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
277 |
-
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
278 |
-
gradient_accumulation_steps = gradient_accumulation_steps,
|
279 |
-
eval_accumulation_steps = eval_accumulation_steps,
|
280 |
-
eval_delay = eval_delay,
|
281 |
-
torch_empty_cache_steps = torch_empty_cache_steps,
|
282 |
-
learning_rate = learning_rate,
|
283 |
-
weight_decay = weight_decay,
|
284 |
-
adam_beta1 = adam_beta1,
|
285 |
-
adam_beta2 = adam_beta2,
|
286 |
-
adam_epsilon = adam_epsilon,
|
287 |
-
max_grad_norm = max_grad_norm,
|
288 |
-
num_train_epochs = num_train_epochs,
|
289 |
-
max_steps = max_steps,
|
290 |
-
lr_scheduler_type = lr_scheduler_type,
|
291 |
-
warmup_ratio = warmup_ratio,
|
292 |
-
warmup_steps = warmup_steps,
|
293 |
-
log_level = log_level,
|
294 |
-
log_level_replica = log_level_replica,
|
295 |
-
log_on_each_node = log_on_each_node,
|
296 |
-
logging_dir = logging_dir,
|
297 |
-
logging_strategy = logging_strategy,
|
298 |
-
logging_first_step = logging_first_step,
|
299 |
-
logging_steps = logging_steps,
|
300 |
-
logging_nan_inf_filter = logging_nan_inf_filter,
|
301 |
-
save_strategy = save_strategy,
|
302 |
-
save_steps = save_steps,
|
303 |
-
save_total_limit = save_total_limit,
|
304 |
-
save_safetensors = save_safetensors,
|
305 |
-
save_on_each_node = save_on_each_node,
|
306 |
-
save_only_model = save_only_model,
|
307 |
-
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
308 |
-
no_cuda = no_cuda,
|
309 |
-
use_cpu = use_cpu,
|
310 |
-
use_mps_device = use_mps_device,
|
311 |
-
seed = seed,
|
312 |
-
data_seed = data_seed,
|
313 |
-
jit_mode_eval = jit_mode_eval,
|
314 |
-
use_ipex = use_ipex,
|
315 |
-
bf16 = bf16,
|
316 |
-
fp16 = fp16,
|
317 |
-
fp16_opt_level = fp16_opt_level,
|
318 |
-
half_precision_backend = half_precision_backend,
|
319 |
-
bf16_full_eval = bf16_full_eval,
|
320 |
-
fp16_full_eval = fp16_full_eval,
|
321 |
-
tf32 = tf32,
|
322 |
-
local_rank = local_rank,
|
323 |
-
ddp_backend = ddp_backend,
|
324 |
-
tpu_num_cores = tpu_num_cores,
|
325 |
-
tpu_metrics_debug = tpu_metrics_debug,
|
326 |
-
debug = debug,
|
327 |
-
dataloader_drop_last = dataloader_drop_last,
|
328 |
-
eval_steps = eval_steps,
|
329 |
-
dataloader_num_workers = dataloader_num_workers,
|
330 |
-
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
331 |
-
past_index = past_index,
|
332 |
-
run_name = run_name,
|
333 |
-
disable_tqdm = disable_tqdm,
|
334 |
-
remove_unused_columns = remove_unused_columns,
|
335 |
-
label_names = label_names,
|
336 |
-
load_best_model_at_end = load_best_model_at_end,
|
337 |
-
metric_for_best_model = metric_for_best_model,
|
338 |
-
greater_is_better = greater_is_better,
|
339 |
-
ignore_data_skip = ignore_data_skip,
|
340 |
-
fsdp = fsdp,
|
341 |
-
fsdp_min_num_params = fsdp_min_num_params,
|
342 |
-
fsdp_config = fsdp_config,
|
343 |
-
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
344 |
-
accelerator_config = accelerator_config,
|
345 |
-
deepspeed = deepspeed,
|
346 |
-
label_smoothing_factor = label_smoothing_factor,
|
347 |
-
optim = optim,
|
348 |
-
optim_args = optim_args,
|
349 |
-
adafactor = adafactor,
|
350 |
-
group_by_length = group_by_length,
|
351 |
-
length_column_name = length_column_name,
|
352 |
-
report_to = report_to,
|
353 |
-
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
354 |
-
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
355 |
-
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
356 |
-
dataloader_pin_memory = dataloader_pin_memory,
|
357 |
-
dataloader_persistent_workers = dataloader_persistent_workers,
|
358 |
-
skip_memory_metrics = skip_memory_metrics,
|
359 |
-
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
360 |
-
push_to_hub = push_to_hub,
|
361 |
-
resume_from_checkpoint = resume_from_checkpoint,
|
362 |
-
hub_model_id = hub_model_id,
|
363 |
-
hub_strategy = hub_strategy,
|
364 |
-
hub_token = hub_token,
|
365 |
-
hub_private_repo = hub_private_repo,
|
366 |
-
hub_always_push = hub_always_push,
|
367 |
-
gradient_checkpointing = gradient_checkpointing,
|
368 |
-
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
369 |
-
include_inputs_for_metrics = include_inputs_for_metrics,
|
370 |
-
eval_do_concat_batches = eval_do_concat_batches,
|
371 |
-
fp16_backend = fp16_backend,
|
372 |
-
evaluation_strategy = evaluation_strategy,
|
373 |
-
push_to_hub_model_id = push_to_hub_model_id,
|
374 |
-
push_to_hub_organization = push_to_hub_organization,
|
375 |
-
push_to_hub_token = push_to_hub_token,
|
376 |
-
mp_parameters = mp_parameters,
|
377 |
-
auto_find_batch_size = auto_find_batch_size,
|
378 |
-
full_determinism = full_determinism,
|
379 |
-
torchdynamo = torchdynamo,
|
380 |
-
ray_scope = ray_scope,
|
381 |
-
ddp_timeout = ddp_timeout,
|
382 |
-
torch_compile = torch_compile,
|
383 |
-
torch_compile_backend = torch_compile_backend,
|
384 |
-
torch_compile_mode = torch_compile_mode,
|
385 |
-
dispatch_batches = dispatch_batches,
|
386 |
-
split_batches = split_batches,
|
387 |
-
include_tokens_per_second = include_tokens_per_second,
|
388 |
-
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
389 |
-
neftune_noise_alpha = neftune_noise_alpha,
|
390 |
-
optim_target_modules = optim_target_modules,
|
391 |
-
batch_eval_metrics = batch_eval_metrics,
|
392 |
-
eval_on_start = eval_on_start,
|
393 |
-
use_liger_kernel = use_liger_kernel,
|
394 |
-
eval_use_gather_object = eval_use_gather_object,
|
395 |
-
average_tokens_across_devices = average_tokens_across_devices,
|
396 |
-
dataset_num_proc = dataset_num_proc,
|
397 |
-
num_mini_batches = num_mini_batches,
|
398 |
-
total_episodes = total_episodes,
|
399 |
-
local_rollout_forward_batch_size = local_rollout_forward_batch_size,
|
400 |
-
num_sample_generations = num_sample_generations,
|
401 |
-
response_length = response_length,
|
402 |
-
stop_token = stop_token,
|
403 |
-
stop_token_id = stop_token_id,
|
404 |
-
temperature = temperature,
|
405 |
-
missing_eos_penalty = missing_eos_penalty,
|
406 |
-
sft_model_path = sft_model_path,
|
407 |
-
world_size = world_size,
|
408 |
-
num_total_batches = num_total_batches,
|
409 |
-
micro_batch_size = micro_batch_size,
|
410 |
-
local_batch_size = local_batch_size,
|
411 |
-
batch_size = batch_size,
|
412 |
-
local_mini_batch_size = local_mini_batch_size,
|
413 |
-
mini_batch_size = mini_batch_size,
|
414 |
-
exp_name = exp_name,
|
415 |
-
reward_model_path = reward_model_path,
|
416 |
-
model_adapter_name = model_adapter_name,
|
417 |
-
ref_adapter_name = ref_adapter_name,
|
418 |
-
num_ppo_epochs = num_ppo_epochs,
|
419 |
-
whiten_rewards = whiten_rewards,
|
420 |
-
kl_coef = kl_coef,
|
421 |
-
cliprange = cliprange,
|
422 |
-
vf_coef = vf_coef,
|
423 |
-
cliprange_value = cliprange_value,
|
424 |
-
gamma = gamma,
|
425 |
-
lam = lam,
|
426 |
-
ds3_gather_for_generation = ds3_gather_for_generation,**kwargs)
|
427 |
-
self.vllm_sampling_params = vllm_sampling_params
|
428 |
-
self.unsloth_num_chunks = unsloth_num_chunks
|
429 |
-
pass
|
430 |
-
|
431 |
-
class _UnslothPPOTrainer(Trainer):
|
432 |
-
_tag_names = ["trl", "ppo"]
|
433 |
-
|
434 |
-
def __init__(
|
435 |
-
self,
|
436 |
-
args: PPOConfig,
|
437 |
-
processing_class: Optional[
|
438 |
-
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
439 |
-
],
|
440 |
-
model: nn.Module,
|
441 |
-
ref_model: Optional[nn.Module],
|
442 |
-
reward_model: nn.Module,
|
443 |
-
train_dataset: Dataset,
|
444 |
-
value_model: Optional[nn.Module] = None,
|
445 |
-
data_collator: Optional[DataCollatorWithPadding] = None,
|
446 |
-
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
447 |
-
# less commonly used
|
448 |
-
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
449 |
-
callbacks: Optional[list[TrainerCallback]] = None,
|
450 |
-
peft_config: Optional["PeftConfig"] = None,
|
451 |
-
) -> None:
|
452 |
-
if ref_model is model:
|
453 |
-
raise ValueError(
|
454 |
-
"`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
|
455 |
-
"same as `model`, you must make a copy of it, or `None` if you use peft."
|
456 |
-
)
|
457 |
-
|
458 |
-
self.args = args
|
459 |
-
self.processing_class = processing_class
|
460 |
-
self.policy_model = model
|
461 |
-
|
462 |
-
# Define the collator if not provided
|
463 |
-
if data_collator is None:
|
464 |
-
data_collator = DataCollatorWithPadding(self.processing_class)
|
465 |
-
|
466 |
-
# Handle stop token settings: update policy model's generation_config to use provided stop token
|
467 |
-
if args.stop_token and args.stop_token_id:
|
468 |
-
raise ValueError("You cannot set both `stop_token` and `stop_token_id`.")
|
469 |
-
elif args.stop_token:
|
470 |
-
if args.stop_token == "eos":
|
471 |
-
self.policy_model.generation_config.eos_token_id = self.stop_token_id = processing_class.eos_token_id
|
472 |
-
else:
|
473 |
-
raise ValueError(
|
474 |
-
f"Unknown `stop_token` {args.stop_token}. Allowed values are: `'eos'` and `None` (no stop token)."
|
475 |
-
)
|
476 |
-
else:
|
477 |
-
self.policy_model.generation_config.eos_token_id = self.stop_token_id = args.stop_token_id # None or int
|
478 |
-
|
479 |
-
# peft support
|
480 |
-
if not is_peft_available() and peft_config is not None:
|
481 |
-
raise ImportError(
|
482 |
-
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
|
483 |
-
)
|
484 |
-
elif is_peft_available() and peft_config is not None:
|
485 |
-
# if model is a peft model and we have a peft_confg, we merge and unload it first
|
486 |
-
if isinstance(self.policy_model, PeftModel):
|
487 |
-
self.policy_model = self.policy_model.merge_and_unload()
|
488 |
-
|
489 |
-
# get peft model with the given config
|
490 |
-
self.policy_model = get_peft_model(self.policy_model, peft_config)
|
491 |
-
if args.bf16 and getattr(self.policy_model, "is_loaded_in_4bit", False):
|
492 |
-
peft_module_casting_to_bf16(self.policy_model)
|
493 |
-
|
494 |
-
self.is_peft_model = is_peft_available() and isinstance(self.policy_model, PeftModel)
|
495 |
-
self.model_adapter_name = args.model_adapter_name
|
496 |
-
self.ref_adapter_name = args.ref_adapter_name
|
497 |
-
|
498 |
-
if ref_model:
|
499 |
-
self.ref_model = ref_model
|
500 |
-
elif self.is_peft_model:
|
501 |
-
self.ref_model = None
|
502 |
-
else:
|
503 |
-
self.ref_model = create_reference_model(self.policy_model)
|
504 |
-
|
505 |
-
self.reward_model = reward_model
|
506 |
-
self.train_dataset = train_dataset
|
507 |
-
self.train_dataset_len = len(train_dataset)
|
508 |
-
self.value_model = value_model
|
509 |
-
self.data_collator = data_collator
|
510 |
-
self.eval_dataset = eval_dataset
|
511 |
-
self.optimizer, self.lr_scheduler = optimizers
|
512 |
-
self.optimizer_cls_and_kwargs = None # needed for transformers >= 4.47
|
513 |
-
|
514 |
-
#########
|
515 |
-
# calculate various batch sizes
|
516 |
-
#########
|
517 |
-
if args.total_episodes is None: # allow the users to define episodes in terms of epochs.
|
518 |
-
args.total_episodes = int(args.num_train_epochs * self.train_dataset_len)
|
519 |
-
accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps)
|
520 |
-
self.accelerator = accelerator
|
521 |
-
args.world_size = accelerator.num_processes
|
522 |
-
args.local_batch_size = (
|
523 |
-
args.per_device_train_batch_size * args.gradient_accumulation_steps * args.num_mini_batches
|
524 |
-
)
|
525 |
-
args.micro_batch_size = int(args.per_device_train_batch_size * args.world_size)
|
526 |
-
args.batch_size = int(args.local_batch_size * args.world_size)
|
527 |
-
args.mini_batch_size = exact_div(
|
528 |
-
args.batch_size, args.num_mini_batches, "`batch_size` must be a multiple of `num_mini_batches`"
|
529 |
-
)
|
530 |
-
args.local_mini_batch_size = exact_div(
|
531 |
-
args.local_batch_size, args.num_mini_batches, "`local_batch_size` must be a multiple of `num_mini_batches`"
|
532 |
-
)
|
533 |
-
if args.whiten_rewards:
|
534 |
-
assert (
|
535 |
-
args.local_mini_batch_size >= 8
|
536 |
-
), f"Per-rank minibatch size {args.local_mini_batch_size} is insufficient for whitening"
|
537 |
-
# `per_rank_rollout_batch_size` is our `args.local_batch_size`
|
538 |
-
# `per_rank_minibatch_size` is our `args.local_mini_batch_size`
|
539 |
-
args.num_total_batches = math.ceil(
|
540 |
-
args.total_episodes / args.batch_size
|
541 |
-
) # we may train for more than `total_episodes`
|
542 |
-
time_tensor = torch.tensor(int(time.time()), device=accelerator.device)
|
543 |
-
time_int = broadcast(time_tensor, 0).item() # avoid different timestamps across processes
|
544 |
-
args.run_name = f"{args.exp_name}__{args.seed}__{time_int}"
|
545 |
-
self.local_seed = args.seed + accelerator.process_index * 100003 # Prime
|
546 |
-
if args.num_sample_generations > 0:
|
547 |
-
self.sample_generations_freq = max(1, args.num_total_batches // args.num_sample_generations)
|
548 |
-
self.local_dataloader_batch_size = args.local_batch_size
|
549 |
-
|
550 |
-
#########
|
551 |
-
# setup model, optimizer, and others
|
552 |
-
#########
|
553 |
-
for module in [self.policy_model, self.ref_model, self.value_model, self.reward_model]:
|
554 |
-
if module is not None:
|
555 |
-
disable_dropout_in_model(module)
|
556 |
-
self.model = PolicyAndValueWrapper(self.policy_model, self.value_model)
|
557 |
-
self.model.config = self.policy_model.config # needed for pushing to hub
|
558 |
-
self.create_optimizer_and_scheduler(
|
559 |
-
num_training_steps=args.num_total_batches
|
560 |
-
) # note that we are calling `self.lr_scheduler.step()` manually only at the batch level
|
561 |
-
|
562 |
-
#########
|
563 |
-
### trainer specifics
|
564 |
-
#########
|
565 |
-
default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
|
566 |
-
self.callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
|
567 |
-
self.callback_handler = CallbackHandler(
|
568 |
-
self.callbacks, self.model, self.processing_class, self.optimizer, self.lr_scheduler
|
569 |
-
)
|
570 |
-
self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
|
571 |
-
self.control = TrainerControl()
|
572 |
-
self.state = OnlineTrainerState(
|
573 |
-
is_local_process_zero=self.is_local_process_zero(),
|
574 |
-
is_world_process_zero=self.is_world_process_zero(),
|
575 |
-
stateful_callbacks=[
|
576 |
-
cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
|
577 |
-
],
|
578 |
-
)
|
579 |
-
self.current_flos = 0
|
580 |
-
self.hp_search_backend = None
|
581 |
-
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
|
582 |
-
self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
|
583 |
-
# Create distant repo and output directory if needed
|
584 |
-
self.hub_model_id = None
|
585 |
-
if self.args.push_to_hub:
|
586 |
-
self.init_hf_repo()
|
587 |
-
if self.args.should_save:
|
588 |
-
os.makedirs(self.args.output_dir, exist_ok=True)
|
589 |
-
|
590 |
-
# Add tags for models that have been loaded with the correct transformers version
|
591 |
-
if hasattr(self.model, "add_model_tags"):
|
592 |
-
self.model.add_model_tags(self._tag_names)
|
593 |
-
|
594 |
-
#########
|
595 |
-
### setup dataloader
|
596 |
-
#########
|
597 |
-
self.dataloader = DataLoader(
|
598 |
-
self.train_dataset,
|
599 |
-
batch_size=self.local_dataloader_batch_size,
|
600 |
-
shuffle=True,
|
601 |
-
collate_fn=self.data_collator,
|
602 |
-
drop_last=True, # needed; otherwise the last batch will be of ragged shape
|
603 |
-
)
|
604 |
-
# sync random states for DataLoader(shuffle=True) before `accelerator.prepare`
|
605 |
-
# see https://gist.github.com/vwxyzjn/2581bff1e48e185e0b85b6dfe1def79c
|
606 |
-
torch.manual_seed(args.seed)
|
607 |
-
self.model, self.optimizer, self.dataloader = accelerator.prepare(self.model, self.optimizer, self.dataloader)
|
608 |
-
torch.manual_seed(self.local_seed) # reset the local seed again
|
609 |
-
|
610 |
-
self.eval_dataloader = DataLoader(
|
611 |
-
self.eval_dataset,
|
612 |
-
batch_size=args.per_device_eval_batch_size,
|
613 |
-
collate_fn=self.data_collator,
|
614 |
-
drop_last=True,
|
615 |
-
) # no need to shuffle eval dataset
|
616 |
-
self.eval_dataloader = accelerator.prepare(self.eval_dataloader)
|
617 |
-
|
618 |
-
if self.is_deepspeed_enabled:
|
619 |
-
self.reward_model = prepare_deepspeed(
|
620 |
-
self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16
|
621 |
-
)
|
622 |
-
|
623 |
-
if self.ref_model is None:
|
624 |
-
if not self.is_peft_model:
|
625 |
-
raise ValueError("No reference model and model is not a Peft model.")
|
626 |
-
else:
|
627 |
-
self.ref_model = prepare_deepspeed(
|
628 |
-
self.ref_model, args.per_device_train_batch_size, args.fp16, args.bf16
|
629 |
-
)
|
630 |
-
else:
|
631 |
-
if self.ref_model is None:
|
632 |
-
if not self.is_peft_model:
|
633 |
-
raise ValueError("No reference model and model is not a Peft model.")
|
634 |
-
else:
|
635 |
-
self.ref_model = self.ref_model.to(self.accelerator.device)
|
636 |
-
self.reward_model = self.reward_model.to(self.accelerator.device)
|
637 |
-
|
638 |
-
def get_train_dataloader(self) -> DataLoader:
|
639 |
-
return self.dataloader
|
640 |
-
|
641 |
-
def get_eval_dataloader(self) -> DataLoader:
|
642 |
-
return self.eval_dataloader
|
643 |
-
|
644 |
-
@contextmanager
|
645 |
-
def null_ref_context(self):
|
646 |
-
"""Context manager for handling null reference model (that is, peft adapter manipulation)."""
|
647 |
-
with (
|
648 |
-
self.accelerator.unwrap_model(self.model.policy).disable_adapter()
|
649 |
-
if self.is_peft_model and not self.ref_adapter_name
|
650 |
-
else nullcontext()
|
651 |
-
):
|
652 |
-
if self.ref_adapter_name:
|
653 |
-
self.model.policy.set_adapter(self.ref_adapter_name)
|
654 |
-
yield
|
655 |
-
if self.ref_adapter_name:
|
656 |
-
self.model.policy.set_adapter(self.model_adapter_name or "default")
|
657 |
-
|
658 |
-
def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
|
659 |
-
backup_model = self.model
|
660 |
-
self.model = self.model.policy # save only the policy
|
661 |
-
|
662 |
-
if self.is_deepspeed_enabled:
|
663 |
-
backup_deepspeed = self.deepspeed
|
664 |
-
self.deepspeed = self.model
|
665 |
-
|
666 |
-
super().save_model(output_dir, _internal_call)
|
667 |
-
|
668 |
-
self.model = backup_model
|
669 |
-
|
670 |
-
if self.is_deepspeed_enabled:
|
671 |
-
self.deepspeed = backup_deepspeed
|
672 |
-
|
673 |
-
def train(self):
|
674 |
-
args = self.args
|
675 |
-
accelerator = self.accelerator
|
676 |
-
optimizer = self.optimizer
|
677 |
-
model = self.model
|
678 |
-
ref_policy = self.ref_model
|
679 |
-
reward_model = self.reward_model
|
680 |
-
processing_class = self.processing_class
|
681 |
-
dataloader = self.dataloader
|
682 |
-
device = accelerator.device
|
683 |
-
|
684 |
-
def repeat_generator():
|
685 |
-
while True:
|
686 |
-
yield from dataloader
|
687 |
-
|
688 |
-
iter_dataloader = iter(repeat_generator())
|
689 |
-
generation_config = GenerationConfig(
|
690 |
-
max_new_tokens=args.response_length,
|
691 |
-
temperature=(args.temperature + 1e-7),
|
692 |
-
top_k=0.0,
|
693 |
-
top_p=1.0,
|
694 |
-
do_sample=True,
|
695 |
-
)
|
696 |
-
|
697 |
-
accelerator.print("===training policy===")
|
698 |
-
start_time = time.time()
|
699 |
-
stats_shape = (args.num_ppo_epochs, args.num_mini_batches, args.gradient_accumulation_steps)
|
700 |
-
approxkl_stats = torch.zeros(stats_shape, device=device)
|
701 |
-
pg_clipfrac_stats = torch.zeros(stats_shape, device=device)
|
702 |
-
pg_loss_stats = torch.zeros(stats_shape, device=device)
|
703 |
-
vf_loss_stats = torch.zeros(stats_shape, device=device)
|
704 |
-
vf_clipfrac_stats = torch.zeros(stats_shape, device=device)
|
705 |
-
entropy_stats = torch.zeros(stats_shape, device=device)
|
706 |
-
ratio_stats = torch.zeros(stats_shape, device=device)
|
707 |
-
model.train()
|
708 |
-
|
709 |
-
# trainer state initialization
|
710 |
-
self.state.global_step = 0
|
711 |
-
self.state.episode = 0
|
712 |
-
self.state.max_steps = args.num_total_batches * args.num_mini_batches
|
713 |
-
self.state.num_train_epochs = args.total_episodes / self.train_dataset_len
|
714 |
-
# Compute absolute values for logging, eval, and save if given as ratio
|
715 |
-
if args.logging_steps is not None:
|
716 |
-
if args.logging_steps < 1:
|
717 |
-
self.state.logging_steps = math.ceil(self.state.max_steps * args.logging_steps)
|
718 |
-
else:
|
719 |
-
self.state.logging_steps = args.logging_steps
|
720 |
-
if args.eval_steps is not None:
|
721 |
-
if args.eval_steps < 1:
|
722 |
-
self.state.eval_steps = math.ceil(self.state.max_steps * args.eval_steps)
|
723 |
-
else:
|
724 |
-
self.state.eval_steps = args.eval_steps
|
725 |
-
if args.save_steps is not None:
|
726 |
-
if args.save_steps < 1:
|
727 |
-
self.state.save_steps = math.ceil(self.state.max_steps * args.save_steps)
|
728 |
-
else:
|
729 |
-
self.state.save_steps = args.save_steps
|
730 |
-
self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
|
731 |
-
|
732 |
-
# backward compatibility
|
733 |
-
if self.is_deepspeed_enabled:
|
734 |
-
self.deepspeed = self.model
|
735 |
-
self.model_wrapped = self.model
|
736 |
-
|
737 |
-
for update in range(1, args.num_total_batches + 1):
|
738 |
-
self.state.episode += 1 * args.batch_size
|
739 |
-
data = next(iter_dataloader)
|
740 |
-
with torch.no_grad():
|
741 |
-
queries = data["input_ids"].to(device)
|
742 |
-
context_length = queries.shape[1]
|
743 |
-
responses = []
|
744 |
-
postprocessed_responses = []
|
745 |
-
logprobs = []
|
746 |
-
ref_logprobs = []
|
747 |
-
scores = []
|
748 |
-
sequence_lengths = []
|
749 |
-
values = []
|
750 |
-
with unwrap_model_for_generation(
|
751 |
-
self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
|
752 |
-
) as unwrapped_model:
|
753 |
-
query_responses, logitss = batch_generation(
|
754 |
-
unwrapped_model.policy,
|
755 |
-
queries,
|
756 |
-
args.local_rollout_forward_batch_size,
|
757 |
-
processing_class.pad_token_id,
|
758 |
-
generation_config,
|
759 |
-
)
|
760 |
-
|
761 |
-
for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size):
|
762 |
-
query = queries[i : i + args.local_rollout_forward_batch_size]
|
763 |
-
query_response = query_responses[i : i + args.local_rollout_forward_batch_size]
|
764 |
-
response = query_response[:, context_length:]
|
765 |
-
logits = logitss[i : i + args.local_rollout_forward_batch_size]
|
766 |
-
logprob = selective_log_softmax(logits, response)
|
767 |
-
del logits
|
768 |
-
torch.cuda.empty_cache()
|
769 |
-
|
770 |
-
if ref_policy is None:
|
771 |
-
with self.null_ref_context():
|
772 |
-
ref_output = forward(model.policy, query_response, processing_class.pad_token_id)
|
773 |
-
else:
|
774 |
-
ref_output = forward(ref_policy, query_response, processing_class.pad_token_id)
|
775 |
-
ref_logits = ref_output.logits[:, context_length - 1 : -1]
|
776 |
-
ref_logits /= args.temperature + 1e-7
|
777 |
-
ref_logprob = selective_log_softmax(ref_logits, response)
|
778 |
-
del ref_output, ref_logits
|
779 |
-
torch.cuda.empty_cache()
|
780 |
-
|
781 |
-
# Response Processing 1. truncate response after the first occurrence of `stop_token_id`
|
782 |
-
postprocessed_response = response
|
783 |
-
if self.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
|
784 |
-
postprocessed_response = truncate_response(
|
785 |
-
self.stop_token_id, processing_class.pad_token_id, response
|
786 |
-
)
|
787 |
-
|
788 |
-
# Response Processing 2. run reward model on the truncated responses
|
789 |
-
postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
|
790 |
-
sequence_length = first_true_indices(postprocessed_response == processing_class.pad_token_id) - 1
|
791 |
-
unwrapped_value_model = accelerator.unwrap_model(model).value_model
|
792 |
-
full_value, _, _ = get_reward(
|
793 |
-
unwrapped_value_model, query_response, processing_class.pad_token_id, context_length
|
794 |
-
)
|
795 |
-
value = full_value[:, context_length - 1 : -1].squeeze(-1)
|
796 |
-
_, score, _ = get_reward(
|
797 |
-
reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length
|
798 |
-
)
|
799 |
-
|
800 |
-
responses.append(response)
|
801 |
-
postprocessed_responses.append(postprocessed_response)
|
802 |
-
logprobs.append(logprob)
|
803 |
-
ref_logprobs.append(ref_logprob)
|
804 |
-
sequence_lengths.append(sequence_length)
|
805 |
-
scores.append(score)
|
806 |
-
values.append(value)
|
807 |
-
responses = torch.cat(responses, 0)
|
808 |
-
postprocessed_responses = torch.cat(postprocessed_responses, 0)
|
809 |
-
logprobs = torch.cat(logprobs, 0)
|
810 |
-
ref_logprobs = torch.cat(ref_logprobs, 0)
|
811 |
-
sequence_lengths = torch.cat(sequence_lengths, 0)
|
812 |
-
scores = torch.cat(scores, 0)
|
813 |
-
values = torch.cat(values, 0)
|
814 |
-
del (logprob, ref_logprob, full_value, value, score, unwrapped_model)
|
815 |
-
torch.cuda.empty_cache()
|
816 |
-
gc.collect()
|
817 |
-
|
818 |
-
# Response Processing 3. Filter completion. Ensure that the sample contains stop_token_id
|
819 |
-
# Completions not passing that filter will receive a lower score.
|
820 |
-
contain_eos_token = torch.any(postprocessed_responses == self.processing_class.eos_token_id, dim=-1)
|
821 |
-
if self.args.missing_eos_penalty is not None:
|
822 |
-
scores[~contain_eos_token] -= self.args.missing_eos_penalty
|
823 |
-
# accelerator.print(f"{scores=}, {(contain_eos_token.sum() / len(contain_eos_token))=}")
|
824 |
-
|
825 |
-
# be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw
|
826 |
-
response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1)
|
827 |
-
padding_mask = response_idxs > sequence_lengths.unsqueeze(1)
|
828 |
-
logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB)
|
829 |
-
ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB)
|
830 |
-
sequence_lengths_p1 = sequence_lengths + 1
|
831 |
-
padding_mask_p1 = response_idxs > (sequence_lengths_p1.unsqueeze(1))
|
832 |
-
values = torch.masked_fill(values, padding_mask_p1, 0)
|
833 |
-
|
834 |
-
# 4. compute rewards
|
835 |
-
kl = logprobs - ref_logprobs
|
836 |
-
non_score_reward = -args.kl_coef * kl
|
837 |
-
rewards = non_score_reward.clone()
|
838 |
-
actual_start = torch.arange(rewards.size(0), device=rewards.device)
|
839 |
-
actual_end = torch.where(sequence_lengths_p1 < rewards.size(1), sequence_lengths_p1, sequence_lengths)
|
840 |
-
rewards[[actual_start, actual_end]] += scores
|
841 |
-
|
842 |
-
# 5. whiten rewards
|
843 |
-
if args.whiten_rewards:
|
844 |
-
rewards = masked_whiten(rewards, mask=~padding_mask_p1, shift_mean=False)
|
845 |
-
rewards = torch.masked_fill(rewards, padding_mask_p1, 0)
|
846 |
-
|
847 |
-
# 6. compute advantages and returns
|
848 |
-
lastgaelam = 0
|
849 |
-
advantages_reversed = []
|
850 |
-
gen_length = responses.shape[1]
|
851 |
-
for t in reversed(range(gen_length)):
|
852 |
-
nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0
|
853 |
-
delta = rewards[:, t] + args.gamma * nextvalues - values[:, t]
|
854 |
-
lastgaelam = delta + args.gamma * args.lam * lastgaelam
|
855 |
-
advantages_reversed.append(lastgaelam)
|
856 |
-
advantages = torch.stack(advantages_reversed[::-1], axis=1)
|
857 |
-
returns = advantages + values
|
858 |
-
advantages = masked_whiten(advantages, ~padding_mask)
|
859 |
-
advantages = torch.masked_fill(advantages, padding_mask, 0)
|
860 |
-
torch.cuda.empty_cache()
|
861 |
-
|
862 |
-
# Do multiple epochs of PPO training, with a fresh random shuffle in each epoch
|
863 |
-
for ppo_epoch_idx in range(args.num_ppo_epochs):
|
864 |
-
b_inds = np.random.permutation(args.local_batch_size)
|
865 |
-
minibatch_idx = 0
|
866 |
-
for mini_batch_start in range(0, args.local_batch_size, args.local_mini_batch_size):
|
867 |
-
mini_batch_end = mini_batch_start + args.local_mini_batch_size
|
868 |
-
mini_batch_inds = b_inds[mini_batch_start:mini_batch_end]
|
869 |
-
gradient_accumulation_idx = 0
|
870 |
-
for micro_batch_start in range(0, args.local_mini_batch_size, args.per_device_train_batch_size):
|
871 |
-
with accelerator.accumulate(model):
|
872 |
-
micro_batch_end = micro_batch_start + args.per_device_train_batch_size
|
873 |
-
micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end]
|
874 |
-
mb_advantage = advantages[micro_batch_inds]
|
875 |
-
mb_responses = responses[micro_batch_inds]
|
876 |
-
mb_query_responses = query_responses[micro_batch_inds]
|
877 |
-
mb_logprobs = logprobs[micro_batch_inds]
|
878 |
-
mb_return = returns[micro_batch_inds]
|
879 |
-
mb_values = values[micro_batch_inds]
|
880 |
-
|
881 |
-
output, vpred_temp = forward(model, mb_query_responses, processing_class.pad_token_id)
|
882 |
-
logits = output.logits[:, context_length - 1 : -1]
|
883 |
-
logits /= args.temperature + 1e-7
|
884 |
-
new_logprobs = selective_log_softmax(logits, mb_responses)
|
885 |
-
new_logprobs = torch.masked_fill(
|
886 |
-
new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB
|
887 |
-
)
|
888 |
-
vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1)
|
889 |
-
vpred = torch.masked_fill(vpred, padding_mask_p1[micro_batch_inds], 0)
|
890 |
-
vpredclipped = torch.clamp(
|
891 |
-
vpred,
|
892 |
-
mb_values - args.cliprange_value,
|
893 |
-
mb_values + args.cliprange_value,
|
894 |
-
)
|
895 |
-
vf_losses1 = torch.square(vpred - mb_return)
|
896 |
-
vf_losses2 = torch.square(vpredclipped - mb_return)
|
897 |
-
vf_loss_max = torch.max(vf_losses1, vf_losses2)
|
898 |
-
vf_loss = 0.5 * masked_mean(vf_loss_max, ~padding_mask_p1[micro_batch_inds])
|
899 |
-
vf_clipfrac = masked_mean(
|
900 |
-
(vf_losses2 > vf_losses1).float(), ~padding_mask_p1[micro_batch_inds]
|
901 |
-
)
|
902 |
-
logprobs_diff = new_logprobs - mb_logprobs
|
903 |
-
ratio = torch.exp(logprobs_diff)
|
904 |
-
pg_losses = -mb_advantage * ratio
|
905 |
-
pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange)
|
906 |
-
pg_loss_max = torch.max(pg_losses, pg_losses2)
|
907 |
-
pg_loss = masked_mean(pg_loss_max, ~padding_mask[micro_batch_inds])
|
908 |
-
loss = pg_loss + args.vf_coef * vf_loss
|
909 |
-
accelerator.backward(loss)
|
910 |
-
optimizer.step()
|
911 |
-
optimizer.zero_grad()
|
912 |
-
with torch.no_grad():
|
913 |
-
pg_clipfrac = masked_mean(
|
914 |
-
(pg_losses2 > pg_losses).float(), ~padding_mask[micro_batch_inds]
|
915 |
-
)
|
916 |
-
prob_dist = torch.nn.functional.softmax(logits, dim=-1)
|
917 |
-
entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1)
|
918 |
-
approxkl = 0.5 * (logprobs_diff**2).mean()
|
919 |
-
approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl
|
920 |
-
pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = (
|
921 |
-
pg_clipfrac
|
922 |
-
)
|
923 |
-
pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss
|
924 |
-
vf_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss
|
925 |
-
vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = (
|
926 |
-
vf_clipfrac
|
927 |
-
)
|
928 |
-
entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean()
|
929 |
-
ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean()
|
930 |
-
gradient_accumulation_idx += 1
|
931 |
-
minibatch_idx += 1
|
932 |
-
# del everything and empty cache
|
933 |
-
# fmt: off
|
934 |
-
del (
|
935 |
-
output, vpred_temp, logits, new_logprobs, vpred, vpredclipped,
|
936 |
-
vf_losses1, vf_losses2, vf_loss, vf_clipfrac, logprobs_diff, ratio, pg_losses, pg_losses2, pg_loss_max,
|
937 |
-
pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl, mb_return,
|
938 |
-
mb_advantage, mb_values, mb_responses, mb_query_responses, mb_logprobs,
|
939 |
-
)
|
940 |
-
# fmt: on
|
941 |
-
torch.cuda.empty_cache()
|
942 |
-
with torch.no_grad():
|
943 |
-
mean_kl = kl.sum(1).mean()
|
944 |
-
mean_entropy = (-logprobs).sum(1).mean()
|
945 |
-
mean_non_score_reward = non_score_reward.sum(1).mean()
|
946 |
-
rlhf_reward = mean_non_score_reward + scores.mean()
|
947 |
-
eps = int(self.state.episode / (time.time() - start_time))
|
948 |
-
metrics = {}
|
949 |
-
metrics["eps"] = eps
|
950 |
-
metrics["objective/kl"] = self.accelerator.gather_for_metrics(mean_kl).mean().item()
|
951 |
-
metrics["objective/entropy"] = self.accelerator.gather_for_metrics(mean_entropy).mean().item()
|
952 |
-
metrics["objective/non_score_reward"] = (
|
953 |
-
self.accelerator.gather_for_metrics(mean_non_score_reward).mean().item()
|
954 |
-
)
|
955 |
-
metrics["objective/rlhf_reward"] = self.accelerator.gather_for_metrics(rlhf_reward).mean().item()
|
956 |
-
metrics["objective/scores"] = self.accelerator.gather_for_metrics(scores.mean()).mean().item()
|
957 |
-
metrics["policy/approxkl_avg"] = self.accelerator.gather_for_metrics(approxkl_stats).mean().item()
|
958 |
-
metrics["policy/clipfrac_avg"] = self.accelerator.gather_for_metrics(pg_clipfrac_stats).mean().item()
|
959 |
-
metrics["loss/policy_avg"] = self.accelerator.gather_for_metrics(pg_loss_stats).mean().item()
|
960 |
-
metrics["loss/value_avg"] = self.accelerator.gather_for_metrics(vf_loss_stats).mean().item()
|
961 |
-
metrics["val/clipfrac_avg"] = self.accelerator.gather_for_metrics(vf_clipfrac_stats).mean().item()
|
962 |
-
metrics["policy/entropy_avg"] = self.accelerator.gather_for_metrics(entropy_stats).mean().item()
|
963 |
-
metrics["val/ratio"] = self.accelerator.gather_for_metrics(ratio_stats).mean().item()
|
964 |
-
metrics["val/ratio_var"] = self.accelerator.gather_for_metrics(ratio_stats).var().item()
|
965 |
-
metrics["val/num_eos_tokens"] = (responses == processing_class.eos_token_id).sum().item()
|
966 |
-
metrics["lr"] = self.lr_scheduler.get_last_lr()[0]
|
967 |
-
metrics["episode"] = self.state.episode
|
968 |
-
self.state.epoch = self.state.episode / self.train_dataset_len # used by self.log
|
969 |
-
self.state.global_step += 1
|
970 |
-
self.log(metrics)
|
971 |
-
|
972 |
-
self.lr_scheduler.step()
|
973 |
-
self.control = self.callback_handler.on_step_end(args, self.state, self.control)
|
974 |
-
if self.control.should_save:
|
975 |
-
self._save_checkpoint(model, trial=None)
|
976 |
-
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
|
977 |
-
del kl, mean_kl, mean_entropy, mean_non_score_reward, scores, metrics, non_score_reward
|
978 |
-
torch.cuda.empty_cache()
|
979 |
-
gc.collect()
|
980 |
-
|
981 |
-
if args.num_sample_generations > 0 and (update - 1) % self.sample_generations_freq == 0:
|
982 |
-
self.generate_completions(sampling=True)
|
983 |
-
torch.cuda.empty_cache()
|
984 |
-
del (
|
985 |
-
query_responses,
|
986 |
-
responses,
|
987 |
-
postprocessed_responses,
|
988 |
-
logprobs,
|
989 |
-
ref_logprobs,
|
990 |
-
values,
|
991 |
-
sequence_lengths,
|
992 |
-
contain_eos_token,
|
993 |
-
sequence_lengths_p1,
|
994 |
-
response_idxs,
|
995 |
-
padding_mask,
|
996 |
-
padding_mask_p1,
|
997 |
-
rewards,
|
998 |
-
actual_start,
|
999 |
-
actual_end,
|
1000 |
-
advantages,
|
1001 |
-
returns,
|
1002 |
-
)
|
1003 |
-
torch.cuda.empty_cache()
|
1004 |
-
|
1005 |
-
# HF trainer specifics
|
1006 |
-
self.control = self.callback_handler.on_train_end(args, self.state, self.control)
|
1007 |
-
if self.control.should_save:
|
1008 |
-
self._save_checkpoint(model, trial=None, metrics=None)
|
1009 |
-
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
|
1010 |
-
|
1011 |
-
def generate_completions(self, sampling: bool = False):
|
1012 |
-
args = self.args
|
1013 |
-
processing_class = self.processing_class
|
1014 |
-
generation_config = GenerationConfig(
|
1015 |
-
max_new_tokens=self.args.response_length,
|
1016 |
-
temperature=(0.01 + 1e-7),
|
1017 |
-
top_k=0.0,
|
1018 |
-
top_p=1.0,
|
1019 |
-
do_sample=True,
|
1020 |
-
)
|
1021 |
-
|
1022 |
-
table = defaultdict(list)
|
1023 |
-
with unwrap_model_for_generation(
|
1024 |
-
self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
|
1025 |
-
) as unwrapped_model:
|
1026 |
-
for batch in self.eval_dataloader:
|
1027 |
-
query = batch["input_ids"]
|
1028 |
-
with torch.no_grad():
|
1029 |
-
context_length = query.shape[1]
|
1030 |
-
query_response, _ = batch_generation(
|
1031 |
-
unwrapped_model.policy,
|
1032 |
-
query,
|
1033 |
-
query.shape[0],
|
1034 |
-
processing_class.pad_token_id,
|
1035 |
-
generation_config,
|
1036 |
-
)
|
1037 |
-
response = query_response[:, context_length:]
|
1038 |
-
postprocessed_response = response
|
1039 |
-
if self.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
|
1040 |
-
postprocessed_response = truncate_response(
|
1041 |
-
self.stop_token_id, processing_class.pad_token_id, response
|
1042 |
-
)
|
1043 |
-
table["query"].extend(
|
1044 |
-
gather_object(processing_class.batch_decode(query, skip_special_tokens=True))
|
1045 |
-
)
|
1046 |
-
table["model response"].extend(
|
1047 |
-
gather_object(processing_class.batch_decode(postprocessed_response))
|
1048 |
-
)
|
1049 |
-
|
1050 |
-
postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
|
1051 |
-
_, score, _ = get_reward(
|
1052 |
-
self.reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length
|
1053 |
-
)
|
1054 |
-
table["score"].extend(self.accelerator.gather_for_metrics(score).float().cpu().numpy())
|
1055 |
-
|
1056 |
-
if sampling:
|
1057 |
-
break
|
1058 |
-
df = pd.DataFrame(table)
|
1059 |
-
|
1060 |
-
if self.accelerator.is_main_process:
|
1061 |
-
print_rich_table(df.iloc[0 : 0 + 5])
|
1062 |
-
if "wandb" in args.report_to:
|
1063 |
-
import wandb
|
1064 |
-
|
1065 |
-
if wandb.run is not None:
|
1066 |
-
wandb.log({"completions": wandb.Table(dataframe=df)})
|
1067 |
-
|
1068 |
-
if "comet_ml" in args.report_to:
|
1069 |
-
log_table_to_comet_experiment(
|
1070 |
-
name="completions.csv",
|
1071 |
-
table=df,
|
1072 |
-
)
|
1073 |
-
|
1074 |
-
def create_model_card(
|
1075 |
-
self,
|
1076 |
-
model_name: Optional[str] = None,
|
1077 |
-
dataset_name: Optional[str] = None,
|
1078 |
-
tags: Union[str, list[str], None] = None,
|
1079 |
-
):
|
1080 |
-
"""
|
1081 |
-
Creates a draft of a model card using the information available to the `Trainer`.
|
1082 |
-
|
1083 |
-
Args:
|
1084 |
-
model_name (`str` or `None`, *optional*, defaults to `None`):
|
1085 |
-
Name of the model.
|
1086 |
-
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
1087 |
-
Name of the dataset used for training.
|
1088 |
-
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
1089 |
-
Tags to be associated with the model card.
|
1090 |
-
"""
|
1091 |
-
if not self.is_world_process_zero():
|
1092 |
-
return
|
1093 |
-
|
1094 |
-
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
1095 |
-
base_model = self.model.config._name_or_path
|
1096 |
-
else:
|
1097 |
-
base_model = None
|
1098 |
-
|
1099 |
-
tags = tags or []
|
1100 |
-
if isinstance(tags, str):
|
1101 |
-
tags = [tags]
|
1102 |
-
|
1103 |
-
if hasattr(self.model.config, "unsloth_version"):
|
1104 |
-
tags.append("unsloth")
|
1105 |
-
|
1106 |
-
citation = textwrap.dedent("""\
|
1107 |
-
@article{mziegler2019fine-tuning,
|
1108 |
-
title = {{Fine-Tuning Language Models from Human Preferences}},
|
1109 |
-
author = {Daniel M. Ziegler and Nisan Stiennon and Jeffrey Wu and Tom B. Brown and Alec Radford and Dario Amodei and Paul F. Christiano and Geoffrey Irving},
|
1110 |
-
year = 2019,
|
1111 |
-
eprint = {arXiv:1909.08593}
|
1112 |
-
}""")
|
1113 |
-
|
1114 |
-
model_card = generate_model_card(
|
1115 |
-
base_model=base_model,
|
1116 |
-
model_name=model_name,
|
1117 |
-
hub_model_id=self.hub_model_id,
|
1118 |
-
dataset_name=dataset_name,
|
1119 |
-
tags=tags,
|
1120 |
-
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
1121 |
-
comet_url=get_comet_experiment_url(),
|
1122 |
-
trainer_name="PPO",
|
1123 |
-
trainer_citation=citation,
|
1124 |
-
paper_title="Fine-Tuning Language Models from Human Preferences",
|
1125 |
-
paper_id="1909.08593",
|
1126 |
-
)
|
1127 |
-
|
1128 |
-
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
1129 |
-
class UnslothPPOTrainer(_UnslothPPOTrainer):
|
1130 |
-
"""
|
1131 |
-
|
1132 |
-
"""
|
1133 |
-
def __init__(
|
1134 |
-
self,
|
1135 |
-
args,
|
1136 |
-
processing_class,
|
1137 |
-
model,
|
1138 |
-
ref_model,
|
1139 |
-
reward_model,
|
1140 |
-
train_dataset,
|
1141 |
-
value_model = None,
|
1142 |
-
data_collator = None,
|
1143 |
-
eval_dataset = None,
|
1144 |
-
callbacks = None,
|
1145 |
-
peft_config = None,
|
1146 |
-
**kwargs
|
1147 |
-
):
|
1148 |
-
if args is None: args = UnslothPPOConfig()
|
1149 |
-
use_bf16 = getattr(args, 'bf16', False)
|
1150 |
-
use_fp16 = getattr(args, 'fp16', False)
|
1151 |
-
force_float32 = False
|
1152 |
-
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
1153 |
-
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
1154 |
-
force_float32 = True
|
1155 |
-
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
1156 |
-
dtype = getattr(model.config, 'torch_dtype', None)
|
1157 |
-
if dtype is None: dtype = model.get_input_embeddings().dtype
|
1158 |
-
from unsloth_zoo.utils import _get_dtype
|
1159 |
-
dtype = _get_dtype(dtype)
|
1160 |
-
float16 = dtype == torch.float16
|
1161 |
-
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
1162 |
-
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
1163 |
-
if force_float32:
|
1164 |
-
args.fp16 = False
|
1165 |
-
args.bf16 = False
|
1166 |
-
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
1167 |
-
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
1168 |
-
args.fp16 = float16
|
1169 |
-
args.bf16 = not float16
|
1170 |
-
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
1171 |
-
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
1172 |
-
args.eval_strategy = 'steps'
|
1173 |
-
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
1174 |
-
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
1175 |
-
if ga_steps is not None and ga_steps > 1:
|
1176 |
-
from transformers import __version__ as transformers_version
|
1177 |
-
if Version(transformers_version) <= Version('4.45.2'):
|
1178 |
-
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
1179 |
-
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
1180 |
-
if getattr(args, 'eval_strategy', 'no') != 'no':
|
1181 |
-
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
1182 |
-
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
1183 |
-
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
1184 |
-
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
1185 |
-
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
1186 |
-
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
1187 |
-
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
1188 |
-
if force_float32:
|
1189 |
-
args.bf16_full_eval = False
|
1190 |
-
args.fp16_full_eval = False
|
1191 |
-
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
1192 |
-
args.bf16_full_eval = True
|
1193 |
-
args.fp16_full_eval = False
|
1194 |
-
elif not bf16_full_eval and not fp16_full_eval:
|
1195 |
-
args.bf16_full_eval = args.bf16
|
1196 |
-
args.fp16_full_eval = args.fp16
|
1197 |
-
_output_logits = False
|
1198 |
-
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
1199 |
-
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
1200 |
-
if _output_logits:
|
1201 |
-
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
1202 |
-
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
1203 |
-
pass
|
1204 |
-
else:
|
1205 |
-
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
1206 |
-
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
1207 |
-
if args_max_seq_length is None and model_max_seq_length is not None:
|
1208 |
-
max_seq_length = model.max_seq_length
|
1209 |
-
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
1210 |
-
if model is not None and hasattr(model, 'for_training'):
|
1211 |
-
model.for_training()
|
1212 |
-
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
1213 |
-
if 'processing_class' in locals():
|
1214 |
-
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
1215 |
-
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
1216 |
-
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
1217 |
-
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
1218 |
-
if not isinstance(data_collator, UnslothVisionDataCollator):
|
1219 |
-
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
1220 |
-
data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
|
1221 |
-
elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
1222 |
-
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
1223 |
-
else:
|
1224 |
-
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
1225 |
-
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
1226 |
-
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
1227 |
-
if not isinstance(data_collator, UnslothVisionDataCollator):
|
1228 |
-
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
1229 |
-
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
1230 |
-
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
1231 |
-
else:
|
1232 |
-
data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
|
1233 |
-
other_metrics = []
|
1234 |
-
|
1235 |
-
from unsloth_zoo.logging_utils import PatchRLStatistics
|
1236 |
-
PatchRLStatistics('ppo_trainer', other_metrics)
|
1237 |
-
|
1238 |
-
super().__init__(
|
1239 |
-
args = args,
|
1240 |
-
processing_class = processing_class,
|
1241 |
-
model = model,
|
1242 |
-
ref_model = ref_model,
|
1243 |
-
reward_model = reward_model,
|
1244 |
-
train_dataset = train_dataset,
|
1245 |
-
value_model = value_model,
|
1246 |
-
data_collator = data_collator,
|
1247 |
-
eval_dataset = eval_dataset,
|
1248 |
-
callbacks = callbacks,
|
1249 |
-
peft_config = peft_config,**kwargs)
|
1250 |
-
if hasattr(self, 'neftune_hook_handle'):
|
1251 |
-
self.neftune_hook_handle.remove()
|
1252 |
-
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
1253 |
-
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
1254 |
-
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
1255 |
-
pass
|
1256 |
-
|
1257 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unsloth_compiled_cache/UnslothPRMTrainer.py
DELETED
@@ -1,798 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
2025.3.13
|
3 |
-
2025.3.15
|
4 |
-
4.48.3
|
5 |
-
0.15.2
|
6 |
-
__UNSLOTH_VERSIONING__
|
7 |
-
"""
|
8 |
-
from torch import Tensor
|
9 |
-
import torch
|
10 |
-
import torch.nn as nn
|
11 |
-
from torch.nn import functional as F
|
12 |
-
from trl.trainer.prm_trainer import (BaseImageProcessor, Callable, DataCollator, DataCollatorForTokenClassification, Dataset, EvalPrediction, FeatureExtractionMixin, Optional, PRMConfig, PRMTrainer, PartialState, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, Trainer, TrainerCallback, Union, chain, compute_accuracy, disable_dropout_in_model, features, generate_model_card, inspect, is_peft_available, is_wandb_available, nn, os, prepare_model_for_kbit_training, textwrap, torch, wandb, warnings)
|
13 |
-
|
14 |
-
|
15 |
-
import os
|
16 |
-
from typing import *
|
17 |
-
from dataclasses import dataclass, field
|
18 |
-
from packaging.version import Version
|
19 |
-
import torch
|
20 |
-
import numpy as np
|
21 |
-
from contextlib import nullcontext
|
22 |
-
from torch.nn import functional as F
|
23 |
-
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
|
24 |
-
|
25 |
-
torch_compile_options = {
|
26 |
-
"epilogue_fusion" : True,
|
27 |
-
"max_autotune" : False,
|
28 |
-
"shape_padding" : True,
|
29 |
-
"trace.enabled" : False,
|
30 |
-
"triton.cudagraphs" : False,
|
31 |
-
}
|
32 |
-
|
33 |
-
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
34 |
-
def selective_log_softmax(logits, index):
|
35 |
-
logits = logits.to(torch.float32)
|
36 |
-
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
37 |
-
# loop to reduce peak mem consumption
|
38 |
-
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
39 |
-
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
40 |
-
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
41 |
-
return per_token_logps
|
42 |
-
@dataclass
|
43 |
-
class UnslothPRMConfig(PRMConfig):
|
44 |
-
"""
|
45 |
-
|
46 |
-
Configuration class for the [`PRMTrainer`].
|
47 |
-
|
48 |
-
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
49 |
-
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
50 |
-
command line.
|
51 |
-
|
52 |
-
Parameters:
|
53 |
-
learning_rate (`float`, *optional*, defaults to `1e-5`):
|
54 |
-
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
|
55 |
-
[`~transformers.TrainingArguments`].
|
56 |
-
max_length (`int` or `None`, *optional*, defaults to `1024`):
|
57 |
-
Maximum length of the sequences (prompt + completion) used for truncation.
|
58 |
-
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
|
59 |
-
Maximum length of the prompt used for truncation.
|
60 |
-
max_completion_length (`int` or `None`, *optional*, defaults to `None`):
|
61 |
-
Maximum length of the completion used for truncation. The completion is the concatenation of the steps.
|
62 |
-
disable_dropout (`bool`, *optional*, defaults to `True`):
|
63 |
-
Whether to disable dropout in the model.
|
64 |
-
step_separator (`str`, *optional*, defaults to `"\n"`):
|
65 |
-
Separator used to separate each step of the reasoning process.
|
66 |
-
train_on_last_step_only (`bool`, *optional*, defaults to `False`):
|
67 |
-
Whether to train only on the last step.
|
68 |
-
dataset_num_proc (`int`, *optional*, defaults to `None`):
|
69 |
-
Number of processes to use for processing the dataset.
|
70 |
-
|
71 |
-
"""
|
72 |
-
vllm_sampling_params: Optional[Any] = field(
|
73 |
-
default = None,
|
74 |
-
metadata = {'help': 'vLLM SamplingParams'},
|
75 |
-
)
|
76 |
-
unsloth_num_chunks : Optional[int] = field(
|
77 |
-
default = -1,
|
78 |
-
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
79 |
-
)
|
80 |
-
def __init__(
|
81 |
-
self,
|
82 |
-
output_dir = None,
|
83 |
-
overwrite_output_dir = None,
|
84 |
-
do_train = False,
|
85 |
-
do_eval = False,
|
86 |
-
do_predict = False,
|
87 |
-
eval_strategy = 'no',
|
88 |
-
prediction_loss_only = False,
|
89 |
-
per_device_train_batch_size = 4,
|
90 |
-
per_device_eval_batch_size = 4,
|
91 |
-
per_gpu_train_batch_size = None,
|
92 |
-
per_gpu_eval_batch_size = None,
|
93 |
-
gradient_accumulation_steps = 2,
|
94 |
-
eval_accumulation_steps = 2,
|
95 |
-
eval_delay = 0,
|
96 |
-
torch_empty_cache_steps = 250,
|
97 |
-
learning_rate = 5e-05,
|
98 |
-
weight_decay = 0.01,
|
99 |
-
adam_beta1 = 0.9,
|
100 |
-
adam_beta2 = 0.999,
|
101 |
-
adam_epsilon = 1e-08,
|
102 |
-
max_grad_norm = 1.0,
|
103 |
-
num_train_epochs = 3.0,
|
104 |
-
max_steps = -1,
|
105 |
-
lr_scheduler_type = 'linear',
|
106 |
-
warmup_ratio = 0.1,
|
107 |
-
warmup_steps = 0,
|
108 |
-
log_level = 'passive',
|
109 |
-
log_level_replica = 'warning',
|
110 |
-
log_on_each_node = True,
|
111 |
-
logging_dir = None,
|
112 |
-
logging_strategy = 'steps',
|
113 |
-
logging_first_step = False,
|
114 |
-
logging_steps = 1,
|
115 |
-
logging_nan_inf_filter = False,
|
116 |
-
save_strategy = 'steps',
|
117 |
-
save_steps = 500,
|
118 |
-
save_total_limit = None,
|
119 |
-
save_safetensors = True,
|
120 |
-
save_on_each_node = False,
|
121 |
-
save_only_model = False,
|
122 |
-
restore_callback_states_from_checkpoint = False,
|
123 |
-
no_cuda = False,
|
124 |
-
use_cpu = False,
|
125 |
-
use_mps_device = False,
|
126 |
-
seed = 3407,
|
127 |
-
data_seed = 3407,
|
128 |
-
jit_mode_eval = False,
|
129 |
-
use_ipex = False,
|
130 |
-
bf16 = False,
|
131 |
-
fp16 = False,
|
132 |
-
fp16_opt_level = 'O1',
|
133 |
-
half_precision_backend = 'auto',
|
134 |
-
bf16_full_eval = False,
|
135 |
-
fp16_full_eval = False,
|
136 |
-
tf32 = None,
|
137 |
-
local_rank = -1,
|
138 |
-
ddp_backend = None,
|
139 |
-
tpu_num_cores = None,
|
140 |
-
tpu_metrics_debug = False,
|
141 |
-
debug = '',
|
142 |
-
dataloader_drop_last = False,
|
143 |
-
eval_steps = None,
|
144 |
-
dataloader_num_workers = 0,
|
145 |
-
dataloader_prefetch_factor = None,
|
146 |
-
past_index = -1,
|
147 |
-
run_name = None,
|
148 |
-
disable_tqdm = None,
|
149 |
-
remove_unused_columns = True,
|
150 |
-
label_names = None,
|
151 |
-
load_best_model_at_end = False,
|
152 |
-
metric_for_best_model = None,
|
153 |
-
greater_is_better = None,
|
154 |
-
ignore_data_skip = False,
|
155 |
-
fsdp = '',
|
156 |
-
fsdp_min_num_params = 0,
|
157 |
-
fsdp_config = None,
|
158 |
-
fsdp_transformer_layer_cls_to_wrap = None,
|
159 |
-
accelerator_config = None,
|
160 |
-
deepspeed = None,
|
161 |
-
label_smoothing_factor = 0.0,
|
162 |
-
optim = 'adamw_8bit',
|
163 |
-
optim_args = None,
|
164 |
-
adafactor = False,
|
165 |
-
group_by_length = False,
|
166 |
-
length_column_name = 'length',
|
167 |
-
report_to = None,
|
168 |
-
ddp_find_unused_parameters = None,
|
169 |
-
ddp_bucket_cap_mb = None,
|
170 |
-
ddp_broadcast_buffers = None,
|
171 |
-
dataloader_pin_memory = True,
|
172 |
-
dataloader_persistent_workers = False,
|
173 |
-
skip_memory_metrics = True,
|
174 |
-
use_legacy_prediction_loop = False,
|
175 |
-
push_to_hub = False,
|
176 |
-
resume_from_checkpoint = None,
|
177 |
-
hub_model_id = None,
|
178 |
-
hub_strategy = 'every_save',
|
179 |
-
hub_token = None,
|
180 |
-
hub_private_repo = None,
|
181 |
-
hub_always_push = False,
|
182 |
-
gradient_checkpointing = False,
|
183 |
-
gradient_checkpointing_kwargs = None,
|
184 |
-
include_inputs_for_metrics = False,
|
185 |
-
eval_do_concat_batches = True,
|
186 |
-
fp16_backend = 'auto',
|
187 |
-
evaluation_strategy = None,
|
188 |
-
push_to_hub_model_id = None,
|
189 |
-
push_to_hub_organization = None,
|
190 |
-
push_to_hub_token = None,
|
191 |
-
mp_parameters = '',
|
192 |
-
auto_find_batch_size = False,
|
193 |
-
full_determinism = False,
|
194 |
-
torchdynamo = None,
|
195 |
-
ray_scope = 'last',
|
196 |
-
ddp_timeout = 1800,
|
197 |
-
torch_compile = False,
|
198 |
-
torch_compile_backend = None,
|
199 |
-
torch_compile_mode = None,
|
200 |
-
dispatch_batches = None,
|
201 |
-
split_batches = None,
|
202 |
-
include_tokens_per_second = False,
|
203 |
-
include_num_input_tokens_seen = False,
|
204 |
-
neftune_noise_alpha = None,
|
205 |
-
optim_target_modules = None,
|
206 |
-
batch_eval_metrics = False,
|
207 |
-
eval_on_start = False,
|
208 |
-
use_liger_kernel = False,
|
209 |
-
eval_use_gather_object = False,
|
210 |
-
average_tokens_across_devices = False,
|
211 |
-
max_length = 1024,
|
212 |
-
max_prompt_length = 512,
|
213 |
-
max_completion_length = None,
|
214 |
-
disable_dropout = True,
|
215 |
-
step_separator = '\
|
216 |
-
',
|
217 |
-
train_on_last_step_only = False,
|
218 |
-
dataset_num_proc = None,
|
219 |
-
vllm_sampling_params = None,
|
220 |
-
unsloth_num_chunks = -1,
|
221 |
-
**kwargs,
|
222 |
-
):
|
223 |
-
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
224 |
-
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
225 |
-
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
226 |
-
output_dir = 'unsloth_training_checkpoints'
|
227 |
-
save_strategy = 'no'
|
228 |
-
if dataset_num_proc is None:
|
229 |
-
from multiprocessing import cpu_count
|
230 |
-
dataset_num_proc = cpu_count()
|
231 |
-
|
232 |
-
super().__init__(
|
233 |
-
output_dir = output_dir,
|
234 |
-
overwrite_output_dir = overwrite_output_dir,
|
235 |
-
do_train = do_train,
|
236 |
-
do_eval = do_eval,
|
237 |
-
do_predict = do_predict,
|
238 |
-
eval_strategy = eval_strategy,
|
239 |
-
prediction_loss_only = prediction_loss_only,
|
240 |
-
per_device_train_batch_size = per_device_train_batch_size,
|
241 |
-
per_device_eval_batch_size = per_device_eval_batch_size,
|
242 |
-
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
243 |
-
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
244 |
-
gradient_accumulation_steps = gradient_accumulation_steps,
|
245 |
-
eval_accumulation_steps = eval_accumulation_steps,
|
246 |
-
eval_delay = eval_delay,
|
247 |
-
torch_empty_cache_steps = torch_empty_cache_steps,
|
248 |
-
learning_rate = learning_rate,
|
249 |
-
weight_decay = weight_decay,
|
250 |
-
adam_beta1 = adam_beta1,
|
251 |
-
adam_beta2 = adam_beta2,
|
252 |
-
adam_epsilon = adam_epsilon,
|
253 |
-
max_grad_norm = max_grad_norm,
|
254 |
-
num_train_epochs = num_train_epochs,
|
255 |
-
max_steps = max_steps,
|
256 |
-
lr_scheduler_type = lr_scheduler_type,
|
257 |
-
warmup_ratio = warmup_ratio,
|
258 |
-
warmup_steps = warmup_steps,
|
259 |
-
log_level = log_level,
|
260 |
-
log_level_replica = log_level_replica,
|
261 |
-
log_on_each_node = log_on_each_node,
|
262 |
-
logging_dir = logging_dir,
|
263 |
-
logging_strategy = logging_strategy,
|
264 |
-
logging_first_step = logging_first_step,
|
265 |
-
logging_steps = logging_steps,
|
266 |
-
logging_nan_inf_filter = logging_nan_inf_filter,
|
267 |
-
save_strategy = save_strategy,
|
268 |
-
save_steps = save_steps,
|
269 |
-
save_total_limit = save_total_limit,
|
270 |
-
save_safetensors = save_safetensors,
|
271 |
-
save_on_each_node = save_on_each_node,
|
272 |
-
save_only_model = save_only_model,
|
273 |
-
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
274 |
-
no_cuda = no_cuda,
|
275 |
-
use_cpu = use_cpu,
|
276 |
-
use_mps_device = use_mps_device,
|
277 |
-
seed = seed,
|
278 |
-
data_seed = data_seed,
|
279 |
-
jit_mode_eval = jit_mode_eval,
|
280 |
-
use_ipex = use_ipex,
|
281 |
-
bf16 = bf16,
|
282 |
-
fp16 = fp16,
|
283 |
-
fp16_opt_level = fp16_opt_level,
|
284 |
-
half_precision_backend = half_precision_backend,
|
285 |
-
bf16_full_eval = bf16_full_eval,
|
286 |
-
fp16_full_eval = fp16_full_eval,
|
287 |
-
tf32 = tf32,
|
288 |
-
local_rank = local_rank,
|
289 |
-
ddp_backend = ddp_backend,
|
290 |
-
tpu_num_cores = tpu_num_cores,
|
291 |
-
tpu_metrics_debug = tpu_metrics_debug,
|
292 |
-
debug = debug,
|
293 |
-
dataloader_drop_last = dataloader_drop_last,
|
294 |
-
eval_steps = eval_steps,
|
295 |
-
dataloader_num_workers = dataloader_num_workers,
|
296 |
-
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
297 |
-
past_index = past_index,
|
298 |
-
run_name = run_name,
|
299 |
-
disable_tqdm = disable_tqdm,
|
300 |
-
remove_unused_columns = remove_unused_columns,
|
301 |
-
label_names = label_names,
|
302 |
-
load_best_model_at_end = load_best_model_at_end,
|
303 |
-
metric_for_best_model = metric_for_best_model,
|
304 |
-
greater_is_better = greater_is_better,
|
305 |
-
ignore_data_skip = ignore_data_skip,
|
306 |
-
fsdp = fsdp,
|
307 |
-
fsdp_min_num_params = fsdp_min_num_params,
|
308 |
-
fsdp_config = fsdp_config,
|
309 |
-
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
310 |
-
accelerator_config = accelerator_config,
|
311 |
-
deepspeed = deepspeed,
|
312 |
-
label_smoothing_factor = label_smoothing_factor,
|
313 |
-
optim = optim,
|
314 |
-
optim_args = optim_args,
|
315 |
-
adafactor = adafactor,
|
316 |
-
group_by_length = group_by_length,
|
317 |
-
length_column_name = length_column_name,
|
318 |
-
report_to = report_to,
|
319 |
-
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
320 |
-
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
321 |
-
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
322 |
-
dataloader_pin_memory = dataloader_pin_memory,
|
323 |
-
dataloader_persistent_workers = dataloader_persistent_workers,
|
324 |
-
skip_memory_metrics = skip_memory_metrics,
|
325 |
-
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
326 |
-
push_to_hub = push_to_hub,
|
327 |
-
resume_from_checkpoint = resume_from_checkpoint,
|
328 |
-
hub_model_id = hub_model_id,
|
329 |
-
hub_strategy = hub_strategy,
|
330 |
-
hub_token = hub_token,
|
331 |
-
hub_private_repo = hub_private_repo,
|
332 |
-
hub_always_push = hub_always_push,
|
333 |
-
gradient_checkpointing = gradient_checkpointing,
|
334 |
-
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
335 |
-
include_inputs_for_metrics = include_inputs_for_metrics,
|
336 |
-
eval_do_concat_batches = eval_do_concat_batches,
|
337 |
-
fp16_backend = fp16_backend,
|
338 |
-
evaluation_strategy = evaluation_strategy,
|
339 |
-
push_to_hub_model_id = push_to_hub_model_id,
|
340 |
-
push_to_hub_organization = push_to_hub_organization,
|
341 |
-
push_to_hub_token = push_to_hub_token,
|
342 |
-
mp_parameters = mp_parameters,
|
343 |
-
auto_find_batch_size = auto_find_batch_size,
|
344 |
-
full_determinism = full_determinism,
|
345 |
-
torchdynamo = torchdynamo,
|
346 |
-
ray_scope = ray_scope,
|
347 |
-
ddp_timeout = ddp_timeout,
|
348 |
-
torch_compile = torch_compile,
|
349 |
-
torch_compile_backend = torch_compile_backend,
|
350 |
-
torch_compile_mode = torch_compile_mode,
|
351 |
-
dispatch_batches = dispatch_batches,
|
352 |
-
split_batches = split_batches,
|
353 |
-
include_tokens_per_second = include_tokens_per_second,
|
354 |
-
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
355 |
-
neftune_noise_alpha = neftune_noise_alpha,
|
356 |
-
optim_target_modules = optim_target_modules,
|
357 |
-
batch_eval_metrics = batch_eval_metrics,
|
358 |
-
eval_on_start = eval_on_start,
|
359 |
-
use_liger_kernel = use_liger_kernel,
|
360 |
-
eval_use_gather_object = eval_use_gather_object,
|
361 |
-
average_tokens_across_devices = average_tokens_across_devices,
|
362 |
-
max_length = max_length,
|
363 |
-
max_prompt_length = max_prompt_length,
|
364 |
-
max_completion_length = max_completion_length,
|
365 |
-
disable_dropout = disable_dropout,
|
366 |
-
step_separator = step_separator,
|
367 |
-
train_on_last_step_only = train_on_last_step_only,
|
368 |
-
dataset_num_proc = dataset_num_proc,**kwargs)
|
369 |
-
self.vllm_sampling_params = vllm_sampling_params
|
370 |
-
self.unsloth_num_chunks = unsloth_num_chunks
|
371 |
-
pass
|
372 |
-
|
373 |
-
class _UnslothPRMTrainer(Trainer):
|
374 |
-
""""""
|
375 |
-
|
376 |
-
_tag_names = ["trl", "prm"]
|
377 |
-
|
378 |
-
def __init__(
|
379 |
-
self,
|
380 |
-
model: Optional[Union[PreTrainedModel, nn.Module]] = None,
|
381 |
-
args: Optional[PRMConfig] = None,
|
382 |
-
data_collator: Optional[DataCollator] = None,
|
383 |
-
train_dataset: Optional[Dataset] = None,
|
384 |
-
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
385 |
-
processing_class: Optional[
|
386 |
-
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
387 |
-
] = None,
|
388 |
-
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
389 |
-
compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
|
390 |
-
callbacks: Optional[list[TrainerCallback]] = None,
|
391 |
-
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (
|
392 |
-
None,
|
393 |
-
None,
|
394 |
-
),
|
395 |
-
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
396 |
-
peft_config: Optional[dict] = None,
|
397 |
-
):
|
398 |
-
if not is_peft_available() and peft_config is not None:
|
399 |
-
raise ValueError(
|
400 |
-
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
|
401 |
-
)
|
402 |
-
elif is_peft_available() and peft_config is not None:
|
403 |
-
if not isinstance(model, PeftModel):
|
404 |
-
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_quantized", False):
|
405 |
-
_supports_gc_kwargs = "gradient_checkpointing_kwargs" in list(
|
406 |
-
inspect.signature(prepare_model_for_kbit_training).parameters
|
407 |
-
)
|
408 |
-
|
409 |
-
prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
|
410 |
-
|
411 |
-
if not _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
|
412 |
-
warnings.warn(
|
413 |
-
"You passed `gradient_checkpointing_kwargs` in the trainer's kwargs, but your peft version does not support it. "
|
414 |
-
"please update to the latest version of peft to use `gradient_checkpointing_kwargs`."
|
415 |
-
)
|
416 |
-
elif _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
|
417 |
-
prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
|
418 |
-
|
419 |
-
model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
|
420 |
-
|
421 |
-
model = model
|
422 |
-
|
423 |
-
# Disable dropout in the model
|
424 |
-
if args.disable_dropout:
|
425 |
-
disable_dropout_in_model(model)
|
426 |
-
|
427 |
-
if compute_metrics is None:
|
428 |
-
compute_metrics = compute_accuracy
|
429 |
-
|
430 |
-
if data_collator is None:
|
431 |
-
if processing_class is None:
|
432 |
-
raise ValueError(
|
433 |
-
"A processing_class must be specified when using the default DataCollatorForTokenClassification"
|
434 |
-
)
|
435 |
-
data_collator = DataCollatorForTokenClassification(processing_class, max_length=args.max_length)
|
436 |
-
|
437 |
-
if "input_ids" not in train_dataset.column_names:
|
438 |
-
with PartialState().local_main_process_first():
|
439 |
-
fn_kwargs = {
|
440 |
-
"tokenizer": processing_class,
|
441 |
-
"step_separator": args.step_separator,
|
442 |
-
"max_length": args.max_length,
|
443 |
-
"max_prompt_length": args.max_prompt_length,
|
444 |
-
"max_completion_length": args.max_completion_length,
|
445 |
-
"train_on_last_step_only": args.train_on_last_step_only,
|
446 |
-
}
|
447 |
-
train_fn_kwargs = {**fn_kwargs, "is_eval": False}
|
448 |
-
train_dataset = train_dataset.map(
|
449 |
-
self.tokenize_row,
|
450 |
-
fn_kwargs=train_fn_kwargs,
|
451 |
-
num_proc=args.dataset_num_proc,
|
452 |
-
remove_columns=train_dataset.features,
|
453 |
-
desc="Tokenizing train dataset",
|
454 |
-
features=features.Features( # needed to avoid map to cast labels to bool
|
455 |
-
{
|
456 |
-
"labels": features.Sequence(features.Value("int64")),
|
457 |
-
"input_ids": features.Sequence(features.Value("int64")),
|
458 |
-
}
|
459 |
-
),
|
460 |
-
)
|
461 |
-
|
462 |
-
eval_fn_kwargs = {**fn_kwargs, "is_eval": True}
|
463 |
-
if eval_dataset is not None:
|
464 |
-
eval_dataset = eval_dataset.map(
|
465 |
-
self.tokenize_row,
|
466 |
-
fn_kwargs=eval_fn_kwargs,
|
467 |
-
num_proc=args.dataset_num_proc,
|
468 |
-
remove_columns=eval_dataset.features,
|
469 |
-
desc="Tokenizing eval dataset",
|
470 |
-
features=features.Features( # needed to avoid map to cast labels to bool
|
471 |
-
{
|
472 |
-
"labels": features.Sequence(features.Value("int64")),
|
473 |
-
"input_ids": features.Sequence(features.Value("int64")),
|
474 |
-
}
|
475 |
-
),
|
476 |
-
)
|
477 |
-
|
478 |
-
super().__init__(
|
479 |
-
model=model,
|
480 |
-
args=args,
|
481 |
-
data_collator=data_collator,
|
482 |
-
train_dataset=train_dataset,
|
483 |
-
eval_dataset=eval_dataset,
|
484 |
-
processing_class=processing_class,
|
485 |
-
model_init=model_init,
|
486 |
-
compute_metrics=compute_metrics,
|
487 |
-
callbacks=callbacks,
|
488 |
-
optimizers=optimizers,
|
489 |
-
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
490 |
-
)
|
491 |
-
|
492 |
-
# Add tags for models that have been loaded with the correct transformers version
|
493 |
-
if hasattr(self.model, "add_model_tags"):
|
494 |
-
self.model.add_model_tags(self._tag_names)
|
495 |
-
|
496 |
-
@staticmethod
|
497 |
-
def tokenize_row(
|
498 |
-
features,
|
499 |
-
tokenizer,
|
500 |
-
step_separator,
|
501 |
-
max_length,
|
502 |
-
max_prompt_length,
|
503 |
-
max_completion_length,
|
504 |
-
train_on_last_step_only,
|
505 |
-
is_eval,
|
506 |
-
):
|
507 |
-
r"""
|
508 |
-
Tokenize a row of the dataset.
|
509 |
-
|
510 |
-
Args:
|
511 |
-
features (`dict[str, str]`):
|
512 |
-
Row of the dataset, should contain the keys `"prompt"`, `"completions"`, and `"labels"`.
|
513 |
-
tokenizer (`PreTrainedTokenizerBase`):
|
514 |
-
Tokenizer used to process the data.
|
515 |
-
step_separator (`str`):
|
516 |
-
Separator between steps in the completion.
|
517 |
-
max_length (`int` or `None`):
|
518 |
-
Maximum length of the sequences (prompt + completion). If `None`, the sequences are not truncated.
|
519 |
-
max_prompt_length (`int` or `None`):
|
520 |
-
Maximum length of the prompt. If `None`, the prompt is not truncated.
|
521 |
-
max_completion_length (`int` or `None`):
|
522 |
-
Maximum length of the completion sequences. If `None`, the completion sequences are not truncated.
|
523 |
-
train_on_last_step_only (`bool`):
|
524 |
-
Whether to train only on the last step. If `True`, the labels are `-100` for all tokens except the last
|
525 |
-
token of the completion.
|
526 |
-
is_eval (`bool`):
|
527 |
-
Whether the function is used to tokenize samples from a training or an evaluation dataset. Used only if `train_on_last_step_only` is set to `True`.
|
528 |
-
|
529 |
-
Returns:
|
530 |
-
`dict[str, list[int]]`:
|
531 |
-
Tokenized sequences with the keys `"input_ids"`, and `"labels".
|
532 |
-
|
533 |
-
Example:
|
534 |
-
```python
|
535 |
-
>>> from transformers import AutoTokenizer
|
536 |
-
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
|
537 |
-
>>> features = {"prompt": "Which number is larger, 9.8 or 9.11?",
|
538 |
-
... "completions": ["11 is greater than 8.",
|
539 |
-
... "Hence, 9.11 > 9.8."],
|
540 |
-
... "labels": [True, False]}
|
541 |
-
>>> PRMTrainer.tokenize_row(features, tokenizer, "\n", max_completion_length=None, train_on_last_step_only=False, is_eval=False)
|
542 |
-
{'input_ids': [23085, 1372, 374, 8131, 11, 220, 24, 13, 23, 476, 220, 24, 13, 16, 16, 30, 16, 16, 374, 7046, 1091, 220, 23, 13, 198, 39, 763, 11, 220, 24, 13, 16, 16, 861, 220, 24, 13, 23, 13, 198],
|
543 |
-
'labels': [-100, -100, -100, -100, -100, -100, -100, -100, 1, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 0]}
|
544 |
-
```
|
545 |
-
"""
|
546 |
-
# Tokenize the prompt and completions
|
547 |
-
prompt_ids = tokenizer(features["prompt"], add_special_tokens=False)["input_ids"]
|
548 |
-
completions_ids = [
|
549 |
-
tokenizer(completion, add_special_tokens=False)["input_ids"] for completion in features["completions"]
|
550 |
-
]
|
551 |
-
if train_on_last_step_only and not is_eval:
|
552 |
-
labels = [-100] * (len(features["labels"]) - 1) + [int(features["labels"][-1])]
|
553 |
-
else:
|
554 |
-
labels = [int(label) for label in features["labels"]]
|
555 |
-
|
556 |
-
# Get the ID of the separator token and add it to the completions
|
557 |
-
separator_ids = tokenizer.encode(step_separator, add_special_tokens=False)
|
558 |
-
completions_ids = [completion + separator_ids for completion in completions_ids]
|
559 |
-
|
560 |
-
# Create the label
|
561 |
-
labels = [[-100] * (len(completion) - 1) + [label] for completion, label in zip(completions_ids, labels)]
|
562 |
-
|
563 |
-
# Join the completions and labels steps
|
564 |
-
completion_ids = list(chain(*completions_ids))
|
565 |
-
labels = list(chain(*labels))
|
566 |
-
|
567 |
-
if tokenizer.bos_token_id is not None:
|
568 |
-
prompt_ids = [tokenizer.bos_token_id] + prompt_ids
|
569 |
-
|
570 |
-
# Truncate prompt and completion sequences
|
571 |
-
if max_prompt_length is not None:
|
572 |
-
prompt_ids = prompt_ids[-max_prompt_length:]
|
573 |
-
if max_completion_length is not None:
|
574 |
-
completion_ids = completion_ids[:max_completion_length]
|
575 |
-
labels = labels[:max_completion_length]
|
576 |
-
|
577 |
-
input_ids = prompt_ids + completion_ids
|
578 |
-
labels = [-100] * len(prompt_ids) + labels
|
579 |
-
|
580 |
-
if max_length is not None:
|
581 |
-
input_ids = input_ids[:max_length]
|
582 |
-
labels = labels[:max_length]
|
583 |
-
|
584 |
-
return {"input_ids": input_ids, "labels": labels}
|
585 |
-
|
586 |
-
def create_model_card(
|
587 |
-
self,
|
588 |
-
model_name: Optional[str] = None,
|
589 |
-
dataset_name: Optional[str] = None,
|
590 |
-
tags: Union[str, list[str], None] = None,
|
591 |
-
):
|
592 |
-
"""
|
593 |
-
Creates a draft of a model card using the information available to the `Trainer`.
|
594 |
-
|
595 |
-
Args:
|
596 |
-
model_name (`str` or `None`, *optional*, defaults to `None`):
|
597 |
-
Name of the model.
|
598 |
-
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
599 |
-
Name of the dataset used for training.
|
600 |
-
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
601 |
-
Tags to be associated with the model card.
|
602 |
-
"""
|
603 |
-
if not self.is_world_process_zero():
|
604 |
-
return
|
605 |
-
|
606 |
-
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
607 |
-
base_model = self.model.config._name_or_path
|
608 |
-
else:
|
609 |
-
base_model = None
|
610 |
-
|
611 |
-
tags = tags or []
|
612 |
-
if isinstance(tags, str):
|
613 |
-
tags = [tags]
|
614 |
-
|
615 |
-
if hasattr(self.model.config, "unsloth_version"):
|
616 |
-
tags.append("unsloth")
|
617 |
-
|
618 |
-
citation = textwrap.dedent("""\
|
619 |
-
@article{uesato2022solving,
|
620 |
-
title = {{Solving Math Word Problems With Process- and Outcome-Based Feedback}},
|
621 |
-
author = {Uesato, Jonathan and Kushman, Nate and Kumar, Ramana and Song, Francis and Siegel, Noah and Wang, Lisa and Creswell, Antonia and Irving, Geoffrey and Higgins, Irina},
|
622 |
-
year = 2022,
|
623 |
-
journal = {arXiv preprint arXiv:2211.14275}
|
624 |
-
}""")
|
625 |
-
|
626 |
-
model_card = generate_model_card(
|
627 |
-
base_model=base_model,
|
628 |
-
model_name=model_name,
|
629 |
-
hub_model_id=self.hub_model_id,
|
630 |
-
dataset_name=dataset_name,
|
631 |
-
tags=tags,
|
632 |
-
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
633 |
-
trainer_name="PRM",
|
634 |
-
trainer_citation=citation,
|
635 |
-
paper_title="Solving math word problems with process-and outcome-based feedback",
|
636 |
-
)
|
637 |
-
|
638 |
-
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
639 |
-
class UnslothPRMTrainer(_UnslothPRMTrainer):
|
640 |
-
"""
|
641 |
-
|
642 |
-
Initialize PRMTrainer.
|
643 |
-
|
644 |
-
Args:
|
645 |
-
model (`transformers.PreTrainedModel`):
|
646 |
-
The model to train, preferably an `AutoModelForTokenClassification`.
|
647 |
-
args (`PRMConfig`):
|
648 |
-
The arguments to use for training.
|
649 |
-
data_collator (`transformers.DataCollator`):
|
650 |
-
The data collator to use for training. If None is specified, the default data collator (`DataCollatorForTokenClassification`) will be used
|
651 |
-
which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
|
652 |
-
train_dataset (`datasets.Dataset`):
|
653 |
-
The dataset to use for training.
|
654 |
-
eval_dataset (`datasets.Dataset`):
|
655 |
-
The dataset to use for evaluation.
|
656 |
-
processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
|
657 |
-
Processing class used to process the data. If provided, will be used to automatically process the inputs
|
658 |
-
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
|
659 |
-
reuse the fine-tuned model.
|
660 |
-
model_init (`Callable[[], transformers.PreTrainedModel]`):
|
661 |
-
The model initializer to use for training. If None is specified, the default model initializer will be used.
|
662 |
-
compute_metrics (`Callable[[transformers.EvalPrediction], dict]`, *optional* defaults to `compute_accuracy`):
|
663 |
-
The metrics to use for evaluation. If no metrics are specified, the default metric (`compute_accuracy`) will be used.
|
664 |
-
callbacks (`list[transformers.TrainerCallback]`):
|
665 |
-
The callbacks to use for training.
|
666 |
-
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
|
667 |
-
The optimizer and scheduler to use for training.
|
668 |
-
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
|
669 |
-
The function to use to preprocess the logits before computing the metrics.
|
670 |
-
peft_config (`dict`, defaults to `None`):
|
671 |
-
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
|
672 |
-
|
673 |
-
"""
|
674 |
-
def __init__(
|
675 |
-
self,
|
676 |
-
model = None,
|
677 |
-
args = None,
|
678 |
-
data_collator = None,
|
679 |
-
train_dataset = None,
|
680 |
-
eval_dataset = None,
|
681 |
-
processing_class = None,
|
682 |
-
model_init = None,
|
683 |
-
compute_metrics = None,
|
684 |
-
callbacks = None,
|
685 |
-
preprocess_logits_for_metrics = None,
|
686 |
-
peft_config = None,
|
687 |
-
**kwargs
|
688 |
-
):
|
689 |
-
if args is None: args = UnslothPRMConfig()
|
690 |
-
use_bf16 = getattr(args, 'bf16', False)
|
691 |
-
use_fp16 = getattr(args, 'fp16', False)
|
692 |
-
force_float32 = False
|
693 |
-
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
694 |
-
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
695 |
-
force_float32 = True
|
696 |
-
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
697 |
-
dtype = getattr(model.config, 'torch_dtype', None)
|
698 |
-
if dtype is None: dtype = model.get_input_embeddings().dtype
|
699 |
-
from unsloth_zoo.utils import _get_dtype
|
700 |
-
dtype = _get_dtype(dtype)
|
701 |
-
float16 = dtype == torch.float16
|
702 |
-
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
703 |
-
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
704 |
-
if force_float32:
|
705 |
-
args.fp16 = False
|
706 |
-
args.bf16 = False
|
707 |
-
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
708 |
-
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
709 |
-
args.fp16 = float16
|
710 |
-
args.bf16 = not float16
|
711 |
-
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
712 |
-
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
713 |
-
args.eval_strategy = 'steps'
|
714 |
-
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
715 |
-
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
716 |
-
if ga_steps is not None and ga_steps > 1:
|
717 |
-
from transformers import __version__ as transformers_version
|
718 |
-
if Version(transformers_version) <= Version('4.45.2'):
|
719 |
-
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
720 |
-
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
721 |
-
if getattr(args, 'eval_strategy', 'no') != 'no':
|
722 |
-
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
723 |
-
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
724 |
-
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
725 |
-
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
726 |
-
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
727 |
-
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
728 |
-
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
729 |
-
if force_float32:
|
730 |
-
args.bf16_full_eval = False
|
731 |
-
args.fp16_full_eval = False
|
732 |
-
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
733 |
-
args.bf16_full_eval = True
|
734 |
-
args.fp16_full_eval = False
|
735 |
-
elif not bf16_full_eval and not fp16_full_eval:
|
736 |
-
args.bf16_full_eval = args.bf16
|
737 |
-
args.fp16_full_eval = args.fp16
|
738 |
-
_output_logits = False
|
739 |
-
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
740 |
-
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
741 |
-
if _output_logits:
|
742 |
-
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
743 |
-
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
744 |
-
pass
|
745 |
-
else:
|
746 |
-
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
747 |
-
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
748 |
-
if args_max_seq_length is None and model_max_seq_length is not None:
|
749 |
-
max_seq_length = model.max_seq_length
|
750 |
-
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
751 |
-
if model is not None and hasattr(model, 'for_training'):
|
752 |
-
model.for_training()
|
753 |
-
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
754 |
-
if 'processing_class' in locals():
|
755 |
-
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
756 |
-
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
757 |
-
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
758 |
-
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
759 |
-
if not isinstance(data_collator, UnslothVisionDataCollator):
|
760 |
-
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
761 |
-
data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
|
762 |
-
elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
763 |
-
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
764 |
-
else:
|
765 |
-
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
766 |
-
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
767 |
-
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
768 |
-
if not isinstance(data_collator, UnslothVisionDataCollator):
|
769 |
-
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
770 |
-
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
771 |
-
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
772 |
-
else:
|
773 |
-
data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
|
774 |
-
other_metrics = []
|
775 |
-
|
776 |
-
from unsloth_zoo.logging_utils import PatchRLStatistics
|
777 |
-
PatchRLStatistics('prm_trainer', other_metrics)
|
778 |
-
|
779 |
-
super().__init__(
|
780 |
-
model = model,
|
781 |
-
args = args,
|
782 |
-
data_collator = data_collator,
|
783 |
-
train_dataset = train_dataset,
|
784 |
-
eval_dataset = eval_dataset,
|
785 |
-
processing_class = processing_class,
|
786 |
-
model_init = model_init,
|
787 |
-
compute_metrics = compute_metrics,
|
788 |
-
callbacks = callbacks,
|
789 |
-
preprocess_logits_for_metrics = preprocess_logits_for_metrics,
|
790 |
-
peft_config = peft_config,**kwargs)
|
791 |
-
if hasattr(self, 'neftune_hook_handle'):
|
792 |
-
self.neftune_hook_handle.remove()
|
793 |
-
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
794 |
-
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
795 |
-
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
796 |
-
pass
|
797 |
-
|
798 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unsloth_compiled_cache/UnslothRLOOTrainer.py
DELETED
@@ -1,1131 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
2025.3.13
|
3 |
-
2025.3.15
|
4 |
-
4.48.3
|
5 |
-
0.15.2
|
6 |
-
__UNSLOTH_VERSIONING__
|
7 |
-
"""
|
8 |
-
from torch import Tensor
|
9 |
-
import torch
|
10 |
-
import torch.nn as nn
|
11 |
-
from torch.nn import functional as F
|
12 |
-
from trl.trainer.rloo_trainer import (Accelerator, BaseImageProcessor, Callable, CallbackHandler, DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK, DataCollatorWithPadding, DataLoader, Dataset, ExportableState, FeatureExtractionMixin, GenerationConfig, INVALID_LOGPROB, OnlineTrainerState, Optional, PreTrainedTokenizerBase, PrinterCallback, ProcessorMixin, RLOOConfig, RLOOTrainer, Trainer, TrainerCallback, TrainerControl, Union, batch_generation, broadcast, defaultdict, disable_dropout_in_model, exact_div, first_true_indices, forward, gather_object, gc, generate_model_card, get_comet_experiment_url, get_reporting_integration_callbacks, get_reward, is_wandb_available, log_table_to_comet_experiment, math, nn, np, os, pd, prepare_deepspeed, print_rich_table, textwrap, time, torch, truncate_response, unwrap_model_for_generation, wandb)
|
13 |
-
|
14 |
-
|
15 |
-
import os
|
16 |
-
from typing import *
|
17 |
-
from dataclasses import dataclass, field
|
18 |
-
from packaging.version import Version
|
19 |
-
import torch
|
20 |
-
import numpy as np
|
21 |
-
from contextlib import nullcontext
|
22 |
-
from torch.nn import functional as F
|
23 |
-
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
|
24 |
-
|
25 |
-
torch_compile_options = {
|
26 |
-
"epilogue_fusion" : True,
|
27 |
-
"max_autotune" : False,
|
28 |
-
"shape_padding" : True,
|
29 |
-
"trace.enabled" : False,
|
30 |
-
"triton.cudagraphs" : False,
|
31 |
-
}
|
32 |
-
|
33 |
-
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
34 |
-
def selective_log_softmax(logits, index):
|
35 |
-
logits = logits.to(torch.float32)
|
36 |
-
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
37 |
-
# loop to reduce peak mem consumption
|
38 |
-
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
39 |
-
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
40 |
-
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
41 |
-
return per_token_logps
|
42 |
-
@dataclass
|
43 |
-
class UnslothRLOOConfig(RLOOConfig):
|
44 |
-
"""
|
45 |
-
|
46 |
-
Configuration class for the [`RLOOTrainer`].
|
47 |
-
|
48 |
-
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
49 |
-
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
50 |
-
command line.
|
51 |
-
|
52 |
-
Parameters:
|
53 |
-
exp_name (`str`, *optional*, defaults to `os.path.basename(__file__)[: -len(".py")]`):
|
54 |
-
Name of this experiment.
|
55 |
-
reward_model_path (`str`, *optional*, defaults to `"EleutherAI/pythia-160m"`):
|
56 |
-
Path to the reward model.
|
57 |
-
num_ppo_epochs (`int`, *optional*, defaults to `4`):
|
58 |
-
Number of epochs to train.
|
59 |
-
whiten_rewards (`bool`, *optional*, defaults to `False`):
|
60 |
-
Whether to whiten the rewards.
|
61 |
-
kl_coef (`float`, *optional*, defaults to `0.05`):
|
62 |
-
KL coefficient.
|
63 |
-
cliprange (`float`, *optional*, defaults to `0.2`):
|
64 |
-
Clip range.
|
65 |
-
rloo_k (`int`, *optional*, defaults to `2`):
|
66 |
-
REINFORCE Leave-One-Out (RLOO) number of online samples per prompt.
|
67 |
-
normalize_reward (`bool`, *optional*, defaults to `False`):
|
68 |
-
Whether to normalize rewards.
|
69 |
-
reward_clip_range (`float`, *optional*, defaults to `10.0`):
|
70 |
-
Clip range for rewards.
|
71 |
-
normalize_advantage (`bool`, *optional*, defaults to `False`):
|
72 |
-
Whether to normalize advantages.
|
73 |
-
token_level_kl (`bool`, *optional*, defaults to `True`):
|
74 |
-
Whether to use token-level KL penalty or sequence-level KL penalty.
|
75 |
-
ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
|
76 |
-
This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
|
77 |
-
improving generation speed. However, disabling this option allows training models that exceed the VRAM
|
78 |
-
capacity of a single GPU, albeit at the cost of slower generation.
|
79 |
-
|
80 |
-
"""
|
81 |
-
vllm_sampling_params: Optional[Any] = field(
|
82 |
-
default = None,
|
83 |
-
metadata = {'help': 'vLLM SamplingParams'},
|
84 |
-
)
|
85 |
-
unsloth_num_chunks : Optional[int] = field(
|
86 |
-
default = -1,
|
87 |
-
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
88 |
-
)
|
89 |
-
def __init__(
|
90 |
-
self,
|
91 |
-
output_dir = None,
|
92 |
-
overwrite_output_dir = None,
|
93 |
-
do_train = False,
|
94 |
-
do_eval = False,
|
95 |
-
do_predict = False,
|
96 |
-
eval_strategy = 'no',
|
97 |
-
prediction_loss_only = False,
|
98 |
-
per_device_train_batch_size = 4,
|
99 |
-
per_device_eval_batch_size = 4,
|
100 |
-
per_gpu_train_batch_size = None,
|
101 |
-
per_gpu_eval_batch_size = None,
|
102 |
-
gradient_accumulation_steps = 2,
|
103 |
-
eval_accumulation_steps = 2,
|
104 |
-
eval_delay = 0,
|
105 |
-
torch_empty_cache_steps = 250,
|
106 |
-
learning_rate = 5e-05,
|
107 |
-
weight_decay = 0.01,
|
108 |
-
adam_beta1 = 0.9,
|
109 |
-
adam_beta2 = 0.999,
|
110 |
-
adam_epsilon = 1e-08,
|
111 |
-
max_grad_norm = 1.0,
|
112 |
-
num_train_epochs = 3.0,
|
113 |
-
max_steps = -1,
|
114 |
-
lr_scheduler_type = 'linear',
|
115 |
-
warmup_ratio = 0.1,
|
116 |
-
warmup_steps = 0,
|
117 |
-
log_level = 'passive',
|
118 |
-
log_level_replica = 'warning',
|
119 |
-
log_on_each_node = True,
|
120 |
-
logging_dir = None,
|
121 |
-
logging_strategy = 'steps',
|
122 |
-
logging_first_step = False,
|
123 |
-
logging_steps = 1,
|
124 |
-
logging_nan_inf_filter = False,
|
125 |
-
save_strategy = 'steps',
|
126 |
-
save_steps = 500,
|
127 |
-
save_total_limit = None,
|
128 |
-
save_safetensors = True,
|
129 |
-
save_on_each_node = False,
|
130 |
-
save_only_model = False,
|
131 |
-
restore_callback_states_from_checkpoint = False,
|
132 |
-
no_cuda = False,
|
133 |
-
use_cpu = False,
|
134 |
-
use_mps_device = False,
|
135 |
-
seed = 3407,
|
136 |
-
data_seed = 3407,
|
137 |
-
jit_mode_eval = False,
|
138 |
-
use_ipex = False,
|
139 |
-
bf16 = False,
|
140 |
-
fp16 = False,
|
141 |
-
fp16_opt_level = 'O1',
|
142 |
-
half_precision_backend = 'auto',
|
143 |
-
bf16_full_eval = False,
|
144 |
-
fp16_full_eval = False,
|
145 |
-
tf32 = None,
|
146 |
-
local_rank = -1,
|
147 |
-
ddp_backend = None,
|
148 |
-
tpu_num_cores = None,
|
149 |
-
tpu_metrics_debug = False,
|
150 |
-
debug = '',
|
151 |
-
dataloader_drop_last = False,
|
152 |
-
eval_steps = None,
|
153 |
-
dataloader_num_workers = 0,
|
154 |
-
dataloader_prefetch_factor = None,
|
155 |
-
past_index = -1,
|
156 |
-
run_name = None,
|
157 |
-
disable_tqdm = None,
|
158 |
-
remove_unused_columns = True,
|
159 |
-
label_names = None,
|
160 |
-
load_best_model_at_end = False,
|
161 |
-
metric_for_best_model = None,
|
162 |
-
greater_is_better = None,
|
163 |
-
ignore_data_skip = False,
|
164 |
-
fsdp = '',
|
165 |
-
fsdp_min_num_params = 0,
|
166 |
-
fsdp_config = None,
|
167 |
-
fsdp_transformer_layer_cls_to_wrap = None,
|
168 |
-
accelerator_config = None,
|
169 |
-
deepspeed = None,
|
170 |
-
label_smoothing_factor = 0.0,
|
171 |
-
optim = 'adamw_8bit',
|
172 |
-
optim_args = None,
|
173 |
-
adafactor = False,
|
174 |
-
group_by_length = False,
|
175 |
-
length_column_name = 'length',
|
176 |
-
report_to = None,
|
177 |
-
ddp_find_unused_parameters = None,
|
178 |
-
ddp_bucket_cap_mb = None,
|
179 |
-
ddp_broadcast_buffers = None,
|
180 |
-
dataloader_pin_memory = True,
|
181 |
-
dataloader_persistent_workers = False,
|
182 |
-
skip_memory_metrics = True,
|
183 |
-
use_legacy_prediction_loop = False,
|
184 |
-
push_to_hub = False,
|
185 |
-
resume_from_checkpoint = None,
|
186 |
-
hub_model_id = None,
|
187 |
-
hub_strategy = 'every_save',
|
188 |
-
hub_token = None,
|
189 |
-
hub_private_repo = None,
|
190 |
-
hub_always_push = False,
|
191 |
-
gradient_checkpointing = False,
|
192 |
-
gradient_checkpointing_kwargs = None,
|
193 |
-
include_inputs_for_metrics = False,
|
194 |
-
eval_do_concat_batches = True,
|
195 |
-
fp16_backend = 'auto',
|
196 |
-
evaluation_strategy = None,
|
197 |
-
push_to_hub_model_id = None,
|
198 |
-
push_to_hub_organization = None,
|
199 |
-
push_to_hub_token = None,
|
200 |
-
mp_parameters = '',
|
201 |
-
auto_find_batch_size = False,
|
202 |
-
full_determinism = False,
|
203 |
-
torchdynamo = None,
|
204 |
-
ray_scope = 'last',
|
205 |
-
ddp_timeout = 1800,
|
206 |
-
torch_compile = False,
|
207 |
-
torch_compile_backend = None,
|
208 |
-
torch_compile_mode = None,
|
209 |
-
dispatch_batches = None,
|
210 |
-
split_batches = None,
|
211 |
-
include_tokens_per_second = False,
|
212 |
-
include_num_input_tokens_seen = False,
|
213 |
-
neftune_noise_alpha = None,
|
214 |
-
optim_target_modules = None,
|
215 |
-
batch_eval_metrics = False,
|
216 |
-
eval_on_start = False,
|
217 |
-
use_liger_kernel = False,
|
218 |
-
eval_use_gather_object = False,
|
219 |
-
average_tokens_across_devices = False,
|
220 |
-
dataset_num_proc = None,
|
221 |
-
num_mini_batches = 1,
|
222 |
-
total_episodes = None,
|
223 |
-
local_rollout_forward_batch_size = 64,
|
224 |
-
num_sample_generations = 10,
|
225 |
-
response_length = 53,
|
226 |
-
stop_token = None,
|
227 |
-
stop_token_id = None,
|
228 |
-
temperature = 0.7,
|
229 |
-
missing_eos_penalty = None,
|
230 |
-
sft_model_path = 'EleutherAI/pythia-160m',
|
231 |
-
world_size = None,
|
232 |
-
num_total_batches = None,
|
233 |
-
micro_batch_size = None,
|
234 |
-
local_batch_size = None,
|
235 |
-
batch_size = None,
|
236 |
-
local_mini_batch_size = None,
|
237 |
-
mini_batch_size = None,
|
238 |
-
exp_name = 'rloo_config',
|
239 |
-
reward_model_path = 'EleutherAI/pythia-160m',
|
240 |
-
num_ppo_epochs = 4,
|
241 |
-
whiten_rewards = False,
|
242 |
-
kl_coef = 0.05,
|
243 |
-
cliprange = 0.2,
|
244 |
-
rloo_k = 2,
|
245 |
-
normalize_reward = False,
|
246 |
-
reward_clip_range = 10.0,
|
247 |
-
normalize_advantage = False,
|
248 |
-
token_level_kl = False,
|
249 |
-
ds3_gather_for_generation = True,
|
250 |
-
vllm_sampling_params = None,
|
251 |
-
unsloth_num_chunks = -1,
|
252 |
-
**kwargs,
|
253 |
-
):
|
254 |
-
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
255 |
-
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
256 |
-
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
257 |
-
output_dir = 'unsloth_training_checkpoints'
|
258 |
-
save_strategy = 'no'
|
259 |
-
if dataset_num_proc is None:
|
260 |
-
from multiprocessing import cpu_count
|
261 |
-
dataset_num_proc = cpu_count()
|
262 |
-
|
263 |
-
super().__init__(
|
264 |
-
output_dir = output_dir,
|
265 |
-
overwrite_output_dir = overwrite_output_dir,
|
266 |
-
do_train = do_train,
|
267 |
-
do_eval = do_eval,
|
268 |
-
do_predict = do_predict,
|
269 |
-
eval_strategy = eval_strategy,
|
270 |
-
prediction_loss_only = prediction_loss_only,
|
271 |
-
per_device_train_batch_size = per_device_train_batch_size,
|
272 |
-
per_device_eval_batch_size = per_device_eval_batch_size,
|
273 |
-
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
274 |
-
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
275 |
-
gradient_accumulation_steps = gradient_accumulation_steps,
|
276 |
-
eval_accumulation_steps = eval_accumulation_steps,
|
277 |
-
eval_delay = eval_delay,
|
278 |
-
torch_empty_cache_steps = torch_empty_cache_steps,
|
279 |
-
learning_rate = learning_rate,
|
280 |
-
weight_decay = weight_decay,
|
281 |
-
adam_beta1 = adam_beta1,
|
282 |
-
adam_beta2 = adam_beta2,
|
283 |
-
adam_epsilon = adam_epsilon,
|
284 |
-
max_grad_norm = max_grad_norm,
|
285 |
-
num_train_epochs = num_train_epochs,
|
286 |
-
max_steps = max_steps,
|
287 |
-
lr_scheduler_type = lr_scheduler_type,
|
288 |
-
warmup_ratio = warmup_ratio,
|
289 |
-
warmup_steps = warmup_steps,
|
290 |
-
log_level = log_level,
|
291 |
-
log_level_replica = log_level_replica,
|
292 |
-
log_on_each_node = log_on_each_node,
|
293 |
-
logging_dir = logging_dir,
|
294 |
-
logging_strategy = logging_strategy,
|
295 |
-
logging_first_step = logging_first_step,
|
296 |
-
logging_steps = logging_steps,
|
297 |
-
logging_nan_inf_filter = logging_nan_inf_filter,
|
298 |
-
save_strategy = save_strategy,
|
299 |
-
save_steps = save_steps,
|
300 |
-
save_total_limit = save_total_limit,
|
301 |
-
save_safetensors = save_safetensors,
|
302 |
-
save_on_each_node = save_on_each_node,
|
303 |
-
save_only_model = save_only_model,
|
304 |
-
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
305 |
-
no_cuda = no_cuda,
|
306 |
-
use_cpu = use_cpu,
|
307 |
-
use_mps_device = use_mps_device,
|
308 |
-
seed = seed,
|
309 |
-
data_seed = data_seed,
|
310 |
-
jit_mode_eval = jit_mode_eval,
|
311 |
-
use_ipex = use_ipex,
|
312 |
-
bf16 = bf16,
|
313 |
-
fp16 = fp16,
|
314 |
-
fp16_opt_level = fp16_opt_level,
|
315 |
-
half_precision_backend = half_precision_backend,
|
316 |
-
bf16_full_eval = bf16_full_eval,
|
317 |
-
fp16_full_eval = fp16_full_eval,
|
318 |
-
tf32 = tf32,
|
319 |
-
local_rank = local_rank,
|
320 |
-
ddp_backend = ddp_backend,
|
321 |
-
tpu_num_cores = tpu_num_cores,
|
322 |
-
tpu_metrics_debug = tpu_metrics_debug,
|
323 |
-
debug = debug,
|
324 |
-
dataloader_drop_last = dataloader_drop_last,
|
325 |
-
eval_steps = eval_steps,
|
326 |
-
dataloader_num_workers = dataloader_num_workers,
|
327 |
-
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
328 |
-
past_index = past_index,
|
329 |
-
run_name = run_name,
|
330 |
-
disable_tqdm = disable_tqdm,
|
331 |
-
remove_unused_columns = remove_unused_columns,
|
332 |
-
label_names = label_names,
|
333 |
-
load_best_model_at_end = load_best_model_at_end,
|
334 |
-
metric_for_best_model = metric_for_best_model,
|
335 |
-
greater_is_better = greater_is_better,
|
336 |
-
ignore_data_skip = ignore_data_skip,
|
337 |
-
fsdp = fsdp,
|
338 |
-
fsdp_min_num_params = fsdp_min_num_params,
|
339 |
-
fsdp_config = fsdp_config,
|
340 |
-
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
341 |
-
accelerator_config = accelerator_config,
|
342 |
-
deepspeed = deepspeed,
|
343 |
-
label_smoothing_factor = label_smoothing_factor,
|
344 |
-
optim = optim,
|
345 |
-
optim_args = optim_args,
|
346 |
-
adafactor = adafactor,
|
347 |
-
group_by_length = group_by_length,
|
348 |
-
length_column_name = length_column_name,
|
349 |
-
report_to = report_to,
|
350 |
-
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
351 |
-
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
352 |
-
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
353 |
-
dataloader_pin_memory = dataloader_pin_memory,
|
354 |
-
dataloader_persistent_workers = dataloader_persistent_workers,
|
355 |
-
skip_memory_metrics = skip_memory_metrics,
|
356 |
-
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
357 |
-
push_to_hub = push_to_hub,
|
358 |
-
resume_from_checkpoint = resume_from_checkpoint,
|
359 |
-
hub_model_id = hub_model_id,
|
360 |
-
hub_strategy = hub_strategy,
|
361 |
-
hub_token = hub_token,
|
362 |
-
hub_private_repo = hub_private_repo,
|
363 |
-
hub_always_push = hub_always_push,
|
364 |
-
gradient_checkpointing = gradient_checkpointing,
|
365 |
-
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
366 |
-
include_inputs_for_metrics = include_inputs_for_metrics,
|
367 |
-
eval_do_concat_batches = eval_do_concat_batches,
|
368 |
-
fp16_backend = fp16_backend,
|
369 |
-
evaluation_strategy = evaluation_strategy,
|
370 |
-
push_to_hub_model_id = push_to_hub_model_id,
|
371 |
-
push_to_hub_organization = push_to_hub_organization,
|
372 |
-
push_to_hub_token = push_to_hub_token,
|
373 |
-
mp_parameters = mp_parameters,
|
374 |
-
auto_find_batch_size = auto_find_batch_size,
|
375 |
-
full_determinism = full_determinism,
|
376 |
-
torchdynamo = torchdynamo,
|
377 |
-
ray_scope = ray_scope,
|
378 |
-
ddp_timeout = ddp_timeout,
|
379 |
-
torch_compile = torch_compile,
|
380 |
-
torch_compile_backend = torch_compile_backend,
|
381 |
-
torch_compile_mode = torch_compile_mode,
|
382 |
-
dispatch_batches = dispatch_batches,
|
383 |
-
split_batches = split_batches,
|
384 |
-
include_tokens_per_second = include_tokens_per_second,
|
385 |
-
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
386 |
-
neftune_noise_alpha = neftune_noise_alpha,
|
387 |
-
optim_target_modules = optim_target_modules,
|
388 |
-
batch_eval_metrics = batch_eval_metrics,
|
389 |
-
eval_on_start = eval_on_start,
|
390 |
-
use_liger_kernel = use_liger_kernel,
|
391 |
-
eval_use_gather_object = eval_use_gather_object,
|
392 |
-
average_tokens_across_devices = average_tokens_across_devices,
|
393 |
-
dataset_num_proc = dataset_num_proc,
|
394 |
-
num_mini_batches = num_mini_batches,
|
395 |
-
total_episodes = total_episodes,
|
396 |
-
local_rollout_forward_batch_size = local_rollout_forward_batch_size,
|
397 |
-
num_sample_generations = num_sample_generations,
|
398 |
-
response_length = response_length,
|
399 |
-
stop_token = stop_token,
|
400 |
-
stop_token_id = stop_token_id,
|
401 |
-
temperature = temperature,
|
402 |
-
missing_eos_penalty = missing_eos_penalty,
|
403 |
-
sft_model_path = sft_model_path,
|
404 |
-
world_size = world_size,
|
405 |
-
num_total_batches = num_total_batches,
|
406 |
-
micro_batch_size = micro_batch_size,
|
407 |
-
local_batch_size = local_batch_size,
|
408 |
-
batch_size = batch_size,
|
409 |
-
local_mini_batch_size = local_mini_batch_size,
|
410 |
-
mini_batch_size = mini_batch_size,
|
411 |
-
exp_name = exp_name,
|
412 |
-
reward_model_path = reward_model_path,
|
413 |
-
num_ppo_epochs = num_ppo_epochs,
|
414 |
-
whiten_rewards = whiten_rewards,
|
415 |
-
kl_coef = kl_coef,
|
416 |
-
cliprange = cliprange,
|
417 |
-
rloo_k = rloo_k,
|
418 |
-
normalize_reward = normalize_reward,
|
419 |
-
reward_clip_range = reward_clip_range,
|
420 |
-
normalize_advantage = normalize_advantage,
|
421 |
-
token_level_kl = token_level_kl,
|
422 |
-
ds3_gather_for_generation = ds3_gather_for_generation,**kwargs)
|
423 |
-
self.vllm_sampling_params = vllm_sampling_params
|
424 |
-
self.unsloth_num_chunks = unsloth_num_chunks
|
425 |
-
pass
|
426 |
-
|
427 |
-
class _UnslothRLOOTrainer(Trainer):
|
428 |
-
_tag_names = ["trl", "rloo"]
|
429 |
-
|
430 |
-
def __init__(
|
431 |
-
self,
|
432 |
-
config: RLOOConfig,
|
433 |
-
processing_class: Optional[
|
434 |
-
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
435 |
-
],
|
436 |
-
policy: nn.Module,
|
437 |
-
ref_policy: nn.Module,
|
438 |
-
reward_model: Union[nn.Module, Callable[[list[str]], list[float]]],
|
439 |
-
train_dataset: Dataset,
|
440 |
-
data_collator: Optional[DataCollatorWithPadding] = None,
|
441 |
-
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
442 |
-
# less commonly used
|
443 |
-
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
444 |
-
callbacks: Optional[list[TrainerCallback]] = None,
|
445 |
-
) -> None:
|
446 |
-
if ref_policy is policy:
|
447 |
-
raise ValueError(
|
448 |
-
"`policy` and `ref_policy` cannot be the same object. If you want `ref_policy` to be the "
|
449 |
-
"same as `policy`, you must mass a copy of it, or `None` if you use peft."
|
450 |
-
)
|
451 |
-
|
452 |
-
self.args = config
|
453 |
-
args = config
|
454 |
-
self.processing_class = processing_class
|
455 |
-
self.policy = policy
|
456 |
-
|
457 |
-
# Define the collator if not provided
|
458 |
-
if data_collator is None:
|
459 |
-
data_collator = DataCollatorWithPadding(self.processing_class)
|
460 |
-
|
461 |
-
self.policy.generation_config.eos_token_id = (
|
462 |
-
None # disable `pad_token_id` and `eos_token_id` because we just want to
|
463 |
-
)
|
464 |
-
self.policy.generation_config.pad_token_id = None # generate tokens without truncation / padding
|
465 |
-
|
466 |
-
self.ref_policy = ref_policy
|
467 |
-
self.reward_model = reward_model
|
468 |
-
self.train_dataset = train_dataset
|
469 |
-
self.train_dataset_len = len(train_dataset)
|
470 |
-
self.data_collator = data_collator
|
471 |
-
self.eval_dataset = eval_dataset
|
472 |
-
self.optimizer, self.lr_scheduler = optimizers
|
473 |
-
self.optimizer_cls_and_kwargs = None # needed for transformers >= 4.47
|
474 |
-
|
475 |
-
#########
|
476 |
-
# calculate various batch sizes
|
477 |
-
#########
|
478 |
-
if args.total_episodes is None: # allow the users to define episodes in terms of epochs.
|
479 |
-
args.total_episodes = int(args.num_train_epochs * self.train_dataset_len)
|
480 |
-
accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps)
|
481 |
-
self.accelerator = accelerator
|
482 |
-
args.world_size = accelerator.num_processes
|
483 |
-
args.local_batch_size = (
|
484 |
-
args.per_device_train_batch_size * args.gradient_accumulation_steps * args.num_mini_batches
|
485 |
-
)
|
486 |
-
args.micro_batch_size = int(args.per_device_train_batch_size * args.world_size)
|
487 |
-
args.batch_size = int(args.local_batch_size * args.world_size)
|
488 |
-
args.mini_batch_size = exact_div(
|
489 |
-
args.batch_size, args.num_mini_batches, "`batch_size` must be a multiple of `num_mini_batches`"
|
490 |
-
)
|
491 |
-
args.local_mini_batch_size = exact_div(
|
492 |
-
args.local_batch_size, args.num_mini_batches, "`local_batch_size` must be a multiple of `num_mini_batches`"
|
493 |
-
)
|
494 |
-
args.num_total_batches = math.ceil(
|
495 |
-
args.total_episodes / args.batch_size
|
496 |
-
) # we may train for more than `total_episodes`
|
497 |
-
time_tensor = torch.tensor(int(time.time()), device=accelerator.device)
|
498 |
-
time_int = broadcast(time_tensor, 0).item() # avoid different timestamps across processes
|
499 |
-
args.run_name = f"{args.exp_name}__{args.seed}__{time_int}"
|
500 |
-
self.local_seed = args.seed + accelerator.process_index * 100003 # Prime
|
501 |
-
if args.num_sample_generations > 0:
|
502 |
-
self.sample_generations_freq = max(1, args.num_total_batches // args.num_sample_generations)
|
503 |
-
self.local_dataloader_batch_size = exact_div(
|
504 |
-
args.local_batch_size, args.rloo_k, "`local_batch_size` must be a multiple of rloo_k"
|
505 |
-
) # RLOO logic: needed because RLOO repeats the same prompt args.rloo_k times
|
506 |
-
|
507 |
-
#########
|
508 |
-
# setup model, optimizer, and others
|
509 |
-
#########
|
510 |
-
for module in [policy, ref_policy, reward_model]:
|
511 |
-
if isinstance(module, nn.Module):
|
512 |
-
disable_dropout_in_model(module)
|
513 |
-
if args.stop_token and args.stop_token == "eos":
|
514 |
-
args.stop_token_id = self.processing_class.eos_token_id
|
515 |
-
self.model = policy
|
516 |
-
self.create_optimizer_and_scheduler(
|
517 |
-
num_training_steps=args.num_total_batches
|
518 |
-
) # note that we are calling `self.lr_scheduler.step()` manually only at the batch level
|
519 |
-
|
520 |
-
#########
|
521 |
-
### trainer specifics
|
522 |
-
#########
|
523 |
-
default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
|
524 |
-
self.callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
|
525 |
-
self.callback_handler = CallbackHandler(
|
526 |
-
self.callbacks, self.model, self.processing_class, self.optimizer, self.lr_scheduler
|
527 |
-
)
|
528 |
-
self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
|
529 |
-
self.control = TrainerControl()
|
530 |
-
self.state = OnlineTrainerState(
|
531 |
-
is_local_process_zero=self.is_local_process_zero(),
|
532 |
-
is_world_process_zero=self.is_world_process_zero(),
|
533 |
-
stateful_callbacks=[
|
534 |
-
cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
|
535 |
-
],
|
536 |
-
)
|
537 |
-
|
538 |
-
self.current_flos = 0
|
539 |
-
self.hp_search_backend = None
|
540 |
-
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
|
541 |
-
self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
|
542 |
-
# Create distant repo and output directory if needed
|
543 |
-
self.hub_model_id = None
|
544 |
-
if self.args.push_to_hub:
|
545 |
-
self.init_hf_repo()
|
546 |
-
if self.args.should_save:
|
547 |
-
os.makedirs(self.args.output_dir, exist_ok=True)
|
548 |
-
self.backup_model = None
|
549 |
-
|
550 |
-
# Add tags for models that have been loaded with the correct transformers version
|
551 |
-
if hasattr(self.model, "add_model_tags"):
|
552 |
-
self.model.add_model_tags(self._tag_names)
|
553 |
-
|
554 |
-
#########
|
555 |
-
### setup dataloader
|
556 |
-
#########
|
557 |
-
self.dataloader = DataLoader(
|
558 |
-
self.train_dataset,
|
559 |
-
batch_size=self.local_dataloader_batch_size,
|
560 |
-
shuffle=True,
|
561 |
-
collate_fn=self.data_collator,
|
562 |
-
drop_last=True, # needed; otherwise the last batch will be of ragged shape
|
563 |
-
)
|
564 |
-
# sync random states for DataLoader(shuffle=True) before `accelerator.prepare`
|
565 |
-
# see https://gist.github.com/vwxyzjn/2581bff1e48e185e0b85b6dfe1def79c
|
566 |
-
torch.manual_seed(args.seed)
|
567 |
-
self.model, self.optimizer, self.dataloader = accelerator.prepare(self.model, self.optimizer, self.dataloader)
|
568 |
-
torch.manual_seed(self.local_seed) # reset the local seed again
|
569 |
-
|
570 |
-
self.eval_dataloader = DataLoader(
|
571 |
-
self.eval_dataset,
|
572 |
-
batch_size=args.per_device_eval_batch_size,
|
573 |
-
collate_fn=self.data_collator,
|
574 |
-
drop_last=True,
|
575 |
-
) # no need to shuffle eval dataset
|
576 |
-
self.eval_dataloader = accelerator.prepare(self.eval_dataloader)
|
577 |
-
|
578 |
-
if self.is_deepspeed_enabled:
|
579 |
-
if isinstance(self.reward_model, nn.Module):
|
580 |
-
self.reward_model = prepare_deepspeed(
|
581 |
-
self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16
|
582 |
-
)
|
583 |
-
self.ref_policy = prepare_deepspeed(
|
584 |
-
self.ref_policy, args.per_device_train_batch_size, args.fp16, args.bf16
|
585 |
-
)
|
586 |
-
self.deepspeed = self.model
|
587 |
-
else:
|
588 |
-
self.ref_policy = self.ref_policy.to(self.accelerator.device)
|
589 |
-
if isinstance(self.reward_model, nn.Module):
|
590 |
-
self.reward_model = self.reward_model.to(self.accelerator.device)
|
591 |
-
|
592 |
-
def get_train_dataloader(self) -> DataLoader:
|
593 |
-
return self.dataloader
|
594 |
-
|
595 |
-
def get_eval_dataloader(self) -> DataLoader:
|
596 |
-
return self.eval_dataloader
|
597 |
-
|
598 |
-
def train(self):
|
599 |
-
args = self.args
|
600 |
-
accelerator = self.accelerator
|
601 |
-
optimizer = self.optimizer
|
602 |
-
model = self.model
|
603 |
-
self.model_wrapped = self.model
|
604 |
-
ref_policy = self.ref_policy
|
605 |
-
reward_model = self.reward_model
|
606 |
-
processing_class = self.processing_class
|
607 |
-
dataloader = self.dataloader
|
608 |
-
device = accelerator.device
|
609 |
-
|
610 |
-
def repeat_generator():
|
611 |
-
while True:
|
612 |
-
yield from dataloader
|
613 |
-
|
614 |
-
iter_dataloader = iter(repeat_generator())
|
615 |
-
generation_config = GenerationConfig(
|
616 |
-
max_new_tokens=args.response_length,
|
617 |
-
temperature=(args.temperature + 1e-7),
|
618 |
-
top_k=0.0,
|
619 |
-
top_p=1.0,
|
620 |
-
do_sample=True,
|
621 |
-
)
|
622 |
-
|
623 |
-
accelerator.print("===training policy===")
|
624 |
-
start_time = time.time()
|
625 |
-
stats_shape = (args.num_ppo_epochs, args.num_mini_batches, args.gradient_accumulation_steps)
|
626 |
-
approxkl_stats = torch.zeros(stats_shape, device=device)
|
627 |
-
pg_clipfrac_stats = torch.zeros(stats_shape, device=device)
|
628 |
-
pg_loss_stats = torch.zeros(stats_shape, device=device)
|
629 |
-
vf_clipfrac_stats = torch.zeros(stats_shape, device=device)
|
630 |
-
entropy_stats = torch.zeros(stats_shape, device=device)
|
631 |
-
ratio_stats = torch.zeros(stats_shape, device=device)
|
632 |
-
model.train()
|
633 |
-
|
634 |
-
# trainer state initialization
|
635 |
-
self.state.global_step = 0
|
636 |
-
self.state.episode = 0
|
637 |
-
self.state.max_steps = (args.num_total_batches * args.num_mini_batches) // 2
|
638 |
-
self.state.num_train_epochs = args.total_episodes / self.train_dataset_len
|
639 |
-
# Compute absolute values for logging, eval, and save if given as ratio
|
640 |
-
if args.logging_steps is not None:
|
641 |
-
if args.logging_steps < 1:
|
642 |
-
self.state.logging_steps = math.ceil(self.state.max_steps * args.logging_steps)
|
643 |
-
else:
|
644 |
-
self.state.logging_steps = args.logging_steps
|
645 |
-
if args.eval_steps is not None:
|
646 |
-
if args.eval_steps < 1:
|
647 |
-
self.state.eval_steps = math.ceil(self.state.max_steps * args.eval_steps)
|
648 |
-
else:
|
649 |
-
self.state.eval_steps = args.eval_steps
|
650 |
-
if args.save_steps is not None:
|
651 |
-
if args.save_steps < 1:
|
652 |
-
self.state.save_steps = math.ceil(self.state.max_steps * args.save_steps)
|
653 |
-
else:
|
654 |
-
self.state.save_steps = args.save_steps
|
655 |
-
self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
|
656 |
-
|
657 |
-
for update in range(1, args.num_total_batches + 1):
|
658 |
-
self.state.episode += 1 * args.batch_size
|
659 |
-
data = next(iter_dataloader)
|
660 |
-
with torch.no_grad():
|
661 |
-
queries = data["input_ids"].to(device)
|
662 |
-
queries = queries.repeat(args.rloo_k, 1)
|
663 |
-
context_length = queries.shape[1]
|
664 |
-
responses = []
|
665 |
-
postprocessed_responses = []
|
666 |
-
logprobs = []
|
667 |
-
ref_logprobs = []
|
668 |
-
scores = []
|
669 |
-
sequence_lengths = []
|
670 |
-
|
671 |
-
# Generate responses and compute logprobs
|
672 |
-
with unwrap_model_for_generation(
|
673 |
-
self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
|
674 |
-
) as unwrapped_model:
|
675 |
-
query_responses, logitss = batch_generation(
|
676 |
-
unwrapped_model,
|
677 |
-
queries,
|
678 |
-
args.local_rollout_forward_batch_size,
|
679 |
-
processing_class.pad_token_id,
|
680 |
-
generation_config,
|
681 |
-
)
|
682 |
-
|
683 |
-
# Process responses in batches
|
684 |
-
for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size):
|
685 |
-
query = queries[i : i + args.local_rollout_forward_batch_size]
|
686 |
-
query_response = query_responses[i : i + args.local_rollout_forward_batch_size]
|
687 |
-
response = query_response[:, context_length:]
|
688 |
-
logits = logitss[i : i + args.local_rollout_forward_batch_size]
|
689 |
-
logprob = selective_log_softmax(logits, response)
|
690 |
-
del logits
|
691 |
-
torch.cuda.empty_cache()
|
692 |
-
|
693 |
-
ref_output = forward(ref_policy, query_response, processing_class.pad_token_id)
|
694 |
-
ref_logits = ref_output.logits[:, context_length - 1 : -1]
|
695 |
-
ref_logits /= args.temperature + 1e-7
|
696 |
-
ref_logprob = selective_log_softmax(ref_logits, response)
|
697 |
-
del ref_output, ref_logits
|
698 |
-
torch.cuda.empty_cache()
|
699 |
-
|
700 |
-
# Response Processing 1. truncate response after the first occurrence of `stop_token_id`
|
701 |
-
postprocessed_response = response
|
702 |
-
if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
|
703 |
-
postprocessed_response = truncate_response(
|
704 |
-
args.stop_token_id, processing_class.pad_token_id, response
|
705 |
-
)
|
706 |
-
|
707 |
-
# Response Processing 2. run reward model on the truncated responses
|
708 |
-
postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
|
709 |
-
sequence_length = first_true_indices(postprocessed_response == processing_class.pad_token_id) - 1
|
710 |
-
|
711 |
-
if isinstance(reward_model, nn.Module):
|
712 |
-
_, score, _ = get_reward(
|
713 |
-
reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length
|
714 |
-
)
|
715 |
-
else:
|
716 |
-
score = torch.tensor(
|
717 |
-
reward_model(
|
718 |
-
processing_class.batch_decode(postprocessed_query_response, skip_special_tokens=True)
|
719 |
-
),
|
720 |
-
dtype=torch.float,
|
721 |
-
).to(device)
|
722 |
-
|
723 |
-
# Store batch results
|
724 |
-
responses.append(response)
|
725 |
-
postprocessed_responses.append(postprocessed_response)
|
726 |
-
logprobs.append(logprob)
|
727 |
-
ref_logprobs.append(ref_logprob)
|
728 |
-
sequence_lengths.append(sequence_length)
|
729 |
-
scores.append(score)
|
730 |
-
|
731 |
-
# Concatenate all batched results
|
732 |
-
responses = torch.cat(responses, 0)
|
733 |
-
postprocessed_responses = torch.cat(postprocessed_responses, 0)
|
734 |
-
logprobs = torch.cat(logprobs, 0)
|
735 |
-
ref_logprobs = torch.cat(ref_logprobs, 0)
|
736 |
-
sequence_lengths = torch.cat(sequence_lengths, 0)
|
737 |
-
scores = torch.cat(scores, 0)
|
738 |
-
del (logprob, ref_logprob, score)
|
739 |
-
torch.cuda.empty_cache()
|
740 |
-
gc.collect()
|
741 |
-
|
742 |
-
# Response Processing 3. filter response. Ensure that the sample contains stop_token_id
|
743 |
-
# responses not passing that filter will receive a low (fixed) score
|
744 |
-
# only query humans on responses that pass that filter
|
745 |
-
contain_eos_token = torch.any(postprocessed_responses == processing_class.eos_token_id, dim=-1)
|
746 |
-
if args.missing_eos_penalty is not None:
|
747 |
-
scores[~contain_eos_token] -= self.args.missing_eos_penalty
|
748 |
-
# accelerator.print(f"{scores=}, {(contain_eos_token.sum() / len(contain_eos_token))=}")
|
749 |
-
|
750 |
-
# be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw
|
751 |
-
response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1)
|
752 |
-
padding_mask = response_idxs > sequence_lengths.unsqueeze(1)
|
753 |
-
logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB)
|
754 |
-
ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB)
|
755 |
-
|
756 |
-
# 4. compute rewards
|
757 |
-
# Compute KL divergence
|
758 |
-
kl = logprobs - ref_logprobs
|
759 |
-
|
760 |
-
# Normalize rewards
|
761 |
-
if args.normalize_reward:
|
762 |
-
scores = (scores - scores.mean()) / (scores.std() + 1e-8)
|
763 |
-
scores = torch.clamp(scores, -args.reward_clip_range, args.reward_clip_range)
|
764 |
-
|
765 |
-
# Compute total reward with KL penalty
|
766 |
-
if args.token_level_kl:
|
767 |
-
# Token-level KL penalty: apply KL penalty per token
|
768 |
-
kl_reward = -args.kl_coef * kl
|
769 |
-
|
770 |
-
# Get the index of the last non-padded token for each sequence
|
771 |
-
eos_indices = padding_mask.size(1) - 1 - padding_mask.long().fliplr().argmax(dim=1, keepdim=True)
|
772 |
-
last_reward = torch.zeros_like(kl)
|
773 |
-
# Ensure scores has correct shape and type
|
774 |
-
scores_shaped = scores.reshape(-1, 1).to(kl.dtype)
|
775 |
-
last_reward.scatter_(dim=1, index=eos_indices, src=scores_shaped)
|
776 |
-
|
777 |
-
# Combine KL reward and last reward
|
778 |
-
non_score_reward = kl_reward.sum(1) # Keep this for logging
|
779 |
-
reward = last_reward + kl_reward
|
780 |
-
rlhf_reward = reward.sum(1) # Sum across sequence length
|
781 |
-
else:
|
782 |
-
# Sequence-level KL penalty: sum KL across tokens first
|
783 |
-
sequence_kl = kl.sum(1)
|
784 |
-
non_score_reward = -args.kl_coef * sequence_kl
|
785 |
-
rlhf_reward = non_score_reward + scores
|
786 |
-
|
787 |
-
# vectorized RLOO advantages implementation
|
788 |
-
rlhf_reward = rlhf_reward.reshape(args.rloo_k, -1)
|
789 |
-
baseline = (rlhf_reward.sum(0) - rlhf_reward) / (args.rloo_k - 1)
|
790 |
-
advantages = rlhf_reward - baseline
|
791 |
-
advantages = advantages.flatten()
|
792 |
-
|
793 |
-
# Normalize advantages
|
794 |
-
if args.normalize_advantage:
|
795 |
-
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
|
796 |
-
|
797 |
-
torch.cuda.empty_cache()
|
798 |
-
|
799 |
-
# Do multiple epochs of PPO training, with a fresh random shuffle in each epoch
|
800 |
-
for ppo_epoch_idx in range(args.num_ppo_epochs):
|
801 |
-
b_inds = np.random.permutation(args.local_batch_size)
|
802 |
-
minibatch_idx = 0
|
803 |
-
for mini_batch_start in range(0, args.local_batch_size, args.local_mini_batch_size):
|
804 |
-
mini_batch_end = mini_batch_start + args.local_mini_batch_size
|
805 |
-
mini_batch_inds = b_inds[mini_batch_start:mini_batch_end]
|
806 |
-
gradient_accumulation_idx = 0
|
807 |
-
for micro_batch_start in range(0, args.local_mini_batch_size, args.per_device_train_batch_size):
|
808 |
-
with accelerator.accumulate(model):
|
809 |
-
micro_batch_end = micro_batch_start + args.per_device_train_batch_size
|
810 |
-
micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end]
|
811 |
-
|
812 |
-
# Get batch data
|
813 |
-
mb_advantage = advantages[micro_batch_inds]
|
814 |
-
mb_responses = responses[micro_batch_inds]
|
815 |
-
mb_query_responses = query_responses[micro_batch_inds]
|
816 |
-
mb_logprobs = logprobs[micro_batch_inds]
|
817 |
-
|
818 |
-
# Forward pass
|
819 |
-
output = forward(model, mb_query_responses, processing_class.pad_token_id)
|
820 |
-
logits = output.logits[:, context_length - 1 : -1]
|
821 |
-
logits /= args.temperature + 1e-7
|
822 |
-
|
823 |
-
# Compute new logprobs
|
824 |
-
new_logprobs = selective_log_softmax(logits, mb_responses)
|
825 |
-
new_logprobs = torch.masked_fill(
|
826 |
-
new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB
|
827 |
-
)
|
828 |
-
|
829 |
-
# Compute probability ratios
|
830 |
-
new_ratio = (new_logprobs - mb_logprobs).exp()
|
831 |
-
new_logprobs = new_logprobs.sum(1)
|
832 |
-
mb_logprobs = mb_logprobs.sum(1)
|
833 |
-
logprobs_diff = new_logprobs - mb_logprobs
|
834 |
-
ratio = torch.exp(logprobs_diff)
|
835 |
-
|
836 |
-
# PPO clipped loss
|
837 |
-
pg_losses = -mb_advantage * ratio
|
838 |
-
pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange)
|
839 |
-
pg_loss_max = torch.max(pg_losses, pg_losses2)
|
840 |
-
pg_loss = pg_loss_max.mean()
|
841 |
-
|
842 |
-
# Final loss
|
843 |
-
loss = pg_loss
|
844 |
-
|
845 |
-
# Optimization step
|
846 |
-
accelerator.backward(loss)
|
847 |
-
optimizer.step()
|
848 |
-
optimizer.zero_grad()
|
849 |
-
|
850 |
-
with torch.no_grad():
|
851 |
-
pg_clipfrac = (pg_losses2 > pg_losses).float().mean()
|
852 |
-
prob_dist = torch.nn.functional.softmax(logits, dim=-1)
|
853 |
-
entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1)
|
854 |
-
approxkl = 0.5 * (logprobs_diff**2).mean()
|
855 |
-
approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl
|
856 |
-
pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = (
|
857 |
-
pg_clipfrac
|
858 |
-
)
|
859 |
-
pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss
|
860 |
-
entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean()
|
861 |
-
ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = new_ratio.mean()
|
862 |
-
gradient_accumulation_idx += 1
|
863 |
-
minibatch_idx += 1
|
864 |
-
|
865 |
-
# del everything and empty cache
|
866 |
-
# fmt: off
|
867 |
-
del (
|
868 |
-
output, logits, new_logprobs, logprobs_diff, ratio, pg_losses,
|
869 |
-
pg_losses2, pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl,
|
870 |
-
mb_advantage, mb_responses, mb_query_responses, mb_logprobs,
|
871 |
-
)
|
872 |
-
# fmt: on
|
873 |
-
torch.cuda.empty_cache()
|
874 |
-
|
875 |
-
# Compute metrics
|
876 |
-
with torch.no_grad():
|
877 |
-
mean_kl = kl.sum(1).mean()
|
878 |
-
mean_entropy = (-logprobs).sum(1).mean()
|
879 |
-
mean_non_score_reward = non_score_reward.mean()
|
880 |
-
eps = int(self.state.episode / (time.time() - start_time))
|
881 |
-
metrics = {}
|
882 |
-
metrics["eps"] = eps
|
883 |
-
metrics["objective/kl"] = self.accelerator.gather_for_metrics(mean_kl).mean().item()
|
884 |
-
metrics["objective/entropy"] = self.accelerator.gather_for_metrics(mean_entropy).mean().item()
|
885 |
-
metrics["objective/non_score_reward"] = (
|
886 |
-
self.accelerator.gather_for_metrics(mean_non_score_reward).mean().item()
|
887 |
-
)
|
888 |
-
metrics["objective/rlhf_reward"] = self.accelerator.gather_for_metrics(rlhf_reward).mean().item()
|
889 |
-
metrics["objective/scores"] = self.accelerator.gather_for_metrics(scores.mean()).mean().item()
|
890 |
-
metrics["policy/approxkl_avg"] = self.accelerator.gather_for_metrics(approxkl_stats).mean().item()
|
891 |
-
metrics["policy/clipfrac_avg"] = self.accelerator.gather_for_metrics(pg_clipfrac_stats).mean().item()
|
892 |
-
metrics["loss/policy_avg"] = self.accelerator.gather_for_metrics(pg_loss_stats).mean().item()
|
893 |
-
metrics["val/clipfrac_avg"] = self.accelerator.gather_for_metrics(vf_clipfrac_stats).mean().item()
|
894 |
-
metrics["policy/entropy_avg"] = self.accelerator.gather_for_metrics(entropy_stats).mean().item()
|
895 |
-
metrics["val/ratio"] = self.accelerator.gather_for_metrics(ratio_stats).mean().item()
|
896 |
-
metrics["val/ratio_var"] = self.accelerator.gather_for_metrics(ratio_stats).var().item()
|
897 |
-
metrics["val/num_eos_tokens"] = (responses == processing_class.eos_token_id).sum().item()
|
898 |
-
metrics["lr"] = self.lr_scheduler.get_last_lr()[0]
|
899 |
-
metrics["episode"] = self.state.episode
|
900 |
-
self.state.epoch = self.state.episode / (args.rloo_k * self.train_dataset_len) # used by self.log
|
901 |
-
self.log(metrics)
|
902 |
-
del kl, mean_kl, mean_entropy, scores
|
903 |
-
|
904 |
-
self.lr_scheduler.step()
|
905 |
-
self.state.global_step += 1
|
906 |
-
self.control = self.callback_handler.on_step_end(args, self.state, self.control)
|
907 |
-
if self.control.should_save:
|
908 |
-
self._save_checkpoint(model, trial=None)
|
909 |
-
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
|
910 |
-
torch.cuda.empty_cache()
|
911 |
-
gc.collect()
|
912 |
-
|
913 |
-
if args.num_sample_generations > 0 and (update - 1) % self.sample_generations_freq == 0:
|
914 |
-
self.generate_completions(sampling=True)
|
915 |
-
|
916 |
-
# HF trainer specifics
|
917 |
-
self.control = self.callback_handler.on_train_end(args, self.state, self.control)
|
918 |
-
if self.control.should_save:
|
919 |
-
self._save_checkpoint(model, trial=None, metrics=None)
|
920 |
-
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
|
921 |
-
|
922 |
-
def generate_completions(self, sampling: bool = False):
|
923 |
-
args = self.args
|
924 |
-
processing_class = self.processing_class
|
925 |
-
generation_config = GenerationConfig(
|
926 |
-
max_new_tokens=self.args.response_length,
|
927 |
-
temperature=(0.01 + 1e-7),
|
928 |
-
top_k=0.0,
|
929 |
-
top_p=1.0,
|
930 |
-
do_sample=True,
|
931 |
-
)
|
932 |
-
|
933 |
-
table = defaultdict(list)
|
934 |
-
with unwrap_model_for_generation(
|
935 |
-
self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
|
936 |
-
) as unwrapped_model:
|
937 |
-
for batch in self.eval_dataloader:
|
938 |
-
query = batch["input_ids"]
|
939 |
-
with torch.no_grad():
|
940 |
-
context_length = query.shape[1]
|
941 |
-
query_response, _ = batch_generation(
|
942 |
-
unwrapped_model,
|
943 |
-
query,
|
944 |
-
query.shape[0],
|
945 |
-
processing_class.pad_token_id,
|
946 |
-
generation_config,
|
947 |
-
)
|
948 |
-
response = query_response[:, context_length:]
|
949 |
-
postprocessed_response = response
|
950 |
-
if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
|
951 |
-
postprocessed_response = truncate_response(
|
952 |
-
args.stop_token_id, processing_class.pad_token_id, response
|
953 |
-
)
|
954 |
-
table["query"].extend(
|
955 |
-
gather_object(processing_class.batch_decode(query, skip_special_tokens=True))
|
956 |
-
)
|
957 |
-
table["model response"].extend(
|
958 |
-
gather_object(processing_class.batch_decode(postprocessed_response))
|
959 |
-
)
|
960 |
-
|
961 |
-
postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
|
962 |
-
|
963 |
-
if isinstance(self.reward_model, nn.Module):
|
964 |
-
_, score, _ = get_reward(
|
965 |
-
self.reward_model,
|
966 |
-
postprocessed_query_response,
|
967 |
-
processing_class.pad_token_id,
|
968 |
-
context_length,
|
969 |
-
)
|
970 |
-
else:
|
971 |
-
score = torch.tensor(
|
972 |
-
self.reward_model(
|
973 |
-
processing_class.batch_decode(postprocessed_query_response, skip_special_tokens=True)
|
974 |
-
),
|
975 |
-
dtype=torch.float,
|
976 |
-
).to(postprocessed_query_response.device)
|
977 |
-
table["score"].extend(self.accelerator.gather_for_metrics(score).float().cpu().numpy())
|
978 |
-
|
979 |
-
if sampling:
|
980 |
-
break
|
981 |
-
df = pd.DataFrame(table)
|
982 |
-
|
983 |
-
if self.accelerator.is_main_process:
|
984 |
-
print_rich_table(df.iloc[0 : 0 + 5])
|
985 |
-
if "wandb" in args.report_to:
|
986 |
-
import wandb
|
987 |
-
|
988 |
-
if wandb.run is not None:
|
989 |
-
wandb.log({"completions": wandb.Table(dataframe=df)})
|
990 |
-
|
991 |
-
if "comet_ml" in args.report_to:
|
992 |
-
log_table_to_comet_experiment(
|
993 |
-
name="completions.csv",
|
994 |
-
table=df,
|
995 |
-
)
|
996 |
-
|
997 |
-
def create_model_card(
|
998 |
-
self,
|
999 |
-
model_name: Optional[str] = None,
|
1000 |
-
dataset_name: Optional[str] = None,
|
1001 |
-
tags: Union[str, list[str], None] = None,
|
1002 |
-
):
|
1003 |
-
"""
|
1004 |
-
Creates a draft of a model card using the information available to the `Trainer`.
|
1005 |
-
|
1006 |
-
Args:
|
1007 |
-
model_name (`str` or `None`, *optional*, defaults to `None`):
|
1008 |
-
Name of the model.
|
1009 |
-
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
1010 |
-
Name of the dataset used for training.
|
1011 |
-
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
1012 |
-
Tags to be associated with the model card.
|
1013 |
-
"""
|
1014 |
-
if not self.is_world_process_zero():
|
1015 |
-
return
|
1016 |
-
|
1017 |
-
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
1018 |
-
base_model = self.model.config._name_or_path
|
1019 |
-
else:
|
1020 |
-
base_model = None
|
1021 |
-
|
1022 |
-
tags = tags or []
|
1023 |
-
if isinstance(tags, str):
|
1024 |
-
tags = [tags]
|
1025 |
-
|
1026 |
-
if hasattr(self.model.config, "unsloth_version"):
|
1027 |
-
tags.append("unsloth")
|
1028 |
-
|
1029 |
-
citation = textwrap.dedent("""\
|
1030 |
-
@inproceedings{ahmadian2024back,
|
1031 |
-
title = {{Back to Basics: Revisiting REINFORCE-Style Optimization for Learning from Human Feedback in LLMs}},
|
1032 |
-
author = {Arash Ahmadian and Chris Cremer and Matthias Gall{\'{e}} and Marzieh Fadaee and Julia Kreutzer and Olivier Pietquin and Ahmet {\"{U}}st{\"{u}}n and Sara Hooker},
|
1033 |
-
year = 2024,
|
1034 |
-
booktitle = {Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), {ACL} 2024, Bangkok, Thailand, August 11-16, 2024},
|
1035 |
-
publisher = {Association for Computational Linguistics},
|
1036 |
-
pages = {12248--12267},
|
1037 |
-
editor = {Lun{-}Wei Ku and Andre Martins and Vivek Srikumar},
|
1038 |
-
}""")
|
1039 |
-
|
1040 |
-
model_card = generate_model_card(
|
1041 |
-
base_model=base_model,
|
1042 |
-
model_name=model_name,
|
1043 |
-
hub_model_id=self.hub_model_id,
|
1044 |
-
dataset_name=dataset_name,
|
1045 |
-
tags=tags,
|
1046 |
-
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
1047 |
-
comet_url=get_comet_experiment_url(),
|
1048 |
-
trainer_name="RLOO",
|
1049 |
-
trainer_citation=citation,
|
1050 |
-
paper_title="Back to Basics: Revisiting REINFORCE-Style Optimization for Learning from Human Feedback in LLMs",
|
1051 |
-
paper_id="2402.14740",
|
1052 |
-
)
|
1053 |
-
|
1054 |
-
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
1055 |
-
class UnslothRLOOTrainer(_UnslothRLOOTrainer):
|
1056 |
-
"""
|
1057 |
-
|
1058 |
-
"""
|
1059 |
-
def __init__(
|
1060 |
-
self,
|
1061 |
-
config,
|
1062 |
-
processing_class,
|
1063 |
-
policy,
|
1064 |
-
ref_policy,
|
1065 |
-
reward_model,
|
1066 |
-
train_dataset,
|
1067 |
-
data_collator = None,
|
1068 |
-
eval_dataset = None,
|
1069 |
-
callbacks = None,
|
1070 |
-
**kwargs
|
1071 |
-
):
|
1072 |
-
if args is None: args = UnslothRLOOConfig()
|
1073 |
-
_output_logits = False
|
1074 |
-
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
1075 |
-
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
1076 |
-
if _output_logits:
|
1077 |
-
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
1078 |
-
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
1079 |
-
pass
|
1080 |
-
else:
|
1081 |
-
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
1082 |
-
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
1083 |
-
if args_max_seq_length is None and model_max_seq_length is not None:
|
1084 |
-
max_seq_length = model.max_seq_length
|
1085 |
-
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
1086 |
-
if model is not None and hasattr(model, 'for_training'):
|
1087 |
-
model.for_training()
|
1088 |
-
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
1089 |
-
if 'processing_class' in locals():
|
1090 |
-
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
1091 |
-
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
1092 |
-
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
1093 |
-
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
1094 |
-
if not isinstance(data_collator, UnslothVisionDataCollator):
|
1095 |
-
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
1096 |
-
data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
|
1097 |
-
elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
1098 |
-
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
1099 |
-
else:
|
1100 |
-
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
1101 |
-
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
1102 |
-
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
1103 |
-
if not isinstance(data_collator, UnslothVisionDataCollator):
|
1104 |
-
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
1105 |
-
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
1106 |
-
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
1107 |
-
else:
|
1108 |
-
data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
|
1109 |
-
other_metrics = []
|
1110 |
-
|
1111 |
-
from unsloth_zoo.logging_utils import PatchRLStatistics
|
1112 |
-
PatchRLStatistics('rloo_trainer', other_metrics)
|
1113 |
-
|
1114 |
-
super().__init__(
|
1115 |
-
config = config,
|
1116 |
-
processing_class = processing_class,
|
1117 |
-
policy = policy,
|
1118 |
-
ref_policy = ref_policy,
|
1119 |
-
reward_model = reward_model,
|
1120 |
-
train_dataset = train_dataset,
|
1121 |
-
data_collator = data_collator,
|
1122 |
-
eval_dataset = eval_dataset,
|
1123 |
-
callbacks = callbacks,**kwargs)
|
1124 |
-
if hasattr(self, 'neftune_hook_handle'):
|
1125 |
-
self.neftune_hook_handle.remove()
|
1126 |
-
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
1127 |
-
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
1128 |
-
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
1129 |
-
pass
|
1130 |
-
|
1131 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unsloth_compiled_cache/UnslothRewardTrainer.py
DELETED
@@ -1,817 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
2025.3.13
|
3 |
-
2025.3.15
|
4 |
-
4.48.3
|
5 |
-
0.15.2
|
6 |
-
__UNSLOTH_VERSIONING__
|
7 |
-
"""
|
8 |
-
from torch import Tensor
|
9 |
-
import torch
|
10 |
-
import torch.nn as nn
|
11 |
-
from torch.nn import functional as F
|
12 |
-
from trl.trainer.reward_trainer import (Any, BaseImageProcessor, Callable, DataCollator, Dataset, EvalPrediction, FeatureExtractionMixin, FrozenInstanceError, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, RewardConfig, RewardDataCollatorWithPadding, RewardTrainer, Trainer, TrainerCallback, Union, _tokenize, compute_accuracy, decode_and_strip_padding, defaultdict, disable_dropout_in_model, gather_object, generate_model_card, get_comet_experiment_url, inspect, is_peft_available, is_wandb_available, log_table_to_comet_experiment, maybe_apply_chat_template, nested_detach, nn, os, pd, prepare_model_for_kbit_training, print_rich_table, replace, torch, wandb, warnings)
|
13 |
-
|
14 |
-
|
15 |
-
import os
|
16 |
-
from typing import *
|
17 |
-
from dataclasses import dataclass, field
|
18 |
-
from packaging.version import Version
|
19 |
-
import torch
|
20 |
-
import numpy as np
|
21 |
-
from contextlib import nullcontext
|
22 |
-
from torch.nn import functional as F
|
23 |
-
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
|
24 |
-
|
25 |
-
torch_compile_options = {
|
26 |
-
"epilogue_fusion" : True,
|
27 |
-
"max_autotune" : False,
|
28 |
-
"shape_padding" : True,
|
29 |
-
"trace.enabled" : False,
|
30 |
-
"triton.cudagraphs" : False,
|
31 |
-
}
|
32 |
-
|
33 |
-
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
34 |
-
def selective_log_softmax(logits, index):
|
35 |
-
logits = logits.to(torch.float32)
|
36 |
-
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
37 |
-
# loop to reduce peak mem consumption
|
38 |
-
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
39 |
-
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
40 |
-
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
41 |
-
return per_token_logps
|
42 |
-
@dataclass
|
43 |
-
class UnslothRewardConfig(RewardConfig):
|
44 |
-
"""
|
45 |
-
|
46 |
-
Configuration class for the [`RewardTrainer`].
|
47 |
-
|
48 |
-
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
49 |
-
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
50 |
-
command line.
|
51 |
-
|
52 |
-
Parameters:
|
53 |
-
max_length (`int` or `None`, *optional*, defaults to `1024`):
|
54 |
-
Maximum length of the sequences (prompt + completion) in the batch, filters out entries that exceed the
|
55 |
-
limit. This argument is required if you want to use the default data collator.
|
56 |
-
disable_dropout (`bool`, *optional*, defaults to `True`):
|
57 |
-
Whether to disable dropout in the model.
|
58 |
-
dataset_num_proc (`int`, *optional*, defaults to `None`):
|
59 |
-
Number of processes to use for processing the dataset.
|
60 |
-
center_rewards_coefficient (`float`, *optional*, defaults to `None`):
|
61 |
-
Coefficient to incentivize the reward model to output mean-zero rewards (proposed by
|
62 |
-
https://huggingface.co/papers/2312.09244, Eq. 2). Recommended value: `0.01`.
|
63 |
-
remove_unused_columns (`bool`, *optional*, defaults to `False`):
|
64 |
-
Whether to remove the columns that are not used by the model's forward pass. Can be `True` only if
|
65 |
-
the dataset is pretokenized.
|
66 |
-
|
67 |
-
"""
|
68 |
-
vllm_sampling_params: Optional[Any] = field(
|
69 |
-
default = None,
|
70 |
-
metadata = {'help': 'vLLM SamplingParams'},
|
71 |
-
)
|
72 |
-
unsloth_num_chunks : Optional[int] = field(
|
73 |
-
default = -1,
|
74 |
-
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
75 |
-
)
|
76 |
-
def __init__(
|
77 |
-
self,
|
78 |
-
output_dir = None,
|
79 |
-
overwrite_output_dir = None,
|
80 |
-
do_train = False,
|
81 |
-
do_eval = False,
|
82 |
-
do_predict = False,
|
83 |
-
eval_strategy = 'no',
|
84 |
-
prediction_loss_only = False,
|
85 |
-
per_device_train_batch_size = 4,
|
86 |
-
per_device_eval_batch_size = 4,
|
87 |
-
per_gpu_train_batch_size = None,
|
88 |
-
per_gpu_eval_batch_size = None,
|
89 |
-
gradient_accumulation_steps = 2,
|
90 |
-
eval_accumulation_steps = 2,
|
91 |
-
eval_delay = 0,
|
92 |
-
torch_empty_cache_steps = 250,
|
93 |
-
learning_rate = 5e-05,
|
94 |
-
weight_decay = 0.01,
|
95 |
-
adam_beta1 = 0.9,
|
96 |
-
adam_beta2 = 0.999,
|
97 |
-
adam_epsilon = 1e-08,
|
98 |
-
max_grad_norm = 1.0,
|
99 |
-
num_train_epochs = 3.0,
|
100 |
-
max_steps = -1,
|
101 |
-
lr_scheduler_type = 'linear',
|
102 |
-
warmup_ratio = 0.1,
|
103 |
-
warmup_steps = 0,
|
104 |
-
log_level = 'passive',
|
105 |
-
log_level_replica = 'warning',
|
106 |
-
log_on_each_node = True,
|
107 |
-
logging_dir = None,
|
108 |
-
logging_strategy = 'steps',
|
109 |
-
logging_first_step = False,
|
110 |
-
logging_steps = 1,
|
111 |
-
logging_nan_inf_filter = False,
|
112 |
-
save_strategy = 'steps',
|
113 |
-
save_steps = 500,
|
114 |
-
save_total_limit = None,
|
115 |
-
save_safetensors = True,
|
116 |
-
save_on_each_node = False,
|
117 |
-
save_only_model = False,
|
118 |
-
restore_callback_states_from_checkpoint = False,
|
119 |
-
no_cuda = False,
|
120 |
-
use_cpu = False,
|
121 |
-
use_mps_device = False,
|
122 |
-
seed = 3407,
|
123 |
-
data_seed = 3407,
|
124 |
-
jit_mode_eval = False,
|
125 |
-
use_ipex = False,
|
126 |
-
bf16 = False,
|
127 |
-
fp16 = False,
|
128 |
-
fp16_opt_level = 'O1',
|
129 |
-
half_precision_backend = 'auto',
|
130 |
-
bf16_full_eval = False,
|
131 |
-
fp16_full_eval = False,
|
132 |
-
tf32 = None,
|
133 |
-
local_rank = -1,
|
134 |
-
ddp_backend = None,
|
135 |
-
tpu_num_cores = None,
|
136 |
-
tpu_metrics_debug = False,
|
137 |
-
debug = '',
|
138 |
-
dataloader_drop_last = False,
|
139 |
-
eval_steps = None,
|
140 |
-
dataloader_num_workers = 0,
|
141 |
-
dataloader_prefetch_factor = None,
|
142 |
-
past_index = -1,
|
143 |
-
run_name = None,
|
144 |
-
disable_tqdm = None,
|
145 |
-
remove_unused_columns = False,
|
146 |
-
label_names = None,
|
147 |
-
load_best_model_at_end = False,
|
148 |
-
metric_for_best_model = None,
|
149 |
-
greater_is_better = None,
|
150 |
-
ignore_data_skip = False,
|
151 |
-
fsdp = '',
|
152 |
-
fsdp_min_num_params = 0,
|
153 |
-
fsdp_config = None,
|
154 |
-
fsdp_transformer_layer_cls_to_wrap = None,
|
155 |
-
accelerator_config = None,
|
156 |
-
deepspeed = None,
|
157 |
-
label_smoothing_factor = 0.0,
|
158 |
-
optim = 'adamw_8bit',
|
159 |
-
optim_args = None,
|
160 |
-
adafactor = False,
|
161 |
-
group_by_length = False,
|
162 |
-
length_column_name = 'length',
|
163 |
-
report_to = None,
|
164 |
-
ddp_find_unused_parameters = None,
|
165 |
-
ddp_bucket_cap_mb = None,
|
166 |
-
ddp_broadcast_buffers = None,
|
167 |
-
dataloader_pin_memory = True,
|
168 |
-
dataloader_persistent_workers = False,
|
169 |
-
skip_memory_metrics = True,
|
170 |
-
use_legacy_prediction_loop = False,
|
171 |
-
push_to_hub = False,
|
172 |
-
resume_from_checkpoint = None,
|
173 |
-
hub_model_id = None,
|
174 |
-
hub_strategy = 'every_save',
|
175 |
-
hub_token = None,
|
176 |
-
hub_private_repo = None,
|
177 |
-
hub_always_push = False,
|
178 |
-
gradient_checkpointing = False,
|
179 |
-
gradient_checkpointing_kwargs = None,
|
180 |
-
include_inputs_for_metrics = False,
|
181 |
-
eval_do_concat_batches = True,
|
182 |
-
fp16_backend = 'auto',
|
183 |
-
evaluation_strategy = None,
|
184 |
-
push_to_hub_model_id = None,
|
185 |
-
push_to_hub_organization = None,
|
186 |
-
push_to_hub_token = None,
|
187 |
-
mp_parameters = '',
|
188 |
-
auto_find_batch_size = False,
|
189 |
-
full_determinism = False,
|
190 |
-
torchdynamo = None,
|
191 |
-
ray_scope = 'last',
|
192 |
-
ddp_timeout = 1800,
|
193 |
-
torch_compile = False,
|
194 |
-
torch_compile_backend = None,
|
195 |
-
torch_compile_mode = None,
|
196 |
-
dispatch_batches = None,
|
197 |
-
split_batches = None,
|
198 |
-
include_tokens_per_second = False,
|
199 |
-
include_num_input_tokens_seen = False,
|
200 |
-
neftune_noise_alpha = None,
|
201 |
-
optim_target_modules = None,
|
202 |
-
batch_eval_metrics = False,
|
203 |
-
eval_on_start = False,
|
204 |
-
use_liger_kernel = False,
|
205 |
-
eval_use_gather_object = False,
|
206 |
-
average_tokens_across_devices = False,
|
207 |
-
max_length = 1024,
|
208 |
-
disable_dropout = True,
|
209 |
-
dataset_num_proc = None,
|
210 |
-
center_rewards_coefficient = None,
|
211 |
-
vllm_sampling_params = None,
|
212 |
-
unsloth_num_chunks = -1,
|
213 |
-
**kwargs,
|
214 |
-
):
|
215 |
-
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
216 |
-
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
217 |
-
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
218 |
-
output_dir = 'unsloth_training_checkpoints'
|
219 |
-
save_strategy = 'no'
|
220 |
-
if dataset_num_proc is None:
|
221 |
-
from multiprocessing import cpu_count
|
222 |
-
dataset_num_proc = cpu_count()
|
223 |
-
|
224 |
-
super().__init__(
|
225 |
-
output_dir = output_dir,
|
226 |
-
overwrite_output_dir = overwrite_output_dir,
|
227 |
-
do_train = do_train,
|
228 |
-
do_eval = do_eval,
|
229 |
-
do_predict = do_predict,
|
230 |
-
eval_strategy = eval_strategy,
|
231 |
-
prediction_loss_only = prediction_loss_only,
|
232 |
-
per_device_train_batch_size = per_device_train_batch_size,
|
233 |
-
per_device_eval_batch_size = per_device_eval_batch_size,
|
234 |
-
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
235 |
-
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
236 |
-
gradient_accumulation_steps = gradient_accumulation_steps,
|
237 |
-
eval_accumulation_steps = eval_accumulation_steps,
|
238 |
-
eval_delay = eval_delay,
|
239 |
-
torch_empty_cache_steps = torch_empty_cache_steps,
|
240 |
-
learning_rate = learning_rate,
|
241 |
-
weight_decay = weight_decay,
|
242 |
-
adam_beta1 = adam_beta1,
|
243 |
-
adam_beta2 = adam_beta2,
|
244 |
-
adam_epsilon = adam_epsilon,
|
245 |
-
max_grad_norm = max_grad_norm,
|
246 |
-
num_train_epochs = num_train_epochs,
|
247 |
-
max_steps = max_steps,
|
248 |
-
lr_scheduler_type = lr_scheduler_type,
|
249 |
-
warmup_ratio = warmup_ratio,
|
250 |
-
warmup_steps = warmup_steps,
|
251 |
-
log_level = log_level,
|
252 |
-
log_level_replica = log_level_replica,
|
253 |
-
log_on_each_node = log_on_each_node,
|
254 |
-
logging_dir = logging_dir,
|
255 |
-
logging_strategy = logging_strategy,
|
256 |
-
logging_first_step = logging_first_step,
|
257 |
-
logging_steps = logging_steps,
|
258 |
-
logging_nan_inf_filter = logging_nan_inf_filter,
|
259 |
-
save_strategy = save_strategy,
|
260 |
-
save_steps = save_steps,
|
261 |
-
save_total_limit = save_total_limit,
|
262 |
-
save_safetensors = save_safetensors,
|
263 |
-
save_on_each_node = save_on_each_node,
|
264 |
-
save_only_model = save_only_model,
|
265 |
-
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
266 |
-
no_cuda = no_cuda,
|
267 |
-
use_cpu = use_cpu,
|
268 |
-
use_mps_device = use_mps_device,
|
269 |
-
seed = seed,
|
270 |
-
data_seed = data_seed,
|
271 |
-
jit_mode_eval = jit_mode_eval,
|
272 |
-
use_ipex = use_ipex,
|
273 |
-
bf16 = bf16,
|
274 |
-
fp16 = fp16,
|
275 |
-
fp16_opt_level = fp16_opt_level,
|
276 |
-
half_precision_backend = half_precision_backend,
|
277 |
-
bf16_full_eval = bf16_full_eval,
|
278 |
-
fp16_full_eval = fp16_full_eval,
|
279 |
-
tf32 = tf32,
|
280 |
-
local_rank = local_rank,
|
281 |
-
ddp_backend = ddp_backend,
|
282 |
-
tpu_num_cores = tpu_num_cores,
|
283 |
-
tpu_metrics_debug = tpu_metrics_debug,
|
284 |
-
debug = debug,
|
285 |
-
dataloader_drop_last = dataloader_drop_last,
|
286 |
-
eval_steps = eval_steps,
|
287 |
-
dataloader_num_workers = dataloader_num_workers,
|
288 |
-
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
289 |
-
past_index = past_index,
|
290 |
-
run_name = run_name,
|
291 |
-
disable_tqdm = disable_tqdm,
|
292 |
-
remove_unused_columns = remove_unused_columns,
|
293 |
-
label_names = label_names,
|
294 |
-
load_best_model_at_end = load_best_model_at_end,
|
295 |
-
metric_for_best_model = metric_for_best_model,
|
296 |
-
greater_is_better = greater_is_better,
|
297 |
-
ignore_data_skip = ignore_data_skip,
|
298 |
-
fsdp = fsdp,
|
299 |
-
fsdp_min_num_params = fsdp_min_num_params,
|
300 |
-
fsdp_config = fsdp_config,
|
301 |
-
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
302 |
-
accelerator_config = accelerator_config,
|
303 |
-
deepspeed = deepspeed,
|
304 |
-
label_smoothing_factor = label_smoothing_factor,
|
305 |
-
optim = optim,
|
306 |
-
optim_args = optim_args,
|
307 |
-
adafactor = adafactor,
|
308 |
-
group_by_length = group_by_length,
|
309 |
-
length_column_name = length_column_name,
|
310 |
-
report_to = report_to,
|
311 |
-
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
312 |
-
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
313 |
-
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
314 |
-
dataloader_pin_memory = dataloader_pin_memory,
|
315 |
-
dataloader_persistent_workers = dataloader_persistent_workers,
|
316 |
-
skip_memory_metrics = skip_memory_metrics,
|
317 |
-
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
318 |
-
push_to_hub = push_to_hub,
|
319 |
-
resume_from_checkpoint = resume_from_checkpoint,
|
320 |
-
hub_model_id = hub_model_id,
|
321 |
-
hub_strategy = hub_strategy,
|
322 |
-
hub_token = hub_token,
|
323 |
-
hub_private_repo = hub_private_repo,
|
324 |
-
hub_always_push = hub_always_push,
|
325 |
-
gradient_checkpointing = gradient_checkpointing,
|
326 |
-
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
327 |
-
include_inputs_for_metrics = include_inputs_for_metrics,
|
328 |
-
eval_do_concat_batches = eval_do_concat_batches,
|
329 |
-
fp16_backend = fp16_backend,
|
330 |
-
evaluation_strategy = evaluation_strategy,
|
331 |
-
push_to_hub_model_id = push_to_hub_model_id,
|
332 |
-
push_to_hub_organization = push_to_hub_organization,
|
333 |
-
push_to_hub_token = push_to_hub_token,
|
334 |
-
mp_parameters = mp_parameters,
|
335 |
-
auto_find_batch_size = auto_find_batch_size,
|
336 |
-
full_determinism = full_determinism,
|
337 |
-
torchdynamo = torchdynamo,
|
338 |
-
ray_scope = ray_scope,
|
339 |
-
ddp_timeout = ddp_timeout,
|
340 |
-
torch_compile = torch_compile,
|
341 |
-
torch_compile_backend = torch_compile_backend,
|
342 |
-
torch_compile_mode = torch_compile_mode,
|
343 |
-
dispatch_batches = dispatch_batches,
|
344 |
-
split_batches = split_batches,
|
345 |
-
include_tokens_per_second = include_tokens_per_second,
|
346 |
-
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
347 |
-
neftune_noise_alpha = neftune_noise_alpha,
|
348 |
-
optim_target_modules = optim_target_modules,
|
349 |
-
batch_eval_metrics = batch_eval_metrics,
|
350 |
-
eval_on_start = eval_on_start,
|
351 |
-
use_liger_kernel = use_liger_kernel,
|
352 |
-
eval_use_gather_object = eval_use_gather_object,
|
353 |
-
average_tokens_across_devices = average_tokens_across_devices,
|
354 |
-
max_length = max_length,
|
355 |
-
disable_dropout = disable_dropout,
|
356 |
-
dataset_num_proc = dataset_num_proc,
|
357 |
-
center_rewards_coefficient = center_rewards_coefficient,**kwargs)
|
358 |
-
self.vllm_sampling_params = vllm_sampling_params
|
359 |
-
self.unsloth_num_chunks = unsloth_num_chunks
|
360 |
-
pass
|
361 |
-
|
362 |
-
class _UnslothRewardTrainer(Trainer):
|
363 |
-
_tag_names = ["trl", "reward-trainer"]
|
364 |
-
|
365 |
-
def __init__(
|
366 |
-
self,
|
367 |
-
model: Optional[Union[PreTrainedModel, nn.Module]] = None,
|
368 |
-
args: Optional[RewardConfig] = None,
|
369 |
-
data_collator: Optional[DataCollator] = None,
|
370 |
-
train_dataset: Optional[Dataset] = None,
|
371 |
-
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
372 |
-
processing_class: Optional[
|
373 |
-
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
374 |
-
] = None,
|
375 |
-
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
376 |
-
compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
|
377 |
-
callbacks: Optional[list[TrainerCallback]] = None,
|
378 |
-
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (
|
379 |
-
None,
|
380 |
-
None,
|
381 |
-
),
|
382 |
-
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
383 |
-
peft_config: Optional[dict] = None,
|
384 |
-
):
|
385 |
-
"""
|
386 |
-
Initialize RewardTrainer.
|
387 |
-
|
388 |
-
Args:
|
389 |
-
model (`transformers.PreTrainedModel`):
|
390 |
-
The model to train, preferably an `AutoModelForSequenceClassification`.
|
391 |
-
args (`RewardConfig`):
|
392 |
-
The arguments to use for training.
|
393 |
-
data_collator (`transformers.DataCollator`):
|
394 |
-
The data collator to use for training. If None is specified, the default data collator (`RewardDataCollatorWithPadding`) will be used
|
395 |
-
which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
|
396 |
-
train_dataset (`datasets.Dataset`):
|
397 |
-
The dataset to use for training.
|
398 |
-
eval_dataset (`datasets.Dataset`):
|
399 |
-
The dataset to use for evaluation.
|
400 |
-
processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
|
401 |
-
Processing class used to process the data. If provided, will be used to automatically process the inputs
|
402 |
-
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
|
403 |
-
reuse the fine-tuned model.
|
404 |
-
model_init (`Callable[[], transformers.PreTrainedModel]`):
|
405 |
-
The model initializer to use for training. If None is specified, the default model initializer will be used.
|
406 |
-
compute_metrics (`Callable[[transformers.EvalPrediction], dict]`, *optional* defaults to `compute_accuracy`):
|
407 |
-
The metrics to use for evaluation. If no metrics are specified, the default metric (`compute_accuracy`) will be used.
|
408 |
-
callbacks (`list[transformers.TrainerCallback]`):
|
409 |
-
The callbacks to use for training.
|
410 |
-
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
|
411 |
-
The optimizer and scheduler to use for training.
|
412 |
-
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
|
413 |
-
The function to use to preprocess the logits before computing the metrics.
|
414 |
-
peft_config (`dict`, defaults to `None`):
|
415 |
-
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
|
416 |
-
"""
|
417 |
-
if not is_peft_available() and peft_config is not None:
|
418 |
-
raise ValueError(
|
419 |
-
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
|
420 |
-
)
|
421 |
-
elif is_peft_available() and peft_config is not None:
|
422 |
-
if not isinstance(model, PeftModel):
|
423 |
-
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_quantized", False):
|
424 |
-
_supports_gc_kwargs = "gradient_checkpointing_kwargs" in list(
|
425 |
-
inspect.signature(prepare_model_for_kbit_training).parameters
|
426 |
-
)
|
427 |
-
|
428 |
-
prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
|
429 |
-
|
430 |
-
if not _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
|
431 |
-
warnings.warn(
|
432 |
-
"You passed `gradient_checkpointing_kwargs` in the trainer's kwargs, but your peft version does not support it. "
|
433 |
-
"please update to the latest version of peft to use `gradient_checkpointing_kwargs`.",
|
434 |
-
UserWarning,
|
435 |
-
)
|
436 |
-
elif _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
|
437 |
-
prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
|
438 |
-
|
439 |
-
model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
|
440 |
-
|
441 |
-
model = model
|
442 |
-
|
443 |
-
# Disable dropout in the model
|
444 |
-
if args.disable_dropout:
|
445 |
-
disable_dropout_in_model(model)
|
446 |
-
|
447 |
-
if compute_metrics is None:
|
448 |
-
compute_metrics = compute_accuracy
|
449 |
-
|
450 |
-
if data_collator is None:
|
451 |
-
if processing_class is None:
|
452 |
-
raise ValueError(
|
453 |
-
"A processing_class must be specified when using the default RewardDataCollatorWithPadding"
|
454 |
-
)
|
455 |
-
|
456 |
-
max_length = args.max_length
|
457 |
-
|
458 |
-
data_collator = RewardDataCollatorWithPadding(processing_class)
|
459 |
-
|
460 |
-
if args.remove_unused_columns:
|
461 |
-
try: # for bc before https://github.com/huggingface/transformers/pull/25435
|
462 |
-
args.remove_unused_columns = False
|
463 |
-
except FrozenInstanceError:
|
464 |
-
args = replace(args, remove_unused_columns=False)
|
465 |
-
# warn users
|
466 |
-
warnings.warn(
|
467 |
-
"When using RewardDataCollatorWithPadding, you should set `remove_unused_columns=False` in your RewardConfig"
|
468 |
-
" we have set it for you, but you should do it yourself in the future.",
|
469 |
-
UserWarning,
|
470 |
-
)
|
471 |
-
|
472 |
-
self.use_reward_data_collator = True
|
473 |
-
else:
|
474 |
-
self.use_reward_data_collator = False
|
475 |
-
|
476 |
-
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
|
477 |
-
# input tensor associated with the key "input_ids". However, in Reward, the sampled data does not include the
|
478 |
-
# "input_ids" key. Instead, the available keys are "input_ids_chosen" and "input_ids_rejected". As a result,
|
479 |
-
# the trainer issues the warning: "Could not estimate the number of tokens of the input, floating-point
|
480 |
-
# operations will not be computed." To suppress this warning, we set the "estimate_tokens" key in the model's
|
481 |
-
# "warnings_issued" dictionary to True. This acts as a flag to indicate that the warning has already been
|
482 |
-
# issued.
|
483 |
-
model.warnings_issued["estimate_tokens"] = True
|
484 |
-
|
485 |
-
if "input_ids_chosen" not in train_dataset.column_names:
|
486 |
-
with PartialState().local_main_process_first():
|
487 |
-
fn_kwargs = {"tokenizer": processing_class}
|
488 |
-
train_dataset = train_dataset.map(maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class})
|
489 |
-
train_dataset = train_dataset.map(
|
490 |
-
_tokenize,
|
491 |
-
batched=True,
|
492 |
-
fn_kwargs=fn_kwargs,
|
493 |
-
num_proc=args.dataset_num_proc,
|
494 |
-
)
|
495 |
-
# This filter is important because otherwise you get samples that exceed the model's context length and
|
496 |
-
# get truncated => noisy signal the chosen/rejected label gets lost. The downside is that the
|
497 |
-
# user might get surprised if N samples are missing from training.
|
498 |
-
train_dataset = train_dataset.filter(
|
499 |
-
lambda x: len(x["input_ids_chosen"]) <= max_length and len(x["input_ids_rejected"]) <= max_length,
|
500 |
-
num_proc=args.dataset_num_proc,
|
501 |
-
)
|
502 |
-
if eval_dataset is not None:
|
503 |
-
eval_dataset = eval_dataset.map(
|
504 |
-
maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}
|
505 |
-
)
|
506 |
-
eval_dataset = eval_dataset.map(
|
507 |
-
_tokenize,
|
508 |
-
fn_kwargs=fn_kwargs,
|
509 |
-
batched=True,
|
510 |
-
num_proc=args.dataset_num_proc,
|
511 |
-
)
|
512 |
-
# This filter is important because otherwise you get samples that exceed the model's context length and
|
513 |
-
# get truncated => noisy signal the chosen/rejected label gets lost. The downside is that the
|
514 |
-
# user might get surprised if N samples are missing from training.
|
515 |
-
eval_dataset = eval_dataset.filter(
|
516 |
-
lambda x: len(x["input_ids_chosen"]) <= max_length
|
517 |
-
and len(x["input_ids_rejected"]) <= max_length,
|
518 |
-
num_proc=args.dataset_num_proc,
|
519 |
-
)
|
520 |
-
|
521 |
-
super().__init__(
|
522 |
-
model=model,
|
523 |
-
args=args,
|
524 |
-
data_collator=data_collator,
|
525 |
-
train_dataset=train_dataset,
|
526 |
-
eval_dataset=eval_dataset,
|
527 |
-
processing_class=processing_class,
|
528 |
-
model_init=model_init,
|
529 |
-
compute_metrics=compute_metrics,
|
530 |
-
callbacks=callbacks,
|
531 |
-
optimizers=optimizers,
|
532 |
-
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
533 |
-
)
|
534 |
-
|
535 |
-
# Add tags for models that have been loaded with the correct transformers version
|
536 |
-
if hasattr(self.model, "add_model_tags"):
|
537 |
-
self.model.add_model_tags(self._tag_names)
|
538 |
-
|
539 |
-
def compute_loss(
|
540 |
-
self,
|
541 |
-
model: Union[PreTrainedModel, nn.Module],
|
542 |
-
inputs: dict[str, Union[torch.Tensor, Any]],
|
543 |
-
return_outputs=False,
|
544 |
-
num_items_in_batch=None,
|
545 |
-
) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
|
546 |
-
rewards_chosen = model(
|
547 |
-
input_ids=inputs["input_ids_chosen"],
|
548 |
-
attention_mask=inputs["attention_mask_chosen"],
|
549 |
-
return_dict=True,
|
550 |
-
)["logits"]
|
551 |
-
rewards_rejected = model(
|
552 |
-
input_ids=inputs["input_ids_rejected"],
|
553 |
-
attention_mask=inputs["attention_mask_rejected"],
|
554 |
-
return_dict=True,
|
555 |
-
)["logits"]
|
556 |
-
# calculate loss, optionally modulate with margin
|
557 |
-
if "margin" in inputs:
|
558 |
-
loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected - inputs["margin"]).mean()
|
559 |
-
else:
|
560 |
-
loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected).mean()
|
561 |
-
|
562 |
-
if self.args.center_rewards_coefficient is not None:
|
563 |
-
loss += self.args.center_rewards_coefficient * torch.mean((rewards_chosen + rewards_rejected) ** 2)
|
564 |
-
|
565 |
-
if return_outputs:
|
566 |
-
return loss, {
|
567 |
-
"rewards_chosen": rewards_chosen,
|
568 |
-
"rewards_rejected": rewards_rejected,
|
569 |
-
}
|
570 |
-
return loss
|
571 |
-
|
572 |
-
def prediction_step(
|
573 |
-
self,
|
574 |
-
model: Union[PreTrainedModel, nn.Module],
|
575 |
-
inputs: dict[str, Union[torch.Tensor, Any]],
|
576 |
-
prediction_loss_only: bool,
|
577 |
-
ignore_keys: Optional[list[str]] = None,
|
578 |
-
) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
579 |
-
inputs = self._prepare_inputs(inputs)
|
580 |
-
if ignore_keys is None:
|
581 |
-
if hasattr(self.model, "config"):
|
582 |
-
ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
|
583 |
-
else:
|
584 |
-
ignore_keys = []
|
585 |
-
|
586 |
-
with torch.no_grad():
|
587 |
-
loss, logits_dict = self.compute_loss(model, inputs, return_outputs=True)
|
588 |
-
|
589 |
-
if prediction_loss_only:
|
590 |
-
return (loss, None, None)
|
591 |
-
|
592 |
-
loss = loss.detach()
|
593 |
-
logits = tuple(v for k, v in logits_dict.items() if k not in ignore_keys)
|
594 |
-
logits = nested_detach(logits)
|
595 |
-
# Stack accepted against rejected, mean over logits
|
596 |
-
# and softmax to get preferences between accepted and rejected to sum to 1
|
597 |
-
logits = torch.stack(logits).mean(dim=2).softmax(dim=0).T
|
598 |
-
|
599 |
-
labels = torch.zeros(logits.shape[0])
|
600 |
-
labels = self._prepare_inputs(labels)
|
601 |
-
|
602 |
-
return loss, logits, labels
|
603 |
-
|
604 |
-
def evaluate(self, *args, **kwargs):
|
605 |
-
num_print_samples = kwargs.pop("num_print_samples", 4)
|
606 |
-
self.visualize_samples(num_print_samples)
|
607 |
-
return super().evaluate(*args, **kwargs)
|
608 |
-
|
609 |
-
def visualize_samples(self, num_print_samples: int):
|
610 |
-
"""
|
611 |
-
Visualize the reward model logits prediction
|
612 |
-
|
613 |
-
Args:
|
614 |
-
num_print_samples (`int`, defaults to `4`):
|
615 |
-
The number of samples to print. Set to `-1` to print all samples.
|
616 |
-
"""
|
617 |
-
eval_dataloader = self.get_eval_dataloader()
|
618 |
-
table = defaultdict(list)
|
619 |
-
for _, inputs in enumerate(eval_dataloader):
|
620 |
-
_, logits, _ = self.prediction_step(self.model, inputs, prediction_loss_only=False)
|
621 |
-
chosen_text = decode_and_strip_padding(inputs["input_ids_chosen"], self.processing_class)
|
622 |
-
rejected_text = decode_and_strip_padding(inputs["input_ids_rejected"], self.processing_class)
|
623 |
-
table["chosen_text"].extend(gather_object(chosen_text))
|
624 |
-
table["rejected_text"].extend(gather_object(rejected_text))
|
625 |
-
table["logits"].extend(
|
626 |
-
gather_object([[round(inner_item, 4) for inner_item in item] for item in logits.tolist()])
|
627 |
-
)
|
628 |
-
if num_print_samples >= 0 and len(table["chosen_text"]) >= num_print_samples:
|
629 |
-
break
|
630 |
-
df = pd.DataFrame(table)
|
631 |
-
if self.accelerator.process_index == 0:
|
632 |
-
print_rich_table(df[:num_print_samples])
|
633 |
-
if "wandb" in self.args.report_to:
|
634 |
-
import wandb
|
635 |
-
|
636 |
-
if wandb.run is not None:
|
637 |
-
wandb.log({"completions": wandb.Table(dataframe=df)})
|
638 |
-
|
639 |
-
if "comet_ml" in self.args.report_to:
|
640 |
-
log_table_to_comet_experiment(
|
641 |
-
name="completions.csv",
|
642 |
-
table=df,
|
643 |
-
)
|
644 |
-
|
645 |
-
def create_model_card(
|
646 |
-
self,
|
647 |
-
model_name: Optional[str] = None,
|
648 |
-
dataset_name: Optional[str] = None,
|
649 |
-
tags: Union[str, list[str], None] = None,
|
650 |
-
):
|
651 |
-
"""
|
652 |
-
Creates a draft of a model card using the information available to the `Trainer`.
|
653 |
-
|
654 |
-
Args:
|
655 |
-
model_name (`str` or `None`, *optional*, defaults to `None`):
|
656 |
-
Name of the model.
|
657 |
-
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
658 |
-
Name of the dataset used for training.
|
659 |
-
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
660 |
-
Tags to be associated with the model card.
|
661 |
-
"""
|
662 |
-
if not self.is_world_process_zero():
|
663 |
-
return
|
664 |
-
|
665 |
-
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
666 |
-
base_model = self.model.config._name_or_path
|
667 |
-
else:
|
668 |
-
base_model = None
|
669 |
-
|
670 |
-
tags = tags or []
|
671 |
-
if isinstance(tags, str):
|
672 |
-
tags = [tags]
|
673 |
-
|
674 |
-
if hasattr(self.model.config, "unsloth_version"):
|
675 |
-
tags.append("unsloth")
|
676 |
-
|
677 |
-
model_card = generate_model_card(
|
678 |
-
base_model=base_model,
|
679 |
-
model_name=model_name,
|
680 |
-
hub_model_id=self.hub_model_id,
|
681 |
-
dataset_name=dataset_name,
|
682 |
-
tags=tags,
|
683 |
-
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
684 |
-
comet_url=get_comet_experiment_url(),
|
685 |
-
trainer_name="Reward",
|
686 |
-
)
|
687 |
-
|
688 |
-
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
689 |
-
class UnslothRewardTrainer(_UnslothRewardTrainer):
|
690 |
-
"""
|
691 |
-
|
692 |
-
"""
|
693 |
-
def __init__(
|
694 |
-
self,
|
695 |
-
model = None,
|
696 |
-
args = None,
|
697 |
-
data_collator = None,
|
698 |
-
train_dataset = None,
|
699 |
-
eval_dataset = None,
|
700 |
-
processing_class = None,
|
701 |
-
model_init = None,
|
702 |
-
compute_metrics = None,
|
703 |
-
callbacks = None,
|
704 |
-
preprocess_logits_for_metrics = None,
|
705 |
-
peft_config = None,
|
706 |
-
**kwargs
|
707 |
-
):
|
708 |
-
if args is None: args = UnslothRewardConfig()
|
709 |
-
use_bf16 = getattr(args, 'bf16', False)
|
710 |
-
use_fp16 = getattr(args, 'fp16', False)
|
711 |
-
force_float32 = False
|
712 |
-
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
713 |
-
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
714 |
-
force_float32 = True
|
715 |
-
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
716 |
-
dtype = getattr(model.config, 'torch_dtype', None)
|
717 |
-
if dtype is None: dtype = model.get_input_embeddings().dtype
|
718 |
-
from unsloth_zoo.utils import _get_dtype
|
719 |
-
dtype = _get_dtype(dtype)
|
720 |
-
float16 = dtype == torch.float16
|
721 |
-
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
722 |
-
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
723 |
-
if force_float32:
|
724 |
-
args.fp16 = False
|
725 |
-
args.bf16 = False
|
726 |
-
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
727 |
-
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
728 |
-
args.fp16 = float16
|
729 |
-
args.bf16 = not float16
|
730 |
-
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
731 |
-
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
732 |
-
args.eval_strategy = 'steps'
|
733 |
-
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
734 |
-
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
735 |
-
if ga_steps is not None and ga_steps > 1:
|
736 |
-
from transformers import __version__ as transformers_version
|
737 |
-
if Version(transformers_version) <= Version('4.45.2'):
|
738 |
-
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
739 |
-
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
740 |
-
if getattr(args, 'eval_strategy', 'no') != 'no':
|
741 |
-
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
742 |
-
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
743 |
-
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
744 |
-
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
745 |
-
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
746 |
-
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
747 |
-
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
748 |
-
if force_float32:
|
749 |
-
args.bf16_full_eval = False
|
750 |
-
args.fp16_full_eval = False
|
751 |
-
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
752 |
-
args.bf16_full_eval = True
|
753 |
-
args.fp16_full_eval = False
|
754 |
-
elif not bf16_full_eval and not fp16_full_eval:
|
755 |
-
args.bf16_full_eval = args.bf16
|
756 |
-
args.fp16_full_eval = args.fp16
|
757 |
-
_output_logits = False
|
758 |
-
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
759 |
-
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
760 |
-
if _output_logits:
|
761 |
-
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
762 |
-
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
763 |
-
pass
|
764 |
-
else:
|
765 |
-
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
766 |
-
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
767 |
-
if args_max_seq_length is None and model_max_seq_length is not None:
|
768 |
-
max_seq_length = model.max_seq_length
|
769 |
-
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
770 |
-
if model is not None and hasattr(model, 'for_training'):
|
771 |
-
model.for_training()
|
772 |
-
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
773 |
-
if 'processing_class' in locals():
|
774 |
-
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
775 |
-
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
776 |
-
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
777 |
-
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
778 |
-
if not isinstance(data_collator, UnslothVisionDataCollator):
|
779 |
-
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
780 |
-
data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
|
781 |
-
elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
782 |
-
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
783 |
-
else:
|
784 |
-
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
785 |
-
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
786 |
-
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
787 |
-
if not isinstance(data_collator, UnslothVisionDataCollator):
|
788 |
-
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
789 |
-
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
790 |
-
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
791 |
-
else:
|
792 |
-
data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
|
793 |
-
other_metrics = []
|
794 |
-
|
795 |
-
from unsloth_zoo.logging_utils import PatchRLStatistics
|
796 |
-
PatchRLStatistics('reward_trainer', other_metrics)
|
797 |
-
|
798 |
-
super().__init__(
|
799 |
-
model = model,
|
800 |
-
args = args,
|
801 |
-
data_collator = data_collator,
|
802 |
-
train_dataset = train_dataset,
|
803 |
-
eval_dataset = eval_dataset,
|
804 |
-
processing_class = processing_class,
|
805 |
-
model_init = model_init,
|
806 |
-
compute_metrics = compute_metrics,
|
807 |
-
callbacks = callbacks,
|
808 |
-
preprocess_logits_for_metrics = preprocess_logits_for_metrics,
|
809 |
-
peft_config = peft_config,**kwargs)
|
810 |
-
if hasattr(self, 'neftune_hook_handle'):
|
811 |
-
self.neftune_hook_handle.remove()
|
812 |
-
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
813 |
-
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
814 |
-
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
815 |
-
pass
|
816 |
-
|
817 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unsloth_compiled_cache/UnslothSFTTrainer.py
DELETED
@@ -1,1025 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
2025.3.13
|
3 |
-
2025.3.15
|
4 |
-
4.48.3
|
5 |
-
0.15.2
|
6 |
-
__UNSLOTH_VERSIONING__
|
7 |
-
"""
|
8 |
-
from torch import Tensor
|
9 |
-
import torch
|
10 |
-
import torch.nn as nn
|
11 |
-
from torch.nn import functional as F
|
12 |
-
from trl.trainer.sft_trainer import (Any, AutoModelForCausalLM, AutoTokenizer, BaseImageProcessor, Callable, ConstantLengthDataset, DataCollator, DataCollatorForLanguageModeling, Dataset, EvalPrediction, FeatureExtractionMixin, IterableDataset, Optional, PeftConfig, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SFTConfig, SFTTrainer, Trainer, TrainerCallback, TrainingArguments, Type, Union, dataclasses, defaultdict, deprecate_kwarg, generate_model_card, get_comet_experiment_url, get_peft_model, is_liger_kernel_available, is_peft_available, is_wandb_available, nn, os, pack_examples, peft, peft_module_casting_to_bf16, prepare_model_for_kbit_training, torch, transformers, version, wandb, warnings, Callable, ConstantLengthDataset, DataCollator, DataCollatorForLanguageModeling, Dataset, IterableDataset, Optional, Union, os, pack_examples, transformers, os)
|
13 |
-
|
14 |
-
|
15 |
-
import os
|
16 |
-
from typing import *
|
17 |
-
from dataclasses import dataclass, field
|
18 |
-
from packaging.version import Version
|
19 |
-
import torch
|
20 |
-
import numpy as np
|
21 |
-
from contextlib import nullcontext
|
22 |
-
from torch.nn import functional as F
|
23 |
-
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
|
24 |
-
|
25 |
-
torch_compile_options = {
|
26 |
-
"epilogue_fusion" : True,
|
27 |
-
"max_autotune" : False,
|
28 |
-
"shape_padding" : True,
|
29 |
-
"trace.enabled" : False,
|
30 |
-
"triton.cudagraphs" : False,
|
31 |
-
}
|
32 |
-
|
33 |
-
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
34 |
-
def selective_log_softmax(logits, index):
|
35 |
-
logits = logits.to(torch.float32)
|
36 |
-
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
37 |
-
# loop to reduce peak mem consumption
|
38 |
-
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
39 |
-
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
40 |
-
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
41 |
-
return per_token_logps
|
42 |
-
@dataclass
|
43 |
-
class UnslothSFTConfig(SFTConfig):
|
44 |
-
"""
|
45 |
-
|
46 |
-
Configuration class for the [`SFTTrainer`].
|
47 |
-
|
48 |
-
Only the parameters specific to SFT training are listed here. For details on other parameters, refer to the
|
49 |
-
[`~transformers.TrainingArguments`] documentation.
|
50 |
-
|
51 |
-
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
52 |
-
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
53 |
-
command line.
|
54 |
-
|
55 |
-
Parameters:
|
56 |
-
> Parameters that control the model
|
57 |
-
|
58 |
-
model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
|
59 |
-
Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model`
|
60 |
-
argument of the [`SFTTrainer`] is provided as a string.
|
61 |
-
use_liger (`bool`, *optional*, defaults to `False`):
|
62 |
-
Monkey patch the model with Liger kernels to increase throughput and reduce memory usage.
|
63 |
-
|
64 |
-
> Parameters that control the data preprocessing
|
65 |
-
|
66 |
-
dataset_text_field (`str`, *optional*, defaults to `"text"`):
|
67 |
-
Name of the column that contains text data in the dataset.
|
68 |
-
dataset_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
|
69 |
-
Dictionary of optional keyword arguments for the dataset preparation. The only supported key is
|
70 |
-
`skip_prepare_dataset`.
|
71 |
-
dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
|
72 |
-
Number of processes to use for processing the dataset.
|
73 |
-
max_seq_length (`int` or `None`, *optional*, defaults to `1024`):
|
74 |
-
Maximum length of the tokenized sequence. Sequences longer than `max_seq_length` are truncated from the
|
75 |
-
right.
|
76 |
-
If `None`, no truncation is applied. When packing is enabled, this value sets the sequence length.
|
77 |
-
packing (`bool`, *optional*, defaults to `False`):
|
78 |
-
Whether to pack multiple sequences into a fixed-length format. Uses `max_seq_length` to define sequence
|
79 |
-
length.
|
80 |
-
eval_packing (`bool` or `None`, *optional*, defaults to `None`):
|
81 |
-
Whether to pack the eval dataset. If `None`, uses the same value as `packing`.
|
82 |
-
|
83 |
-
> Parameters that control the training
|
84 |
-
|
85 |
-
learning_rate (`float`, *optional*, defaults to `2e-5`):
|
86 |
-
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
|
87 |
-
[`~transformers.TrainingArguments`].
|
88 |
-
|
89 |
-
"""
|
90 |
-
vllm_sampling_params: Optional[Any] = field(
|
91 |
-
default = None,
|
92 |
-
metadata = {'help': 'vLLM SamplingParams'},
|
93 |
-
)
|
94 |
-
unsloth_num_chunks : Optional[int] = field(
|
95 |
-
default = -1,
|
96 |
-
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
97 |
-
)
|
98 |
-
def __init__(
|
99 |
-
self,
|
100 |
-
output_dir = None,
|
101 |
-
overwrite_output_dir = None,
|
102 |
-
do_train = False,
|
103 |
-
do_eval = False,
|
104 |
-
do_predict = False,
|
105 |
-
eval_strategy = 'no',
|
106 |
-
prediction_loss_only = False,
|
107 |
-
per_device_train_batch_size = 4,
|
108 |
-
per_device_eval_batch_size = 4,
|
109 |
-
per_gpu_train_batch_size = None,
|
110 |
-
per_gpu_eval_batch_size = None,
|
111 |
-
gradient_accumulation_steps = 2,
|
112 |
-
eval_accumulation_steps = 2,
|
113 |
-
eval_delay = 0,
|
114 |
-
torch_empty_cache_steps = 250,
|
115 |
-
learning_rate = 5e-05,
|
116 |
-
weight_decay = 0.01,
|
117 |
-
adam_beta1 = 0.9,
|
118 |
-
adam_beta2 = 0.999,
|
119 |
-
adam_epsilon = 1e-08,
|
120 |
-
max_grad_norm = 1.0,
|
121 |
-
num_train_epochs = 3.0,
|
122 |
-
max_steps = -1,
|
123 |
-
lr_scheduler_type = 'linear',
|
124 |
-
warmup_ratio = 0.1,
|
125 |
-
warmup_steps = 0,
|
126 |
-
log_level = 'passive',
|
127 |
-
log_level_replica = 'warning',
|
128 |
-
log_on_each_node = True,
|
129 |
-
logging_dir = None,
|
130 |
-
logging_strategy = 'steps',
|
131 |
-
logging_first_step = False,
|
132 |
-
logging_steps = 1,
|
133 |
-
logging_nan_inf_filter = False,
|
134 |
-
save_strategy = 'steps',
|
135 |
-
save_steps = 500,
|
136 |
-
save_total_limit = None,
|
137 |
-
save_safetensors = True,
|
138 |
-
save_on_each_node = False,
|
139 |
-
save_only_model = False,
|
140 |
-
restore_callback_states_from_checkpoint = False,
|
141 |
-
no_cuda = False,
|
142 |
-
use_cpu = False,
|
143 |
-
use_mps_device = False,
|
144 |
-
seed = 3407,
|
145 |
-
data_seed = 3407,
|
146 |
-
jit_mode_eval = False,
|
147 |
-
use_ipex = False,
|
148 |
-
bf16 = False,
|
149 |
-
fp16 = False,
|
150 |
-
fp16_opt_level = 'O1',
|
151 |
-
half_precision_backend = 'auto',
|
152 |
-
bf16_full_eval = False,
|
153 |
-
fp16_full_eval = False,
|
154 |
-
tf32 = None,
|
155 |
-
local_rank = -1,
|
156 |
-
ddp_backend = None,
|
157 |
-
tpu_num_cores = None,
|
158 |
-
tpu_metrics_debug = False,
|
159 |
-
debug = '',
|
160 |
-
dataloader_drop_last = False,
|
161 |
-
eval_steps = None,
|
162 |
-
dataloader_num_workers = 0,
|
163 |
-
dataloader_prefetch_factor = None,
|
164 |
-
past_index = -1,
|
165 |
-
run_name = None,
|
166 |
-
disable_tqdm = None,
|
167 |
-
remove_unused_columns = True,
|
168 |
-
label_names = None,
|
169 |
-
load_best_model_at_end = False,
|
170 |
-
metric_for_best_model = None,
|
171 |
-
greater_is_better = None,
|
172 |
-
ignore_data_skip = False,
|
173 |
-
fsdp = '',
|
174 |
-
fsdp_min_num_params = 0,
|
175 |
-
fsdp_config = None,
|
176 |
-
fsdp_transformer_layer_cls_to_wrap = None,
|
177 |
-
accelerator_config = None,
|
178 |
-
deepspeed = None,
|
179 |
-
label_smoothing_factor = 0.0,
|
180 |
-
optim = 'adamw_8bit',
|
181 |
-
optim_args = None,
|
182 |
-
adafactor = False,
|
183 |
-
group_by_length = False,
|
184 |
-
length_column_name = 'length',
|
185 |
-
report_to = None,
|
186 |
-
ddp_find_unused_parameters = None,
|
187 |
-
ddp_bucket_cap_mb = None,
|
188 |
-
ddp_broadcast_buffers = None,
|
189 |
-
dataloader_pin_memory = True,
|
190 |
-
dataloader_persistent_workers = False,
|
191 |
-
skip_memory_metrics = True,
|
192 |
-
use_legacy_prediction_loop = False,
|
193 |
-
push_to_hub = False,
|
194 |
-
resume_from_checkpoint = None,
|
195 |
-
hub_model_id = None,
|
196 |
-
hub_strategy = 'every_save',
|
197 |
-
hub_token = None,
|
198 |
-
hub_private_repo = None,
|
199 |
-
hub_always_push = False,
|
200 |
-
gradient_checkpointing = False,
|
201 |
-
gradient_checkpointing_kwargs = None,
|
202 |
-
include_inputs_for_metrics = False,
|
203 |
-
eval_do_concat_batches = True,
|
204 |
-
fp16_backend = 'auto',
|
205 |
-
evaluation_strategy = None,
|
206 |
-
push_to_hub_model_id = None,
|
207 |
-
push_to_hub_organization = None,
|
208 |
-
push_to_hub_token = None,
|
209 |
-
mp_parameters = '',
|
210 |
-
auto_find_batch_size = False,
|
211 |
-
full_determinism = False,
|
212 |
-
torchdynamo = None,
|
213 |
-
ray_scope = 'last',
|
214 |
-
ddp_timeout = 1800,
|
215 |
-
torch_compile = False,
|
216 |
-
torch_compile_backend = None,
|
217 |
-
torch_compile_mode = None,
|
218 |
-
dispatch_batches = None,
|
219 |
-
split_batches = None,
|
220 |
-
include_tokens_per_second = False,
|
221 |
-
include_num_input_tokens_seen = False,
|
222 |
-
neftune_noise_alpha = None,
|
223 |
-
optim_target_modules = None,
|
224 |
-
batch_eval_metrics = False,
|
225 |
-
eval_on_start = False,
|
226 |
-
use_liger_kernel = False,
|
227 |
-
eval_use_gather_object = False,
|
228 |
-
average_tokens_across_devices = False,
|
229 |
-
model_init_kwargs = None,
|
230 |
-
use_liger = False,
|
231 |
-
dataset_text_field = 'text',
|
232 |
-
dataset_kwargs = None,
|
233 |
-
dataset_num_proc = None,
|
234 |
-
max_seq_length = None,
|
235 |
-
packing = False,
|
236 |
-
eval_packing = None,
|
237 |
-
dataset_batch_size = None,
|
238 |
-
num_of_sequences = None,
|
239 |
-
chars_per_token = None,
|
240 |
-
vllm_sampling_params = None,
|
241 |
-
unsloth_num_chunks = -1,
|
242 |
-
**kwargs,
|
243 |
-
):
|
244 |
-
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
245 |
-
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
246 |
-
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
247 |
-
output_dir = 'unsloth_training_checkpoints'
|
248 |
-
save_strategy = 'no'
|
249 |
-
if dataset_num_proc is None:
|
250 |
-
from multiprocessing import cpu_count
|
251 |
-
dataset_num_proc = cpu_count()
|
252 |
-
|
253 |
-
super().__init__(
|
254 |
-
output_dir = output_dir,
|
255 |
-
overwrite_output_dir = overwrite_output_dir,
|
256 |
-
do_train = do_train,
|
257 |
-
do_eval = do_eval,
|
258 |
-
do_predict = do_predict,
|
259 |
-
eval_strategy = eval_strategy,
|
260 |
-
prediction_loss_only = prediction_loss_only,
|
261 |
-
per_device_train_batch_size = per_device_train_batch_size,
|
262 |
-
per_device_eval_batch_size = per_device_eval_batch_size,
|
263 |
-
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
264 |
-
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
265 |
-
gradient_accumulation_steps = gradient_accumulation_steps,
|
266 |
-
eval_accumulation_steps = eval_accumulation_steps,
|
267 |
-
eval_delay = eval_delay,
|
268 |
-
torch_empty_cache_steps = torch_empty_cache_steps,
|
269 |
-
learning_rate = learning_rate,
|
270 |
-
weight_decay = weight_decay,
|
271 |
-
adam_beta1 = adam_beta1,
|
272 |
-
adam_beta2 = adam_beta2,
|
273 |
-
adam_epsilon = adam_epsilon,
|
274 |
-
max_grad_norm = max_grad_norm,
|
275 |
-
num_train_epochs = num_train_epochs,
|
276 |
-
max_steps = max_steps,
|
277 |
-
lr_scheduler_type = lr_scheduler_type,
|
278 |
-
warmup_ratio = warmup_ratio,
|
279 |
-
warmup_steps = warmup_steps,
|
280 |
-
log_level = log_level,
|
281 |
-
log_level_replica = log_level_replica,
|
282 |
-
log_on_each_node = log_on_each_node,
|
283 |
-
logging_dir = logging_dir,
|
284 |
-
logging_strategy = logging_strategy,
|
285 |
-
logging_first_step = logging_first_step,
|
286 |
-
logging_steps = logging_steps,
|
287 |
-
logging_nan_inf_filter = logging_nan_inf_filter,
|
288 |
-
save_strategy = save_strategy,
|
289 |
-
save_steps = save_steps,
|
290 |
-
save_total_limit = save_total_limit,
|
291 |
-
save_safetensors = save_safetensors,
|
292 |
-
save_on_each_node = save_on_each_node,
|
293 |
-
save_only_model = save_only_model,
|
294 |
-
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
295 |
-
no_cuda = no_cuda,
|
296 |
-
use_cpu = use_cpu,
|
297 |
-
use_mps_device = use_mps_device,
|
298 |
-
seed = seed,
|
299 |
-
data_seed = data_seed,
|
300 |
-
jit_mode_eval = jit_mode_eval,
|
301 |
-
use_ipex = use_ipex,
|
302 |
-
bf16 = bf16,
|
303 |
-
fp16 = fp16,
|
304 |
-
fp16_opt_level = fp16_opt_level,
|
305 |
-
half_precision_backend = half_precision_backend,
|
306 |
-
bf16_full_eval = bf16_full_eval,
|
307 |
-
fp16_full_eval = fp16_full_eval,
|
308 |
-
tf32 = tf32,
|
309 |
-
local_rank = local_rank,
|
310 |
-
ddp_backend = ddp_backend,
|
311 |
-
tpu_num_cores = tpu_num_cores,
|
312 |
-
tpu_metrics_debug = tpu_metrics_debug,
|
313 |
-
debug = debug,
|
314 |
-
dataloader_drop_last = dataloader_drop_last,
|
315 |
-
eval_steps = eval_steps,
|
316 |
-
dataloader_num_workers = dataloader_num_workers,
|
317 |
-
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
318 |
-
past_index = past_index,
|
319 |
-
run_name = run_name,
|
320 |
-
disable_tqdm = disable_tqdm,
|
321 |
-
remove_unused_columns = remove_unused_columns,
|
322 |
-
label_names = label_names,
|
323 |
-
load_best_model_at_end = load_best_model_at_end,
|
324 |
-
metric_for_best_model = metric_for_best_model,
|
325 |
-
greater_is_better = greater_is_better,
|
326 |
-
ignore_data_skip = ignore_data_skip,
|
327 |
-
fsdp = fsdp,
|
328 |
-
fsdp_min_num_params = fsdp_min_num_params,
|
329 |
-
fsdp_config = fsdp_config,
|
330 |
-
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
331 |
-
accelerator_config = accelerator_config,
|
332 |
-
deepspeed = deepspeed,
|
333 |
-
label_smoothing_factor = label_smoothing_factor,
|
334 |
-
optim = optim,
|
335 |
-
optim_args = optim_args,
|
336 |
-
adafactor = adafactor,
|
337 |
-
group_by_length = group_by_length,
|
338 |
-
length_column_name = length_column_name,
|
339 |
-
report_to = report_to,
|
340 |
-
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
341 |
-
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
342 |
-
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
343 |
-
dataloader_pin_memory = dataloader_pin_memory,
|
344 |
-
dataloader_persistent_workers = dataloader_persistent_workers,
|
345 |
-
skip_memory_metrics = skip_memory_metrics,
|
346 |
-
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
347 |
-
push_to_hub = push_to_hub,
|
348 |
-
resume_from_checkpoint = resume_from_checkpoint,
|
349 |
-
hub_model_id = hub_model_id,
|
350 |
-
hub_strategy = hub_strategy,
|
351 |
-
hub_token = hub_token,
|
352 |
-
hub_private_repo = hub_private_repo,
|
353 |
-
hub_always_push = hub_always_push,
|
354 |
-
gradient_checkpointing = gradient_checkpointing,
|
355 |
-
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
356 |
-
include_inputs_for_metrics = include_inputs_for_metrics,
|
357 |
-
eval_do_concat_batches = eval_do_concat_batches,
|
358 |
-
fp16_backend = fp16_backend,
|
359 |
-
evaluation_strategy = evaluation_strategy,
|
360 |
-
push_to_hub_model_id = push_to_hub_model_id,
|
361 |
-
push_to_hub_organization = push_to_hub_organization,
|
362 |
-
push_to_hub_token = push_to_hub_token,
|
363 |
-
mp_parameters = mp_parameters,
|
364 |
-
auto_find_batch_size = auto_find_batch_size,
|
365 |
-
full_determinism = full_determinism,
|
366 |
-
torchdynamo = torchdynamo,
|
367 |
-
ray_scope = ray_scope,
|
368 |
-
ddp_timeout = ddp_timeout,
|
369 |
-
torch_compile = torch_compile,
|
370 |
-
torch_compile_backend = torch_compile_backend,
|
371 |
-
torch_compile_mode = torch_compile_mode,
|
372 |
-
dispatch_batches = dispatch_batches,
|
373 |
-
split_batches = split_batches,
|
374 |
-
include_tokens_per_second = include_tokens_per_second,
|
375 |
-
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
376 |
-
neftune_noise_alpha = neftune_noise_alpha,
|
377 |
-
optim_target_modules = optim_target_modules,
|
378 |
-
batch_eval_metrics = batch_eval_metrics,
|
379 |
-
eval_on_start = eval_on_start,
|
380 |
-
use_liger_kernel = use_liger_kernel,
|
381 |
-
eval_use_gather_object = eval_use_gather_object,
|
382 |
-
average_tokens_across_devices = average_tokens_across_devices,
|
383 |
-
model_init_kwargs = model_init_kwargs,
|
384 |
-
use_liger = use_liger,
|
385 |
-
dataset_text_field = dataset_text_field,
|
386 |
-
dataset_kwargs = dataset_kwargs,
|
387 |
-
dataset_num_proc = dataset_num_proc,
|
388 |
-
max_seq_length = max_seq_length,
|
389 |
-
packing = packing,
|
390 |
-
eval_packing = eval_packing,
|
391 |
-
dataset_batch_size = dataset_batch_size,
|
392 |
-
num_of_sequences = num_of_sequences,
|
393 |
-
chars_per_token = chars_per_token,**kwargs)
|
394 |
-
self.vllm_sampling_params = vllm_sampling_params
|
395 |
-
self.unsloth_num_chunks = unsloth_num_chunks
|
396 |
-
pass
|
397 |
-
|
398 |
-
class _UnslothSFTTrainer(Trainer):
|
399 |
-
""""""
|
400 |
-
|
401 |
-
_tag_names = ["trl", "sft"]
|
402 |
-
|
403 |
-
@deprecate_kwarg(
|
404 |
-
"tokenizer", "0.16.0", "processing_class", warn_if_greater_or_equal_version=True, raise_if_both_names=True
|
405 |
-
)
|
406 |
-
def __init__(
|
407 |
-
self,
|
408 |
-
model: Union[str, nn.Module, PreTrainedModel],
|
409 |
-
args: Optional[Union[SFTConfig, TrainingArguments]] = None,
|
410 |
-
data_collator: Optional[DataCollator] = None, # type: ignore
|
411 |
-
train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
|
412 |
-
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
413 |
-
processing_class: Optional[
|
414 |
-
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
415 |
-
] = None,
|
416 |
-
compute_loss_func: Optional[Callable] = None,
|
417 |
-
compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
|
418 |
-
callbacks: Optional[list[TrainerCallback]] = None,
|
419 |
-
optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
|
420 |
-
optimizer_cls_and_kwargs: Optional[tuple[Type[torch.optim.Optimizer], dict[str, Any]]] = None,
|
421 |
-
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
422 |
-
peft_config: Optional["PeftConfig"] = None,
|
423 |
-
formatting_func: Optional[Union[Callable[[dict], str], Callable[[dict], list[str]]]] = None,
|
424 |
-
):
|
425 |
-
# Args
|
426 |
-
if args is None:
|
427 |
-
model_name = model if isinstance(model, str) else model.config._name_or_path
|
428 |
-
model_name = model_name.split("/")[-1]
|
429 |
-
args = SFTConfig(f"{model_name}-SFT")
|
430 |
-
elif isinstance(args, TrainingArguments) and not isinstance(args, SFTConfig):
|
431 |
-
dict_args = args.to_dict()
|
432 |
-
dict_args["hub_token"] = args.hub_token # to_dict hides the hub_token
|
433 |
-
dict_args.pop("push_to_hub_token")
|
434 |
-
args = SFTConfig(**dict_args)
|
435 |
-
|
436 |
-
# Model
|
437 |
-
if args.model_init_kwargs is not None and not isinstance(model, str):
|
438 |
-
warnings.warn(
|
439 |
-
"You passed model_init_kwargs to the `SFTConfig`, but your model is already instantiated. "
|
440 |
-
"The `model_init_kwargs` will be ignored."
|
441 |
-
)
|
442 |
-
if isinstance(model, str):
|
443 |
-
model = self._create_model_from_path(model, args)
|
444 |
-
|
445 |
-
# PEFT configuration and model wrapping
|
446 |
-
if False:
|
447 |
-
model = self._prepare_peft_model(model, peft_config, args)
|
448 |
-
|
449 |
-
# Handle the tokenizer
|
450 |
-
if processing_class is None:
|
451 |
-
processing_class = AutoTokenizer.from_pretrained(model.config._name_or_path)
|
452 |
-
if processing_class.pad_token is None:
|
453 |
-
processing_class.pad_token = processing_class.eos_token # required for padding when collating data
|
454 |
-
|
455 |
-
# Dataset
|
456 |
-
preprocess_dataset = args.dataset_kwargs is None or not args.dataset_kwargs.get("skip_prepare_dataset", False)
|
457 |
-
if preprocess_dataset:
|
458 |
-
train_dataset = self._prepare_dataset(
|
459 |
-
train_dataset, processing_class, args, args.packing, formatting_func, "train"
|
460 |
-
)
|
461 |
-
if eval_dataset is not None:
|
462 |
-
packing = args.packing if args.eval_packing is None else args.eval_packing
|
463 |
-
if isinstance(eval_dataset, dict):
|
464 |
-
eval_dataset = {
|
465 |
-
key: self._prepare_dataset(dataset, processing_class, args, packing, formatting_func, key)
|
466 |
-
for key, dataset in eval_dataset.items()
|
467 |
-
}
|
468 |
-
else:
|
469 |
-
eval_dataset = self._prepare_dataset(
|
470 |
-
eval_dataset, processing_class, args, packing, formatting_func, "eval"
|
471 |
-
)
|
472 |
-
|
473 |
-
# Data collator
|
474 |
-
if data_collator is None:
|
475 |
-
data_collator = DataCollatorForLanguageModeling(tokenizer=processing_class, mlm=False)
|
476 |
-
|
477 |
-
# Initialize the metrics
|
478 |
-
self._metrics = defaultdict(list)
|
479 |
-
|
480 |
-
# Initialize the Trainer. Parent class will handle:
|
481 |
-
# - DeepSpeed configuration (through create_accelerator_and_postprocess)
|
482 |
-
# - FSDP setup
|
483 |
-
# - Distributed training setup
|
484 |
-
# - Optimizer and scheduler creation
|
485 |
-
# Some arguments are only available for transformers>=4.47.0. Can be removed when the min version is bumped.
|
486 |
-
super_init_kwargs = {}
|
487 |
-
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
488 |
-
super_init_kwargs["optimizer_cls_and_kwargs"] = optimizer_cls_and_kwargs
|
489 |
-
else:
|
490 |
-
if optimizer_cls_and_kwargs is not None:
|
491 |
-
warnings.warn(
|
492 |
-
"The `optimizer_cls_and_kwargs` argument is only available for `transformers>=4.47.0`. "
|
493 |
-
"The default optimizer will be used. "
|
494 |
-
"Remove the `optimizer_cls_and_kwargs` or upgrade to `transformers>=4.47.0`."
|
495 |
-
)
|
496 |
-
super().__init__(
|
497 |
-
model=model,
|
498 |
-
args=args,
|
499 |
-
data_collator=data_collator,
|
500 |
-
train_dataset=train_dataset,
|
501 |
-
eval_dataset=eval_dataset,
|
502 |
-
processing_class=processing_class,
|
503 |
-
compute_loss_func=compute_loss_func,
|
504 |
-
compute_metrics=compute_metrics,
|
505 |
-
callbacks=callbacks,
|
506 |
-
optimizers=optimizers,
|
507 |
-
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
508 |
-
**super_init_kwargs,
|
509 |
-
)
|
510 |
-
|
511 |
-
# Add tags for models that have been loaded with the correct transformers version
|
512 |
-
if hasattr(self.model, "add_model_tags"):
|
513 |
-
self.model.add_model_tags(self._tag_names)
|
514 |
-
|
515 |
-
def _create_model_from_path(self, model_path: str, args: SFTConfig) -> PreTrainedModel:
|
516 |
-
"""Creates a model from a path or model identifier."""
|
517 |
-
model_init_kwargs = args.model_init_kwargs or {}
|
518 |
-
# Handle torch dtype
|
519 |
-
torch_dtype = model_init_kwargs.get("torch_dtype")
|
520 |
-
if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None:
|
521 |
-
pass # torch_dtype is already a torch.dtype or "auto" or None
|
522 |
-
elif isinstance(torch_dtype, str): # it's a str, but not "auto"
|
523 |
-
torch_dtype = getattr(torch, torch_dtype)
|
524 |
-
model_init_kwargs["torch_dtype"] = torch_dtype
|
525 |
-
else:
|
526 |
-
raise ValueError(
|
527 |
-
"Invalid `torch_dtype` passed to `SFTConfig`. Expected either 'auto' or a string representing "
|
528 |
-
f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
|
529 |
-
)
|
530 |
-
# Disable caching if gradient checkpointing is enabled (not supported)
|
531 |
-
if args.gradient_checkpointing:
|
532 |
-
model_init_kwargs["use_cache"] = False
|
533 |
-
|
534 |
-
# Create model
|
535 |
-
if args.use_liger:
|
536 |
-
if not is_liger_kernel_available():
|
537 |
-
raise ImportError("Please install Liger-kernel for use_liger=True")
|
538 |
-
model = AutoLigerKernelForCausalLM.from_pretrained(model_path, **model_init_kwargs)
|
539 |
-
else:
|
540 |
-
model = AutoModelForCausalLM.from_pretrained(model_path, **model_init_kwargs)
|
541 |
-
return model
|
542 |
-
|
543 |
-
def _prepare_peft_model(self, model: PreTrainedModel, peft_config: Any, args: SFTConfig) -> PreTrainedModel:
|
544 |
-
"""Prepares a model for PEFT training."""
|
545 |
-
if not is_peft_available():
|
546 |
-
raise ImportError("To use PeftModel, you need to install the `peft` library.")
|
547 |
-
|
548 |
-
if not isinstance(peft_config, PeftConfig):
|
549 |
-
raise ValueError(
|
550 |
-
f"Expected PeftConfig object but got {type(peft_config)}. If you want to use the PeftModel, you need "
|
551 |
-
"to pass a PeftConfig object to the SFTTrainer."
|
552 |
-
)
|
553 |
-
|
554 |
-
if isinstance(model, PeftModel):
|
555 |
-
return model
|
556 |
-
|
557 |
-
# Handle quantized models (QLoRA)
|
558 |
-
is_qlora = getattr(model, "is_loaded_in_4bit", False) or getattr(model, "is_loaded_in_8bit", False)
|
559 |
-
|
560 |
-
is_sharded_qlora = False
|
561 |
-
if getattr(model, "is_loaded_in_4bit", False):
|
562 |
-
# Check if model is sharded (FSDP/DS-Zero3)
|
563 |
-
for _, param in model.named_parameters():
|
564 |
-
if param.__class__.__name__ == "Params4bit":
|
565 |
-
is_sharded_qlora = param.data.device.type in {"cpu", "meta"}
|
566 |
-
break
|
567 |
-
|
568 |
-
# Prepare model for kbit training if needed
|
569 |
-
if is_qlora and not is_sharded_qlora:
|
570 |
-
model = self._prepare_model_for_kbit_training(model, args)
|
571 |
-
# Disable gradient checkpointing as it's handled by prepare_model_for_kbit_training
|
572 |
-
args = dataclasses.replace(args, gradient_checkpointing=False)
|
573 |
-
elif args.gradient_checkpointing:
|
574 |
-
model = self._enable_gradient_checkpointing(model, args)
|
575 |
-
|
576 |
-
# Create PEFT model
|
577 |
-
if (
|
578 |
-
version.parse(peft.__version__) >= version.parse("0.12") # autocast_adapter_dtype introduced in 0.12
|
579 |
-
and getattr(model, "is_loaded_in_4bit", False)
|
580 |
-
and is_sharded_qlora
|
581 |
-
):
|
582 |
-
model = get_peft_model(model, peft_config, autocast_adapter_dtype=False)
|
583 |
-
else:
|
584 |
-
model = get_peft_model(model, peft_config)
|
585 |
-
|
586 |
-
# Handle bf16 casting for 4-bit models
|
587 |
-
if args.bf16 and getattr(model, "is_loaded_in_4bit", False) and not is_sharded_qlora:
|
588 |
-
peft_module_casting_to_bf16(model)
|
589 |
-
|
590 |
-
return model
|
591 |
-
|
592 |
-
def _prepare_model_for_kbit_training(self, model: PreTrainedModel, args: SFTConfig) -> PreTrainedModel:
|
593 |
-
"""Prepares a quantized model for kbit training."""
|
594 |
-
prepare_model_kwargs = {
|
595 |
-
"use_gradient_checkpointing": args.gradient_checkpointing,
|
596 |
-
"gradient_checkpointing_kwargs": args.gradient_checkpointing_kwargs or {},
|
597 |
-
}
|
598 |
-
|
599 |
-
return prepare_model_for_kbit_training(model, **prepare_model_kwargs)
|
600 |
-
|
601 |
-
def _enable_gradient_checkpointing(self, model: PreTrainedModel, args: SFTConfig) -> PreTrainedModel:
|
602 |
-
"""Enables gradient checkpointing for the model."""
|
603 |
-
gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
|
604 |
-
use_reentrant = (
|
605 |
-
"use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]
|
606 |
-
)
|
607 |
-
|
608 |
-
if use_reentrant:
|
609 |
-
if hasattr(model, "enable_input_require_grads"):
|
610 |
-
model.enable_input_require_grads()
|
611 |
-
else:
|
612 |
-
|
613 |
-
def make_inputs_require_grad(module, input, output):
|
614 |
-
output.requires_grad_(True)
|
615 |
-
|
616 |
-
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
617 |
-
|
618 |
-
return model
|
619 |
-
|
620 |
-
def _prepare_dataset(
|
621 |
-
self,
|
622 |
-
dataset: Union[Dataset, IterableDataset],
|
623 |
-
processing_class,
|
624 |
-
args,
|
625 |
-
packing: bool,
|
626 |
-
formatting_func: Optional[Callable[[dict], str]],
|
627 |
-
dataset_name: str,
|
628 |
-
) -> Union[Dataset, IterableDataset]:
|
629 |
-
# All Unsloth Zoo code licensed under LGPLv3
|
630 |
-
if isinstance(dataset, ConstantLengthDataset): return dataset
|
631 |
-
|
632 |
-
map_kwargs = {}
|
633 |
-
use_desc = isinstance(dataset, Dataset)
|
634 |
-
is_vlm = hasattr(processing_class, "tokenizer")
|
635 |
-
tokenizer = processing_class
|
636 |
-
if is_vlm: tokenizer = processing_class.tokenizer
|
637 |
-
|
638 |
-
# Get max length
|
639 |
-
max_seq_length = getattr(args, "max_length", 0)
|
640 |
-
if max_seq_length == 0: max_seq_length = getattr(args, "max_seq_length", 0)
|
641 |
-
if max_seq_length == 0: max_seq_length = getattr(self, "max_seq_length", 0)
|
642 |
-
if max_seq_length == 0: max_seq_length = getattr(self, "max_seq", 0)
|
643 |
-
if max_seq_length == 0: raise RuntimeError("Unsloth: max_seq_length is 0! Please specify one!")
|
644 |
-
dataset_text_field = getattr(args, "dataset_text_field", "text")
|
645 |
-
do_truncation = max_seq_length != 0
|
646 |
-
do_formatting_func = False
|
647 |
-
do_tokenize = True
|
648 |
-
|
649 |
-
# Get correct column names
|
650 |
-
column_names = set(next(iter(dataset)).keys())
|
651 |
-
used_column_names = ["input_ids"]
|
652 |
-
if "attention_mask" in column_names:
|
653 |
-
used_column_names.append("attention_mask")
|
654 |
-
|
655 |
-
# Check if already tokenized so skip
|
656 |
-
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
|
657 |
-
if "labels" in column_names:
|
658 |
-
# Most likely forgot data collator!
|
659 |
-
if is_vlm and not hasattr(tokenizer, "pad"):
|
660 |
-
# Check if processing_class has a .pad, if not, use tokenizer.tokenizer
|
661 |
-
raise RuntimeError(f"Unsloth: {processing_class.__class__} does not have .pad!")
|
662 |
-
self.data_collator = DataCollatorForSeq2Seq(tokenizer)
|
663 |
-
used_column_names.append("labels")
|
664 |
-
do_tokenize = False
|
665 |
-
elif "input_ids" in column_names:
|
666 |
-
# Skip dataset prep, and set data collator
|
667 |
-
if is_vlm and not hasattr(tokenizer, "pad"):
|
668 |
-
# Check if processing_class has a .pad, if not, use tokenizer.tokenizer
|
669 |
-
raise RuntimeError(f"Unsloth: {processing_class.__class__} does not have .pad!")
|
670 |
-
self.data_collator = DataCollatorForLanguageModeling(tokenizer, mlm = False)
|
671 |
-
do_tokenize = False
|
672 |
-
elif dataset_text_field not in column_names:
|
673 |
-
do_formatting_func = True
|
674 |
-
if formatting_func is None:
|
675 |
-
raise RuntimeError("Unsloth: You must specify a `formatting_func`")
|
676 |
-
pass
|
677 |
-
|
678 |
-
if do_tokenize:
|
679 |
-
# Check double BOS tokens
|
680 |
-
if do_formatting_func:
|
681 |
-
test_text = formatting_func(dataset[0])
|
682 |
-
if not isinstance(test_text, list):
|
683 |
-
raise ValueError(
|
684 |
-
"Unsloth: The `formatting_func` should return a list of processed strings."
|
685 |
-
)
|
686 |
-
test_text = test_text[0]
|
687 |
-
else:
|
688 |
-
test_text = dataset[0][dataset_text_field]
|
689 |
-
|
690 |
-
# Get chat template
|
691 |
-
chat_template = getattr(processing_class, 'chat_template', '')
|
692 |
-
if chat_template == '' and is_vlm:
|
693 |
-
chat_template = getattr(tokenizer, 'chat_template', '')
|
694 |
-
if chat_template is None:
|
695 |
-
chat_template = ''
|
696 |
-
|
697 |
-
# Get bos_token
|
698 |
-
add_special_tokens = True
|
699 |
-
bos_token_1 = getattr(processing_class, 'bos_token', None)
|
700 |
-
bos_token_2 = getattr(tokenizer, 'bos_token', None)
|
701 |
-
bos_token = bos_token_1 or bos_token_2
|
702 |
-
|
703 |
-
if bos_token is not None:
|
704 |
-
if test_text.startswith(bos_token) or bos_token in chat_template:
|
705 |
-
add_special_tokens = False
|
706 |
-
print("Unsloth: We found double BOS tokens - we shall remove one automatically.")
|
707 |
-
pass
|
708 |
-
|
709 |
-
# Create tokenize function
|
710 |
-
def _tokenize(example):
|
711 |
-
return tokenizer(
|
712 |
-
example[dataset_text_field] if not do_formatting_func else formatting_func(example),
|
713 |
-
truncation = do_truncation,
|
714 |
-
max_length = max_seq_length,
|
715 |
-
return_token_type_ids = False,
|
716 |
-
add_special_tokens = add_special_tokens,
|
717 |
-
)
|
718 |
-
pass
|
719 |
-
|
720 |
-
map_kwargs["num_proc"] = getattr(args, "dataset_num_proc", 2)
|
721 |
-
if use_desc: map_kwargs["desc"] = f'Unsloth: Tokenizing ["{dataset_text_field}"]'
|
722 |
-
dataset = dataset.map(_tokenize, batched = True, **map_kwargs)
|
723 |
-
|
724 |
-
# If VLM, switch data collator since .pad is needed!
|
725 |
-
if is_vlm and not hasattr(processing_class, "pad"):
|
726 |
-
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm = False)
|
727 |
-
self.data_collator = data_collator
|
728 |
-
pass
|
729 |
-
pass
|
730 |
-
if packing:
|
731 |
-
print("Unsloth: Hugging Face's packing is currently buggy - we're disabling it for now!")
|
732 |
-
return dataset
|
733 |
-
|
734 |
-
if max_seq_length == 0:
|
735 |
-
raise ValueError("When packing is enabled, `max_seq_length` can't be `None`.")
|
736 |
-
|
737 |
-
if use_desc: map_kwargs["desc"] = f"Unsloth: Packing {dataset_name} dataset"
|
738 |
-
dataset = dataset.select_columns(used_column_names).map(
|
739 |
-
pack_examples,
|
740 |
-
batched = True,
|
741 |
-
fn_kwargs = {"seq_length": max_seq_length,},
|
742 |
-
**map_kwargs,
|
743 |
-
)
|
744 |
-
pass
|
745 |
-
return dataset
|
746 |
-
|
747 |
-
def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch = None):
|
748 |
-
outputs = super().compute_loss(
|
749 |
-
model,
|
750 |
-
inputs,
|
751 |
-
return_outputs = return_outputs,
|
752 |
-
num_items_in_batch = num_items_in_batch,
|
753 |
-
)
|
754 |
-
return outputs
|
755 |
-
|
756 |
-
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
|
757 |
-
metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics
|
758 |
-
|
759 |
-
# This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
|
760 |
-
# start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
|
761 |
-
if next(iter(logs.keys())).startswith("eval_"):
|
762 |
-
metrics = {f"eval_{key}": val for key, val in metrics.items()}
|
763 |
-
|
764 |
-
logs = {**logs, **metrics}
|
765 |
-
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
766 |
-
super().log(logs, start_time)
|
767 |
-
else: # transformers<=4.46
|
768 |
-
super().log(logs)
|
769 |
-
self._metrics.clear()
|
770 |
-
|
771 |
-
def create_model_card(
|
772 |
-
self,
|
773 |
-
model_name: Optional[str] = None,
|
774 |
-
dataset_name: Optional[str] = None,
|
775 |
-
tags: Union[str, list[str], None] = None,
|
776 |
-
):
|
777 |
-
"""
|
778 |
-
Creates a draft of a model card using the information available to the `Trainer`.
|
779 |
-
|
780 |
-
Args:
|
781 |
-
model_name (`str` or `None`, *optional*, defaults to `None`):
|
782 |
-
Name of the model.
|
783 |
-
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
784 |
-
Name of the dataset used for training.
|
785 |
-
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
786 |
-
Tags to be associated with the model card.
|
787 |
-
"""
|
788 |
-
if not self.is_world_process_zero():
|
789 |
-
return
|
790 |
-
|
791 |
-
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
792 |
-
base_model = self.model.config._name_or_path
|
793 |
-
else:
|
794 |
-
base_model = None
|
795 |
-
|
796 |
-
tags = tags or []
|
797 |
-
if isinstance(tags, str):
|
798 |
-
tags = [tags]
|
799 |
-
|
800 |
-
if hasattr(self.model.config, "unsloth_version"):
|
801 |
-
tags.append("unsloth")
|
802 |
-
|
803 |
-
model_card = generate_model_card(
|
804 |
-
base_model=base_model,
|
805 |
-
model_name=model_name,
|
806 |
-
hub_model_id=self.hub_model_id,
|
807 |
-
dataset_name=dataset_name,
|
808 |
-
tags=tags,
|
809 |
-
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
810 |
-
comet_url=get_comet_experiment_url(),
|
811 |
-
trainer_name="SFT",
|
812 |
-
)
|
813 |
-
|
814 |
-
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
815 |
-
class UnslothSFTTrainer(_UnslothSFTTrainer):
|
816 |
-
"""
|
817 |
-
|
818 |
-
Trainer for Supervised Fine-Tuning (SFT) method.
|
819 |
-
|
820 |
-
This class is a wrapper around the [`transformers.Trainer`] class and inherits all of its attributes and methods.
|
821 |
-
|
822 |
-
Example:
|
823 |
-
|
824 |
-
```python
|
825 |
-
from datasets import load_dataset
|
826 |
-
from trl import SFTTrainer
|
827 |
-
|
828 |
-
dataset = load_dataset("roneneldan/TinyStories", split="train[:1%]")
|
829 |
-
|
830 |
-
trainer = SFTTrainer(model="Qwen/Qwen2-0.5B-Instruct", train_dataset=dataset)
|
831 |
-
trainer.train()
|
832 |
-
```
|
833 |
-
|
834 |
-
Args:
|
835 |
-
model (`Union[str, PreTrainedModel]`):
|
836 |
-
Model to be trained. Can be either:
|
837 |
-
|
838 |
-
- A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or
|
839 |
-
a path to a *directory* containing model weights saved using
|
840 |
-
[`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is
|
841 |
-
loaded using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keywork arguments
|
842 |
-
in `args.model_init_kwargs`.
|
843 |
-
- A [`~transformers.PreTrainedModel`] object. Only causal language models are supported.
|
844 |
-
args ([`SFTConfig`], *optional*, defaults to `None`):
|
845 |
-
Configuration for this trainer. If `None`, a default configuration is used.
|
846 |
-
data_collator (`DataCollator`, *optional*):
|
847 |
-
Function to use to form a batch from a list of elements of the prcessed `train_dataset` or `eval_dataset`.
|
848 |
-
Will default to [`~transformers.default_data_collator`] if no `processing_class` is provided, an instance
|
849 |
-
of [`~transformers.DataCollatorWithPadding`] otherwise if the processing_class is a feature extractor or
|
850 |
-
tokenizer.
|
851 |
-
train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
|
852 |
-
Dataset to use for training. SFT supports both [language modeling](#language-modeling) type and
|
853 |
-
[prompt-completion](#prompt-completion) type. The format of the samples can be either:
|
854 |
-
|
855 |
-
- [Standard](dataset_formats#standard): Each sample contains plain text.
|
856 |
-
- [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
|
857 |
-
and content).
|
858 |
-
|
859 |
-
The trainer also supports processed datasets (tokenized) as long as they contain an `input_ids` field.
|
860 |
-
eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
|
861 |
-
Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
|
862 |
-
processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*, defaults to `None`):
|
863 |
-
Processing class used to process the data. If `None`, the processing class is loaded from the model's name
|
864 |
-
with [`~transformers.AutoTokenizer.from_pretrained`].
|
865 |
-
callbacks (list of [`~transformers.TrainerCallback`], *optional*, defaults to `None`):
|
866 |
-
List of callbacks to customize the training loop. Will add those to the list of default callbacks
|
867 |
-
detailed in [here](https://huggingface.co/docs/transformers/main_classes/callback).
|
868 |
-
|
869 |
-
If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`]
|
870 |
-
method.
|
871 |
-
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`):
|
872 |
-
A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your
|
873 |
-
model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
|
874 |
-
optimizer_cls_and_kwargs (`Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]`, *optional*, defaults to `None`):
|
875 |
-
A tuple containing the optimizer class and keyword arguments to use.
|
876 |
-
Overrides `optim` and `optim_args` in `args`. Incompatible with the `optimizers` argument.
|
877 |
-
|
878 |
-
Unlike `optimizers`, this argument avoids the need to place model parameters on the correct devices before initializing the Trainer.
|
879 |
-
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*, defaults to `None`):
|
880 |
-
A function that preprocess the logits right before caching them at each evaluation step. Must take two
|
881 |
-
tensors, the logits and the labels, and return the logits once processed as desired. The modifications made
|
882 |
-
by this function will be reflected in the predictions received by `compute_metrics`.
|
883 |
-
|
884 |
-
Note that the labels (second parameter) will be `None` if the dataset does not have them.
|
885 |
-
peft_config ([`~peft.PeftConfig`], *optional*, defaults to `None`):
|
886 |
-
PEFT configuration used to wrap the model. If `None`, the model is not wrapped.
|
887 |
-
formatting_func (`Optional[Callable]`):
|
888 |
-
Formatting function applied to the dataset before tokenization.
|
889 |
-
|
890 |
-
"""
|
891 |
-
def __init__(
|
892 |
-
self,
|
893 |
-
model,
|
894 |
-
args = None,
|
895 |
-
data_collator = None,
|
896 |
-
train_dataset = None,
|
897 |
-
eval_dataset = None,
|
898 |
-
processing_class = None,
|
899 |
-
compute_loss_func = None,
|
900 |
-
compute_metrics = None,
|
901 |
-
callbacks = None,
|
902 |
-
optimizer_cls_and_kwargs = None,
|
903 |
-
preprocess_logits_for_metrics = None,
|
904 |
-
peft_config = None,
|
905 |
-
formatting_func = None,
|
906 |
-
**kwargs
|
907 |
-
):
|
908 |
-
if args is None: args = UnslothSFTConfig()
|
909 |
-
use_bf16 = getattr(args, 'bf16', False)
|
910 |
-
use_fp16 = getattr(args, 'fp16', False)
|
911 |
-
force_float32 = False
|
912 |
-
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
913 |
-
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
914 |
-
force_float32 = True
|
915 |
-
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
916 |
-
dtype = getattr(model.config, 'torch_dtype', None)
|
917 |
-
if dtype is None: dtype = model.get_input_embeddings().dtype
|
918 |
-
from unsloth_zoo.utils import _get_dtype
|
919 |
-
dtype = _get_dtype(dtype)
|
920 |
-
float16 = dtype == torch.float16
|
921 |
-
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
922 |
-
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
923 |
-
if force_float32:
|
924 |
-
args.fp16 = False
|
925 |
-
args.bf16 = False
|
926 |
-
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
927 |
-
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
928 |
-
args.fp16 = float16
|
929 |
-
args.bf16 = not float16
|
930 |
-
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
931 |
-
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
932 |
-
args.eval_strategy = 'steps'
|
933 |
-
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
934 |
-
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
935 |
-
if ga_steps is not None and ga_steps > 1:
|
936 |
-
from transformers import __version__ as transformers_version
|
937 |
-
if Version(transformers_version) <= Version('4.45.2'):
|
938 |
-
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
939 |
-
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
940 |
-
if getattr(args, 'eval_strategy', 'no') != 'no':
|
941 |
-
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
942 |
-
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
943 |
-
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
944 |
-
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
945 |
-
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
946 |
-
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
947 |
-
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
948 |
-
if force_float32:
|
949 |
-
args.bf16_full_eval = False
|
950 |
-
args.fp16_full_eval = False
|
951 |
-
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
952 |
-
args.bf16_full_eval = True
|
953 |
-
args.fp16_full_eval = False
|
954 |
-
elif not bf16_full_eval and not fp16_full_eval:
|
955 |
-
args.bf16_full_eval = args.bf16
|
956 |
-
args.fp16_full_eval = args.fp16
|
957 |
-
_output_logits = False
|
958 |
-
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
959 |
-
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
960 |
-
if _output_logits:
|
961 |
-
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
962 |
-
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
963 |
-
pass
|
964 |
-
else:
|
965 |
-
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
966 |
-
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
967 |
-
if args_max_seq_length is None and model_max_seq_length is not None:
|
968 |
-
max_seq_length = model.max_seq_length
|
969 |
-
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
970 |
-
if model is not None and hasattr(model, 'for_training'):
|
971 |
-
model.for_training()
|
972 |
-
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
973 |
-
if 'processing_class' in locals():
|
974 |
-
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
975 |
-
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
976 |
-
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
977 |
-
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
978 |
-
if not isinstance(data_collator, UnslothVisionDataCollator):
|
979 |
-
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
980 |
-
data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
|
981 |
-
elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
982 |
-
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
983 |
-
else:
|
984 |
-
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
985 |
-
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
986 |
-
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
987 |
-
if not isinstance(data_collator, UnslothVisionDataCollator):
|
988 |
-
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
989 |
-
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
990 |
-
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
991 |
-
else:
|
992 |
-
data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
|
993 |
-
other_metrics = []
|
994 |
-
|
995 |
-
from unsloth_zoo.logging_utils import PatchRLStatistics
|
996 |
-
PatchRLStatistics('sft_trainer', other_metrics)
|
997 |
-
IGNORED_TOKENIZER_NAMES = os.environ.get('UNSLOTH_IGNORED_TOKENIZER_NAMES', '').split('\n')
|
998 |
-
from unsloth_zoo.tokenizer_utils import fix_untrained_tokens
|
999 |
-
from unsloth_zoo.training_utils import fix_zero_training_loss
|
1000 |
-
if 'tokenizer' not in locals(): tokenizer = processing_class
|
1001 |
-
fix_untrained_tokens(model, tokenizer, train_dataset, IGNORED_TOKENIZER_NAMES, eps = 1e-16)
|
1002 |
-
fix_zero_training_loss(model, tokenizer, train_dataset)
|
1003 |
-
|
1004 |
-
super().__init__(
|
1005 |
-
model = model,
|
1006 |
-
args = args,
|
1007 |
-
data_collator = data_collator,
|
1008 |
-
train_dataset = train_dataset,
|
1009 |
-
eval_dataset = eval_dataset,
|
1010 |
-
processing_class = processing_class,
|
1011 |
-
compute_loss_func = compute_loss_func,
|
1012 |
-
compute_metrics = compute_metrics,
|
1013 |
-
callbacks = callbacks,
|
1014 |
-
optimizer_cls_and_kwargs = optimizer_cls_and_kwargs,
|
1015 |
-
preprocess_logits_for_metrics = preprocess_logits_for_metrics,
|
1016 |
-
peft_config = peft_config,
|
1017 |
-
formatting_func = formatting_func,**kwargs)
|
1018 |
-
if hasattr(self, 'neftune_hook_handle'):
|
1019 |
-
self.neftune_hook_handle.remove()
|
1020 |
-
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
1021 |
-
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
1022 |
-
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
1023 |
-
pass
|
1024 |
-
|
1025 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unsloth_compiled_cache/UnslothXPOTrainer.py
DELETED
@@ -1,1008 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
2025.3.13
|
3 |
-
2025.3.15
|
4 |
-
4.48.3
|
5 |
-
0.15.2
|
6 |
-
__UNSLOTH_VERSIONING__
|
7 |
-
"""
|
8 |
-
from torch import Tensor
|
9 |
-
import torch
|
10 |
-
import torch.nn as nn
|
11 |
-
from torch.nn import functional as F
|
12 |
-
from trl.trainer.xpo_trainer import (Any, BaseImageProcessor, BasePairwiseJudge, Callable, Dataset, EvalPrediction, F, FeatureExtractionMixin, IterableDataset, OnlineDPOTrainer, OptimizerNames, Optional, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SIMPLE_CHAT_TEMPLATE, TrainerCallback, Union, XPOConfig, XPOTrainer, empty_cache, generate_model_card, get_comet_experiment_url, get_reward, is_conversational, is_wandb_available, jinja2, maybe_apply_chat_template, nn, os, textwrap, torch, truncate_right, unwrap_model_for_generation, wandb)
|
13 |
-
|
14 |
-
|
15 |
-
import os
|
16 |
-
from typing import *
|
17 |
-
from dataclasses import dataclass, field
|
18 |
-
from packaging.version import Version
|
19 |
-
import torch
|
20 |
-
import numpy as np
|
21 |
-
from contextlib import nullcontext
|
22 |
-
from torch.nn import functional as F
|
23 |
-
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
|
24 |
-
|
25 |
-
torch_compile_options = {
|
26 |
-
"epilogue_fusion" : True,
|
27 |
-
"max_autotune" : False,
|
28 |
-
"shape_padding" : True,
|
29 |
-
"trace.enabled" : False,
|
30 |
-
"triton.cudagraphs" : False,
|
31 |
-
}
|
32 |
-
|
33 |
-
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
34 |
-
def selective_log_softmax(logits, index):
|
35 |
-
logits = logits.to(torch.float32)
|
36 |
-
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
37 |
-
# loop to reduce peak mem consumption
|
38 |
-
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
39 |
-
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
40 |
-
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
41 |
-
return per_token_logps
|
42 |
-
@dataclass
|
43 |
-
class UnslothXPOConfig(XPOConfig):
|
44 |
-
"""
|
45 |
-
|
46 |
-
Configuration class for the [`XPOTrainer`].
|
47 |
-
|
48 |
-
Subclass of [`OnlineDPOConfig`] we can use all its arguments and add the following:
|
49 |
-
|
50 |
-
Parameters:
|
51 |
-
alpha (`float` or `list[float]`, *optional*, defaults to `1e-5`):
|
52 |
-
Weight of the XPO loss term. If a list of floats is provided then the alpha is selected for each new epoch
|
53 |
-
and the last alpha is used for the rest of the epochs.
|
54 |
-
|
55 |
-
"""
|
56 |
-
vllm_sampling_params: Optional[Any] = field(
|
57 |
-
default = None,
|
58 |
-
metadata = {'help': 'vLLM SamplingParams'},
|
59 |
-
)
|
60 |
-
unsloth_num_chunks : Optional[int] = field(
|
61 |
-
default = -1,
|
62 |
-
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
63 |
-
)
|
64 |
-
def __init__(
|
65 |
-
self,
|
66 |
-
output_dir = None,
|
67 |
-
overwrite_output_dir = None,
|
68 |
-
do_train = False,
|
69 |
-
do_eval = False,
|
70 |
-
do_predict = False,
|
71 |
-
eval_strategy = 'no',
|
72 |
-
prediction_loss_only = False,
|
73 |
-
per_device_train_batch_size = 4,
|
74 |
-
per_device_eval_batch_size = 4,
|
75 |
-
per_gpu_train_batch_size = None,
|
76 |
-
per_gpu_eval_batch_size = None,
|
77 |
-
gradient_accumulation_steps = 2,
|
78 |
-
eval_accumulation_steps = 2,
|
79 |
-
eval_delay = 0,
|
80 |
-
torch_empty_cache_steps = 250,
|
81 |
-
learning_rate = 5e-05,
|
82 |
-
weight_decay = 0.01,
|
83 |
-
adam_beta1 = 0.9,
|
84 |
-
adam_beta2 = 0.999,
|
85 |
-
adam_epsilon = 1e-08,
|
86 |
-
max_grad_norm = 1.0,
|
87 |
-
num_train_epochs = 3.0,
|
88 |
-
max_steps = -1,
|
89 |
-
lr_scheduler_type = 'linear',
|
90 |
-
warmup_ratio = 0.1,
|
91 |
-
warmup_steps = 0,
|
92 |
-
log_level = 'passive',
|
93 |
-
log_level_replica = 'warning',
|
94 |
-
log_on_each_node = True,
|
95 |
-
logging_dir = None,
|
96 |
-
logging_strategy = 'steps',
|
97 |
-
logging_first_step = False,
|
98 |
-
logging_steps = 1,
|
99 |
-
logging_nan_inf_filter = False,
|
100 |
-
save_strategy = 'steps',
|
101 |
-
save_steps = 500,
|
102 |
-
save_total_limit = None,
|
103 |
-
save_safetensors = True,
|
104 |
-
save_on_each_node = False,
|
105 |
-
save_only_model = False,
|
106 |
-
restore_callback_states_from_checkpoint = False,
|
107 |
-
no_cuda = False,
|
108 |
-
use_cpu = False,
|
109 |
-
use_mps_device = False,
|
110 |
-
seed = 3407,
|
111 |
-
data_seed = 3407,
|
112 |
-
jit_mode_eval = False,
|
113 |
-
use_ipex = False,
|
114 |
-
bf16 = False,
|
115 |
-
fp16 = False,
|
116 |
-
fp16_opt_level = 'O1',
|
117 |
-
half_precision_backend = 'auto',
|
118 |
-
bf16_full_eval = False,
|
119 |
-
fp16_full_eval = False,
|
120 |
-
tf32 = None,
|
121 |
-
local_rank = -1,
|
122 |
-
ddp_backend = None,
|
123 |
-
tpu_num_cores = None,
|
124 |
-
tpu_metrics_debug = False,
|
125 |
-
debug = '',
|
126 |
-
dataloader_drop_last = False,
|
127 |
-
eval_steps = None,
|
128 |
-
dataloader_num_workers = 0,
|
129 |
-
dataloader_prefetch_factor = None,
|
130 |
-
past_index = -1,
|
131 |
-
run_name = None,
|
132 |
-
disable_tqdm = None,
|
133 |
-
remove_unused_columns = True,
|
134 |
-
label_names = None,
|
135 |
-
load_best_model_at_end = False,
|
136 |
-
metric_for_best_model = None,
|
137 |
-
greater_is_better = None,
|
138 |
-
ignore_data_skip = False,
|
139 |
-
fsdp = '',
|
140 |
-
fsdp_min_num_params = 0,
|
141 |
-
fsdp_config = None,
|
142 |
-
fsdp_transformer_layer_cls_to_wrap = None,
|
143 |
-
accelerator_config = None,
|
144 |
-
deepspeed = None,
|
145 |
-
label_smoothing_factor = 0.0,
|
146 |
-
optim = 'adamw_8bit',
|
147 |
-
optim_args = None,
|
148 |
-
adafactor = False,
|
149 |
-
group_by_length = False,
|
150 |
-
length_column_name = 'length',
|
151 |
-
report_to = None,
|
152 |
-
ddp_find_unused_parameters = None,
|
153 |
-
ddp_bucket_cap_mb = None,
|
154 |
-
ddp_broadcast_buffers = None,
|
155 |
-
dataloader_pin_memory = True,
|
156 |
-
dataloader_persistent_workers = False,
|
157 |
-
skip_memory_metrics = True,
|
158 |
-
use_legacy_prediction_loop = False,
|
159 |
-
push_to_hub = False,
|
160 |
-
resume_from_checkpoint = None,
|
161 |
-
hub_model_id = None,
|
162 |
-
hub_strategy = 'every_save',
|
163 |
-
hub_token = None,
|
164 |
-
hub_private_repo = None,
|
165 |
-
hub_always_push = False,
|
166 |
-
gradient_checkpointing = False,
|
167 |
-
gradient_checkpointing_kwargs = None,
|
168 |
-
include_inputs_for_metrics = False,
|
169 |
-
eval_do_concat_batches = True,
|
170 |
-
fp16_backend = 'auto',
|
171 |
-
evaluation_strategy = None,
|
172 |
-
push_to_hub_model_id = None,
|
173 |
-
push_to_hub_organization = None,
|
174 |
-
push_to_hub_token = None,
|
175 |
-
mp_parameters = '',
|
176 |
-
auto_find_batch_size = False,
|
177 |
-
full_determinism = False,
|
178 |
-
torchdynamo = None,
|
179 |
-
ray_scope = 'last',
|
180 |
-
ddp_timeout = 1800,
|
181 |
-
torch_compile = False,
|
182 |
-
torch_compile_backend = None,
|
183 |
-
torch_compile_mode = None,
|
184 |
-
dispatch_batches = None,
|
185 |
-
split_batches = None,
|
186 |
-
include_tokens_per_second = False,
|
187 |
-
include_num_input_tokens_seen = False,
|
188 |
-
neftune_noise_alpha = None,
|
189 |
-
optim_target_modules = None,
|
190 |
-
batch_eval_metrics = False,
|
191 |
-
eval_on_start = False,
|
192 |
-
use_liger_kernel = False,
|
193 |
-
eval_use_gather_object = False,
|
194 |
-
average_tokens_across_devices = False,
|
195 |
-
reward_model_path = None,
|
196 |
-
judge = None,
|
197 |
-
max_new_tokens = 64,
|
198 |
-
max_length = 512,
|
199 |
-
temperature = 0.9,
|
200 |
-
missing_eos_penalty = None,
|
201 |
-
loss_type = 'sigmoid',
|
202 |
-
dataset_num_proc = None,
|
203 |
-
disable_dropout = True,
|
204 |
-
use_vllm = False,
|
205 |
-
ds3_gather_for_generation = True,
|
206 |
-
vllm_sampling_params = None,
|
207 |
-
unsloth_num_chunks = -1,
|
208 |
-
**kwargs,
|
209 |
-
):
|
210 |
-
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
211 |
-
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
212 |
-
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
213 |
-
output_dir = 'unsloth_training_checkpoints'
|
214 |
-
save_strategy = 'no'
|
215 |
-
if dataset_num_proc is None:
|
216 |
-
from multiprocessing import cpu_count
|
217 |
-
dataset_num_proc = cpu_count()
|
218 |
-
|
219 |
-
super().__init__(
|
220 |
-
output_dir = output_dir,
|
221 |
-
overwrite_output_dir = overwrite_output_dir,
|
222 |
-
do_train = do_train,
|
223 |
-
do_eval = do_eval,
|
224 |
-
do_predict = do_predict,
|
225 |
-
eval_strategy = eval_strategy,
|
226 |
-
prediction_loss_only = prediction_loss_only,
|
227 |
-
per_device_train_batch_size = per_device_train_batch_size,
|
228 |
-
per_device_eval_batch_size = per_device_eval_batch_size,
|
229 |
-
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
230 |
-
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
231 |
-
gradient_accumulation_steps = gradient_accumulation_steps,
|
232 |
-
eval_accumulation_steps = eval_accumulation_steps,
|
233 |
-
eval_delay = eval_delay,
|
234 |
-
torch_empty_cache_steps = torch_empty_cache_steps,
|
235 |
-
learning_rate = learning_rate,
|
236 |
-
weight_decay = weight_decay,
|
237 |
-
adam_beta1 = adam_beta1,
|
238 |
-
adam_beta2 = adam_beta2,
|
239 |
-
adam_epsilon = adam_epsilon,
|
240 |
-
max_grad_norm = max_grad_norm,
|
241 |
-
num_train_epochs = num_train_epochs,
|
242 |
-
max_steps = max_steps,
|
243 |
-
lr_scheduler_type = lr_scheduler_type,
|
244 |
-
warmup_ratio = warmup_ratio,
|
245 |
-
warmup_steps = warmup_steps,
|
246 |
-
log_level = log_level,
|
247 |
-
log_level_replica = log_level_replica,
|
248 |
-
log_on_each_node = log_on_each_node,
|
249 |
-
logging_dir = logging_dir,
|
250 |
-
logging_strategy = logging_strategy,
|
251 |
-
logging_first_step = logging_first_step,
|
252 |
-
logging_steps = logging_steps,
|
253 |
-
logging_nan_inf_filter = logging_nan_inf_filter,
|
254 |
-
save_strategy = save_strategy,
|
255 |
-
save_steps = save_steps,
|
256 |
-
save_total_limit = save_total_limit,
|
257 |
-
save_safetensors = save_safetensors,
|
258 |
-
save_on_each_node = save_on_each_node,
|
259 |
-
save_only_model = save_only_model,
|
260 |
-
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
261 |
-
no_cuda = no_cuda,
|
262 |
-
use_cpu = use_cpu,
|
263 |
-
use_mps_device = use_mps_device,
|
264 |
-
seed = seed,
|
265 |
-
data_seed = data_seed,
|
266 |
-
jit_mode_eval = jit_mode_eval,
|
267 |
-
use_ipex = use_ipex,
|
268 |
-
bf16 = bf16,
|
269 |
-
fp16 = fp16,
|
270 |
-
fp16_opt_level = fp16_opt_level,
|
271 |
-
half_precision_backend = half_precision_backend,
|
272 |
-
bf16_full_eval = bf16_full_eval,
|
273 |
-
fp16_full_eval = fp16_full_eval,
|
274 |
-
tf32 = tf32,
|
275 |
-
local_rank = local_rank,
|
276 |
-
ddp_backend = ddp_backend,
|
277 |
-
tpu_num_cores = tpu_num_cores,
|
278 |
-
tpu_metrics_debug = tpu_metrics_debug,
|
279 |
-
debug = debug,
|
280 |
-
dataloader_drop_last = dataloader_drop_last,
|
281 |
-
eval_steps = eval_steps,
|
282 |
-
dataloader_num_workers = dataloader_num_workers,
|
283 |
-
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
284 |
-
past_index = past_index,
|
285 |
-
run_name = run_name,
|
286 |
-
disable_tqdm = disable_tqdm,
|
287 |
-
remove_unused_columns = remove_unused_columns,
|
288 |
-
label_names = label_names,
|
289 |
-
load_best_model_at_end = load_best_model_at_end,
|
290 |
-
metric_for_best_model = metric_for_best_model,
|
291 |
-
greater_is_better = greater_is_better,
|
292 |
-
ignore_data_skip = ignore_data_skip,
|
293 |
-
fsdp = fsdp,
|
294 |
-
fsdp_min_num_params = fsdp_min_num_params,
|
295 |
-
fsdp_config = fsdp_config,
|
296 |
-
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
297 |
-
accelerator_config = accelerator_config,
|
298 |
-
deepspeed = deepspeed,
|
299 |
-
label_smoothing_factor = label_smoothing_factor,
|
300 |
-
optim = optim,
|
301 |
-
optim_args = optim_args,
|
302 |
-
adafactor = adafactor,
|
303 |
-
group_by_length = group_by_length,
|
304 |
-
length_column_name = length_column_name,
|
305 |
-
report_to = report_to,
|
306 |
-
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
307 |
-
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
308 |
-
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
309 |
-
dataloader_pin_memory = dataloader_pin_memory,
|
310 |
-
dataloader_persistent_workers = dataloader_persistent_workers,
|
311 |
-
skip_memory_metrics = skip_memory_metrics,
|
312 |
-
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
313 |
-
push_to_hub = push_to_hub,
|
314 |
-
resume_from_checkpoint = resume_from_checkpoint,
|
315 |
-
hub_model_id = hub_model_id,
|
316 |
-
hub_strategy = hub_strategy,
|
317 |
-
hub_token = hub_token,
|
318 |
-
hub_private_repo = hub_private_repo,
|
319 |
-
hub_always_push = hub_always_push,
|
320 |
-
gradient_checkpointing = gradient_checkpointing,
|
321 |
-
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
322 |
-
include_inputs_for_metrics = include_inputs_for_metrics,
|
323 |
-
eval_do_concat_batches = eval_do_concat_batches,
|
324 |
-
fp16_backend = fp16_backend,
|
325 |
-
evaluation_strategy = evaluation_strategy,
|
326 |
-
push_to_hub_model_id = push_to_hub_model_id,
|
327 |
-
push_to_hub_organization = push_to_hub_organization,
|
328 |
-
push_to_hub_token = push_to_hub_token,
|
329 |
-
mp_parameters = mp_parameters,
|
330 |
-
auto_find_batch_size = auto_find_batch_size,
|
331 |
-
full_determinism = full_determinism,
|
332 |
-
torchdynamo = torchdynamo,
|
333 |
-
ray_scope = ray_scope,
|
334 |
-
ddp_timeout = ddp_timeout,
|
335 |
-
torch_compile = torch_compile,
|
336 |
-
torch_compile_backend = torch_compile_backend,
|
337 |
-
torch_compile_mode = torch_compile_mode,
|
338 |
-
dispatch_batches = dispatch_batches,
|
339 |
-
split_batches = split_batches,
|
340 |
-
include_tokens_per_second = include_tokens_per_second,
|
341 |
-
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
342 |
-
neftune_noise_alpha = neftune_noise_alpha,
|
343 |
-
optim_target_modules = optim_target_modules,
|
344 |
-
batch_eval_metrics = batch_eval_metrics,
|
345 |
-
eval_on_start = eval_on_start,
|
346 |
-
use_liger_kernel = use_liger_kernel,
|
347 |
-
eval_use_gather_object = eval_use_gather_object,
|
348 |
-
average_tokens_across_devices = average_tokens_across_devices,
|
349 |
-
reward_model_path = reward_model_path,
|
350 |
-
judge = judge,
|
351 |
-
max_new_tokens = max_new_tokens,
|
352 |
-
max_length = max_length,
|
353 |
-
temperature = temperature,
|
354 |
-
missing_eos_penalty = missing_eos_penalty,
|
355 |
-
loss_type = loss_type,
|
356 |
-
dataset_num_proc = dataset_num_proc,
|
357 |
-
disable_dropout = disable_dropout,
|
358 |
-
use_vllm = use_vllm,
|
359 |
-
ds3_gather_for_generation = ds3_gather_for_generation,**kwargs)
|
360 |
-
self.vllm_sampling_params = vllm_sampling_params
|
361 |
-
self.unsloth_num_chunks = unsloth_num_chunks
|
362 |
-
pass
|
363 |
-
|
364 |
-
class _UnslothXPOTrainer(OnlineDPOTrainer):
|
365 |
-
r""""""
|
366 |
-
|
367 |
-
_tag_names = ["trl", "xpo"]
|
368 |
-
|
369 |
-
def __init__(
|
370 |
-
self,
|
371 |
-
model: Union[PreTrainedModel, nn.Module] = None,
|
372 |
-
ref_model: Union[PreTrainedModel, nn.Module] = None,
|
373 |
-
reward_model: Optional[nn.Module] = None,
|
374 |
-
judge: Optional[BasePairwiseJudge] = None,
|
375 |
-
args: Optional[XPOConfig] = None,
|
376 |
-
data_collator: Optional[Callable] = None,
|
377 |
-
train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
|
378 |
-
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
379 |
-
processing_class: Optional[
|
380 |
-
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
381 |
-
] = None,
|
382 |
-
peft_config: Optional[dict] = None,
|
383 |
-
compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
|
384 |
-
callbacks: Optional[list[TrainerCallback]] = None,
|
385 |
-
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
386 |
-
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
387 |
-
) -> None:
|
388 |
-
super().__init__(
|
389 |
-
model=model,
|
390 |
-
ref_model=ref_model,
|
391 |
-
judge=judge,
|
392 |
-
reward_model=reward_model,
|
393 |
-
args=args,
|
394 |
-
data_collator=data_collator,
|
395 |
-
train_dataset=train_dataset,
|
396 |
-
eval_dataset=eval_dataset,
|
397 |
-
processing_class=processing_class,
|
398 |
-
reward_processing_class=processing_class, # for now, XPOTrainer can't use any reward model
|
399 |
-
peft_config=peft_config,
|
400 |
-
compute_metrics=compute_metrics,
|
401 |
-
callbacks=callbacks,
|
402 |
-
optimizers=optimizers,
|
403 |
-
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
404 |
-
)
|
405 |
-
|
406 |
-
self._alpha = self.args.alpha
|
407 |
-
|
408 |
-
# Overwrite the stats dictionary to include XPO specific statistics
|
409 |
-
self.stats = {
|
410 |
-
# Remove "non_score_reward", "rlhf_reward", "scores"
|
411 |
-
# Add "loss/dpo", "loss/xpo"
|
412 |
-
"loss/dpo": [],
|
413 |
-
"loss/xpo": [],
|
414 |
-
"objective/kl": [],
|
415 |
-
"objective/entropy": [],
|
416 |
-
"rewards/chosen": [],
|
417 |
-
"rewards/rejected": [],
|
418 |
-
"rewards/accuracies": [],
|
419 |
-
"rewards/margins": [],
|
420 |
-
"logps/chosen": [],
|
421 |
-
"logps/rejected": [],
|
422 |
-
# Replace "contain_eos_token" by "model_contain_eos_token" and "ref_contain_eos_token"
|
423 |
-
"val/model_contain_eos_token": [],
|
424 |
-
"val/ref_contain_eos_token": [],
|
425 |
-
"alpha": [],
|
426 |
-
"beta": [],
|
427 |
-
}
|
428 |
-
if self.reward_model is not None:
|
429 |
-
# Replace "scores" by "model_scores" and "ref_scores"
|
430 |
-
self.stats["objective/model_scores"] = []
|
431 |
-
self.stats["objective/ref_scores"] = []
|
432 |
-
self.stats["objective/scores_margin"] = []
|
433 |
-
|
434 |
-
@property
|
435 |
-
def alpha(self):
|
436 |
-
if isinstance(self._alpha, list):
|
437 |
-
epoch = self.state.epoch
|
438 |
-
return self._alpha[epoch] if epoch < len(self._alpha) else self._alpha[-1]
|
439 |
-
else:
|
440 |
-
return self._alpha
|
441 |
-
|
442 |
-
def _generate_completions(self, prompts, model):
|
443 |
-
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
|
444 |
-
model_output = unwrapped_model.generate(
|
445 |
-
input_ids=prompts["input_ids"],
|
446 |
-
attention_mask=prompts["attention_mask"],
|
447 |
-
generation_config=self.generation_config,
|
448 |
-
)
|
449 |
-
|
450 |
-
ref_model = model if self.ref_model is None else self.ref_model
|
451 |
-
with torch.no_grad(), unwrap_model_for_generation(ref_model, self.accelerator) as unwrapped_ref_model:
|
452 |
-
ref_output = unwrapped_ref_model.generate(
|
453 |
-
input_ids=prompts["input_ids"],
|
454 |
-
attention_mask=prompts["attention_mask"],
|
455 |
-
generation_config=self.generation_config,
|
456 |
-
)
|
457 |
-
|
458 |
-
return model_output, ref_output
|
459 |
-
|
460 |
-
def _process_completions(self, model_output, ref_output, prompts):
|
461 |
-
context_length = prompts["input_ids"].shape[1]
|
462 |
-
|
463 |
-
# Process model completions
|
464 |
-
model_completion_ids = model_output[:, context_length:]
|
465 |
-
model_completion_ids, model_completion_mask = truncate_right(
|
466 |
-
model_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id
|
467 |
-
)
|
468 |
-
model_data = {
|
469 |
-
"input_ids": torch.cat((prompts["input_ids"], model_completion_ids), dim=1),
|
470 |
-
"attention_mask": torch.cat((prompts["attention_mask"], model_completion_mask), dim=1),
|
471 |
-
"raw": prompts["raw"],
|
472 |
-
}
|
473 |
-
|
474 |
-
# Process reference model completions
|
475 |
-
ref_completion_ids = ref_output[:, context_length:]
|
476 |
-
ref_completion_ids, ref_completion_mask = truncate_right(
|
477 |
-
ref_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id
|
478 |
-
)
|
479 |
-
ref_data = {
|
480 |
-
"input_ids": torch.cat((prompts["input_ids"], ref_completion_ids), dim=1),
|
481 |
-
"attention_mask": torch.cat((prompts["attention_mask"], ref_completion_mask), dim=1),
|
482 |
-
"raw": prompts["raw"],
|
483 |
-
}
|
484 |
-
|
485 |
-
return model_data, ref_data
|
486 |
-
|
487 |
-
def _compute_rewards(self, model_data, ref_data, context_length):
|
488 |
-
with torch.no_grad():
|
489 |
-
_, model_scores, _ = get_reward(
|
490 |
-
self.reward_model, model_data["input_ids"], self.processing_class.pad_token_id, context_length
|
491 |
-
)
|
492 |
-
_, ref_scores, _ = get_reward(
|
493 |
-
self.reward_model, ref_data["input_ids"], self.processing_class.pad_token_id, context_length
|
494 |
-
)
|
495 |
-
|
496 |
-
# Apply EOS penalty if needed
|
497 |
-
if self.args.missing_eos_penalty is not None:
|
498 |
-
model_contain_eos = torch.any(model_data["input_ids"] == self.processing_class.eos_token_id, dim=-1)
|
499 |
-
ref_contain_eos = torch.any(ref_data["input_ids"] == self.processing_class.eos_token_id, dim=-1)
|
500 |
-
model_scores[~model_contain_eos] -= self.args.missing_eos_penalty
|
501 |
-
ref_scores[~ref_contain_eos] -= self.args.missing_eos_penalty
|
502 |
-
|
503 |
-
return model_scores, ref_scores
|
504 |
-
|
505 |
-
def _compute_judge(self, model_data, ref_data, context_length):
|
506 |
-
prompts = model_data["raw"]
|
507 |
-
model_data_completions = self.processing_class.batch_decode(
|
508 |
-
model_data["input_ids"][:, context_length:], skip_special_tokens=True
|
509 |
-
)
|
510 |
-
model_data_completions = [completion.strip() for completion in model_data_completions]
|
511 |
-
|
512 |
-
ref_data_completions = self.processing_class.batch_decode(
|
513 |
-
ref_data["input_ids"][:, context_length:], skip_special_tokens=True
|
514 |
-
)
|
515 |
-
ref_data_completions = [completion.strip() for completion in ref_data_completions]
|
516 |
-
|
517 |
-
if is_conversational({"prompt": prompts[0]}):
|
518 |
-
model_data_completions = [
|
519 |
-
[{"role": "assistant", "content": completion}] for completion in model_data_completions
|
520 |
-
]
|
521 |
-
environment = jinja2.Environment()
|
522 |
-
template = environment.from_string(SIMPLE_CHAT_TEMPLATE)
|
523 |
-
prompts = [template.render(messages=message) for message in prompts]
|
524 |
-
model_data_completions = [template.render(messages=completion) for completion in model_data_completions]
|
525 |
-
|
526 |
-
ref_data_completions = [
|
527 |
-
[{"role": "assistant", "content": completion}] for completion in ref_data_completions
|
528 |
-
]
|
529 |
-
ref_data_completions = [template.render(messages=completion) for completion in ref_data_completions]
|
530 |
-
|
531 |
-
ranks_of_first_completion = self.judge.judge(
|
532 |
-
prompts,
|
533 |
-
list(zip(model_data_completions, ref_data_completions)),
|
534 |
-
)
|
535 |
-
# convert ranks to a True/False mask:
|
536 |
-
# when rank == 0, it means the first completion is the best
|
537 |
-
# when rank == 1, it means the second completion is the best
|
538 |
-
return torch.tensor([rank == 0 for rank in ranks_of_first_completion], device=model_data["input_ids"].device)
|
539 |
-
|
540 |
-
def _compute_logprobs(self, model, model_data, ref_data, context_length):
|
541 |
-
def compute_logprobs_for_data(m, data):
|
542 |
-
output = m(data["input_ids"], attention_mask=data["attention_mask"])
|
543 |
-
logits = output.logits[:, context_length - 1 : -1]
|
544 |
-
token_logprobs = selective_log_softmax(logits, data["input_ids"][:, context_length:])
|
545 |
-
return token_logprobs
|
546 |
-
|
547 |
-
# Compute logprobs for model completions
|
548 |
-
model_logprobs_model_data = compute_logprobs_for_data(model, model_data)
|
549 |
-
# Compute logprobs for model on reference completions (for XPO loss)
|
550 |
-
model_logprobs_ref_data = compute_logprobs_for_data(model, ref_data)
|
551 |
-
|
552 |
-
# Compute logprobs for reference model completions
|
553 |
-
with torch.no_grad():
|
554 |
-
if self.ref_model is None:
|
555 |
-
with model.disable_adapter():
|
556 |
-
ref_logprobs_model_data = compute_logprobs_for_data(model, model_data)
|
557 |
-
ref_logprobs_ref_data = compute_logprobs_for_data(model, ref_data)
|
558 |
-
else:
|
559 |
-
ref_logprobs_model_data = compute_logprobs_for_data(self.ref_model, model_data)
|
560 |
-
ref_logprobs_ref_data = compute_logprobs_for_data(self.ref_model, ref_data)
|
561 |
-
|
562 |
-
# Mask padding tokens
|
563 |
-
model_padding_mask = model_data["attention_mask"][:, context_length:] == 0
|
564 |
-
ref_padding_mask = ref_data["attention_mask"][:, context_length:] == 0
|
565 |
-
model_logprobs_model_data = model_logprobs_model_data.masked_fill(model_padding_mask, 0.0)
|
566 |
-
model_logprobs_ref_data = model_logprobs_ref_data.masked_fill(ref_padding_mask, 0.0)
|
567 |
-
ref_logprobs_ref_data = ref_logprobs_ref_data.masked_fill(ref_padding_mask, 0.0)
|
568 |
-
ref_logprobs_model_data = ref_logprobs_model_data.masked_fill(model_padding_mask, 0.0)
|
569 |
-
|
570 |
-
return model_logprobs_model_data, model_logprobs_ref_data, ref_logprobs_ref_data, ref_logprobs_model_data
|
571 |
-
|
572 |
-
def _compute_losses(
|
573 |
-
self,
|
574 |
-
model_logprobs_model_data,
|
575 |
-
model_logprobs_ref_data,
|
576 |
-
ref_logprobs_ref_data,
|
577 |
-
ref_logprobs_model_data,
|
578 |
-
chosen_mask,
|
579 |
-
):
|
580 |
-
# Compute log probs
|
581 |
-
model_logprobs_model_data_sum = model_logprobs_model_data.sum(1)
|
582 |
-
model_logprobs_ref_data_sum = model_logprobs_ref_data.sum(1)
|
583 |
-
ref_logprobs_ref_data_sum = ref_logprobs_ref_data.sum(1)
|
584 |
-
ref_logprobs_model_data_sum = ref_logprobs_model_data.sum(1)
|
585 |
-
|
586 |
-
chosen_model_logprobs = torch.where(chosen_mask, model_logprobs_model_data_sum, model_logprobs_ref_data_sum)
|
587 |
-
chosen_ref_logprobs = torch.where(chosen_mask, ref_logprobs_model_data_sum, ref_logprobs_ref_data_sum)
|
588 |
-
chosen_log_ratios = chosen_model_logprobs - chosen_ref_logprobs
|
589 |
-
|
590 |
-
rejected_model_logprobs = torch.where(~chosen_mask, model_logprobs_model_data_sum, model_logprobs_ref_data_sum)
|
591 |
-
rejected_ref_logprobs = torch.where(~chosen_mask, ref_logprobs_model_data_sum, ref_logprobs_ref_data_sum)
|
592 |
-
rejected_log_ratios = rejected_model_logprobs - rejected_ref_logprobs
|
593 |
-
|
594 |
-
# Compute logits as the difference between chosen and rejected log ratios
|
595 |
-
logits = chosen_log_ratios - rejected_log_ratios
|
596 |
-
|
597 |
-
if self.args.loss_type == "sigmoid":
|
598 |
-
dpo_losses = -F.logsigmoid(self.beta * logits)
|
599 |
-
elif self.args.loss_type == "ipo":
|
600 |
-
dpo_losses = (logits - 1 / (2 * self.beta)) ** 2
|
601 |
-
else:
|
602 |
-
raise NotImplementedError(f"invalid loss type {self.args.loss_type}")
|
603 |
-
|
604 |
-
# Compute XPO specific loss
|
605 |
-
xpo_losses = self.alpha * model_logprobs_ref_data_sum
|
606 |
-
|
607 |
-
# Total loss
|
608 |
-
loss = (dpo_losses + xpo_losses).mean()
|
609 |
-
|
610 |
-
return loss, dpo_losses, xpo_losses
|
611 |
-
|
612 |
-
def _log_statistics(
|
613 |
-
self,
|
614 |
-
model_data,
|
615 |
-
ref_data,
|
616 |
-
model_logprobs_model_data,
|
617 |
-
model_logprobs_ref_data,
|
618 |
-
ref_logprobs_ref_data,
|
619 |
-
ref_logprobs_model_data,
|
620 |
-
chosen_mask,
|
621 |
-
dpo_losses,
|
622 |
-
xpo_losses,
|
623 |
-
context_length,
|
624 |
-
model_scores=None,
|
625 |
-
ref_scores=None,
|
626 |
-
):
|
627 |
-
# Helper function to gather and compute mean
|
628 |
-
def gather_mean(tensor):
|
629 |
-
return self.accelerator.gather_for_metrics(tensor).mean().item()
|
630 |
-
|
631 |
-
# Log losses
|
632 |
-
self.stats["loss/dpo"].append(gather_mean(dpo_losses))
|
633 |
-
self.stats["loss/xpo"].append(gather_mean(xpo_losses))
|
634 |
-
|
635 |
-
# Log scores
|
636 |
-
if self.reward_model is not None:
|
637 |
-
self.stats["objective/model_scores"].append(gather_mean(model_scores))
|
638 |
-
self.stats["objective/ref_scores"].append(gather_mean(ref_scores))
|
639 |
-
self.stats["objective/scores_margin"].append(gather_mean(model_scores - ref_scores))
|
640 |
-
|
641 |
-
# Log logprobs
|
642 |
-
model_logprobs_model_data_sum = model_logprobs_model_data.sum(1)
|
643 |
-
model_logprobs_ref_data_sum = model_logprobs_ref_data.sum(1)
|
644 |
-
ref_logprobs_ref_data_sum = ref_logprobs_ref_data.sum(1)
|
645 |
-
ref_logprobs_model_data_sum = ref_logprobs_model_data.sum(1)
|
646 |
-
|
647 |
-
chosen_model_logprobs = torch.where(chosen_mask, model_logprobs_model_data_sum, model_logprobs_ref_data_sum)
|
648 |
-
chosen_ref_logprobs = torch.where(chosen_mask, ref_logprobs_model_data_sum, ref_logprobs_ref_data_sum)
|
649 |
-
chosen_log_ratios = chosen_model_logprobs - chosen_ref_logprobs
|
650 |
-
|
651 |
-
rejected_model_logprobs = torch.where(~chosen_mask, model_logprobs_model_data_sum, model_logprobs_ref_data_sum)
|
652 |
-
rejected_ref_logprobs = torch.where(~chosen_mask, ref_logprobs_model_data_sum, ref_logprobs_ref_data_sum)
|
653 |
-
rejected_log_ratios = rejected_model_logprobs - rejected_ref_logprobs
|
654 |
-
|
655 |
-
self.stats["logps/chosen"].append(gather_mean(chosen_model_logprobs.mean() + chosen_ref_logprobs.mean()))
|
656 |
-
self.stats["logps/rejected"].append(gather_mean(rejected_model_logprobs.mean() + rejected_ref_logprobs.mean()))
|
657 |
-
|
658 |
-
# Log rewards
|
659 |
-
# Compute various statistics
|
660 |
-
chosen_rewards = chosen_log_ratios * self.beta
|
661 |
-
rejected_rewards = rejected_log_ratios * self.beta
|
662 |
-
self.stats["rewards/chosen"].append(gather_mean(chosen_rewards.mean()))
|
663 |
-
self.stats["rewards/rejected"].append(gather_mean(rejected_rewards.mean()))
|
664 |
-
|
665 |
-
# Calculate KL divergence for model and ref data
|
666 |
-
kl_model_data = model_logprobs_model_data - ref_logprobs_model_data
|
667 |
-
kl_ref_data = model_logprobs_ref_data - ref_logprobs_ref_data
|
668 |
-
mean_kl = (kl_model_data.sum(1) + kl_ref_data.sum(1)).mean() / 2
|
669 |
-
self.stats["objective/kl"].append(gather_mean(mean_kl))
|
670 |
-
|
671 |
-
# Calculate entropy for model and ref data
|
672 |
-
entropy_model_data = -model_logprobs_model_data.sum(1)
|
673 |
-
entropy_ref_data = -model_logprobs_ref_data.sum(1)
|
674 |
-
mean_entropy = (entropy_model_data.mean() + entropy_ref_data.mean()) / 2
|
675 |
-
self.stats["objective/entropy"].append(gather_mean(mean_entropy))
|
676 |
-
|
677 |
-
# Calculate margins
|
678 |
-
margin = chosen_rewards - rejected_rewards
|
679 |
-
self.stats["rewards/margins"].append(gather_mean(margin.mean()))
|
680 |
-
|
681 |
-
# Calculate accuracy
|
682 |
-
accuracy = (margin > 0).float()
|
683 |
-
self.stats["rewards/accuracies"].append(gather_mean(accuracy.mean()))
|
684 |
-
|
685 |
-
# Log EOS token statistics
|
686 |
-
model_eos = (model_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1)
|
687 |
-
ref_eos = (ref_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1)
|
688 |
-
self.stats["val/model_contain_eos_token"].append(gather_mean(model_eos.float()))
|
689 |
-
self.stats["val/ref_contain_eos_token"].append(gather_mean(ref_eos.float()))
|
690 |
-
|
691 |
-
# Log alpha and beta
|
692 |
-
self.stats["alpha"].append(self.alpha)
|
693 |
-
self.stats["beta"].append(self.beta)
|
694 |
-
|
695 |
-
def training_step(
|
696 |
-
self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
|
697 |
-
) -> torch.Tensor:
|
698 |
-
model.train()
|
699 |
-
|
700 |
-
# Apply chat template and tokenize the input
|
701 |
-
batch_size = len(next(iter(inputs.values())))
|
702 |
-
prompts = inputs["prompt"]
|
703 |
-
inputs = [{k: v[i] for k, v in inputs.items()} for i in range(batch_size)]
|
704 |
-
inputs = [maybe_apply_chat_template(x, self.processing_class) for x in inputs]
|
705 |
-
inputs = [self.tokenize_row(x, self.model.config.is_encoder_decoder, self.processing_class) for x in inputs]
|
706 |
-
inputs = self.data_collator(inputs)
|
707 |
-
|
708 |
-
# need the prompt_ only
|
709 |
-
inputs = self._prepare_inputs(inputs)
|
710 |
-
context_length = inputs["prompt_input_ids"].shape[1]
|
711 |
-
prompts = {
|
712 |
-
"input_ids": inputs["prompt_input_ids"],
|
713 |
-
"attention_mask": inputs["prompt_attention_mask"],
|
714 |
-
"raw": prompts,
|
715 |
-
}
|
716 |
-
del inputs
|
717 |
-
|
718 |
-
# Sample completions from both the model and the reference model
|
719 |
-
model_output, ref_output = self._generate_completions(prompts, model)
|
720 |
-
|
721 |
-
# Process model completions
|
722 |
-
model_data, ref_data = self._process_completions(model_output, ref_output, prompts)
|
723 |
-
|
724 |
-
# Compute rewards
|
725 |
-
if self.reward_model is not None:
|
726 |
-
model_scores, ref_scores = self._compute_rewards(model_data, ref_data, context_length)
|
727 |
-
chosen_mask = model_scores >= ref_scores
|
728 |
-
else:
|
729 |
-
model_scores, ref_scores = None, None
|
730 |
-
chosen_mask = self._compute_judge(model_data, ref_data, context_length)
|
731 |
-
|
732 |
-
# Compute logprobs
|
733 |
-
model_logprobs_model_data, model_logprobs_ref_data, ref_logprobs_ref_data, ref_logprobs_model_data = (
|
734 |
-
self._compute_logprobs(model, model_data, ref_data, context_length)
|
735 |
-
)
|
736 |
-
|
737 |
-
# Compute loss
|
738 |
-
loss, dpo_losses, xpo_losses = self._compute_losses(
|
739 |
-
model_logprobs_model_data,
|
740 |
-
model_logprobs_ref_data,
|
741 |
-
ref_logprobs_ref_data,
|
742 |
-
ref_logprobs_model_data,
|
743 |
-
chosen_mask,
|
744 |
-
)
|
745 |
-
|
746 |
-
# Log everything
|
747 |
-
self._log_statistics(
|
748 |
-
model_data,
|
749 |
-
ref_data,
|
750 |
-
model_logprobs_model_data.detach(),
|
751 |
-
model_logprobs_ref_data.detach(),
|
752 |
-
ref_logprobs_ref_data,
|
753 |
-
ref_logprobs_model_data,
|
754 |
-
chosen_mask,
|
755 |
-
dpo_losses.detach(),
|
756 |
-
xpo_losses.detach(),
|
757 |
-
context_length,
|
758 |
-
model_scores,
|
759 |
-
ref_scores,
|
760 |
-
)
|
761 |
-
|
762 |
-
if (
|
763 |
-
self.args.torch_empty_cache_steps is not None
|
764 |
-
and self.state.global_step % self.args.torch_empty_cache_steps == 0
|
765 |
-
):
|
766 |
-
empty_cache()
|
767 |
-
|
768 |
-
kwargs = {}
|
769 |
-
# For LOMO optimizers you need to explicitly use the learning rate
|
770 |
-
if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
|
771 |
-
kwargs["learning_rate"] = self._get_learning_rate()
|
772 |
-
|
773 |
-
if self.args.n_gpu > 1:
|
774 |
-
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
775 |
-
|
776 |
-
if self.use_apex:
|
777 |
-
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
778 |
-
scaled_loss.backward()
|
779 |
-
else:
|
780 |
-
self.accelerator.backward(loss, **kwargs)
|
781 |
-
|
782 |
-
return loss.detach() / self.args.gradient_accumulation_steps
|
783 |
-
|
784 |
-
def create_model_card(
|
785 |
-
self,
|
786 |
-
model_name: Optional[str] = None,
|
787 |
-
dataset_name: Optional[str] = None,
|
788 |
-
tags: Union[str, list[str], None] = None,
|
789 |
-
):
|
790 |
-
"""
|
791 |
-
Creates a draft of a model card using the information available to the `Trainer`.
|
792 |
-
|
793 |
-
Args:
|
794 |
-
model_name (`str` or `None`, *optional*, defaults to `None`):
|
795 |
-
Name of the model.
|
796 |
-
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
797 |
-
Name of the dataset used for training.
|
798 |
-
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
799 |
-
Tags to be associated with the model card.
|
800 |
-
"""
|
801 |
-
if not self.is_world_process_zero():
|
802 |
-
return
|
803 |
-
|
804 |
-
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
805 |
-
base_model = self.model.config._name_or_path
|
806 |
-
else:
|
807 |
-
base_model = None
|
808 |
-
|
809 |
-
tags = tags or []
|
810 |
-
if isinstance(tags, str):
|
811 |
-
tags = [tags]
|
812 |
-
|
813 |
-
if hasattr(self.model.config, "unsloth_version"):
|
814 |
-
tags.append("unsloth")
|
815 |
-
|
816 |
-
citation = textwrap.dedent("""\
|
817 |
-
@article{jung2024binary,
|
818 |
-
title = {{Exploratory Preference Optimization: Harnessing Implicit Q*-Approximation for Sample-Efficient RLHF}},
|
819 |
-
author = {Tengyang Xie and Dylan J. Foster and Akshay Krishnamurthy and Corby Rosset and Ahmed Awadallah and Alexander Rakhlin},
|
820 |
-
year = 2024,
|
821 |
-
eprint = {arXiv:2405.21046}
|
822 |
-
}""")
|
823 |
-
|
824 |
-
model_card = generate_model_card(
|
825 |
-
base_model=base_model,
|
826 |
-
model_name=model_name,
|
827 |
-
hub_model_id=self.hub_model_id,
|
828 |
-
dataset_name=dataset_name,
|
829 |
-
tags=tags,
|
830 |
-
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
831 |
-
comet_url=get_comet_experiment_url(),
|
832 |
-
trainer_name="XPO",
|
833 |
-
trainer_citation=citation,
|
834 |
-
paper_title="Exploratory Preference Optimization: Harnessing Implicit Q*-Approximation for Sample-Efficient RLHF",
|
835 |
-
paper_id="2405.21046",
|
836 |
-
)
|
837 |
-
|
838 |
-
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
839 |
-
class UnslothXPOTrainer(_UnslothXPOTrainer):
|
840 |
-
"""
|
841 |
-
|
842 |
-
Initialize XPOTrainer as a subclass of [`OnlineDPOConfig`].
|
843 |
-
|
844 |
-
Args:
|
845 |
-
model (`transformers.PreTrainedModel`):
|
846 |
-
The model to train, preferably an `AutoModelForCausalLM`.
|
847 |
-
ref_model (`PreTrainedModelWrapper`):
|
848 |
-
Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no
|
849 |
-
reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized.
|
850 |
-
reward_model (`transformers.PreTrainedModel`):
|
851 |
-
The reward model to score completions with, preferably an `AutoModelForSequenceClassification`.
|
852 |
-
judge (`BasePairwiseJudge`):
|
853 |
-
The judge to use for pairwise comparison of model completions.
|
854 |
-
args (`XPOConfig`):
|
855 |
-
The XPO config arguments to use for training.
|
856 |
-
data_collator (`transformers.DataCollator`):
|
857 |
-
The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
|
858 |
-
which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
|
859 |
-
train_dataset (`datasets.Dataset`):
|
860 |
-
The dataset to use for training.
|
861 |
-
eval_dataset (`datasets.Dataset`):
|
862 |
-
The dataset to use for evaluation.
|
863 |
-
processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
|
864 |
-
Processing class used to process the data. If provided, will be used to automatically process the inputs
|
865 |
-
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
|
866 |
-
reuse the fine-tuned model.
|
867 |
-
peft_config (`dict`):
|
868 |
-
The peft config to use for training.
|
869 |
-
compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
|
870 |
-
The function to use to compute the metrics. Must take a `EvalPrediction` and return
|
871 |
-
a dictionary string to metric values.
|
872 |
-
callbacks (`list[transformers.TrainerCallback]`):
|
873 |
-
The callbacks to use for training.
|
874 |
-
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
|
875 |
-
The optimizer and scheduler to use for training.
|
876 |
-
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
|
877 |
-
The function to use to preprocess the logits before computing the metrics.
|
878 |
-
|
879 |
-
"""
|
880 |
-
def __init__(
|
881 |
-
self,
|
882 |
-
model = None,
|
883 |
-
ref_model = None,
|
884 |
-
reward_model = None,
|
885 |
-
judge = None,
|
886 |
-
args = None,
|
887 |
-
data_collator = None,
|
888 |
-
train_dataset = None,
|
889 |
-
eval_dataset = None,
|
890 |
-
processing_class = None,
|
891 |
-
peft_config = None,
|
892 |
-
compute_metrics = None,
|
893 |
-
callbacks = None,
|
894 |
-
preprocess_logits_for_metrics = None,
|
895 |
-
**kwargs
|
896 |
-
):
|
897 |
-
if args is None: args = UnslothXPOConfig()
|
898 |
-
use_bf16 = getattr(args, 'bf16', False)
|
899 |
-
use_fp16 = getattr(args, 'fp16', False)
|
900 |
-
force_float32 = False
|
901 |
-
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
902 |
-
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
903 |
-
force_float32 = True
|
904 |
-
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
905 |
-
dtype = getattr(model.config, 'torch_dtype', None)
|
906 |
-
if dtype is None: dtype = model.get_input_embeddings().dtype
|
907 |
-
from unsloth_zoo.utils import _get_dtype
|
908 |
-
dtype = _get_dtype(dtype)
|
909 |
-
float16 = dtype == torch.float16
|
910 |
-
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
911 |
-
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
912 |
-
if force_float32:
|
913 |
-
args.fp16 = False
|
914 |
-
args.bf16 = False
|
915 |
-
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
916 |
-
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
917 |
-
args.fp16 = float16
|
918 |
-
args.bf16 = not float16
|
919 |
-
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
920 |
-
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
921 |
-
args.eval_strategy = 'steps'
|
922 |
-
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
923 |
-
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
924 |
-
if ga_steps is not None and ga_steps > 1:
|
925 |
-
from transformers import __version__ as transformers_version
|
926 |
-
if Version(transformers_version) <= Version('4.45.2'):
|
927 |
-
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
928 |
-
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
929 |
-
if getattr(args, 'eval_strategy', 'no') != 'no':
|
930 |
-
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
931 |
-
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
932 |
-
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
933 |
-
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
934 |
-
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
935 |
-
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
936 |
-
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
937 |
-
if force_float32:
|
938 |
-
args.bf16_full_eval = False
|
939 |
-
args.fp16_full_eval = False
|
940 |
-
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
941 |
-
args.bf16_full_eval = True
|
942 |
-
args.fp16_full_eval = False
|
943 |
-
elif not bf16_full_eval and not fp16_full_eval:
|
944 |
-
args.bf16_full_eval = args.bf16
|
945 |
-
args.fp16_full_eval = args.fp16
|
946 |
-
_output_logits = False
|
947 |
-
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
948 |
-
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
949 |
-
if _output_logits:
|
950 |
-
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
951 |
-
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
952 |
-
pass
|
953 |
-
else:
|
954 |
-
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
955 |
-
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
956 |
-
if args_max_seq_length is None and model_max_seq_length is not None:
|
957 |
-
max_seq_length = model.max_seq_length
|
958 |
-
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
959 |
-
if model is not None and hasattr(model, 'for_training'):
|
960 |
-
model.for_training()
|
961 |
-
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
962 |
-
if 'processing_class' in locals():
|
963 |
-
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
964 |
-
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
965 |
-
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
966 |
-
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
967 |
-
if not isinstance(data_collator, UnslothVisionDataCollator):
|
968 |
-
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
969 |
-
data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
|
970 |
-
elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
971 |
-
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
972 |
-
else:
|
973 |
-
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
974 |
-
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
975 |
-
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
976 |
-
if not isinstance(data_collator, UnslothVisionDataCollator):
|
977 |
-
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
978 |
-
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
979 |
-
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
980 |
-
else:
|
981 |
-
data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
|
982 |
-
other_metrics = []
|
983 |
-
|
984 |
-
from unsloth_zoo.logging_utils import PatchRLStatistics
|
985 |
-
PatchRLStatistics('xpo_trainer', other_metrics)
|
986 |
-
|
987 |
-
super().__init__(
|
988 |
-
model = model,
|
989 |
-
ref_model = ref_model,
|
990 |
-
reward_model = reward_model,
|
991 |
-
judge = judge,
|
992 |
-
args = args,
|
993 |
-
data_collator = data_collator,
|
994 |
-
train_dataset = train_dataset,
|
995 |
-
eval_dataset = eval_dataset,
|
996 |
-
processing_class = processing_class,
|
997 |
-
peft_config = peft_config,
|
998 |
-
compute_metrics = compute_metrics,
|
999 |
-
callbacks = callbacks,
|
1000 |
-
preprocess_logits_for_metrics = preprocess_logits_for_metrics,**kwargs)
|
1001 |
-
if hasattr(self, 'neftune_hook_handle'):
|
1002 |
-
self.neftune_hook_handle.remove()
|
1003 |
-
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
1004 |
-
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
1005 |
-
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
1006 |
-
pass
|
1007 |
-
|
1008 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unsloth_compiled_cache/__pycache__/UnslothAlignPropTrainer.cpython-311.pyc
DELETED
Binary file (32.9 kB)
|
|
unsloth_compiled_cache/__pycache__/UnslothBCOTrainer.cpython-311.pyc
DELETED
Binary file (91.7 kB)
|
|
unsloth_compiled_cache/__pycache__/UnslothCPOTrainer.cpython-311.pyc
DELETED
Binary file (75.6 kB)
|
|
unsloth_compiled_cache/__pycache__/UnslothDDPOTrainer.cpython-311.pyc
DELETED
Binary file (45.5 kB)
|
|
unsloth_compiled_cache/__pycache__/UnslothDPOTrainer.cpython-311.pyc
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:866cfa82efe9b576ed95d6c36ad4eb1e47599692a7b674f9acbb3e671b4d5178
|
3 |
-
size 103496
|
|
|
|
|
|
|
|
unsloth_compiled_cache/__pycache__/UnslothGKDTrainer.cpython-311.pyc
DELETED
Binary file (37.7 kB)
|
|
unsloth_compiled_cache/__pycache__/UnslothGRPOTrainer.cpython-311.pyc
DELETED
Binary file (78.4 kB)
|
|
unsloth_compiled_cache/__pycache__/UnslothKTOTrainer.cpython-311.pyc
DELETED
Binary file (87.3 kB)
|
|
unsloth_compiled_cache/__pycache__/UnslothNashMDTrainer.cpython-311.pyc
DELETED
Binary file (47.2 kB)
|
|
unsloth_compiled_cache/__pycache__/UnslothORPOTrainer.cpython-311.pyc
DELETED
Binary file (75.5 kB)
|
|
unsloth_compiled_cache/__pycache__/UnslothOnlineDPOTrainer.cpython-311.pyc
DELETED
Binary file (67.1 kB)
|
|
unsloth_compiled_cache/__pycache__/UnslothPPOTrainer.cpython-311.pyc
DELETED
Binary file (62.6 kB)
|
|
unsloth_compiled_cache/__pycache__/UnslothPRMTrainer.cpython-311.pyc
DELETED
Binary file (36.4 kB)
|
|
unsloth_compiled_cache/__pycache__/UnslothRLOOTrainer.cpython-311.pyc
DELETED
Binary file (54.2 kB)
|
|
unsloth_compiled_cache/__pycache__/UnslothRewardTrainer.cpython-311.pyc
DELETED
Binary file (38.9 kB)
|
|
unsloth_compiled_cache/__pycache__/UnslothSFTTrainer.cpython-311.pyc
DELETED
Binary file (47.8 kB)
|
|
unsloth_compiled_cache/__pycache__/UnslothXPOTrainer.cpython-311.pyc
DELETED
Binary file (49.9 kB)
|
|