codewithdark commited on
Commit
a3e502c
·
verified ·
1 Parent(s): 0c88ced

Delete unsloth_compiled_cache

Browse files
Files changed (34) hide show
  1. unsloth_compiled_cache/UnslothAlignPropTrainer.py +0 -637
  2. unsloth_compiled_cache/UnslothBCOTrainer.py +0 -1822
  3. unsloth_compiled_cache/UnslothCPOTrainer.py +0 -1555
  4. unsloth_compiled_cache/UnslothDDPOTrainer.py +0 -872
  5. unsloth_compiled_cache/UnslothDPOTrainer.py +0 -0
  6. unsloth_compiled_cache/UnslothGKDTrainer.py +0 -861
  7. unsloth_compiled_cache/UnslothGRPOTrainer.py +0 -1436
  8. unsloth_compiled_cache/UnslothKTOTrainer.py +0 -1838
  9. unsloth_compiled_cache/UnslothNashMDTrainer.py +0 -953
  10. unsloth_compiled_cache/UnslothORPOTrainer.py +0 -1541
  11. unsloth_compiled_cache/UnslothOnlineDPOTrainer.py +0 -1267
  12. unsloth_compiled_cache/UnslothPPOTrainer.py +0 -1257
  13. unsloth_compiled_cache/UnslothPRMTrainer.py +0 -798
  14. unsloth_compiled_cache/UnslothRLOOTrainer.py +0 -1131
  15. unsloth_compiled_cache/UnslothRewardTrainer.py +0 -817
  16. unsloth_compiled_cache/UnslothSFTTrainer.py +0 -1025
  17. unsloth_compiled_cache/UnslothXPOTrainer.py +0 -1008
  18. unsloth_compiled_cache/__pycache__/UnslothAlignPropTrainer.cpython-311.pyc +0 -0
  19. unsloth_compiled_cache/__pycache__/UnslothBCOTrainer.cpython-311.pyc +0 -0
  20. unsloth_compiled_cache/__pycache__/UnslothCPOTrainer.cpython-311.pyc +0 -0
  21. unsloth_compiled_cache/__pycache__/UnslothDDPOTrainer.cpython-311.pyc +0 -0
  22. unsloth_compiled_cache/__pycache__/UnslothDPOTrainer.cpython-311.pyc +0 -3
  23. unsloth_compiled_cache/__pycache__/UnslothGKDTrainer.cpython-311.pyc +0 -0
  24. unsloth_compiled_cache/__pycache__/UnslothGRPOTrainer.cpython-311.pyc +0 -0
  25. unsloth_compiled_cache/__pycache__/UnslothKTOTrainer.cpython-311.pyc +0 -0
  26. unsloth_compiled_cache/__pycache__/UnslothNashMDTrainer.cpython-311.pyc +0 -0
  27. unsloth_compiled_cache/__pycache__/UnslothORPOTrainer.cpython-311.pyc +0 -0
  28. unsloth_compiled_cache/__pycache__/UnslothOnlineDPOTrainer.cpython-311.pyc +0 -0
  29. unsloth_compiled_cache/__pycache__/UnslothPPOTrainer.cpython-311.pyc +0 -0
  30. unsloth_compiled_cache/__pycache__/UnslothPRMTrainer.cpython-311.pyc +0 -0
  31. unsloth_compiled_cache/__pycache__/UnslothRLOOTrainer.cpython-311.pyc +0 -0
  32. unsloth_compiled_cache/__pycache__/UnslothRewardTrainer.cpython-311.pyc +0 -0
  33. unsloth_compiled_cache/__pycache__/UnslothSFTTrainer.cpython-311.pyc +0 -0
  34. unsloth_compiled_cache/__pycache__/UnslothXPOTrainer.cpython-311.pyc +0 -0
unsloth_compiled_cache/UnslothAlignPropTrainer.py DELETED
@@ -1,637 +0,0 @@
1
- """
2
- 2025.3.13
3
- 2025.3.15
4
- 4.48.3
5
- 0.15.2
6
- __UNSLOTH_VERSIONING__
7
- """
8
- from torch import Tensor
9
- import torch
10
- import torch.nn as nn
11
- from torch.nn import functional as F
12
- from trl.trainer.alignprop_trainer import (Accelerator, AlignPropConfig, AlignPropTrainer, Any, Callable, DDPOStableDiffusionPipeline, Optional, ProjectConfiguration, PyTorchModelHubMixin, Union, defaultdict, generate_model_card, get_comet_experiment_url, is_wandb_available, logger, os, set_seed, textwrap, torch, wandb, warn)
13
-
14
-
15
- import os
16
- from typing import *
17
- from dataclasses import dataclass, field
18
- from packaging.version import Version
19
- import torch
20
- import numpy as np
21
- from contextlib import nullcontext
22
- from torch.nn import functional as F
23
- from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
-
25
- torch_compile_options = {
26
- "epilogue_fusion" : True,
27
- "max_autotune" : False,
28
- "shape_padding" : True,
29
- "trace.enabled" : False,
30
- "triton.cudagraphs" : False,
31
- }
32
-
33
- @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
- def selective_log_softmax(logits, index):
35
- logits = logits.to(torch.float32)
36
- selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
- # loop to reduce peak mem consumption
38
- # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
- logsumexp_values = torch.logsumexp(logits, dim = -1)
40
- per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
- return per_token_logps
42
- @dataclass
43
- class UnslothAlignPropConfig(AlignPropConfig):
44
- """
45
-
46
- Configuration class for the [`AlignPropTrainer`].
47
-
48
- Using [`~transformers.HfArgumentParser`] we can turn this class into
49
- [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
50
- command line.
51
-
52
- Parameters:
53
- exp_name (`str`, *optional*, defaults to `os.path.basename(sys.argv[0])[: -len(".py")]`):
54
- Name of this experiment (defaults to the file name without the extension).
55
- run_name (`str`, *optional*, defaults to `""`):
56
- Name of this run.
57
- seed (`int`, *optional*, defaults to `0`):
58
- Random seed for reproducibility.
59
- log_with (`str` or `None`, *optional*, defaults to `None`):
60
- Log with either `"wandb"` or `"tensorboard"`. Check
61
- [tracking](https://huggingface.co/docs/accelerate/usage_guides/tracking) for more details.
62
- log_image_freq (`int`, *optional*, defaults to `1`):
63
- Frequency for logging images.
64
- tracker_kwargs (`dict[str, Any]`, *optional*, defaults to `{}`):
65
- Keyword arguments for the tracker (e.g., `wandb_project`).
66
- accelerator_kwargs (`dict[str, Any]`, *optional*, defaults to `{}`):
67
- Keyword arguments for the accelerator.
68
- project_kwargs (`dict[str, Any]`, *optional*, defaults to `{}`):
69
- Keyword arguments for the accelerator project config (e.g., `logging_dir`).
70
- tracker_project_name (`str`, *optional*, defaults to `"trl"`):
71
- Name of project to use for tracking.
72
- logdir (`str`, *optional*, defaults to `"logs"`):
73
- Top-level logging directory for checkpoint saving.
74
- num_epochs (`int`, *optional*, defaults to `100`):
75
- Number of epochs to train.
76
- save_freq (`int`, *optional*, defaults to `1`):
77
- Number of epochs between saving model checkpoints.
78
- num_checkpoint_limit (`int`, *optional*, defaults to `5`):
79
- Number of checkpoints to keep before overwriting old ones.
80
- mixed_precision (`str`, *optional*, defaults to `"fp16"`):
81
- Mixed precision training.
82
- allow_tf32 (`bool`, *optional*, defaults to `True`):
83
- Allow `tf32` on Ampere GPUs.
84
- resume_from (`str`, *optional*, defaults to `""`):
85
- Path to resume training from a checkpoint.
86
- sample_num_steps (`int`, *optional*, defaults to `50`):
87
- Number of sampler inference steps.
88
- sample_eta (`float`, *optional*, defaults to `1.0`):
89
- Eta parameter for the DDIM sampler.
90
- sample_guidance_scale (`float`, *optional*, defaults to `5.0`):
91
- Classifier-free guidance weight.
92
- train_batch_size (`int`, *optional*, defaults to `1`):
93
- Batch size for training.
94
- train_use_8bit_adam (`bool`, *optional*, defaults to `False`):
95
- Whether to use the 8bit Adam optimizer from `bitsandbytes`.
96
- train_learning_rate (`float`, *optional*, defaults to `1e-3`):
97
- Learning rate.
98
- train_adam_beta1 (`float`, *optional*, defaults to `0.9`):
99
- Beta1 for Adam optimizer.
100
- train_adam_beta2 (`float`, *optional*, defaults to `0.999`):
101
- Beta2 for Adam optimizer.
102
- train_adam_weight_decay (`float`, *optional*, defaults to `1e-4`):
103
- Weight decay for Adam optimizer.
104
- train_adam_epsilon (`float`, *optional*, defaults to `1e-8`):
105
- Epsilon value for Adam optimizer.
106
- train_gradient_accumulation_steps (`int`, *optional*, defaults to `1`):
107
- Number of gradient accumulation steps.
108
- train_max_grad_norm (`float`, *optional*, defaults to `1.0`):
109
- Maximum gradient norm for gradient clipping.
110
- negative_prompts (`str` or `None`, *optional*, defaults to `None`):
111
- Comma-separated list of prompts to use as negative examples.
112
- truncated_backprop_rand (`bool`, *optional*, defaults to `True`):
113
- If `True`, randomized truncation to different diffusion timesteps is used.
114
- truncated_backprop_timestep (`int`, *optional*, defaults to `49`):
115
- Absolute timestep to which the gradients are backpropagated. Used only if `truncated_backprop_rand=False`.
116
- truncated_rand_backprop_minmax (`tuple[int, int]`, *optional*, defaults to `(0, 50)`):
117
- Range of diffusion timesteps for randomized truncated backpropagation.
118
- push_to_hub (`bool`, *optional*, defaults to `False`):
119
- Whether to push the final model to the Hub.
120
-
121
- """
122
- vllm_sampling_params: Optional[Any] = field(
123
- default = None,
124
- metadata = {'help': 'vLLM SamplingParams'},
125
- )
126
- unsloth_num_chunks : Optional[int] = field(
127
- default = -1,
128
- metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
129
- )
130
- def __init__(
131
- self,
132
- exp_name = 'main',
133
- run_name = '',
134
- seed = 3407,
135
- log_with = None,
136
- log_image_freq = 1,
137
- tracker_project_name = 'trl',
138
- logdir = 'logs',
139
- num_epochs = 100,
140
- save_freq = 1,
141
- num_checkpoint_limit = 5,
142
- mixed_precision = 'fp16',
143
- allow_tf32 = True,
144
- resume_from = '',
145
- sample_num_steps = 50,
146
- sample_eta = 1.0,
147
- sample_guidance_scale = 5.0,
148
- train_batch_size = 1,
149
- train_use_8bit_adam = False,
150
- train_learning_rate = 5e-05,
151
- train_adam_beta1 = 0.9,
152
- train_adam_beta2 = 0.999,
153
- train_adam_weight_decay = 0.01,
154
- train_adam_epsilon = 1e-08,
155
- train_gradient_accumulation_steps = 2,
156
- train_max_grad_norm = 1.0,
157
- negative_prompts = None,
158
- truncated_backprop_rand = True,
159
- truncated_backprop_timestep = 49,
160
- push_to_hub = False,
161
- vllm_sampling_params = None,
162
- unsloth_num_chunks = -1,
163
- **kwargs,
164
- ):
165
-
166
- super().__init__(
167
- exp_name = exp_name,
168
- run_name = run_name,
169
- seed = seed,
170
- log_with = log_with,
171
- log_image_freq = log_image_freq,
172
- tracker_project_name = tracker_project_name,
173
- logdir = logdir,
174
- num_epochs = num_epochs,
175
- save_freq = save_freq,
176
- num_checkpoint_limit = num_checkpoint_limit,
177
- mixed_precision = mixed_precision,
178
- allow_tf32 = allow_tf32,
179
- resume_from = resume_from,
180
- sample_num_steps = sample_num_steps,
181
- sample_eta = sample_eta,
182
- sample_guidance_scale = sample_guidance_scale,
183
- train_batch_size = train_batch_size,
184
- train_use_8bit_adam = train_use_8bit_adam,
185
- train_learning_rate = train_learning_rate,
186
- train_adam_beta1 = train_adam_beta1,
187
- train_adam_beta2 = train_adam_beta2,
188
- train_adam_weight_decay = train_adam_weight_decay,
189
- train_adam_epsilon = train_adam_epsilon,
190
- train_gradient_accumulation_steps = train_gradient_accumulation_steps,
191
- train_max_grad_norm = train_max_grad_norm,
192
- negative_prompts = negative_prompts,
193
- truncated_backprop_rand = truncated_backprop_rand,
194
- truncated_backprop_timestep = truncated_backprop_timestep,
195
- push_to_hub = push_to_hub,**kwargs)
196
- self.vllm_sampling_params = vllm_sampling_params
197
- self.unsloth_num_chunks = unsloth_num_chunks
198
- pass
199
-
200
- class _UnslothAlignPropTrainer(PyTorchModelHubMixin):
201
- """"""
202
-
203
- _tag_names = ["trl", "alignprop"]
204
-
205
- def __init__(
206
- self,
207
- config: AlignPropConfig,
208
- reward_function: Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor],
209
- prompt_function: Callable[[], tuple[str, Any]],
210
- sd_pipeline: DDPOStableDiffusionPipeline,
211
- image_samples_hook: Optional[Callable[[Any, Any, Any], Any]] = None,
212
- ):
213
- if image_samples_hook is None:
214
- warn("No image_samples_hook provided; no images will be logged")
215
-
216
- self.prompt_fn = prompt_function
217
- self.reward_fn = reward_function
218
- self.config = config
219
- self.image_samples_callback = image_samples_hook
220
-
221
- accelerator_project_config = ProjectConfiguration(**self.config.project_kwargs)
222
-
223
- if self.config.resume_from:
224
- self.config.resume_from = os.path.normpath(os.path.expanduser(self.config.resume_from))
225
- if "checkpoint_" not in os.path.basename(self.config.resume_from):
226
- # get the most recent checkpoint in this directory
227
- checkpoints = list(
228
- filter(
229
- lambda x: "checkpoint_" in x,
230
- os.listdir(self.config.resume_from),
231
- )
232
- )
233
- if len(checkpoints) == 0:
234
- raise ValueError(f"No checkpoints found in {self.config.resume_from}")
235
- checkpoint_numbers = sorted([int(x.split("_")[-1]) for x in checkpoints])
236
- self.config.resume_from = os.path.join(
237
- self.config.resume_from,
238
- f"checkpoint_{checkpoint_numbers[-1]}",
239
- )
240
-
241
- accelerator_project_config.iteration = checkpoint_numbers[-1] + 1
242
-
243
- self.accelerator = Accelerator(
244
- log_with=self.config.log_with,
245
- mixed_precision=self.config.mixed_precision,
246
- project_config=accelerator_project_config,
247
- # we always accumulate gradients across timesteps; we want config.train.gradient_accumulation_steps to be the
248
- # number of *samples* we accumulate across, so we need to multiply by the number of training timesteps to get
249
- # the total number of optimizer steps to accumulate across.
250
- gradient_accumulation_steps=self.config.train_gradient_accumulation_steps,
251
- **self.config.accelerator_kwargs,
252
- )
253
-
254
- is_using_tensorboard = config.log_with is not None and config.log_with == "tensorboard"
255
-
256
- if self.accelerator.is_main_process:
257
- self.accelerator.init_trackers(
258
- self.config.tracker_project_name,
259
- config=dict(alignprop_trainer_config=config.to_dict())
260
- if not is_using_tensorboard
261
- else config.to_dict(),
262
- init_kwargs=self.config.tracker_kwargs,
263
- )
264
-
265
- logger.info(f"\n{config}")
266
-
267
- set_seed(self.config.seed, device_specific=True)
268
-
269
- self.sd_pipeline = sd_pipeline
270
-
271
- self.sd_pipeline.set_progress_bar_config(
272
- position=1,
273
- disable=not self.accelerator.is_local_main_process,
274
- leave=False,
275
- desc="Timestep",
276
- dynamic_ncols=True,
277
- )
278
-
279
- # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
280
- # as these weights are only used for inference, keeping weights in full precision is not required.
281
- if self.accelerator.mixed_precision == "fp16":
282
- inference_dtype = torch.float16
283
- elif self.accelerator.mixed_precision == "bf16":
284
- inference_dtype = torch.bfloat16
285
- else:
286
- inference_dtype = torch.float32
287
-
288
- self.sd_pipeline.vae.to(self.accelerator.device, dtype=inference_dtype)
289
- self.sd_pipeline.text_encoder.to(self.accelerator.device, dtype=inference_dtype)
290
- self.sd_pipeline.unet.to(self.accelerator.device, dtype=inference_dtype)
291
-
292
- trainable_layers = self.sd_pipeline.get_trainable_layers()
293
-
294
- self.accelerator.register_save_state_pre_hook(self._save_model_hook)
295
- self.accelerator.register_load_state_pre_hook(self._load_model_hook)
296
-
297
- # Enable TF32 for faster training on Ampere GPUs,
298
- # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
299
- if self.config.allow_tf32:
300
- torch.backends.cuda.matmul.allow_tf32 = True
301
-
302
- self.optimizer = self._setup_optimizer(
303
- trainable_layers.parameters() if not isinstance(trainable_layers, list) else trainable_layers
304
- )
305
-
306
- self.neg_prompt_embed = self.sd_pipeline.text_encoder(
307
- self.sd_pipeline.tokenizer(
308
- [""] if self.config.negative_prompts is None else self.config.negative_prompts,
309
- return_tensors="pt",
310
- padding="max_length",
311
- truncation=True,
312
- max_length=self.sd_pipeline.tokenizer.model_max_length,
313
- ).input_ids.to(self.accelerator.device)
314
- )[0]
315
-
316
- # NOTE: for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses
317
- # more memory
318
- self.autocast = self.sd_pipeline.autocast or self.accelerator.autocast
319
-
320
- if hasattr(self.sd_pipeline, "use_lora") and self.sd_pipeline.use_lora:
321
- unet, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
322
- self.trainable_layers = list(filter(lambda p: p.requires_grad, unet.parameters()))
323
- else:
324
- self.trainable_layers, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
325
-
326
- if config.resume_from:
327
- logger.info(f"Resuming from {config.resume_from}")
328
- self.accelerator.load_state(config.resume_from)
329
- self.first_epoch = int(config.resume_from.split("_")[-1]) + 1
330
- else:
331
- self.first_epoch = 0
332
-
333
- def compute_rewards(self, prompt_image_pairs):
334
- reward, reward_metadata = self.reward_fn(
335
- prompt_image_pairs["images"], prompt_image_pairs["prompts"], prompt_image_pairs["prompt_metadata"]
336
- )
337
- return reward
338
-
339
- def step(self, epoch: int, global_step: int):
340
- """
341
- Perform a single step of training.
342
-
343
- Args:
344
- epoch (int): The current epoch.
345
- global_step (int): The current global step.
346
-
347
- Side Effects:
348
- - Model weights are updated
349
- - Logs the statistics to the accelerator trackers.
350
- - If `self.image_samples_callback` is not None, it will be called with the prompt_image_pairs, global_step, and the accelerator tracker.
351
-
352
- Returns:
353
- global_step (int): The updated global step.
354
- """
355
- info = defaultdict(list)
356
-
357
- self.sd_pipeline.unet.train()
358
-
359
- for _ in range(self.config.train_gradient_accumulation_steps):
360
- with self.accelerator.accumulate(self.sd_pipeline.unet), self.autocast(), torch.enable_grad():
361
- prompt_image_pairs = self._generate_samples(
362
- batch_size=self.config.train_batch_size,
363
- )
364
-
365
- rewards = self.compute_rewards(prompt_image_pairs)
366
-
367
- prompt_image_pairs["rewards"] = rewards
368
-
369
- rewards_vis = self.accelerator.gather(rewards).detach().cpu().numpy()
370
-
371
- loss = self.calculate_loss(rewards)
372
-
373
- self.accelerator.backward(loss)
374
-
375
- if self.accelerator.sync_gradients:
376
- self.accelerator.clip_grad_norm_(
377
- self.trainable_layers.parameters()
378
- if not isinstance(self.trainable_layers, list)
379
- else self.trainable_layers,
380
- self.config.train_max_grad_norm,
381
- )
382
-
383
- self.optimizer.step()
384
- self.optimizer.zero_grad()
385
-
386
- info["reward_mean"].append(rewards_vis.mean())
387
- info["reward_std"].append(rewards_vis.std())
388
- info["loss"].append(loss.item())
389
-
390
- # Checks if the accelerator has performed an optimization step behind the scenes
391
- if self.accelerator.sync_gradients:
392
- # log training-related stuff
393
- info = {k: torch.mean(torch.tensor(v)) for k, v in info.items()}
394
- info = self.accelerator.reduce(info, reduction="mean")
395
- info.update({"epoch": epoch})
396
- self.accelerator.log(info, step=global_step)
397
- global_step += 1
398
- info = defaultdict(list)
399
- else:
400
- raise ValueError(
401
- "Optimization step should have been performed by this point. Please check calculated gradient accumulation settings."
402
- )
403
- # Logs generated images
404
- if self.image_samples_callback is not None and global_step % self.config.log_image_freq == 0:
405
- self.image_samples_callback(prompt_image_pairs, global_step, self.accelerator.trackers[0])
406
-
407
- if epoch != 0 and epoch % self.config.save_freq == 0 and self.accelerator.is_main_process:
408
- self.accelerator.save_state()
409
-
410
- return global_step
411
-
412
- def calculate_loss(self, rewards):
413
- """
414
- Calculate the loss for a batch of an unpacked sample
415
-
416
- Args:
417
- rewards (torch.Tensor):
418
- Differentiable reward scalars for each generated image, shape: [batch_size]
419
-
420
- Returns:
421
- loss (torch.Tensor)
422
- (all of these are of shape (1,))
423
- """
424
- # Loss is specific to Aesthetic Reward function used in AlignProp (https://huggingface.co/papers/2310.03739)
425
- loss = 10.0 - (rewards).mean()
426
- return loss
427
-
428
- def loss(
429
- self,
430
- advantages: torch.Tensor,
431
- clip_range: float,
432
- ratio: torch.Tensor,
433
- ):
434
- unclipped_loss = -advantages * ratio
435
- clipped_loss = -advantages * torch.clamp(
436
- ratio,
437
- 1.0 - clip_range,
438
- 1.0 + clip_range,
439
- )
440
- return torch.mean(torch.maximum(unclipped_loss, clipped_loss))
441
-
442
- def _setup_optimizer(self, trainable_layers_parameters):
443
- if self.config.train_use_8bit_adam:
444
- import bitsandbytes
445
-
446
- optimizer_cls = bitsandbytes.optim.AdamW8bit
447
- else:
448
- optimizer_cls = torch.optim.AdamW
449
-
450
- return optimizer_cls(
451
- trainable_layers_parameters,
452
- lr=self.config.train_learning_rate,
453
- betas=(self.config.train_adam_beta1, self.config.train_adam_beta2),
454
- weight_decay=self.config.train_adam_weight_decay,
455
- eps=self.config.train_adam_epsilon,
456
- )
457
-
458
- def _save_model_hook(self, models, weights, output_dir):
459
- self.sd_pipeline.save_checkpoint(models, weights, output_dir)
460
- weights.pop() # ensures that accelerate doesn't try to handle saving of the model
461
-
462
- def _load_model_hook(self, models, input_dir):
463
- self.sd_pipeline.load_checkpoint(models, input_dir)
464
- models.pop() # ensures that accelerate doesn't try to handle loading of the model
465
-
466
- def _generate_samples(self, batch_size, with_grad=True, prompts=None):
467
- """
468
- Generate samples from the model
469
-
470
- Args:
471
- batch_size (int): Batch size to use for sampling
472
- with_grad (bool): Whether the generated RGBs should have gradients attached to it.
473
-
474
- Returns:
475
- prompt_image_pairs (dict[Any])
476
- """
477
- prompt_image_pairs = {}
478
-
479
- sample_neg_prompt_embeds = self.neg_prompt_embed.repeat(batch_size, 1, 1)
480
-
481
- if prompts is None:
482
- prompts, prompt_metadata = zip(*[self.prompt_fn() for _ in range(batch_size)])
483
- else:
484
- prompt_metadata = [{} for _ in range(batch_size)]
485
-
486
- prompt_ids = self.sd_pipeline.tokenizer(
487
- prompts,
488
- return_tensors="pt",
489
- padding="max_length",
490
- truncation=True,
491
- max_length=self.sd_pipeline.tokenizer.model_max_length,
492
- ).input_ids.to(self.accelerator.device)
493
-
494
- prompt_embeds = self.sd_pipeline.text_encoder(prompt_ids)[0]
495
-
496
- if with_grad:
497
- sd_output = self.sd_pipeline.rgb_with_grad(
498
- prompt_embeds=prompt_embeds,
499
- negative_prompt_embeds=sample_neg_prompt_embeds,
500
- num_inference_steps=self.config.sample_num_steps,
501
- guidance_scale=self.config.sample_guidance_scale,
502
- eta=self.config.sample_eta,
503
- truncated_backprop_rand=self.config.truncated_backprop_rand,
504
- truncated_backprop_timestep=self.config.truncated_backprop_timestep,
505
- truncated_rand_backprop_minmax=self.config.truncated_rand_backprop_minmax,
506
- output_type="pt",
507
- )
508
- else:
509
- sd_output = self.sd_pipeline(
510
- prompt_embeds=prompt_embeds,
511
- negative_prompt_embeds=sample_neg_prompt_embeds,
512
- num_inference_steps=self.config.sample_num_steps,
513
- guidance_scale=self.config.sample_guidance_scale,
514
- eta=self.config.sample_eta,
515
- output_type="pt",
516
- )
517
-
518
- images = sd_output.images
519
-
520
- prompt_image_pairs["images"] = images
521
- prompt_image_pairs["prompts"] = prompts
522
- prompt_image_pairs["prompt_metadata"] = prompt_metadata
523
-
524
- return prompt_image_pairs
525
-
526
- def train(self, epochs: Optional[int] = None):
527
- """
528
- Train the model for a given number of epochs
529
- """
530
- global_step = 0
531
- if epochs is None:
532
- epochs = self.config.num_epochs
533
- for epoch in range(self.first_epoch, epochs):
534
- global_step = self.step(epoch, global_step)
535
-
536
- def _save_pretrained(self, save_directory):
537
- self.sd_pipeline.save_pretrained(save_directory)
538
- self.create_model_card()
539
-
540
- def create_model_card(
541
- self,
542
- model_name: Optional[str] = None,
543
- dataset_name: Optional[str] = None,
544
- tags: Union[str, list[str], None] = None,
545
- ):
546
- """
547
- Creates a draft of a model card using the information available to the `Trainer`.
548
-
549
- Args:
550
- model_name (`str` or `None`, *optional*, defaults to `None`):
551
- Name of the model.
552
- dataset_name (`str` or `None`, *optional*, defaults to `None`):
553
- Name of the dataset used for training.
554
- tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
555
- Tags to be associated with the model card.
556
- """
557
- if not self.is_world_process_zero():
558
- return
559
-
560
- if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
561
- base_model = self.model.config._name_or_path
562
- else:
563
- base_model = None
564
-
565
- tags = tags or []
566
- if isinstance(tags, str):
567
- tags = [tags]
568
-
569
- if hasattr(self.model.config, "unsloth_version"):
570
- tags.append("unsloth")
571
-
572
- citation = textwrap.dedent("""\
573
- @article{prabhudesai2024aligning,
574
- title = {{Aligning Text-to-Image Diffusion Models with Reward Backpropagation}},
575
- author = {Mihir Prabhudesai and Anirudh Goyal and Deepak Pathak and Katerina Fragkiadaki},
576
- year = 2024,
577
- eprint = {arXiv:2310.03739}
578
- }""")
579
-
580
- model_card = generate_model_card(
581
- base_model=base_model,
582
- model_name=model_name,
583
- hub_model_id=self.hub_model_id,
584
- dataset_name=dataset_name,
585
- tags=tags,
586
- wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
587
- comet_url=get_comet_experiment_url(),
588
- trainer_name="AlignProp",
589
- trainer_citation=citation,
590
- paper_title="Aligning Text-to-Image Diffusion Models with Reward Backpropagation",
591
- paper_id="2310.03739",
592
- )
593
-
594
- model_card.save(os.path.join(self.args.output_dir, "README.md"))
595
- class UnslothAlignPropTrainer(_UnslothAlignPropTrainer):
596
- """
597
-
598
- The AlignPropTrainer uses Deep Diffusion Policy Optimization to optimise diffusion models.
599
- Note, this trainer is heavily inspired by the work here: https://github.com/mihirp1998/AlignProp/
600
- As of now only Stable Diffusion based pipelines are supported
601
-
602
- Attributes:
603
- config (`AlignPropConfig`):
604
- Configuration object for AlignPropTrainer. Check the documentation of `PPOConfig` for more details.
605
- reward_function (`Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor]`):
606
- Reward function to be used
607
- prompt_function (`Callable[[], tuple[str, Any]]`):
608
- Function to generate prompts to guide model
609
- sd_pipeline (`DDPOStableDiffusionPipeline`):
610
- Stable Diffusion pipeline to be used for training.
611
- image_samples_hook (`Optional[Callable[[Any, Any, Any], Any]]`):
612
- Hook to be called to log images
613
-
614
- """
615
- def __init__(
616
- self,
617
- config,
618
- reward_function,
619
- prompt_function,
620
- sd_pipeline,
621
- image_samples_hook = None,
622
- **kwargs
623
- ):
624
- if args is None: args = UnslothAlignPropConfig()
625
- other_metrics = []
626
-
627
- from unsloth_zoo.logging_utils import PatchRLStatistics
628
- PatchRLStatistics('alignprop_trainer', other_metrics)
629
-
630
- super().__init__(
631
- config = config,
632
- reward_function = reward_function,
633
- prompt_function = prompt_function,
634
- sd_pipeline = sd_pipeline,
635
- image_samples_hook = image_samples_hook,**kwargs)
636
-
637
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
unsloth_compiled_cache/UnslothBCOTrainer.py DELETED
@@ -1,1822 +0,0 @@
1
- """
2
- 2025.3.13
3
- 2025.3.15
4
- 4.48.3
5
- 0.15.2
6
- __UNSLOTH_VERSIONING__
7
- """
8
- from torch import Tensor
9
- import torch
10
- import torch.nn as nn
11
- from torch.nn import functional as F
12
- from trl.trainer.bco_trainer import (Any, AutoModelForCausalLM, BCOConfig, BCOTrainer, BaseImageProcessor, CLF_NAME, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, Literal, LogisticRegression, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedModelWrapper, PreTrainedTokenizerBase, ProcessorMixin, RUNNING_NAME, RunningMoments, SequentialSampler, Trainer, TrainerCallback, TrainingArguments, Union, _process_tokens, _tokenize, amp, contextmanager, create_reference_model, deepcopy, defaultdict, disable_dropout_in_model, generate_model_card, get_comet_experiment_url, has_length, inspect, is_comet_available, is_peft_available, is_sklearn_available, is_wandb_available, itemgetter, log_table_to_comet_experiment, maybe_apply_chat_template, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_model_for_kbit_training, random, textwrap, torch, tqdm, transformers, version, wandb, warnings, F, Optional, PeftModel, PreTrainedModel, Trainer, is_peft_available, os, torch)
13
-
14
-
15
- import os
16
- from typing import *
17
- from dataclasses import dataclass, field
18
- from packaging.version import Version
19
- import torch
20
- import numpy as np
21
- from contextlib import nullcontext
22
- from torch.nn import functional as F
23
- from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
-
25
- torch_compile_options = {
26
- "epilogue_fusion" : True,
27
- "max_autotune" : False,
28
- "shape_padding" : True,
29
- "trace.enabled" : False,
30
- "triton.cudagraphs" : False,
31
- }
32
-
33
- @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
- def selective_log_softmax(logits, index):
35
- logits = logits.to(torch.float32)
36
- selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
- # loop to reduce peak mem consumption
38
- # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
- logsumexp_values = torch.logsumexp(logits, dim = -1)
40
- per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
- return per_token_logps
42
- @dataclass
43
- class UnslothBCOConfig(BCOConfig):
44
- """
45
-
46
- Configuration class for the [`BCOTrainer`].
47
-
48
- Using [`~transformers.HfArgumentParser`] we can turn this class into
49
- [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
50
- command line.
51
-
52
- Parameters:
53
- max_length (`int` or `None`, *optional*, defaults to `1024`):
54
- Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
55
- to use the default data collator.
56
- max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
57
- Maximum length of the prompt. This argument is required if you want to use the default data collator.
58
- max_completion_length (`int` or `None`, *optional*, defaults to `None`):
59
- Maximum length of the completion. This argument is required if you want to use the default data collator
60
- and your model is an encoder-decoder.
61
- beta (`float`, *optional*, defaults to `0.1`):
62
- Parameter controlling the deviation from the reference model. Higher β means less deviation from the
63
- reference model.
64
- label_pad_token_id (`int`, *optional*, defaults to `-100`):
65
- Label pad token id. This argument is required if you want to use the default data collator.
66
- padding_value (`int` or `None`, *optional*, defaults to `None`):
67
- Padding value to use. If `None`, the padding value of the tokenizer is used.
68
- truncation_mode (`str`, *optional*, defaults to `"keep_end"`):
69
- Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`.
70
- This argument is required if you want to use the default data collator.
71
- disable_dropout (`bool`, *optional*, defaults to `True`):
72
- Whether to disable dropout in the model and reference model.
73
- generate_during_eval (`bool`, *optional*, defaults to `False`):
74
- If `True`, generates and logs completions from both the model and the reference model to W&B or Comet during
75
- evaluation.
76
- is_encoder_decoder (`bool` or `None`, *optional*, defaults to `None`):
77
- When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument,
78
- you need to specify if the model returned by the callable is an encoder-decoder model.
79
- precompute_ref_log_probs (`bool`, *optional*, defaults to `False`):
80
- Whether to precompute reference model log probabilities for training and evaluation datasets. This is
81
- useful when training without the reference model to reduce the total GPU memory needed.
82
- model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
83
- Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
84
- string.
85
- ref_model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
86
- Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the reference model
87
- from a string.
88
- dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
89
- Number of processes to use for processing the dataset.
90
- prompt_sample_size (`int`, *optional*, defaults to `1024`):
91
- Number of prompts that are fed to density ratio classifier.
92
- min_density_ratio (`float`, *optional*, defaults to `0.5`):
93
- Minimum value of the density ratio. The estimated density ratio is clamped to this value.
94
- max_density_ratio (`float`, *optional*, defaults to `10.0`):
95
- Maximum value of the density ratio. The estimated density ratio is clamped to this value.
96
-
97
- """
98
- vllm_sampling_params: Optional[Any] = field(
99
- default = None,
100
- metadata = {'help': 'vLLM SamplingParams'},
101
- )
102
- unsloth_num_chunks : Optional[int] = field(
103
- default = -1,
104
- metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
105
- )
106
- def __init__(
107
- self,
108
- output_dir = None,
109
- overwrite_output_dir = None,
110
- do_train = False,
111
- do_eval = False,
112
- do_predict = False,
113
- eval_strategy = 'no',
114
- prediction_loss_only = False,
115
- per_device_train_batch_size = 4,
116
- per_device_eval_batch_size = 4,
117
- per_gpu_train_batch_size = None,
118
- per_gpu_eval_batch_size = None,
119
- gradient_accumulation_steps = 2,
120
- eval_accumulation_steps = 2,
121
- eval_delay = 0,
122
- torch_empty_cache_steps = 250,
123
- learning_rate = 5e-05,
124
- weight_decay = 0.01,
125
- adam_beta1 = 0.9,
126
- adam_beta2 = 0.999,
127
- adam_epsilon = 1e-08,
128
- max_grad_norm = 1.0,
129
- num_train_epochs = 3.0,
130
- max_steps = -1,
131
- lr_scheduler_type = 'linear',
132
- warmup_ratio = 0.1,
133
- warmup_steps = 0,
134
- log_level = 'passive',
135
- log_level_replica = 'warning',
136
- log_on_each_node = True,
137
- logging_dir = None,
138
- logging_strategy = 'steps',
139
- logging_first_step = False,
140
- logging_steps = 1,
141
- logging_nan_inf_filter = False,
142
- save_strategy = 'steps',
143
- save_steps = 500,
144
- save_total_limit = None,
145
- save_safetensors = True,
146
- save_on_each_node = False,
147
- save_only_model = False,
148
- restore_callback_states_from_checkpoint = False,
149
- no_cuda = False,
150
- use_cpu = False,
151
- use_mps_device = False,
152
- seed = 3407,
153
- data_seed = 3407,
154
- jit_mode_eval = False,
155
- use_ipex = False,
156
- bf16 = False,
157
- fp16 = False,
158
- fp16_opt_level = 'O1',
159
- half_precision_backend = 'auto',
160
- bf16_full_eval = False,
161
- fp16_full_eval = False,
162
- tf32 = None,
163
- local_rank = -1,
164
- ddp_backend = None,
165
- tpu_num_cores = None,
166
- tpu_metrics_debug = False,
167
- debug = '',
168
- dataloader_drop_last = False,
169
- eval_steps = None,
170
- dataloader_num_workers = 0,
171
- dataloader_prefetch_factor = None,
172
- past_index = -1,
173
- run_name = None,
174
- disable_tqdm = None,
175
- remove_unused_columns = True,
176
- label_names = None,
177
- load_best_model_at_end = False,
178
- metric_for_best_model = None,
179
- greater_is_better = None,
180
- ignore_data_skip = False,
181
- fsdp = '',
182
- fsdp_min_num_params = 0,
183
- fsdp_config = None,
184
- fsdp_transformer_layer_cls_to_wrap = None,
185
- accelerator_config = None,
186
- deepspeed = None,
187
- label_smoothing_factor = 0.0,
188
- optim = 'adamw_8bit',
189
- optim_args = None,
190
- adafactor = False,
191
- group_by_length = False,
192
- length_column_name = 'length',
193
- report_to = None,
194
- ddp_find_unused_parameters = None,
195
- ddp_bucket_cap_mb = None,
196
- ddp_broadcast_buffers = None,
197
- dataloader_pin_memory = True,
198
- dataloader_persistent_workers = False,
199
- skip_memory_metrics = True,
200
- use_legacy_prediction_loop = False,
201
- push_to_hub = False,
202
- resume_from_checkpoint = None,
203
- hub_model_id = None,
204
- hub_strategy = 'every_save',
205
- hub_token = None,
206
- hub_private_repo = None,
207
- hub_always_push = False,
208
- gradient_checkpointing = False,
209
- gradient_checkpointing_kwargs = None,
210
- include_inputs_for_metrics = False,
211
- eval_do_concat_batches = True,
212
- fp16_backend = 'auto',
213
- evaluation_strategy = None,
214
- push_to_hub_model_id = None,
215
- push_to_hub_organization = None,
216
- push_to_hub_token = None,
217
- mp_parameters = '',
218
- auto_find_batch_size = False,
219
- full_determinism = False,
220
- torchdynamo = None,
221
- ray_scope = 'last',
222
- ddp_timeout = 1800,
223
- torch_compile = False,
224
- torch_compile_backend = None,
225
- torch_compile_mode = None,
226
- dispatch_batches = None,
227
- split_batches = None,
228
- include_tokens_per_second = False,
229
- include_num_input_tokens_seen = False,
230
- neftune_noise_alpha = None,
231
- optim_target_modules = None,
232
- batch_eval_metrics = False,
233
- eval_on_start = False,
234
- use_liger_kernel = False,
235
- eval_use_gather_object = False,
236
- average_tokens_across_devices = False,
237
- max_length = 1024,
238
- max_prompt_length = 512,
239
- max_completion_length = None,
240
- beta = 0.1,
241
- label_pad_token_id = -100,
242
- padding_value = None,
243
- truncation_mode = 'keep_end',
244
- disable_dropout = True,
245
- generate_during_eval = False,
246
- is_encoder_decoder = None,
247
- precompute_ref_log_probs = False,
248
- model_init_kwargs = None,
249
- ref_model_init_kwargs = None,
250
- dataset_num_proc = None,
251
- prompt_sample_size = 1024,
252
- min_density_ratio = 0.5,
253
- max_density_ratio = 10.0,
254
- vllm_sampling_params = None,
255
- unsloth_num_chunks = -1,
256
- **kwargs,
257
- ):
258
- if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
259
- if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
260
- if output_dir is None and save_strategy == 'steps' and save_steps == 500:
261
- output_dir = 'unsloth_training_checkpoints'
262
- save_strategy = 'no'
263
- if dataset_num_proc is None:
264
- from multiprocessing import cpu_count
265
- dataset_num_proc = cpu_count()
266
-
267
- super().__init__(
268
- output_dir = output_dir,
269
- overwrite_output_dir = overwrite_output_dir,
270
- do_train = do_train,
271
- do_eval = do_eval,
272
- do_predict = do_predict,
273
- eval_strategy = eval_strategy,
274
- prediction_loss_only = prediction_loss_only,
275
- per_device_train_batch_size = per_device_train_batch_size,
276
- per_device_eval_batch_size = per_device_eval_batch_size,
277
- per_gpu_train_batch_size = per_gpu_train_batch_size,
278
- per_gpu_eval_batch_size = per_gpu_eval_batch_size,
279
- gradient_accumulation_steps = gradient_accumulation_steps,
280
- eval_accumulation_steps = eval_accumulation_steps,
281
- eval_delay = eval_delay,
282
- torch_empty_cache_steps = torch_empty_cache_steps,
283
- learning_rate = learning_rate,
284
- weight_decay = weight_decay,
285
- adam_beta1 = adam_beta1,
286
- adam_beta2 = adam_beta2,
287
- adam_epsilon = adam_epsilon,
288
- max_grad_norm = max_grad_norm,
289
- num_train_epochs = num_train_epochs,
290
- max_steps = max_steps,
291
- lr_scheduler_type = lr_scheduler_type,
292
- warmup_ratio = warmup_ratio,
293
- warmup_steps = warmup_steps,
294
- log_level = log_level,
295
- log_level_replica = log_level_replica,
296
- log_on_each_node = log_on_each_node,
297
- logging_dir = logging_dir,
298
- logging_strategy = logging_strategy,
299
- logging_first_step = logging_first_step,
300
- logging_steps = logging_steps,
301
- logging_nan_inf_filter = logging_nan_inf_filter,
302
- save_strategy = save_strategy,
303
- save_steps = save_steps,
304
- save_total_limit = save_total_limit,
305
- save_safetensors = save_safetensors,
306
- save_on_each_node = save_on_each_node,
307
- save_only_model = save_only_model,
308
- restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
309
- no_cuda = no_cuda,
310
- use_cpu = use_cpu,
311
- use_mps_device = use_mps_device,
312
- seed = seed,
313
- data_seed = data_seed,
314
- jit_mode_eval = jit_mode_eval,
315
- use_ipex = use_ipex,
316
- bf16 = bf16,
317
- fp16 = fp16,
318
- fp16_opt_level = fp16_opt_level,
319
- half_precision_backend = half_precision_backend,
320
- bf16_full_eval = bf16_full_eval,
321
- fp16_full_eval = fp16_full_eval,
322
- tf32 = tf32,
323
- local_rank = local_rank,
324
- ddp_backend = ddp_backend,
325
- tpu_num_cores = tpu_num_cores,
326
- tpu_metrics_debug = tpu_metrics_debug,
327
- debug = debug,
328
- dataloader_drop_last = dataloader_drop_last,
329
- eval_steps = eval_steps,
330
- dataloader_num_workers = dataloader_num_workers,
331
- dataloader_prefetch_factor = dataloader_prefetch_factor,
332
- past_index = past_index,
333
- run_name = run_name,
334
- disable_tqdm = disable_tqdm,
335
- remove_unused_columns = remove_unused_columns,
336
- label_names = label_names,
337
- load_best_model_at_end = load_best_model_at_end,
338
- metric_for_best_model = metric_for_best_model,
339
- greater_is_better = greater_is_better,
340
- ignore_data_skip = ignore_data_skip,
341
- fsdp = fsdp,
342
- fsdp_min_num_params = fsdp_min_num_params,
343
- fsdp_config = fsdp_config,
344
- fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
345
- accelerator_config = accelerator_config,
346
- deepspeed = deepspeed,
347
- label_smoothing_factor = label_smoothing_factor,
348
- optim = optim,
349
- optim_args = optim_args,
350
- adafactor = adafactor,
351
- group_by_length = group_by_length,
352
- length_column_name = length_column_name,
353
- report_to = report_to,
354
- ddp_find_unused_parameters = ddp_find_unused_parameters,
355
- ddp_bucket_cap_mb = ddp_bucket_cap_mb,
356
- ddp_broadcast_buffers = ddp_broadcast_buffers,
357
- dataloader_pin_memory = dataloader_pin_memory,
358
- dataloader_persistent_workers = dataloader_persistent_workers,
359
- skip_memory_metrics = skip_memory_metrics,
360
- use_legacy_prediction_loop = use_legacy_prediction_loop,
361
- push_to_hub = push_to_hub,
362
- resume_from_checkpoint = resume_from_checkpoint,
363
- hub_model_id = hub_model_id,
364
- hub_strategy = hub_strategy,
365
- hub_token = hub_token,
366
- hub_private_repo = hub_private_repo,
367
- hub_always_push = hub_always_push,
368
- gradient_checkpointing = gradient_checkpointing,
369
- gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
370
- include_inputs_for_metrics = include_inputs_for_metrics,
371
- eval_do_concat_batches = eval_do_concat_batches,
372
- fp16_backend = fp16_backend,
373
- evaluation_strategy = evaluation_strategy,
374
- push_to_hub_model_id = push_to_hub_model_id,
375
- push_to_hub_organization = push_to_hub_organization,
376
- push_to_hub_token = push_to_hub_token,
377
- mp_parameters = mp_parameters,
378
- auto_find_batch_size = auto_find_batch_size,
379
- full_determinism = full_determinism,
380
- torchdynamo = torchdynamo,
381
- ray_scope = ray_scope,
382
- ddp_timeout = ddp_timeout,
383
- torch_compile = torch_compile,
384
- torch_compile_backend = torch_compile_backend,
385
- torch_compile_mode = torch_compile_mode,
386
- dispatch_batches = dispatch_batches,
387
- split_batches = split_batches,
388
- include_tokens_per_second = include_tokens_per_second,
389
- include_num_input_tokens_seen = include_num_input_tokens_seen,
390
- neftune_noise_alpha = neftune_noise_alpha,
391
- optim_target_modules = optim_target_modules,
392
- batch_eval_metrics = batch_eval_metrics,
393
- eval_on_start = eval_on_start,
394
- use_liger_kernel = use_liger_kernel,
395
- eval_use_gather_object = eval_use_gather_object,
396
- average_tokens_across_devices = average_tokens_across_devices,
397
- max_length = max_length,
398
- max_prompt_length = max_prompt_length,
399
- max_completion_length = max_completion_length,
400
- beta = beta,
401
- label_pad_token_id = label_pad_token_id,
402
- padding_value = padding_value,
403
- truncation_mode = truncation_mode,
404
- disable_dropout = disable_dropout,
405
- generate_during_eval = generate_during_eval,
406
- is_encoder_decoder = is_encoder_decoder,
407
- precompute_ref_log_probs = precompute_ref_log_probs,
408
- model_init_kwargs = model_init_kwargs,
409
- ref_model_init_kwargs = ref_model_init_kwargs,
410
- dataset_num_proc = dataset_num_proc,
411
- prompt_sample_size = prompt_sample_size,
412
- min_density_ratio = min_density_ratio,
413
- max_density_ratio = max_density_ratio,**kwargs)
414
- self.vllm_sampling_params = vllm_sampling_params
415
- self.unsloth_num_chunks = unsloth_num_chunks
416
- pass
417
-
418
- class _UnslothBCOTrainer(Trainer):
419
- r""""""
420
-
421
- _tag_names = ["trl", "bco"]
422
-
423
- def __init__(
424
- self,
425
- model: Union[PreTrainedModel, nn.Module, str] = None,
426
- ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
427
- args: BCOConfig = None,
428
- train_dataset: Optional[Dataset] = None,
429
- eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
430
- processing_class: Optional[
431
- Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
432
- ] = None,
433
- data_collator: Optional[DataCollator] = None,
434
- model_init: Optional[Callable[[], PreTrainedModel]] = None,
435
- callbacks: Optional[list[TrainerCallback]] = None,
436
- optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
437
- preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
438
- peft_config: Optional[dict] = None,
439
- compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
440
- model_adapter_name: Optional[str] = None,
441
- ref_adapter_name: Optional[str] = None,
442
- embedding_func: Optional[Callable] = None,
443
- embedding_tokenizer: Optional[PreTrainedTokenizerBase] = None,
444
- ):
445
- if not is_sklearn_available():
446
- raise ImportError(
447
- "BCOTrainer requires the scikit-learn library. Please install it with `pip install scikit-learn`."
448
- )
449
-
450
- if type(args) is TrainingArguments:
451
- raise ValueError("Please use `BCOConfig` instead `TrainingArguments`.")
452
-
453
- if not isinstance(model, str) and ref_model is model:
454
- raise ValueError(
455
- "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
456
- "same as `model`, you must mass a copy of it, or `None` if you use peft."
457
- )
458
-
459
- if args.model_init_kwargs is None:
460
- model_init_kwargs = {}
461
- elif not isinstance(model, str):
462
- raise ValueError("You passed model_kwargs to the BCOTrainer. But your model is already instantiated.")
463
- else:
464
- model_init_kwargs = args.model_init_kwargs
465
- torch_dtype = model_init_kwargs.get("torch_dtype")
466
- if torch_dtype is not None:
467
- # Convert to `torch.dtype` if an str is passed
468
- if isinstance(torch_dtype, str) and torch_dtype != "auto":
469
- torch_dtype = getattr(torch, torch_dtype)
470
- if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
471
- raise ValueError(
472
- f"Invalid `torch_dtype` passed to the BCOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
473
- )
474
- model_init_kwargs["torch_dtype"] = torch_dtype
475
-
476
- if args.ref_model_init_kwargs is None:
477
- ref_model_init_kwargs = {}
478
- elif not isinstance(ref_model, str):
479
- raise ValueError(
480
- "You passed ref_model_kwargs to the BCOTrainer. But your ref_model is already instantiated."
481
- )
482
- else:
483
- ref_model_init_kwargs = args.ref_model_init_kwargs
484
- torch_dtype = ref_model_init_kwargs.get("torch_dtype")
485
- if torch_dtype is not None:
486
- # Convert to `torch.dtype` if an str is passed
487
- if isinstance(torch_dtype, str) and torch_dtype != "auto":
488
- torch_dtype = getattr(torch, torch_dtype)
489
- if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
490
- raise ValueError(
491
- f"Invalid `torch_dtype` passed to the BCOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
492
- )
493
- ref_model_init_kwargs["torch_dtype"] = torch_dtype
494
-
495
- if isinstance(model, str):
496
- model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
497
-
498
- if isinstance(ref_model, str):
499
- ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **ref_model_init_kwargs)
500
-
501
- # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
502
- # has been called in order to properly call autocast if needed.
503
- self._peft_has_been_casted_to_bf16 = False
504
-
505
- if not is_peft_available() and peft_config is not None:
506
- raise ValueError(
507
- "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it with `pip install peft` to use the PEFT models"
508
- )
509
- elif is_peft_available() and peft_config is not None:
510
- # if model is a peft model and we have a peft_config, we merge and unload it first
511
- if isinstance(model, PeftModel):
512
- model = model.merge_and_unload()
513
-
514
- if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
515
- _support_gc_kwargs = hasattr(
516
- args, "gradient_checkpointing_kwargs"
517
- ) and "gradient_checkpointing_kwargs" in list(
518
- inspect.signature(prepare_model_for_kbit_training).parameters
519
- )
520
-
521
- prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
522
-
523
- if _support_gc_kwargs:
524
- prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
525
-
526
- model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
527
- elif getattr(args, "gradient_checkpointing", False):
528
- # For backward compatibility with older versions of transformers
529
- if hasattr(model, "enable_input_require_grads"):
530
- model.enable_input_require_grads()
531
- else:
532
-
533
- def make_inputs_require_grad(module, input, output):
534
- output.requires_grad_(True)
535
-
536
- model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
537
-
538
- # get peft model with the given config
539
- model = model
540
- if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
541
- peft_module_casting_to_bf16(model)
542
- # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
543
- self._peft_has_been_casted_to_bf16 = True
544
-
545
- # For models that use gradient_checkpointing, we need to attach a hook that enables input
546
- # to explicitly have `requires_grad=True`, otherwise training will either silently
547
- # fail or completely fail.
548
- elif getattr(args, "gradient_checkpointing", False):
549
- # For backward compatibility with older versions of transformers
550
- if hasattr(model, "enable_input_require_grads"):
551
- model.enable_input_require_grads()
552
- else:
553
-
554
- def make_inputs_require_grad(module, input, output):
555
- output.requires_grad_(True)
556
-
557
- model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
558
-
559
- if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
560
- raise ValueError(
561
- "`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
562
- " Please install `wandb` or `comet-ml` to resolve."
563
- )
564
-
565
- if model is not None:
566
- self.is_encoder_decoder = model.config.is_encoder_decoder
567
- elif args.is_encoder_decoder is None:
568
- raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
569
- else:
570
- self.is_encoder_decoder = args.is_encoder_decoder
571
-
572
- self.is_peft_model = is_peft_available() and isinstance(model, PeftModel)
573
- self.model_adapter_name = model_adapter_name
574
- self.ref_adapter_name = ref_adapter_name
575
-
576
- if ref_model:
577
- self.ref_model = ref_model
578
- elif self.is_peft_model or args.precompute_ref_log_probs:
579
- # The `model` with adapters turned off will be used as the reference model
580
- self.ref_model = None
581
- else:
582
- self.ref_model = create_reference_model(model)
583
-
584
- if processing_class is None:
585
- raise ValueError(
586
- "max_length or a processing_class must be specified when using the default DPODataCollatorWithPadding"
587
- )
588
- if args.max_length is None:
589
- warnings.warn(
590
- "When using DPODataCollatorWithPadding, you should set `max_length` in the `BCOConfig`. "
591
- "It will be set to `512` by default, but you should do it yourself in the future.",
592
- UserWarning,
593
- )
594
- max_length = 512
595
- if args.max_length is not None:
596
- max_length = args.max_length
597
-
598
- if args.max_prompt_length is None:
599
- warnings.warn(
600
- "When using DPODataCollatorWithPadding, you should set `max_prompt_length` in the `BCOConfig`. "
601
- "It will be set to `128` by default, but you should do it yourself in the future.",
602
- UserWarning,
603
- )
604
- max_prompt_length = 128
605
- if args.max_prompt_length is not None:
606
- max_prompt_length = args.max_prompt_length
607
-
608
- max_completion_length = None
609
- if args.max_completion_length is None and self.is_encoder_decoder:
610
- warnings.warn(
611
- "When using DPODataCollatorWithPadding with an encoder decoder architecture, you should set `max_completion_length` in the BCOTrainer's init"
612
- " it will be set to `128` by default, but you should do it yourself in the future.",
613
- UserWarning,
614
- )
615
- max_completion_length = 128
616
- if args.max_completion_length is not None and self.is_encoder_decoder:
617
- max_completion_length = args.max_completion_length
618
-
619
- if data_collator is None:
620
- data_collator = DPODataCollatorWithPadding(
621
- pad_token_id=processing_class.pad_token_id,
622
- label_pad_token_id=args.label_pad_token_id,
623
- is_encoder_decoder=self.is_encoder_decoder,
624
- )
625
-
626
- if args.remove_unused_columns:
627
- args.remove_unused_columns = False
628
- # warn users
629
- warnings.warn(
630
- "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your BCOConfig"
631
- " we have set it for you, but you should do it yourself in the future.",
632
- UserWarning,
633
- )
634
-
635
- self.use_dpo_data_collator = True
636
- else:
637
- self.use_dpo_data_collator = False
638
-
639
- # Disable dropout in the model and reference model
640
- if args.disable_dropout:
641
- disable_dropout_in_model(model)
642
- if self.ref_model is not None:
643
- disable_dropout_in_model(self.ref_model)
644
-
645
- self.max_length = max_length
646
- self.generate_during_eval = args.generate_during_eval
647
- self.label_pad_token_id = args.label_pad_token_id
648
- self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id
649
- self.max_prompt_length = max_prompt_length
650
- self.truncation_mode = args.truncation_mode
651
- self.max_completion_length = max_completion_length
652
- self.precompute_ref_log_probs = args.precompute_ref_log_probs
653
-
654
- # Since ref_logs are precomputed on the first call to get_train/eval_dataloader
655
- # keep track of first called to avoid computation of future calls
656
- self._precomputed_train_ref_log_probs = False
657
- self._precomputed_eval_ref_log_probs = False
658
-
659
- # metric
660
- self._stored_metrics = defaultdict(lambda: defaultdict(list))
661
-
662
- # BCO parameter
663
- self.beta = args.beta
664
- self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
665
- self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
666
- if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
667
- warnings.warn(
668
- "You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
669
- "`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
670
- "greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
671
- "loss.",
672
- UserWarning,
673
- )
674
-
675
- # Underlying Distribution Matching argument
676
- self.embedding_func = embedding_func
677
- self.embedding_tokenizer = embedding_tokenizer
678
-
679
- # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
680
- # input tensor associated with the key "input_ids". However, in BCO, the sampled data does not include the
681
- # "input_ids" key. Instead, the available keys are "prompt_input_ids" and "completion_input_ids". As a result,
682
- # the trainer issues the warning: "Could not estimate the number of tokens of the input, floating-point
683
- # operations will not be computed." To suppress this warning, we set the "estimate_tokens" key in the model's
684
- # "warnings_issued" dictionary to True. This acts as a flag to indicate that the warning has already been
685
- # issued.
686
- model.warnings_issued["estimate_tokens"] = True
687
-
688
- with PartialState().local_main_process_first():
689
- # Apply the chat template if needed
690
- train_dataset = train_dataset.map(
691
- maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, num_proc=args.dataset_num_proc
692
- )
693
- if eval_dataset is not None:
694
- eval_dataset = eval_dataset.map(
695
- maybe_apply_chat_template,
696
- fn_kwargs={"tokenizer": processing_class},
697
- num_proc=args.dataset_num_proc,
698
- )
699
- # Shuffle the datasets
700
- train_dataset = train_dataset.shuffle(seed=args.data_seed)
701
- if eval_dataset is not None:
702
- eval_dataset = eval_dataset.shuffle(seed=args.data_seed)
703
- # Tokenize and prepare the training datasets
704
- train_dataset = train_dataset.map(
705
- _tokenize,
706
- batched=True,
707
- fn_kwargs={"tokenizer": processing_class, "embedding_tokenizer": self.embedding_tokenizer},
708
- num_proc=args.dataset_num_proc,
709
- desc="Tokenizing train dataset",
710
- )
711
-
712
- # Prepare the datasets
713
- fn_kwargs = {
714
- "prefix": "",
715
- "is_encoder_decoder": self.is_encoder_decoder,
716
- "tokenizer": processing_class,
717
- "max_length": self.max_length,
718
- "truncation_mode": self.truncation_mode,
719
- "label_pad_token_id": self.label_pad_token_id,
720
- "max_prompt_length": self.max_prompt_length,
721
- "max_completion_length": self.max_completion_length,
722
- }
723
- train_dataset = train_dataset.map(
724
- _process_tokens,
725
- fn_kwargs=fn_kwargs,
726
- num_proc=args.dataset_num_proc,
727
- desc="Processing tokenized train dataset",
728
- )
729
-
730
- if eval_dataset is not None:
731
- # Tokenize
732
- eval_dataset = eval_dataset.map(
733
- _tokenize,
734
- fn_kwargs={"tokenizer": processing_class, "embedding_tokenizer": self.embedding_tokenizer},
735
- batched=True,
736
- num_proc=args.dataset_num_proc,
737
- desc="Tokenizing eval dataset",
738
- )
739
-
740
- # Process
741
- fn_kwargs = {
742
- "prefix": "",
743
- "is_encoder_decoder": self.is_encoder_decoder,
744
- "tokenizer": processing_class,
745
- "max_length": self.max_length,
746
- "truncation_mode": self.truncation_mode,
747
- "label_pad_token_id": self.label_pad_token_id,
748
- "max_prompt_length": self.max_prompt_length,
749
- "max_completion_length": self.max_completion_length,
750
- }
751
- eval_dataset = eval_dataset.map(
752
- _process_tokens,
753
- fn_kwargs=fn_kwargs,
754
- num_proc=args.dataset_num_proc,
755
- desc="Processing tokenized eval dataset",
756
- )
757
-
758
- desirable = train_dataset.filter(
759
- lambda x: x["label"], num_proc=args.dataset_num_proc, desc="Filtering desirable examples"
760
- )
761
- undesirable = train_dataset.filter(
762
- lambda x: not x["label"], num_proc=args.dataset_num_proc, desc="Filtering undesirable examples"
763
- )
764
-
765
- desirable = desirable.shuffle(seed=args.data_seed)
766
- undesirable = undesirable.shuffle(seed=args.data_seed)
767
-
768
- super().__init__(
769
- model=model,
770
- args=args,
771
- data_collator=data_collator,
772
- train_dataset=train_dataset,
773
- eval_dataset=eval_dataset,
774
- processing_class=processing_class,
775
- model_init=model_init,
776
- compute_metrics=compute_metrics,
777
- callbacks=callbacks,
778
- optimizers=optimizers,
779
- preprocess_logits_for_metrics=preprocess_logits_for_metrics,
780
- )
781
-
782
- # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
783
- # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
784
- # self.model_accepts_loss_kwargs to False to enable scaling.
785
- self.model_accepts_loss_kwargs = False
786
-
787
- # Add tags for models that have been loaded with the correct transformers version
788
- if hasattr(self.model, "add_model_tags"):
789
- self.model.add_model_tags(self._tag_names)
790
-
791
- if not hasattr(self, "accelerator"):
792
- raise AttributeError(
793
- "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
794
- )
795
-
796
- # Deepspeed Zero-3 does not support precompute_ref_log_probs
797
- if self.is_deepspeed_enabled:
798
- if self.accelerator.state.deepspeed_plugin.zero_stage == 3 and self.precompute_ref_log_probs:
799
- raise ValueError(
800
- "You cannot use `precompute_ref_log_probs=True` with Deepspeed ZeRO-3. Please set `precompute_ref_log_probs=False`."
801
- )
802
-
803
- if self.ref_model is None:
804
- if not (self.is_peft_model or self.precompute_ref_log_probs):
805
- raise ValueError(
806
- "No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`"
807
- )
808
- else:
809
- if self.is_deepspeed_enabled:
810
- self.ref_model = self._prepare_deepspeed(self.ref_model)
811
- else:
812
- self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
813
-
814
- self.running = RunningMoments(accelerator=self.accelerator)
815
-
816
- if self.embedding_func is None:
817
- return
818
-
819
- chosen_embeddings = self._get_sample_prompt_embeddings(desirable, sample_size=self.args.prompt_sample_size)
820
- rejected_embeddings = self._get_sample_prompt_embeddings(undesirable, sample_size=self.args.prompt_sample_size)
821
-
822
- embeddings = torch.cat((chosen_embeddings, rejected_embeddings), dim=0)
823
- labels = torch.cat(
824
- (torch.ones_like(chosen_embeddings[:, 0]), torch.zeros_like(rejected_embeddings[:, 0])), dim=0
825
- )
826
-
827
- self.clf = LogisticRegression(class_weight="balanced").fit(
828
- embeddings.cpu().float().numpy(), labels.cpu().numpy()
829
- )
830
-
831
- @property
832
- def match_underlying_distribution(self):
833
- return self.embedding_func is not None and self.embedding_tokenizer is not None
834
-
835
- def _get_chosen_prob(self, prompt_embeddings: torch.FloatTensor) -> torch.FloatTensor:
836
- """
837
- Calculates the probability if the given prompt embedding is from desirable dataset.
838
- This function calculates the probability in the process and ensemble across processes.
839
- """
840
- dtype = prompt_embeddings.dtype
841
- device = prompt_embeddings.device
842
- rank = self.accelerator.process_index
843
-
844
- padded_prompt_embeddings = self.accelerator.pad_across_processes(
845
- prompt_embeddings, pad_index=self.embedding_tokenizer.pad_token_id
846
- )
847
- sample_size = padded_prompt_embeddings.shape[0]
848
- nonzero = padded_prompt_embeddings.mean(dim=1) != self.embedding_tokenizer.pad_token_id
849
- prompt_embeddings = self.accelerator.gather(padded_prompt_embeddings)
850
-
851
- # cannot predict for all empty values
852
- if prompt_embeddings.shape[0] == 0:
853
- return torch.tensor([], device=device, dtype=dtype)
854
-
855
- prob = self.clf.predict_proba(prompt_embeddings.cpu().float().numpy())[:, 1]
856
- prob = torch.as_tensor(prob, dtype=dtype, device=device)
857
- prob = self.accelerator.reduce(prob, reduction="mean")
858
-
859
- prob = prob[sample_size * rank : sample_size * (rank + 1)]
860
- prob = prob[nonzero]
861
-
862
- return prob
863
-
864
- def _vectorize_prompt(self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor) -> torch.FloatTensor:
865
- """
866
- Replaces processing_class.pad_token_id to embedding_tokenizer.pad_token_id
867
- and applies self.embedding_func
868
- """
869
- input_ids = torch.where(
870
- input_ids == self.processing_class.pad_token_id,
871
- self.embedding_tokenizer.pad_token_id,
872
- input_ids,
873
- )
874
-
875
- with torch.no_grad():
876
- embeddings = self.embedding_func(
877
- input_ids=input_ids,
878
- attention_mask=attention_mask,
879
- )
880
-
881
- return embeddings
882
-
883
- def _get_prompt_embeddings(
884
- self, batch: dict[str, Union[list, torch.LongTensor]]
885
- ) -> tuple[torch.FloatTensor, torch.FloatTensor]:
886
- """Extract embeddings from frozen embedding model"""
887
-
888
- if not self.match_underlying_distribution:
889
- return None, None
890
-
891
- embeddings = self._vectorize_prompt(
892
- input_ids=batch["embedding_input_ids"],
893
- attention_mask=batch["embedding_attention_mask"],
894
- )
895
-
896
- chosen_idx = [i for i in range(len(batch["label"])) if batch["label"][i] is True]
897
- rejected_idx = [i for i in range(len(batch["label"])) if batch["label"][i] is False]
898
-
899
- chosen_embeddings = embeddings[chosen_idx, ...]
900
- rejected_embeddings = embeddings[rejected_idx, ...]
901
-
902
- return (chosen_embeddings, rejected_embeddings)
903
-
904
- def _get_sample_prompt_embeddings(self, dataset: Dataset, sample_size: int = 512) -> torch.FloatTensor:
905
- """
906
- Sample instances from dataset and get prompt embeddings.
907
- Used for density ratio classifier training.
908
- """
909
- n_samples = min(len(dataset), sample_size)
910
- rand_indices = np.random.choice(len(dataset), size=(n_samples,))
911
-
912
- embedding_dataset = dataset.select(rand_indices)
913
-
914
- dataloader_params = {
915
- "batch_size": self.args.per_device_train_batch_size,
916
- "collate_fn": self.data_collator,
917
- "num_workers": self.args.dataloader_num_workers,
918
- "pin_memory": self.args.dataloader_pin_memory,
919
- "shuffle": False,
920
- }
921
-
922
- # prepare dataloader
923
- data_loader = self.accelerator.prepare(DataLoader(embedding_dataset, **dataloader_params))
924
-
925
- with torch.no_grad():
926
- all_embeddings = torch.empty(0)
927
- for padded_batch in tqdm(iterable=data_loader, desc="Building sample prompt embeddings"):
928
- embeddings = self._vectorize_prompt(
929
- input_ids=padded_batch["embedding_input_ids"],
930
- attention_mask=padded_batch["embedding_attention_mask"],
931
- )
932
- embeddings = self.accelerator.gather_for_metrics(embeddings)
933
- all_embeddings = torch.cat((all_embeddings, embeddings.cpu()))
934
-
935
- return all_embeddings
936
-
937
- def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
938
- # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
939
- deepspeed_plugin = self.accelerator.state.deepspeed_plugin
940
- config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
941
-
942
- if model is not None:
943
- if hasattr(model, "config"):
944
- hidden_size = (
945
- max(model.config.hidden_sizes)
946
- if getattr(model.config, "hidden_sizes", None)
947
- else getattr(model.config, "hidden_size", None)
948
- )
949
- if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
950
- # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
951
- # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
952
- config_kwargs.update(
953
- {
954
- "zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
955
- "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
956
- "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
957
- }
958
- )
959
-
960
- # If ZeRO-3 is used, we shard both the active and reference model.
961
- # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
962
- if config_kwargs["zero_optimization"]["stage"] != 3:
963
- config_kwargs["zero_optimization"]["stage"] = 0
964
- model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
965
- model.eval()
966
- return model
967
-
968
- def _save_optimizer_and_scheduler(self, output_dir):
969
- super()._save_optimizer_and_scheduler(output_dir)
970
-
971
- # When saving optimizer and scheduler to checkpoint, save also the running delta object.
972
- output_dir = output_dir if output_dir is not None else self.args.output_dir
973
-
974
- self.running.save_to_json(os.path.join(output_dir, RUNNING_NAME))
975
-
976
- if self.match_underlying_distribution:
977
- torch.save(self.clf.get_params(), os.path.join(output_dir, CLF_NAME))
978
-
979
- def _load_optimizer_and_scheduler(self, checkpoint):
980
- super()._load_optimizer_and_scheduler(checkpoint)
981
-
982
- if checkpoint is None:
983
- return
984
- # when loading optimizer and scheduler from checkpoint, also load the running delta object.
985
- running_file = os.path.join(checkpoint, RUNNING_NAME)
986
- if os.path.isfile(running_file):
987
- self.running = RunningMoments.load_from_json(self.accelerator, running_file)
988
-
989
- if self.match_underlying_distribution:
990
- clf_file = os.path.join(checkpoint, CLF_NAME)
991
- if os.path.isfile(running_file):
992
- self.clf.set_params(**torch.load(clf_file, weights_only=True, map_location="cpu"))
993
-
994
- @contextmanager
995
- def null_ref_context(self):
996
- """Context manager for handling null reference model (that is, peft adapter manipulation)."""
997
- with (
998
- self.accelerator.unwrap_model(self.model).disable_adapter()
999
- if self.is_peft_model and not self.ref_adapter_name
1000
- else nullcontext()
1001
- ):
1002
- if self.ref_adapter_name:
1003
- self.model.set_adapter(self.ref_adapter_name)
1004
- yield
1005
- if self.ref_adapter_name:
1006
- self.model.set_adapter(self.model_adapter_name or "default")
1007
-
1008
- def get_train_dataloader(self) -> DataLoader:
1009
- """
1010
- Returns the training [`~torch.utils.data.DataLoader`].
1011
-
1012
- Subclass of transformers.src.transformers.trainer.get_train_dataloader to precompute `ref_log_probs`.
1013
- """
1014
-
1015
- if self.precompute_ref_log_probs and not self._precomputed_train_ref_log_probs:
1016
- dataloader_params = {
1017
- "batch_size": self.args.per_device_train_batch_size,
1018
- "collate_fn": self.data_collator,
1019
- "num_workers": self.args.dataloader_num_workers,
1020
- "pin_memory": self.args.dataloader_pin_memory,
1021
- "shuffle": False,
1022
- }
1023
-
1024
- # prepare dataloader
1025
- data_loader = self.accelerator.prepare(DataLoader(self.train_dataset, **dataloader_params))
1026
- reference_completion_logps = []
1027
-
1028
- for padded_batch in tqdm(iterable=data_loader, desc="Train dataset reference log probs"):
1029
- reference_completion_logp = self.compute_reference_log_probs(padded_batch)
1030
-
1031
- reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp)
1032
- reference_completion_logps.append(reference_completion_logp.cpu())
1033
-
1034
- self.train_dataset = self.train_dataset.add_column(
1035
- name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy()
1036
- )
1037
-
1038
- self._precomputed_train_ref_log_probs = True
1039
-
1040
- return super().get_train_dataloader()
1041
-
1042
- def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
1043
- """
1044
- Returns the evaluation [`~torch.utils.data.DataLoader`].
1045
-
1046
- Subclass of transformers.src.transformers.trainer.get_eval_dataloader to precompute `ref_log_probs`.
1047
-
1048
- Args:
1049
- eval_dataset (`torch.utils.data.Dataset`, *optional*):
1050
- If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted
1051
- by the `model.forward()` method are automatically removed. It must implement `__len__`.
1052
- """
1053
- if eval_dataset is None and self.eval_dataset is None:
1054
- raise ValueError("Trainer: evaluation requires an eval_dataset.")
1055
- eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
1056
-
1057
- if self.precompute_ref_log_probs and not self._precomputed_eval_ref_log_probs:
1058
- dataloader_params = {
1059
- "batch_size": self.args.per_device_eval_batch_size,
1060
- "collate_fn": self.data_collator,
1061
- "num_workers": self.args.dataloader_num_workers,
1062
- "pin_memory": self.args.dataloader_pin_memory,
1063
- "shuffle": False,
1064
- }
1065
-
1066
- # prepare dataloader
1067
- data_loader = self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params))
1068
-
1069
- reference_completion_logps = []
1070
-
1071
- for padded_batch in tqdm(iterable=data_loader, desc="Eval dataset reference log probs"):
1072
- reference_completion_logp = self.compute_reference_log_probs(padded_batch)
1073
-
1074
- reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp)
1075
- reference_completion_logps.append(reference_completion_logp.cpu())
1076
-
1077
- eval_dataset = eval_dataset.add_column(
1078
- name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy()
1079
- )
1080
-
1081
- # Save calculated reference_chosen_logps and reference_rejected_logps to the eval_dataset for subsequent runs
1082
- if self.eval_dataset is not None:
1083
- self.eval_dataset = eval_dataset
1084
- self._precomputed_eval_ref_log_probs = True
1085
-
1086
- return super().get_eval_dataloader(eval_dataset=eval_dataset)
1087
-
1088
- def compute_reference_log_probs(self, padded_batch: dict) -> dict:
1089
- """Computes log probabilities of the reference model for a single padded batch of a BCO specific dataset."""
1090
- with torch.no_grad():
1091
- if self.ref_model is None:
1092
- with self.null_ref_context():
1093
- if self.is_encoder_decoder:
1094
- completion_logits = self.model(
1095
- padded_batch["prompt_input_ids"],
1096
- attention_mask=padded_batch["prompt_attention_mask"],
1097
- decoder_input_ids=padded_batch.get("completion_decoder_input_ids"),
1098
- labels=padded_batch["completion_labels"],
1099
- ).logits
1100
-
1101
- else:
1102
- completion_logits = self.model(
1103
- padded_batch["completion_input_ids"],
1104
- attention_mask=padded_batch["completion_attention_mask"],
1105
- ).logits
1106
-
1107
- else:
1108
- if self.is_encoder_decoder:
1109
- completion_logits = self.ref_model(
1110
- padded_batch["prompt_input_ids"],
1111
- attention_mask=padded_batch["prompt_attention_mask"],
1112
- decoder_input_ids=padded_batch.get("completion_decoder_input_ids"),
1113
- labels=padded_batch["completion_labels"],
1114
- ).logits
1115
-
1116
- else:
1117
- completion_logits = self.ref_model(
1118
- padded_batch["completion_input_ids"], attention_mask=padded_batch["completion_attention_mask"]
1119
- ).logits
1120
-
1121
- completion_logps = self.get_batch_logps(
1122
- completion_logits,
1123
- padded_batch["completion_labels"],
1124
- average_log_prob=False,
1125
- is_encoder_decoder=self.is_encoder_decoder,
1126
- label_pad_token_id=self.label_pad_token_id,
1127
- )
1128
-
1129
- return completion_logps
1130
-
1131
- @staticmethod
1132
- def get_batch_logps(
1133
- logits: torch.FloatTensor,
1134
- labels: torch.LongTensor,
1135
- average_log_prob: bool = False,
1136
- label_pad_token_id: int = -100,
1137
- is_encoder_decoder: bool = False,
1138
- ) -> torch.FloatTensor:
1139
- """Compute the log probabilities of the given labels under the given logits.
1140
-
1141
- Args:
1142
- logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
1143
- labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
1144
- average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
1145
-
1146
- Returns:
1147
- A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
1148
- """
1149
- if logits.shape[:-1] != labels.shape:
1150
- raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
1151
-
1152
- if not is_encoder_decoder:
1153
- labels = labels[:, 1:].clone()
1154
- logits = logits[:, :-1, :]
1155
- else:
1156
- # Fixes end-dec RuntimeError
1157
- labels = labels.clone()
1158
-
1159
- loss_mask = labels != label_pad_token_id
1160
-
1161
- # dummy token; we'll ignore the losses on these tokens later
1162
- labels[labels == label_pad_token_id] = 0
1163
-
1164
- per_token_logps = selective_log_softmax(logits, labels)
1165
-
1166
- if average_log_prob:
1167
- return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
1168
- else:
1169
- return (per_token_logps * loss_mask).sum(-1)
1170
-
1171
- def forward(
1172
- self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
1173
- ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
1174
- model_kwargs = (
1175
- {
1176
- "labels": batch["completion_labels"],
1177
- "decoder_input_ids": batch.get("completion_decoder_input_ids"),
1178
- }
1179
- if self.is_encoder_decoder
1180
- else {}
1181
- )
1182
- if self.aux_loss_enabled:
1183
- model_kwargs["output_router_logits"] = True
1184
-
1185
- outputs = model(
1186
- batch["completion_input_ids"],
1187
- attention_mask=batch["completion_attention_mask"],
1188
- **model_kwargs,
1189
- )
1190
- completion_logits = outputs.logits
1191
-
1192
- completion_logps = self.get_batch_logps(
1193
- completion_logits,
1194
- batch["completion_labels"],
1195
- average_log_prob=False,
1196
- is_encoder_decoder=self.is_encoder_decoder,
1197
- label_pad_token_id=self.label_pad_token_id,
1198
- )
1199
-
1200
- if completion_logps.shape[0] != len(batch["label"]):
1201
- raise ValueError(
1202
- "There is a mismatch between the number of examples in this batch and the number of "
1203
- "examples for which an output sequence was predicted."
1204
- )
1205
-
1206
- chosen_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is True]
1207
- rejected_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is False]
1208
-
1209
- chosen_logps = completion_logps[chosen_idx, ...]
1210
- rejected_logps = completion_logps[rejected_idx, ...]
1211
-
1212
- chosen_logits = completion_logits[chosen_idx, ...]
1213
- rejected_logits = completion_logits[rejected_idx, ...]
1214
-
1215
- if self.aux_loss_enabled:
1216
- return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, outputs.aux_loss)
1217
- else:
1218
- return (chosen_logps, rejected_logps, chosen_logits, rejected_logits)
1219
-
1220
- def _get_udm_weight(self, rejected_embeddings: torch.FloatTensor) -> torch.FloatTensor:
1221
- prob_desirable = self._get_chosen_prob(rejected_embeddings)
1222
- min_ratio = self.args.min_density_ratio
1223
- max_ratio = self.args.max_density_ratio
1224
-
1225
- weight = (prob_desirable / (1 - prob_desirable + 1e-8)).clamp(min=min_ratio, max=max_ratio)
1226
-
1227
- return weight
1228
-
1229
- def bco_loss(
1230
- self,
1231
- policy_chosen_logps: torch.FloatTensor,
1232
- policy_rejected_logps: torch.FloatTensor,
1233
- reference_chosen_logps: torch.FloatTensor,
1234
- reference_rejected_logps: torch.FloatTensor,
1235
- chosen_embeddings: Optional[torch.FloatTensor],
1236
- rejected_embeddings: Optional[torch.FloatTensor],
1237
- ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
1238
- """Compute the BCO loss for a batch of policy and reference model log probabilities.
1239
-
1240
- Args:
1241
- policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (num(chosen) in batch_size,)
1242
- policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (num(rejected) in batch_size,)
1243
- reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (num(chosen) in batch_size,)
1244
- reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (num(rejected) in batch_size,)
1245
- chosen_embeddings: embeddings of desirable prompts
1246
- rejected_embeddings: embeddings of undesirable prompts
1247
-
1248
- Returns:
1249
- A tuple of four tensors: (losses, chosen_rewards, rejected_rewards, delta).
1250
- The losses tensor contains the BCO loss for each example in the batch.
1251
- The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
1252
- The delta value contains the moving average of all implicit rewards.
1253
- """
1254
-
1255
- if policy_chosen_logps.shape[0] != 0 or reference_chosen_logps.shape[0] != 0:
1256
- chosen_logratios = policy_chosen_logps - reference_chosen_logps
1257
- chosen_rewards = self.beta * chosen_logratios
1258
- else:
1259
- # lists can't be empty -- if they are, then accelerate.gather will hang
1260
- chosen_losses = torch.Tensor([]).to(self.accelerator.device)
1261
- chosen_rewards = torch.Tensor([]).to(self.accelerator.device)
1262
-
1263
- if policy_rejected_logps.shape[0] != 0 or reference_rejected_logps.shape[0] != 0:
1264
- rejected_logratios = policy_rejected_logps - reference_rejected_logps
1265
- rejected_rewards = self.beta * rejected_logratios
1266
- else:
1267
- # lists can't be empty -- if they are, then accelerate.gather will hang
1268
- rejected_losses = torch.Tensor([]).to(self.accelerator.device)
1269
- rejected_rewards = torch.Tensor([]).to(self.accelerator.device)
1270
-
1271
- rewards = torch.cat((chosen_rewards, rejected_rewards), 0).mean().detach()
1272
- self.running.update(rewards)
1273
- delta = self.running.mean
1274
-
1275
- if policy_chosen_logps.shape[0] != 0 or reference_chosen_logps.shape[0] != 0:
1276
- chosen_losses = -F.logsigmoid(chosen_rewards - delta)
1277
-
1278
- if policy_rejected_logps.shape[0] != 0 or reference_rejected_logps.shape[0] != 0:
1279
- rejected_losses = -F.logsigmoid(-(rejected_rewards - delta))
1280
-
1281
- if self.match_underlying_distribution:
1282
- chosen_weight = torch.ones_like(chosen_losses)
1283
- rejected_weight = self._get_udm_weight(rejected_embeddings)
1284
-
1285
- losses = torch.cat((chosen_weight * chosen_losses, rejected_weight * rejected_losses), dim=0)
1286
- else:
1287
- losses = torch.cat((chosen_losses, rejected_losses), dim=0)
1288
-
1289
- return losses, chosen_rewards, rejected_rewards, torch.as_tensor(delta)
1290
-
1291
- def get_batch_loss_metrics(
1292
- self,
1293
- model,
1294
- batch: dict[str, Union[list, torch.LongTensor]],
1295
- ):
1296
- """Compute the BCO loss and other metrics for the given batch of inputs for train or test."""
1297
- metrics = {}
1298
- batch = {k: (v.to(self.accelerator.device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
1299
-
1300
- forward_output = self.forward(model, batch)
1301
- (
1302
- policy_chosen_logps,
1303
- policy_rejected_logps,
1304
- policy_chosen_logits,
1305
- policy_rejected_logits,
1306
- ) = forward_output[:4]
1307
- if self.aux_loss_enabled:
1308
- aux_loss = forward_output[4]
1309
-
1310
- # if reference_logps in batch use them, otherwise use the reference model
1311
- if "reference_logps" in batch:
1312
- chosen_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is True]
1313
- rejected_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is False]
1314
-
1315
- reference_chosen_logps = batch["reference_logps"][chosen_idx, ...]
1316
- reference_rejected_logps = batch["reference_logps"][rejected_idx, ...]
1317
- else:
1318
- with torch.no_grad():
1319
- if self.ref_model is None:
1320
- with self.null_ref_context():
1321
- (
1322
- reference_chosen_logps,
1323
- reference_rejected_logps,
1324
- _,
1325
- _,
1326
- ) = self.forward(self.model, batch)[:4]
1327
- else:
1328
- (
1329
- reference_chosen_logps,
1330
- reference_rejected_logps,
1331
- _,
1332
- _,
1333
- ) = self.forward(self.ref_model, batch)[:4]
1334
-
1335
- chosen_embeddings, rejected_embeddings = self._get_prompt_embeddings(batch)
1336
-
1337
- losses, chosen_rewards, rejected_rewards, delta = self.bco_loss(
1338
- policy_chosen_logps,
1339
- policy_rejected_logps,
1340
- reference_chosen_logps,
1341
- reference_rejected_logps,
1342
- chosen_embeddings,
1343
- rejected_embeddings,
1344
- )
1345
- metrics["delta"] = self.accelerator.gather_for_metrics(delta).mean().item()
1346
-
1347
- num_chosen = torch.Tensor([len(chosen_rewards)]).to(self.accelerator.device)
1348
- num_rejected = torch.Tensor([len(rejected_rewards)]).to(self.accelerator.device)
1349
-
1350
- all_num_chosen = self.accelerator.gather_for_metrics(num_chosen).sum().item()
1351
- all_num_rejected = self.accelerator.gather_for_metrics(num_rejected).sum().item()
1352
-
1353
- if all_num_chosen > 0:
1354
- metrics["rewards/chosen_sum"] = (
1355
- self.accelerator.gather_for_metrics(chosen_rewards.nansum()).nansum().item()
1356
- )
1357
- metrics["logps/chosen_sum"] = (
1358
- self.accelerator.gather_for_metrics(policy_chosen_logps.nansum()).nansum().item()
1359
- )
1360
- metrics["logits/chosen_sum"] = (
1361
- self.accelerator.gather_for_metrics(policy_chosen_logits.nansum()).nansum().item()
1362
- )
1363
- metrics["count/chosen"] = all_num_chosen
1364
-
1365
- if all_num_rejected > 0:
1366
- metrics["rewards/rejected_sum"] = (
1367
- self.accelerator.gather_for_metrics(rejected_rewards.nansum()).nansum().item()
1368
- )
1369
- metrics["logps/rejected_sum"] = (
1370
- self.accelerator.gather_for_metrics(policy_rejected_logps.nansum()).nansum().item()
1371
- )
1372
- metrics["logits/rejected_sum"] = (
1373
- self.accelerator.gather_for_metrics(policy_rejected_logits.nansum()).nansum().item()
1374
- )
1375
- metrics["count/rejected"] = all_num_rejected
1376
-
1377
- loss = losses.nanmean()
1378
- if self.aux_loss_enabled:
1379
- loss += self.aux_loss_coef * aux_loss
1380
-
1381
- return loss, metrics
1382
-
1383
- def compute_loss(
1384
- self,
1385
- model: Union[PreTrainedModel, nn.Module],
1386
- inputs: dict[str, Union[torch.Tensor, Any]],
1387
- return_outputs=False,
1388
- num_items_in_batch=None,
1389
- ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
1390
- compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1391
-
1392
- with compute_loss_context_manager:
1393
- loss, metrics = self.get_batch_loss_metrics(model, inputs)
1394
-
1395
- # Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class:
1396
- loss = loss.to(self.args.device)
1397
- # force log the metrics
1398
- if self.accelerator.is_main_process:
1399
- self.store_metrics(metrics, train_eval="train")
1400
-
1401
- if return_outputs:
1402
- return (loss, metrics)
1403
- return loss
1404
-
1405
- def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
1406
- for key, value in metrics.items():
1407
- self._stored_metrics[train_eval][key].append(value)
1408
-
1409
- def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
1410
- if self.train_dataset is None or not has_length(self.train_dataset):
1411
- return None
1412
- return SequentialSampler(self.train_dataset)
1413
-
1414
- def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) -> tuple[str, str]:
1415
- """Generate samples from the model and reference model for the given batch of inputs."""
1416
-
1417
- # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
1418
- # the torch cuda amp context manager as some hidden states are silently casted to full precision.
1419
- generate_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1420
- with generate_context_manager:
1421
- policy_output = model.generate(
1422
- input_ids=batch["prompt_input_ids"],
1423
- attention_mask=batch["prompt_attention_mask"],
1424
- max_length=self.max_length,
1425
- do_sample=True,
1426
- pad_token_id=self.processing_class.pad_token_id,
1427
- )
1428
-
1429
- # if reference_output in batch use that otherwise use the reference model
1430
- if "reference_output" in batch:
1431
- reference_output = batch["reference_output"]
1432
- else:
1433
- if self.ref_model is None:
1434
- with self.null_ref_context():
1435
- reference_output = self.model.generate(
1436
- input_ids=batch["prompt_input_ids"],
1437
- attention_mask=batch["prompt_attention_mask"],
1438
- max_length=self.max_length,
1439
- do_sample=True,
1440
- pad_token_id=self.processing_class.pad_token_id,
1441
- )
1442
- else:
1443
- reference_output = self.ref_model.generate(
1444
- input_ids=batch["prompt_input_ids"],
1445
- attention_mask=batch["prompt_attention_mask"],
1446
- max_length=self.max_length,
1447
- do_sample=True,
1448
- pad_token_id=self.processing_class.pad_token_id,
1449
- )
1450
-
1451
- policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id)
1452
- policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True)
1453
-
1454
- reference_output = pad_to_length(reference_output, self.max_length, self.processing_class.pad_token_id)
1455
- reference_output_decoded = self.processing_class.batch_decode(reference_output, skip_special_tokens=True)
1456
-
1457
- return policy_output_decoded, reference_output_decoded
1458
-
1459
- def prediction_step(
1460
- self,
1461
- model: Union[PreTrainedModel, nn.Module],
1462
- inputs: dict[str, Union[torch.Tensor, Any]],
1463
- prediction_loss_only: bool,
1464
- ignore_keys: Optional[list[str]] = None,
1465
- ):
1466
- if ignore_keys is None:
1467
- if hasattr(model, "config"):
1468
- ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
1469
- else:
1470
- ignore_keys = []
1471
-
1472
- prediction_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1473
- with torch.no_grad(), prediction_context_manager:
1474
- loss, metrics = self.get_batch_loss_metrics(model, inputs)
1475
-
1476
- # force log the metrics
1477
- if self.accelerator.is_main_process:
1478
- self.store_metrics(metrics, train_eval="eval")
1479
-
1480
- if prediction_loss_only:
1481
- return (loss.detach(), None, None)
1482
-
1483
- # logits for the chosen and rejected samples from model
1484
- logits_dict = {
1485
- "eval_logits/chosen": metrics["logits/chosen"],
1486
- "eval_logits/rejected": metrics["logits/rejected"],
1487
- }
1488
- logits = tuple(v.unsqueeze(dim=0) for k, v in logits_dict.items() if k not in ignore_keys)
1489
- logits = torch.stack(logits).mean(axis=1).to(self.accelerator.device)
1490
- labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
1491
-
1492
- return (loss.detach(), logits, labels)
1493
-
1494
- def evaluation_loop(
1495
- self,
1496
- dataloader: DataLoader,
1497
- description: str,
1498
- prediction_loss_only: Optional[bool] = None,
1499
- ignore_keys: Optional[list[str]] = None,
1500
- metric_key_prefix: str = "eval",
1501
- ) -> EvalLoopOutput:
1502
- """
1503
- Overriding built-in evaluation loop to store metrics for each batch.
1504
- Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
1505
-
1506
- Works both with or without labels.
1507
- """
1508
-
1509
- # Sample and save to game log if requested (for one batch to save time)
1510
- if self.generate_during_eval:
1511
- # Generate random indices within the range of the total number of samples
1512
- num_samples = len(dataloader.dataset)
1513
- random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
1514
-
1515
- # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
1516
- random_batch_dataset = dataloader.dataset.select(random_indices)
1517
- random_batch = self.data_collator(random_batch_dataset)
1518
- random_batch = self._prepare_inputs(random_batch)
1519
-
1520
- target_indicies = [i for i in range(len(random_batch["label"])) if random_batch["label"][i] is False]
1521
- target_batch = {
1522
- "prompt_input_ids": random_batch["prompt_input_ids"][target_indicies],
1523
- "prompt_attention_mask": random_batch["prompt_attention_mask"][target_indicies],
1524
- "prompt": itemgetter(*target_indicies)(random_batch["prompt"]),
1525
- }
1526
- policy_output_decoded, ref_output_decoded = self.generate_from_model_and_ref(self.model, target_batch)
1527
-
1528
- table = pd.DataFrame(
1529
- columns=["Prompt", "Policy", "Ref Model"],
1530
- data=[
1531
- [prompt, pol[len(prompt) :], ref[len(prompt) :]]
1532
- for prompt, pol, ref in zip(target_batch["prompt"], policy_output_decoded, ref_output_decoded)
1533
- ],
1534
- )
1535
- if "wandb" in self.args.report_to:
1536
- wandb.log({"game_log": wandb.Table(data=table)})
1537
-
1538
- if "comet_ml" in self.args.report_to:
1539
- log_table_to_comet_experiment(
1540
- name="game_log.csv",
1541
- table=table,
1542
- )
1543
-
1544
- # Base evaluation
1545
- initial_output = super().evaluation_loop(
1546
- dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
1547
- )
1548
-
1549
- return initial_output
1550
-
1551
- def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
1552
- """
1553
- Log `logs` on the various objects watching training, including stored metrics.
1554
-
1555
- Args:
1556
- logs (`dict[str, float]`):
1557
- The values to log.
1558
- start_time (`float` or `None`, *optional*, defaults to `None`):
1559
- Start time of the training.
1560
- """
1561
- # logs either has 'loss' or 'eval_loss'
1562
- train_eval = "train" if "loss" in logs else "eval"
1563
- # train metrics should have no prefix, eval should have 'eval_'
1564
- prefix = "eval_" if train_eval == "eval" else ""
1565
- # accumulate average metrics from sums and lengths
1566
- for split in ["chosen", "rejected"]:
1567
- if f"count/{split}" in self._stored_metrics[train_eval]:
1568
- count_sum = torch.Tensor(self._stored_metrics[train_eval][f"count/{split}"]).sum().item()
1569
- for metric in ["rewards", "logps", "logits"]:
1570
- logs[f"{prefix}{metric}/{split}"] = (
1571
- torch.Tensor(self._stored_metrics[train_eval][f"{metric}/{split}_sum"]).sum().item()
1572
- / count_sum
1573
- )
1574
- # delete obsolete metric
1575
- del self._stored_metrics[train_eval][f"{metric}/{split}_sum"]
1576
- del self._stored_metrics[train_eval][f"count/{split}"]
1577
- # calculate reward margin
1578
- if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs:
1579
- logs[f"{prefix}rewards/margins"] = logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"]
1580
- # Add averaged stored metrics to logs
1581
- for key, metrics in self._stored_metrics[train_eval].items():
1582
- logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item()
1583
- del self._stored_metrics[train_eval]
1584
-
1585
- if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
1586
- return super().log(logs, start_time)
1587
- else: # transformers<=4.46
1588
- return super().log(logs)
1589
-
1590
- def create_model_card(
1591
- self,
1592
- model_name: Optional[str] = None,
1593
- dataset_name: Optional[str] = None,
1594
- tags: Union[str, list[str], None] = None,
1595
- ):
1596
- """
1597
- Creates a draft of a model card using the information available to the `Trainer`.
1598
-
1599
- Args:
1600
- model_name (`str` or `None`, *optional*, defaults to `None`):
1601
- Name of the model.
1602
- dataset_name (`str` or `None`, *optional*, defaults to `None`):
1603
- Name of the dataset used for training.
1604
- tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
1605
- Tags to be associated with the model card.
1606
- """
1607
- if not self.is_world_process_zero():
1608
- return
1609
-
1610
- if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
1611
- base_model = self.model.config._name_or_path
1612
- else:
1613
- base_model = None
1614
-
1615
- tags = tags or []
1616
- if isinstance(tags, str):
1617
- tags = [tags]
1618
-
1619
- if hasattr(self.model.config, "unsloth_version"):
1620
- tags.append("unsloth")
1621
-
1622
- citation = textwrap.dedent("""\
1623
- @article{jung2024binary,
1624
- title = {{Binary Classifier Optimization for Large Language Model Alignment}},
1625
- author = {Seungjae Jung and Gunsoo Han and Daniel Wontae Nam and Kyoung{-}Woon On},
1626
- year = 2024,
1627
- eprint = {arXiv:2404.04656}
1628
- }""")
1629
-
1630
- model_card = generate_model_card(
1631
- base_model=base_model,
1632
- model_name=model_name,
1633
- hub_model_id=self.hub_model_id,
1634
- dataset_name=dataset_name,
1635
- tags=tags,
1636
- wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
1637
- comet_url=get_comet_experiment_url(),
1638
- trainer_name="BCO",
1639
- trainer_citation=citation,
1640
- paper_title="Binary Classifier Optimization for Large Language Model Alignment",
1641
- paper_id="2404.04656",
1642
- )
1643
-
1644
- model_card.save(os.path.join(self.args.output_dir, "README.md"))
1645
- class UnslothBCOTrainer(_UnslothBCOTrainer):
1646
- """
1647
-
1648
- Initialize BCOTrainer from [BCO](https://huggingface.co/papers/2404.04656) paper.
1649
-
1650
- Args:
1651
- model (`transformers.PreTrainedModel`):
1652
- The model to train, preferably an `AutoModelForSequenceClassification`.
1653
- ref_model (`PreTrainedModelWrapper`):
1654
- Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no
1655
- reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized.
1656
- args (`BCOConfig`):
1657
- The arguments to use for training.
1658
- train_dataset (`datasets.Dataset`):
1659
- The dataset to use for training.
1660
- eval_dataset (`datasets.Dataset`):
1661
- The dataset to use for evaluation.
1662
- processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
1663
- Processing class used to process the data. If provided, will be used to automatically process the inputs
1664
- for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
1665
- reuse the fine-tuned model.
1666
- data_collator (`transformers.DataCollator`, *optional*, defaults to `None`):
1667
- The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
1668
- which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
1669
- model_init (`Callable[[], transformers.PreTrainedModel]`):
1670
- The model initializer to use for training. If None is specified, the default model initializer will be used.
1671
- callbacks (`list[transformers.TrainerCallback]`):
1672
- The callbacks to use for training.
1673
- optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
1674
- The optimizer and scheduler to use for training.
1675
- preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
1676
- The function to use to preprocess the logits before computing the metrics.
1677
- peft_config (`dict`, defaults to `None`):
1678
- The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
1679
- compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
1680
- The function to use to compute the metrics. Must take a `EvalPrediction` and return
1681
- a dictionary string to metric values.
1682
- model_adapter_name (`str`, defaults to `None`):
1683
- Name of the train target PEFT adapter, when using LoRA with multiple adapters.
1684
- ref_adapter_name (`str`, defaults to `None`):
1685
- Name of the reference PEFT adapter, when using LoRA with multiple adapters.
1686
-
1687
- """
1688
- def __init__(
1689
- self,
1690
- model = None,
1691
- ref_model = None,
1692
- args = None,
1693
- train_dataset = None,
1694
- eval_dataset = None,
1695
- processing_class = None,
1696
- data_collator = None,
1697
- model_init = None,
1698
- callbacks = None,
1699
- preprocess_logits_for_metrics = None,
1700
- peft_config = None,
1701
- compute_metrics = None,
1702
- model_adapter_name = None,
1703
- ref_adapter_name = None,
1704
- embedding_func = None,
1705
- embedding_tokenizer = None,
1706
- **kwargs
1707
- ):
1708
- if args is None: args = UnslothBCOConfig()
1709
- use_bf16 = getattr(args, 'bf16', False)
1710
- use_fp16 = getattr(args, 'fp16', False)
1711
- force_float32 = False
1712
- if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
1713
- print('Unsloth: Switching to float32 training since model cannot work with float16')
1714
- force_float32 = True
1715
- mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
1716
- dtype = getattr(model.config, 'torch_dtype', None)
1717
- if dtype is None: dtype = model.get_input_embeddings().dtype
1718
- from unsloth_zoo.utils import _get_dtype
1719
- dtype = _get_dtype(dtype)
1720
- float16 = dtype == torch.float16
1721
- if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
1722
- if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
1723
- if force_float32:
1724
- args.fp16 = False
1725
- args.bf16 = False
1726
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1727
- elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
1728
- args.fp16 = float16
1729
- args.bf16 = not float16
1730
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
1731
- if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
1732
- args.eval_strategy = 'steps'
1733
- if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
1734
- ga_steps = getattr(args, 'gradient_accumulation_steps', None)
1735
- if ga_steps is not None and ga_steps > 1:
1736
- from transformers import __version__ as transformers_version
1737
- if Version(transformers_version) <= Version('4.45.2'):
1738
- print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
1739
- '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
1740
- if getattr(args, 'eval_strategy', 'no') != 'no':
1741
- eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
1742
- if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
1743
- if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
1744
- fp16_full_eval = getattr(args, 'fp16_full_eval', False)
1745
- bf16_full_eval = getattr(args, 'bf16_full_eval', False)
1746
- if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
1747
- if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
1748
- if force_float32:
1749
- args.bf16_full_eval = False
1750
- args.fp16_full_eval = False
1751
- elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
1752
- args.bf16_full_eval = True
1753
- args.fp16_full_eval = False
1754
- elif not bf16_full_eval and not fp16_full_eval:
1755
- args.bf16_full_eval = args.bf16
1756
- args.fp16_full_eval = args.fp16
1757
- _output_logits = False
1758
- if locals().get('compute_metrics', None) is not None: _output_logits = True
1759
- if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
1760
- if _output_logits:
1761
- os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
1762
- if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1763
- pass
1764
- else:
1765
- model_max_seq_length = getattr(model, 'max_seq_length', None)
1766
- args_max_seq_length = getattr(args, 'max_seq_length', None)
1767
- if args_max_seq_length is None and model_max_seq_length is not None:
1768
- max_seq_length = model.max_seq_length
1769
- if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1770
- if model is not None and hasattr(model, 'for_training'):
1771
- model.for_training()
1772
- if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1773
- if 'processing_class' in locals():
1774
- if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1775
- if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1776
- __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
1777
- from unsloth_zoo.vision_utils import UnslothVisionDataCollator
1778
- if not isinstance(data_collator, UnslothVisionDataCollator):
1779
- if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
1780
- data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
1781
- elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
1782
- data_collator = DataCollatorForSeq2Seq(__tokenizer)
1783
- else:
1784
- if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
1785
- if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
1786
- if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
1787
- if not isinstance(data_collator, UnslothVisionDataCollator):
1788
- if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
1789
- if isinstance(data_collator, DataCollatorForSeq2Seq):
1790
- data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
1791
- else:
1792
- data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
1793
- other_metrics = []
1794
-
1795
- from unsloth_zoo.logging_utils import PatchRLStatistics
1796
- PatchRLStatistics('bco_trainer', other_metrics)
1797
-
1798
- super().__init__(
1799
- model = model,
1800
- ref_model = ref_model,
1801
- args = args,
1802
- train_dataset = train_dataset,
1803
- eval_dataset = eval_dataset,
1804
- processing_class = processing_class,
1805
- data_collator = data_collator,
1806
- model_init = model_init,
1807
- callbacks = callbacks,
1808
- preprocess_logits_for_metrics = preprocess_logits_for_metrics,
1809
- peft_config = peft_config,
1810
- compute_metrics = compute_metrics,
1811
- model_adapter_name = model_adapter_name,
1812
- ref_adapter_name = ref_adapter_name,
1813
- embedding_func = embedding_func,
1814
- embedding_tokenizer = embedding_tokenizer,**kwargs)
1815
- if hasattr(self, 'neftune_hook_handle'):
1816
- self.neftune_hook_handle.remove()
1817
- if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1818
- if getattr(args, 'neftune_noise_alpha', None) is not None:
1819
- model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1820
- pass
1821
-
1822
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
unsloth_compiled_cache/UnslothCPOTrainer.py DELETED
@@ -1,1555 +0,0 @@
1
- """
2
- 2025.3.13
3
- 2025.3.15
4
- 4.48.3
5
- 0.15.2
6
- __UNSLOTH_VERSIONING__
7
- """
8
- from torch import Tensor
9
- import torch
10
- import torch.nn as nn
11
- from torch.nn import functional as F
12
- from trl.trainer.cpo_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, CPOConfig, CPOTrainer, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, Literal, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, Trainer, TrainerCallback, Union, add_bos_token_if_needed, add_eos_token_if_needed, amp, defaultdict, disable_dropout_in_model, generate_model_card, get_comet_experiment_url, inspect, is_comet_available, is_peft_available, is_torch_fx_proxy, is_wandb_available, log_table_to_comet_experiment, maybe_apply_chat_template, maybe_extract_prompt, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_model_for_kbit_training, random, textwrap, torch, transformers, version, wandb, warnings)
13
-
14
-
15
- import os
16
- from typing import *
17
- from dataclasses import dataclass, field
18
- from packaging.version import Version
19
- import torch
20
- import numpy as np
21
- from contextlib import nullcontext
22
- from torch.nn import functional as F
23
- from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
-
25
- torch_compile_options = {
26
- "epilogue_fusion" : True,
27
- "max_autotune" : False,
28
- "shape_padding" : True,
29
- "trace.enabled" : False,
30
- "triton.cudagraphs" : False,
31
- }
32
-
33
- @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
- def selective_log_softmax(logits, index):
35
- logits = logits.to(torch.float32)
36
- selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
- # loop to reduce peak mem consumption
38
- # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
- logsumexp_values = torch.logsumexp(logits, dim = -1)
40
- per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
- return per_token_logps
42
- @dataclass
43
- class UnslothCPOConfig(CPOConfig):
44
- """
45
-
46
- Configuration class for the [`CPOTrainer`].
47
-
48
- Using [`~transformers.HfArgumentParser`] we can turn this class into
49
- [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
50
- command line.
51
-
52
- Parameters:
53
- learning_rate (`float`, *optional*, defaults to `1e-6`):
54
- Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
55
- [`~transformers.TrainingArguments`].
56
- max_length (`int` or `None`, *optional*, defaults to `1024`):
57
- Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
58
- to use the default data collator.
59
- max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
60
- Maximum length of the prompt. This argument is required if you want to use the default data collator.
61
- max_completion_length (`int` or `None`, *optional*, defaults to `None`):
62
- Maximum length of the completion. This argument is required if you want to use the default data collator
63
- and your model is an encoder-decoder.
64
- beta (`float`, *optional*, defaults to `0.1`):
65
- Parameter controlling the deviation from the reference model. Higher β means less deviation from the
66
- reference model. For the IPO loss (`loss_type="ipo"`), β is the regularization parameter denoted by τ in
67
- the [paper](https://huggingface.co/papers/2310.12036).
68
- label_smoothing (`float`, *optional*, defaults to `0.0`):
69
- Label smoothing factor. This argument is required if you want to use the default data collator.
70
- loss_type (`str`, *optional*, defaults to `"sigmoid"`):
71
- Type of loss to use. Possible values are:
72
-
73
- - `"sigmoid"`: sigmoid loss from the original [DPO](https://huggingface.co/papers/2305.18290) paper.
74
- - `"hinge"`: hinge loss on the normalized likelihood from the [SLiC](https://huggingface.co/papers/2305.10425) paper.
75
- - `"ipo"`: IPO loss from the [IPO](https://huggingface.co/papers/2310.12036) paper.
76
- - `"simpo"`: SimPO loss from the [SimPO](https://huggingface.co/papers/2405.14734) paper.
77
-
78
- disable_dropout (`bool`, *optional*, defaults to `True`):
79
- Whether to disable dropout in the model.
80
- cpo_alpha (`float`, *optional*, defaults to `1.0`):
81
- Weight of the BC regularizer in CPO training.
82
- simpo_gamma (`float`, *optional*, defaults to `0.5`):
83
- Target reward margin for the SimPO loss, used only when the `loss_type="simpo"`.
84
- label_pad_token_id (`int`, *optional*, defaults to `-100`):
85
- Label pad token id. This argument is required if you want to use the default data collator.
86
- padding_value (`int` or `None`, *optional*, defaults to `None`):
87
- Padding value to use. If `None`, the padding value of the tokenizer is used.
88
- truncation_mode (`str`,*optional*, defaults to `"keep_end"`):
89
- Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`.
90
- This argument is required if you want to use the default data collator.
91
- generate_during_eval (`bool`, *optional*, defaults to `False`):
92
- If `True`, generates and logs completions from the model to W&B or Comet during evaluation.
93
- is_encoder_decoder (`bool` or `None`, *optional*, defaults to `None`):
94
- When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument,
95
- you need to specify if the model returned by the callable is an encoder-decoder model.
96
- model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
97
- Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
98
- string.
99
- dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
100
- Number of processes to use for processing the dataset.
101
-
102
- """
103
- vllm_sampling_params: Optional[Any] = field(
104
- default = None,
105
- metadata = {'help': 'vLLM SamplingParams'},
106
- )
107
- unsloth_num_chunks : Optional[int] = field(
108
- default = -1,
109
- metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
110
- )
111
- def __init__(
112
- self,
113
- output_dir = None,
114
- overwrite_output_dir = None,
115
- do_train = False,
116
- do_eval = False,
117
- do_predict = False,
118
- eval_strategy = 'no',
119
- prediction_loss_only = False,
120
- per_device_train_batch_size = 4,
121
- per_device_eval_batch_size = 4,
122
- per_gpu_train_batch_size = None,
123
- per_gpu_eval_batch_size = None,
124
- gradient_accumulation_steps = 2,
125
- eval_accumulation_steps = 2,
126
- eval_delay = 0,
127
- torch_empty_cache_steps = 250,
128
- learning_rate = 5e-05,
129
- weight_decay = 0.01,
130
- adam_beta1 = 0.9,
131
- adam_beta2 = 0.999,
132
- adam_epsilon = 1e-08,
133
- max_grad_norm = 1.0,
134
- num_train_epochs = 3.0,
135
- max_steps = -1,
136
- lr_scheduler_type = 'linear',
137
- warmup_ratio = 0.1,
138
- warmup_steps = 0,
139
- log_level = 'passive',
140
- log_level_replica = 'warning',
141
- log_on_each_node = True,
142
- logging_dir = None,
143
- logging_strategy = 'steps',
144
- logging_first_step = False,
145
- logging_steps = 1,
146
- logging_nan_inf_filter = False,
147
- save_strategy = 'steps',
148
- save_steps = 500,
149
- save_total_limit = None,
150
- save_safetensors = True,
151
- save_on_each_node = False,
152
- save_only_model = False,
153
- restore_callback_states_from_checkpoint = False,
154
- no_cuda = False,
155
- use_cpu = False,
156
- use_mps_device = False,
157
- seed = 3407,
158
- data_seed = 3407,
159
- jit_mode_eval = False,
160
- use_ipex = False,
161
- bf16 = False,
162
- fp16 = False,
163
- fp16_opt_level = 'O1',
164
- half_precision_backend = 'auto',
165
- bf16_full_eval = False,
166
- fp16_full_eval = False,
167
- tf32 = None,
168
- local_rank = -1,
169
- ddp_backend = None,
170
- tpu_num_cores = None,
171
- tpu_metrics_debug = False,
172
- debug = '',
173
- dataloader_drop_last = False,
174
- eval_steps = None,
175
- dataloader_num_workers = 0,
176
- dataloader_prefetch_factor = None,
177
- past_index = -1,
178
- run_name = None,
179
- disable_tqdm = None,
180
- remove_unused_columns = True,
181
- label_names = None,
182
- load_best_model_at_end = False,
183
- metric_for_best_model = None,
184
- greater_is_better = None,
185
- ignore_data_skip = False,
186
- fsdp = '',
187
- fsdp_min_num_params = 0,
188
- fsdp_config = None,
189
- fsdp_transformer_layer_cls_to_wrap = None,
190
- accelerator_config = None,
191
- deepspeed = None,
192
- label_smoothing_factor = 0.0,
193
- optim = 'adamw_8bit',
194
- optim_args = None,
195
- adafactor = False,
196
- group_by_length = False,
197
- length_column_name = 'length',
198
- report_to = None,
199
- ddp_find_unused_parameters = None,
200
- ddp_bucket_cap_mb = None,
201
- ddp_broadcast_buffers = None,
202
- dataloader_pin_memory = True,
203
- dataloader_persistent_workers = False,
204
- skip_memory_metrics = True,
205
- use_legacy_prediction_loop = False,
206
- push_to_hub = False,
207
- resume_from_checkpoint = None,
208
- hub_model_id = None,
209
- hub_strategy = 'every_save',
210
- hub_token = None,
211
- hub_private_repo = None,
212
- hub_always_push = False,
213
- gradient_checkpointing = False,
214
- gradient_checkpointing_kwargs = None,
215
- include_inputs_for_metrics = False,
216
- eval_do_concat_batches = True,
217
- fp16_backend = 'auto',
218
- evaluation_strategy = None,
219
- push_to_hub_model_id = None,
220
- push_to_hub_organization = None,
221
- push_to_hub_token = None,
222
- mp_parameters = '',
223
- auto_find_batch_size = False,
224
- full_determinism = False,
225
- torchdynamo = None,
226
- ray_scope = 'last',
227
- ddp_timeout = 1800,
228
- torch_compile = False,
229
- torch_compile_backend = None,
230
- torch_compile_mode = None,
231
- dispatch_batches = None,
232
- split_batches = None,
233
- include_tokens_per_second = False,
234
- include_num_input_tokens_seen = False,
235
- neftune_noise_alpha = None,
236
- optim_target_modules = None,
237
- batch_eval_metrics = False,
238
- eval_on_start = False,
239
- use_liger_kernel = False,
240
- eval_use_gather_object = False,
241
- average_tokens_across_devices = False,
242
- max_length = 1024,
243
- max_prompt_length = 512,
244
- max_completion_length = None,
245
- beta = 0.1,
246
- label_smoothing = 0.0,
247
- loss_type = 'sigmoid',
248
- disable_dropout = True,
249
- cpo_alpha = 1.0,
250
- simpo_gamma = 0.5,
251
- label_pad_token_id = -100,
252
- padding_value = None,
253
- truncation_mode = 'keep_end',
254
- generate_during_eval = False,
255
- is_encoder_decoder = None,
256
- model_init_kwargs = None,
257
- dataset_num_proc = None,
258
- vllm_sampling_params = None,
259
- unsloth_num_chunks = -1,
260
- **kwargs,
261
- ):
262
- if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
263
- if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
264
- if output_dir is None and save_strategy == 'steps' and save_steps == 500:
265
- output_dir = 'unsloth_training_checkpoints'
266
- save_strategy = 'no'
267
- if dataset_num_proc is None:
268
- from multiprocessing import cpu_count
269
- dataset_num_proc = cpu_count()
270
-
271
- super().__init__(
272
- output_dir = output_dir,
273
- overwrite_output_dir = overwrite_output_dir,
274
- do_train = do_train,
275
- do_eval = do_eval,
276
- do_predict = do_predict,
277
- eval_strategy = eval_strategy,
278
- prediction_loss_only = prediction_loss_only,
279
- per_device_train_batch_size = per_device_train_batch_size,
280
- per_device_eval_batch_size = per_device_eval_batch_size,
281
- per_gpu_train_batch_size = per_gpu_train_batch_size,
282
- per_gpu_eval_batch_size = per_gpu_eval_batch_size,
283
- gradient_accumulation_steps = gradient_accumulation_steps,
284
- eval_accumulation_steps = eval_accumulation_steps,
285
- eval_delay = eval_delay,
286
- torch_empty_cache_steps = torch_empty_cache_steps,
287
- learning_rate = learning_rate,
288
- weight_decay = weight_decay,
289
- adam_beta1 = adam_beta1,
290
- adam_beta2 = adam_beta2,
291
- adam_epsilon = adam_epsilon,
292
- max_grad_norm = max_grad_norm,
293
- num_train_epochs = num_train_epochs,
294
- max_steps = max_steps,
295
- lr_scheduler_type = lr_scheduler_type,
296
- warmup_ratio = warmup_ratio,
297
- warmup_steps = warmup_steps,
298
- log_level = log_level,
299
- log_level_replica = log_level_replica,
300
- log_on_each_node = log_on_each_node,
301
- logging_dir = logging_dir,
302
- logging_strategy = logging_strategy,
303
- logging_first_step = logging_first_step,
304
- logging_steps = logging_steps,
305
- logging_nan_inf_filter = logging_nan_inf_filter,
306
- save_strategy = save_strategy,
307
- save_steps = save_steps,
308
- save_total_limit = save_total_limit,
309
- save_safetensors = save_safetensors,
310
- save_on_each_node = save_on_each_node,
311
- save_only_model = save_only_model,
312
- restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
313
- no_cuda = no_cuda,
314
- use_cpu = use_cpu,
315
- use_mps_device = use_mps_device,
316
- seed = seed,
317
- data_seed = data_seed,
318
- jit_mode_eval = jit_mode_eval,
319
- use_ipex = use_ipex,
320
- bf16 = bf16,
321
- fp16 = fp16,
322
- fp16_opt_level = fp16_opt_level,
323
- half_precision_backend = half_precision_backend,
324
- bf16_full_eval = bf16_full_eval,
325
- fp16_full_eval = fp16_full_eval,
326
- tf32 = tf32,
327
- local_rank = local_rank,
328
- ddp_backend = ddp_backend,
329
- tpu_num_cores = tpu_num_cores,
330
- tpu_metrics_debug = tpu_metrics_debug,
331
- debug = debug,
332
- dataloader_drop_last = dataloader_drop_last,
333
- eval_steps = eval_steps,
334
- dataloader_num_workers = dataloader_num_workers,
335
- dataloader_prefetch_factor = dataloader_prefetch_factor,
336
- past_index = past_index,
337
- run_name = run_name,
338
- disable_tqdm = disable_tqdm,
339
- remove_unused_columns = remove_unused_columns,
340
- label_names = label_names,
341
- load_best_model_at_end = load_best_model_at_end,
342
- metric_for_best_model = metric_for_best_model,
343
- greater_is_better = greater_is_better,
344
- ignore_data_skip = ignore_data_skip,
345
- fsdp = fsdp,
346
- fsdp_min_num_params = fsdp_min_num_params,
347
- fsdp_config = fsdp_config,
348
- fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
349
- accelerator_config = accelerator_config,
350
- deepspeed = deepspeed,
351
- label_smoothing_factor = label_smoothing_factor,
352
- optim = optim,
353
- optim_args = optim_args,
354
- adafactor = adafactor,
355
- group_by_length = group_by_length,
356
- length_column_name = length_column_name,
357
- report_to = report_to,
358
- ddp_find_unused_parameters = ddp_find_unused_parameters,
359
- ddp_bucket_cap_mb = ddp_bucket_cap_mb,
360
- ddp_broadcast_buffers = ddp_broadcast_buffers,
361
- dataloader_pin_memory = dataloader_pin_memory,
362
- dataloader_persistent_workers = dataloader_persistent_workers,
363
- skip_memory_metrics = skip_memory_metrics,
364
- use_legacy_prediction_loop = use_legacy_prediction_loop,
365
- push_to_hub = push_to_hub,
366
- resume_from_checkpoint = resume_from_checkpoint,
367
- hub_model_id = hub_model_id,
368
- hub_strategy = hub_strategy,
369
- hub_token = hub_token,
370
- hub_private_repo = hub_private_repo,
371
- hub_always_push = hub_always_push,
372
- gradient_checkpointing = gradient_checkpointing,
373
- gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
374
- include_inputs_for_metrics = include_inputs_for_metrics,
375
- eval_do_concat_batches = eval_do_concat_batches,
376
- fp16_backend = fp16_backend,
377
- evaluation_strategy = evaluation_strategy,
378
- push_to_hub_model_id = push_to_hub_model_id,
379
- push_to_hub_organization = push_to_hub_organization,
380
- push_to_hub_token = push_to_hub_token,
381
- mp_parameters = mp_parameters,
382
- auto_find_batch_size = auto_find_batch_size,
383
- full_determinism = full_determinism,
384
- torchdynamo = torchdynamo,
385
- ray_scope = ray_scope,
386
- ddp_timeout = ddp_timeout,
387
- torch_compile = torch_compile,
388
- torch_compile_backend = torch_compile_backend,
389
- torch_compile_mode = torch_compile_mode,
390
- dispatch_batches = dispatch_batches,
391
- split_batches = split_batches,
392
- include_tokens_per_second = include_tokens_per_second,
393
- include_num_input_tokens_seen = include_num_input_tokens_seen,
394
- neftune_noise_alpha = neftune_noise_alpha,
395
- optim_target_modules = optim_target_modules,
396
- batch_eval_metrics = batch_eval_metrics,
397
- eval_on_start = eval_on_start,
398
- use_liger_kernel = use_liger_kernel,
399
- eval_use_gather_object = eval_use_gather_object,
400
- average_tokens_across_devices = average_tokens_across_devices,
401
- max_length = max_length,
402
- max_prompt_length = max_prompt_length,
403
- max_completion_length = max_completion_length,
404
- beta = beta,
405
- label_smoothing = label_smoothing,
406
- loss_type = loss_type,
407
- disable_dropout = disable_dropout,
408
- cpo_alpha = cpo_alpha,
409
- simpo_gamma = simpo_gamma,
410
- label_pad_token_id = label_pad_token_id,
411
- padding_value = padding_value,
412
- truncation_mode = truncation_mode,
413
- generate_during_eval = generate_during_eval,
414
- is_encoder_decoder = is_encoder_decoder,
415
- model_init_kwargs = model_init_kwargs,
416
- dataset_num_proc = dataset_num_proc,**kwargs)
417
- self.vllm_sampling_params = vllm_sampling_params
418
- self.unsloth_num_chunks = unsloth_num_chunks
419
- pass
420
-
421
- class _UnslothCPOTrainer(Trainer):
422
- r""""""
423
-
424
- _tag_names = ["trl", "cpo"]
425
-
426
- def __init__(
427
- self,
428
- model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
429
- args: Optional[CPOConfig] = None,
430
- data_collator: Optional[DataCollator] = None,
431
- train_dataset: Optional[Dataset] = None,
432
- eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
433
- processing_class: Optional[
434
- Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
435
- ] = None,
436
- model_init: Optional[Callable[[], PreTrainedModel]] = None,
437
- callbacks: Optional[list[TrainerCallback]] = None,
438
- optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
439
- preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
440
- peft_config: Optional[dict] = None,
441
- compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
442
- ):
443
- if args.model_init_kwargs is None:
444
- model_init_kwargs = {}
445
- elif not isinstance(model, str):
446
- raise ValueError("You passed model_kwargs to the CPOTrainer. But your model is already instantiated.")
447
- else:
448
- model_init_kwargs = args.model_init_kwargs
449
- torch_dtype = model_init_kwargs.get("torch_dtype")
450
- if torch_dtype is not None:
451
- # Convert to `torch.dtype` if an str is passed
452
- if isinstance(torch_dtype, str) and torch_dtype != "auto":
453
- torch_dtype = getattr(torch, torch_dtype)
454
- if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
455
- raise ValueError(
456
- f"Invalid `torch_dtype` passed to the CPOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
457
- )
458
- model_init_kwargs["torch_dtype"] = torch_dtype
459
-
460
- if isinstance(model, str):
461
- model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
462
-
463
- # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
464
- # has been called in order to properly call autocast if needed.
465
- self._peft_has_been_casted_to_bf16 = False
466
-
467
- if not is_peft_available() and peft_config is not None:
468
- raise ValueError(
469
- "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
470
- )
471
- elif is_peft_available() and peft_config is not None:
472
- # if model is a peft model and we have a peft_config, we merge and unload it first
473
- if isinstance(model, PeftModel):
474
- model = model.merge_and_unload()
475
-
476
- if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
477
- _support_gc_kwargs = hasattr(
478
- args, "gradient_checkpointing_kwargs"
479
- ) and "gradient_checkpointing_kwargs" in list(
480
- inspect.signature(prepare_model_for_kbit_training).parameters
481
- )
482
-
483
- prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
484
-
485
- if _support_gc_kwargs:
486
- prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
487
-
488
- model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
489
- elif getattr(args, "gradient_checkpointing", False):
490
- # For backward compatibility with older versions of transformers
491
- if hasattr(model, "enable_input_require_grads"):
492
- model.enable_input_require_grads()
493
- else:
494
-
495
- def make_inputs_require_grad(module, input, output):
496
- output.requires_grad_(True)
497
-
498
- model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
499
-
500
- # get peft model with the given config
501
- model = model
502
- if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
503
- peft_module_casting_to_bf16(model)
504
- # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
505
- self._peft_has_been_casted_to_bf16 = True
506
-
507
- # For models that use gradient_checkpointing, we need to attach a hook that enables input
508
- # to explicitly have `requires_grad=True`, otherwise training will either silently
509
- # fail or completely fail.
510
- elif getattr(args, "gradient_checkpointing", False):
511
- # For backward compatibility with older versions of transformers
512
- if hasattr(model, "enable_input_require_grads"):
513
- model.enable_input_require_grads()
514
- else:
515
-
516
- def make_inputs_require_grad(module, input, output):
517
- output.requires_grad_(True)
518
-
519
- model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
520
-
521
- if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
522
- raise ValueError(
523
- "`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
524
- " Please install `wandb` or `comet-ml` to resolve."
525
- )
526
-
527
- if model is not None:
528
- self.is_encoder_decoder = model.config.is_encoder_decoder
529
- elif args.is_encoder_decoder is None:
530
- raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
531
- else:
532
- self.is_encoder_decoder = args.is_encoder_decoder
533
-
534
- if self.is_encoder_decoder:
535
- self.decoder_start_token_id = model.config.decoder_start_token_id
536
- self.pad_token_id = model.config.pad_token_id
537
-
538
- if processing_class is None:
539
- raise ValueError("processing_class must be specified to tokenize a CPO dataset.")
540
- if args.max_length is None:
541
- warnings.warn(
542
- "`max_length` is not set in the CPOConfig's init"
543
- " it will default to `512` by default, but you should do it yourself in the future.",
544
- UserWarning,
545
- )
546
- max_length = 512
547
- else:
548
- max_length = args.max_length
549
- if args.max_prompt_length is None:
550
- warnings.warn(
551
- "`max_prompt_length` is not set in the CPOConfig's init"
552
- " it will default to `128` by default, but you should do it yourself in the future.",
553
- UserWarning,
554
- )
555
- max_prompt_length = 128
556
- else:
557
- max_prompt_length = args.max_prompt_length
558
-
559
- if args.max_completion_length is None and self.is_encoder_decoder:
560
- warnings.warn(
561
- "When using an encoder decoder architecture, you should set `max_completion_length` in the CPOConfig's init"
562
- " it will default to `128` by default, but you should do it yourself in the future.",
563
- UserWarning,
564
- )
565
- max_completion_length = 128
566
- else:
567
- max_completion_length = args.max_completion_length
568
-
569
- if data_collator is None:
570
- data_collator = DPODataCollatorWithPadding(
571
- pad_token_id=processing_class.pad_token_id,
572
- label_pad_token_id=args.label_pad_token_id,
573
- is_encoder_decoder=self.is_encoder_decoder,
574
- )
575
-
576
- if args.remove_unused_columns:
577
- args.remove_unused_columns = False
578
- # warn users
579
- warnings.warn(
580
- "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments"
581
- " we have set it for you, but you should do it yourself in the future.",
582
- UserWarning,
583
- )
584
-
585
- self.use_dpo_data_collator = True
586
- else:
587
- self.use_dpo_data_collator = False
588
-
589
- # Disable dropout in the model
590
- if args.disable_dropout:
591
- disable_dropout_in_model(model)
592
-
593
- self.max_length = max_length
594
- self.generate_during_eval = args.generate_during_eval
595
- self.label_pad_token_id = args.label_pad_token_id
596
- self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id
597
- self.max_prompt_length = max_prompt_length
598
- self.truncation_mode = args.truncation_mode
599
- self.max_completion_length = max_completion_length
600
- self.processing_class = processing_class
601
-
602
- if args.loss_type in ["hinge", "ipo"] and args.label_smoothing > 0:
603
- warnings.warn(
604
- f"You are using the {args.loss_type} loss type that does not support label smoothing. The "
605
- "`label_smoothing` parameter will be ignored. Set `label_smoothing` to `0.0` to remove this warning.",
606
- UserWarning,
607
- )
608
- if args.loss_type == "kto_pair":
609
- raise ValueError("Support for kto_pair has been removed in CPOTrainer. Please use KTOTrainer.")
610
-
611
- self.beta = args.beta
612
- self.label_smoothing = args.label_smoothing
613
- self.loss_type = args.loss_type
614
- self.cpo_alpha = args.cpo_alpha
615
- self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
616
- self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
617
- if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
618
- warnings.warn(
619
- "You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
620
- "`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
621
- "greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
622
- "loss.",
623
- UserWarning,
624
- )
625
-
626
- if args.loss_type == "simpo":
627
- self.simpo_gamma = args.simpo_gamma
628
-
629
- self._stored_metrics = defaultdict(lambda: defaultdict(list))
630
-
631
- # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
632
- # input tensor associated with the key "input_ids". However, in CPO, the sampled data does not include the
633
- # "input_ids" key. Instead, the available keys are "prompt_input_ids", "chosen_input_ids", and
634
- # "rejected_input_ids". As a result, the trainer issues the warning: "Could not estimate the number of tokens
635
- # of the input, floating-point operations will not be computed." To suppress this warning, we set the
636
- # "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate
637
- # that the warning has already been issued.
638
- model.warnings_issued["estimate_tokens"] = True
639
-
640
- # Compute that only on the main process for faster data processing.
641
- # see: https://github.com/huggingface/trl/pull/1255
642
- with PartialState().local_main_process_first():
643
- # Extract the prompt if needed, and apply the chat template if needed
644
- train_dataset = train_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
645
- train_dataset = train_dataset.map(
646
- maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, num_proc=args.dataset_num_proc
647
- )
648
- if eval_dataset is not None:
649
- eval_dataset = eval_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
650
- eval_dataset = eval_dataset.map(
651
- maybe_apply_chat_template,
652
- fn_kwargs={"tokenizer": processing_class},
653
- num_proc=args.dataset_num_proc,
654
- )
655
-
656
- # tokenize the dataset
657
- train_dataset = train_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
658
- if eval_dataset is not None:
659
- eval_dataset = eval_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
660
-
661
- super().__init__(
662
- model=model,
663
- args=args,
664
- data_collator=data_collator,
665
- train_dataset=train_dataset,
666
- eval_dataset=eval_dataset,
667
- processing_class=processing_class,
668
- model_init=model_init,
669
- compute_metrics=compute_metrics,
670
- callbacks=callbacks,
671
- optimizers=optimizers,
672
- preprocess_logits_for_metrics=preprocess_logits_for_metrics,
673
- )
674
-
675
- # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
676
- # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
677
- # self.model_accepts_loss_kwargs to False to enable scaling.
678
- self.model_accepts_loss_kwargs = False
679
-
680
- # Add tags for models that have been loaded with the correct transformers version
681
- if hasattr(self.model, "add_model_tags"):
682
- self.model.add_model_tags(self._tag_names)
683
-
684
- if not hasattr(self, "accelerator"):
685
- raise AttributeError(
686
- "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
687
- )
688
-
689
- def build_tokenized_answer(self, prompt, answer):
690
- """
691
- Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`.
692
- It does ensure `enc(a + b) = enc(a) + enc(a + b)[len(enc(a)):]`.
693
- Reference:
694
- https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
695
- """
696
-
697
- full_tokenized = self.processing_class(prompt + answer, add_special_tokens=False)
698
- prompt_input_ids = self.processing_class(prompt, add_special_tokens=False)["input_ids"]
699
-
700
- answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids) :]
701
- answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids) :]
702
-
703
- # Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]`
704
- full_concat_input_ids = np.concatenate([prompt_input_ids, answer_input_ids])
705
-
706
- # Prepare input tokens for token by token comparison
707
- full_input_ids = np.array(full_tokenized["input_ids"])
708
-
709
- if len(full_input_ids) != len(full_concat_input_ids):
710
- raise ValueError("Prompt input ids and answer input ids should have the same length.")
711
-
712
- # On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens
713
- # can be merged together when tokenizing prompt+answer. This could result
714
- # on the last token from the prompt being different when tokenized on its own
715
- # vs when done as prompt+answer.
716
- response_token_ids_start_idx = len(prompt_input_ids)
717
-
718
- # If tokenized prompt is different than both prompt+answer, then it means the
719
- # last token has changed due to merging.
720
- if prompt_input_ids != full_tokenized["input_ids"][:response_token_ids_start_idx]:
721
- response_token_ids_start_idx -= 1
722
-
723
- prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx]
724
- prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx]
725
-
726
- if len(prompt_input_ids) != len(prompt_attention_mask):
727
- raise ValueError("Prompt input ids and attention mask should have the same length.")
728
-
729
- answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:]
730
- answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:]
731
-
732
- return dict(
733
- prompt_input_ids=prompt_input_ids,
734
- prompt_attention_mask=prompt_attention_mask,
735
- input_ids=answer_input_ids,
736
- attention_mask=answer_attention_mask,
737
- )
738
-
739
- def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module]] = None) -> dict:
740
- """Tokenize a single row from a CPO specific dataset.
741
-
742
- At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation
743
- in case the prompt + chosen or prompt + rejected responses is/are too long. First
744
- we truncate the prompt; if we're still too long, we truncate the chosen/rejected.
745
-
746
- We also create the labels for the chosen/rejected responses, which are of length equal to
747
- the sum of the length of the prompt and the chosen/rejected response, with
748
- label_pad_token_id for the prompt tokens.
749
- """
750
- batch = {}
751
- prompt = feature["prompt"]
752
- chosen = feature["chosen"]
753
- rejected = feature["rejected"]
754
-
755
- if not self.is_encoder_decoder:
756
- # Check issues below for more details
757
- # 1. https://github.com/huggingface/trl/issues/907
758
- # 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
759
- # 3. https://github.com/LianjiaTech/BELLE/issues/337
760
-
761
- if not isinstance(prompt, str):
762
- raise ValueError(f"prompt should be an str but got {type(prompt)}")
763
- prompt_tokens = self.processing_class(prompt, add_special_tokens=False)
764
- prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()}
765
-
766
- if not isinstance(chosen, str):
767
- raise ValueError(f"chosen should be an str but got {type(chosen)}")
768
- chosen_tokens = self.build_tokenized_answer(prompt, chosen)
769
-
770
- if not isinstance(rejected, str):
771
- raise ValueError(f"rejected should be an str but got {type(rejected)}")
772
- rejected_tokens = self.build_tokenized_answer(prompt, rejected)
773
-
774
- # Last prompt token might get merged by tokenizer and
775
- # it should not be included for generation if that happens
776
- prompt_len_input_ids = len(prompt_tokens["prompt_input_ids"])
777
-
778
- chosen_prompt_len_input_ids = len(chosen_tokens["prompt_input_ids"])
779
- rejected_prompt_len_input_ids = len(rejected_tokens["prompt_input_ids"])
780
- prompt_len_input_ids = min(chosen_prompt_len_input_ids, rejected_prompt_len_input_ids)
781
-
782
- for k, v in prompt_tokens.items():
783
- prompt_tokens[k] = v[:prompt_len_input_ids]
784
-
785
- # Make sure prompts only have one different token at most an
786
- # and length only differs by 1 at most
787
- num_diff_tokens = sum(
788
- [a != b for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"])]
789
- )
790
- num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids)
791
- if num_diff_tokens > 1 or num_diff_len > 1:
792
- raise ValueError(
793
- "Chosen and rejected prompt_input_ids might only differ on the "
794
- "last token due to tokenizer merge ops."
795
- )
796
-
797
- # add BOS token to head of prompt. Avoid adding if it's already there
798
- prompt_tokens, chosen_tokens, rejected_tokens = add_bos_token_if_needed(
799
- self.processing_class.bos_token_id,
800
- prompt_len_input_ids,
801
- prompt_tokens,
802
- chosen_prompt_len_input_ids,
803
- chosen_tokens,
804
- rejected_prompt_len_input_ids,
805
- rejected_tokens,
806
- )
807
-
808
- # add EOS token to end of answer. Avoid adding if it's already there
809
- chosen_tokens, rejected_tokens = add_eos_token_if_needed(
810
- self.processing_class.eos_token_id, chosen_tokens, rejected_tokens
811
- )
812
-
813
- longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))
814
-
815
- # if combined sequence is too long, truncate the prompt
816
- for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]:
817
- if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
818
- if self.truncation_mode == "keep_start":
819
- for k in ["prompt_input_ids", "prompt_attention_mask"]:
820
- answer_tokens[k] = answer_tokens[k][: self.max_prompt_length]
821
- elif self.truncation_mode == "keep_end":
822
- for k in ["prompt_input_ids", "prompt_attention_mask"]:
823
- answer_tokens[k] = answer_tokens[k][-self.max_prompt_length :]
824
- else:
825
- raise ValueError(f"Unknown truncation mode: {self.truncation_mode}")
826
-
827
- # if that's still too long, truncate the response
828
- for answer_tokens in [chosen_tokens, rejected_tokens]:
829
- if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
830
- for k in ["input_ids", "attention_mask"]:
831
- answer_tokens[k] = answer_tokens[k][: self.max_length - self.max_prompt_length]
832
-
833
- # Create labels
834
- chosen_sequence_tokens = {
835
- k: chosen_tokens[f"prompt_{k}"] + chosen_tokens[k] for k in ["input_ids", "attention_mask"]
836
- }
837
- rejected_sequence_tokens = {
838
- k: rejected_tokens[f"prompt_{k}"] + rejected_tokens[k] for k in ["input_ids", "attention_mask"]
839
- }
840
- chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:]
841
- chosen_sequence_tokens["labels"][: len(chosen_tokens["prompt_input_ids"])] = [
842
- self.label_pad_token_id
843
- ] * len(chosen_tokens["prompt_input_ids"])
844
- rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:]
845
- rejected_sequence_tokens["labels"][: len(rejected_tokens["prompt_input_ids"])] = [
846
- self.label_pad_token_id
847
- ] * len(rejected_tokens["prompt_input_ids"])
848
-
849
- for k, toks in {
850
- "chosen_": chosen_sequence_tokens,
851
- "rejected_": rejected_sequence_tokens,
852
- "": prompt_tokens,
853
- }.items():
854
- for type_key, tokens in toks.items():
855
- if type_key == "token_type_ids":
856
- continue
857
- batch[f"{k}{type_key}"] = tokens
858
-
859
- else:
860
- chosen_tokens = self.processing_class(
861
- chosen, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
862
- )
863
- rejected_tokens = self.processing_class(
864
- rejected, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
865
- )
866
- prompt_tokens = self.processing_class(
867
- prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True
868
- )
869
-
870
- batch["chosen_labels"] = chosen_tokens["input_ids"]
871
- batch["rejected_labels"] = rejected_tokens["input_ids"]
872
- batch["prompt_input_ids"] = prompt_tokens["input_ids"]
873
- batch["prompt_attention_mask"] = prompt_tokens["attention_mask"]
874
-
875
- if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"):
876
- batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
877
- labels=torch.tensor(batch["rejected_labels"])
878
- )
879
- batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
880
- labels=torch.tensor(batch["chosen_labels"])
881
- )
882
-
883
- return batch
884
-
885
- @staticmethod
886
- def concatenated_inputs(
887
- batch: dict[str, Union[list, torch.LongTensor]],
888
- is_encoder_decoder: bool = False,
889
- label_pad_token_id: int = -100,
890
- padding_value: int = 0,
891
- device: Optional[torch.device] = None,
892
- ) -> dict[str, torch.LongTensor]:
893
- """Concatenate the chosen and rejected inputs into a single tensor.
894
-
895
- Args:
896
- batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length).
897
- is_encoder_decoder: Whether the model is an encoder-decoder model.
898
- label_pad_token_id: The label pad token id.
899
- padding_value: The padding value to use for the concatenated inputs_ids.
900
- device: The device for the concatenated inputs.
901
-
902
- Returns:
903
- A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'.
904
- """
905
- concatenated_batch = {}
906
-
907
- if is_encoder_decoder:
908
- max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1])
909
- else:
910
- max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1])
911
-
912
- for k in batch:
913
- if k.startswith("chosen") and isinstance(batch[k], torch.Tensor):
914
- if "labels" in k or is_encoder_decoder:
915
- pad_value = label_pad_token_id
916
- elif k.endswith("_input_ids"):
917
- pad_value = padding_value
918
- elif k.endswith("_attention_mask"):
919
- pad_value = 0
920
- concatenated_key = k.replace("chosen", "concatenated")
921
- concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value)
922
- for k in batch:
923
- if k.startswith("rejected") and isinstance(batch[k], torch.Tensor):
924
- if "labels" in k or is_encoder_decoder:
925
- pad_value = label_pad_token_id
926
- elif k.endswith("_input_ids"):
927
- pad_value = padding_value
928
- elif k.endswith("_attention_mask"):
929
- pad_value = 0
930
- concatenated_key = k.replace("rejected", "concatenated")
931
- concatenated_batch[concatenated_key] = torch.cat(
932
- (
933
- concatenated_batch[concatenated_key],
934
- pad_to_length(batch[k], max_length, pad_value=pad_value),
935
- ),
936
- dim=0,
937
- ).to(device=device)
938
-
939
- if is_encoder_decoder:
940
- concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1).to(device=device)
941
- concatenated_batch["concatenated_attention_mask"] = (
942
- batch["prompt_attention_mask"].repeat(2, 1).to(device=device)
943
- )
944
-
945
- return concatenated_batch
946
-
947
- def cpo_loss(
948
- self,
949
- policy_chosen_logps: torch.FloatTensor,
950
- policy_rejected_logps: torch.FloatTensor,
951
- ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
952
- """Compute the CPO loss for a batch of policy and reference model log probabilities.
953
-
954
- Args:
955
- policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
956
- policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
957
-
958
- Returns:
959
- A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
960
- The losses tensor contains the CPO loss for each example in the batch.
961
- The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
962
- """
963
- logits = (policy_chosen_logps - policy_rejected_logps).to(self.accelerator.device)
964
-
965
- # The beta is a temperature parameter for the CPO loss, typically something in the range of 0.1 to 0.5.
966
- # We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the labels and
967
- # calculates a conservative CPO loss.
968
-
969
- if self.loss_type == "simpo":
970
- gamma_logratios = self.simpo_gamma / self.beta
971
- logits = logits - gamma_logratios
972
- # This reduces to Equation 3 from the CPO paper when label_smoothing -> 0.
973
- losses = (
974
- -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
975
- - F.logsigmoid(-self.beta * logits) * self.label_smoothing
976
- )
977
- elif self.loss_type == "sigmoid":
978
- # This reduces to Equation 3 from the CPO paper when label_smoothing -> 0.
979
- losses = (
980
- -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
981
- - F.logsigmoid(-self.beta * logits) * self.label_smoothing
982
- )
983
- elif self.loss_type == "hinge":
984
- losses = torch.relu(1 - self.beta * logits)
985
- elif self.loss_type == "ipo":
986
- # eqn (17) of the paper where beta is the regularization parameter for the IPO loss, denoted by tau in the paper.
987
- losses = (logits - 1 / (2 * self.beta)) ** 2
988
- else:
989
- raise ValueError(
990
- f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'simpo']"
991
- )
992
-
993
- chosen_rewards = self.beta * (policy_chosen_logps.to(self.accelerator.device)).detach()
994
- rejected_rewards = self.beta * (policy_rejected_logps.to(self.accelerator.device)).detach()
995
-
996
- return losses, chosen_rewards, rejected_rewards
997
-
998
- @staticmethod
999
- def get_batch_logps(
1000
- logits: torch.FloatTensor,
1001
- labels: torch.LongTensor,
1002
- average_log_prob: bool = False,
1003
- label_pad_token_id: int = -100,
1004
- is_encoder_decoder: bool = False,
1005
- ) -> torch.FloatTensor:
1006
- """Compute the log probabilities of the given labels under the given logits.
1007
-
1008
- Args:
1009
- logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
1010
- labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
1011
- average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
1012
- label_pad_token_id: The label pad token id.
1013
- is_encoder_decoder: Whether the model is an encoder-decoder model.
1014
-
1015
- Returns:
1016
- A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
1017
- """
1018
- if logits.shape[:-1] != labels.shape:
1019
- raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
1020
-
1021
- if not is_encoder_decoder:
1022
- labels = labels[:, 1:].clone()
1023
- logits = logits[:, :-1, :]
1024
- loss_mask = labels != label_pad_token_id
1025
-
1026
- # dummy token; we'll ignore the losses on these tokens later
1027
- labels[labels == label_pad_token_id] = 0
1028
-
1029
- per_token_logps = selective_log_softmax(logits, labels)
1030
-
1031
- if average_log_prob:
1032
- return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
1033
- else:
1034
- return (per_token_logps * loss_mask).sum(-1)
1035
-
1036
- def concatenated_forward(
1037
- self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
1038
- ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
1039
- """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
1040
-
1041
- We do this to avoid doing two forward passes, because it's faster for FSDP.
1042
- """
1043
- concatenated_batch = self.concatenated_inputs(
1044
- batch,
1045
- is_encoder_decoder=self.is_encoder_decoder,
1046
- label_pad_token_id=self.label_pad_token_id,
1047
- padding_value=self.padding_value,
1048
- device=self.accelerator.device,
1049
- )
1050
- len_chosen = batch["chosen_labels"].shape[0]
1051
-
1052
- model_kwargs = (
1053
- {
1054
- "decoder_input_ids": self._shift_right(concatenated_batch["concatenated_labels"]),
1055
- }
1056
- if self.is_encoder_decoder
1057
- else {}
1058
- )
1059
-
1060
- if self.aux_loss_enabled:
1061
- model_kwargs["output_router_logits"] = True
1062
-
1063
- outputs = model(
1064
- concatenated_batch["concatenated_input_ids"],
1065
- attention_mask=concatenated_batch["concatenated_attention_mask"],
1066
- use_cache=False,
1067
- **model_kwargs,
1068
- )
1069
- all_logits = outputs.logits
1070
-
1071
- def cross_entropy_loss(logits, labels):
1072
- if not self.is_encoder_decoder:
1073
- # Shift so that tokens < n predict n
1074
- logits = logits[..., :-1, :].contiguous()
1075
- labels = labels[..., 1:].contiguous()
1076
- # Flatten the tokens
1077
- loss_fct = nn.CrossEntropyLoss()
1078
- logits = logits.view(-1, logits.shape[-1])
1079
- labels = labels.view(-1)
1080
- # Enable model parallelism
1081
- labels = labels.to(logits.device)
1082
- loss = loss_fct(logits, labels)
1083
- return loss
1084
-
1085
- labels = concatenated_batch["concatenated_labels"].clone()
1086
-
1087
- if self.cpo_alpha == 0:
1088
- nll_loss = torch.tensor(0.0).to(self.accelerator.device)
1089
- else:
1090
- nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])
1091
-
1092
- all_logps = self.get_batch_logps(
1093
- all_logits,
1094
- concatenated_batch["concatenated_labels"],
1095
- average_log_prob=self.loss_type in ["ipo", "simpo"],
1096
- is_encoder_decoder=self.is_encoder_decoder,
1097
- label_pad_token_id=self.label_pad_token_id,
1098
- )
1099
-
1100
- chosen_logps = all_logps[:len_chosen]
1101
- rejected_logps = all_logps[len_chosen:]
1102
-
1103
- chosen_logits = all_logits[:len_chosen]
1104
- rejected_logits = all_logits[len_chosen:]
1105
-
1106
- if self.aux_loss_enabled:
1107
- return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss, outputs.aux_loss)
1108
-
1109
- return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss)
1110
-
1111
- def get_batch_loss_metrics(
1112
- self,
1113
- model,
1114
- batch: dict[str, Union[list, torch.LongTensor]],
1115
- train_eval: Literal["train", "eval"] = "train",
1116
- ):
1117
- """Compute the CPO loss and other metrics for the given batch of inputs for train or test."""
1118
- metrics = {}
1119
-
1120
- forward_output = self.concatenated_forward(model, batch)
1121
- (
1122
- policy_chosen_logps,
1123
- policy_rejected_logps,
1124
- policy_chosen_logits,
1125
- policy_rejected_logits,
1126
- policy_nll_loss,
1127
- ) = forward_output[:5]
1128
- if self.aux_loss_enabled:
1129
- aux_loss = forward_output[5]
1130
-
1131
- losses, chosen_rewards, rejected_rewards = self.cpo_loss(
1132
- policy_chosen_logps,
1133
- policy_rejected_logps,
1134
- )
1135
-
1136
- loss = losses.mean() + self.cpo_alpha * policy_nll_loss
1137
- reward_accuracies = (chosen_rewards > rejected_rewards).float()
1138
-
1139
- prefix = "eval_" if train_eval == "eval" else ""
1140
- metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean().item()
1141
- metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean().item()
1142
- metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean().item()
1143
- metrics[f"{prefix}rewards/margins"] = (
1144
- self.accelerator.gather_for_metrics(chosen_rewards - rejected_rewards).mean().item()
1145
- )
1146
- metrics[f"{prefix}logps/rejected"] = (
1147
- self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean().item()
1148
- )
1149
- metrics[f"{prefix}logps/chosen"] = (
1150
- self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean().item()
1151
- )
1152
- metrics[f"{prefix}logits/rejected"] = (
1153
- self.accelerator.gather_for_metrics(policy_rejected_logits).detach().mean().item()
1154
- )
1155
- metrics[f"{prefix}logits/chosen"] = (
1156
- self.accelerator.gather_for_metrics(policy_chosen_logits).detach().mean().item()
1157
- )
1158
- metrics[f"{prefix}nll_loss"] = self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean().item()
1159
-
1160
- if self.aux_loss_enabled:
1161
- loss += self.aux_loss_coef * aux_loss
1162
-
1163
- return loss, metrics
1164
-
1165
- def compute_loss(
1166
- self,
1167
- model: Union[PreTrainedModel, nn.Module],
1168
- inputs: dict[str, Union[torch.Tensor, Any]],
1169
- return_outputs=False,
1170
- num_items_in_batch=None,
1171
- ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
1172
- compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1173
-
1174
- with compute_loss_context_manager:
1175
- loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
1176
-
1177
- # force log the metrics
1178
- self.store_metrics(metrics, train_eval="train")
1179
-
1180
- if return_outputs:
1181
- return (loss, metrics)
1182
- return loss
1183
-
1184
- def generate_from_model(self, model, batch: dict[str, torch.LongTensor]) -> str:
1185
- """Generate samples from the model and reference model for the given batch of inputs."""
1186
-
1187
- # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
1188
- # the torch cuda amp context manager as some hidden states are silently casted to full precision.
1189
- generate_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1190
-
1191
- with generate_context_manager:
1192
- policy_output = model.generate(
1193
- input_ids=batch["prompt_input_ids"],
1194
- attention_mask=batch["prompt_attention_mask"],
1195
- max_length=self.max_length,
1196
- do_sample=True,
1197
- pad_token_id=self.processing_class.pad_token_id,
1198
- )
1199
-
1200
- policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id)
1201
- policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True)
1202
-
1203
- return policy_output_decoded
1204
-
1205
- def prediction_step(
1206
- self,
1207
- model: Union[PreTrainedModel, nn.Module],
1208
- inputs: dict[str, Union[torch.Tensor, Any]],
1209
- prediction_loss_only: bool,
1210
- ignore_keys: Optional[list[str]] = None,
1211
- ):
1212
- if ignore_keys is None:
1213
- if hasattr(model, "config"):
1214
- ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
1215
- else:
1216
- ignore_keys = []
1217
-
1218
- prediction_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1219
-
1220
- with torch.no_grad(), prediction_context_manager:
1221
- loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval")
1222
-
1223
- # force log the metrics
1224
- self.store_metrics(metrics, train_eval="eval")
1225
-
1226
- if prediction_loss_only:
1227
- return (loss.detach(), None, None)
1228
-
1229
- # logits for the chosen and rejected samples from model
1230
- logits_dict = {
1231
- "eval_logits/chosen": metrics["eval_logits/chosen"],
1232
- "eval_logits/rejected": metrics["eval_logits/rejected"],
1233
- }
1234
- logits = tuple(v.unsqueeze(dim=0) for k, v in logits_dict.items() if k not in ignore_keys)
1235
- logits = torch.stack(logits).mean(axis=1).to(self.accelerator.device)
1236
- labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
1237
-
1238
- return (loss.detach(), logits, labels)
1239
-
1240
- def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
1241
- for key, value in metrics.items():
1242
- self._stored_metrics[train_eval][key].append(value)
1243
-
1244
- def evaluation_loop(
1245
- self,
1246
- dataloader: DataLoader,
1247
- description: str,
1248
- prediction_loss_only: Optional[bool] = None,
1249
- ignore_keys: Optional[list[str]] = None,
1250
- metric_key_prefix: str = "eval",
1251
- ) -> EvalLoopOutput:
1252
- """
1253
- Overriding built-in evaluation loop to store metrics for each batch.
1254
- Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
1255
-
1256
- Works both with or without labels.
1257
- """
1258
-
1259
- # Sample and save to game log if requested (for one batch to save time)
1260
- if self.generate_during_eval:
1261
- # Generate random indices within the range of the total number of samples
1262
- num_samples = len(dataloader.dataset)
1263
- random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
1264
-
1265
- # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
1266
- random_batch_dataset = dataloader.dataset.select(random_indices)
1267
- random_batch = self.data_collator(random_batch_dataset)
1268
- random_batch = self._prepare_inputs(random_batch)
1269
-
1270
- policy_output_decoded = self.generate_from_model(self.model, random_batch)
1271
-
1272
- table = pd.DataFrame(
1273
- columns=["Prompt", "Policy"],
1274
- data=[
1275
- [prompt, pol[len(prompt) :]] for prompt, pol in zip(random_batch["prompt"], policy_output_decoded)
1276
- ],
1277
- )
1278
- if "wandb" in self.args.report_to:
1279
- wandb.log({"game_log": wandb.Table(data=table)})
1280
-
1281
- if "comet_ml" in self.args.report_to:
1282
- log_table_to_comet_experiment(
1283
- name="game_log.csv",
1284
- table=table,
1285
- )
1286
-
1287
- # Base evaluation
1288
- initial_output = super().evaluation_loop(
1289
- dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
1290
- )
1291
-
1292
- return initial_output
1293
-
1294
- def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
1295
- """
1296
- Log `logs` on the various objects watching training, including stored metrics.
1297
-
1298
- Args:
1299
- logs (`dict[str, float]`):
1300
- The values to log.
1301
- start_time (`float` or `None`, *optional*, defaults to `None`):
1302
- Start time of the training.
1303
- """
1304
- # logs either has 'loss' or 'eval_loss'
1305
- train_eval = "train" if "loss" in logs else "eval"
1306
- # Add averaged stored metrics to logs
1307
- for key, metrics in self._stored_metrics[train_eval].items():
1308
- logs[key] = torch.tensor(metrics).mean().item()
1309
- del self._stored_metrics[train_eval]
1310
-
1311
- if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
1312
- return super().log(logs, start_time)
1313
- else: # transformers<=4.46
1314
- return super().log(logs)
1315
-
1316
- def _shift_right(self, input_ids):
1317
- if self.decoder_start_token_id is None:
1318
- raise ValueError(
1319
- "model.config.decoder_start_token_id has to be defined. It is usually set to the pad_token_id."
1320
- )
1321
-
1322
- # shift inputs to the right
1323
- if is_torch_fx_proxy(input_ids):
1324
- # Item assignment is not supported natively for proxies.
1325
- shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), self.decoder_start_token_id)
1326
- shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
1327
- else:
1328
- shifted_input_ids = input_ids.new_zeros(input_ids.shape)
1329
- shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
1330
- shifted_input_ids[..., 0] = self.decoder_start_token_id
1331
-
1332
- if self.pad_token_id is None:
1333
- raise ValueError("model.config.pad_token_id has to be defined.")
1334
- # replace possible -100 values in labels by `pad_token_id`
1335
- shifted_input_ids.masked_fill_(shifted_input_ids == -100, self.pad_token_id)
1336
-
1337
- return shifted_input_ids
1338
-
1339
- def create_model_card(
1340
- self,
1341
- model_name: Optional[str] = None,
1342
- dataset_name: Optional[str] = None,
1343
- tags: Union[str, list[str], None] = None,
1344
- ):
1345
- """
1346
- Creates a draft of a model card using the information available to the `Trainer`.
1347
-
1348
- Args:
1349
- model_name (`str` or `None`, *optional*, defaults to `None`):
1350
- Name of the model.
1351
- dataset_name (`str` or `None`, *optional*, defaults to `None`):
1352
- Name of the dataset used for training.
1353
- tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
1354
- Tags to be associated with the model card.
1355
- """
1356
- if not self.is_world_process_zero():
1357
- return
1358
-
1359
- if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
1360
- base_model = self.model.config._name_or_path
1361
- else:
1362
- base_model = None
1363
-
1364
- tags = tags or []
1365
- if isinstance(tags, str):
1366
- tags = [tags]
1367
-
1368
- if hasattr(self.model.config, "unsloth_version"):
1369
- tags.append("unsloth")
1370
-
1371
- citation = textwrap.dedent("""\
1372
- @inproceedings{xu2024contrastive,
1373
- title = {{Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation}},
1374
- author = {Haoran Xu and Amr Sharaf and Yunmo Chen and Weiting Tan and Lingfeng Shen and Benjamin Van Durme and Kenton Murray and Young Jin Kim},
1375
- year = 2024,
1376
- booktitle = {Forty-first International Conference on Machine Learning, {ICML} 2024, Vienna, Austria, July 21-27, 2024},
1377
- publisher = {OpenReview.net},
1378
- url = {https://openreview.net/forum?id=51iwkioZpn}
1379
- }""")
1380
-
1381
- model_card = generate_model_card(
1382
- base_model=base_model,
1383
- model_name=model_name,
1384
- hub_model_id=self.hub_model_id,
1385
- dataset_name=dataset_name,
1386
- tags=tags,
1387
- wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
1388
- comet_url=get_comet_experiment_url(),
1389
- trainer_name="CPO",
1390
- trainer_citation=citation,
1391
- paper_title="Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation",
1392
- paper_id="2401.08417",
1393
- )
1394
- model_card.save(os.path.join(self.args.output_dir, "README.md"))
1395
- class UnslothCPOTrainer(_UnslothCPOTrainer):
1396
- """
1397
-
1398
- Initialize CPOTrainer.
1399
-
1400
- Args:
1401
- model (`transformers.PreTrainedModel`):
1402
- The model to train, preferably an `AutoModelForSequenceClassification`.
1403
- args (`CPOConfig`):
1404
- The CPO config arguments to use for training.
1405
- data_collator (`transformers.DataCollator`):
1406
- The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
1407
- which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
1408
- train_dataset (`datasets.Dataset`):
1409
- The dataset to use for training.
1410
- eval_dataset (`datasets.Dataset`):
1411
- The dataset to use for evaluation.
1412
- processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
1413
- Processing class used to process the data. If provided, will be used to automatically process the inputs
1414
- for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
1415
- reuse the fine-tuned model.
1416
- model_init (`Callable[[], transformers.PreTrainedModel]`):
1417
- The model initializer to use for training. If None is specified, the default model initializer will be used.
1418
- callbacks (`list[transformers.TrainerCallback]`):
1419
- The callbacks to use for training.
1420
- optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
1421
- The optimizer and scheduler to use for training.
1422
- preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
1423
- The function to use to preprocess the logits before computing the metrics.
1424
- peft_config (`dict`, defaults to `None`):
1425
- The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
1426
- compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
1427
- The function to use to compute the metrics. Must take a `EvalPrediction` and return
1428
- a dictionary string to metric values.
1429
-
1430
- """
1431
- def __init__(
1432
- self,
1433
- model = None,
1434
- args = None,
1435
- data_collator = None,
1436
- train_dataset = None,
1437
- eval_dataset = None,
1438
- processing_class = None,
1439
- model_init = None,
1440
- callbacks = None,
1441
- preprocess_logits_for_metrics = None,
1442
- peft_config = None,
1443
- compute_metrics = None,
1444
- **kwargs
1445
- ):
1446
- if args is None: args = UnslothCPOConfig()
1447
- use_bf16 = getattr(args, 'bf16', False)
1448
- use_fp16 = getattr(args, 'fp16', False)
1449
- force_float32 = False
1450
- if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
1451
- print('Unsloth: Switching to float32 training since model cannot work with float16')
1452
- force_float32 = True
1453
- mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
1454
- dtype = getattr(model.config, 'torch_dtype', None)
1455
- if dtype is None: dtype = model.get_input_embeddings().dtype
1456
- from unsloth_zoo.utils import _get_dtype
1457
- dtype = _get_dtype(dtype)
1458
- float16 = dtype == torch.float16
1459
- if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
1460
- if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
1461
- if force_float32:
1462
- args.fp16 = False
1463
- args.bf16 = False
1464
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1465
- elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
1466
- args.fp16 = float16
1467
- args.bf16 = not float16
1468
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
1469
- if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
1470
- args.eval_strategy = 'steps'
1471
- if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
1472
- ga_steps = getattr(args, 'gradient_accumulation_steps', None)
1473
- if ga_steps is not None and ga_steps > 1:
1474
- from transformers import __version__ as transformers_version
1475
- if Version(transformers_version) <= Version('4.45.2'):
1476
- print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
1477
- '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
1478
- if getattr(args, 'eval_strategy', 'no') != 'no':
1479
- eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
1480
- if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
1481
- if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
1482
- fp16_full_eval = getattr(args, 'fp16_full_eval', False)
1483
- bf16_full_eval = getattr(args, 'bf16_full_eval', False)
1484
- if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
1485
- if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
1486
- if force_float32:
1487
- args.bf16_full_eval = False
1488
- args.fp16_full_eval = False
1489
- elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
1490
- args.bf16_full_eval = True
1491
- args.fp16_full_eval = False
1492
- elif not bf16_full_eval and not fp16_full_eval:
1493
- args.bf16_full_eval = args.bf16
1494
- args.fp16_full_eval = args.fp16
1495
- _output_logits = False
1496
- if locals().get('compute_metrics', None) is not None: _output_logits = True
1497
- if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
1498
- if _output_logits:
1499
- os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
1500
- if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1501
- pass
1502
- else:
1503
- model_max_seq_length = getattr(model, 'max_seq_length', None)
1504
- args_max_seq_length = getattr(args, 'max_seq_length', None)
1505
- if args_max_seq_length is None and model_max_seq_length is not None:
1506
- max_seq_length = model.max_seq_length
1507
- if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1508
- if model is not None and hasattr(model, 'for_training'):
1509
- model.for_training()
1510
- if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1511
- if 'processing_class' in locals():
1512
- if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1513
- if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1514
- __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
1515
- from unsloth_zoo.vision_utils import UnslothVisionDataCollator
1516
- if not isinstance(data_collator, UnslothVisionDataCollator):
1517
- if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
1518
- data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
1519
- elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
1520
- data_collator = DataCollatorForSeq2Seq(__tokenizer)
1521
- else:
1522
- if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
1523
- if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
1524
- if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
1525
- if not isinstance(data_collator, UnslothVisionDataCollator):
1526
- if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
1527
- if isinstance(data_collator, DataCollatorForSeq2Seq):
1528
- data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
1529
- else:
1530
- data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
1531
- other_metrics = []
1532
-
1533
- from unsloth_zoo.logging_utils import PatchRLStatistics
1534
- PatchRLStatistics('cpo_trainer', other_metrics)
1535
-
1536
- super().__init__(
1537
- model = model,
1538
- args = args,
1539
- data_collator = data_collator,
1540
- train_dataset = train_dataset,
1541
- eval_dataset = eval_dataset,
1542
- processing_class = processing_class,
1543
- model_init = model_init,
1544
- callbacks = callbacks,
1545
- preprocess_logits_for_metrics = preprocess_logits_for_metrics,
1546
- peft_config = peft_config,
1547
- compute_metrics = compute_metrics,**kwargs)
1548
- if hasattr(self, 'neftune_hook_handle'):
1549
- self.neftune_hook_handle.remove()
1550
- if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1551
- if getattr(args, 'neftune_noise_alpha', None) is not None:
1552
- model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1553
- pass
1554
-
1555
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
unsloth_compiled_cache/UnslothDDPOTrainer.py DELETED
@@ -1,872 +0,0 @@
1
- """
2
- 2025.3.13
3
- 2025.3.15
4
- 4.48.3
5
- 0.15.2
6
- __UNSLOTH_VERSIONING__
7
- """
8
- from torch import Tensor
9
- import torch
10
- import torch.nn as nn
11
- from torch.nn import functional as F
12
- from trl.trainer.ddpo_trainer import (Accelerator, Any, Callable, DDPOConfig, DDPOStableDiffusionPipeline, DDPOTrainer, Optional, PerPromptStatTracker, ProjectConfiguration, PyTorchModelHubMixin, Union, defaultdict, futures, generate_model_card, get_comet_experiment_url, is_wandb_available, logger, os, set_seed, textwrap, torch, wandb, warn)
13
-
14
-
15
- import os
16
- from typing import *
17
- from dataclasses import dataclass, field
18
- from packaging.version import Version
19
- import torch
20
- import numpy as np
21
- from contextlib import nullcontext
22
- from torch.nn import functional as F
23
- from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
-
25
- torch_compile_options = {
26
- "epilogue_fusion" : True,
27
- "max_autotune" : False,
28
- "shape_padding" : True,
29
- "trace.enabled" : False,
30
- "triton.cudagraphs" : False,
31
- }
32
-
33
- @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
- def selective_log_softmax(logits, index):
35
- logits = logits.to(torch.float32)
36
- selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
- # loop to reduce peak mem consumption
38
- # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
- logsumexp_values = torch.logsumexp(logits, dim = -1)
40
- per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
- return per_token_logps
42
- @dataclass
43
- class UnslothDDPOConfig(DDPOConfig):
44
- """
45
-
46
- Configuration class for the [`DDPOTrainer`].
47
-
48
- Using [`~transformers.HfArgumentParser`] we can turn this class into
49
- [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
50
- command line.
51
-
52
- Parameters:
53
- exp_name (`str`, *optional*, defaults to `os.path.basename(sys.argv[0])[: -len(".py")]`):
54
- Name of this experiment (by default is the file name without the extension name).
55
- run_name (`str`, *optional*, defaults to `""`):
56
- Name of this run.
57
- seed (`int`, *optional*, defaults to `0`):
58
- Random seed.
59
- log_with (`Literal["wandb", "tensorboard"]]` or `None`, *optional*, defaults to `None`):
60
- Log with either 'wandb' or 'tensorboard', check
61
- https://huggingface.co/docs/accelerate/usage_guides/tracking for more details.
62
- tracker_kwargs (`Dict`, *optional*, defaults to `{}`):
63
- Keyword arguments for the tracker (e.g. wandb_project).
64
- accelerator_kwargs (`Dict`, *optional*, defaults to `{}`):
65
- Keyword arguments for the accelerator.
66
- project_kwargs (`Dict`, *optional*, defaults to `{}`):
67
- Keyword arguments for the accelerator project config (e.g. `logging_dir`).
68
- tracker_project_name (`str`, *optional*, defaults to `"trl"`):
69
- Name of project to use for tracking.
70
- logdir (`str`, *optional*, defaults to `"logs"`):
71
- Top-level logging directory for checkpoint saving.
72
- num_epochs (`int`, *optional*, defaults to `100`):
73
- Number of epochs to train.
74
- save_freq (`int`, *optional*, defaults to `1`):
75
- Number of epochs between saving model checkpoints.
76
- num_checkpoint_limit (`int`, *optional*, defaults to `5`):
77
- Number of checkpoints to keep before overwriting old ones.
78
- mixed_precision (`str`, *optional*, defaults to `"fp16"`):
79
- Mixed precision training.
80
- allow_tf32 (`bool`, *optional*, defaults to `True`):
81
- Allow `tf32` on Ampere GPUs.
82
- resume_from (`str`, *optional*, defaults to `""`):
83
- Resume training from a checkpoint.
84
- sample_num_steps (`int`, *optional*, defaults to `50`):
85
- Number of sampler inference steps.
86
- sample_eta (`float`, *optional*, defaults to `1.0`):
87
- Eta parameter for the DDIM sampler.
88
- sample_guidance_scale (`float`, *optional*, defaults to `5.0`):
89
- Classifier-free guidance weight.
90
- sample_batch_size (`int`, *optional*, defaults to `1`):
91
- Batch size (per GPU) to use for sampling.
92
- sample_num_batches_per_epoch (`int`, *optional*, defaults to `2`):
93
- Number of batches to sample per epoch.
94
- train_batch_size (`int`, *optional*, defaults to `1`):
95
- Batch size (per GPU) to use for training.
96
- train_use_8bit_adam (`bool`, *optional*, defaults to `False`):
97
- Use 8bit Adam optimizer from bitsandbytes.
98
- train_learning_rate (`float`, *optional*, defaults to `3e-4`):
99
- Learning rate.
100
- train_adam_beta1 (`float`, *optional*, defaults to `0.9`):
101
- Adam beta1.
102
- train_adam_beta2 (`float`, *optional*, defaults to `0.999`):
103
- Adam beta2.
104
- train_adam_weight_decay (`float`, *optional*, defaults to `1e-4`):
105
- Adam weight decay.
106
- train_adam_epsilon (`float`, *optional*, defaults to `1e-8`):
107
- Adam epsilon.
108
- train_gradient_accumulation_steps (`int`, *optional*, defaults to `1`):
109
- Number of gradient accumulation steps.
110
- train_max_grad_norm (`float`, *optional*, defaults to `1.0`):
111
- Maximum gradient norm for gradient clipping.
112
- train_num_inner_epochs (`int`, *optional*, defaults to `1`):
113
- Number of inner epochs per outer epoch.
114
- train_cfg (`bool`, *optional*, defaults to `True`):
115
- Whether to use classifier-free guidance during training.
116
- train_adv_clip_max (`float`, *optional*, defaults to `5.0`):
117
- Clip advantages to the range.
118
- train_clip_range (`float`, *optional*, defaults to `1e-4`):
119
- PPO clip range.
120
- train_timestep_fraction (`float`, *optional*, defaults to `1.0`):
121
- Fraction of timesteps to train on.
122
- per_prompt_stat_tracking (`bool`, *optional*, defaults to `False`):
123
- Whether to track statistics for each prompt separately.
124
- per_prompt_stat_tracking_buffer_size (`int`, *optional*, defaults to `16`):
125
- Number of reward values to store in the buffer for each prompt.
126
- per_prompt_stat_tracking_min_count (`int`, *optional*, defaults to `16`):
127
- Minimum number of reward values to store in the buffer.
128
- async_reward_computation (`bool`, *optional*, defaults to `False`):
129
- Whether to compute rewards asynchronously.
130
- max_workers (`int`, *optional*, defaults to `2`):
131
- Maximum number of workers to use for async reward computation.
132
- negative_prompts (`str`, *optional*, defaults to `""`):
133
- Comma-separated list of prompts to use as negative examples.
134
- push_to_hub (`bool`, *optional*, defaults to `False`):
135
- Whether to push the final model checkpoint to the Hub.
136
-
137
- """
138
- vllm_sampling_params: Optional[Any] = field(
139
- default = None,
140
- metadata = {'help': 'vLLM SamplingParams'},
141
- )
142
- unsloth_num_chunks : Optional[int] = field(
143
- default = -1,
144
- metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
145
- )
146
- def __init__(
147
- self,
148
- exp_name = 'main',
149
- run_name = '',
150
- seed = 3407,
151
- log_with = None,
152
- tracker_project_name = 'trl',
153
- logdir = 'logs',
154
- num_epochs = 100,
155
- save_freq = 1,
156
- num_checkpoint_limit = 5,
157
- mixed_precision = 'fp16',
158
- allow_tf32 = True,
159
- resume_from = '',
160
- sample_num_steps = 50,
161
- sample_eta = 1.0,
162
- sample_guidance_scale = 5.0,
163
- sample_batch_size = 1,
164
- sample_num_batches_per_epoch = 2,
165
- train_batch_size = 1,
166
- train_use_8bit_adam = False,
167
- train_learning_rate = 5e-05,
168
- train_adam_beta1 = 0.9,
169
- train_adam_beta2 = 0.999,
170
- train_adam_weight_decay = 0.01,
171
- train_adam_epsilon = 1e-08,
172
- train_gradient_accumulation_steps = 2,
173
- train_max_grad_norm = 1.0,
174
- train_num_inner_epochs = 1,
175
- train_cfg = True,
176
- train_adv_clip_max = 5.0,
177
- train_clip_range = 0.0001,
178
- train_timestep_fraction = 1.0,
179
- per_prompt_stat_tracking = False,
180
- per_prompt_stat_tracking_buffer_size = 16,
181
- per_prompt_stat_tracking_min_count = 16,
182
- async_reward_computation = False,
183
- max_workers = 2,
184
- negative_prompts = '',
185
- push_to_hub = False,
186
- vllm_sampling_params = None,
187
- unsloth_num_chunks = -1,
188
- **kwargs,
189
- ):
190
-
191
- super().__init__(
192
- exp_name = exp_name,
193
- run_name = run_name,
194
- seed = seed,
195
- log_with = log_with,
196
- tracker_project_name = tracker_project_name,
197
- logdir = logdir,
198
- num_epochs = num_epochs,
199
- save_freq = save_freq,
200
- num_checkpoint_limit = num_checkpoint_limit,
201
- mixed_precision = mixed_precision,
202
- allow_tf32 = allow_tf32,
203
- resume_from = resume_from,
204
- sample_num_steps = sample_num_steps,
205
- sample_eta = sample_eta,
206
- sample_guidance_scale = sample_guidance_scale,
207
- sample_batch_size = sample_batch_size,
208
- sample_num_batches_per_epoch = sample_num_batches_per_epoch,
209
- train_batch_size = train_batch_size,
210
- train_use_8bit_adam = train_use_8bit_adam,
211
- train_learning_rate = train_learning_rate,
212
- train_adam_beta1 = train_adam_beta1,
213
- train_adam_beta2 = train_adam_beta2,
214
- train_adam_weight_decay = train_adam_weight_decay,
215
- train_adam_epsilon = train_adam_epsilon,
216
- train_gradient_accumulation_steps = train_gradient_accumulation_steps,
217
- train_max_grad_norm = train_max_grad_norm,
218
- train_num_inner_epochs = train_num_inner_epochs,
219
- train_cfg = train_cfg,
220
- train_adv_clip_max = train_adv_clip_max,
221
- train_clip_range = train_clip_range,
222
- train_timestep_fraction = train_timestep_fraction,
223
- per_prompt_stat_tracking = per_prompt_stat_tracking,
224
- per_prompt_stat_tracking_buffer_size = per_prompt_stat_tracking_buffer_size,
225
- per_prompt_stat_tracking_min_count = per_prompt_stat_tracking_min_count,
226
- async_reward_computation = async_reward_computation,
227
- max_workers = max_workers,
228
- negative_prompts = negative_prompts,
229
- push_to_hub = push_to_hub,**kwargs)
230
- self.vllm_sampling_params = vllm_sampling_params
231
- self.unsloth_num_chunks = unsloth_num_chunks
232
- pass
233
-
234
- class _UnslothDDPOTrainer(PyTorchModelHubMixin):
235
- """"""
236
-
237
- _tag_names = ["trl", "ddpo"]
238
-
239
- def __init__(
240
- self,
241
- config: DDPOConfig,
242
- reward_function: Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor],
243
- prompt_function: Callable[[], tuple[str, Any]],
244
- sd_pipeline: DDPOStableDiffusionPipeline,
245
- image_samples_hook: Optional[Callable[[Any, Any, Any], Any]] = None,
246
- ):
247
- if image_samples_hook is None:
248
- warn("No image_samples_hook provided; no images will be logged")
249
-
250
- self.prompt_fn = prompt_function
251
- self.reward_fn = reward_function
252
- self.config = config
253
- self.image_samples_callback = image_samples_hook
254
-
255
- accelerator_project_config = ProjectConfiguration(**self.config.project_kwargs)
256
-
257
- if self.config.resume_from:
258
- self.config.resume_from = os.path.normpath(os.path.expanduser(self.config.resume_from))
259
- if "checkpoint_" not in os.path.basename(self.config.resume_from):
260
- # get the most recent checkpoint in this directory
261
- checkpoints = list(
262
- filter(
263
- lambda x: "checkpoint_" in x,
264
- os.listdir(self.config.resume_from),
265
- )
266
- )
267
- if len(checkpoints) == 0:
268
- raise ValueError(f"No checkpoints found in {self.config.resume_from}")
269
- checkpoint_numbers = sorted([int(x.split("_")[-1]) for x in checkpoints])
270
- self.config.resume_from = os.path.join(
271
- self.config.resume_from,
272
- f"checkpoint_{checkpoint_numbers[-1]}",
273
- )
274
-
275
- accelerator_project_config.iteration = checkpoint_numbers[-1] + 1
276
-
277
- # number of timesteps within each trajectory to train on
278
- self.num_train_timesteps = int(self.config.sample_num_steps * self.config.train_timestep_fraction)
279
-
280
- self.accelerator = Accelerator(
281
- log_with=self.config.log_with,
282
- mixed_precision=self.config.mixed_precision,
283
- project_config=accelerator_project_config,
284
- # we always accumulate gradients across timesteps; we want config.train.gradient_accumulation_steps to be the
285
- # number of *samples* we accumulate across, so we need to multiply by the number of training timesteps to get
286
- # the total number of optimizer steps to accumulate across.
287
- gradient_accumulation_steps=self.config.train_gradient_accumulation_steps * self.num_train_timesteps,
288
- **self.config.accelerator_kwargs,
289
- )
290
-
291
- is_okay, message = self._config_check()
292
- if not is_okay:
293
- raise ValueError(message)
294
-
295
- is_using_tensorboard = config.log_with is not None and config.log_with == "tensorboard"
296
-
297
- if self.accelerator.is_main_process:
298
- self.accelerator.init_trackers(
299
- self.config.tracker_project_name,
300
- config=dict(ddpo_trainer_config=config.to_dict()) if not is_using_tensorboard else config.to_dict(),
301
- init_kwargs=self.config.tracker_kwargs,
302
- )
303
-
304
- logger.info(f"\n{config}")
305
-
306
- set_seed(self.config.seed, device_specific=True)
307
-
308
- self.sd_pipeline = sd_pipeline
309
-
310
- self.sd_pipeline.set_progress_bar_config(
311
- position=1,
312
- disable=not self.accelerator.is_local_main_process,
313
- leave=False,
314
- desc="Timestep",
315
- dynamic_ncols=True,
316
- )
317
-
318
- # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
319
- # as these weights are only used for inference, keeping weights in full precision is not required.
320
- if self.accelerator.mixed_precision == "fp16":
321
- inference_dtype = torch.float16
322
- elif self.accelerator.mixed_precision == "bf16":
323
- inference_dtype = torch.bfloat16
324
- else:
325
- inference_dtype = torch.float32
326
-
327
- self.sd_pipeline.vae.to(self.accelerator.device, dtype=inference_dtype)
328
- self.sd_pipeline.text_encoder.to(self.accelerator.device, dtype=inference_dtype)
329
- self.sd_pipeline.unet.to(self.accelerator.device, dtype=inference_dtype)
330
-
331
- trainable_layers = self.sd_pipeline.get_trainable_layers()
332
-
333
- self.accelerator.register_save_state_pre_hook(self._save_model_hook)
334
- self.accelerator.register_load_state_pre_hook(self._load_model_hook)
335
-
336
- # Enable TF32 for faster training on Ampere GPUs,
337
- # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
338
- if self.config.allow_tf32:
339
- torch.backends.cuda.matmul.allow_tf32 = True
340
-
341
- self.optimizer = self._setup_optimizer(
342
- trainable_layers.parameters() if not isinstance(trainable_layers, list) else trainable_layers
343
- )
344
-
345
- self.neg_prompt_embed = self.sd_pipeline.text_encoder(
346
- self.sd_pipeline.tokenizer(
347
- [""] if self.config.negative_prompts is None else self.config.negative_prompts,
348
- return_tensors="pt",
349
- padding="max_length",
350
- truncation=True,
351
- max_length=self.sd_pipeline.tokenizer.model_max_length,
352
- ).input_ids.to(self.accelerator.device)
353
- )[0]
354
-
355
- if config.per_prompt_stat_tracking:
356
- self.stat_tracker = PerPromptStatTracker(
357
- config.per_prompt_stat_tracking_buffer_size,
358
- config.per_prompt_stat_tracking_min_count,
359
- )
360
-
361
- # NOTE: for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses
362
- # more memory
363
- self.autocast = self.sd_pipeline.autocast or self.accelerator.autocast
364
-
365
- if hasattr(self.sd_pipeline, "use_lora") and self.sd_pipeline.use_lora:
366
- unet, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
367
- self.trainable_layers = list(filter(lambda p: p.requires_grad, unet.parameters()))
368
- else:
369
- self.trainable_layers, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
370
-
371
- if self.config.async_reward_computation:
372
- self.executor = futures.ThreadPoolExecutor(max_workers=config.max_workers)
373
-
374
- if config.resume_from:
375
- logger.info(f"Resuming from {config.resume_from}")
376
- self.accelerator.load_state(config.resume_from)
377
- self.first_epoch = int(config.resume_from.split("_")[-1]) + 1
378
- else:
379
- self.first_epoch = 0
380
-
381
- def compute_rewards(self, prompt_image_pairs, is_async=False):
382
- if not is_async:
383
- rewards = []
384
- for images, prompts, prompt_metadata in prompt_image_pairs:
385
- reward, reward_metadata = self.reward_fn(images, prompts, prompt_metadata)
386
- rewards.append(
387
- (
388
- torch.as_tensor(reward, device=self.accelerator.device),
389
- reward_metadata,
390
- )
391
- )
392
- else:
393
- rewards = self.executor.map(lambda x: self.reward_fn(*x), prompt_image_pairs)
394
- rewards = [
395
- (torch.as_tensor(reward.result(), device=self.accelerator.device), reward_metadata.result())
396
- for reward, reward_metadata in rewards
397
- ]
398
-
399
- return zip(*rewards)
400
-
401
- def step(self, epoch: int, global_step: int):
402
- """
403
- Perform a single step of training.
404
-
405
- Args:
406
- epoch (int): The current epoch.
407
- global_step (int): The current global step.
408
-
409
- Side Effects:
410
- - Model weights are updated
411
- - Logs the statistics to the accelerator trackers.
412
- - If `self.image_samples_callback` is not None, it will be called with the prompt_image_pairs, global_step, and the accelerator tracker.
413
-
414
- Returns:
415
- global_step (int): The updated global step.
416
-
417
- """
418
- samples, prompt_image_data = self._generate_samples(
419
- iterations=self.config.sample_num_batches_per_epoch,
420
- batch_size=self.config.sample_batch_size,
421
- )
422
-
423
- # collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...)
424
- samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()}
425
- rewards, rewards_metadata = self.compute_rewards(
426
- prompt_image_data, is_async=self.config.async_reward_computation
427
- )
428
-
429
- for i, image_data in enumerate(prompt_image_data):
430
- image_data.extend([rewards[i], rewards_metadata[i]])
431
-
432
- if self.image_samples_callback is not None:
433
- self.image_samples_callback(prompt_image_data, global_step, self.accelerator.trackers[0])
434
-
435
- rewards = torch.cat(rewards)
436
- rewards = self.accelerator.gather(rewards).cpu().numpy()
437
-
438
- self.accelerator.log(
439
- {
440
- "reward": rewards,
441
- "epoch": epoch,
442
- "reward_mean": rewards.mean(),
443
- "reward_std": rewards.std(),
444
- },
445
- step=global_step,
446
- )
447
-
448
- if self.config.per_prompt_stat_tracking:
449
- # gather the prompts across processes
450
- prompt_ids = self.accelerator.gather(samples["prompt_ids"]).cpu().numpy()
451
- prompts = self.sd_pipeline.tokenizer.batch_decode(prompt_ids, skip_special_tokens=True)
452
- advantages = self.stat_tracker.update(prompts, rewards)
453
- else:
454
- advantages = (rewards - rewards.mean()) / (rewards.std() + 1e-8)
455
-
456
- # ungather advantages; keep the entries corresponding to the samples on this process
457
- samples["advantages"] = (
458
- torch.as_tensor(advantages)
459
- .reshape(self.accelerator.num_processes, -1)[self.accelerator.process_index]
460
- .to(self.accelerator.device)
461
- )
462
-
463
- del samples["prompt_ids"]
464
-
465
- total_batch_size, num_timesteps = samples["timesteps"].shape
466
-
467
- for inner_epoch in range(self.config.train_num_inner_epochs):
468
- # shuffle samples along batch dimension
469
- perm = torch.randperm(total_batch_size, device=self.accelerator.device)
470
- samples = {k: v[perm] for k, v in samples.items()}
471
-
472
- # shuffle along time dimension independently for each sample
473
- # still trying to understand the code below
474
- perms = torch.stack(
475
- [torch.randperm(num_timesteps, device=self.accelerator.device) for _ in range(total_batch_size)]
476
- )
477
-
478
- for key in ["timesteps", "latents", "next_latents", "log_probs"]:
479
- samples[key] = samples[key][
480
- torch.arange(total_batch_size, device=self.accelerator.device)[:, None],
481
- perms,
482
- ]
483
-
484
- original_keys = samples.keys()
485
- original_values = samples.values()
486
- # rebatch them as user defined train_batch_size is different from sample_batch_size
487
- reshaped_values = [v.reshape(-1, self.config.train_batch_size, *v.shape[1:]) for v in original_values]
488
-
489
- # Transpose the list of original values
490
- transposed_values = zip(*reshaped_values)
491
- # Create new dictionaries for each row of transposed values
492
- samples_batched = [dict(zip(original_keys, row_values)) for row_values in transposed_values]
493
-
494
- self.sd_pipeline.unet.train()
495
- global_step = self._train_batched_samples(inner_epoch, epoch, global_step, samples_batched)
496
- # ensure optimization step at the end of the inner epoch
497
- if not self.accelerator.sync_gradients:
498
- raise ValueError(
499
- "Optimization step should have been performed by this point. Please check calculated gradient accumulation settings."
500
- )
501
-
502
- if epoch != 0 and epoch % self.config.save_freq == 0 and self.accelerator.is_main_process:
503
- self.accelerator.save_state()
504
-
505
- return global_step
506
-
507
- def calculate_loss(self, latents, timesteps, next_latents, log_probs, advantages, embeds):
508
- """
509
- Calculate the loss for a batch of an unpacked sample
510
-
511
- Args:
512
- latents (torch.Tensor):
513
- The latents sampled from the diffusion model, shape: [batch_size, num_channels_latents, height, width]
514
- timesteps (torch.Tensor):
515
- The timesteps sampled from the diffusion model, shape: [batch_size]
516
- next_latents (torch.Tensor):
517
- The next latents sampled from the diffusion model, shape: [batch_size, num_channels_latents, height, width]
518
- log_probs (torch.Tensor):
519
- The log probabilities of the latents, shape: [batch_size]
520
- advantages (torch.Tensor):
521
- The advantages of the latents, shape: [batch_size]
522
- embeds (torch.Tensor):
523
- The embeddings of the prompts, shape: [2*batch_size or batch_size, ...]
524
- Note: the "or" is because if train_cfg is True, the expectation is that negative prompts are concatenated to the embeds
525
-
526
- Returns:
527
- loss (torch.Tensor), approx_kl (torch.Tensor), clipfrac (torch.Tensor)
528
- (all of these are of shape (1,))
529
- """
530
- with self.autocast():
531
- if self.config.train_cfg:
532
- noise_pred = self.sd_pipeline.unet(
533
- torch.cat([latents] * 2),
534
- torch.cat([timesteps] * 2),
535
- embeds,
536
- ).sample
537
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
538
- noise_pred = noise_pred_uncond + self.config.sample_guidance_scale * (
539
- noise_pred_text - noise_pred_uncond
540
- )
541
- else:
542
- noise_pred = self.sd_pipeline.unet(
543
- latents,
544
- timesteps,
545
- embeds,
546
- ).sample
547
- # compute the log prob of next_latents given latents under the current model
548
-
549
- scheduler_step_output = self.sd_pipeline.scheduler_step(
550
- noise_pred,
551
- timesteps,
552
- latents,
553
- eta=self.config.sample_eta,
554
- prev_sample=next_latents,
555
- )
556
-
557
- log_prob = scheduler_step_output.log_probs
558
-
559
- advantages = torch.clamp(
560
- advantages,
561
- -self.config.train_adv_clip_max,
562
- self.config.train_adv_clip_max,
563
- )
564
-
565
- ratio = torch.exp(log_prob - log_probs)
566
-
567
- loss = self.loss(advantages, self.config.train_clip_range, ratio)
568
-
569
- approx_kl = 0.5 * torch.mean((log_prob - log_probs) ** 2)
570
-
571
- clipfrac = torch.mean((torch.abs(ratio - 1.0) > self.config.train_clip_range).float())
572
-
573
- return loss, approx_kl, clipfrac
574
-
575
- def loss(
576
- self,
577
- advantages: torch.Tensor,
578
- clip_range: float,
579
- ratio: torch.Tensor,
580
- ):
581
- unclipped_loss = -advantages * ratio
582
- clipped_loss = -advantages * torch.clamp(
583
- ratio,
584
- 1.0 - clip_range,
585
- 1.0 + clip_range,
586
- )
587
- return torch.mean(torch.maximum(unclipped_loss, clipped_loss))
588
-
589
- def _setup_optimizer(self, trainable_layers_parameters):
590
- if self.config.train_use_8bit_adam:
591
- import bitsandbytes
592
-
593
- optimizer_cls = bitsandbytes.optim.AdamW8bit
594
- else:
595
- optimizer_cls = torch.optim.AdamW
596
-
597
- return optimizer_cls(
598
- trainable_layers_parameters,
599
- lr=self.config.train_learning_rate,
600
- betas=(self.config.train_adam_beta1, self.config.train_adam_beta2),
601
- weight_decay=self.config.train_adam_weight_decay,
602
- eps=self.config.train_adam_epsilon,
603
- )
604
-
605
- def _save_model_hook(self, models, weights, output_dir):
606
- self.sd_pipeline.save_checkpoint(models, weights, output_dir)
607
- weights.pop() # ensures that accelerate doesn't try to handle saving of the model
608
-
609
- def _load_model_hook(self, models, input_dir):
610
- self.sd_pipeline.load_checkpoint(models, input_dir)
611
- models.pop() # ensures that accelerate doesn't try to handle loading of the model
612
-
613
- def _generate_samples(self, iterations, batch_size):
614
- """
615
- Generate samples from the model
616
-
617
- Args:
618
- iterations (int): Number of iterations to generate samples for
619
- batch_size (int): Batch size to use for sampling
620
-
621
- Returns:
622
- samples (list[dict[str, torch.Tensor]]), prompt_image_pairs (list[list[Any]])
623
- """
624
- samples = []
625
- prompt_image_pairs = []
626
- self.sd_pipeline.unet.eval()
627
-
628
- sample_neg_prompt_embeds = self.neg_prompt_embed.repeat(batch_size, 1, 1)
629
-
630
- for _ in range(iterations):
631
- prompts, prompt_metadata = zip(*[self.prompt_fn() for _ in range(batch_size)])
632
-
633
- prompt_ids = self.sd_pipeline.tokenizer(
634
- prompts,
635
- return_tensors="pt",
636
- padding="max_length",
637
- truncation=True,
638
- max_length=self.sd_pipeline.tokenizer.model_max_length,
639
- ).input_ids.to(self.accelerator.device)
640
- prompt_embeds = self.sd_pipeline.text_encoder(prompt_ids)[0]
641
-
642
- with self.autocast():
643
- sd_output = self.sd_pipeline(
644
- prompt_embeds=prompt_embeds,
645
- negative_prompt_embeds=sample_neg_prompt_embeds,
646
- num_inference_steps=self.config.sample_num_steps,
647
- guidance_scale=self.config.sample_guidance_scale,
648
- eta=self.config.sample_eta,
649
- output_type="pt",
650
- )
651
-
652
- images = sd_output.images
653
- latents = sd_output.latents
654
- log_probs = sd_output.log_probs
655
-
656
- latents = torch.stack(latents, dim=1) # (batch_size, num_steps + 1, ...)
657
- log_probs = torch.stack(log_probs, dim=1) # (batch_size, num_steps, 1)
658
- timesteps = self.sd_pipeline.scheduler.timesteps.repeat(batch_size, 1) # (batch_size, num_steps)
659
-
660
- samples.append(
661
- {
662
- "prompt_ids": prompt_ids,
663
- "prompt_embeds": prompt_embeds,
664
- "timesteps": timesteps,
665
- "latents": latents[:, :-1], # each entry is the latent before timestep t
666
- "next_latents": latents[:, 1:], # each entry is the latent after timestep t
667
- "log_probs": log_probs,
668
- "negative_prompt_embeds": sample_neg_prompt_embeds,
669
- }
670
- )
671
- prompt_image_pairs.append([images, prompts, prompt_metadata])
672
-
673
- return samples, prompt_image_pairs
674
-
675
- def _train_batched_samples(self, inner_epoch, epoch, global_step, batched_samples):
676
- """
677
- Train on a batch of samples. Main training segment
678
-
679
- Args:
680
- inner_epoch (int): The current inner epoch
681
- epoch (int): The current epoch
682
- global_step (int): The current global step
683
- batched_samples (list[dict[str, torch.Tensor]]): The batched samples to train on
684
-
685
- Side Effects:
686
- - Model weights are updated
687
- - Logs the statistics to the accelerator trackers.
688
-
689
- Returns:
690
- global_step (int): The updated global step
691
- """
692
- info = defaultdict(list)
693
- for _i, sample in enumerate(batched_samples):
694
- if self.config.train_cfg:
695
- # concat negative prompts to sample prompts to avoid two forward passes
696
- embeds = torch.cat([sample["negative_prompt_embeds"], sample["prompt_embeds"]])
697
- else:
698
- embeds = sample["prompt_embeds"]
699
-
700
- for j in range(self.num_train_timesteps):
701
- with self.accelerator.accumulate(self.sd_pipeline.unet):
702
- loss, approx_kl, clipfrac = self.calculate_loss(
703
- sample["latents"][:, j],
704
- sample["timesteps"][:, j],
705
- sample["next_latents"][:, j],
706
- sample["log_probs"][:, j],
707
- sample["advantages"],
708
- embeds,
709
- )
710
- info["approx_kl"].append(approx_kl)
711
- info["clipfrac"].append(clipfrac)
712
- info["loss"].append(loss)
713
-
714
- self.accelerator.backward(loss)
715
- if self.accelerator.sync_gradients:
716
- self.accelerator.clip_grad_norm_(
717
- self.trainable_layers.parameters()
718
- if not isinstance(self.trainable_layers, list)
719
- else self.trainable_layers,
720
- self.config.train_max_grad_norm,
721
- )
722
- self.optimizer.step()
723
- self.optimizer.zero_grad()
724
-
725
- # Checks if the accelerator has performed an optimization step behind the scenes
726
- if self.accelerator.sync_gradients:
727
- # log training-related stuff
728
- info = {k: torch.mean(torch.stack(v)) for k, v in info.items()}
729
- info = self.accelerator.reduce(info, reduction="mean")
730
- info.update({"epoch": epoch, "inner_epoch": inner_epoch})
731
- self.accelerator.log(info, step=global_step)
732
- global_step += 1
733
- info = defaultdict(list)
734
- return global_step
735
-
736
- def _config_check(self) -> tuple[bool, str]:
737
- samples_per_epoch = (
738
- self.config.sample_batch_size * self.accelerator.num_processes * self.config.sample_num_batches_per_epoch
739
- )
740
- total_train_batch_size = (
741
- self.config.train_batch_size
742
- * self.accelerator.num_processes
743
- * self.config.train_gradient_accumulation_steps
744
- )
745
-
746
- if not self.config.sample_batch_size >= self.config.train_batch_size:
747
- return (
748
- False,
749
- f"Sample batch size ({self.config.sample_batch_size}) must be greater than or equal to the train batch size ({self.config.train_batch_size})",
750
- )
751
- if not self.config.sample_batch_size % self.config.train_batch_size == 0:
752
- return (
753
- False,
754
- f"Sample batch size ({self.config.sample_batch_size}) must be divisible by the train batch size ({self.config.train_batch_size})",
755
- )
756
- if not samples_per_epoch % total_train_batch_size == 0:
757
- return (
758
- False,
759
- f"Number of samples per epoch ({samples_per_epoch}) must be divisible by the total train batch size ({total_train_batch_size})",
760
- )
761
- return True, ""
762
-
763
- def train(self, epochs: Optional[int] = None):
764
- """
765
- Train the model for a given number of epochs
766
- """
767
- global_step = 0
768
- if epochs is None:
769
- epochs = self.config.num_epochs
770
- for epoch in range(self.first_epoch, epochs):
771
- global_step = self.step(epoch, global_step)
772
-
773
- def _save_pretrained(self, save_directory):
774
- self.sd_pipeline.save_pretrained(save_directory)
775
- self.create_model_card()
776
-
777
- def create_model_card(
778
- self,
779
- model_name: Optional[str] = None,
780
- dataset_name: Optional[str] = None,
781
- tags: Union[str, list[str], None] = None,
782
- ):
783
- """
784
- Creates a draft of a model card using the information available to the `Trainer`.
785
-
786
- Args:
787
- model_name (`str` or `None`, *optional*, defaults to `None`):
788
- Name of the model.
789
- dataset_name (`str` or `None`, *optional*, defaults to `None`):
790
- Name of the dataset used for training.
791
- tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
792
- Tags to be associated with the model card.
793
- """
794
- if not self.is_world_process_zero():
795
- return
796
-
797
- if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
798
- base_model = self.model.config._name_or_path
799
- else:
800
- base_model = None
801
-
802
- tags = tags or []
803
- if isinstance(tags, str):
804
- tags = [tags]
805
-
806
- if hasattr(self.model.config, "unsloth_version"):
807
- tags.append("unsloth")
808
-
809
- citation = textwrap.dedent("""\
810
- @inproceedings{black2024training,
811
- title = {{Training Diffusion Models with Reinforcement Learning}},
812
- author = {Kevin Black and Michael Janner and Yilun Du and Ilya Kostrikov and Sergey Levine},
813
- year = 2024,
814
- booktitle = {The Twelfth International Conference on Learning Representations, {ICLR} 2024, Vienna, Austria, May 7-11, 2024},
815
- publisher = {OpenReview.net},
816
- url = {https://openreview.net/forum?id=YCWjhGrJFD},
817
- }""")
818
-
819
- model_card = generate_model_card(
820
- base_model=base_model,
821
- model_name=model_name,
822
- hub_model_id=self.hub_model_id,
823
- dataset_name=dataset_name,
824
- tags=tags,
825
- wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
826
- comet_url=get_comet_experiment_url(),
827
- trainer_name="DDPO",
828
- trainer_citation=citation,
829
- paper_title="Training Diffusion Models with Reinforcement Learning",
830
- paper_id="2305.13301",
831
- )
832
-
833
- model_card.save(os.path.join(self.args.output_dir, "README.md"))
834
- class UnslothDDPOTrainer(_UnslothDDPOTrainer):
835
- """
836
-
837
- The DDPOTrainer uses Deep Diffusion Policy Optimization to optimise diffusion models.
838
- Note, this trainer is heavily inspired by the work here: https://github.com/kvablack/ddpo-pytorch
839
- As of now only Stable Diffusion based pipelines are supported
840
-
841
- Attributes:
842
- **config** (`DDPOConfig`) -- Configuration object for DDPOTrainer. Check the documentation of `PPOConfig` for more
843
- details.
844
- **reward_function** (Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor]) -- Reward function to be used
845
- **prompt_function** (Callable[[], tuple[str, Any]]) -- Function to generate prompts to guide model
846
- **sd_pipeline** (`DDPOStableDiffusionPipeline`) -- Stable Diffusion pipeline to be used for training.
847
- **image_samples_hook** (Optional[Callable[[Any, Any, Any], Any]]) -- Hook to be called to log images
848
-
849
- """
850
- def __init__(
851
- self,
852
- config,
853
- reward_function,
854
- prompt_function,
855
- sd_pipeline,
856
- image_samples_hook = None,
857
- **kwargs
858
- ):
859
- if args is None: args = UnslothDDPOConfig()
860
- other_metrics = []
861
-
862
- from unsloth_zoo.logging_utils import PatchRLStatistics
863
- PatchRLStatistics('ddpo_trainer', other_metrics)
864
-
865
- super().__init__(
866
- config = config,
867
- reward_function = reward_function,
868
- prompt_function = prompt_function,
869
- sd_pipeline = sd_pipeline,
870
- image_samples_hook = image_samples_hook,**kwargs)
871
-
872
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
unsloth_compiled_cache/UnslothDPOTrainer.py DELETED
The diff for this file is too large to render. See raw diff
 
unsloth_compiled_cache/UnslothGKDTrainer.py DELETED
@@ -1,861 +0,0 @@
1
- """
2
- 2025.3.13
3
- 2025.3.15
4
- 4.48.3
5
- 0.15.2
6
- __UNSLOTH_VERSIONING__
7
- """
8
- from torch import Tensor
9
- import torch
10
- import torch.nn as nn
11
- from torch.nn import functional as F
12
- from trl.trainer.gkd_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, Callable, DataCollator, DataCollatorForChatML, Dataset, EvalPrediction, F, FeatureExtractionMixin, GKDConfig, GKDTrainer, GenerationConfig, Optional, PeftConfig, PreTrainedModel, PreTrainedModelWrapper, PreTrainedTokenizerBase, ProcessorMixin, SFTTrainer, TrainerCallback, Union, deepcopy, disable_dropout_in_model, empty_cache, generate_model_card, get_comet_experiment_url, is_wandb_available, nn, os, random, textwrap, torch, unwrap_model_for_generation, wandb)
13
-
14
-
15
- import os
16
- from typing import *
17
- from dataclasses import dataclass, field
18
- from packaging.version import Version
19
- import torch
20
- import numpy as np
21
- from contextlib import nullcontext
22
- from torch.nn import functional as F
23
- from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
-
25
- torch_compile_options = {
26
- "epilogue_fusion" : True,
27
- "max_autotune" : False,
28
- "shape_padding" : True,
29
- "trace.enabled" : False,
30
- "triton.cudagraphs" : False,
31
- }
32
-
33
- @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
- def selective_log_softmax(logits, index):
35
- logits = logits.to(torch.float32)
36
- selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
- # loop to reduce peak mem consumption
38
- # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
- logsumexp_values = torch.logsumexp(logits, dim = -1)
40
- per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
- return per_token_logps
42
- @dataclass
43
- class UnslothGKDConfig(GKDConfig):
44
- """
45
-
46
- Configuration class for [`GKDTrainer`].
47
-
48
- Args:
49
- temperature (`float`, *optional*, defaults to `0.9`):
50
- Temperature for sampling. The higher the temperature, the more random the completions.
51
- lmbda (`float`, *optional*, defaults to `0.5`):
52
- Lambda parameter that controls the student data fraction (i.e., the proportion of on-policy
53
- student-generated outputs).
54
- beta (`float`, *optional*, defaults to `0.5`):
55
- Interpolation coefficient between `0.0` and `1.0` of the Generalized Jensen-Shannon Divergence loss. When
56
- beta is `0.0`, the loss is the KL divergence. When beta is `1.0`, the loss is the Inverse KL Divergence.
57
- max_new_tokens (`int`, *optional*, defaults to `128`):
58
- Maximum number of tokens to generate per completion.
59
- teacher_model_name_or_path (`str` or `None`, *optional*, defaults to `None`):
60
- Model name or path of the teacher model. If `None`, the teacher model will be the same as the model
61
- being trained.
62
- teacher_model_init_kwargs (`dict[str, Any]]` or `None`, *optional*, defaults to `None`):
63
- Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the teacher model
64
- from a string.
65
- disable_dropout (`bool`, *optional*, defaults to `True`):
66
- Whether to disable dropout in the model.
67
- seq_kd (`bool`, *optional*, defaults to `False`):
68
- Seq_kd parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised FT
69
- on teacher-generated output).
70
-
71
- """
72
- vllm_sampling_params: Optional[Any] = field(
73
- default = None,
74
- metadata = {'help': 'vLLM SamplingParams'},
75
- )
76
- unsloth_num_chunks : Optional[int] = field(
77
- default = -1,
78
- metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
79
- )
80
- def __init__(
81
- self,
82
- output_dir = None,
83
- overwrite_output_dir = None,
84
- do_train = False,
85
- do_eval = False,
86
- do_predict = False,
87
- eval_strategy = 'no',
88
- prediction_loss_only = False,
89
- per_device_train_batch_size = 4,
90
- per_device_eval_batch_size = 4,
91
- per_gpu_train_batch_size = None,
92
- per_gpu_eval_batch_size = None,
93
- gradient_accumulation_steps = 2,
94
- eval_accumulation_steps = 2,
95
- eval_delay = 0,
96
- torch_empty_cache_steps = 250,
97
- learning_rate = 5e-05,
98
- weight_decay = 0.01,
99
- adam_beta1 = 0.9,
100
- adam_beta2 = 0.999,
101
- adam_epsilon = 1e-08,
102
- max_grad_norm = 1.0,
103
- num_train_epochs = 3.0,
104
- max_steps = -1,
105
- lr_scheduler_type = 'linear',
106
- warmup_ratio = 0.1,
107
- warmup_steps = 0,
108
- log_level = 'passive',
109
- log_level_replica = 'warning',
110
- log_on_each_node = True,
111
- logging_dir = None,
112
- logging_strategy = 'steps',
113
- logging_first_step = False,
114
- logging_steps = 1,
115
- logging_nan_inf_filter = False,
116
- save_strategy = 'steps',
117
- save_steps = 500,
118
- save_total_limit = None,
119
- save_safetensors = True,
120
- save_on_each_node = False,
121
- save_only_model = False,
122
- restore_callback_states_from_checkpoint = False,
123
- no_cuda = False,
124
- use_cpu = False,
125
- use_mps_device = False,
126
- seed = 3407,
127
- data_seed = 3407,
128
- jit_mode_eval = False,
129
- use_ipex = False,
130
- bf16 = False,
131
- fp16 = False,
132
- fp16_opt_level = 'O1',
133
- half_precision_backend = 'auto',
134
- bf16_full_eval = False,
135
- fp16_full_eval = False,
136
- tf32 = None,
137
- local_rank = -1,
138
- ddp_backend = None,
139
- tpu_num_cores = None,
140
- tpu_metrics_debug = False,
141
- debug = '',
142
- dataloader_drop_last = False,
143
- eval_steps = None,
144
- dataloader_num_workers = 0,
145
- dataloader_prefetch_factor = None,
146
- past_index = -1,
147
- run_name = None,
148
- disable_tqdm = None,
149
- remove_unused_columns = True,
150
- label_names = None,
151
- load_best_model_at_end = False,
152
- metric_for_best_model = None,
153
- greater_is_better = None,
154
- ignore_data_skip = False,
155
- fsdp = '',
156
- fsdp_min_num_params = 0,
157
- fsdp_config = None,
158
- fsdp_transformer_layer_cls_to_wrap = None,
159
- accelerator_config = None,
160
- deepspeed = None,
161
- label_smoothing_factor = 0.0,
162
- optim = 'adamw_8bit',
163
- optim_args = None,
164
- adafactor = False,
165
- group_by_length = False,
166
- length_column_name = 'length',
167
- report_to = None,
168
- ddp_find_unused_parameters = None,
169
- ddp_bucket_cap_mb = None,
170
- ddp_broadcast_buffers = None,
171
- dataloader_pin_memory = True,
172
- dataloader_persistent_workers = False,
173
- skip_memory_metrics = True,
174
- use_legacy_prediction_loop = False,
175
- push_to_hub = False,
176
- resume_from_checkpoint = None,
177
- hub_model_id = None,
178
- hub_strategy = 'every_save',
179
- hub_token = None,
180
- hub_private_repo = None,
181
- hub_always_push = False,
182
- gradient_checkpointing = False,
183
- gradient_checkpointing_kwargs = None,
184
- include_inputs_for_metrics = False,
185
- eval_do_concat_batches = True,
186
- fp16_backend = 'auto',
187
- evaluation_strategy = None,
188
- push_to_hub_model_id = None,
189
- push_to_hub_organization = None,
190
- push_to_hub_token = None,
191
- mp_parameters = '',
192
- auto_find_batch_size = False,
193
- full_determinism = False,
194
- torchdynamo = None,
195
- ray_scope = 'last',
196
- ddp_timeout = 1800,
197
- torch_compile = False,
198
- torch_compile_backend = None,
199
- torch_compile_mode = None,
200
- dispatch_batches = None,
201
- split_batches = None,
202
- include_tokens_per_second = False,
203
- include_num_input_tokens_seen = False,
204
- neftune_noise_alpha = None,
205
- optim_target_modules = None,
206
- batch_eval_metrics = False,
207
- eval_on_start = False,
208
- use_liger_kernel = False,
209
- eval_use_gather_object = False,
210
- average_tokens_across_devices = False,
211
- model_init_kwargs = None,
212
- use_liger = False,
213
- dataset_text_field = 'text',
214
- dataset_kwargs = None,
215
- dataset_num_proc = None,
216
- max_seq_length = None,
217
- packing = False,
218
- eval_packing = None,
219
- dataset_batch_size = None,
220
- num_of_sequences = None,
221
- chars_per_token = None,
222
- temperature = 0.9,
223
- lmbda = 0.5,
224
- beta = 0.5,
225
- max_new_tokens = 128,
226
- teacher_model_name_or_path = None,
227
- teacher_model_init_kwargs = None,
228
- disable_dropout = True,
229
- seq_kd = False,
230
- vllm_sampling_params = None,
231
- unsloth_num_chunks = -1,
232
- **kwargs,
233
- ):
234
- if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
235
- if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
236
- if output_dir is None and save_strategy == 'steps' and save_steps == 500:
237
- output_dir = 'unsloth_training_checkpoints'
238
- save_strategy = 'no'
239
- if dataset_num_proc is None:
240
- from multiprocessing import cpu_count
241
- dataset_num_proc = cpu_count()
242
-
243
- super().__init__(
244
- output_dir = output_dir,
245
- overwrite_output_dir = overwrite_output_dir,
246
- do_train = do_train,
247
- do_eval = do_eval,
248
- do_predict = do_predict,
249
- eval_strategy = eval_strategy,
250
- prediction_loss_only = prediction_loss_only,
251
- per_device_train_batch_size = per_device_train_batch_size,
252
- per_device_eval_batch_size = per_device_eval_batch_size,
253
- per_gpu_train_batch_size = per_gpu_train_batch_size,
254
- per_gpu_eval_batch_size = per_gpu_eval_batch_size,
255
- gradient_accumulation_steps = gradient_accumulation_steps,
256
- eval_accumulation_steps = eval_accumulation_steps,
257
- eval_delay = eval_delay,
258
- torch_empty_cache_steps = torch_empty_cache_steps,
259
- learning_rate = learning_rate,
260
- weight_decay = weight_decay,
261
- adam_beta1 = adam_beta1,
262
- adam_beta2 = adam_beta2,
263
- adam_epsilon = adam_epsilon,
264
- max_grad_norm = max_grad_norm,
265
- num_train_epochs = num_train_epochs,
266
- max_steps = max_steps,
267
- lr_scheduler_type = lr_scheduler_type,
268
- warmup_ratio = warmup_ratio,
269
- warmup_steps = warmup_steps,
270
- log_level = log_level,
271
- log_level_replica = log_level_replica,
272
- log_on_each_node = log_on_each_node,
273
- logging_dir = logging_dir,
274
- logging_strategy = logging_strategy,
275
- logging_first_step = logging_first_step,
276
- logging_steps = logging_steps,
277
- logging_nan_inf_filter = logging_nan_inf_filter,
278
- save_strategy = save_strategy,
279
- save_steps = save_steps,
280
- save_total_limit = save_total_limit,
281
- save_safetensors = save_safetensors,
282
- save_on_each_node = save_on_each_node,
283
- save_only_model = save_only_model,
284
- restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
285
- no_cuda = no_cuda,
286
- use_cpu = use_cpu,
287
- use_mps_device = use_mps_device,
288
- seed = seed,
289
- data_seed = data_seed,
290
- jit_mode_eval = jit_mode_eval,
291
- use_ipex = use_ipex,
292
- bf16 = bf16,
293
- fp16 = fp16,
294
- fp16_opt_level = fp16_opt_level,
295
- half_precision_backend = half_precision_backend,
296
- bf16_full_eval = bf16_full_eval,
297
- fp16_full_eval = fp16_full_eval,
298
- tf32 = tf32,
299
- local_rank = local_rank,
300
- ddp_backend = ddp_backend,
301
- tpu_num_cores = tpu_num_cores,
302
- tpu_metrics_debug = tpu_metrics_debug,
303
- debug = debug,
304
- dataloader_drop_last = dataloader_drop_last,
305
- eval_steps = eval_steps,
306
- dataloader_num_workers = dataloader_num_workers,
307
- dataloader_prefetch_factor = dataloader_prefetch_factor,
308
- past_index = past_index,
309
- run_name = run_name,
310
- disable_tqdm = disable_tqdm,
311
- remove_unused_columns = remove_unused_columns,
312
- label_names = label_names,
313
- load_best_model_at_end = load_best_model_at_end,
314
- metric_for_best_model = metric_for_best_model,
315
- greater_is_better = greater_is_better,
316
- ignore_data_skip = ignore_data_skip,
317
- fsdp = fsdp,
318
- fsdp_min_num_params = fsdp_min_num_params,
319
- fsdp_config = fsdp_config,
320
- fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
321
- accelerator_config = accelerator_config,
322
- deepspeed = deepspeed,
323
- label_smoothing_factor = label_smoothing_factor,
324
- optim = optim,
325
- optim_args = optim_args,
326
- adafactor = adafactor,
327
- group_by_length = group_by_length,
328
- length_column_name = length_column_name,
329
- report_to = report_to,
330
- ddp_find_unused_parameters = ddp_find_unused_parameters,
331
- ddp_bucket_cap_mb = ddp_bucket_cap_mb,
332
- ddp_broadcast_buffers = ddp_broadcast_buffers,
333
- dataloader_pin_memory = dataloader_pin_memory,
334
- dataloader_persistent_workers = dataloader_persistent_workers,
335
- skip_memory_metrics = skip_memory_metrics,
336
- use_legacy_prediction_loop = use_legacy_prediction_loop,
337
- push_to_hub = push_to_hub,
338
- resume_from_checkpoint = resume_from_checkpoint,
339
- hub_model_id = hub_model_id,
340
- hub_strategy = hub_strategy,
341
- hub_token = hub_token,
342
- hub_private_repo = hub_private_repo,
343
- hub_always_push = hub_always_push,
344
- gradient_checkpointing = gradient_checkpointing,
345
- gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
346
- include_inputs_for_metrics = include_inputs_for_metrics,
347
- eval_do_concat_batches = eval_do_concat_batches,
348
- fp16_backend = fp16_backend,
349
- evaluation_strategy = evaluation_strategy,
350
- push_to_hub_model_id = push_to_hub_model_id,
351
- push_to_hub_organization = push_to_hub_organization,
352
- push_to_hub_token = push_to_hub_token,
353
- mp_parameters = mp_parameters,
354
- auto_find_batch_size = auto_find_batch_size,
355
- full_determinism = full_determinism,
356
- torchdynamo = torchdynamo,
357
- ray_scope = ray_scope,
358
- ddp_timeout = ddp_timeout,
359
- torch_compile = torch_compile,
360
- torch_compile_backend = torch_compile_backend,
361
- torch_compile_mode = torch_compile_mode,
362
- dispatch_batches = dispatch_batches,
363
- split_batches = split_batches,
364
- include_tokens_per_second = include_tokens_per_second,
365
- include_num_input_tokens_seen = include_num_input_tokens_seen,
366
- neftune_noise_alpha = neftune_noise_alpha,
367
- optim_target_modules = optim_target_modules,
368
- batch_eval_metrics = batch_eval_metrics,
369
- eval_on_start = eval_on_start,
370
- use_liger_kernel = use_liger_kernel,
371
- eval_use_gather_object = eval_use_gather_object,
372
- average_tokens_across_devices = average_tokens_across_devices,
373
- model_init_kwargs = model_init_kwargs,
374
- use_liger = use_liger,
375
- dataset_text_field = dataset_text_field,
376
- dataset_kwargs = dataset_kwargs,
377
- dataset_num_proc = dataset_num_proc,
378
- max_seq_length = max_seq_length,
379
- packing = packing,
380
- eval_packing = eval_packing,
381
- dataset_batch_size = dataset_batch_size,
382
- num_of_sequences = num_of_sequences,
383
- chars_per_token = chars_per_token,
384
- temperature = temperature,
385
- lmbda = lmbda,
386
- beta = beta,
387
- max_new_tokens = max_new_tokens,
388
- teacher_model_name_or_path = teacher_model_name_or_path,
389
- teacher_model_init_kwargs = teacher_model_init_kwargs,
390
- disable_dropout = disable_dropout,
391
- seq_kd = seq_kd,**kwargs)
392
- self.vllm_sampling_params = vllm_sampling_params
393
- self.unsloth_num_chunks = unsloth_num_chunks
394
- pass
395
-
396
- class _UnslothGKDTrainer(SFTTrainer):
397
- _tag_names = ["trl", "gkd"]
398
-
399
- def __init__(
400
- self,
401
- model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
402
- teacher_model: Union[PreTrainedModel, nn.Module, str] = None,
403
- args: Optional[GKDConfig] = None,
404
- data_collator: Optional[DataCollator] = None, # type: ignore
405
- train_dataset: Optional[Dataset] = None,
406
- eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
407
- processing_class: Optional[
408
- Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
409
- ] = None,
410
- compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
411
- callbacks: Optional[list[TrainerCallback]] = None,
412
- optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
413
- preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
414
- peft_config: Optional["PeftConfig"] = None,
415
- formatting_func: Optional[Callable] = None,
416
- ):
417
- # add remove_unused_columns=False to the dataclass args
418
- args.remove_unused_columns = False
419
- data_collator = DataCollatorForChatML(tokenizer=processing_class, max_length=args.max_seq_length)
420
-
421
- super().__init__(
422
- model,
423
- args=args,
424
- data_collator=data_collator,
425
- train_dataset=train_dataset,
426
- eval_dataset=eval_dataset,
427
- processing_class=processing_class,
428
- compute_metrics=compute_metrics,
429
- callbacks=callbacks,
430
- optimizers=optimizers,
431
- preprocess_logits_for_metrics=preprocess_logits_for_metrics,
432
- peft_config=peft_config,
433
- formatting_func=formatting_func,
434
- )
435
-
436
- if args.teacher_model_init_kwargs is None:
437
- teacher_model_init_kwargs = {}
438
- elif not isinstance(teacher_model, str):
439
- raise ValueError(
440
- "You passed teacher_model_init_kwargs to the GKDConfig, but your teacher_model is already instantiated."
441
- )
442
- else:
443
- teacher_model_init_kwargs = args.teacher_model_init_kwargs
444
- teacher_model_init_kwargs["torch_dtype"] = (
445
- teacher_model_init_kwargs["torch_dtype"]
446
- if teacher_model_init_kwargs["torch_dtype"] in ["auto", None]
447
- else getattr(torch, teacher_model_init_kwargs["torch_dtype"])
448
- )
449
-
450
- if isinstance(teacher_model, str):
451
- if args.use_liger:
452
- teacher_model = AutoLigerKernelForCausalLM.from_pretrained(teacher_model, **teacher_model_init_kwargs)
453
- else:
454
- teacher_model = AutoModelForCausalLM.from_pretrained(teacher_model, **teacher_model_init_kwargs)
455
-
456
- # Disable dropout in the model
457
- if args.disable_dropout:
458
- disable_dropout_in_model(self.model)
459
-
460
- if self.is_deepspeed_enabled:
461
- self.teacher_model = self._prepare_deepspeed(teacher_model)
462
- else:
463
- self.teacher_model = self.accelerator.prepare_model(teacher_model, evaluation_mode=True)
464
-
465
- self.lmbda = args.lmbda
466
- self.beta = args.beta
467
- self.temperature = args.temperature
468
- self.seq_kd = args.seq_kd
469
-
470
- self.generation_config = GenerationConfig(
471
- max_new_tokens=args.max_new_tokens,
472
- temperature=args.temperature,
473
- do_sample=True,
474
- top_k=0,
475
- use_cache=False if args.gradient_checkpointing else True,
476
- pad_token_id=self.processing_class.pad_token_id,
477
- )
478
- # Set custom EOS tokens if they are specified by the model's generation
479
- # config. This is important for models with the Llama 3 chat template,
480
- # which use special tokens <|eot_id|> and <|eom_id|> to mark the end of
481
- # turns or messages.
482
- if (
483
- hasattr(self.model.generation_config, "eos_token_id")
484
- and self.model.generation_config.eos_token_id is not None
485
- ):
486
- self.generation_config.eos_token_id = self.model.generation_config.eos_token_id
487
-
488
- def _prepare_dataset(self, dataset, *args):
489
- # SFTTrainer._prepare_dataset() applies the chat template and rename the messages column to text. However, we
490
- # need to keep the messages column as it is. We use the following workaround to keep the messages column.
491
- dataset = dataset.add_column("_messages", dataset["messages"])
492
- dataset = super()._prepare_dataset(dataset, *args)
493
- dataset = dataset.rename_column("_messages", "messages")
494
- return dataset
495
-
496
- @staticmethod
497
- def generalized_jsd_loss(
498
- student_logits, teacher_logits, labels=None, beta=0.5, temperature=1.0, reduction="batchmean"
499
- ):
500
- """
501
- Compute the generalized Jensen-Shannon Divergence loss for knowledge distillation using F.kl_div. See Eq. (1)
502
- of https://huggingface.co/papers/2306.13649 for the definition.
503
-
504
- Args:
505
- student_logits: Tensor of shape (batch_size, sequence_length, vocab_size)
506
- teacher_logits: Tensor of shape (batch_size, sequence_length, vocab_size)
507
- labels: Tensor of shape (batch_size, sequence_length) with -100 for padding tokens to ignore when computing loss
508
- beta: Interpolation coefficient between 0 and 1 (default: 0.5)
509
- temperature: Softmax temperature (default: 1.0)
510
- reduction: Specifies the reduction to apply to the output (default: 'batchmean')
511
-
512
- Returns:
513
- loss: Scalar tensor with the generalized JSD loss
514
- """
515
-
516
- # Apply temperature scaling
517
- student_logits = student_logits / temperature
518
- teacher_logits = teacher_logits / temperature
519
-
520
- # Compute log probabilities for student and probabilities for teacher
521
- student_log_probs = F.log_softmax(student_logits, dim=-1)
522
- teacher_log_probs = F.log_softmax(teacher_logits, dim=-1)
523
-
524
- # Compute the log of the mixture distribution
525
- # log(a + b) = log(exp(log(a)) + exp(log(b))) -> for mixture
526
- beta = torch.tensor(beta, dtype=student_log_probs.dtype)
527
- mixture_log_probs = torch.logsumexp(
528
- torch.stack([student_log_probs + torch.log(beta), teacher_log_probs + torch.log(1 - beta)]),
529
- dim=0,
530
- )
531
-
532
- # Compute KL divergences using F.kl_div
533
- # PyTorch differs from the standard mathematical definition, so the order of the probability distributions is swapped compared to that defined in the paper.
534
- kl_teacher = F.kl_div(mixture_log_probs, teacher_log_probs, reduction="none", log_target=True)
535
- kl_student = F.kl_div(mixture_log_probs, student_log_probs, reduction="none", log_target=True)
536
-
537
- # Compute the Generalized Jensen-Shannon Divergence
538
- jsd = beta * kl_teacher + (1 - beta) * kl_student
539
-
540
- # Masking
541
- if labels is not None:
542
- mask = labels != -100
543
- jsd = jsd[mask]
544
-
545
- # Apply reduction
546
- if reduction == "batchmean":
547
- return jsd.sum() / mask.sum() if labels is not None else jsd.sum() / (jsd.size(0) * jsd.size(1))
548
- elif reduction == "sum":
549
- return jsd.sum()
550
- elif reduction == "mean":
551
- return jsd.mean()
552
- else:
553
- return jsd
554
-
555
- def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
556
- # compute student output
557
- outputs_student = model(
558
- input_ids=inputs["input_ids"],
559
- attention_mask=inputs["attention_mask"],
560
- )
561
-
562
- # compute teacher output in eval mode
563
- self.teacher_model.eval()
564
- with torch.no_grad():
565
- outputs_teacher = self.teacher_model(
566
- input_ids=inputs["input_ids"],
567
- attention_mask=inputs["attention_mask"],
568
- )
569
-
570
- # slice the logits for the generated tokens using the inputs["prompts"] lengths
571
- prompt_lengths = inputs["prompts"].shape[1]
572
- shifted_student_logits = outputs_student.logits[:, prompt_lengths - 1 : -1, :]
573
- shifted_teacher_logits = outputs_teacher.logits[:, prompt_lengths - 1 : -1, :]
574
- shifted_labels = inputs["labels"][:, prompt_lengths:]
575
-
576
- # compute loss
577
- loss = self.generalized_jsd_loss(
578
- student_logits=shifted_student_logits,
579
- teacher_logits=shifted_teacher_logits,
580
- labels=shifted_labels,
581
- beta=self.beta,
582
- )
583
-
584
- # empty cache
585
- empty_cache()
586
-
587
- # Return loss
588
- return (loss, outputs_student) if return_outputs else loss
589
-
590
- @staticmethod
591
- def generate_on_policy_outputs(model, inputs, generation_config, pad_token_id=None):
592
- # Generate output with respect to the prompt only
593
- generated_outputs = model.generate(
594
- input_ids=inputs["prompts"],
595
- attention_mask=inputs.get("prompt_attention_mask", None),
596
- generation_config=generation_config,
597
- return_dict_in_generate=True,
598
- )
599
-
600
- # Get the generated token IDs
601
- generated_tokens = generated_outputs.sequences
602
- # Calculate new attention mask
603
- new_attention_mask = torch.ones_like(generated_tokens)
604
- new_labels = generated_tokens.clone()
605
-
606
- # If there's pad_token_id, set attention mask to 0 for padding tokens
607
- if pad_token_id is not None:
608
- new_labels[new_labels == pad_token_id] = -100
609
- new_attention_mask[generated_tokens == pad_token_id] = 0
610
-
611
- return generated_tokens, new_attention_mask, new_labels
612
-
613
- def training_step(
614
- self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
615
- ) -> torch.Tensor:
616
- """
617
- Perform a training step for the Generalized Knowledge Distillation (GKD) model.
618
-
619
- This method implements the on-policy learning approach described in the GKD paper.
620
- With probability `self.lmbda`, it generates new responses using the student model,
621
- which are then used for training instead of the original inputs.
622
- """
623
- if self.seq_kd:
624
- with unwrap_model_for_generation(self.teacher_model, self.accelerator) as unwrapped_model:
625
- new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs(
626
- unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id
627
- )
628
- inputs["input_ids"] = new_input_ids
629
- inputs["attention_mask"] = new_attention_mask
630
- inputs["labels"] = new_labels
631
- if random.random() <= self.lmbda:
632
- with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
633
- new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs(
634
- unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id
635
- )
636
- inputs["input_ids"] = new_input_ids
637
- inputs["attention_mask"] = new_attention_mask
638
- inputs["labels"] = new_labels
639
-
640
- loss = super().training_step(model, inputs, num_items_in_batch)
641
- return loss
642
-
643
- def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
644
- # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
645
- deepspeed_plugin = self.accelerator.state.deepspeed_plugin
646
- config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
647
-
648
- if model is not None:
649
- if hasattr(model, "config"):
650
- hidden_size = (
651
- max(model.config.hidden_sizes)
652
- if getattr(model.config, "hidden_sizes", None)
653
- else getattr(model.config, "hidden_size", None)
654
- )
655
- if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
656
- # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
657
- # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
658
- config_kwargs.update(
659
- {
660
- "zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
661
- "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
662
- "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
663
- }
664
- )
665
-
666
- # If ZeRO-3 is used, we shard both the active and reference model.
667
- # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
668
- if config_kwargs["zero_optimization"]["stage"] != 3:
669
- config_kwargs["zero_optimization"]["stage"] = 0
670
- model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
671
- model.eval()
672
- return model
673
-
674
- def create_model_card(
675
- self,
676
- model_name: Optional[str] = None,
677
- dataset_name: Optional[str] = None,
678
- tags: Union[str, list[str], None] = None,
679
- ):
680
- """
681
- Creates a draft of a model card using the information available to the `Trainer`.
682
-
683
- Args:
684
- model_name (`str` or `None`, *optional*, defaults to `None`):
685
- Name of the model.
686
- dataset_name (`str` or `None`, *optional*, defaults to `None`):
687
- Name of the dataset used for training.
688
- tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
689
- Tags to be associated with the model card.
690
- """
691
- if not self.is_world_process_zero():
692
- return
693
-
694
- if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
695
- base_model = self.model.config._name_or_path
696
- else:
697
- base_model = None
698
-
699
- tags = tags or []
700
- if isinstance(tags, str):
701
- tags = [tags]
702
-
703
- if hasattr(self.model.config, "unsloth_version"):
704
- tags.append("unsloth")
705
-
706
- citation = textwrap.dedent("""\
707
- @inproceedings{agarwal2024on-policy,
708
- title = {{On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes}},
709
- author = {Rishabh Agarwal and Nino Vieillard and Yongchao Zhou and Piotr Stanczyk and Sabela Ramos Garea and Matthieu Geist and Olivier Bachem},
710
- year = 2024,
711
- booktitle = {The Twelfth International Conference on Learning Representations, {ICLR} 2024, Vienna, Austria, May 7-11, 2024},
712
- publisher = {OpenReview.net},
713
- url = {https://openreview.net/forum?id=3zKtaqxLhW},
714
- }""")
715
-
716
- model_card = generate_model_card(
717
- base_model=base_model,
718
- model_name=model_name,
719
- hub_model_id=self.hub_model_id,
720
- dataset_name=dataset_name,
721
- tags=tags,
722
- wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
723
- comet_url=get_comet_experiment_url(),
724
- trainer_name="GKD",
725
- trainer_citation=citation,
726
- paper_title="On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes",
727
- paper_id="2306.13649",
728
- )
729
-
730
- model_card.save(os.path.join(self.args.output_dir, "README.md"))
731
- class UnslothGKDTrainer(_UnslothGKDTrainer):
732
- """
733
-
734
- """
735
- def __init__(
736
- self,
737
- model = None,
738
- teacher_model = None,
739
- args = None,
740
- data_collator = None,
741
- train_dataset = None,
742
- eval_dataset = None,
743
- processing_class = None,
744
- compute_metrics = None,
745
- callbacks = None,
746
- preprocess_logits_for_metrics = None,
747
- peft_config = None,
748
- formatting_func = None,
749
- **kwargs
750
- ):
751
- if args is None: args = UnslothGKDConfig()
752
- use_bf16 = getattr(args, 'bf16', False)
753
- use_fp16 = getattr(args, 'fp16', False)
754
- force_float32 = False
755
- if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
756
- print('Unsloth: Switching to float32 training since model cannot work with float16')
757
- force_float32 = True
758
- mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
759
- dtype = getattr(model.config, 'torch_dtype', None)
760
- if dtype is None: dtype = model.get_input_embeddings().dtype
761
- from unsloth_zoo.utils import _get_dtype
762
- dtype = _get_dtype(dtype)
763
- float16 = dtype == torch.float16
764
- if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
765
- if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
766
- if force_float32:
767
- args.fp16 = False
768
- args.bf16 = False
769
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
770
- elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
771
- args.fp16 = float16
772
- args.bf16 = not float16
773
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
774
- if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
775
- args.eval_strategy = 'steps'
776
- if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
777
- ga_steps = getattr(args, 'gradient_accumulation_steps', None)
778
- if ga_steps is not None and ga_steps > 1:
779
- from transformers import __version__ as transformers_version
780
- if Version(transformers_version) <= Version('4.45.2'):
781
- print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
782
- '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
783
- if getattr(args, 'eval_strategy', 'no') != 'no':
784
- eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
785
- if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
786
- if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
787
- fp16_full_eval = getattr(args, 'fp16_full_eval', False)
788
- bf16_full_eval = getattr(args, 'bf16_full_eval', False)
789
- if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
790
- if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
791
- if force_float32:
792
- args.bf16_full_eval = False
793
- args.fp16_full_eval = False
794
- elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
795
- args.bf16_full_eval = True
796
- args.fp16_full_eval = False
797
- elif not bf16_full_eval and not fp16_full_eval:
798
- args.bf16_full_eval = args.bf16
799
- args.fp16_full_eval = args.fp16
800
- _output_logits = False
801
- if locals().get('compute_metrics', None) is not None: _output_logits = True
802
- if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
803
- if _output_logits:
804
- os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
805
- if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
806
- pass
807
- else:
808
- model_max_seq_length = getattr(model, 'max_seq_length', None)
809
- args_max_seq_length = getattr(args, 'max_seq_length', None)
810
- if args_max_seq_length is None and model_max_seq_length is not None:
811
- max_seq_length = model.max_seq_length
812
- if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
813
- if model is not None and hasattr(model, 'for_training'):
814
- model.for_training()
815
- if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
816
- if 'processing_class' in locals():
817
- if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
818
- if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
819
- __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
820
- from unsloth_zoo.vision_utils import UnslothVisionDataCollator
821
- if not isinstance(data_collator, UnslothVisionDataCollator):
822
- if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
823
- data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
824
- elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
825
- data_collator = DataCollatorForSeq2Seq(__tokenizer)
826
- else:
827
- if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
828
- if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
829
- if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
830
- if not isinstance(data_collator, UnslothVisionDataCollator):
831
- if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
832
- if isinstance(data_collator, DataCollatorForSeq2Seq):
833
- data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
834
- else:
835
- data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
836
- other_metrics = []
837
-
838
- from unsloth_zoo.logging_utils import PatchRLStatistics
839
- PatchRLStatistics('gkd_trainer', other_metrics)
840
-
841
- super().__init__(
842
- model = model,
843
- teacher_model = teacher_model,
844
- args = args,
845
- data_collator = data_collator,
846
- train_dataset = train_dataset,
847
- eval_dataset = eval_dataset,
848
- processing_class = processing_class,
849
- compute_metrics = compute_metrics,
850
- callbacks = callbacks,
851
- preprocess_logits_for_metrics = preprocess_logits_for_metrics,
852
- peft_config = peft_config,
853
- formatting_func = formatting_func,**kwargs)
854
- if hasattr(self, 'neftune_hook_handle'):
855
- self.neftune_hook_handle.remove()
856
- if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
857
- if getattr(args, 'neftune_noise_alpha', None) is not None:
858
- model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
859
- pass
860
-
861
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
unsloth_compiled_cache/UnslothGRPOTrainer.py DELETED
@@ -1,1436 +0,0 @@
1
- """
2
- 2025.3.13
3
- 2025.3.15
4
- 4.48.3
5
- 0.15.2
6
- __UNSLOTH_VERSIONING__
7
- """
8
- from torch import Tensor
9
- import torch
10
- import torch.nn as nn
11
- from torch.nn import functional as F
12
- from trl.trainer.grpo_trainer import (Any, AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, Dataset, GRPOConfig, GRPOTrainer, GenerationConfig, IterableDataset, Optional, PeftConfig, PreTrainedModel, PreTrainedTokenizerBase, RepeatRandomSampler, RewardFunc, Sampler, SyncRefModelCallback, Trainer, TrainerCallback, Union, apply_chat_template, broadcast_object_list, create_reference_model, defaultdict, gather, gather_object, generate_model_card, get_comet_experiment_url, is_conversational, is_deepspeed_zero3_enabled, is_peft_model, is_wandb_available, maybe_apply_chat_template, nn, os, pad, patch, prepare_deepspeed, set_seed, textwrap, torch, transformers, unwrap_model_for_generation, version, wandb, warnings, os, torch, transformers, Any, Union, apply_chat_template, broadcast_object_list, gather, gather_object, is_conversational, maybe_apply_chat_template, nn, os, pad, torch, unwrap_model_for_generation, wandb, GRPOTrainer, Trainer, gather, os, torch)
13
-
14
-
15
- import os
16
- from typing import *
17
- from dataclasses import dataclass, field
18
- from packaging.version import Version
19
- import torch
20
- import numpy as np
21
- from contextlib import nullcontext
22
- from torch.nn import functional as F
23
- from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
-
25
- torch_compile_options = {
26
- "epilogue_fusion" : True,
27
- "max_autotune" : False,
28
- "shape_padding" : True,
29
- "trace.enabled" : False,
30
- "triton.cudagraphs" : False,
31
- }
32
-
33
- @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
- def selective_log_softmax(logits, index):
35
- logits = logits.to(torch.float32)
36
- selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
- # loop to reduce peak mem consumption
38
- # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
- logsumexp_values = torch.logsumexp(logits, dim = -1)
40
- per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
- return per_token_logps
42
-
43
- def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta, advantages):
44
- # All Unsloth Zoo code licensed under LGPLv3
45
- old_logits = old_logits.to(torch.float32)
46
- new_logits = new_logits.to(torch.float32)
47
- input_ids = input_ids.unsqueeze(-1)
48
-
49
- # x_i - logsumexp(x_i)
50
- old_x = torch.gather(old_logits, dim = -1, index = input_ids).squeeze(-1)
51
- new_x = torch.gather(new_logits, dim = -1, index = input_ids).squeeze(-1)
52
- old = old_x - torch.logsumexp(old_logits, dim = -1)
53
- new = new_x - torch.logsumexp(new_logits, dim = -1)
54
-
55
- # Reverse KL
56
- kl_i = torch.exp(old - new) - (old - new) - 1.0
57
- # Full correct reverse KL divergence?? Missing term maybe?
58
- # kl_i = torch.exp(new) * kl_i
59
-
60
- # Below is forward KL (normal KL)
61
- # kl_i = torch.exp(old) * (old - new)
62
-
63
- # Must detach - otherwise gradients are not propagated correctly!
64
- # exp(x - x) == 1
65
- loss_i = torch.exp(new - new.detach()) * advantages.unsqueeze(1)
66
- loss_i = -(loss_i - beta * kl_i)
67
-
68
- mask = mask.to(torch.float32)
69
- n_mask_per_reward = mask.sum(1)
70
-
71
- # See https://github.com/huggingface/trl/pull/2881
72
- loss_per_reward = (loss_i * mask).sum(1) / n_mask_per_reward
73
- loss = loss_per_reward.mean()
74
- # loss = (loss_i * mask).sum() / mask.sum()
75
-
76
- # Get metrics as well which are folded
77
- with torch.inference_mode():
78
- completion_length = n_mask_per_reward.mean()
79
- mean_kl_per_reward = (kl_i * mask).sum(1) / n_mask_per_reward
80
- mean_kl = mean_kl_per_reward.mean()
81
- pass
82
- return loss, completion_length, mean_kl
83
-
84
- class UnslothEfficientGRPO(torch.autograd.Function):
85
- # All Unsloth Zoo code licensed under LGPLv3
86
- @staticmethod
87
- def forward(ctx, _new_hidden_states, _old_hidden_states, lm_head, _input_ids, _mask, _advantages, beta, scaler = None, n_chunks = 1):
88
- def compute_loss(new_hidden_states, old_hidden_states, input_ids, mask, advantages, scaling):
89
- new_logits = torch.matmul(new_hidden_states, lm_head.t())
90
- new_logits = new_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred
91
- old_logits = torch.matmul(old_hidden_states, lm_head.t())
92
- old_logits = old_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred
93
- loss, completion_length, mean_kl = grpo_compute_loss(
94
- old_logits, new_logits, input_ids, mask, beta, advantages,
95
- )
96
- # Scale loss if needed for mixed precision training
97
- scaled_loss = loss * scaling
98
- # Must add .loss.detach otherwise autograd uses 2x VRAM
99
- return scaled_loss, (loss.detach(), completion_length, mean_kl,)
100
- pass
101
-
102
- device =_new_hidden_states.device
103
- grad_inputs = torch.empty_like(_new_hidden_states)
104
- accumulated_loss = torch.zeros(1, device = device)
105
- accumulated_completion_length = torch.zeros(1, device = device)
106
- accumulated_mean_kl = torch.zeros(1, device = device)
107
-
108
- def accumulate_chunk(new_hidden_states_j, old_hidden_states_j, input_ids_j, mask_j, advantages_j, scaling):
109
- (chunk_grad_input,), (chunk_loss, (unscaled_loss, chunk_completion_length, chunk_mean_kl,)) = torch.func.grad_and_value(
110
- compute_loss,
111
- argnums = (0,),
112
- has_aux = True,
113
- )(new_hidden_states_j, old_hidden_states_j, input_ids_j, mask_j, advantages_j, scaling)
114
- accumulated_loss .add_(unscaled_loss)
115
- accumulated_completion_length.add_(chunk_completion_length)
116
- accumulated_mean_kl .add_(chunk_mean_kl)
117
- return chunk_grad_input
118
- pass
119
-
120
- accumulate_chunk = torch.compile(
121
- accumulate_chunk,
122
- fullgraph = True,
123
- options = torch_compile_options,
124
- )
125
-
126
- grad_inputs_chunks = torch.chunk(grad_inputs, chunks = n_chunks, dim = 0)
127
- new_hidden_states = torch.chunk(_new_hidden_states, chunks = n_chunks, dim = 0)
128
- old_hidden_states = torch.chunk(_old_hidden_states, chunks = n_chunks, dim = 0)
129
- input_ids = torch.chunk(_input_ids, chunks = n_chunks, dim = 0)
130
- mask = torch.chunk(_mask, chunks = n_chunks, dim = 0)
131
- advantages = torch.chunk(_advantages, chunks = n_chunks, dim = 0)
132
-
133
- # Get mixed precision scaling if seen
134
- scaling = scaler.get_scale() if scaler is not None else 1.0
135
-
136
- # Force torch.compile to use dynamic shapes for seqlen dim
137
- mark_dynamic = lambda x: torch._dynamo.mark_dynamic(x, 1)
138
-
139
- for (grad_inputs_j, new_hidden_states_j, old_hidden_states_j, input_ids_j, mask_j, advantages_j,) in \
140
- zip(grad_inputs_chunks, new_hidden_states, old_hidden_states, input_ids, mask, advantages):
141
-
142
- mark_dynamic(new_hidden_states_j)
143
- mark_dynamic(old_hidden_states_j)
144
- mark_dynamic(input_ids_j)
145
- mark_dynamic(mask_j)
146
-
147
- grad_inputs_j.copy_(
148
- accumulate_chunk(new_hidden_states_j, old_hidden_states_j, input_ids_j, mask_j, advantages_j, scaling)
149
- )
150
- pass
151
-
152
- grad_inputs .div_(n_chunks)
153
- accumulated_loss .div_(n_chunks)
154
- accumulated_completion_length.div_(n_chunks)
155
- accumulated_mean_kl .div_(n_chunks)
156
- ctx.save_for_backward(grad_inputs)
157
-
158
- return (
159
- accumulated_loss,
160
- accumulated_completion_length,
161
- accumulated_mean_kl,
162
- )
163
- pass
164
-
165
- @staticmethod
166
- def backward(ctx, grad_output, dcompletion_length, dmean_kl):
167
- (grad_input,) = ctx.saved_tensors
168
- return (grad_input, None, None, None, None, None, None, None, None,)
169
- pass
170
-
171
- def grpo_accumulated_loss(
172
- trainer,
173
- input_ids,
174
- logits_to_keep,
175
- completion_mask,
176
- advantages,
177
- n_chunks = -1,
178
- ):
179
- # All Unsloth Zoo code licensed under LGPLv3
180
- bsz, qlen = input_ids.shape
181
- # Find closest multiple
182
- factors = [i for i in range(1, bsz + 1) if bsz % i == 0]
183
- if n_chunks == -1: n_chunks = bsz
184
- n_chunks = factors[min(np.searchsorted(factors, n_chunks), len(factors)-1)]
185
-
186
- mixed_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16
187
- os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1"
188
-
189
- completion_input_ids = input_ids[:, -logits_to_keep:]
190
- lm_head = trainer.model.get_output_embeddings().weight
191
-
192
- with torch.amp.autocast(device_type = "cuda", dtype = mixed_dtype):
193
- with torch.inference_mode(), trainer.accelerator.unwrap_model(trainer.model, keep_fp32_wrapper = False).disable_adapter():
194
- old_hidden_states = trainer.model(input_ids = input_ids, logits_to_keep = logits_to_keep + 1).logits
195
- pass
196
-
197
- new_hidden_states = trainer.model(input_ids = input_ids, logits_to_keep = logits_to_keep + 1).logits
198
-
199
- loss, completion_length, mean_kl = UnslothEfficientGRPO.apply(
200
- new_hidden_states, old_hidden_states, lm_head,
201
- completion_input_ids, completion_mask, advantages, trainer.beta,
202
- trainer.accelerator.scaler,
203
- n_chunks,
204
- )
205
- return loss, completion_length, mean_kl
206
-
207
- # Old non efficient code path
208
- new_logits = torch.matmul(new_hidden_states, lm_head.t())
209
- new_logits = new_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred
210
- old_logits = torch.matmul(old_hidden_states, lm_head.t())
211
- old_logits = old_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred
212
- loss, completion_length, mean_kl = grpo_compute_loss(
213
- old_logits, new_logits, completion_input_ids, completion_mask, trainer.beta, advantages,
214
- )
215
- return loss, completion_length, mean_kl
216
- pass
217
-
218
- @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options)
219
- def grpo_compute_loss_slow(old_logits, new_logits, input_ids, mask, beta, advantages):
220
- # All Unsloth Zoo code licensed under LGPLv3
221
- old_logits = old_logits.to(torch.float32)
222
- new_logits = new_logits.to(torch.float32)
223
- input_ids = input_ids.unsqueeze(-1)
224
-
225
- # x_i - logsumexp(x_i)
226
- old_x = torch.gather(old_logits, dim = -1, index = input_ids).squeeze(-1)
227
- new_x = torch.gather(new_logits, dim = -1, index = input_ids).squeeze(-1)
228
- old = old_x - torch.logsumexp(old_logits, dim = -1)
229
- new = new_x - torch.logsumexp(new_logits, dim = -1)
230
-
231
- # Reverse KL
232
- kl_i = torch.exp(old - new) - (old - new) - 1.0
233
- # Full correct reverse KL divergence?? Missing term maybe?
234
- # kl_i = torch.exp(new) * kl_i
235
-
236
- # Below is forward KL (normal KL)
237
- # kl_i = torch.exp(old) * (old - new)
238
-
239
- # Must detach - otherwise gradients are not propagated correctly!
240
- # exp(x - x) == 1
241
- loss_i = torch.exp(new - new.detach()) * advantages.unsqueeze(1)
242
- loss_i = -(loss_i - beta * kl_i)
243
-
244
- mask = mask.to(torch.float32)
245
- n_mask_per_reward = mask.sum(1)
246
-
247
- # See https://github.com/huggingface/trl/pull/2881
248
- loss_per_reward = (loss_i * mask).sum(1) / n_mask_per_reward
249
- loss = loss_per_reward.mean()
250
- # loss = (loss_i * mask).sum() / mask.sum()
251
-
252
- # Get metrics as well which are folded
253
- with torch.inference_mode():
254
- completion_length = n_mask_per_reward.mean()
255
- mean_kl_per_reward = (kl_i * mask).sum(1) / n_mask_per_reward
256
- mean_kl = mean_kl_per_reward.mean()
257
- pass
258
- return loss, completion_length, mean_kl
259
-
260
- def vLLMSamplingParams(**kwargs):
261
- from vllm import SamplingParams
262
- sampling_params = SamplingParams(**kwargs)
263
- sampling_params._set_kwargs = kwargs
264
- return sampling_params
265
- @dataclass
266
- class UnslothGRPOConfig(GRPOConfig):
267
- """
268
-
269
- Configuration class for the [`GRPOTrainer`].
270
-
271
- Only the parameters specific to GRPO training are listed here. For details on other parameters, refer to the
272
- [`~transformers.TrainingArguments`] documentation.
273
-
274
- Using [`~transformers.HfArgumentParser`] we can turn this class into
275
- [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
276
- command line.
277
-
278
- Parameters:
279
- > Parameters that control the model and reference model
280
-
281
- model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
282
- Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model`
283
- argument of the [`GRPOTrainer`] is provided as a string.
284
-
285
- > Parameters that control the data preprocessing
286
-
287
- remove_unused_columns (`bool`, *optional*, defaults to `False`):
288
- Whether to only keep the column `"prompt"` in the dataset. If you use a custom reward function that
289
- requires any column other than `"prompts"` and `"completions"`, you should keep this to `False`.
290
- max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
291
- Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left.
292
- num_generations (`int` or `None`, *optional*, defaults to `8`):
293
- Number of generations per prompt to sample. The global batch size (num_processes * per_device_batch_size)
294
- must be divisible by this value.
295
- temperature (`float`, *optional*, defaults to `0.9`):
296
- Temperature for sampling. The higher the temperature, the more random the completions.
297
- max_completion_length (`int` or `None`, *optional*, defaults to `256`):
298
- Maximum length of the generated completion.
299
- ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
300
- This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
301
- improving generation speed. However, disabling this option allows training models that exceed the VRAM
302
- capacity of a single GPU, albeit at the cost of slower generation. Disabling this option is not compatible
303
- with vLLM generation.
304
-
305
- > Parameters that control generation acceleration powered by vLLM
306
-
307
- use_vllm (`bool`, *optional*, defaults to `False`):
308
- Whether to use vLLM for generating completions. If set to `True`, ensure that a GPU is kept unused for
309
- training, as vLLM will require one for generation. vLLM must be installed (`pip install vllm`).
310
- vllm_device (`str`, *optional*, defaults to `"auto"`):
311
- Device where vLLM generation will run, e.g. `"cuda:1"`. If set to `"auto"` (default), the system will
312
- automatically select the next available GPU after the last one used for training. This assumes that
313
- training has not already occupied all available GPUs. If only one device is available, the device will be
314
- shared between both training and vLLM.
315
- vllm_gpu_memory_utilization (`float`, *optional*, defaults to `0.9`):
316
- Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache on the
317
- device dedicated to generation powered by vLLM. Higher values will increase the KV cache size and thus
318
- improve the model's throughput. However, if the value is too high, it may cause out-of-memory (OOM) errors
319
- during initialization.
320
- vllm_dtype (`str`, *optional*, defaults to `"auto"`):
321
- Data type to use for vLLM generation. If set to `"auto"`, the data type will be automatically determined
322
- based on the model configuration. Find the supported values in the vLLM documentation.
323
- vllm_max_model_len (`int` or `None`, *optional*, defaults to `None`):
324
- If set, the `max_model_len` to use for vLLM. This could be useful when running with reduced
325
- `vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model
326
- context size, which might be much larger than the KV cache, leading to inefficiencies.
327
-
328
- > Parameters that control the training
329
-
330
- learning_rate (`float`, *optional*, defaults to `1e-6`):
331
- Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
332
- [`~transformers.TrainingArguments`].
333
- beta (`float`, *optional*, defaults to `0.04`):
334
- KL coefficient.
335
- reward_weights (`list[float]` or `None`, *optional*, defaults to `None`):
336
- Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are
337
- weighted equally with weight `1.0`.
338
- sync_ref_model (`bool`, *optional*, defaults to `False`):
339
- Whether to synchronize the reference model with the active model every `ref_model_sync_steps` steps, using
340
- the `ref_model_mixup_alpha` parameter. This synchronization originites from the
341
- [TR-DPO](https://huggingface.co/papers/2404.09656) paper.
342
- ref_model_mixup_alpha (`float`, *optional*, defaults to `0.9`):
343
- α parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which controls the mix
344
- between the current policy and the previous reference policy during updates. The reference policy is
345
- updated according to the equation: `π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you
346
- must set `sync_ref_model=True`.
347
- ref_model_sync_steps (`int`, *optional*, defaults to `64`):
348
- τ parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which determines how
349
- frequently the current policy is synchronized with the reference policy. To use this parameter, you must
350
- set `sync_ref_model=True`.
351
-
352
- > Parameters that control the logging
353
-
354
- log_completions (`bool`, *optional*, defaults to `False`):
355
- Whether to log the completions during training.
356
-
357
- """
358
- vllm_sampling_params: Optional[Any] = field(
359
- default = None,
360
- metadata = {'help': 'vLLM SamplingParams'},
361
- )
362
- unsloth_num_chunks : Optional[int] = field(
363
- default = -1,
364
- metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
365
- )
366
- def __init__(
367
- self,
368
- output_dir = None,
369
- overwrite_output_dir = None,
370
- do_train = False,
371
- do_eval = False,
372
- do_predict = False,
373
- eval_strategy = 'no',
374
- prediction_loss_only = False,
375
- per_device_train_batch_size = 4,
376
- per_device_eval_batch_size = 4,
377
- per_gpu_train_batch_size = None,
378
- per_gpu_eval_batch_size = None,
379
- gradient_accumulation_steps = 2,
380
- eval_accumulation_steps = 2,
381
- eval_delay = 0,
382
- torch_empty_cache_steps = 250,
383
- learning_rate = 5e-05,
384
- weight_decay = 0.01,
385
- adam_beta1 = 0.9,
386
- adam_beta2 = 0.999,
387
- adam_epsilon = 1e-08,
388
- max_grad_norm = 1.0,
389
- num_train_epochs = 3.0,
390
- max_steps = -1,
391
- lr_scheduler_type = 'linear',
392
- warmup_ratio = 0.1,
393
- warmup_steps = 0,
394
- log_level = 'passive',
395
- log_level_replica = 'warning',
396
- log_on_each_node = True,
397
- logging_dir = None,
398
- logging_strategy = 'steps',
399
- logging_first_step = False,
400
- logging_steps = 1,
401
- logging_nan_inf_filter = False,
402
- save_strategy = 'steps',
403
- save_steps = 500,
404
- save_total_limit = None,
405
- save_safetensors = True,
406
- save_on_each_node = False,
407
- save_only_model = False,
408
- restore_callback_states_from_checkpoint = False,
409
- no_cuda = False,
410
- use_cpu = False,
411
- use_mps_device = False,
412
- seed = 3407,
413
- data_seed = 3407,
414
- jit_mode_eval = False,
415
- use_ipex = False,
416
- bf16 = False,
417
- fp16 = False,
418
- fp16_opt_level = 'O1',
419
- half_precision_backend = 'auto',
420
- bf16_full_eval = False,
421
- fp16_full_eval = False,
422
- tf32 = None,
423
- local_rank = -1,
424
- ddp_backend = None,
425
- tpu_num_cores = None,
426
- tpu_metrics_debug = False,
427
- debug = '',
428
- dataloader_drop_last = False,
429
- eval_steps = None,
430
- dataloader_num_workers = 0,
431
- dataloader_prefetch_factor = None,
432
- past_index = -1,
433
- run_name = None,
434
- disable_tqdm = None,
435
- remove_unused_columns = False,
436
- label_names = None,
437
- load_best_model_at_end = False,
438
- metric_for_best_model = None,
439
- greater_is_better = None,
440
- ignore_data_skip = False,
441
- fsdp = '',
442
- fsdp_min_num_params = 0,
443
- fsdp_config = None,
444
- fsdp_transformer_layer_cls_to_wrap = None,
445
- accelerator_config = None,
446
- deepspeed = None,
447
- label_smoothing_factor = 0.0,
448
- optim = 'adamw_8bit',
449
- optim_args = None,
450
- adafactor = False,
451
- group_by_length = False,
452
- length_column_name = 'length',
453
- report_to = None,
454
- ddp_find_unused_parameters = None,
455
- ddp_bucket_cap_mb = None,
456
- ddp_broadcast_buffers = None,
457
- dataloader_pin_memory = True,
458
- dataloader_persistent_workers = False,
459
- skip_memory_metrics = True,
460
- use_legacy_prediction_loop = False,
461
- push_to_hub = False,
462
- resume_from_checkpoint = None,
463
- hub_model_id = None,
464
- hub_strategy = 'every_save',
465
- hub_token = None,
466
- hub_private_repo = None,
467
- hub_always_push = False,
468
- gradient_checkpointing = False,
469
- gradient_checkpointing_kwargs = None,
470
- include_inputs_for_metrics = False,
471
- eval_do_concat_batches = True,
472
- fp16_backend = 'auto',
473
- evaluation_strategy = None,
474
- push_to_hub_model_id = None,
475
- push_to_hub_organization = None,
476
- push_to_hub_token = None,
477
- mp_parameters = '',
478
- auto_find_batch_size = False,
479
- full_determinism = False,
480
- torchdynamo = None,
481
- ray_scope = 'last',
482
- ddp_timeout = 1800,
483
- torch_compile = False,
484
- torch_compile_backend = None,
485
- torch_compile_mode = None,
486
- dispatch_batches = None,
487
- split_batches = None,
488
- include_tokens_per_second = False,
489
- include_num_input_tokens_seen = False,
490
- neftune_noise_alpha = None,
491
- optim_target_modules = None,
492
- batch_eval_metrics = False,
493
- eval_on_start = False,
494
- use_liger_kernel = False,
495
- eval_use_gather_object = False,
496
- average_tokens_across_devices = False,
497
- model_init_kwargs = None,
498
- max_prompt_length = 512,
499
- num_generations = 8,
500
- temperature = 0.9,
501
- max_completion_length = 256,
502
- ds3_gather_for_generation = True,
503
- use_vllm = False,
504
- vllm_device = 'auto',
505
- vllm_gpu_memory_utilization = 0.9,
506
- vllm_dtype = 'auto',
507
- vllm_max_model_len = None,
508
- beta = 0.04,
509
- reward_weights = None,
510
- sync_ref_model = False,
511
- ref_model_mixup_alpha = 0.9,
512
- ref_model_sync_steps = 64,
513
- log_completions = False,
514
- vllm_sampling_params = None,
515
- unsloth_num_chunks = -1,
516
- **kwargs,
517
- ):
518
- if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
519
- if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
520
- if output_dir is None and save_strategy == 'steps' and save_steps == 500:
521
- output_dir = 'unsloth_training_checkpoints'
522
- save_strategy = 'no'
523
- div = per_device_train_batch_size // num_generations
524
- if div * num_generations != per_device_train_batch_size:
525
- print('Unsloth: We now expect `per_device_train_batch_size` to be a multiple of `num_generations`.\nWe will change the batch size of ' + str(per_device_train_batch_size) + ' to the `num_generations` of ' + str(num_generations))
526
- per_device_train_batch_size = num_generations
527
-
528
- super().__init__(
529
- output_dir = output_dir,
530
- overwrite_output_dir = overwrite_output_dir,
531
- do_train = do_train,
532
- do_eval = do_eval,
533
- do_predict = do_predict,
534
- eval_strategy = eval_strategy,
535
- prediction_loss_only = prediction_loss_only,
536
- per_device_train_batch_size = per_device_train_batch_size,
537
- per_device_eval_batch_size = per_device_eval_batch_size,
538
- per_gpu_train_batch_size = per_gpu_train_batch_size,
539
- per_gpu_eval_batch_size = per_gpu_eval_batch_size,
540
- gradient_accumulation_steps = gradient_accumulation_steps,
541
- eval_accumulation_steps = eval_accumulation_steps,
542
- eval_delay = eval_delay,
543
- torch_empty_cache_steps = torch_empty_cache_steps,
544
- learning_rate = learning_rate,
545
- weight_decay = weight_decay,
546
- adam_beta1 = adam_beta1,
547
- adam_beta2 = adam_beta2,
548
- adam_epsilon = adam_epsilon,
549
- max_grad_norm = max_grad_norm,
550
- num_train_epochs = num_train_epochs,
551
- max_steps = max_steps,
552
- lr_scheduler_type = lr_scheduler_type,
553
- warmup_ratio = warmup_ratio,
554
- warmup_steps = warmup_steps,
555
- log_level = log_level,
556
- log_level_replica = log_level_replica,
557
- log_on_each_node = log_on_each_node,
558
- logging_dir = logging_dir,
559
- logging_strategy = logging_strategy,
560
- logging_first_step = logging_first_step,
561
- logging_steps = logging_steps,
562
- logging_nan_inf_filter = logging_nan_inf_filter,
563
- save_strategy = save_strategy,
564
- save_steps = save_steps,
565
- save_total_limit = save_total_limit,
566
- save_safetensors = save_safetensors,
567
- save_on_each_node = save_on_each_node,
568
- save_only_model = save_only_model,
569
- restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
570
- no_cuda = no_cuda,
571
- use_cpu = use_cpu,
572
- use_mps_device = use_mps_device,
573
- seed = seed,
574
- data_seed = data_seed,
575
- jit_mode_eval = jit_mode_eval,
576
- use_ipex = use_ipex,
577
- bf16 = bf16,
578
- fp16 = fp16,
579
- fp16_opt_level = fp16_opt_level,
580
- half_precision_backend = half_precision_backend,
581
- bf16_full_eval = bf16_full_eval,
582
- fp16_full_eval = fp16_full_eval,
583
- tf32 = tf32,
584
- local_rank = local_rank,
585
- ddp_backend = ddp_backend,
586
- tpu_num_cores = tpu_num_cores,
587
- tpu_metrics_debug = tpu_metrics_debug,
588
- debug = debug,
589
- dataloader_drop_last = dataloader_drop_last,
590
- eval_steps = eval_steps,
591
- dataloader_num_workers = dataloader_num_workers,
592
- dataloader_prefetch_factor = dataloader_prefetch_factor,
593
- past_index = past_index,
594
- run_name = run_name,
595
- disable_tqdm = disable_tqdm,
596
- remove_unused_columns = remove_unused_columns,
597
- label_names = label_names,
598
- load_best_model_at_end = load_best_model_at_end,
599
- metric_for_best_model = metric_for_best_model,
600
- greater_is_better = greater_is_better,
601
- ignore_data_skip = ignore_data_skip,
602
- fsdp = fsdp,
603
- fsdp_min_num_params = fsdp_min_num_params,
604
- fsdp_config = fsdp_config,
605
- fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
606
- accelerator_config = accelerator_config,
607
- deepspeed = deepspeed,
608
- label_smoothing_factor = label_smoothing_factor,
609
- optim = optim,
610
- optim_args = optim_args,
611
- adafactor = adafactor,
612
- group_by_length = group_by_length,
613
- length_column_name = length_column_name,
614
- report_to = report_to,
615
- ddp_find_unused_parameters = ddp_find_unused_parameters,
616
- ddp_bucket_cap_mb = ddp_bucket_cap_mb,
617
- ddp_broadcast_buffers = ddp_broadcast_buffers,
618
- dataloader_pin_memory = dataloader_pin_memory,
619
- dataloader_persistent_workers = dataloader_persistent_workers,
620
- skip_memory_metrics = skip_memory_metrics,
621
- use_legacy_prediction_loop = use_legacy_prediction_loop,
622
- push_to_hub = push_to_hub,
623
- resume_from_checkpoint = resume_from_checkpoint,
624
- hub_model_id = hub_model_id,
625
- hub_strategy = hub_strategy,
626
- hub_token = hub_token,
627
- hub_private_repo = hub_private_repo,
628
- hub_always_push = hub_always_push,
629
- gradient_checkpointing = gradient_checkpointing,
630
- gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
631
- include_inputs_for_metrics = include_inputs_for_metrics,
632
- eval_do_concat_batches = eval_do_concat_batches,
633
- fp16_backend = fp16_backend,
634
- evaluation_strategy = evaluation_strategy,
635
- push_to_hub_model_id = push_to_hub_model_id,
636
- push_to_hub_organization = push_to_hub_organization,
637
- push_to_hub_token = push_to_hub_token,
638
- mp_parameters = mp_parameters,
639
- auto_find_batch_size = auto_find_batch_size,
640
- full_determinism = full_determinism,
641
- torchdynamo = torchdynamo,
642
- ray_scope = ray_scope,
643
- ddp_timeout = ddp_timeout,
644
- torch_compile = torch_compile,
645
- torch_compile_backend = torch_compile_backend,
646
- torch_compile_mode = torch_compile_mode,
647
- dispatch_batches = dispatch_batches,
648
- split_batches = split_batches,
649
- include_tokens_per_second = include_tokens_per_second,
650
- include_num_input_tokens_seen = include_num_input_tokens_seen,
651
- neftune_noise_alpha = neftune_noise_alpha,
652
- optim_target_modules = optim_target_modules,
653
- batch_eval_metrics = batch_eval_metrics,
654
- eval_on_start = eval_on_start,
655
- use_liger_kernel = use_liger_kernel,
656
- eval_use_gather_object = eval_use_gather_object,
657
- average_tokens_across_devices = average_tokens_across_devices,
658
- model_init_kwargs = model_init_kwargs,
659
- max_prompt_length = max_prompt_length,
660
- num_generations = num_generations,
661
- temperature = temperature,
662
- max_completion_length = max_completion_length,
663
- ds3_gather_for_generation = ds3_gather_for_generation,
664
- use_vllm = use_vllm,
665
- vllm_device = vllm_device,
666
- vllm_gpu_memory_utilization = vllm_gpu_memory_utilization,
667
- vllm_dtype = vllm_dtype,
668
- vllm_max_model_len = vllm_max_model_len,
669
- beta = beta,
670
- reward_weights = reward_weights,
671
- sync_ref_model = sync_ref_model,
672
- ref_model_mixup_alpha = ref_model_mixup_alpha,
673
- ref_model_sync_steps = ref_model_sync_steps,
674
- log_completions = log_completions,**kwargs)
675
- self.vllm_sampling_params = vllm_sampling_params
676
- self.unsloth_num_chunks = unsloth_num_chunks
677
- pass
678
-
679
- class _UnslothGRPOTrainer(Trainer):
680
- """"""
681
-
682
- _tag_names = ["trl", "grpo"]
683
-
684
- def __init__(
685
- self,
686
- model: Union[str, PreTrainedModel],
687
- reward_funcs: Union[RewardFunc, list[RewardFunc]],
688
- args: GRPOConfig = None,
689
- train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
690
- eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None,
691
- processing_class: Optional[PreTrainedTokenizerBase] = None,
692
- reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None,
693
- callbacks: Optional[list[TrainerCallback]] = None,
694
- optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
695
- peft_config: Optional["PeftConfig"] = None,
696
- ):
697
-
698
- if hasattr(model, 'vllm_engine') and hasattr(args, 'use_vllm') and (getattr(args, 'use_vllm', False) == False): args.use_vllm = True
699
- # Args
700
- if args is None:
701
- model_name = model if isinstance(model, str) else model.config._name_or_path
702
- model_name = model_name.split("/")[-1]
703
- args = GRPOConfig(f"{model_name}-GRPO")
704
-
705
- # Models
706
- # Trained model
707
- model_init_kwargs = args.model_init_kwargs or {}
708
- if isinstance(model, str):
709
- model_id = model
710
- torch_dtype = model_init_kwargs.get("torch_dtype")
711
- if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None:
712
- pass # torch_dtype is already a torch.dtype or "auto" or None
713
- elif isinstance(torch_dtype, str): # it's a str, but not "auto"
714
- torch_dtype = getattr(torch, torch_dtype)
715
- model_init_kwargs["torch_dtype"] = torch_dtype
716
- else:
717
- raise ValueError(
718
- "Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing "
719
- f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
720
- )
721
- # Disable caching if gradient checkpointing is enabled (not supported)
722
- model_init_kwargs["use_cache"] = (
723
- False if args.gradient_checkpointing else model_init_kwargs.get("use_cache")
724
- )
725
- model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
726
- else:
727
- model_id = model.config._name_or_path
728
- if args.model_init_kwargs is not None:
729
- raise ValueError(
730
- "You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. "
731
- "This argument can only be used when the `model` argument is a string."
732
- )
733
-
734
- if False:
735
- model = model
736
-
737
- # Reference model
738
- if is_deepspeed_zero3_enabled():
739
- self.ref_model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs)
740
- elif not is_peft_model(model):
741
- # If PEFT configuration is not provided, create a reference model based on the initial model.
742
- self.ref_model = create_reference_model(model)
743
- else:
744
- # If PEFT is used, the reference model is not needed since the adapter can be disabled
745
- # to revert to the initial model.
746
- self.ref_model = None
747
-
748
- # Processing class
749
- if processing_class is None:
750
- processing_class = AutoTokenizer.from_pretrained(model.config._name_or_path, padding_side="left")
751
-
752
- # Reward functions
753
- if not isinstance(reward_funcs, list):
754
- reward_funcs = [reward_funcs]
755
- for i, reward_func in enumerate(reward_funcs):
756
- if isinstance(reward_func, str):
757
- reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
758
- reward_func, num_labels=1, **model_init_kwargs
759
- )
760
- self.reward_funcs = reward_funcs
761
-
762
- # Reward weights
763
- if args.reward_weights is not None:
764
- if len(args.reward_weights) != len(reward_funcs):
765
- raise ValueError(
766
- f"Number of reward weights ({len(args.reward_weights)}) must match number of reward "
767
- f"functions ({len(reward_funcs)})"
768
- )
769
- self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32)
770
- else:
771
- self.reward_weights = torch.ones(len(reward_funcs), dtype=torch.float32)
772
-
773
- # Reward processing class
774
- if reward_processing_classes is None:
775
- reward_processing_classes = [None] * len(reward_funcs)
776
- elif not isinstance(reward_processing_classes, list):
777
- reward_processing_classes = [reward_processing_classes]
778
- else:
779
- if len(reward_processing_classes) != len(reward_funcs):
780
- raise ValueError("The number of reward processing classes must match the number of reward functions.")
781
-
782
- for i, (reward_processing_class, reward_func) in enumerate(zip(reward_processing_classes, reward_funcs)):
783
- if isinstance(reward_func, PreTrainedModel):
784
- if reward_processing_class is None:
785
- reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path)
786
- if reward_processing_class.pad_token_id is None:
787
- reward_processing_class.pad_token = reward_processing_class.eos_token
788
- # The reward model computes the reward for the latest non-padded token in the input sequence.
789
- # So it's important to set the pad token ID to the padding token ID of the processing class.
790
- reward_func.config.pad_token_id = reward_processing_class.pad_token_id
791
- reward_processing_classes[i] = reward_processing_class
792
- self.reward_processing_classes = reward_processing_classes
793
-
794
- # Data collator
795
- def data_collator(features): # No data collation is needed in GRPO
796
- return features
797
-
798
- # Training arguments
799
- self.max_prompt_length = args.max_prompt_length
800
- self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper
801
- self.num_generations = args.num_generations # = G in the GRPO paper
802
- self.use_vllm = args.use_vllm
803
-
804
- self.beta = args.beta
805
-
806
- # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
807
- # input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the
808
- # "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning:
809
- # "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To
810
- # suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True.
811
- # This acts as a flag to indicate that the warning has already been issued.
812
- model.warnings_issued["estimate_tokens"] = True
813
-
814
- # Initialize the metrics
815
- self._metrics = defaultdict(list)
816
- self.log_completions = args.log_completions
817
-
818
- super().__init__(
819
- model=model,
820
- args=args,
821
- data_collator=data_collator,
822
- train_dataset=train_dataset,
823
- eval_dataset=eval_dataset,
824
- processing_class=processing_class,
825
- callbacks=callbacks,
826
- optimizers=optimizers,
827
- )
828
-
829
- # Check if the per_device_train/eval_batch_size * num processes can be divided by the number of generations
830
- num_processes = self.accelerator.num_processes
831
- global_batch_size = args.per_device_train_batch_size * num_processes
832
- possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0]
833
- if self.num_generations not in possible_values:
834
- raise ValueError(
835
- f"The global train batch size ({num_processes} x {args.per_device_train_batch_size}) must be evenly "
836
- f"divisible by the number of generations per prompt ({self.num_generations}). Given the current train "
837
- f"batch size, the valid values for the number of generations are: {possible_values}."
838
- )
839
- if self.args.eval_strategy != "no":
840
- global_batch_size = args.per_device_eval_batch_size * num_processes
841
- possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0]
842
- if self.num_generations not in possible_values:
843
- raise ValueError(
844
- f"The global eval batch size ({num_processes} x {args.per_device_eval_batch_size}) must be evenly "
845
- f"divisible by the number of generations per prompt ({self.num_generations}). Given the current "
846
- f"eval batch size, the valid values for the number of generations are: {possible_values}."
847
- )
848
-
849
- # Ensure each process receives a unique seed to prevent duplicate completions when generating with
850
- # transformers if num_generations exceeds per_device_train_batch_size. We could skip it if we use vLLM, but
851
- # it's safer to set it in all cases.
852
- set_seed(args.seed, device_specific=True)
853
-
854
- if self.use_vllm:
855
- self.llm = model.vllm_engine; self._last_loaded_step = 0; self.sampling_params = SamplingParams(
856
- temperature=args.temperature,
857
- max_tokens=self.max_completion_length,**getattr(getattr(args, 'vllm_sampling_params', vLLMSamplingParams()), '_set_kwargs', {}),)
858
- else:
859
- self.generation_config = GenerationConfig(
860
- max_new_tokens=self.max_completion_length,
861
- do_sample=True,
862
- temperature=args.temperature,
863
- pad_token_id=processing_class.pad_token_id,
864
- )
865
-
866
- # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
867
- # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
868
- # self.model_accepts_loss_kwargs to False to enable scaling.
869
- self.model_accepts_loss_kwargs = False
870
-
871
- # Add tags to the model
872
- self.model.add_model_tags(self._tag_names)
873
-
874
- if self.ref_model is not None:
875
- if self.is_deepspeed_enabled:
876
- self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
877
- else:
878
- self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
879
-
880
- if args.sync_ref_model:
881
- self.add_callback(SyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator))
882
-
883
- for i, reward_func in enumerate(self.reward_funcs):
884
- if isinstance(reward_func, PreTrainedModel):
885
- self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True)
886
-
887
- def _set_signature_columns_if_needed(self):
888
- # If `self.args.remove_unused_columns` is True, non-signature columns are removed.
889
- # By default, this method sets `self._signature_columns` to the model's expected inputs.
890
- # In GRPOTrainer, we preprocess data, so using the model's signature columns doesn't work.
891
- # Instead, we set them to the columns expected by the `training_step` method, hence the override.
892
- if self._signature_columns is None:
893
- self._signature_columns = ["prompt"]
894
-
895
- def _get_train_sampler(self) -> Sampler:
896
- # Returns a sampler that ensures each prompt is repeated across multiple processes. This guarantees that
897
- # identical prompts are distributed to different GPUs, allowing rewards to be computed and normalized correctly
898
- # within each prompt group. Using the same seed across processes ensures consistent prompt assignment,
899
- # preventing discrepancies in group formation.
900
- return RepeatRandomSampler(self.train_dataset, self.num_generations, seed=self.args.seed)
901
-
902
- def _get_eval_sampler(self, eval_dataset) -> Sampler:
903
- # Returns a sampler that ensures each prompt is repeated across multiple processes. This guarantees that
904
- # identical prompts are distributed to different GPUs, allowing rewards to be computed and normalized correctly
905
- # within each prompt group. Using the same seed across processes ensures consistent prompt assignment,
906
- # preventing discrepancies in group formation.
907
- return RepeatRandomSampler(eval_dataset, self.num_generations, seed=self.args.seed)
908
-
909
- # Get the per-token log probabilities for the completions for the model and the reference model
910
- def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
911
- if os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0':
912
- return None # Unsloth efficient GRPO
913
- # Otherwise, calculate normally:
914
- if not hasattr(self, '_autocast_dtype'):
915
- self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16
916
- if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1': self._autocast_dtype = torch.float16
917
- with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype):
918
- # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
919
- logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
920
- logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
921
-
922
- input_ids = input_ids[:, -logits_to_keep:]
923
- # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
924
- # See https://github.com/huggingface/trl/issues/2770
925
- logits = logits[:, -logits_to_keep:]
926
- return logits
927
- # return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens
928
- pass
929
-
930
- def _move_model_to_vllm(self, *args, **kwargs): return None
931
-
932
- def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
933
- device = self.accelerator.device
934
- prompts = [x["prompt"] for x in inputs]
935
- prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
936
- prompt_inputs = self.processing_class(
937
- prompts_text, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False
938
- )
939
- prompt_inputs = super()._prepare_inputs(prompt_inputs)
940
- prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
941
-
942
- if self.max_prompt_length is not None:
943
- prompt_ids = prompt_ids[:, -self.max_prompt_length :]
944
- prompt_mask = prompt_mask[:, -self.max_prompt_length :]
945
-
946
- # Generate completions using either vLLM or regular generation
947
- if self.args.use_vllm:
948
- # First, have main process load weights if needed
949
- if self.state.global_step != self._last_loaded_step:
950
- self._move_model_to_vllm()
951
- self._last_loaded_step = self.state.global_step
952
-
953
- # Generate completions using vLLM: gather all prompts and use them in a single call in the main process
954
- all_prompts_text = gather_object(prompts_text)
955
- if self.accelerator.is_main_process:
956
- outputs = self.llm.generate(all_prompts_text, sampling_params=self.sampling_params, use_tqdm=False, lora_request = self.model.load_lora('grpo_trainer_lora_model', load_tensors = True))
957
- completion_ids = [out.token_ids for completions in outputs for out in completions.outputs]
958
- else:
959
- completion_ids = [None] * len(all_prompts_text)
960
- # Broadcast the completions from the main process to all processes, ensuring each process receives its
961
- # corresponding slice.
962
- completion_ids = broadcast_object_list(completion_ids, from_process=0)
963
- process_slice = slice(
964
- self.accelerator.process_index * len(prompts),
965
- (self.accelerator.process_index + 1) * len(prompts),
966
- )
967
- completion_ids = completion_ids[process_slice]
968
-
969
- # Pad the completions, and concatenate them with the prompts
970
- completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
971
- completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id)
972
- prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
973
- else:
974
- # Regular generation path
975
- with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
976
- prompt_completion_ids = unwrapped_model.generate(
977
- prompt_ids, attention_mask=prompt_mask, generation_config=self.generation_config
978
- )
979
-
980
- # Compute prompt length and extract completion ids
981
- prompt_length = prompt_ids.size(1)
982
- prompt_ids = prompt_completion_ids[:, :prompt_length]
983
- completion_ids = prompt_completion_ids[:, prompt_length:]
984
-
985
- # Mask everything after the first EOS token
986
- is_eos = completion_ids == self.processing_class.eos_token_id
987
- eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
988
- eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
989
- sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
990
- completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
991
-
992
- # Concatenate prompt_mask with completion_mask for logit computation
993
- attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B*G, P+C)
994
-
995
- logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
996
-
997
- with torch.inference_mode(), torch.amp.autocast(device_type = 'cuda', dtype = ((torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16) if not torch.is_autocast_enabled('cuda') else nullcontext())if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '0' else torch.float16):
998
- if self.ref_model is not None:
999
- ref_per_token_logps = self._get_per_token_logps(
1000
- self.ref_model, prompt_completion_ids, attention_mask, logits_to_keep
1001
- )
1002
- else:
1003
- with self.accelerator.unwrap_model(self.model, keep_fp32_wrapper = False).disable_adapter():
1004
- ref_per_token_logps = self._get_per_token_logps(
1005
- self.model, prompt_completion_ids, attention_mask, logits_to_keep
1006
- )
1007
-
1008
- # Decode the generated completions
1009
- completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
1010
- if is_conversational(inputs[0]):
1011
- completions = []
1012
- for prompt, completion in zip(prompts, completions_text):
1013
- bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else ""
1014
- completions.append([{"role": "assistant", "content": bootstrap + completion}])
1015
- else:
1016
- completions = completions_text
1017
-
1018
- rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
1019
- for i, (reward_func, reward_processing_class) in enumerate(
1020
- zip(self.reward_funcs, self.reward_processing_classes)
1021
- ):
1022
- if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models
1023
- if is_conversational(inputs[0]):
1024
- messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
1025
- texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
1026
- else:
1027
- texts = [p + c for p, c in zip(prompts, completions)]
1028
- reward_inputs = reward_processing_class(
1029
- texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
1030
- )
1031
- reward_inputs = super()._prepare_inputs(reward_inputs)
1032
- with torch.inference_mode(), torch.amp.autocast(device_type = 'cuda', dtype = ((torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16) if not torch.is_autocast_enabled('cuda') else nullcontext())if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '0' else torch.float16):
1033
- rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,)
1034
- else:
1035
- # Repeat all input columns (but "prompt" and "completion") to match the number of generations
1036
- keys = [key for key in inputs[0] if key not in ["prompt", "completion"]]
1037
- reward_kwargs = {key: [example[key] for example in inputs] for key in keys}
1038
- output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs)
1039
- rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
1040
-
1041
- # Gather the reward per function: this part is crucial, because the rewards are normalized per group and the
1042
- # completions may be distributed across processes
1043
- rewards_per_func = gather(rewards_per_func)
1044
-
1045
- # Apply weights to each reward function's output and sum
1046
- rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).sum(dim=1)
1047
-
1048
- # Compute grouped-wise rewards
1049
- mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
1050
- std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
1051
-
1052
- # Normalize the rewards to compute the advantages
1053
- mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
1054
- std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
1055
- advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
1056
-
1057
- # Slice to keep only the local part of the data
1058
- process_slice = slice(
1059
- self.accelerator.process_index * len(prompts),
1060
- (self.accelerator.process_index + 1) * len(prompts),
1061
- )
1062
- advantages = advantages[process_slice]
1063
-
1064
- # Log the metrics
1065
- reward_per_func = rewards_per_func.mean(0)
1066
- for i, reward_func in enumerate(self.reward_funcs):
1067
- if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models
1068
- reward_func_name = reward_func.config._name_or_path.split("/")[-1]
1069
- else:
1070
- reward_func_name = reward_func.__name__
1071
- self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item())
1072
-
1073
- self._metrics["reward"].append(rewards.mean().item())
1074
- self._metrics["reward_std"].append(std_grouped_rewards.mean().item())
1075
-
1076
- if (
1077
- self.log_completions
1078
- and self.state.global_step % self.args.logging_steps == 0
1079
- and "wandb" in self.args.report_to
1080
- ):
1081
- import pandas as pd
1082
-
1083
- # For logging
1084
- table = {
1085
- "step": [str(self.state.global_step)] * len(rewards),
1086
- "prompt": gather_object(prompts_text),
1087
- "completion": gather_object(completions_text),
1088
- "reward": rewards.tolist(),
1089
- }
1090
- df = pd.DataFrame(table)
1091
-
1092
- if wandb.run is not None and self.accelerator.is_main_process:
1093
- wandb.log({"completions": wandb.Table(dataframe=df)})
1094
-
1095
- return {
1096
- "prompt_ids": prompt_ids,
1097
- "prompt_mask": prompt_mask,
1098
- "completion_ids": completion_ids,
1099
- "completion_mask": completion_mask,
1100
- "ref_per_token_logps": ref_per_token_logps,
1101
- "advantages": advantages,
1102
- }
1103
-
1104
- def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch = None):
1105
- if return_outputs:
1106
- raise ValueError("The GRPOTrainer does not support returning outputs")
1107
- # Compute the per-token log probabilities for the model
1108
-
1109
- prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
1110
- completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
1111
- input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
1112
- bsz, qlen = input_ids.shape
1113
- attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
1114
- # attention_mask = None
1115
- logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
1116
- _input_ids = input_ids
1117
- _logits_to_keep = logits_to_keep
1118
-
1119
- per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)
1120
-
1121
- # Compute the KL divergence between the model and the reference model
1122
- ref_per_token_logps = inputs["ref_per_token_logps"]
1123
- # per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
1124
-
1125
- # x - x.detach() allows for preserving gradients from x
1126
- advantages = inputs["advantages"]
1127
- # per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
1128
- # per_token_loss = -(per_token_loss - self.beta * per_token_kl)
1129
- # loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
1130
- input_ids = input_ids[:, -logits_to_keep:]
1131
- if per_token_logps is not None:
1132
- loss, completion_length, mean_kl = grpo_compute_loss_slow(
1133
- ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages,
1134
- )
1135
- else:
1136
- loss, completion_length, mean_kl = grpo_accumulated_loss(
1137
- self, _input_ids, logits_to_keep, completion_mask, advantages,
1138
- n_chunks = self.args.unsloth_num_chunks,
1139
- )
1140
-
1141
- # Log the metrics
1142
- # completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
1143
-
1144
- # mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
1145
- # self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
1146
-
1147
- if "train" in self._metrics:
1148
- mode = "eval" if self.control.should_evaluate else "train"
1149
- self._metrics[mode]["completion_length"].append(completion_length.item())
1150
- self._metrics[mode]["kl"].append(mean_kl.item())
1151
- else:
1152
- self._metrics["completion_length"].append(completion_length.item())
1153
- self._metrics["kl"].append(mean_kl.item())
1154
- return loss
1155
-
1156
- def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys: Optional[list[str]] = None):
1157
- inputs = self._prepare_inputs(inputs)
1158
- with torch.no_grad():
1159
- with self.compute_loss_context_manager():
1160
- loss = self.compute_loss(model, inputs)
1161
- loss = loss.mean().detach()
1162
- return loss, None, None
1163
-
1164
- def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
1165
- metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics
1166
-
1167
- # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
1168
- # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
1169
- if next(iter(logs.keys())).startswith("eval_"):
1170
- metrics = {f"eval_{key}": val for key, val in metrics.items()}
1171
-
1172
- logs = {**logs, **metrics}
1173
- if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
1174
- super().log(logs, start_time)
1175
- else: # transformers<=4.46
1176
- super().log(logs)
1177
- self._metrics.clear()
1178
-
1179
- def create_model_card(
1180
- self,
1181
- model_name: Optional[str] = None,
1182
- dataset_name: Optional[str] = None,
1183
- tags: Union[str, list[str], None] = None,
1184
- ):
1185
- """
1186
- Creates a draft of a model card using the information available to the `Trainer`.
1187
-
1188
- Args:
1189
- model_name (`str` or `None`, *optional*, defaults to `None`):
1190
- Name of the model.
1191
- dataset_name (`str` or `None`, *optional*, defaults to `None`):
1192
- Name of the dataset used for training.
1193
- tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
1194
- Tags to be associated with the model card.
1195
- """
1196
- if not self.is_world_process_zero():
1197
- return
1198
-
1199
- if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
1200
- base_model = self.model.config._name_or_path
1201
- else:
1202
- base_model = None
1203
-
1204
- tags = tags or []
1205
- if isinstance(tags, str):
1206
- tags = [tags]
1207
-
1208
- if hasattr(self.model.config, "unsloth_version"):
1209
- tags.append("unsloth")
1210
-
1211
- citation = textwrap.dedent(
1212
- """\
1213
- @article{zhihong2024deepseekmath,
1214
- title = {{DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models}},
1215
- author = {Zhihong Shao and Peiyi Wang and Qihao Zhu and Runxin Xu and Junxiao Song and Mingchuan Zhang and Y. K. Li and Y. Wu and Daya Guo},
1216
- year = 2024,
1217
- eprint = {arXiv:2402.03300},
1218
- }
1219
- """
1220
- )
1221
-
1222
- model_card = generate_model_card(
1223
- base_model=base_model,
1224
- model_name=model_name,
1225
- hub_model_id=self.hub_model_id,
1226
- dataset_name=dataset_name,
1227
- tags=tags,
1228
- wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
1229
- comet_url=get_comet_experiment_url(),
1230
- trainer_name="GRPO",
1231
- trainer_citation=citation,
1232
- paper_title="DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models",
1233
- paper_id="2402.03300",
1234
- )
1235
-
1236
- model_card.save(os.path.join(self.args.output_dir, "README.md"))
1237
- class UnslothGRPOTrainer(_UnslothGRPOTrainer):
1238
- """
1239
-
1240
- Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the
1241
- paper [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300).
1242
-
1243
- Example:
1244
-
1245
- ```python
1246
- from datasets import load_dataset
1247
- from trl import GRPOTrainer
1248
-
1249
- dataset = load_dataset("trl-lib/tldr", split="train")
1250
-
1251
- def reward_func(completions, **kwargs):
1252
- # Dummy reward function that rewards completions with more unique letters.
1253
- return [float(len(set(completion))) for completion in completions]
1254
-
1255
- trainer = GRPOTrainer(
1256
- model="Qwen/Qwen2-0.5B-Instruct",
1257
- reward_funcs=reward_func,
1258
- train_dataset=dataset,
1259
- )
1260
-
1261
- trainer.train()
1262
- ```
1263
-
1264
- Args:
1265
- model (`Union[str, PreTrainedModel]`):
1266
- Model to be trained. Can be either:
1267
-
1268
- - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or
1269
- a path to a *directory* containing model weights saved using
1270
- [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is
1271
- loaded using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keywork arguments
1272
- in `args.model_init_kwargs`.
1273
- - A [`~transformers.PreTrainedModel`] object. Only causal language models are supported.
1274
- reward_funcs (`Union[RewardFunc, list[RewardFunc]]`):
1275
- Reward functions to be used for computing the rewards. To compute the rewards, we call all the reward
1276
- functions with the prompts and completions and sum the rewards. Can be either:
1277
-
1278
- - A single reward function, such as:
1279
- - A string: The *model ID* of a pretrained model hosted inside a model repo on huggingface.co, or a
1280
- path to a *directory* containing model weights saved using
1281
- [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded
1282
- using [`~transformers.AutoModelForSequenceClassification.from_pretrained`] with `num_labels=1` and the
1283
- keyword arguments in `args.model_init_kwargs`.
1284
- - A [`~transformers.PreTrainedModel`] object: Only sequence classification models are supported.
1285
- - A custom reward function: The function is provided with the prompts and the generated completions,
1286
- plus any additional columns in the dataset. It should return a list of rewards. For more details, see
1287
- [Using a custom reward function](#using-a-custom-reward-function).
1288
- - A list of reward functions, where each item can independently be any of the above types. Mixing different
1289
- types within the list (e.g., a string model ID and a custom reward function) is allowed.
1290
- args ([`GRPOConfig`], *optional*, defaults to `None`):
1291
- Configuration for this trainer. If `None`, a default configuration is used.
1292
- train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
1293
- Dataset to use for training. It must include a column `"prompt"`. Any additional columns in the dataset is
1294
- ignored. The format of the samples can be either:
1295
-
1296
- - [Standard](dataset_formats#standard): Each sample contains plain text.
1297
- - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
1298
- and content).
1299
- eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
1300
- Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
1301
- processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*, defaults to `None`):
1302
- Processing class used to process the data. The padding side must be set to "left". If `None`, the
1303
- processing class is loaded from the model's name with [`~transformers.AutoTokenizer.from_pretrained`].
1304
- reward_processing_classes (`Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]`, *optional*, defaults to `None`):
1305
- Processing classes corresponding to the reward functions specified in `reward_funcs`. Can be either:
1306
-
1307
- - A single processing class: Used when `reward_funcs` contains only one reward function.
1308
- - A list of processing classes: Must match the order and length of the reward functions in `reward_funcs`.
1309
- If set to `None`, or if an element of the list corresponding to a [`~transformers.PreTrainedModel`] is
1310
- `None`, the tokenizer for the model is automatically loaded using [`~transformers.AutoTokenizer.from_pretrained`].
1311
- For elements in `reward_funcs` that are custom reward functions (not [`~transformers.PreTrainedModel`]),
1312
- the corresponding entries in `reward_processing_classes` are ignored.
1313
- callbacks (list of [`~transformers.TrainerCallback`], *optional*, defaults to `None`):
1314
- List of callbacks to customize the training loop. Will add those to the list of default callbacks
1315
- detailed in [here](https://huggingface.co/docs/transformers/main_classes/callback).
1316
-
1317
- If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`]
1318
- method.
1319
- optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`):
1320
- A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your
1321
- model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
1322
- peft_config ([`~peft.PeftConfig`], *optional*, defaults to `None`):
1323
- PEFT configuration used to wrap the model. If `None`, the model is not wrapped.
1324
-
1325
- """
1326
- def __init__(
1327
- self,
1328
- model,
1329
- reward_funcs,
1330
- args = None,
1331
- train_dataset = None,
1332
- eval_dataset = None,
1333
- processing_class = None,
1334
- reward_processing_classes = None,
1335
- callbacks = None,
1336
- peft_config = None,
1337
- **kwargs
1338
- ):
1339
- if args is None: args = UnslothGRPOConfig()
1340
- use_bf16 = getattr(args, 'bf16', False)
1341
- use_fp16 = getattr(args, 'fp16', False)
1342
- force_float32 = False
1343
- if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
1344
- print('Unsloth: Switching to float32 training since model cannot work with float16')
1345
- force_float32 = True
1346
- mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
1347
- dtype = getattr(model.config, 'torch_dtype', None)
1348
- if dtype is None: dtype = model.get_input_embeddings().dtype
1349
- from unsloth_zoo.utils import _get_dtype
1350
- dtype = _get_dtype(dtype)
1351
- float16 = dtype == torch.float16
1352
- if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
1353
- if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
1354
- if force_float32:
1355
- args.fp16 = False
1356
- args.bf16 = False
1357
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1358
- elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
1359
- args.fp16 = float16
1360
- args.bf16 = not float16
1361
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
1362
- if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
1363
- args.eval_strategy = 'steps'
1364
- if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
1365
- ga_steps = getattr(args, 'gradient_accumulation_steps', None)
1366
- if ga_steps is not None and ga_steps > 1:
1367
- from transformers import __version__ as transformers_version
1368
- if Version(transformers_version) <= Version('4.45.2'):
1369
- print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
1370
- '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
1371
- if getattr(args, 'eval_strategy', 'no') != 'no':
1372
- eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
1373
- if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
1374
- if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
1375
- fp16_full_eval = getattr(args, 'fp16_full_eval', False)
1376
- bf16_full_eval = getattr(args, 'bf16_full_eval', False)
1377
- if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
1378
- if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
1379
- if force_float32:
1380
- args.bf16_full_eval = False
1381
- args.fp16_full_eval = False
1382
- elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
1383
- args.bf16_full_eval = True
1384
- args.fp16_full_eval = False
1385
- elif not bf16_full_eval and not fp16_full_eval:
1386
- args.bf16_full_eval = args.bf16
1387
- args.fp16_full_eval = args.fp16
1388
- _output_logits = False
1389
- if locals().get('compute_metrics', None) is not None: _output_logits = True
1390
- if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
1391
- if _output_logits:
1392
- os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
1393
- if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1394
- pass
1395
- else:
1396
- model_max_seq_length = getattr(model, 'max_seq_length', None)
1397
- args_max_seq_length = getattr(args, 'max_seq_length', None)
1398
- if args_max_seq_length is None and model_max_seq_length is not None:
1399
- max_seq_length = model.max_seq_length
1400
- if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1401
- if model is not None and hasattr(model, 'for_training'):
1402
- model.for_training()
1403
- if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1404
- if 'processing_class' in locals():
1405
- if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1406
- if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1407
- other_metrics = []
1408
- if not isinstance(reward_funcs, list): _reward_funcs = [reward_funcs]
1409
- else: _reward_funcs = reward_funcs
1410
- for reward_func in _reward_funcs:
1411
- try:
1412
- reward_func_name = reward_func.__name__
1413
- other_metrics.append(f'rewards/{reward_func_name}')
1414
- except: pass
1415
-
1416
- from unsloth_zoo.logging_utils import PatchRLStatistics
1417
- PatchRLStatistics('grpo_trainer', other_metrics)
1418
-
1419
- super().__init__(
1420
- model = model,
1421
- reward_funcs = reward_funcs,
1422
- args = args,
1423
- train_dataset = train_dataset,
1424
- eval_dataset = eval_dataset,
1425
- processing_class = processing_class,
1426
- reward_processing_classes = reward_processing_classes,
1427
- callbacks = callbacks,
1428
- peft_config = peft_config,**kwargs)
1429
- if hasattr(self, 'neftune_hook_handle'):
1430
- self.neftune_hook_handle.remove()
1431
- if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1432
- if getattr(args, 'neftune_noise_alpha', None) is not None:
1433
- model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1434
- pass
1435
-
1436
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
unsloth_compiled_cache/UnslothKTOTrainer.py DELETED
@@ -1,1838 +0,0 @@
1
- """
2
- 2025.3.13
3
- 2025.3.15
4
- 4.48.3
5
- 0.15.2
6
- __UNSLOTH_VERSIONING__
7
- """
8
- from torch import Tensor
9
- import torch
10
- import torch.nn as nn
11
- from torch.nn import functional as F
12
- from trl.trainer.kto_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, KTOConfig, KTOTrainer, Literal, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedModelWrapper, PreTrainedTokenizerBase, ProcessorMixin, SequentialSampler, Trainer, TrainerCallback, TrainingArguments, Union, _get_kl_dataset, _process_tokens, _tokenize, amp, concatenate_datasets, contextmanager, create_reference_model, deepcopy, defaultdict, disable_dropout_in_model, generate_model_card, get_comet_experiment_url, has_length, inspect, is_comet_available, is_peft_available, is_wandb_available, itemgetter, log_table_to_comet_experiment, maybe_apply_chat_template, maybe_extract_prompt, maybe_unpair_preference_dataset, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_model_for_kbit_training, random, textwrap, torch, tqdm, transformers, version, wandb, warnings)
13
-
14
-
15
- import os
16
- from typing import *
17
- from dataclasses import dataclass, field
18
- from packaging.version import Version
19
- import torch
20
- import numpy as np
21
- from contextlib import nullcontext
22
- from torch.nn import functional as F
23
- from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
-
25
- torch_compile_options = {
26
- "epilogue_fusion" : True,
27
- "max_autotune" : False,
28
- "shape_padding" : True,
29
- "trace.enabled" : False,
30
- "triton.cudagraphs" : False,
31
- }
32
-
33
- @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
- def selective_log_softmax(logits, index):
35
- logits = logits.to(torch.float32)
36
- selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
- # loop to reduce peak mem consumption
38
- # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
- logsumexp_values = torch.logsumexp(logits, dim = -1)
40
- per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
- return per_token_logps
42
- @dataclass
43
- class UnslothKTOConfig(KTOConfig):
44
- """
45
-
46
- Configuration class for the [`KTOTrainer`].
47
-
48
- Using [`~transformers.HfArgumentParser`] we can turn this class into
49
- [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
50
- command line.
51
-
52
- Parameters:
53
- learning_rate (`float`, *optional*, defaults to `5e-7`):
54
- Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
55
- [`~transformers.TrainingArguments`].
56
- max_length (`int` or `None`, *optional*, defaults to `1024`):
57
- Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
58
- to use the default data collator.
59
- max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
60
- Maximum length of the prompt. This argument is required if you want to use the default data collator.
61
- max_completion_length (`int` or `None`, *optional*, defaults to `None`):
62
- Maximum length of the completion. This argument is required if you want to use the default data collator
63
- and your model is an encoder-decoder.
64
- beta (`float`, *optional*, defaults to `0.1`):
65
- Parameter controlling the deviation from the reference model. Higher β means less deviation from the
66
- reference model.
67
- loss_type (`str`, *optional*, defaults to `"kto"`):
68
- Type of loss to use. Possible values are:
69
-
70
- - `"kto"`: KTO loss from the [KTO](https://huggingface.co/papers/2402.01306) paper.
71
- - `"apo_zero_unpaired"`: Unpaired variant of APO-zero loss from the [APO](https://huggingface.co/papers/2408.06266) paper.
72
-
73
- desirable_weight (`float`, *optional*, defaults to `1.0`):
74
- Desirable losses are weighed by this factor to counter unequal number of desirable and undesirable paris.
75
- undesirable_weight (`float`, *optional*, defaults to `1.0`):
76
- Undesirable losses are weighed by this factor to counter unequal number of desirable and undesirable pairs.
77
- label_pad_token_id (`int`, *optional*, defaults to `-100`):
78
- Label pad token id. This argument is required if you want to use the default data collator.
79
- padding_value (`int` or `None`, *optional*, defaults to `None`):
80
- Padding value to use. If `None`, the padding value of the tokenizer is used.
81
- truncation_mode (`str`, *optional*, defaults to `"keep_end"`):
82
- Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`.
83
- This argument is required if you want to use the default data collator.
84
- generate_during_eval (`bool`, *optional*, defaults to `False`):
85
- If `True`, generates and logs completions from both the model and the reference model to W&B or Comet during
86
- evaluation.
87
- is_encoder_decoder (`bool` or `None`, *optional*, defaults to `None`):
88
- When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument,
89
- you need to specify if the model returned by the callable is an encoder-decoder model.
90
- precompute_ref_log_probs (`bool`, *optional*, defaults to `False`):
91
- Whether to precompute reference model log probabilities for training and evaluation datasets. This is
92
- useful when training without the reference model to reduce the total GPU memory needed.
93
- model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
94
- Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
95
- string.
96
- ref_model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
97
- Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the reference model
98
- from a string.
99
- dataset_num_proc: (`int` or `None`, *optional*, defaults to `None`):
100
- Number of processes to use for processing the dataset.
101
- disable_dropout (`bool`, *optional*, defaults to `True`):
102
- Whether to disable dropout in the model and reference model.
103
-
104
- """
105
- vllm_sampling_params: Optional[Any] = field(
106
- default = None,
107
- metadata = {'help': 'vLLM SamplingParams'},
108
- )
109
- unsloth_num_chunks : Optional[int] = field(
110
- default = -1,
111
- metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
112
- )
113
- def __init__(
114
- self,
115
- output_dir = None,
116
- overwrite_output_dir = None,
117
- do_train = False,
118
- do_eval = False,
119
- do_predict = False,
120
- eval_strategy = 'no',
121
- prediction_loss_only = False,
122
- per_device_train_batch_size = 4,
123
- per_device_eval_batch_size = 4,
124
- per_gpu_train_batch_size = None,
125
- per_gpu_eval_batch_size = None,
126
- gradient_accumulation_steps = 2,
127
- eval_accumulation_steps = 2,
128
- eval_delay = 0,
129
- torch_empty_cache_steps = 250,
130
- learning_rate = 5e-05,
131
- weight_decay = 0.01,
132
- adam_beta1 = 0.9,
133
- adam_beta2 = 0.999,
134
- adam_epsilon = 1e-08,
135
- max_grad_norm = 1.0,
136
- num_train_epochs = 3.0,
137
- max_steps = -1,
138
- lr_scheduler_type = 'linear',
139
- warmup_ratio = 0.1,
140
- warmup_steps = 0,
141
- log_level = 'passive',
142
- log_level_replica = 'warning',
143
- log_on_each_node = True,
144
- logging_dir = None,
145
- logging_strategy = 'steps',
146
- logging_first_step = False,
147
- logging_steps = 1,
148
- logging_nan_inf_filter = False,
149
- save_strategy = 'steps',
150
- save_steps = 500,
151
- save_total_limit = None,
152
- save_safetensors = True,
153
- save_on_each_node = False,
154
- save_only_model = False,
155
- restore_callback_states_from_checkpoint = False,
156
- no_cuda = False,
157
- use_cpu = False,
158
- use_mps_device = False,
159
- seed = 3407,
160
- data_seed = 3407,
161
- jit_mode_eval = False,
162
- use_ipex = False,
163
- bf16 = False,
164
- fp16 = False,
165
- fp16_opt_level = 'O1',
166
- half_precision_backend = 'auto',
167
- bf16_full_eval = False,
168
- fp16_full_eval = False,
169
- tf32 = None,
170
- local_rank = -1,
171
- ddp_backend = None,
172
- tpu_num_cores = None,
173
- tpu_metrics_debug = False,
174
- debug = '',
175
- dataloader_drop_last = False,
176
- eval_steps = None,
177
- dataloader_num_workers = 0,
178
- dataloader_prefetch_factor = None,
179
- past_index = -1,
180
- run_name = None,
181
- disable_tqdm = None,
182
- remove_unused_columns = True,
183
- label_names = None,
184
- load_best_model_at_end = False,
185
- metric_for_best_model = None,
186
- greater_is_better = None,
187
- ignore_data_skip = False,
188
- fsdp = '',
189
- fsdp_min_num_params = 0,
190
- fsdp_config = None,
191
- fsdp_transformer_layer_cls_to_wrap = None,
192
- accelerator_config = None,
193
- deepspeed = None,
194
- label_smoothing_factor = 0.0,
195
- optim = 'adamw_8bit',
196
- optim_args = None,
197
- adafactor = False,
198
- group_by_length = False,
199
- length_column_name = 'length',
200
- report_to = None,
201
- ddp_find_unused_parameters = None,
202
- ddp_bucket_cap_mb = None,
203
- ddp_broadcast_buffers = None,
204
- dataloader_pin_memory = True,
205
- dataloader_persistent_workers = False,
206
- skip_memory_metrics = True,
207
- use_legacy_prediction_loop = False,
208
- push_to_hub = False,
209
- resume_from_checkpoint = None,
210
- hub_model_id = None,
211
- hub_strategy = 'every_save',
212
- hub_token = None,
213
- hub_private_repo = None,
214
- hub_always_push = False,
215
- gradient_checkpointing = False,
216
- gradient_checkpointing_kwargs = None,
217
- include_inputs_for_metrics = False,
218
- eval_do_concat_batches = True,
219
- fp16_backend = 'auto',
220
- evaluation_strategy = None,
221
- push_to_hub_model_id = None,
222
- push_to_hub_organization = None,
223
- push_to_hub_token = None,
224
- mp_parameters = '',
225
- auto_find_batch_size = False,
226
- full_determinism = False,
227
- torchdynamo = None,
228
- ray_scope = 'last',
229
- ddp_timeout = 1800,
230
- torch_compile = False,
231
- torch_compile_backend = None,
232
- torch_compile_mode = None,
233
- dispatch_batches = None,
234
- split_batches = None,
235
- include_tokens_per_second = False,
236
- include_num_input_tokens_seen = False,
237
- neftune_noise_alpha = None,
238
- optim_target_modules = None,
239
- batch_eval_metrics = False,
240
- eval_on_start = False,
241
- use_liger_kernel = False,
242
- eval_use_gather_object = False,
243
- average_tokens_across_devices = False,
244
- max_length = 1024,
245
- max_prompt_length = 512,
246
- max_completion_length = None,
247
- beta = 0.1,
248
- loss_type = 'kto',
249
- desirable_weight = 1.0,
250
- undesirable_weight = 1.0,
251
- label_pad_token_id = -100,
252
- padding_value = None,
253
- truncation_mode = 'keep_end',
254
- generate_during_eval = False,
255
- is_encoder_decoder = None,
256
- disable_dropout = True,
257
- precompute_ref_log_probs = False,
258
- model_init_kwargs = None,
259
- ref_model_init_kwargs = None,
260
- dataset_num_proc = None,
261
- vllm_sampling_params = None,
262
- unsloth_num_chunks = -1,
263
- **kwargs,
264
- ):
265
- if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
266
- if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
267
- if output_dir is None and save_strategy == 'steps' and save_steps == 500:
268
- output_dir = 'unsloth_training_checkpoints'
269
- save_strategy = 'no'
270
- if dataset_num_proc is None:
271
- from multiprocessing import cpu_count
272
- dataset_num_proc = cpu_count()
273
-
274
- super().__init__(
275
- output_dir = output_dir,
276
- overwrite_output_dir = overwrite_output_dir,
277
- do_train = do_train,
278
- do_eval = do_eval,
279
- do_predict = do_predict,
280
- eval_strategy = eval_strategy,
281
- prediction_loss_only = prediction_loss_only,
282
- per_device_train_batch_size = per_device_train_batch_size,
283
- per_device_eval_batch_size = per_device_eval_batch_size,
284
- per_gpu_train_batch_size = per_gpu_train_batch_size,
285
- per_gpu_eval_batch_size = per_gpu_eval_batch_size,
286
- gradient_accumulation_steps = gradient_accumulation_steps,
287
- eval_accumulation_steps = eval_accumulation_steps,
288
- eval_delay = eval_delay,
289
- torch_empty_cache_steps = torch_empty_cache_steps,
290
- learning_rate = learning_rate,
291
- weight_decay = weight_decay,
292
- adam_beta1 = adam_beta1,
293
- adam_beta2 = adam_beta2,
294
- adam_epsilon = adam_epsilon,
295
- max_grad_norm = max_grad_norm,
296
- num_train_epochs = num_train_epochs,
297
- max_steps = max_steps,
298
- lr_scheduler_type = lr_scheduler_type,
299
- warmup_ratio = warmup_ratio,
300
- warmup_steps = warmup_steps,
301
- log_level = log_level,
302
- log_level_replica = log_level_replica,
303
- log_on_each_node = log_on_each_node,
304
- logging_dir = logging_dir,
305
- logging_strategy = logging_strategy,
306
- logging_first_step = logging_first_step,
307
- logging_steps = logging_steps,
308
- logging_nan_inf_filter = logging_nan_inf_filter,
309
- save_strategy = save_strategy,
310
- save_steps = save_steps,
311
- save_total_limit = save_total_limit,
312
- save_safetensors = save_safetensors,
313
- save_on_each_node = save_on_each_node,
314
- save_only_model = save_only_model,
315
- restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
316
- no_cuda = no_cuda,
317
- use_cpu = use_cpu,
318
- use_mps_device = use_mps_device,
319
- seed = seed,
320
- data_seed = data_seed,
321
- jit_mode_eval = jit_mode_eval,
322
- use_ipex = use_ipex,
323
- bf16 = bf16,
324
- fp16 = fp16,
325
- fp16_opt_level = fp16_opt_level,
326
- half_precision_backend = half_precision_backend,
327
- bf16_full_eval = bf16_full_eval,
328
- fp16_full_eval = fp16_full_eval,
329
- tf32 = tf32,
330
- local_rank = local_rank,
331
- ddp_backend = ddp_backend,
332
- tpu_num_cores = tpu_num_cores,
333
- tpu_metrics_debug = tpu_metrics_debug,
334
- debug = debug,
335
- dataloader_drop_last = dataloader_drop_last,
336
- eval_steps = eval_steps,
337
- dataloader_num_workers = dataloader_num_workers,
338
- dataloader_prefetch_factor = dataloader_prefetch_factor,
339
- past_index = past_index,
340
- run_name = run_name,
341
- disable_tqdm = disable_tqdm,
342
- remove_unused_columns = remove_unused_columns,
343
- label_names = label_names,
344
- load_best_model_at_end = load_best_model_at_end,
345
- metric_for_best_model = metric_for_best_model,
346
- greater_is_better = greater_is_better,
347
- ignore_data_skip = ignore_data_skip,
348
- fsdp = fsdp,
349
- fsdp_min_num_params = fsdp_min_num_params,
350
- fsdp_config = fsdp_config,
351
- fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
352
- accelerator_config = accelerator_config,
353
- deepspeed = deepspeed,
354
- label_smoothing_factor = label_smoothing_factor,
355
- optim = optim,
356
- optim_args = optim_args,
357
- adafactor = adafactor,
358
- group_by_length = group_by_length,
359
- length_column_name = length_column_name,
360
- report_to = report_to,
361
- ddp_find_unused_parameters = ddp_find_unused_parameters,
362
- ddp_bucket_cap_mb = ddp_bucket_cap_mb,
363
- ddp_broadcast_buffers = ddp_broadcast_buffers,
364
- dataloader_pin_memory = dataloader_pin_memory,
365
- dataloader_persistent_workers = dataloader_persistent_workers,
366
- skip_memory_metrics = skip_memory_metrics,
367
- use_legacy_prediction_loop = use_legacy_prediction_loop,
368
- push_to_hub = push_to_hub,
369
- resume_from_checkpoint = resume_from_checkpoint,
370
- hub_model_id = hub_model_id,
371
- hub_strategy = hub_strategy,
372
- hub_token = hub_token,
373
- hub_private_repo = hub_private_repo,
374
- hub_always_push = hub_always_push,
375
- gradient_checkpointing = gradient_checkpointing,
376
- gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
377
- include_inputs_for_metrics = include_inputs_for_metrics,
378
- eval_do_concat_batches = eval_do_concat_batches,
379
- fp16_backend = fp16_backend,
380
- evaluation_strategy = evaluation_strategy,
381
- push_to_hub_model_id = push_to_hub_model_id,
382
- push_to_hub_organization = push_to_hub_organization,
383
- push_to_hub_token = push_to_hub_token,
384
- mp_parameters = mp_parameters,
385
- auto_find_batch_size = auto_find_batch_size,
386
- full_determinism = full_determinism,
387
- torchdynamo = torchdynamo,
388
- ray_scope = ray_scope,
389
- ddp_timeout = ddp_timeout,
390
- torch_compile = torch_compile,
391
- torch_compile_backend = torch_compile_backend,
392
- torch_compile_mode = torch_compile_mode,
393
- dispatch_batches = dispatch_batches,
394
- split_batches = split_batches,
395
- include_tokens_per_second = include_tokens_per_second,
396
- include_num_input_tokens_seen = include_num_input_tokens_seen,
397
- neftune_noise_alpha = neftune_noise_alpha,
398
- optim_target_modules = optim_target_modules,
399
- batch_eval_metrics = batch_eval_metrics,
400
- eval_on_start = eval_on_start,
401
- use_liger_kernel = use_liger_kernel,
402
- eval_use_gather_object = eval_use_gather_object,
403
- average_tokens_across_devices = average_tokens_across_devices,
404
- max_length = max_length,
405
- max_prompt_length = max_prompt_length,
406
- max_completion_length = max_completion_length,
407
- beta = beta,
408
- loss_type = loss_type,
409
- desirable_weight = desirable_weight,
410
- undesirable_weight = undesirable_weight,
411
- label_pad_token_id = label_pad_token_id,
412
- padding_value = padding_value,
413
- truncation_mode = truncation_mode,
414
- generate_during_eval = generate_during_eval,
415
- is_encoder_decoder = is_encoder_decoder,
416
- disable_dropout = disable_dropout,
417
- precompute_ref_log_probs = precompute_ref_log_probs,
418
- model_init_kwargs = model_init_kwargs,
419
- ref_model_init_kwargs = ref_model_init_kwargs,
420
- dataset_num_proc = dataset_num_proc,**kwargs)
421
- self.vllm_sampling_params = vllm_sampling_params
422
- self.unsloth_num_chunks = unsloth_num_chunks
423
- pass
424
-
425
- class _UnslothKTOTrainer(Trainer):
426
- r""""""
427
-
428
- _tag_names = ["trl", "kto"]
429
-
430
- def __init__(
431
- self,
432
- model: Union[PreTrainedModel, nn.Module, str] = None,
433
- ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
434
- args: KTOConfig = None,
435
- train_dataset: Optional[Dataset] = None,
436
- eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
437
- processing_class: Optional[
438
- Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
439
- ] = None,
440
- data_collator: Optional[DataCollator] = None,
441
- model_init: Optional[Callable[[], PreTrainedModel]] = None,
442
- callbacks: Optional[list[TrainerCallback]] = None,
443
- optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
444
- preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
445
- peft_config: Optional[dict] = None,
446
- compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
447
- model_adapter_name: Optional[str] = None,
448
- ref_adapter_name: Optional[str] = None,
449
- ):
450
- if type(args) is TrainingArguments:
451
- raise ValueError("Please use `KTOConfig` instead TrainingArguments.")
452
-
453
- if not isinstance(model, str) and ref_model is model:
454
- raise ValueError(
455
- "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
456
- "same as `model`, you must mass a copy of it, or `None` if you use peft."
457
- )
458
-
459
- if args.model_init_kwargs is None:
460
- model_init_kwargs = {}
461
- elif not isinstance(model, str):
462
- raise ValueError("You passed model_kwargs to the KTOTrainer. But your model is already instantiated.")
463
- else:
464
- model_init_kwargs = args.model_init_kwargs
465
- torch_dtype = model_init_kwargs.get("torch_dtype")
466
- if torch_dtype is not None:
467
- # Convert to `torch.dtype` if an str is passed
468
- if isinstance(torch_dtype, str) and torch_dtype != "auto":
469
- torch_dtype = getattr(torch, torch_dtype)
470
- if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
471
- raise ValueError(
472
- f"Invalid `torch_dtype` passed to the KTOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
473
- )
474
- model_init_kwargs["torch_dtype"] = torch_dtype
475
-
476
- if args.ref_model_init_kwargs is None:
477
- ref_model_init_kwargs = {}
478
- elif not isinstance(ref_model, str):
479
- raise ValueError(
480
- "You passed ref_model_kwargs to the KTOTrainer. But your ref_model is already instantiated."
481
- )
482
- else:
483
- ref_model_init_kwargs = args.ref_model_init_kwargs
484
- torch_dtype = ref_model_init_kwargs.get("torch_dtype")
485
- if torch_dtype is not None:
486
- # Convert to `torch.dtype` if an str is passed
487
- if isinstance(torch_dtype, str) and torch_dtype != "auto":
488
- torch_dtype = getattr(torch, torch_dtype)
489
- if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
490
- raise ValueError(
491
- f"Invalid `torch_dtype` passed to the KTOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
492
- )
493
- ref_model_init_kwargs["torch_dtype"] = torch_dtype
494
-
495
- if isinstance(model, str):
496
- model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
497
-
498
- if isinstance(ref_model, str):
499
- ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **ref_model_init_kwargs)
500
-
501
- # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
502
- # has been called in order to properly call autocast if needed.
503
- self._peft_has_been_casted_to_bf16 = False
504
-
505
- if not is_peft_available() and peft_config is not None:
506
- raise ValueError(
507
- "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it with `pip install peft` to use the PEFT models"
508
- )
509
- elif is_peft_available() and peft_config is not None:
510
- # if model is a peft model and we have a peft_config, we merge and unload it first
511
- if isinstance(model, PeftModel):
512
- model = model.merge_and_unload()
513
-
514
- if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
515
- _support_gc_kwargs = hasattr(
516
- args, "gradient_checkpointing_kwargs"
517
- ) and "gradient_checkpointing_kwargs" in list(
518
- inspect.signature(prepare_model_for_kbit_training).parameters
519
- )
520
-
521
- prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
522
-
523
- if _support_gc_kwargs:
524
- prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
525
-
526
- model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
527
- elif getattr(args, "gradient_checkpointing", False):
528
- # For backward compatibility with older versions of transformers
529
- if hasattr(model, "enable_input_require_grads"):
530
- model.enable_input_require_grads()
531
- else:
532
-
533
- def make_inputs_require_grad(module, input, output):
534
- output.requires_grad_(True)
535
-
536
- model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
537
-
538
- # get peft model with the given config
539
- model = model
540
- if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
541
- peft_module_casting_to_bf16(model)
542
- # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
543
- self._peft_has_been_casted_to_bf16 = True
544
-
545
- # For models that use gradient_checkpointing, we need to attach a hook that enables input
546
- # to explicitly have `requires_grad=True`, otherwise training will either silently
547
- # fail or completely fail.
548
- elif getattr(args, "gradient_checkpointing", False):
549
- # For backward compatibility with older versions of transformers
550
- if hasattr(model, "enable_input_require_grads"):
551
- model.enable_input_require_grads()
552
- else:
553
-
554
- def make_inputs_require_grad(module, input, output):
555
- output.requires_grad_(True)
556
-
557
- model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
558
-
559
- if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
560
- raise ValueError(
561
- "`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
562
- " Please install `wandb` or `comet-ml` to resolve."
563
- )
564
-
565
- if model is not None:
566
- self.is_encoder_decoder = model.config.is_encoder_decoder
567
- elif args.is_encoder_decoder is None:
568
- raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
569
- else:
570
- self.is_encoder_decoder = args.is_encoder_decoder
571
-
572
- self.is_peft_model = is_peft_available() and isinstance(model, PeftModel)
573
- self.model_adapter_name = model_adapter_name
574
- self.ref_adapter_name = ref_adapter_name
575
-
576
- if ref_model:
577
- self.ref_model = ref_model
578
- elif self.is_peft_model or args.precompute_ref_log_probs:
579
- # The `model` with adapters turned off will be used as the reference model
580
- self.ref_model = None
581
- else:
582
- self.ref_model = create_reference_model(model)
583
-
584
- if processing_class is None:
585
- raise ValueError(
586
- "max_length or a processing_class must be specified when using the default DPODataCollatorWithPadding"
587
- )
588
- if args.max_length is None:
589
- warnings.warn(
590
- "When using DPODataCollatorWithPadding, you should set `max_length` in the KTOTrainer's init"
591
- " it will be set to `512` by default, but you should do it yourself in the future.",
592
- UserWarning,
593
- )
594
- max_length = 512
595
- if args.max_length is not None:
596
- max_length = args.max_length
597
-
598
- if args.max_prompt_length is None:
599
- warnings.warn(
600
- "When using DPODataCollatorWithPadding, you should set `max_prompt_length` in the KTOTrainer's init"
601
- " it will be set to `128` by default, but you should do it yourself in the future.",
602
- UserWarning,
603
- )
604
- max_prompt_length = 128
605
- if args.max_prompt_length is not None:
606
- max_prompt_length = args.max_prompt_length
607
-
608
- max_completion_length = None
609
- if args.max_completion_length is None and self.is_encoder_decoder:
610
- warnings.warn(
611
- "When using DPODataCollatorWithPadding with an encoder decoder architecture, you should set `max_completion_length` in the KTOTrainer's init"
612
- " it will be set to `128` by default, but you should do it yourself in the future.",
613
- UserWarning,
614
- )
615
- max_completion_length = 128
616
- if args.max_completion_length is not None and self.is_encoder_decoder:
617
- max_completion_length = args.max_completion_length
618
-
619
- if data_collator is None:
620
- data_collator = DPODataCollatorWithPadding(
621
- pad_token_id=processing_class.pad_token_id,
622
- label_pad_token_id=args.label_pad_token_id,
623
- is_encoder_decoder=self.is_encoder_decoder,
624
- )
625
-
626
- if args.remove_unused_columns:
627
- args.remove_unused_columns = False
628
- # warn users
629
- warnings.warn(
630
- "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your KTOConfig"
631
- " we have set it for you, but you should do it yourself in the future.",
632
- UserWarning,
633
- )
634
-
635
- self.use_dpo_data_collator = True
636
- else:
637
- self.use_dpo_data_collator = False
638
-
639
- # Disable dropout in the model and reference model
640
- if args.disable_dropout:
641
- disable_dropout_in_model(model)
642
- if self.ref_model is not None:
643
- disable_dropout_in_model(self.ref_model)
644
-
645
- self.loss_type = args.loss_type
646
- self.max_length = max_length
647
- self.generate_during_eval = args.generate_during_eval
648
- self.label_pad_token_id = args.label_pad_token_id
649
- self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id
650
- self.max_prompt_length = max_prompt_length
651
- self.truncation_mode = args.truncation_mode
652
- self.max_completion_length = max_completion_length
653
- self.processing_class = processing_class
654
- self.precompute_ref_log_probs = args.precompute_ref_log_probs
655
-
656
- # Not all losses require a KL calculation
657
- self.calculate_KL = True
658
- if self.loss_type in ["apo_zero_unpaired"]:
659
- self.calculate_KL = False
660
-
661
- # Since ref_logs are precomputed on the first call to get_train/eval_dataloader
662
- # keep track of first called to avoid computation of future calls
663
- self._precomputed_train_ref_log_probs = False
664
- self._precomputed_eval_ref_log_probs = False
665
-
666
- # metric
667
- self._stored_metrics = defaultdict(lambda: defaultdict(list))
668
-
669
- # KTO parameter
670
- self.beta = args.beta
671
- self.desirable_weight = args.desirable_weight
672
- self.undesirable_weight = args.undesirable_weight
673
- self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
674
- self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
675
- if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
676
- warnings.warn(
677
- "You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
678
- "`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
679
- "greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
680
- "loss.",
681
- UserWarning,
682
- )
683
-
684
- # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
685
- # input tensor associated with the key "input_ids". However, in KTO, the sampled data does not include the
686
- # "input_ids" key. Instead, the available keys are "prompt_input_ids" and "completion_input_ids". As a result,
687
- # the trainer issues the warning: "Could not estimate the number of tokens of the input, floating-point
688
- # operations will not be computed." To suppress this warning, we set the "estimate_tokens" key in the model's
689
- # "warnings_issued" dictionary to True. This acts as a flag to indicate that the warning has already been
690
- # issued.
691
- model.warnings_issued["estimate_tokens"] = True
692
-
693
- # Compute that only on the main process for faster data processing.
694
- # see: https://github.com/huggingface/trl/pull/1255
695
- with PartialState().local_main_process_first():
696
- # Extract the prompt if needed
697
- train_dataset = train_dataset.map(
698
- maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from train dataset"
699
- )
700
- # Unpair the dataset if needed
701
- train_dataset = maybe_unpair_preference_dataset(
702
- train_dataset, args.dataset_num_proc, desc="Unpairing train dataset"
703
- )
704
- # Apply the chat template if needed
705
- train_dataset = train_dataset.map(
706
- maybe_apply_chat_template,
707
- fn_kwargs={"tokenizer": processing_class},
708
- num_proc=args.dataset_num_proc,
709
- desc="Applying chat template to train dataset",
710
- )
711
- if eval_dataset is not None:
712
- eval_dataset = eval_dataset.map(
713
- maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from eval dataset"
714
- )
715
- eval_dataset = maybe_unpair_preference_dataset(
716
- eval_dataset, args.dataset_num_proc, desc="Unpairing eval dataset"
717
- )
718
- eval_dataset = eval_dataset.map(
719
- maybe_apply_chat_template,
720
- fn_kwargs={"tokenizer": processing_class},
721
- num_proc=args.dataset_num_proc,
722
- desc="Applying chat template to eval dataset",
723
- )
724
-
725
- # Tokenize and prepare the training datasets
726
- train_dataset = train_dataset.map(
727
- _tokenize,
728
- batched=True,
729
- fn_kwargs={"tokenizer": self.processing_class},
730
- num_proc=args.dataset_num_proc,
731
- desc="Tokenizing train dataset",
732
- )
733
-
734
- fn_kwargs = {
735
- "prefix": "",
736
- "is_encoder_decoder": self.is_encoder_decoder,
737
- "tokenizer": self.processing_class,
738
- "max_length": self.max_length,
739
- "truncation_mode": self.truncation_mode,
740
- "label_pad_token_id": self.label_pad_token_id,
741
- "max_prompt_length": self.max_prompt_length,
742
- "max_completion_length": self.max_completion_length,
743
- }
744
-
745
- train_dataset = train_dataset.map(
746
- _process_tokens,
747
- fn_kwargs=fn_kwargs,
748
- num_proc=args.dataset_num_proc,
749
- desc="Processing tokenized train dataset",
750
- )
751
-
752
- # Tokenize and prepare the eval datasets
753
- if eval_dataset is not None:
754
- eval_dataset = eval_dataset.map(
755
- _tokenize,
756
- fn_kwargs={"tokenizer": self.processing_class},
757
- batched=True,
758
- num_proc=args.dataset_num_proc,
759
- desc="Tokenizing eval dataset",
760
- )
761
-
762
- eval_dataset = eval_dataset.map(
763
- _process_tokens,
764
- fn_kwargs=fn_kwargs,
765
- num_proc=args.dataset_num_proc,
766
- desc="Processing tokenized eval dataset",
767
- )
768
-
769
- # Get KL datasets if needed
770
- if self.calculate_KL:
771
- if args.per_device_train_batch_size <= 1:
772
- raise ValueError(
773
- "Actual (not effective) batch size must be > 1. KTO will not work properly because the KL term will be equivalent to the implied reward."
774
- )
775
-
776
- # create pairs for estimating the KL term by flipping the matched pairs in each batch of size total_batch_size
777
- # i.e., (x_1, y_1), ..., (x_n, y_n) --> (x_1, y_n), ..., (x_n, y_1) = (x'_1, y'_1), ..., (x'_n, y'_n)
778
- train_kl_dataset = train_dataset.map(
779
- _get_kl_dataset,
780
- batched=True,
781
- batch_size=args.per_device_train_batch_size,
782
- num_proc=args.dataset_num_proc,
783
- desc="Extracting KL train dataset",
784
- )
785
-
786
- fn_kwargs["prefix"] = "KL_"
787
- train_kl_dataset = train_kl_dataset.map(
788
- _process_tokens,
789
- fn_kwargs=fn_kwargs,
790
- num_proc=args.dataset_num_proc,
791
- remove_columns=[c for c in train_kl_dataset.column_names if c in train_dataset.column_names],
792
- desc="Processing tokenized train KL dataset",
793
- )
794
-
795
- # merge the datasets
796
- train_dataset = concatenate_datasets([train_dataset, train_kl_dataset], axis=1)
797
-
798
- if eval_dataset is not None:
799
- # Get KL dataset
800
- eval_kl_dataset = eval_dataset.map(
801
- _get_kl_dataset,
802
- batched=True,
803
- batch_size=args.per_device_train_batch_size,
804
- num_proc=args.dataset_num_proc,
805
- desc="Extracting eval KL dataset",
806
- )
807
-
808
- eval_kl_dataset = eval_kl_dataset.map(
809
- _process_tokens,
810
- fn_kwargs=fn_kwargs,
811
- num_proc=args.dataset_num_proc,
812
- remove_columns=[c for c in eval_kl_dataset.column_names if c in eval_dataset.column_names],
813
- desc="Processing tokenized eval KL dataset",
814
- )
815
-
816
- # merge the datasets
817
- eval_dataset = concatenate_datasets([eval_dataset, eval_kl_dataset], axis=1)
818
-
819
- # calculate dataset desirability balance
820
- num_desirable = max(sum(train_dataset["label"]), 1)
821
- num_undesirable = max(len(train_dataset["label"]) - num_desirable, 1) # "label" is binary
822
-
823
- if num_desirable != num_undesirable:
824
- # The lower and upper bounds come from Eq. (8) of https://huggingface.co/papers/2402.01306
825
- des_weight_lower_bound = round((num_undesirable * self.undesirable_weight / num_desirable) * 1, 2)
826
- des_weight_upper_bound = round((num_undesirable * self.undesirable_weight / num_desirable) * 1.33, 2)
827
- und_weight_lower_bound = round((num_desirable * self.desirable_weight / num_undesirable) / 1.33, 2)
828
- und_weight_upper_bound = round((num_desirable * self.desirable_weight / num_undesirable) / 1, 2)
829
-
830
- des_weight_in_range = des_weight_lower_bound <= self.desirable_weight <= des_weight_upper_bound
831
- und_weight_in_range = und_weight_lower_bound <= self.undesirable_weight <= und_weight_upper_bound
832
-
833
- if not (des_weight_in_range or und_weight_in_range):
834
- warnings.warn(
835
- "You have different amounts of desirable/positive and undesirable/negative examples but the "
836
- "weights on the desirable and undesirable losses don't seem to be in an ideal range. Based "
837
- f"on your data, we recommend EITHER "
838
- f"desirable_weight in [{des_weight_lower_bound}, {des_weight_upper_bound}] or "
839
- f"undesirable_weight in [{und_weight_lower_bound}, {und_weight_upper_bound}] (but NOT BOTH). "
840
- "See the documentation on how to optimally set these weights.",
841
- UserWarning,
842
- )
843
-
844
- super().__init__(
845
- model=model,
846
- args=args,
847
- data_collator=data_collator,
848
- train_dataset=train_dataset,
849
- eval_dataset=eval_dataset,
850
- processing_class=processing_class,
851
- model_init=model_init,
852
- compute_metrics=compute_metrics,
853
- callbacks=callbacks,
854
- optimizers=optimizers,
855
- preprocess_logits_for_metrics=preprocess_logits_for_metrics,
856
- )
857
-
858
- # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
859
- # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
860
- # self.model_accepts_loss_kwargs to False to enable scaling.
861
- self.model_accepts_loss_kwargs = False
862
-
863
- # Add tags for models that have been loaded with the correct transformers version
864
- if hasattr(self.model, "add_model_tags"):
865
- self.model.add_model_tags(self._tag_names)
866
-
867
- if not hasattr(self, "accelerator"):
868
- raise AttributeError(
869
- "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
870
- )
871
-
872
- # Deepspeed Zero-3 does not support precompute_ref_log_probs
873
- if self.is_deepspeed_enabled:
874
- if self.accelerator.state.deepspeed_plugin.zero_stage == 3 and self.precompute_ref_log_probs:
875
- raise ValueError(
876
- "You cannot use `precompute_ref_log_probs=True` with Deepspeed ZeRO-3. Please set `precompute_ref_log_probs=False`."
877
- )
878
-
879
- if self.ref_model is None:
880
- if not (self.is_peft_model or self.precompute_ref_log_probs):
881
- raise ValueError(
882
- "No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`"
883
- )
884
- else:
885
- if self.is_deepspeed_enabled:
886
- self.ref_model = self._prepare_deepspeed(self.ref_model)
887
- else:
888
- self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
889
-
890
- def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
891
- # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
892
- deepspeed_plugin = self.accelerator.state.deepspeed_plugin
893
- config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
894
-
895
- if model is not None:
896
- if hasattr(model, "config"):
897
- hidden_size = (
898
- max(model.config.hidden_sizes)
899
- if getattr(model.config, "hidden_sizes", None)
900
- else getattr(model.config, "hidden_size", None)
901
- )
902
- if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
903
- # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
904
- # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
905
- config_kwargs.update(
906
- {
907
- "zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
908
- "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
909
- "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
910
- }
911
- )
912
-
913
- # If ZeRO-3 is used, we shard both the active and reference model.
914
- # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
915
- if config_kwargs["zero_optimization"]["stage"] != 3:
916
- config_kwargs["zero_optimization"]["stage"] = 0
917
- model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
918
- model.eval()
919
- return model
920
-
921
- @contextmanager
922
- def null_ref_context(self):
923
- """Context manager for handling null reference model (that is, peft adapter manipulation)."""
924
- with (
925
- self.accelerator.unwrap_model(self.model).disable_adapter()
926
- if self.is_peft_model and not self.ref_adapter_name
927
- else nullcontext()
928
- ):
929
- if self.ref_adapter_name:
930
- self.model.set_adapter(self.ref_adapter_name)
931
- yield
932
- if self.ref_adapter_name:
933
- self.model.set_adapter(self.model_adapter_name or "default")
934
-
935
- def get_train_dataloader(self) -> DataLoader:
936
- """
937
- Returns the training [`~torch.utils.data.DataLoader`].
938
-
939
- Subclass of transformers.src.transformers.trainer.get_train_dataloader to precompute `ref_log_probs`.
940
- """
941
-
942
- if self.precompute_ref_log_probs and not self._precomputed_train_ref_log_probs:
943
- dataloader_params = {
944
- "batch_size": self.args.per_device_train_batch_size,
945
- "collate_fn": self.data_collator,
946
- "num_workers": self.args.dataloader_num_workers,
947
- "pin_memory": self.args.dataloader_pin_memory,
948
- "shuffle": False,
949
- }
950
-
951
- # prepare dataloader
952
- data_loader = self.accelerator.prepare(DataLoader(self.train_dataset, **dataloader_params))
953
- reference_completion_logps = []
954
- reference_KL_logps = []
955
-
956
- for padded_batch in tqdm(iterable=data_loader, desc="Train dataset reference log probs"):
957
- reference_completion_logp, reference_KL_logp = self.compute_reference_log_probs(padded_batch)
958
-
959
- reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp)
960
- reference_completion_logps.append(reference_completion_logp.cpu())
961
-
962
- if self.calculate_KL:
963
- reference_KL_logp = self.accelerator.gather_for_metrics(reference_KL_logp)
964
- reference_KL_logps.append(reference_KL_logp.cpu())
965
-
966
- self.train_dataset = self.train_dataset.add_column(
967
- name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy()
968
- )
969
-
970
- if self.calculate_KL:
971
- self.train_dataset = self.train_dataset.add_column(
972
- name="reference_KL_logps", column=torch.cat(reference_KL_logps).float().numpy()
973
- )
974
-
975
- self._precomputed_train_ref_log_probs = True
976
-
977
- return super().get_train_dataloader()
978
-
979
- def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
980
- """
981
- Returns the evaluation [`~torch.utils.data.DataLoader`].
982
-
983
- Subclass of transformers.src.transformers.trainer.get_eval_dataloader to precompute `ref_log_probs`.
984
-
985
- Args:
986
- eval_dataset (`torch.utils.data.Dataset`, *optional*):
987
- If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted
988
- by the `model.forward()` method are automatically removed. It must implement `__len__`.
989
- """
990
- if eval_dataset is None and self.eval_dataset is None:
991
- raise ValueError("Trainer: evaluation requires an eval_dataset.")
992
- eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
993
-
994
- if self.precompute_ref_log_probs and not self._precomputed_eval_ref_log_probs:
995
- dataloader_params = {
996
- "batch_size": self.args.per_device_eval_batch_size,
997
- "collate_fn": self.data_collator,
998
- "num_workers": self.args.dataloader_num_workers,
999
- "pin_memory": self.args.dataloader_pin_memory,
1000
- "shuffle": False,
1001
- }
1002
-
1003
- # prepare dataloader
1004
- data_loader = self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params))
1005
-
1006
- reference_completion_logps = []
1007
- reference_KL_logps = []
1008
-
1009
- for padded_batch in tqdm(iterable=data_loader, desc="Eval dataset reference log probs"):
1010
- reference_completion_logp, reference_KL_logp = self.compute_reference_log_probs(padded_batch)
1011
-
1012
- reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp)
1013
- reference_completion_logps.append(reference_completion_logp.cpu())
1014
-
1015
- if self.calculate_KL:
1016
- reference_KL_logp = self.accelerator.gather_for_metrics(reference_KL_logp)
1017
- reference_KL_logps.append(reference_KL_logp.cpu())
1018
-
1019
- eval_dataset = eval_dataset.add_column(
1020
- name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy()
1021
- )
1022
- if self.calculate_KL:
1023
- eval_dataset = eval_dataset.add_column(
1024
- name="reference_KL_logps", column=torch.cat(reference_KL_logps).float().numpy()
1025
- )
1026
-
1027
- # Save calculated reference_chosen_logps and reference_rejected_logps to the eval_dataset for subsequent runs
1028
- if self.eval_dataset is not None:
1029
- self.eval_dataset = eval_dataset
1030
- self._precomputed_eval_ref_log_probs = True
1031
-
1032
- return super().get_eval_dataloader(eval_dataset=eval_dataset)
1033
-
1034
- def compute_reference_log_probs(self, padded_batch: dict) -> dict:
1035
- """Computes log probabilities of the reference model for a single padded batch of a KTO specific dataset."""
1036
- with torch.no_grad():
1037
- if self.ref_model is None:
1038
- with self.null_ref_context():
1039
- if self.is_encoder_decoder:
1040
- completion_logits = self.model(
1041
- padded_batch["prompt_input_ids"],
1042
- attention_mask=padded_batch["prompt_attention_mask"],
1043
- decoder_input_ids=padded_batch.get("completion_decoder_input_ids"),
1044
- labels=padded_batch["completion_labels"],
1045
- ).logits
1046
-
1047
- if self.calculate_KL:
1048
- KL_logits = self.model(
1049
- padded_batch["KL_prompt_input_ids"],
1050
- attention_mask=padded_batch["KL_prompt_attention_mask"],
1051
- decoder_input_ids=padded_batch.get("KL_completion_decoder_input_ids"),
1052
- labels=padded_batch["KL_completion_labels"],
1053
- ).logits
1054
- else:
1055
- completion_logits = self.model(
1056
- padded_batch["completion_input_ids"],
1057
- attention_mask=padded_batch["completion_attention_mask"],
1058
- ).logits
1059
-
1060
- if self.calculate_KL:
1061
- KL_logits = self.model(
1062
- padded_batch["KL_completion_input_ids"],
1063
- attention_mask=padded_batch["KL_completion_attention_mask"],
1064
- ).logits
1065
- else:
1066
- if self.is_encoder_decoder:
1067
- completion_logits = self.ref_model(
1068
- padded_batch["prompt_input_ids"],
1069
- attention_mask=padded_batch["prompt_attention_mask"],
1070
- decoder_input_ids=padded_batch.get("completion_decoder_input_ids"),
1071
- labels=padded_batch["completion_labels"],
1072
- ).logits
1073
-
1074
- if self.calculate_KL:
1075
- KL_logits = self.ref_model(
1076
- padded_batch["KL_prompt_input_ids"],
1077
- attention_mask=padded_batch["KL_prompt_attention_mask"],
1078
- decoder_input_ids=padded_batch.get("KL_completion_decoder_input_ids"),
1079
- labels=padded_batch["KL_completion_labels"],
1080
- ).logits
1081
- else:
1082
- completion_logits = self.ref_model(
1083
- padded_batch["completion_input_ids"], attention_mask=padded_batch["completion_attention_mask"]
1084
- ).logits
1085
-
1086
- if self.calculate_KL:
1087
- KL_logits = self.ref_model(
1088
- padded_batch["KL_completion_input_ids"],
1089
- attention_mask=padded_batch["KL_completion_attention_mask"],
1090
- ).logits
1091
-
1092
- completion_logps = self.get_batch_logps(
1093
- completion_logits,
1094
- padded_batch["completion_labels"],
1095
- average_log_prob=False,
1096
- is_encoder_decoder=self.is_encoder_decoder,
1097
- label_pad_token_id=self.label_pad_token_id,
1098
- )
1099
-
1100
- if self.calculate_KL:
1101
- KL_logps = self.get_batch_logps(
1102
- KL_logits,
1103
- padded_batch["KL_completion_labels"],
1104
- average_log_prob=False,
1105
- is_encoder_decoder=self.is_encoder_decoder,
1106
- label_pad_token_id=self.label_pad_token_id,
1107
- )
1108
- else:
1109
- KL_logps = None
1110
-
1111
- return completion_logps, KL_logps
1112
-
1113
- @staticmethod
1114
- def get_batch_logps(
1115
- logits: torch.FloatTensor,
1116
- labels: torch.LongTensor,
1117
- average_log_prob: bool = False,
1118
- label_pad_token_id: int = -100,
1119
- is_encoder_decoder: bool = False,
1120
- ) -> torch.FloatTensor:
1121
- """Compute the log probabilities of the given labels under the given logits.
1122
-
1123
- Args:
1124
- logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
1125
- labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
1126
- average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
1127
-
1128
- Returns:
1129
- A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
1130
- """
1131
- if logits.shape[:-1] != labels.shape:
1132
- raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
1133
-
1134
- if not is_encoder_decoder:
1135
- labels = labels[:, 1:].clone()
1136
- logits = logits[:, :-1, :]
1137
- else:
1138
- # Fixes end-dec RuntimeError
1139
- labels = labels.clone()
1140
-
1141
- loss_mask = labels != label_pad_token_id
1142
-
1143
- # dummy token; we'll ignore the losses on these tokens later
1144
- labels[labels == label_pad_token_id] = 0
1145
-
1146
- per_token_logps = selective_log_softmax(logits, labels)
1147
-
1148
- if average_log_prob:
1149
- return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
1150
- else:
1151
- return (per_token_logps * loss_mask).sum(-1)
1152
-
1153
- def forward(
1154
- self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
1155
- ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
1156
- if self.calculate_KL:
1157
- KL_logps = None
1158
- KL_model_kwargs = (
1159
- {
1160
- "input_ids": batch["KL_prompt_input_ids"],
1161
- "attention_mask": batch["KL_prompt_attention_mask"],
1162
- "labels": batch["KL_completion_labels"],
1163
- "decoder_input_ids": batch.get("KL_completion_decoder_input_ids"),
1164
- }
1165
- if self.is_encoder_decoder
1166
- else {
1167
- "input_ids": batch["KL_completion_input_ids"],
1168
- "attention_mask": batch["KL_completion_attention_mask"],
1169
- }
1170
- )
1171
- with torch.no_grad():
1172
- KL_logits = model(
1173
- **KL_model_kwargs,
1174
- ).logits
1175
-
1176
- KL_logps = self.get_batch_logps(
1177
- KL_logits,
1178
- batch["KL_completion_labels"],
1179
- average_log_prob=False,
1180
- is_encoder_decoder=self.is_encoder_decoder,
1181
- label_pad_token_id=self.label_pad_token_id,
1182
- )
1183
- else:
1184
- KL_logps = None
1185
-
1186
- model_kwargs = (
1187
- {
1188
- "labels": batch["completion_labels"],
1189
- "decoder_input_ids": batch.get("completion_decoder_input_ids"),
1190
- }
1191
- if self.is_encoder_decoder
1192
- else {}
1193
- )
1194
- if self.aux_loss_enabled:
1195
- model_kwargs["output_router_logits"] = True
1196
-
1197
- outputs = model(
1198
- batch["completion_input_ids"],
1199
- attention_mask=batch["completion_attention_mask"],
1200
- **model_kwargs,
1201
- )
1202
- completion_logits = outputs.logits
1203
-
1204
- completion_logps = self.get_batch_logps(
1205
- completion_logits,
1206
- batch["completion_labels"],
1207
- average_log_prob=False,
1208
- is_encoder_decoder=self.is_encoder_decoder,
1209
- label_pad_token_id=self.label_pad_token_id,
1210
- )
1211
-
1212
- if completion_logps.shape[0] != len(batch["label"]):
1213
- raise ValueError(
1214
- "There is a mismatch between the number of examples in this batch and the number of "
1215
- "examples for which an output sequence was predicted."
1216
- )
1217
-
1218
- chosen_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is True]
1219
- rejected_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is False]
1220
-
1221
- chosen_logps = completion_logps[chosen_idx, ...]
1222
- rejected_logps = completion_logps[rejected_idx, ...]
1223
-
1224
- chosen_logits = completion_logits[chosen_idx, ...]
1225
- rejected_logits = completion_logits[rejected_idx, ...]
1226
-
1227
- if self.aux_loss_enabled:
1228
- return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, KL_logps, outputs.aux_loss)
1229
- else:
1230
- return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, KL_logps)
1231
-
1232
- def kto_loss(
1233
- self,
1234
- policy_chosen_logps: torch.FloatTensor,
1235
- policy_rejected_logps: torch.FloatTensor,
1236
- policy_KL_logps: torch.FloatTensor,
1237
- reference_chosen_logps: torch.FloatTensor,
1238
- reference_rejected_logps: torch.FloatTensor,
1239
- reference_KL_logps: torch.FloatTensor,
1240
- ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
1241
- """Compute the KTO loss for a batch of policy and reference model log probabilities.
1242
-
1243
- Args:
1244
- policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (num(chosen) in batch_size,)
1245
- policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (num(rejected) in batch_size,)
1246
- policy_KL_logps: Log probabilities of the policy model for the KL responses. Shape: (batch_size,)
1247
- reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (num(chosen) in batch_size,)
1248
- reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (num(rejected) in batch_size,)
1249
- reference_KL_logps: Log probabilities of the reference model for the KL responses. Shape: (batch_size,)
1250
-
1251
- Returns:
1252
- A tuple of four tensors: (losses, chosen_rewards, rejected_rewards, KL).
1253
- The losses tensor contains the KTO loss for each example in the batch.
1254
- The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
1255
- The KL tensor contains the detached KL divergence estimate between the policy and reference models.
1256
- """
1257
- if self.calculate_KL:
1258
- kl = (policy_KL_logps - reference_KL_logps).mean().detach()
1259
- kl = self.accelerator.gather_for_metrics(kl).mean().clamp(min=0)
1260
- else:
1261
- kl = torch.zeros(1).to(policy_chosen_logps.device)
1262
-
1263
- # Chosen losses
1264
- if policy_chosen_logps.shape[0] != 0 or reference_chosen_logps.shape[0] != 0:
1265
- chosen_logratios = policy_chosen_logps - reference_chosen_logps
1266
-
1267
- if self.loss_type == "kto":
1268
- # Eqn (7) of the KTO paper (https://huggingface.co/papers/2402.01306)
1269
- chosen_losses = 1 - F.sigmoid(self.beta * (chosen_logratios - kl))
1270
- elif self.loss_type == "apo_zero_unpaired":
1271
- # Unpaired variant of Eqn (7) of the APO paper (https://huggingface.co/papers/2408.06266)
1272
- # Use this loss when you believe the chosen outputs are better than your model's default output
1273
- chosen_losses = 1 - F.sigmoid(self.beta * chosen_logratios)
1274
-
1275
- chosen_rewards = self.beta * chosen_logratios.detach()
1276
-
1277
- else:
1278
- # lists can't be empty -- if they are, then accelerate.gather will hang
1279
- chosen_losses = torch.Tensor([]).to(self.accelerator.device)
1280
- chosen_rewards = torch.Tensor([]).to(self.accelerator.device)
1281
-
1282
- # Rejected losses
1283
- if policy_rejected_logps.shape[0] != 0 or reference_rejected_logps.shape[0] != 0:
1284
- rejected_logratios = policy_rejected_logps - reference_rejected_logps
1285
-
1286
- if self.loss_type == "kto":
1287
- rejected_losses = 1 - F.sigmoid(self.beta * (kl - rejected_logratios))
1288
- elif self.loss_type == "apo_zero_unpaired":
1289
- rejected_losses = F.sigmoid(self.beta * rejected_logratios)
1290
-
1291
- rejected_rewards = self.beta * rejected_logratios.detach()
1292
- else:
1293
- # lists can't be empty -- if they are, then accelerate.gather will hang
1294
- rejected_losses = torch.Tensor([]).to(self.accelerator.device)
1295
- rejected_rewards = torch.Tensor([]).to(self.accelerator.device)
1296
-
1297
- losses = torch.cat(
1298
- (self.desirable_weight * chosen_losses, self.undesirable_weight * rejected_losses),
1299
- 0,
1300
- )
1301
-
1302
- return losses, chosen_rewards, rejected_rewards, kl
1303
-
1304
- def get_batch_loss_metrics(
1305
- self,
1306
- model,
1307
- batch: dict[str, Union[list, torch.LongTensor]],
1308
- ):
1309
- """Compute the KTO loss and other metrics for the given batch of inputs for train or test."""
1310
- metrics = {}
1311
- batch = {k: (v.to(self.accelerator.device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
1312
-
1313
- forward_output = self.forward(model, batch)
1314
- (
1315
- policy_chosen_logps,
1316
- policy_rejected_logps,
1317
- policy_chosen_logits,
1318
- policy_rejected_logits,
1319
- policy_KL_logps,
1320
- ) = forward_output[:5]
1321
- if self.aux_loss_enabled:
1322
- aux_loss = forward_output[5]
1323
-
1324
- # if reference_logps in batch use them, otherwise use the reference model
1325
- if "reference_logps" in batch:
1326
- chosen_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is True]
1327
- rejected_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is False]
1328
-
1329
- reference_chosen_logps = batch["reference_logps"][chosen_idx, ...]
1330
- reference_rejected_logps = batch["reference_logps"][rejected_idx, ...]
1331
- if self.calculate_KL:
1332
- reference_KL_logps = batch["reference_KL_logps"]
1333
- else:
1334
- reference_KL_logps = None
1335
- else:
1336
- with torch.no_grad():
1337
- if self.ref_model is None:
1338
- with self.null_ref_context():
1339
- (
1340
- reference_chosen_logps,
1341
- reference_rejected_logps,
1342
- _,
1343
- _,
1344
- reference_KL_logps,
1345
- ) = self.forward(self.model, batch)[:5]
1346
- else:
1347
- (
1348
- reference_chosen_logps,
1349
- reference_rejected_logps,
1350
- _,
1351
- _,
1352
- reference_KL_logps,
1353
- ) = self.forward(self.ref_model, batch)[:5]
1354
-
1355
- losses, chosen_rewards, rejected_rewards, kl = self.kto_loss(
1356
- policy_chosen_logps,
1357
- policy_rejected_logps,
1358
- policy_KL_logps,
1359
- reference_chosen_logps,
1360
- reference_rejected_logps,
1361
- reference_KL_logps,
1362
- )
1363
- metrics["kl"] = kl.item()
1364
-
1365
- num_chosen = torch.Tensor([len(chosen_rewards)]).to(self.accelerator.device)
1366
- num_rejected = torch.Tensor([len(rejected_rewards)]).to(self.accelerator.device)
1367
-
1368
- all_num_chosen = self.accelerator.gather_for_metrics(num_chosen).sum().item()
1369
- all_num_rejected = self.accelerator.gather_for_metrics(num_rejected).sum().item()
1370
-
1371
- if all_num_chosen > 0:
1372
- metrics["rewards/chosen_sum"] = (
1373
- self.accelerator.gather_for_metrics(chosen_rewards.nansum()).nansum().item()
1374
- )
1375
- metrics["logps/chosen_sum"] = (
1376
- self.accelerator.gather_for_metrics(policy_chosen_logps.nansum()).nansum().item()
1377
- )
1378
- metrics["logits/chosen_sum"] = (
1379
- self.accelerator.gather_for_metrics(policy_chosen_logits.nansum()).nansum().item()
1380
- )
1381
- metrics["count/chosen"] = all_num_chosen
1382
-
1383
- if all_num_rejected > 0:
1384
- metrics["rewards/rejected_sum"] = (
1385
- self.accelerator.gather_for_metrics(rejected_rewards.nansum()).nansum().item()
1386
- )
1387
- metrics["logps/rejected_sum"] = (
1388
- self.accelerator.gather_for_metrics(policy_rejected_logps.nansum()).nansum().item()
1389
- )
1390
- metrics["logits/rejected_sum"] = (
1391
- self.accelerator.gather_for_metrics(policy_rejected_logits.nansum()).nansum().item()
1392
- )
1393
- metrics["count/rejected"] = all_num_rejected
1394
-
1395
- loss = losses.nanmean()
1396
- if self.aux_loss_enabled:
1397
- loss += self.aux_loss_coef * aux_loss
1398
-
1399
- return loss, metrics
1400
-
1401
- def compute_loss(
1402
- self,
1403
- model: Union[PreTrainedModel, nn.Module],
1404
- inputs: dict[str, Union[torch.Tensor, Any]],
1405
- return_outputs=False,
1406
- num_items_in_batch=None,
1407
- ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
1408
- compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1409
-
1410
- with compute_loss_context_manager:
1411
- loss, metrics = self.get_batch_loss_metrics(model, inputs)
1412
-
1413
- # Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class:
1414
- loss = loss.to(self.args.device)
1415
- # force log the metrics
1416
- if self.accelerator.is_main_process:
1417
- self.store_metrics(metrics, train_eval="train")
1418
-
1419
- if return_outputs:
1420
- return (loss, metrics)
1421
- return loss
1422
-
1423
- def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
1424
- for key, value in metrics.items():
1425
- self._stored_metrics[train_eval][key].append(value)
1426
-
1427
- def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
1428
- if self.train_dataset is None or not has_length(self.train_dataset):
1429
- return None
1430
- return SequentialSampler(self.train_dataset)
1431
-
1432
- def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) -> tuple[str, str]:
1433
- """Generate samples from the model and reference model for the given batch of inputs."""
1434
-
1435
- # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
1436
- # the torch cuda amp context manager as some hidden states are silently casted to full precision.
1437
- generate_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1438
-
1439
- with generate_context_manager:
1440
- policy_output = model.generate(
1441
- input_ids=batch["prompt_input_ids"],
1442
- attention_mask=batch["prompt_attention_mask"],
1443
- max_length=self.max_length,
1444
- do_sample=True,
1445
- pad_token_id=self.processing_class.pad_token_id,
1446
- )
1447
-
1448
- # if reference_output in batch use that otherwise use the reference model
1449
- if "reference_output" in batch:
1450
- reference_output = batch["reference_output"]
1451
- else:
1452
- if self.ref_model is None:
1453
- with self.null_ref_context():
1454
- reference_output = self.model.generate(
1455
- input_ids=batch["prompt_input_ids"],
1456
- attention_mask=batch["prompt_attention_mask"],
1457
- max_length=self.max_length,
1458
- do_sample=True,
1459
- pad_token_id=self.processing_class.pad_token_id,
1460
- )
1461
- else:
1462
- reference_output = self.ref_model.generate(
1463
- input_ids=batch["prompt_input_ids"],
1464
- attention_mask=batch["prompt_attention_mask"],
1465
- max_length=self.max_length,
1466
- do_sample=True,
1467
- pad_token_id=self.processing_class.pad_token_id,
1468
- )
1469
-
1470
- policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id)
1471
- policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True)
1472
-
1473
- reference_output = pad_to_length(reference_output, self.max_length, self.processing_class.pad_token_id)
1474
- reference_output_decoded = self.processing_class.batch_decode(reference_output, skip_special_tokens=True)
1475
-
1476
- return policy_output_decoded, reference_output_decoded
1477
-
1478
- def prediction_step(
1479
- self,
1480
- model: Union[PreTrainedModel, nn.Module],
1481
- inputs: dict[str, Union[torch.Tensor, Any]],
1482
- prediction_loss_only: bool,
1483
- ignore_keys: Optional[list[str]] = None,
1484
- ):
1485
- if ignore_keys is None:
1486
- if hasattr(model, "config"):
1487
- ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
1488
- else:
1489
- ignore_keys = []
1490
-
1491
- prediction_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1492
- with torch.no_grad(), prediction_context_manager:
1493
- loss, metrics = self.get_batch_loss_metrics(model, inputs)
1494
-
1495
- # force log the metrics
1496
- if self.accelerator.is_main_process:
1497
- self.store_metrics(metrics, train_eval="eval")
1498
-
1499
- if prediction_loss_only:
1500
- return (loss.detach(), None, None)
1501
-
1502
- # logits for the chosen and rejected samples from model
1503
- logits_dict = {
1504
- "eval_logits/chosen": metrics["logits/chosen"],
1505
- "eval_logits/rejected": metrics["logits/rejected"],
1506
- }
1507
- logits = torch.tensor(
1508
- [v for k, v in logits_dict.items() if k not in ignore_keys], device=self.accelerator.device
1509
- )
1510
- labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
1511
-
1512
- return (loss.detach(), logits, labels)
1513
-
1514
- def evaluation_loop(
1515
- self,
1516
- dataloader: DataLoader,
1517
- description: str,
1518
- prediction_loss_only: Optional[bool] = None,
1519
- ignore_keys: Optional[list[str]] = None,
1520
- metric_key_prefix: str = "eval",
1521
- ) -> EvalLoopOutput:
1522
- """
1523
- Overriding built-in evaluation loop to store metrics for each batch.
1524
- Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
1525
-
1526
- Works both with or without labels.
1527
- """
1528
-
1529
- # Sample and save to game log if requested (for one batch to save time)
1530
- if self.generate_during_eval:
1531
- # Generate random indices within the range of the total number of samples
1532
- num_samples = len(dataloader.dataset)
1533
- random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
1534
-
1535
- # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
1536
- random_batch_dataset = dataloader.dataset.select(random_indices)
1537
- random_batch = self.data_collator(random_batch_dataset)
1538
- random_batch = self._prepare_inputs(random_batch)
1539
-
1540
- target_indicies = [i for i in range(len(random_batch["label"])) if random_batch["label"][i] is False]
1541
- target_batch = {
1542
- "prompt_input_ids": random_batch["prompt_input_ids"][target_indicies],
1543
- "prompt_attention_mask": random_batch["prompt_attention_mask"][target_indicies],
1544
- "prompt": itemgetter(*target_indicies)(random_batch["prompt"]),
1545
- }
1546
- policy_output_decoded, ref_output_decoded = self.generate_from_model_and_ref(self.model, target_batch)
1547
-
1548
- table = pd.DataFrame(
1549
- columns=["Prompt", "Policy", "Ref Model"],
1550
- data=[
1551
- [prompt, pol[len(prompt) :], ref[len(prompt) :]]
1552
- for prompt, pol, ref in zip(target_batch["prompt"], policy_output_decoded, ref_output_decoded)
1553
- ],
1554
- )
1555
- if "wandb" in self.args.report_to:
1556
- wandb.log({"game_log": wandb.Table(data=table)})
1557
-
1558
- if "comet_ml" in self.args.report_to:
1559
- log_table_to_comet_experiment(
1560
- name="game_log.csv",
1561
- table=table,
1562
- )
1563
-
1564
- # Base evaluation
1565
- initial_output = super().evaluation_loop(
1566
- dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
1567
- )
1568
-
1569
- return initial_output
1570
-
1571
- def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
1572
- """
1573
- Log `logs` on the various objects watching training, including stored metrics.
1574
-
1575
- Args:
1576
- logs (`dict[str, float]`):
1577
- The values to log.
1578
- start_time (`float` or `None`, *optional*, defaults to `None`):
1579
- Start time of the training.
1580
- """
1581
- # logs either has 'loss' or 'eval_loss'
1582
- train_eval = "train" if "loss" in logs else "eval"
1583
- # train metrics should have no prefix, eval should have 'eval_'
1584
- prefix = "eval_" if train_eval == "eval" else ""
1585
- # accumulate average metrics from sums and lengths
1586
- for split in ["chosen", "rejected"]:
1587
- if f"count/{split}" in self._stored_metrics[train_eval]:
1588
- count_sum = torch.Tensor(self._stored_metrics[train_eval][f"count/{split}"]).sum().item()
1589
- for metric in ["rewards", "logps", "logits"]:
1590
- logs[f"{prefix}{metric}/{split}"] = (
1591
- torch.Tensor(self._stored_metrics[train_eval][f"{metric}/{split}_sum"]).sum().item()
1592
- / count_sum
1593
- )
1594
- # delete obsolete metric
1595
- del self._stored_metrics[train_eval][f"{metric}/{split}_sum"]
1596
- del self._stored_metrics[train_eval][f"count/{split}"]
1597
- # calculate reward margin
1598
- if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs:
1599
- logs[f"{prefix}rewards/margins"] = logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"]
1600
- # Add averaged stored metrics to logs
1601
- for key, metrics in self._stored_metrics[train_eval].items():
1602
- logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item()
1603
- del self._stored_metrics[train_eval]
1604
-
1605
- if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
1606
- return super().log(logs, start_time)
1607
- else: # transformers<=4.46
1608
- return super().log(logs)
1609
-
1610
- def create_model_card(
1611
- self,
1612
- model_name: Optional[str] = None,
1613
- dataset_name: Optional[str] = None,
1614
- tags: Union[str, list[str], None] = None,
1615
- ):
1616
- """
1617
- Creates a draft of a model card using the information available to the `Trainer`.
1618
-
1619
- Args:
1620
- model_name (`str` or `None`, *optional*, defaults to `None`):
1621
- Name of the model.
1622
- dataset_name (`str` or `None`, *optional*, defaults to `None`):
1623
- Name of the dataset used for training.
1624
- tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
1625
- Tags to be associated with the model card.
1626
- """
1627
- if not self.is_world_process_zero():
1628
- return
1629
-
1630
- if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
1631
- base_model = self.model.config._name_or_path
1632
- else:
1633
- base_model = None
1634
-
1635
- tags = tags or []
1636
- if isinstance(tags, str):
1637
- tags = [tags]
1638
-
1639
- if hasattr(self.model.config, "unsloth_version"):
1640
- tags.append("unsloth")
1641
-
1642
- citation = textwrap.dedent("""\
1643
- @article{ethayarajh2024kto,
1644
- title = {{KTO: Model Alignment as Prospect Theoretic Optimization}},
1645
- author = {Kawin Ethayarajh and Winnie Xu and Niklas Muennighoff and Dan Jurafsky and Douwe Kiela},
1646
- year = 2024,
1647
- eprint = {arXiv:2402.01306},
1648
- }""")
1649
-
1650
- model_card = generate_model_card(
1651
- base_model=base_model,
1652
- model_name=model_name,
1653
- hub_model_id=self.hub_model_id,
1654
- dataset_name=dataset_name,
1655
- tags=tags,
1656
- wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
1657
- comet_url=get_comet_experiment_url(),
1658
- trainer_name="KTO",
1659
- trainer_citation=citation,
1660
- paper_title="KTO: Model Alignment as Prospect Theoretic Optimization",
1661
- paper_id="2402.01306",
1662
- )
1663
-
1664
- model_card.save(os.path.join(self.args.output_dir, "README.md"))
1665
- class UnslothKTOTrainer(_UnslothKTOTrainer):
1666
- """
1667
-
1668
- Initialize KTOTrainer.
1669
-
1670
- Args:
1671
- model (`transformers.PreTrainedModel`):
1672
- The model to train, preferably an `AutoModelForSequenceClassification`.
1673
- ref_model (`PreTrainedModelWrapper`):
1674
- Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no
1675
- reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized.
1676
- args (`KTOConfig`):
1677
- The arguments to use for training.
1678
- train_dataset (`datasets.Dataset`):
1679
- The dataset to use for training.
1680
- eval_dataset (`datasets.Dataset`):
1681
- The dataset to use for evaluation.
1682
- processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
1683
- Processing class used to process the data. If provided, will be used to automatically process the inputs
1684
- for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
1685
- reuse the fine-tuned model.
1686
- data_collator (`transformers.DataCollator`, *optional*, defaults to `None`):
1687
- The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
1688
- which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
1689
- model_init (`Callable[[], transformers.PreTrainedModel]`):
1690
- The model initializer to use for training. If None is specified, the default model initializer will be used.
1691
- callbacks (`list[transformers.TrainerCallback]`):
1692
- The callbacks to use for training.
1693
- optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
1694
- The optimizer and scheduler to use for training.
1695
- preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
1696
- The function to use to preprocess the logits before computing the metrics.
1697
- peft_config (`dict`, defaults to `None`):
1698
- The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
1699
- compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
1700
- The function to use to compute the metrics. Must take a `EvalPrediction` and return
1701
- a dictionary string to metric values.
1702
- model_adapter_name (`str`, defaults to `None`):
1703
- Name of the train target PEFT adapter, when using LoRA with multiple adapters.
1704
- ref_adapter_name (`str`, defaults to `None`):
1705
- Name of the reference PEFT adapter, when using LoRA with multiple adapters.
1706
-
1707
- """
1708
- def __init__(
1709
- self,
1710
- model = None,
1711
- ref_model = None,
1712
- args = None,
1713
- train_dataset = None,
1714
- eval_dataset = None,
1715
- processing_class = None,
1716
- data_collator = None,
1717
- model_init = None,
1718
- callbacks = None,
1719
- preprocess_logits_for_metrics = None,
1720
- peft_config = None,
1721
- compute_metrics = None,
1722
- model_adapter_name = None,
1723
- ref_adapter_name = None,
1724
- **kwargs
1725
- ):
1726
- if args is None: args = UnslothKTOConfig()
1727
- use_bf16 = getattr(args, 'bf16', False)
1728
- use_fp16 = getattr(args, 'fp16', False)
1729
- force_float32 = False
1730
- if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
1731
- print('Unsloth: Switching to float32 training since model cannot work with float16')
1732
- force_float32 = True
1733
- mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
1734
- dtype = getattr(model.config, 'torch_dtype', None)
1735
- if dtype is None: dtype = model.get_input_embeddings().dtype
1736
- from unsloth_zoo.utils import _get_dtype
1737
- dtype = _get_dtype(dtype)
1738
- float16 = dtype == torch.float16
1739
- if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
1740
- if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
1741
- if force_float32:
1742
- args.fp16 = False
1743
- args.bf16 = False
1744
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1745
- elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
1746
- args.fp16 = float16
1747
- args.bf16 = not float16
1748
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
1749
- if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
1750
- args.eval_strategy = 'steps'
1751
- if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
1752
- ga_steps = getattr(args, 'gradient_accumulation_steps', None)
1753
- if ga_steps is not None and ga_steps > 1:
1754
- from transformers import __version__ as transformers_version
1755
- if Version(transformers_version) <= Version('4.45.2'):
1756
- print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
1757
- '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
1758
- if getattr(args, 'eval_strategy', 'no') != 'no':
1759
- eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
1760
- if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
1761
- if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
1762
- fp16_full_eval = getattr(args, 'fp16_full_eval', False)
1763
- bf16_full_eval = getattr(args, 'bf16_full_eval', False)
1764
- if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
1765
- if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
1766
- if force_float32:
1767
- args.bf16_full_eval = False
1768
- args.fp16_full_eval = False
1769
- elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
1770
- args.bf16_full_eval = True
1771
- args.fp16_full_eval = False
1772
- elif not bf16_full_eval and not fp16_full_eval:
1773
- args.bf16_full_eval = args.bf16
1774
- args.fp16_full_eval = args.fp16
1775
- _output_logits = False
1776
- if locals().get('compute_metrics', None) is not None: _output_logits = True
1777
- if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
1778
- if _output_logits:
1779
- os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
1780
- if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1781
- pass
1782
- else:
1783
- model_max_seq_length = getattr(model, 'max_seq_length', None)
1784
- args_max_seq_length = getattr(args, 'max_seq_length', None)
1785
- if args_max_seq_length is None and model_max_seq_length is not None:
1786
- max_seq_length = model.max_seq_length
1787
- if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1788
- if model is not None and hasattr(model, 'for_training'):
1789
- model.for_training()
1790
- if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1791
- if 'processing_class' in locals():
1792
- if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1793
- if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1794
- __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
1795
- from unsloth_zoo.vision_utils import UnslothVisionDataCollator
1796
- if not isinstance(data_collator, UnslothVisionDataCollator):
1797
- if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
1798
- data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
1799
- elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
1800
- data_collator = DataCollatorForSeq2Seq(__tokenizer)
1801
- else:
1802
- if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
1803
- if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
1804
- if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
1805
- if not isinstance(data_collator, UnslothVisionDataCollator):
1806
- if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
1807
- if isinstance(data_collator, DataCollatorForSeq2Seq):
1808
- data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
1809
- else:
1810
- data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
1811
- other_metrics = []
1812
-
1813
- from unsloth_zoo.logging_utils import PatchRLStatistics
1814
- PatchRLStatistics('kto_trainer', other_metrics)
1815
-
1816
- super().__init__(
1817
- model = model,
1818
- ref_model = ref_model,
1819
- args = args,
1820
- train_dataset = train_dataset,
1821
- eval_dataset = eval_dataset,
1822
- processing_class = processing_class,
1823
- data_collator = data_collator,
1824
- model_init = model_init,
1825
- callbacks = callbacks,
1826
- preprocess_logits_for_metrics = preprocess_logits_for_metrics,
1827
- peft_config = peft_config,
1828
- compute_metrics = compute_metrics,
1829
- model_adapter_name = model_adapter_name,
1830
- ref_adapter_name = ref_adapter_name,**kwargs)
1831
- if hasattr(self, 'neftune_hook_handle'):
1832
- self.neftune_hook_handle.remove()
1833
- if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1834
- if getattr(args, 'neftune_noise_alpha', None) is not None:
1835
- model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1836
- pass
1837
-
1838
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
unsloth_compiled_cache/UnslothNashMDTrainer.py DELETED
@@ -1,953 +0,0 @@
1
- """
2
- 2025.3.13
3
- 2025.3.15
4
- 4.48.3
5
- 0.15.2
6
- __UNSLOTH_VERSIONING__
7
- """
8
- from torch import Tensor
9
- import torch
10
- import torch.nn as nn
11
- from torch.nn import functional as F
12
- from trl.trainer.nash_md_trainer import (Any, BaseImageProcessor, BasePairwiseJudge, Callable, Dataset, EvalPrediction, F, FeatureExtractionMixin, GeometricMixtureWrapper, IterableDataset, NashMDConfig, NashMDTrainer, OnlineDPOTrainer, OptimizerNames, Optional, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SIMPLE_CHAT_TEMPLATE, TrainerCallback, Union, empty_cache, generate_model_card, get_comet_experiment_url, get_reward, is_conversational, is_wandb_available, jinja2, maybe_apply_chat_template, nn, os, textwrap, torch, truncate_right, unwrap_model_for_generation, wandb)
13
-
14
-
15
- import os
16
- from typing import *
17
- from dataclasses import dataclass, field
18
- from packaging.version import Version
19
- import torch
20
- import numpy as np
21
- from contextlib import nullcontext
22
- from torch.nn import functional as F
23
- from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
-
25
- torch_compile_options = {
26
- "epilogue_fusion" : True,
27
- "max_autotune" : False,
28
- "shape_padding" : True,
29
- "trace.enabled" : False,
30
- "triton.cudagraphs" : False,
31
- }
32
-
33
- @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
- def selective_log_softmax(logits, index):
35
- logits = logits.to(torch.float32)
36
- selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
- # loop to reduce peak mem consumption
38
- # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
- logsumexp_values = torch.logsumexp(logits, dim = -1)
40
- per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
- return per_token_logps
42
- @dataclass
43
- class UnslothNashMDConfig(NashMDConfig):
44
- """
45
-
46
- Configuration class for the [`NashMDTrainer`].
47
-
48
- Subclass of [`OnlineDPOConfig`] we can use all its arguments and add the following:
49
-
50
- Parameters:
51
- mixture_coef (`float` or `list[float]`, *optional*, defaults to `0.5`):
52
- Logit mixture coefficient for the model and reference model. If a list of floats is provided then the
53
- mixture coefficient is selected for each new epoch and the last coefficient is used for the rest of the
54
- epochs.
55
-
56
- """
57
- vllm_sampling_params: Optional[Any] = field(
58
- default = None,
59
- metadata = {'help': 'vLLM SamplingParams'},
60
- )
61
- unsloth_num_chunks : Optional[int] = field(
62
- default = -1,
63
- metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
64
- )
65
- def __init__(
66
- self,
67
- output_dir = None,
68
- overwrite_output_dir = None,
69
- do_train = False,
70
- do_eval = False,
71
- do_predict = False,
72
- eval_strategy = 'no',
73
- prediction_loss_only = False,
74
- per_device_train_batch_size = 4,
75
- per_device_eval_batch_size = 4,
76
- per_gpu_train_batch_size = None,
77
- per_gpu_eval_batch_size = None,
78
- gradient_accumulation_steps = 2,
79
- eval_accumulation_steps = 2,
80
- eval_delay = 0,
81
- torch_empty_cache_steps = 250,
82
- learning_rate = 5e-05,
83
- weight_decay = 0.01,
84
- adam_beta1 = 0.9,
85
- adam_beta2 = 0.999,
86
- adam_epsilon = 1e-08,
87
- max_grad_norm = 1.0,
88
- num_train_epochs = 3.0,
89
- max_steps = -1,
90
- lr_scheduler_type = 'linear',
91
- warmup_ratio = 0.1,
92
- warmup_steps = 0,
93
- log_level = 'passive',
94
- log_level_replica = 'warning',
95
- log_on_each_node = True,
96
- logging_dir = None,
97
- logging_strategy = 'steps',
98
- logging_first_step = False,
99
- logging_steps = 1,
100
- logging_nan_inf_filter = False,
101
- save_strategy = 'steps',
102
- save_steps = 500,
103
- save_total_limit = None,
104
- save_safetensors = True,
105
- save_on_each_node = False,
106
- save_only_model = False,
107
- restore_callback_states_from_checkpoint = False,
108
- no_cuda = False,
109
- use_cpu = False,
110
- use_mps_device = False,
111
- seed = 3407,
112
- data_seed = 3407,
113
- jit_mode_eval = False,
114
- use_ipex = False,
115
- bf16 = False,
116
- fp16 = False,
117
- fp16_opt_level = 'O1',
118
- half_precision_backend = 'auto',
119
- bf16_full_eval = False,
120
- fp16_full_eval = False,
121
- tf32 = None,
122
- local_rank = -1,
123
- ddp_backend = None,
124
- tpu_num_cores = None,
125
- tpu_metrics_debug = False,
126
- debug = '',
127
- dataloader_drop_last = False,
128
- eval_steps = None,
129
- dataloader_num_workers = 0,
130
- dataloader_prefetch_factor = None,
131
- past_index = -1,
132
- run_name = None,
133
- disable_tqdm = None,
134
- remove_unused_columns = True,
135
- label_names = None,
136
- load_best_model_at_end = False,
137
- metric_for_best_model = None,
138
- greater_is_better = None,
139
- ignore_data_skip = False,
140
- fsdp = '',
141
- fsdp_min_num_params = 0,
142
- fsdp_config = None,
143
- fsdp_transformer_layer_cls_to_wrap = None,
144
- accelerator_config = None,
145
- deepspeed = None,
146
- label_smoothing_factor = 0.0,
147
- optim = 'adamw_8bit',
148
- optim_args = None,
149
- adafactor = False,
150
- group_by_length = False,
151
- length_column_name = 'length',
152
- report_to = None,
153
- ddp_find_unused_parameters = None,
154
- ddp_bucket_cap_mb = None,
155
- ddp_broadcast_buffers = None,
156
- dataloader_pin_memory = True,
157
- dataloader_persistent_workers = False,
158
- skip_memory_metrics = True,
159
- use_legacy_prediction_loop = False,
160
- push_to_hub = False,
161
- resume_from_checkpoint = None,
162
- hub_model_id = None,
163
- hub_strategy = 'every_save',
164
- hub_token = None,
165
- hub_private_repo = None,
166
- hub_always_push = False,
167
- gradient_checkpointing = False,
168
- gradient_checkpointing_kwargs = None,
169
- include_inputs_for_metrics = False,
170
- eval_do_concat_batches = True,
171
- fp16_backend = 'auto',
172
- evaluation_strategy = None,
173
- push_to_hub_model_id = None,
174
- push_to_hub_organization = None,
175
- push_to_hub_token = None,
176
- mp_parameters = '',
177
- auto_find_batch_size = False,
178
- full_determinism = False,
179
- torchdynamo = None,
180
- ray_scope = 'last',
181
- ddp_timeout = 1800,
182
- torch_compile = False,
183
- torch_compile_backend = None,
184
- torch_compile_mode = None,
185
- dispatch_batches = None,
186
- split_batches = None,
187
- include_tokens_per_second = False,
188
- include_num_input_tokens_seen = False,
189
- neftune_noise_alpha = None,
190
- optim_target_modules = None,
191
- batch_eval_metrics = False,
192
- eval_on_start = False,
193
- use_liger_kernel = False,
194
- eval_use_gather_object = False,
195
- average_tokens_across_devices = False,
196
- reward_model_path = None,
197
- judge = None,
198
- max_new_tokens = 64,
199
- max_length = 512,
200
- temperature = 0.9,
201
- missing_eos_penalty = None,
202
- loss_type = 'sigmoid',
203
- dataset_num_proc = None,
204
- disable_dropout = True,
205
- use_vllm = False,
206
- ds3_gather_for_generation = True,
207
- vllm_sampling_params = None,
208
- unsloth_num_chunks = -1,
209
- **kwargs,
210
- ):
211
- if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
212
- if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
213
- if output_dir is None and save_strategy == 'steps' and save_steps == 500:
214
- output_dir = 'unsloth_training_checkpoints'
215
- save_strategy = 'no'
216
- if dataset_num_proc is None:
217
- from multiprocessing import cpu_count
218
- dataset_num_proc = cpu_count()
219
-
220
- super().__init__(
221
- output_dir = output_dir,
222
- overwrite_output_dir = overwrite_output_dir,
223
- do_train = do_train,
224
- do_eval = do_eval,
225
- do_predict = do_predict,
226
- eval_strategy = eval_strategy,
227
- prediction_loss_only = prediction_loss_only,
228
- per_device_train_batch_size = per_device_train_batch_size,
229
- per_device_eval_batch_size = per_device_eval_batch_size,
230
- per_gpu_train_batch_size = per_gpu_train_batch_size,
231
- per_gpu_eval_batch_size = per_gpu_eval_batch_size,
232
- gradient_accumulation_steps = gradient_accumulation_steps,
233
- eval_accumulation_steps = eval_accumulation_steps,
234
- eval_delay = eval_delay,
235
- torch_empty_cache_steps = torch_empty_cache_steps,
236
- learning_rate = learning_rate,
237
- weight_decay = weight_decay,
238
- adam_beta1 = adam_beta1,
239
- adam_beta2 = adam_beta2,
240
- adam_epsilon = adam_epsilon,
241
- max_grad_norm = max_grad_norm,
242
- num_train_epochs = num_train_epochs,
243
- max_steps = max_steps,
244
- lr_scheduler_type = lr_scheduler_type,
245
- warmup_ratio = warmup_ratio,
246
- warmup_steps = warmup_steps,
247
- log_level = log_level,
248
- log_level_replica = log_level_replica,
249
- log_on_each_node = log_on_each_node,
250
- logging_dir = logging_dir,
251
- logging_strategy = logging_strategy,
252
- logging_first_step = logging_first_step,
253
- logging_steps = logging_steps,
254
- logging_nan_inf_filter = logging_nan_inf_filter,
255
- save_strategy = save_strategy,
256
- save_steps = save_steps,
257
- save_total_limit = save_total_limit,
258
- save_safetensors = save_safetensors,
259
- save_on_each_node = save_on_each_node,
260
- save_only_model = save_only_model,
261
- restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
262
- no_cuda = no_cuda,
263
- use_cpu = use_cpu,
264
- use_mps_device = use_mps_device,
265
- seed = seed,
266
- data_seed = data_seed,
267
- jit_mode_eval = jit_mode_eval,
268
- use_ipex = use_ipex,
269
- bf16 = bf16,
270
- fp16 = fp16,
271
- fp16_opt_level = fp16_opt_level,
272
- half_precision_backend = half_precision_backend,
273
- bf16_full_eval = bf16_full_eval,
274
- fp16_full_eval = fp16_full_eval,
275
- tf32 = tf32,
276
- local_rank = local_rank,
277
- ddp_backend = ddp_backend,
278
- tpu_num_cores = tpu_num_cores,
279
- tpu_metrics_debug = tpu_metrics_debug,
280
- debug = debug,
281
- dataloader_drop_last = dataloader_drop_last,
282
- eval_steps = eval_steps,
283
- dataloader_num_workers = dataloader_num_workers,
284
- dataloader_prefetch_factor = dataloader_prefetch_factor,
285
- past_index = past_index,
286
- run_name = run_name,
287
- disable_tqdm = disable_tqdm,
288
- remove_unused_columns = remove_unused_columns,
289
- label_names = label_names,
290
- load_best_model_at_end = load_best_model_at_end,
291
- metric_for_best_model = metric_for_best_model,
292
- greater_is_better = greater_is_better,
293
- ignore_data_skip = ignore_data_skip,
294
- fsdp = fsdp,
295
- fsdp_min_num_params = fsdp_min_num_params,
296
- fsdp_config = fsdp_config,
297
- fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
298
- accelerator_config = accelerator_config,
299
- deepspeed = deepspeed,
300
- label_smoothing_factor = label_smoothing_factor,
301
- optim = optim,
302
- optim_args = optim_args,
303
- adafactor = adafactor,
304
- group_by_length = group_by_length,
305
- length_column_name = length_column_name,
306
- report_to = report_to,
307
- ddp_find_unused_parameters = ddp_find_unused_parameters,
308
- ddp_bucket_cap_mb = ddp_bucket_cap_mb,
309
- ddp_broadcast_buffers = ddp_broadcast_buffers,
310
- dataloader_pin_memory = dataloader_pin_memory,
311
- dataloader_persistent_workers = dataloader_persistent_workers,
312
- skip_memory_metrics = skip_memory_metrics,
313
- use_legacy_prediction_loop = use_legacy_prediction_loop,
314
- push_to_hub = push_to_hub,
315
- resume_from_checkpoint = resume_from_checkpoint,
316
- hub_model_id = hub_model_id,
317
- hub_strategy = hub_strategy,
318
- hub_token = hub_token,
319
- hub_private_repo = hub_private_repo,
320
- hub_always_push = hub_always_push,
321
- gradient_checkpointing = gradient_checkpointing,
322
- gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
323
- include_inputs_for_metrics = include_inputs_for_metrics,
324
- eval_do_concat_batches = eval_do_concat_batches,
325
- fp16_backend = fp16_backend,
326
- evaluation_strategy = evaluation_strategy,
327
- push_to_hub_model_id = push_to_hub_model_id,
328
- push_to_hub_organization = push_to_hub_organization,
329
- push_to_hub_token = push_to_hub_token,
330
- mp_parameters = mp_parameters,
331
- auto_find_batch_size = auto_find_batch_size,
332
- full_determinism = full_determinism,
333
- torchdynamo = torchdynamo,
334
- ray_scope = ray_scope,
335
- ddp_timeout = ddp_timeout,
336
- torch_compile = torch_compile,
337
- torch_compile_backend = torch_compile_backend,
338
- torch_compile_mode = torch_compile_mode,
339
- dispatch_batches = dispatch_batches,
340
- split_batches = split_batches,
341
- include_tokens_per_second = include_tokens_per_second,
342
- include_num_input_tokens_seen = include_num_input_tokens_seen,
343
- neftune_noise_alpha = neftune_noise_alpha,
344
- optim_target_modules = optim_target_modules,
345
- batch_eval_metrics = batch_eval_metrics,
346
- eval_on_start = eval_on_start,
347
- use_liger_kernel = use_liger_kernel,
348
- eval_use_gather_object = eval_use_gather_object,
349
- average_tokens_across_devices = average_tokens_across_devices,
350
- reward_model_path = reward_model_path,
351
- judge = judge,
352
- max_new_tokens = max_new_tokens,
353
- max_length = max_length,
354
- temperature = temperature,
355
- missing_eos_penalty = missing_eos_penalty,
356
- loss_type = loss_type,
357
- dataset_num_proc = dataset_num_proc,
358
- disable_dropout = disable_dropout,
359
- use_vllm = use_vllm,
360
- ds3_gather_for_generation = ds3_gather_for_generation,**kwargs)
361
- self.vllm_sampling_params = vllm_sampling_params
362
- self.unsloth_num_chunks = unsloth_num_chunks
363
- pass
364
-
365
- class _UnslothNashMDTrainer(OnlineDPOTrainer):
366
- r""""""
367
-
368
- _tag_names = ["trl", "nash-md"]
369
-
370
- def __init__(
371
- self,
372
- model: Union[PreTrainedModel, nn.Module] = None,
373
- ref_model: Union[PreTrainedModel, nn.Module] = None,
374
- reward_model: Union[PreTrainedModel, nn.Module, None] = None,
375
- judge: Optional[BasePairwiseJudge] = None,
376
- args: Optional[NashMDConfig] = None,
377
- data_collator: Optional[Callable] = None,
378
- train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
379
- eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
380
- processing_class: Optional[
381
- Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
382
- ] = None,
383
- peft_config: Optional[dict] = None,
384
- compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
385
- callbacks: Optional[list[TrainerCallback]] = None,
386
- optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
387
- preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
388
- ) -> None:
389
- super().__init__(
390
- model=model,
391
- ref_model=ref_model,
392
- reward_model=reward_model,
393
- judge=judge,
394
- args=args,
395
- data_collator=data_collator,
396
- train_dataset=train_dataset,
397
- eval_dataset=eval_dataset,
398
- processing_class=processing_class,
399
- reward_processing_class=processing_class, # for now, NashMDTrainer can't use any reward model
400
- peft_config=peft_config,
401
- compute_metrics=compute_metrics,
402
- callbacks=callbacks,
403
- optimizers=optimizers,
404
- preprocess_logits_for_metrics=preprocess_logits_for_metrics,
405
- )
406
-
407
- self._mixture_coef = self.args.mixture_coef
408
-
409
- # Overwrite the stats dictionary to include NashMD specific statistics
410
- self.stats = {
411
- # Remove "non_score_reward", "rlhf_reward", "scores_margin"
412
- # Add "mixture_coef"
413
- "loss/kl": [],
414
- "objective/entropy": [],
415
- "loss/score": [],
416
- "rewards/probabilities": [],
417
- "rewards/accuracies": [],
418
- "rewards/margins": [],
419
- "logps/chosen": [],
420
- "logps/rejected": [],
421
- "val/model_contain_eos_token": [],
422
- "val/ref_contain_eos_token": [],
423
- "beta": [],
424
- "mixture_coef": [],
425
- }
426
- if self.reward_model is not None:
427
- self.stats["rewards/chosen"] = []
428
- self.stats["rewards/rejected"] = []
429
-
430
- @property
431
- def mixture_coef(self):
432
- if isinstance(self._mixture_coef, list):
433
- epoch = self.state.epoch
434
- return self._mixture_coef[epoch] if epoch < len(self._mixture_coef) else self._mixture_coef[-1]
435
- else:
436
- return self._mixture_coef
437
-
438
- def _generate_completions(self, model, prompts):
439
- with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
440
- model_output = unwrapped_model.generate(
441
- input_ids=prompts["input_ids"],
442
- attention_mask=prompts["attention_mask"],
443
- generation_config=self.generation_config,
444
- )
445
-
446
- ref_model = model if self.ref_model is None else self.ref_model
447
- with torch.no_grad(), unwrap_model_for_generation(ref_model, self.accelerator) as unwrapped_ref_model:
448
- mixture_model = GeometricMixtureWrapper(
449
- model=unwrapped_model,
450
- ref_model=unwrapped_ref_model,
451
- generation_config=self.generation_config,
452
- mixture_coef=self.mixture_coef,
453
- device=self.accelerator.device,
454
- )
455
-
456
- mixture_output = mixture_model.generate(
457
- input_ids=prompts["input_ids"],
458
- attention_mask=prompts["attention_mask"],
459
- generation_config=self.generation_config,
460
- )
461
-
462
- return model_output, mixture_output
463
-
464
- def _process_completions(self, model_output, mixture_output, prompts):
465
- context_length = prompts["input_ids"].shape[1]
466
-
467
- # Process model completions
468
- model_completion_ids = model_output[:, context_length:]
469
- model_completion_ids, model_completion_mask = truncate_right(
470
- model_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id
471
- )
472
- model_data = {
473
- "input_ids": torch.cat((prompts["input_ids"], model_completion_ids), dim=1),
474
- "attention_mask": torch.cat((prompts["attention_mask"], model_completion_mask), dim=1),
475
- "raw": prompts["raw"],
476
- }
477
-
478
- # Process reference model completions
479
- mixture_completion_ids = mixture_output[:, context_length:]
480
- mixture_completion_ids, mixture_completion_mask = truncate_right(
481
- mixture_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id
482
- )
483
- mixture_data = {
484
- "input_ids": torch.cat((prompts["input_ids"], mixture_completion_ids), dim=1),
485
- "attention_mask": torch.cat((prompts["attention_mask"], mixture_completion_mask), dim=1),
486
- "raw": prompts["raw"],
487
- }
488
-
489
- return model_data, mixture_data
490
-
491
- def _compute_rewards(self, model_data, mixture_data, context_length):
492
- with torch.no_grad():
493
- _, model_scores, _ = get_reward(
494
- self.reward_model, model_data["input_ids"], self.processing_class.pad_token_id, context_length
495
- )
496
- _, mixture_scores, _ = get_reward(
497
- self.reward_model, mixture_data["input_ids"], self.processing_class.pad_token_id, context_length
498
- )
499
-
500
- # Apply EOS penalty if needed
501
- if self.args.missing_eos_penalty is not None:
502
- model_contain_eos = torch.any(model_data["input_ids"] == self.processing_class.eos_token_id, dim=-1)
503
- mixture_contain_eos = torch.any(mixture_data["input_ids"] == self.processing_class.eos_token_id, dim=-1)
504
- model_scores[~model_contain_eos] -= self.args.missing_eos_penalty
505
- mixture_scores[~mixture_contain_eos] -= self.args.missing_eos_penalty
506
-
507
- return model_scores, mixture_scores
508
-
509
- def _compute_judge(self, model_data, mixture_data, context_length):
510
- prompts = model_data["raw"]
511
- model_data_completions = self.processing_class.batch_decode(
512
- model_data["input_ids"][:, context_length:], skip_special_tokens=True
513
- )
514
- model_data_completions = [completion.strip() for completion in model_data_completions]
515
-
516
- mixture_data_completions = self.processing_class.batch_decode(
517
- mixture_data["input_ids"][:, context_length:], skip_special_tokens=True
518
- )
519
- mixture_data_completions = [completion.strip() for completion in mixture_data_completions]
520
- if is_conversational({"prompt": prompts[0]}):
521
- model_data_completions = [
522
- [{"role": "assistant", "content": completion}] for completion in model_data_completions
523
- ]
524
- environment = jinja2.Environment()
525
- template = environment.from_string(SIMPLE_CHAT_TEMPLATE)
526
- prompts = [template.render(messages=message) for message in prompts]
527
- model_data_completions = [template.render(messages=completion) for completion in model_data_completions]
528
-
529
- mixture_data_completions = [
530
- [{"role": "assistant", "content": completion}] for completion in mixture_data_completions
531
- ]
532
- mixture_data_completions = [
533
- template.render(messages=completion) for completion in mixture_data_completions
534
- ]
535
-
536
- probability = self.judge.judge(
537
- prompts,
538
- list(zip(model_data_completions, mixture_data_completions)),
539
- return_scores=True,
540
- )
541
- return torch.tensor(probability, device=model_data["input_ids"].device)
542
-
543
- def _compute_logprobs(self, model, model_data, context_length):
544
- def compute_logprobs_for_data(m, data):
545
- output = m(data["input_ids"], attention_mask=data["attention_mask"])
546
- logits = output.logits[:, context_length - 1 : -1]
547
- token_logprobs = selective_log_softmax(logits, data["input_ids"][:, context_length:])
548
- return token_logprobs
549
-
550
- # Compute logprobs for model completions under the model
551
- model_logprobs_model_data = compute_logprobs_for_data(model, model_data)
552
-
553
- # Compute logprobs of model completions under the reference model
554
- with torch.no_grad():
555
- if self.ref_model is None:
556
- with model.disable_adapter():
557
- ref_logprobs_model_data = compute_logprobs_for_data(model, model_data)
558
- else:
559
- ref_logprobs_model_data = compute_logprobs_for_data(self.ref_model, model_data)
560
-
561
- # Mask padding tokens
562
- model_padding_mask = model_data["attention_mask"][:, context_length:] == 0
563
- model_logprobs_model_data = model_logprobs_model_data.masked_fill(model_padding_mask, 0.0)
564
- ref_logprobs_model_data = ref_logprobs_model_data.masked_fill(model_padding_mask, 0.0)
565
-
566
- return (model_logprobs_model_data, ref_logprobs_model_data)
567
-
568
- def _compute_losses(
569
- self,
570
- model_logprobs_model_data,
571
- ref_logprobs_model_data,
572
- probability,
573
- ):
574
- # reinforce score where 0.5 is a control variate
575
- score = (probability - 0.5) * model_logprobs_model_data.sum(1)
576
-
577
- # kl divergence via reinforce
578
- with torch.no_grad():
579
- log_ratio = model_logprobs_model_data - ref_logprobs_model_data
580
- kl_div_log = log_ratio.sum(1)
581
- kl_div_loss = (log_ratio * model_logprobs_model_data).sum(1)
582
-
583
- # final loss
584
- loss = self.beta * kl_div_loss - score
585
-
586
- return loss.mean(), score, kl_div_log
587
-
588
- def _log_statistics(
589
- self,
590
- model_data,
591
- mixture_data,
592
- model_logprobs_model_data,
593
- ref_logprobs_model_data,
594
- probability,
595
- score,
596
- kl_div,
597
- context_length,
598
- model_scores=None,
599
- mixture_scores=None,
600
- ):
601
- # Helper function to gather and compute mean
602
- def gather_mean(tensor):
603
- return self.accelerator.gather_for_metrics(tensor).mean().item()
604
-
605
- # Log score
606
- self.stats["loss/score"].append(gather_mean(score))
607
- # Log KL divergence
608
- self.stats["loss/kl"].append(gather_mean(kl_div))
609
-
610
- # Log logprobs
611
- model_logprobs_model_data_sum = model_logprobs_model_data.sum(1)
612
- ref_logprobs_model_data_sum = ref_logprobs_model_data.sum(1)
613
-
614
- self.stats["logps/chosen"].append(gather_mean(model_logprobs_model_data_sum))
615
- self.stats["logps/rejected"].append(gather_mean(ref_logprobs_model_data_sum))
616
-
617
- # Log rewards
618
- if self.reward_model is not None:
619
- self.stats["rewards/chosen"].append(gather_mean(model_scores))
620
- self.stats["rewards/rejected"].append(gather_mean(mixture_scores))
621
-
622
- # Log probabilities
623
- self.stats["rewards/probabilities"].append(gather_mean(probability))
624
-
625
- # Calculate entropy for model data
626
- entropy_model_data = -model_logprobs_model_data.sum(1)
627
- self.stats["objective/entropy"].append(gather_mean(entropy_model_data))
628
-
629
- # Calculate margins
630
- margin = model_logprobs_model_data_sum - ref_logprobs_model_data_sum
631
- self.stats["rewards/margins"].append(gather_mean(margin))
632
-
633
- # Calculate accuracy
634
- accuracy = (margin > 0).float()
635
- self.stats["rewards/accuracies"].append(gather_mean(accuracy))
636
-
637
- # Log EOS token statistics
638
- model_eos = (model_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1)
639
- mixture_eos = (mixture_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1)
640
- self.stats["val/model_contain_eos_token"].append(gather_mean(model_eos.float()))
641
- self.stats["val/ref_contain_eos_token"].append(gather_mean(mixture_eos.float()))
642
-
643
- # Log beta and mixture coef
644
- self.stats["beta"].append(self.beta)
645
- self.stats["mixture_coef"].append(self.mixture_coef)
646
-
647
- def training_step(
648
- self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
649
- ) -> torch.Tensor:
650
- model.train()
651
-
652
- # Apply chat template and tokenize the input
653
- batch_size = len(next(iter(inputs.values())))
654
- prompts = inputs["prompt"]
655
- inputs = [{k: v[i] for k, v in inputs.items()} for i in range(batch_size)]
656
- inputs = [maybe_apply_chat_template(x, self.processing_class) for x in inputs]
657
- inputs = [self.tokenize_row(x, self.model.config.is_encoder_decoder, self.processing_class) for x in inputs]
658
- inputs = self.data_collator(inputs)
659
-
660
- # need the prompt_ only
661
- inputs = self._prepare_inputs(inputs)
662
- context_length = inputs["prompt_input_ids"].shape[1]
663
- prompts = {
664
- "input_ids": inputs["prompt_input_ids"],
665
- "attention_mask": inputs["prompt_attention_mask"],
666
- "raw": prompts,
667
- }
668
- del inputs
669
-
670
- # Sample completions from both the model and the reference model
671
- model_output, mixture_output = self._generate_completions(model, prompts)
672
-
673
- # Process model completions
674
- model_data, mixture_data = self._process_completions(model_output, mixture_output, prompts)
675
-
676
- # Compute rewards
677
- if self.reward_model is not None:
678
- model_scores, mixture_scores = self._compute_rewards(model_data, mixture_data, context_length)
679
- # probability of the model data vs the mixture data
680
- probability = F.sigmoid(model_scores - mixture_scores)
681
- else:
682
- model_scores, mixture_scores = None, None
683
- probability = self._compute_judge(model_data, mixture_data, context_length)
684
-
685
- # Compute logprobs
686
- model_logprobs_model_data, ref_logprobs_model_data = self._compute_logprobs(model, model_data, context_length)
687
-
688
- # Compute loss
689
- loss, score, kl_div = self._compute_losses(model_logprobs_model_data, ref_logprobs_model_data, probability)
690
-
691
- # Log everything
692
- self._log_statistics(
693
- model_data,
694
- mixture_data,
695
- model_logprobs_model_data.detach(),
696
- ref_logprobs_model_data,
697
- probability,
698
- score.detach(),
699
- kl_div.detach(),
700
- context_length,
701
- model_scores,
702
- mixture_scores,
703
- )
704
-
705
- if (
706
- self.args.torch_empty_cache_steps is not None
707
- and self.state.global_step % self.args.torch_empty_cache_steps == 0
708
- ):
709
- empty_cache()
710
-
711
- kwargs = {}
712
- # For LOMO optimizers you need to explicitly use the learning rate
713
- if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
714
- kwargs["learning_rate"] = self._get_learning_rate()
715
-
716
- if self.args.n_gpu > 1:
717
- loss = loss.mean() # mean() to average on multi-gpu parallel training
718
-
719
- if self.use_apex:
720
- with amp.scale_loss(loss, self.optimizer) as scaled_loss:
721
- scaled_loss.backward()
722
- else:
723
- self.accelerator.backward(loss, **kwargs)
724
-
725
- return loss.detach() / self.args.gradient_accumulation_steps
726
-
727
- def create_model_card(
728
- self,
729
- model_name: Optional[str] = None,
730
- dataset_name: Optional[str] = None,
731
- tags: Union[str, list[str], None] = None,
732
- ):
733
- """
734
- Creates a draft of a model card using the information available to the `Trainer`.
735
-
736
- Args:
737
- model_name (`str` or `None`, *optional*, defaults to `None`):
738
- Name of the model.
739
- dataset_name (`str` or `None`, *optional*, defaults to `None`):
740
- Name of the dataset used for training.
741
- tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
742
- Tags to be associated with the model card.
743
- """
744
- if not self.is_world_process_zero():
745
- return
746
-
747
- if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
748
- base_model = self.model.config._name_or_path
749
- else:
750
- base_model = None
751
-
752
- tags = tags or []
753
- if isinstance(tags, str):
754
- tags = [tags]
755
-
756
- if hasattr(self.model.config, "unsloth_version"):
757
- tags.append("unsloth")
758
-
759
- citation = textwrap.dedent("""\
760
- @inproceedings{munos2024nash,
761
- title = {{Nash Learning from Human Feedback}},
762
- author = {R{\'{e}}mi Munos and Michal Valko and Daniele Calandriello and Mohammad Gheshlaghi Azar and Mark Rowland and Zhaohan Daniel Guo and Yunhao Tang and Matthieu Geist and Thomas Mesnard and C{\\^{o}}me Fiegel and Andrea Michi and Marco Selvi and Sertan Girgin and Nikola Momchev and Olivier Bachem and Daniel J. Mankowitz and Doina Precup and Bilal Piot},
763
- year = 2024,
764
- booktitle = {Forty-first International Conference on Machine Learning, {ICML} 2024, Vienna, Austria, July 21-27, 2024},
765
- publisher = {OpenReview.net},
766
- url = {https://openreview.net/forum?id=Y5AmNYiyCQ}
767
- }""")
768
-
769
- model_card = generate_model_card(
770
- base_model=base_model,
771
- model_name=model_name,
772
- hub_model_id=self.hub_model_id,
773
- dataset_name=dataset_name,
774
- tags=tags,
775
- wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
776
- comet_url=get_comet_experiment_url(),
777
- trainer_name="Nash-MD",
778
- trainer_citation=citation,
779
- paper_title="Nash Learning from Human Feedback",
780
- paper_id="2312.00886",
781
- )
782
-
783
- model_card.save(os.path.join(self.args.output_dir, "README.md"))
784
- class UnslothNashMDTrainer(_UnslothNashMDTrainer):
785
- """
786
-
787
- Initialize NashMDTrainer as a subclass of [`OnlineDPOConfig`].
788
-
789
- Args:
790
- model (`transformers.PreTrainedModel`):
791
- The model to train, preferably an `AutoModelForCausalLM`.
792
- ref_model (`PreTrainedModelWrapper`):
793
- Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no
794
- reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized.
795
- reward_model (`transformers.PreTrainedModel`):
796
- The reward model to score completions with, preferably an `AutoModelForSequenceClassification`.
797
- judge (`BasePairwiseJudge`):
798
- The judge to use for pairwise comparison of model completions.
799
- args (`NashMDConfig`):
800
- The NashMD config arguments to use for training.
801
- data_collator (`transformers.DataCollator`):
802
- The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
803
- which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
804
- train_dataset (`datasets.Dataset`):
805
- The dataset to use for training.
806
- eval_dataset (`datasets.Dataset`):
807
- The dataset to use for evaluation.
808
- processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
809
- Processing class used to process the data. If provided, will be used to automatically process the inputs
810
- for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
811
- reuse the fine-tuned model.
812
- peft_config (`dict`):
813
- The peft config to use for training.
814
- compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
815
- The function to use to compute the metrics. Must take a `EvalPrediction` and return
816
- a dictionary string to metric values.
817
- callbacks (`list[transformers.TrainerCallback]`):
818
- The callbacks to use for training.
819
- optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
820
- The optimizer and scheduler to use for training.
821
- preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
822
- The function to use to preprocess the logits before computing the metrics.
823
-
824
- """
825
- def __init__(
826
- self,
827
- model = None,
828
- ref_model = None,
829
- reward_model = None,
830
- judge = None,
831
- args = None,
832
- data_collator = None,
833
- train_dataset = None,
834
- eval_dataset = None,
835
- processing_class = None,
836
- peft_config = None,
837
- compute_metrics = None,
838
- callbacks = None,
839
- preprocess_logits_for_metrics = None,
840
- **kwargs
841
- ):
842
- if args is None: args = UnslothNashMDConfig()
843
- use_bf16 = getattr(args, 'bf16', False)
844
- use_fp16 = getattr(args, 'fp16', False)
845
- force_float32 = False
846
- if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
847
- print('Unsloth: Switching to float32 training since model cannot work with float16')
848
- force_float32 = True
849
- mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
850
- dtype = getattr(model.config, 'torch_dtype', None)
851
- if dtype is None: dtype = model.get_input_embeddings().dtype
852
- from unsloth_zoo.utils import _get_dtype
853
- dtype = _get_dtype(dtype)
854
- float16 = dtype == torch.float16
855
- if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
856
- if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
857
- if force_float32:
858
- args.fp16 = False
859
- args.bf16 = False
860
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
861
- elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
862
- args.fp16 = float16
863
- args.bf16 = not float16
864
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
865
- if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
866
- args.eval_strategy = 'steps'
867
- if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
868
- ga_steps = getattr(args, 'gradient_accumulation_steps', None)
869
- if ga_steps is not None and ga_steps > 1:
870
- from transformers import __version__ as transformers_version
871
- if Version(transformers_version) <= Version('4.45.2'):
872
- print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
873
- '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
874
- if getattr(args, 'eval_strategy', 'no') != 'no':
875
- eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
876
- if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
877
- if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
878
- fp16_full_eval = getattr(args, 'fp16_full_eval', False)
879
- bf16_full_eval = getattr(args, 'bf16_full_eval', False)
880
- if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
881
- if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
882
- if force_float32:
883
- args.bf16_full_eval = False
884
- args.fp16_full_eval = False
885
- elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
886
- args.bf16_full_eval = True
887
- args.fp16_full_eval = False
888
- elif not bf16_full_eval and not fp16_full_eval:
889
- args.bf16_full_eval = args.bf16
890
- args.fp16_full_eval = args.fp16
891
- _output_logits = False
892
- if locals().get('compute_metrics', None) is not None: _output_logits = True
893
- if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
894
- if _output_logits:
895
- os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
896
- if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
897
- pass
898
- else:
899
- model_max_seq_length = getattr(model, 'max_seq_length', None)
900
- args_max_seq_length = getattr(args, 'max_seq_length', None)
901
- if args_max_seq_length is None and model_max_seq_length is not None:
902
- max_seq_length = model.max_seq_length
903
- if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
904
- if model is not None and hasattr(model, 'for_training'):
905
- model.for_training()
906
- if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
907
- if 'processing_class' in locals():
908
- if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
909
- if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
910
- __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
911
- from unsloth_zoo.vision_utils import UnslothVisionDataCollator
912
- if not isinstance(data_collator, UnslothVisionDataCollator):
913
- if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
914
- data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
915
- elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
916
- data_collator = DataCollatorForSeq2Seq(__tokenizer)
917
- else:
918
- if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
919
- if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
920
- if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
921
- if not isinstance(data_collator, UnslothVisionDataCollator):
922
- if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
923
- if isinstance(data_collator, DataCollatorForSeq2Seq):
924
- data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
925
- else:
926
- data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
927
- other_metrics = []
928
-
929
- from unsloth_zoo.logging_utils import PatchRLStatistics
930
- PatchRLStatistics('nash_md_trainer', other_metrics)
931
-
932
- super().__init__(
933
- model = model,
934
- ref_model = ref_model,
935
- reward_model = reward_model,
936
- judge = judge,
937
- args = args,
938
- data_collator = data_collator,
939
- train_dataset = train_dataset,
940
- eval_dataset = eval_dataset,
941
- processing_class = processing_class,
942
- peft_config = peft_config,
943
- compute_metrics = compute_metrics,
944
- callbacks = callbacks,
945
- preprocess_logits_for_metrics = preprocess_logits_for_metrics,**kwargs)
946
- if hasattr(self, 'neftune_hook_handle'):
947
- self.neftune_hook_handle.remove()
948
- if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
949
- if getattr(args, 'neftune_noise_alpha', None) is not None:
950
- model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
951
- pass
952
-
953
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
unsloth_compiled_cache/UnslothORPOTrainer.py DELETED
@@ -1,1541 +0,0 @@
1
- """
2
- 2025.3.13
3
- 2025.3.15
4
- 4.48.3
5
- 0.15.2
6
- __UNSLOTH_VERSIONING__
7
- """
8
- from torch import Tensor
9
- import torch
10
- import torch.nn as nn
11
- from torch.nn import functional as F
12
- from trl.trainer.orpo_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, Literal, ORPOConfig, ORPOTrainer, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedModelWrapper, PreTrainedTokenizerBase, ProcessorMixin, Trainer, TrainerCallback, Union, add_bos_token_if_needed, add_eos_token_if_needed, amp, deepcopy, defaultdict, disable_dropout_in_model, generate_model_card, get_comet_experiment_url, inspect, is_comet_available, is_peft_available, is_torch_fx_proxy, is_torch_xla_available, is_wandb_available, log_table_to_comet_experiment, maybe_apply_chat_template, maybe_extract_prompt, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_model_for_kbit_training, random, textwrap, torch, transformers, version, wandb, warnings)
13
-
14
-
15
- import os
16
- from typing import *
17
- from dataclasses import dataclass, field
18
- from packaging.version import Version
19
- import torch
20
- import numpy as np
21
- from contextlib import nullcontext
22
- from torch.nn import functional as F
23
- from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
-
25
- torch_compile_options = {
26
- "epilogue_fusion" : True,
27
- "max_autotune" : False,
28
- "shape_padding" : True,
29
- "trace.enabled" : False,
30
- "triton.cudagraphs" : False,
31
- }
32
-
33
- @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
- def selective_log_softmax(logits, index):
35
- logits = logits.to(torch.float32)
36
- selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
- # loop to reduce peak mem consumption
38
- # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
- logsumexp_values = torch.logsumexp(logits, dim = -1)
40
- per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
- return per_token_logps
42
- @dataclass
43
- class UnslothORPOConfig(ORPOConfig):
44
- """
45
-
46
- Configuration class for the [`ORPOTrainer`].
47
-
48
- Using [`~transformers.HfArgumentParser`] we can turn this class into
49
- [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
50
- command line.
51
-
52
- Parameters:
53
- learning_rate (`float`, *optional*, defaults to `1e-6`):
54
- Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
55
- [`~transformers.TrainingArguments`].
56
- max_length (`int` or `None`, *optional*, defaults to `1024`):
57
- Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
58
- to use the default data collator.
59
- max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
60
- Maximum length of the prompt. This argument is required if you want to use the default data collator.
61
- max_completion_length (`int` or `None`, *optional*, defaults to `None`):
62
- Maximum length of the completion. This argument is required if you want to use the default data collator
63
- and your model is an encoder-decoder.
64
- beta (`float`, *optional*, defaults to `0.1`):
65
- Parameter controlling the relative ratio loss weight in the ORPO loss. In the [paper](https://huggingface.co/papers/2403.07691),
66
- it is denoted by λ. In the [code](https://github.com/xfactlab/orpo), it is denoted by `alpha`.
67
- disable_dropout (`bool`, *optional*, defaults to `True`):
68
- Whether to disable dropout in the model.
69
- label_pad_token_id (`int`, *optional*, defaults to `-100`):
70
- Label pad token id. This argument is required if you want to use the default data collator.
71
- padding_value (`int` or `None`, *optional*, defaults to `None`):
72
- Padding value to use. If `None`, the padding value of the tokenizer is used.
73
- truncation_mode (`str`, *optional*, defaults to `"keep_end"`):
74
- Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`.
75
- This argument is required if you want to use the default data collator.
76
- generate_during_eval (`bool`, *optional*, defaults to `False`):
77
- If `True`, generates and logs completions from the model to W&B or Comet during evaluation.
78
- is_encoder_decoder (`bool` or `None`, *optional*, defaults to `None`):
79
- When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument,
80
- you need to specify if the model returned by the callable is an encoder-decoder model.
81
- model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
82
- Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
83
- string.
84
- dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
85
- Number of processes to use for processing the dataset.
86
-
87
- """
88
- vllm_sampling_params: Optional[Any] = field(
89
- default = None,
90
- metadata = {'help': 'vLLM SamplingParams'},
91
- )
92
- unsloth_num_chunks : Optional[int] = field(
93
- default = -1,
94
- metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
95
- )
96
- def __init__(
97
- self,
98
- output_dir = None,
99
- overwrite_output_dir = None,
100
- do_train = False,
101
- do_eval = False,
102
- do_predict = False,
103
- eval_strategy = 'no',
104
- prediction_loss_only = False,
105
- per_device_train_batch_size = 4,
106
- per_device_eval_batch_size = 4,
107
- per_gpu_train_batch_size = None,
108
- per_gpu_eval_batch_size = None,
109
- gradient_accumulation_steps = 2,
110
- eval_accumulation_steps = 2,
111
- eval_delay = 0,
112
- torch_empty_cache_steps = 250,
113
- learning_rate = 5e-05,
114
- weight_decay = 0.01,
115
- adam_beta1 = 0.9,
116
- adam_beta2 = 0.999,
117
- adam_epsilon = 1e-08,
118
- max_grad_norm = 1.0,
119
- num_train_epochs = 3.0,
120
- max_steps = -1,
121
- lr_scheduler_type = 'linear',
122
- warmup_ratio = 0.1,
123
- warmup_steps = 0,
124
- log_level = 'passive',
125
- log_level_replica = 'warning',
126
- log_on_each_node = True,
127
- logging_dir = None,
128
- logging_strategy = 'steps',
129
- logging_first_step = False,
130
- logging_steps = 1,
131
- logging_nan_inf_filter = False,
132
- save_strategy = 'steps',
133
- save_steps = 500,
134
- save_total_limit = None,
135
- save_safetensors = True,
136
- save_on_each_node = False,
137
- save_only_model = False,
138
- restore_callback_states_from_checkpoint = False,
139
- no_cuda = False,
140
- use_cpu = False,
141
- use_mps_device = False,
142
- seed = 3407,
143
- data_seed = 3407,
144
- jit_mode_eval = False,
145
- use_ipex = False,
146
- bf16 = False,
147
- fp16 = False,
148
- fp16_opt_level = 'O1',
149
- half_precision_backend = 'auto',
150
- bf16_full_eval = False,
151
- fp16_full_eval = False,
152
- tf32 = None,
153
- local_rank = -1,
154
- ddp_backend = None,
155
- tpu_num_cores = None,
156
- tpu_metrics_debug = False,
157
- debug = '',
158
- dataloader_drop_last = False,
159
- eval_steps = None,
160
- dataloader_num_workers = 0,
161
- dataloader_prefetch_factor = None,
162
- past_index = -1,
163
- run_name = None,
164
- disable_tqdm = None,
165
- remove_unused_columns = True,
166
- label_names = None,
167
- load_best_model_at_end = False,
168
- metric_for_best_model = None,
169
- greater_is_better = None,
170
- ignore_data_skip = False,
171
- fsdp = '',
172
- fsdp_min_num_params = 0,
173
- fsdp_config = None,
174
- fsdp_transformer_layer_cls_to_wrap = None,
175
- accelerator_config = None,
176
- deepspeed = None,
177
- label_smoothing_factor = 0.0,
178
- optim = 'adamw_8bit',
179
- optim_args = None,
180
- adafactor = False,
181
- group_by_length = False,
182
- length_column_name = 'length',
183
- report_to = None,
184
- ddp_find_unused_parameters = None,
185
- ddp_bucket_cap_mb = None,
186
- ddp_broadcast_buffers = None,
187
- dataloader_pin_memory = True,
188
- dataloader_persistent_workers = False,
189
- skip_memory_metrics = True,
190
- use_legacy_prediction_loop = False,
191
- push_to_hub = False,
192
- resume_from_checkpoint = None,
193
- hub_model_id = None,
194
- hub_strategy = 'every_save',
195
- hub_token = None,
196
- hub_private_repo = None,
197
- hub_always_push = False,
198
- gradient_checkpointing = False,
199
- gradient_checkpointing_kwargs = None,
200
- include_inputs_for_metrics = False,
201
- eval_do_concat_batches = True,
202
- fp16_backend = 'auto',
203
- evaluation_strategy = None,
204
- push_to_hub_model_id = None,
205
- push_to_hub_organization = None,
206
- push_to_hub_token = None,
207
- mp_parameters = '',
208
- auto_find_batch_size = False,
209
- full_determinism = False,
210
- torchdynamo = None,
211
- ray_scope = 'last',
212
- ddp_timeout = 1800,
213
- torch_compile = False,
214
- torch_compile_backend = None,
215
- torch_compile_mode = None,
216
- dispatch_batches = None,
217
- split_batches = None,
218
- include_tokens_per_second = False,
219
- include_num_input_tokens_seen = False,
220
- neftune_noise_alpha = None,
221
- optim_target_modules = None,
222
- batch_eval_metrics = False,
223
- eval_on_start = False,
224
- use_liger_kernel = False,
225
- eval_use_gather_object = False,
226
- average_tokens_across_devices = False,
227
- max_length = 1024,
228
- max_prompt_length = 512,
229
- max_completion_length = None,
230
- beta = 0.1,
231
- disable_dropout = True,
232
- label_pad_token_id = -100,
233
- padding_value = None,
234
- truncation_mode = 'keep_end',
235
- generate_during_eval = False,
236
- is_encoder_decoder = None,
237
- model_init_kwargs = None,
238
- dataset_num_proc = None,
239
- vllm_sampling_params = None,
240
- unsloth_num_chunks = -1,
241
- **kwargs,
242
- ):
243
- if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
244
- if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
245
- if output_dir is None and save_strategy == 'steps' and save_steps == 500:
246
- output_dir = 'unsloth_training_checkpoints'
247
- save_strategy = 'no'
248
- if dataset_num_proc is None:
249
- from multiprocessing import cpu_count
250
- dataset_num_proc = cpu_count()
251
-
252
- super().__init__(
253
- output_dir = output_dir,
254
- overwrite_output_dir = overwrite_output_dir,
255
- do_train = do_train,
256
- do_eval = do_eval,
257
- do_predict = do_predict,
258
- eval_strategy = eval_strategy,
259
- prediction_loss_only = prediction_loss_only,
260
- per_device_train_batch_size = per_device_train_batch_size,
261
- per_device_eval_batch_size = per_device_eval_batch_size,
262
- per_gpu_train_batch_size = per_gpu_train_batch_size,
263
- per_gpu_eval_batch_size = per_gpu_eval_batch_size,
264
- gradient_accumulation_steps = gradient_accumulation_steps,
265
- eval_accumulation_steps = eval_accumulation_steps,
266
- eval_delay = eval_delay,
267
- torch_empty_cache_steps = torch_empty_cache_steps,
268
- learning_rate = learning_rate,
269
- weight_decay = weight_decay,
270
- adam_beta1 = adam_beta1,
271
- adam_beta2 = adam_beta2,
272
- adam_epsilon = adam_epsilon,
273
- max_grad_norm = max_grad_norm,
274
- num_train_epochs = num_train_epochs,
275
- max_steps = max_steps,
276
- lr_scheduler_type = lr_scheduler_type,
277
- warmup_ratio = warmup_ratio,
278
- warmup_steps = warmup_steps,
279
- log_level = log_level,
280
- log_level_replica = log_level_replica,
281
- log_on_each_node = log_on_each_node,
282
- logging_dir = logging_dir,
283
- logging_strategy = logging_strategy,
284
- logging_first_step = logging_first_step,
285
- logging_steps = logging_steps,
286
- logging_nan_inf_filter = logging_nan_inf_filter,
287
- save_strategy = save_strategy,
288
- save_steps = save_steps,
289
- save_total_limit = save_total_limit,
290
- save_safetensors = save_safetensors,
291
- save_on_each_node = save_on_each_node,
292
- save_only_model = save_only_model,
293
- restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
294
- no_cuda = no_cuda,
295
- use_cpu = use_cpu,
296
- use_mps_device = use_mps_device,
297
- seed = seed,
298
- data_seed = data_seed,
299
- jit_mode_eval = jit_mode_eval,
300
- use_ipex = use_ipex,
301
- bf16 = bf16,
302
- fp16 = fp16,
303
- fp16_opt_level = fp16_opt_level,
304
- half_precision_backend = half_precision_backend,
305
- bf16_full_eval = bf16_full_eval,
306
- fp16_full_eval = fp16_full_eval,
307
- tf32 = tf32,
308
- local_rank = local_rank,
309
- ddp_backend = ddp_backend,
310
- tpu_num_cores = tpu_num_cores,
311
- tpu_metrics_debug = tpu_metrics_debug,
312
- debug = debug,
313
- dataloader_drop_last = dataloader_drop_last,
314
- eval_steps = eval_steps,
315
- dataloader_num_workers = dataloader_num_workers,
316
- dataloader_prefetch_factor = dataloader_prefetch_factor,
317
- past_index = past_index,
318
- run_name = run_name,
319
- disable_tqdm = disable_tqdm,
320
- remove_unused_columns = remove_unused_columns,
321
- label_names = label_names,
322
- load_best_model_at_end = load_best_model_at_end,
323
- metric_for_best_model = metric_for_best_model,
324
- greater_is_better = greater_is_better,
325
- ignore_data_skip = ignore_data_skip,
326
- fsdp = fsdp,
327
- fsdp_min_num_params = fsdp_min_num_params,
328
- fsdp_config = fsdp_config,
329
- fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
330
- accelerator_config = accelerator_config,
331
- deepspeed = deepspeed,
332
- label_smoothing_factor = label_smoothing_factor,
333
- optim = optim,
334
- optim_args = optim_args,
335
- adafactor = adafactor,
336
- group_by_length = group_by_length,
337
- length_column_name = length_column_name,
338
- report_to = report_to,
339
- ddp_find_unused_parameters = ddp_find_unused_parameters,
340
- ddp_bucket_cap_mb = ddp_bucket_cap_mb,
341
- ddp_broadcast_buffers = ddp_broadcast_buffers,
342
- dataloader_pin_memory = dataloader_pin_memory,
343
- dataloader_persistent_workers = dataloader_persistent_workers,
344
- skip_memory_metrics = skip_memory_metrics,
345
- use_legacy_prediction_loop = use_legacy_prediction_loop,
346
- push_to_hub = push_to_hub,
347
- resume_from_checkpoint = resume_from_checkpoint,
348
- hub_model_id = hub_model_id,
349
- hub_strategy = hub_strategy,
350
- hub_token = hub_token,
351
- hub_private_repo = hub_private_repo,
352
- hub_always_push = hub_always_push,
353
- gradient_checkpointing = gradient_checkpointing,
354
- gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
355
- include_inputs_for_metrics = include_inputs_for_metrics,
356
- eval_do_concat_batches = eval_do_concat_batches,
357
- fp16_backend = fp16_backend,
358
- evaluation_strategy = evaluation_strategy,
359
- push_to_hub_model_id = push_to_hub_model_id,
360
- push_to_hub_organization = push_to_hub_organization,
361
- push_to_hub_token = push_to_hub_token,
362
- mp_parameters = mp_parameters,
363
- auto_find_batch_size = auto_find_batch_size,
364
- full_determinism = full_determinism,
365
- torchdynamo = torchdynamo,
366
- ray_scope = ray_scope,
367
- ddp_timeout = ddp_timeout,
368
- torch_compile = torch_compile,
369
- torch_compile_backend = torch_compile_backend,
370
- torch_compile_mode = torch_compile_mode,
371
- dispatch_batches = dispatch_batches,
372
- split_batches = split_batches,
373
- include_tokens_per_second = include_tokens_per_second,
374
- include_num_input_tokens_seen = include_num_input_tokens_seen,
375
- neftune_noise_alpha = neftune_noise_alpha,
376
- optim_target_modules = optim_target_modules,
377
- batch_eval_metrics = batch_eval_metrics,
378
- eval_on_start = eval_on_start,
379
- use_liger_kernel = use_liger_kernel,
380
- eval_use_gather_object = eval_use_gather_object,
381
- average_tokens_across_devices = average_tokens_across_devices,
382
- max_length = max_length,
383
- max_prompt_length = max_prompt_length,
384
- max_completion_length = max_completion_length,
385
- beta = beta,
386
- disable_dropout = disable_dropout,
387
- label_pad_token_id = label_pad_token_id,
388
- padding_value = padding_value,
389
- truncation_mode = truncation_mode,
390
- generate_during_eval = generate_during_eval,
391
- is_encoder_decoder = is_encoder_decoder,
392
- model_init_kwargs = model_init_kwargs,
393
- dataset_num_proc = dataset_num_proc,**kwargs)
394
- self.vllm_sampling_params = vllm_sampling_params
395
- self.unsloth_num_chunks = unsloth_num_chunks
396
- pass
397
-
398
- class _UnslothORPOTrainer(Trainer):
399
- r""""""
400
-
401
- _tag_names = ["trl", "orpo"]
402
-
403
- def __init__(
404
- self,
405
- model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
406
- args: Optional[ORPOConfig] = None,
407
- data_collator: Optional[DataCollator] = None,
408
- train_dataset: Optional[Dataset] = None,
409
- eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
410
- processing_class: Optional[
411
- Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
412
- ] = None,
413
- model_init: Optional[Callable[[], PreTrainedModel]] = None,
414
- callbacks: Optional[list[TrainerCallback]] = None,
415
- optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
416
- preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
417
- peft_config: Optional[dict] = None,
418
- compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
419
- ):
420
- if args.model_init_kwargs is None:
421
- model_init_kwargs = {}
422
- elif not isinstance(model, str):
423
- raise ValueError("You passed model_kwargs to the ORPOTrainer. But your model is already instantiated.")
424
- else:
425
- model_init_kwargs = args.model_init_kwargs
426
- torch_dtype = model_init_kwargs.get("torch_dtype")
427
- if torch_dtype is not None:
428
- # Convert to `torch.dtype` if an str is passed
429
- if isinstance(torch_dtype, str) and torch_dtype != "auto":
430
- torch_dtype = getattr(torch, torch_dtype)
431
- if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
432
- raise ValueError(
433
- f"Invalid `torch_dtype` passed to the ORPOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
434
- )
435
- model_init_kwargs["torch_dtype"] = torch_dtype
436
-
437
- if isinstance(model, str):
438
- model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
439
-
440
- # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
441
- # has been called in order to properly call autocast if needed.
442
- self._peft_has_been_casted_to_bf16 = False
443
-
444
- if not is_peft_available() and peft_config is not None:
445
- raise ValueError(
446
- "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
447
- )
448
- elif is_peft_available() and peft_config is not None:
449
- # if model is a peft model and we have a peft_config, we merge and unload it first
450
- if isinstance(model, PeftModel):
451
- model = model.merge_and_unload()
452
-
453
- if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
454
- _support_gc_kwargs = hasattr(
455
- args, "gradient_checkpointing_kwargs"
456
- ) and "gradient_checkpointing_kwargs" in list(
457
- inspect.signature(prepare_model_for_kbit_training).parameters
458
- )
459
-
460
- prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
461
-
462
- if _support_gc_kwargs:
463
- prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
464
-
465
- model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
466
- elif getattr(args, "gradient_checkpointing", False):
467
- # For backward compatibility with older versions of transformers
468
- if hasattr(model, "enable_input_require_grads"):
469
- model.enable_input_require_grads()
470
- else:
471
-
472
- def make_inputs_require_grad(module, input, output):
473
- output.requires_grad_(True)
474
-
475
- model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
476
-
477
- # get peft model with the given config
478
- model = model
479
- if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
480
- peft_module_casting_to_bf16(model)
481
- # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
482
- self._peft_has_been_casted_to_bf16 = True
483
-
484
- # For models that use gradient_checkpointing, we need to attach a hook that enables input
485
- # to explicitly have `requires_grad=True`, otherwise training will either silently
486
- # fail or completely fail.
487
- elif getattr(args, "gradient_checkpointing", False):
488
- # For backward compatibility with older versions of transformers
489
- if hasattr(model, "enable_input_require_grads"):
490
- model.enable_input_require_grads()
491
- else:
492
-
493
- def make_inputs_require_grad(module, input, output):
494
- output.requires_grad_(True)
495
-
496
- model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
497
-
498
- if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
499
- raise ValueError(
500
- "`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
501
- " Please install `wandb` or `comet-ml` to resolve."
502
- )
503
-
504
- if model is not None:
505
- self.is_encoder_decoder = model.config.is_encoder_decoder
506
- elif args.is_encoder_decoder is None:
507
- raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
508
- else:
509
- self.is_encoder_decoder = args.is_encoder_decoder
510
-
511
- if self.is_encoder_decoder:
512
- self.decoder_start_token_id = model.config.decoder_start_token_id
513
- self.pad_token_id = model.config.pad_token_id
514
-
515
- if processing_class is None:
516
- raise ValueError("processing_class must be specified to tokenize a ORPO dataset.")
517
- if args.max_length is None:
518
- warnings.warn(
519
- "`max_length` is not set in the ORPOConfig's init"
520
- " it will default to `512` by default, but you should do it yourself in the future.",
521
- UserWarning,
522
- )
523
- max_length = 512
524
- else:
525
- max_length = args.max_length
526
- if args.max_prompt_length is None:
527
- warnings.warn(
528
- "`max_prompt_length` is not set in the ORPOConfig's init"
529
- " it will default to `128` by default, but you should do it yourself in the future.",
530
- UserWarning,
531
- )
532
- max_prompt_length = 128
533
- else:
534
- max_prompt_length = args.max_prompt_length
535
-
536
- if args.max_completion_length is None and self.is_encoder_decoder:
537
- warnings.warn(
538
- "When using an encoder decoder architecture, you should set `max_completion_length` in the ORPOConfig's init"
539
- " it will default to `128` by default, but you should do it yourself in the future.",
540
- UserWarning,
541
- )
542
- self.max_completion_length = 128
543
- else:
544
- self.max_completion_length = args.max_completion_length
545
-
546
- if data_collator is None:
547
- data_collator = DPODataCollatorWithPadding(
548
- pad_token_id=processing_class.pad_token_id,
549
- label_pad_token_id=args.label_pad_token_id,
550
- is_encoder_decoder=self.is_encoder_decoder,
551
- )
552
-
553
- if args.remove_unused_columns:
554
- args.remove_unused_columns = False
555
- # warn users
556
- warnings.warn(
557
- "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments"
558
- " we have set it for you, but you should do it yourself in the future.",
559
- UserWarning,
560
- )
561
-
562
- self.use_dpo_data_collator = True
563
- else:
564
- self.use_dpo_data_collator = False
565
-
566
- # Disable dropout in the model and reference model
567
- if args.disable_dropout:
568
- disable_dropout_in_model(model)
569
-
570
- self.max_length = max_length
571
- self.generate_during_eval = args.generate_during_eval
572
- self.label_pad_token_id = args.label_pad_token_id
573
- self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id
574
- self.max_prompt_length = max_prompt_length
575
- self.truncation_mode = args.truncation_mode
576
- self.processing_class = processing_class
577
-
578
- self.beta = args.beta
579
- self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
580
- self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
581
- if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
582
- warnings.warn(
583
- "You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
584
- "`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
585
- "greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
586
- "loss.",
587
- UserWarning,
588
- )
589
-
590
- self._stored_metrics = defaultdict(lambda: defaultdict(list))
591
-
592
- # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
593
- # input tensor associated with the key "input_ids". However, in ORPO, the sampled data does not include the
594
- # "input_ids" key. Instead, the available keys are "prompt_input_ids", "chosen_input_ids", and
595
- # "rejected_input_ids". As a result, the trainer issues the warning: "Could not estimate the number of tokens
596
- # of the input, floating-point operations will not be computed." To suppress this warning, we set the
597
- # "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate
598
- # that the warning has already been issued.
599
- model.warnings_issued["estimate_tokens"] = True
600
-
601
- # Compute that only on the main process for faster data processing.
602
- # see: https://github.com/huggingface/trl/pull/1255
603
- with PartialState().local_main_process_first():
604
- # Extract the prompt if needed, and apply the chat template if needed
605
- train_dataset = train_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
606
- train_dataset = train_dataset.map(
607
- maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, num_proc=args.dataset_num_proc
608
- )
609
- train_dataset = train_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
610
- if eval_dataset is not None:
611
- eval_dataset = eval_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
612
- eval_dataset = eval_dataset.map(
613
- maybe_apply_chat_template,
614
- fn_kwargs={"tokenizer": processing_class},
615
- num_proc=args.dataset_num_proc,
616
- )
617
- eval_dataset = eval_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
618
-
619
- super().__init__(
620
- model=model,
621
- args=args,
622
- data_collator=data_collator,
623
- train_dataset=train_dataset,
624
- eval_dataset=eval_dataset,
625
- processing_class=processing_class,
626
- model_init=model_init,
627
- compute_metrics=compute_metrics,
628
- callbacks=callbacks,
629
- optimizers=optimizers,
630
- preprocess_logits_for_metrics=preprocess_logits_for_metrics,
631
- )
632
-
633
- # Add tags for models that have been loaded with the correct transformers version
634
- if hasattr(self.model, "add_model_tags"):
635
- self.model.add_model_tags(self._tag_names)
636
-
637
- if not hasattr(self, "accelerator"):
638
- raise AttributeError(
639
- "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
640
- )
641
-
642
- def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
643
- # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
644
- deepspeed_plugin = self.accelerator.state.deepspeed_plugin
645
- config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
646
-
647
- if model is not None:
648
- if hasattr(model, "config"):
649
- hidden_size = (
650
- max(model.config.hidden_sizes)
651
- if getattr(model.config, "hidden_sizes", None)
652
- else getattr(model.config, "hidden_size", None)
653
- )
654
- if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
655
- # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
656
- # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
657
- config_kwargs.update(
658
- {
659
- "zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
660
- "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
661
- "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
662
- }
663
- )
664
-
665
- # If ZeRO-3 is used, we shard both the active and reference model.
666
- # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
667
- if config_kwargs["zero_optimization"]["stage"] != 3:
668
- config_kwargs["zero_optimization"]["stage"] = 0
669
- model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
670
- model.eval()
671
- return model
672
-
673
- def build_tokenized_answer(self, prompt, answer):
674
- """
675
- Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`.
676
- It does ensure `enc(a + b) = enc(a) + enc(a + b)[len(enc(a)):]`.
677
- Reference:
678
- https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
679
- """
680
-
681
- full_tokenized = self.processing_class(prompt + answer, add_special_tokens=False)
682
- prompt_input_ids = self.processing_class(prompt, add_special_tokens=False)["input_ids"]
683
-
684
- answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids) :]
685
- answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids) :]
686
-
687
- # Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]`
688
- full_concat_input_ids = np.concatenate([prompt_input_ids, answer_input_ids])
689
-
690
- # Prepare input tokens for token by token comparison
691
- full_input_ids = np.array(full_tokenized["input_ids"])
692
-
693
- if len(full_input_ids) != len(full_concat_input_ids):
694
- raise ValueError("Prompt input ids and answer input ids should have the same length.")
695
-
696
- # On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens
697
- # can be merged together when tokenizing prompt+answer. This could result
698
- # on the last token from the prompt being different when tokenized on its own
699
- # vs when done as prompt+answer.
700
- response_token_ids_start_idx = len(prompt_input_ids)
701
-
702
- # If tokenized prompt is different than both prompt+answer, then it means the
703
- # last token has changed due to merging.
704
- if prompt_input_ids != full_tokenized["input_ids"][:response_token_ids_start_idx]:
705
- response_token_ids_start_idx -= 1
706
-
707
- prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx]
708
- prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx]
709
-
710
- if len(prompt_input_ids) != len(prompt_attention_mask):
711
- raise ValueError("Prompt input ids and attention mask should have the same length.")
712
-
713
- answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:]
714
- answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:]
715
-
716
- return dict(
717
- prompt_input_ids=prompt_input_ids,
718
- prompt_attention_mask=prompt_attention_mask,
719
- input_ids=answer_input_ids,
720
- attention_mask=answer_attention_mask,
721
- )
722
-
723
- def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module]] = None) -> dict:
724
- """Tokenize a single row from a ORPO specific dataset.
725
-
726
- At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation
727
- in case the prompt + chosen or prompt + rejected responses is/are too long. First
728
- we truncate the prompt; if we're still too long, we truncate the chosen/rejected.
729
-
730
- We also create the labels for the chosen/rejected responses, which are of length equal to
731
- the sum of the length of the prompt and the chosen/rejected response, with
732
- label_pad_token_id for the prompt tokens.
733
- """
734
- batch = {}
735
- prompt = feature["prompt"]
736
- chosen = feature["chosen"]
737
- rejected = feature["rejected"]
738
-
739
- if not self.is_encoder_decoder:
740
- # Check issues below for more details
741
- # 1. https://github.com/huggingface/trl/issues/907
742
- # 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
743
- # 3. https://github.com/LianjiaTech/BELLE/issues/337
744
-
745
- if not isinstance(prompt, str):
746
- raise ValueError(f"prompt should be an str but got {type(prompt)}")
747
- prompt_tokens = self.processing_class(prompt, add_special_tokens=False)
748
- prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()}
749
-
750
- if not isinstance(chosen, str):
751
- raise ValueError(f"chosen should be an str but got {type(chosen)}")
752
- chosen_tokens = self.build_tokenized_answer(prompt, chosen)
753
-
754
- if not isinstance(rejected, str):
755
- raise ValueError(f"rejected should be an str but got {type(rejected)}")
756
- rejected_tokens = self.build_tokenized_answer(prompt, rejected)
757
-
758
- # Last prompt token might get merged by tokenizer and
759
- # it should not be included for generation if that happens
760
- prompt_len_input_ids = len(prompt_tokens["prompt_input_ids"])
761
-
762
- chosen_prompt_len_input_ids = len(chosen_tokens["prompt_input_ids"])
763
- rejected_prompt_len_input_ids = len(rejected_tokens["prompt_input_ids"])
764
- prompt_len_input_ids = min(chosen_prompt_len_input_ids, rejected_prompt_len_input_ids)
765
-
766
- for k, v in prompt_tokens.items():
767
- prompt_tokens[k] = v[:prompt_len_input_ids]
768
-
769
- # Make sure prompts only have one different token at most an
770
- # and length only differs by 1 at most
771
- num_diff_tokens = sum(
772
- [a != b for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"])]
773
- )
774
- num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids)
775
- if num_diff_tokens > 1 or num_diff_len > 1:
776
- raise ValueError(
777
- "Chosen and rejected prompt_input_ids might only differ on the "
778
- "last token due to tokenizer merge ops."
779
- )
780
-
781
- # add BOS token to head of prompt. Avoid adding if it's already there
782
- prompt_tokens, chosen_tokens, rejected_tokens = add_bos_token_if_needed(
783
- self.processing_class.bos_token_id,
784
- prompt_len_input_ids,
785
- prompt_tokens,
786
- chosen_prompt_len_input_ids,
787
- chosen_tokens,
788
- rejected_prompt_len_input_ids,
789
- rejected_tokens,
790
- )
791
-
792
- # add EOS token to end of answer. Avoid adding if it's already there
793
- chosen_tokens, rejected_tokens = add_eos_token_if_needed(
794
- self.processing_class.eos_token_id, chosen_tokens, rejected_tokens
795
- )
796
-
797
- longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))
798
-
799
- # if combined sequence is too long, truncate the prompt
800
- for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]:
801
- if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
802
- if self.truncation_mode == "keep_start":
803
- for k in ["prompt_input_ids", "prompt_attention_mask"]:
804
- answer_tokens[k] = answer_tokens[k][: self.max_prompt_length]
805
- elif self.truncation_mode == "keep_end":
806
- for k in ["prompt_input_ids", "prompt_attention_mask"]:
807
- answer_tokens[k] = answer_tokens[k][-self.max_prompt_length :]
808
- else:
809
- raise ValueError(f"Unknown truncation mode: {self.truncation_mode}")
810
-
811
- # if that's still too long, truncate the response
812
- for answer_tokens in [chosen_tokens, rejected_tokens]:
813
- if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
814
- for k in ["input_ids", "attention_mask"]:
815
- answer_tokens[k] = answer_tokens[k][: self.max_length - self.max_prompt_length]
816
-
817
- # Create labels
818
- chosen_sequence_tokens = {
819
- k: chosen_tokens[f"prompt_{k}"] + chosen_tokens[k] for k in ["input_ids", "attention_mask"]
820
- }
821
- rejected_sequence_tokens = {
822
- k: rejected_tokens[f"prompt_{k}"] + rejected_tokens[k] for k in ["input_ids", "attention_mask"]
823
- }
824
- chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:]
825
- chosen_sequence_tokens["labels"][: len(chosen_tokens["prompt_input_ids"])] = [
826
- self.label_pad_token_id
827
- ] * len(chosen_tokens["prompt_input_ids"])
828
- rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:]
829
- rejected_sequence_tokens["labels"][: len(rejected_tokens["prompt_input_ids"])] = [
830
- self.label_pad_token_id
831
- ] * len(rejected_tokens["prompt_input_ids"])
832
-
833
- for k, toks in {
834
- "chosen_": chosen_sequence_tokens,
835
- "rejected_": rejected_sequence_tokens,
836
- "": prompt_tokens,
837
- }.items():
838
- for type_key, tokens in toks.items():
839
- if type_key == "token_type_ids":
840
- continue
841
- batch[f"{k}{type_key}"] = tokens
842
-
843
- else:
844
- chosen_tokens = self.processing_class(
845
- chosen, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
846
- )
847
- rejected_tokens = self.processing_class(
848
- rejected, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
849
- )
850
- prompt_tokens = self.processing_class(
851
- prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True
852
- )
853
-
854
- batch["chosen_labels"] = chosen_tokens["input_ids"]
855
- batch["rejected_labels"] = rejected_tokens["input_ids"]
856
- batch["prompt_input_ids"] = prompt_tokens["input_ids"]
857
- batch["prompt_attention_mask"] = prompt_tokens["attention_mask"]
858
-
859
- if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"):
860
- batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
861
- labels=torch.tensor(batch["rejected_labels"])
862
- )
863
- batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
864
- labels=torch.tensor(batch["chosen_labels"])
865
- )
866
-
867
- if is_torch_xla_available():
868
- # Pad the sequences to global max_length to avoid TorchXLA recompilation
869
- for k in batch:
870
- if "labels" in k or self.is_encoder_decoder:
871
- pad_value = self.label_pad_token_id
872
- elif k.endswith("_input_ids"):
873
- pad_value = self.padding_value
874
- elif k.endswith("_attention_mask"):
875
- pad_value = 0
876
- batch[k] = batch[k] + [pad_value] * (self.max_length - len(batch[k]))
877
- return batch
878
-
879
- @staticmethod
880
- def concatenated_inputs(
881
- batch: dict[str, Union[list, torch.LongTensor]],
882
- is_encoder_decoder: bool = False,
883
- label_pad_token_id: int = -100,
884
- padding_value: int = 0,
885
- device: Optional[torch.device] = None,
886
- ) -> dict[str, torch.LongTensor]:
887
- """Concatenate the chosen and rejected inputs into a single tensor.
888
-
889
- Args:
890
- batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length).
891
- is_encoder_decoder: Whether the model is an encoder-decoder model.
892
- label_pad_token_id: The label pad token id.
893
- padding_value: The padding value to use for the concatenated inputs_ids.
894
- device: The device for the concatenated inputs.
895
-
896
- Returns:
897
- A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'.
898
- """
899
- concatenated_batch = {}
900
-
901
- if is_encoder_decoder:
902
- max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1])
903
- else:
904
- max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1])
905
-
906
- for k in batch:
907
- if k.startswith("chosen") and isinstance(batch[k], torch.Tensor):
908
- if "labels" in k or is_encoder_decoder:
909
- pad_value = label_pad_token_id
910
- elif k.endswith("_input_ids"):
911
- pad_value = padding_value
912
- elif k.endswith("_attention_mask"):
913
- pad_value = 0
914
- concatenated_key = k.replace("chosen", "concatenated")
915
- concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value)
916
- for k in batch:
917
- if k.startswith("rejected") and isinstance(batch[k], torch.Tensor):
918
- if "labels" in k or is_encoder_decoder:
919
- pad_value = label_pad_token_id
920
- elif k.endswith("_input_ids"):
921
- pad_value = padding_value
922
- elif k.endswith("_attention_mask"):
923
- pad_value = 0
924
- concatenated_key = k.replace("rejected", "concatenated")
925
- concatenated_batch[concatenated_key] = torch.cat(
926
- (
927
- concatenated_batch[concatenated_key],
928
- pad_to_length(batch[k], max_length, pad_value=pad_value),
929
- ),
930
- dim=0,
931
- ).to(device=device)
932
-
933
- if is_encoder_decoder:
934
- concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1).to(device=device)
935
- concatenated_batch["concatenated_attention_mask"] = (
936
- batch["prompt_attention_mask"].repeat(2, 1).to(device=device)
937
- )
938
-
939
- return concatenated_batch
940
-
941
- def odds_ratio_loss(
942
- self,
943
- policy_chosen_logps: torch.FloatTensor,
944
- policy_rejected_logps: torch.FloatTensor,
945
- ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
946
- """Compute ORPO's odds ratio (OR) loss for a batch of policy and reference model log probabilities.
947
-
948
- Args:
949
- policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
950
- policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
951
-
952
- Returns:
953
- A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
954
- The losses tensor contains the ORPO loss for each example in the batch.
955
- The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
956
- The log odds ratio of the chosen responses over the rejected responses ratio for logging purposes.
957
- The `log(sigmoid(log_odds_chosen))` for logging purposes.
958
- """
959
-
960
- # Derived from Eqs. (4) and (7) from https://huggingface.co/papers/2403.07691 by using log identities and exp(log(P(y|x)) = P(y|x)
961
- log_odds = (policy_chosen_logps - policy_rejected_logps) - (
962
- torch.log1p(-torch.exp(policy_chosen_logps)) - torch.log1p(-torch.exp(policy_rejected_logps))
963
- )
964
- ratio = F.logsigmoid(log_odds)
965
- losses = self.beta * ratio
966
-
967
- chosen_rewards = self.beta * (policy_chosen_logps.to(self.accelerator.device)).detach()
968
- rejected_rewards = self.beta * (policy_rejected_logps.to(self.accelerator.device)).detach()
969
-
970
- return losses, chosen_rewards, rejected_rewards, torch.mean(ratio), torch.mean(log_odds)
971
-
972
- @staticmethod
973
- def get_batch_logps(
974
- logits: torch.FloatTensor,
975
- labels: torch.LongTensor,
976
- average_log_prob: bool = False,
977
- label_pad_token_id: int = -100,
978
- is_encoder_decoder: bool = False,
979
- ) -> torch.FloatTensor:
980
- """Compute the log probabilities of the given labels under the given logits.
981
-
982
- Args:
983
- logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
984
- labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
985
- average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
986
- label_pad_token_id: The label pad token id.
987
- is_encoder_decoder: Whether the model is an encoder-decoder model.
988
-
989
- Returns:
990
- A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
991
- """
992
- if logits.shape[:-1] != labels.shape:
993
- raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
994
-
995
- if not is_encoder_decoder:
996
- labels = labels[:, 1:].clone()
997
- logits = logits[:, :-1, :]
998
- loss_mask = labels != label_pad_token_id
999
-
1000
- # dummy token; we'll ignore the losses on these tokens later
1001
- labels = torch.where(labels == label_pad_token_id, 0, labels)
1002
-
1003
- per_token_logps = selective_log_softmax(logits, labels)
1004
-
1005
- if average_log_prob:
1006
- return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
1007
- else:
1008
- return (per_token_logps * loss_mask).sum(-1)
1009
-
1010
- def concatenated_forward(
1011
- self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
1012
- ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
1013
- """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
1014
-
1015
- We do this to avoid doing two forward passes, because it's faster for FSDP.
1016
- """
1017
- concatenated_batch = self.concatenated_inputs(
1018
- batch,
1019
- is_encoder_decoder=self.is_encoder_decoder,
1020
- label_pad_token_id=self.label_pad_token_id,
1021
- padding_value=self.padding_value,
1022
- device=self.accelerator.device,
1023
- )
1024
- len_chosen = batch["chosen_labels"].shape[0]
1025
-
1026
- model_kwargs = (
1027
- {
1028
- "decoder_input_ids": self._shift_right(concatenated_batch["concatenated_labels"]),
1029
- }
1030
- if self.is_encoder_decoder
1031
- else {}
1032
- )
1033
-
1034
- if self.aux_loss_enabled:
1035
- model_kwargs["output_router_logits"] = True
1036
-
1037
- outputs = model(
1038
- concatenated_batch["concatenated_input_ids"],
1039
- attention_mask=concatenated_batch["concatenated_attention_mask"],
1040
- use_cache=False,
1041
- **model_kwargs,
1042
- )
1043
- all_logits = outputs.logits
1044
-
1045
- def cross_entropy_loss(logits, labels):
1046
- if not self.is_encoder_decoder:
1047
- # Shift so that tokens < n predict n
1048
- logits = logits[..., :-1, :].contiguous()
1049
- labels = labels[..., 1:].contiguous()
1050
- # Flatten the tokens
1051
- loss_fct = nn.CrossEntropyLoss()
1052
- logits = logits.view(-1, logits.shape[-1])
1053
- labels = labels.view(-1)
1054
- # Enable model parallelism
1055
- labels = labels.to(logits.device)
1056
- loss = loss_fct(logits, labels)
1057
- return loss
1058
-
1059
- if self.is_encoder_decoder:
1060
- labels = concatenated_batch["concatenated_labels"].clone()
1061
- else:
1062
- labels = concatenated_batch["concatenated_input_ids"].clone()
1063
- attention_mask = concatenated_batch["concatenated_attention_mask"]
1064
- labels = torch.where(attention_mask == 1, labels, self.label_pad_token_id)
1065
- # orpo chosen nll loss is computed over the full prompt and response
1066
- chosen_nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])
1067
-
1068
- all_logps = self.get_batch_logps(
1069
- all_logits,
1070
- concatenated_batch["concatenated_labels"],
1071
- average_log_prob=True,
1072
- is_encoder_decoder=self.is_encoder_decoder,
1073
- label_pad_token_id=self.label_pad_token_id,
1074
- )
1075
-
1076
- chosen_logps = all_logps[:len_chosen]
1077
- rejected_logps = all_logps[len_chosen:]
1078
-
1079
- if not self.is_encoder_decoder:
1080
- chosen_logits = all_logits[:len_chosen, :-1, :]
1081
- rejected_logits = all_logits[len_chosen:, :-1, :]
1082
- else:
1083
- chosen_logits = all_logits[:len_chosen]
1084
- rejected_logits = all_logits[len_chosen:]
1085
-
1086
- if self.aux_loss_enabled:
1087
- return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss, outputs.aux_loss)
1088
-
1089
- return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss)
1090
-
1091
- def get_batch_loss_metrics(
1092
- self,
1093
- model,
1094
- batch: dict[str, Union[list, torch.LongTensor]],
1095
- train_eval: Literal["train", "eval"] = "train",
1096
- ):
1097
- """Compute the ORPO loss and other metrics for the given batch of inputs for train or test."""
1098
- metrics = {}
1099
-
1100
- forward_output = self.concatenated_forward(model, batch)
1101
- (
1102
- policy_chosen_logps,
1103
- policy_rejected_logps,
1104
- policy_chosen_logits,
1105
- policy_rejected_logits,
1106
- policy_nll_loss,
1107
- ) = forward_output[:5]
1108
- if self.aux_loss_enabled:
1109
- aux_loss = forward_output[5]
1110
-
1111
- losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = self.odds_ratio_loss(
1112
- policy_chosen_logps, policy_rejected_logps
1113
- )
1114
- # full ORPO loss
1115
- loss = policy_nll_loss - losses.mean()
1116
-
1117
- reward_accuracies = (chosen_rewards > rejected_rewards).float()
1118
-
1119
- prefix = "eval_" if train_eval == "eval" else ""
1120
- metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean()
1121
- metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean()
1122
- metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean()
1123
- metrics[f"{prefix}rewards/margins"] = self.accelerator.gather_for_metrics(
1124
- chosen_rewards - rejected_rewards
1125
- ).mean()
1126
- metrics[f"{prefix}logps/rejected"] = self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean()
1127
- metrics[f"{prefix}logps/chosen"] = self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean()
1128
- metrics[f"{prefix}logits/rejected"] = (
1129
- self.accelerator.gather_for_metrics(policy_rejected_logits).detach().mean()
1130
- )
1131
- metrics[f"{prefix}logits/chosen"] = self.accelerator.gather_for_metrics(policy_chosen_logits).detach().mean()
1132
- metrics[f"{prefix}nll_loss"] = self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean()
1133
- metrics[f"{prefix}log_odds_ratio"] = self.accelerator.gather_for_metrics(log_odds_ratio).mean()
1134
- metrics[f"{prefix}log_odds_chosen"] = self.accelerator.gather_for_metrics(log_odds_chosen).mean()
1135
- if is_torch_xla_available():
1136
- xm.mark_step() # needed because .item() calls
1137
- for k, v in metrics.items():
1138
- metrics[k] = v.item()
1139
- if self.aux_loss_enabled:
1140
- loss += self.aux_loss_coef * aux_loss
1141
-
1142
- return loss, metrics
1143
-
1144
- def compute_loss(
1145
- self,
1146
- model: Union[PreTrainedModel, nn.Module],
1147
- inputs: dict[str, Union[torch.Tensor, Any]],
1148
- return_outputs=False,
1149
- num_items_in_batch=None,
1150
- ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
1151
- compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1152
-
1153
- with compute_loss_context_manager:
1154
- loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
1155
-
1156
- # Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class:
1157
- loss = loss.to(self.args.device)
1158
-
1159
- # force log the metrics
1160
- self.store_metrics(metrics, train_eval="train")
1161
-
1162
- if return_outputs:
1163
- return (loss, metrics)
1164
- return loss
1165
-
1166
- def generate_from_model(self, model, batch: dict[str, torch.LongTensor]) -> str:
1167
- """Generate samples from the model and reference model for the given batch of inputs."""
1168
-
1169
- # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
1170
- # the torch cuda amp context manager as some hidden states are silently casted to full precision.
1171
- generate_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1172
-
1173
- with generate_context_manager:
1174
- policy_output = model.generate(
1175
- input_ids=batch["prompt_input_ids"],
1176
- attention_mask=batch["prompt_attention_mask"],
1177
- max_length=self.max_length,
1178
- do_sample=True,
1179
- pad_token_id=self.processing_class.pad_token_id,
1180
- )
1181
-
1182
- policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id)
1183
- policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True)
1184
-
1185
- return policy_output_decoded
1186
-
1187
- def prediction_step(
1188
- self,
1189
- model: Union[PreTrainedModel, nn.Module],
1190
- inputs: dict[str, Union[torch.Tensor, Any]],
1191
- prediction_loss_only: bool,
1192
- ignore_keys: Optional[list[str]] = None,
1193
- ):
1194
- if not self.use_dpo_data_collator:
1195
- warnings.warn(
1196
- "prediction_step is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
1197
- "DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
1198
- )
1199
- if ignore_keys is None:
1200
- if hasattr(model, "config"):
1201
- ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
1202
- else:
1203
- ignore_keys = []
1204
-
1205
- prediction_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1206
-
1207
- with torch.no_grad(), prediction_context_manager:
1208
- loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval")
1209
-
1210
- # force log the metrics
1211
- self.store_metrics(metrics, train_eval="eval")
1212
-
1213
- if prediction_loss_only:
1214
- return (loss.detach(), None, None)
1215
-
1216
- # logits for the chosen and rejected samples from model
1217
- logits_dict = {
1218
- "eval_logits/chosen": metrics["eval_logits/chosen"],
1219
- "eval_logits/rejected": metrics["eval_logits/rejected"],
1220
- }
1221
- logits = tuple(v.unsqueeze(dim=0) for k, v in logits_dict.items() if k not in ignore_keys)
1222
- logits = torch.stack(logits).mean(axis=1).to(self.accelerator.device)
1223
- labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
1224
-
1225
- return (loss.detach(), logits, labels)
1226
-
1227
- def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
1228
- for key, value in metrics.items():
1229
- self._stored_metrics[train_eval][key].append(value)
1230
-
1231
- def evaluation_loop(
1232
- self,
1233
- dataloader: DataLoader,
1234
- description: str,
1235
- prediction_loss_only: Optional[bool] = None,
1236
- ignore_keys: Optional[list[str]] = None,
1237
- metric_key_prefix: str = "eval",
1238
- ) -> EvalLoopOutput:
1239
- """
1240
- Overriding built-in evaluation loop to store metrics for each batch.
1241
- Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
1242
-
1243
- Works both with or without labels.
1244
- """
1245
-
1246
- # Sample and save to game log if requested (for one batch to save time)
1247
- if self.generate_during_eval:
1248
- # Generate random indices within the range of the total number of samples
1249
- num_samples = len(dataloader.dataset)
1250
- random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
1251
-
1252
- # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
1253
- random_batch_dataset = dataloader.dataset.select(random_indices)
1254
- random_batch = self.data_collator(random_batch_dataset)
1255
- random_batch = self._prepare_inputs(random_batch)
1256
-
1257
- policy_output_decoded = self.generate_from_model(self.model, random_batch)
1258
-
1259
- table = pd.DataFrame(
1260
- columns=["Prompt", "Policy"],
1261
- data=[
1262
- [prompt, pol[len(prompt) :]] for prompt, pol in zip(random_batch["prompt"], policy_output_decoded)
1263
- ],
1264
- )
1265
- if "wandb" in self.args.report_to:
1266
- wandb.log({"game_log": wandb.Table(data=table)})
1267
-
1268
- if "comet_ml" in self.args.report_to:
1269
- log_table_to_comet_experiment(
1270
- name="game_log.csv",
1271
- table=table,
1272
- )
1273
-
1274
- # Base evaluation
1275
- initial_output = super().evaluation_loop(
1276
- dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
1277
- )
1278
-
1279
- return initial_output
1280
-
1281
- def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
1282
- """
1283
- Log `logs` on the various objects watching training, including stored metrics.
1284
-
1285
- Args:
1286
- logs (`dict[str, float]`):
1287
- The values to log.
1288
- start_time (`float` or `None`, *optional*, defaults to `None`):
1289
- Start time of the training.
1290
- """
1291
- # logs either has 'loss' or 'eval_loss'
1292
- train_eval = "train" if "loss" in logs else "eval"
1293
- # Add averaged stored metrics to logs
1294
- for key, metrics in self._stored_metrics[train_eval].items():
1295
- logs[key] = torch.tensor(metrics).mean().item()
1296
- del self._stored_metrics[train_eval]
1297
-
1298
- if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
1299
- return super().log(logs, start_time)
1300
- else: # transformers<=4.46
1301
- return super().log(logs)
1302
-
1303
- def _shift_right(self, input_ids):
1304
- if self.decoder_start_token_id is None:
1305
- raise ValueError(
1306
- "model.config.decoder_start_token_id has to be defined. It is usually set to the pad_token_id."
1307
- )
1308
-
1309
- # shift inputs to the right
1310
- if is_torch_fx_proxy(input_ids):
1311
- # Item assignment is not supported natively for proxies.
1312
- shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), self.decoder_start_token_id)
1313
- shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
1314
- else:
1315
- shifted_input_ids = input_ids.new_zeros(input_ids.shape)
1316
- shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
1317
- shifted_input_ids[..., 0] = self.decoder_start_token_id
1318
-
1319
- if self.pad_token_id is None:
1320
- raise ValueError("model.config.pad_token_id has to be defined.")
1321
- # replace possible -100 values in labels by `pad_token_id`
1322
- shifted_input_ids.masked_fill_(shifted_input_ids == -100, self.pad_token_id)
1323
-
1324
- return shifted_input_ids
1325
-
1326
- def create_model_card(
1327
- self,
1328
- model_name: Optional[str] = None,
1329
- dataset_name: Optional[str] = None,
1330
- tags: Union[str, list[str], None] = None,
1331
- ):
1332
- """
1333
- Creates a draft of a model card using the information available to the `Trainer`.
1334
-
1335
- Args:
1336
- model_name (`str` or `None`, *optional*, defaults to `None`):
1337
- Name of the model.
1338
- dataset_name (`str` or `None`, *optional*, defaults to `None`):
1339
- Name of the dataset used for training.
1340
- tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
1341
- Tags to be associated with the model card.
1342
- """
1343
- if not self.is_world_process_zero():
1344
- return
1345
-
1346
- if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
1347
- base_model = self.model.config._name_or_path
1348
- else:
1349
- base_model = None
1350
-
1351
- tags = tags or []
1352
- if isinstance(tags, str):
1353
- tags = [tags]
1354
-
1355
- if hasattr(self.model.config, "unsloth_version"):
1356
- tags.append("unsloth")
1357
-
1358
- citation = textwrap.dedent("""\
1359
- @article{hong2024orpo,
1360
- title = {{ORPO: Monolithic Preference Optimization without Reference Model}},
1361
- author = {Jiwoo Hong and Noah Lee and James Thorne},
1362
- year = 2024,
1363
- eprint = {arXiv:2403.07691}
1364
- }""")
1365
-
1366
- model_card = generate_model_card(
1367
- base_model=base_model,
1368
- model_name=model_name,
1369
- hub_model_id=self.hub_model_id,
1370
- dataset_name=dataset_name,
1371
- tags=tags,
1372
- wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
1373
- comet_url=get_comet_experiment_url(),
1374
- trainer_name="ORPO",
1375
- trainer_citation=citation,
1376
- paper_title="ORPO: Monolithic Preference Optimization without Reference Model",
1377
- paper_id="2403.07691",
1378
- )
1379
-
1380
- model_card.save(os.path.join(self.args.output_dir, "README.md"))
1381
- class UnslothORPOTrainer(_UnslothORPOTrainer):
1382
- """
1383
-
1384
- Initialize ORPOTrainer.
1385
-
1386
- Args:
1387
- model (`transformers.PreTrainedModel`):
1388
- The model to train, preferably an `AutoModelForSequenceClassification`.
1389
- args (`ORPOConfig`):
1390
- The ORPO config arguments to use for training.
1391
- data_collator (`transformers.DataCollator`):
1392
- The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
1393
- which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
1394
- train_dataset (`datasets.Dataset`):
1395
- The dataset to use for training.
1396
- eval_dataset (`datasets.Dataset`):
1397
- The dataset to use for evaluation.
1398
- processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
1399
- Processing class used to process the data. If provided, will be used to automatically process the inputs
1400
- for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
1401
- reuse the fine-tuned model.
1402
- model_init (`Callable[[], transformers.PreTrainedModel]`):
1403
- The model initializer to use for training. If None is specified, the default model initializer will be used.
1404
- callbacks (`list[transformers.TrainerCallback]`):
1405
- The callbacks to use for training.
1406
- optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
1407
- The optimizer and scheduler to use for training.
1408
- preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
1409
- The function to use to preprocess the logits before computing the metrics.
1410
- peft_config (`dict`, defaults to `None`):
1411
- The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
1412
- compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
1413
- The function to use to compute the metrics. Must take a `EvalPrediction` and return
1414
- a dictionary string to metric values.
1415
-
1416
- """
1417
- def __init__(
1418
- self,
1419
- model = None,
1420
- args = None,
1421
- data_collator = None,
1422
- train_dataset = None,
1423
- eval_dataset = None,
1424
- processing_class = None,
1425
- model_init = None,
1426
- callbacks = None,
1427
- preprocess_logits_for_metrics = None,
1428
- peft_config = None,
1429
- compute_metrics = None,
1430
- **kwargs
1431
- ):
1432
- if args is None: args = UnslothORPOConfig()
1433
- use_bf16 = getattr(args, 'bf16', False)
1434
- use_fp16 = getattr(args, 'fp16', False)
1435
- force_float32 = False
1436
- if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
1437
- print('Unsloth: Switching to float32 training since model cannot work with float16')
1438
- force_float32 = True
1439
- mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
1440
- dtype = getattr(model.config, 'torch_dtype', None)
1441
- if dtype is None: dtype = model.get_input_embeddings().dtype
1442
- from unsloth_zoo.utils import _get_dtype
1443
- dtype = _get_dtype(dtype)
1444
- float16 = dtype == torch.float16
1445
- if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
1446
- if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
1447
- if force_float32:
1448
- args.fp16 = False
1449
- args.bf16 = False
1450
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1451
- elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
1452
- args.fp16 = float16
1453
- args.bf16 = not float16
1454
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
1455
- if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
1456
- args.eval_strategy = 'steps'
1457
- if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
1458
- ga_steps = getattr(args, 'gradient_accumulation_steps', None)
1459
- if ga_steps is not None and ga_steps > 1:
1460
- from transformers import __version__ as transformers_version
1461
- if Version(transformers_version) <= Version('4.45.2'):
1462
- print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
1463
- '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
1464
- if getattr(args, 'eval_strategy', 'no') != 'no':
1465
- eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
1466
- if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
1467
- if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
1468
- fp16_full_eval = getattr(args, 'fp16_full_eval', False)
1469
- bf16_full_eval = getattr(args, 'bf16_full_eval', False)
1470
- if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
1471
- if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
1472
- if force_float32:
1473
- args.bf16_full_eval = False
1474
- args.fp16_full_eval = False
1475
- elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
1476
- args.bf16_full_eval = True
1477
- args.fp16_full_eval = False
1478
- elif not bf16_full_eval and not fp16_full_eval:
1479
- args.bf16_full_eval = args.bf16
1480
- args.fp16_full_eval = args.fp16
1481
- _output_logits = False
1482
- if locals().get('compute_metrics', None) is not None: _output_logits = True
1483
- if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
1484
- if _output_logits:
1485
- os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
1486
- if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1487
- pass
1488
- else:
1489
- model_max_seq_length = getattr(model, 'max_seq_length', None)
1490
- args_max_seq_length = getattr(args, 'max_seq_length', None)
1491
- if args_max_seq_length is None and model_max_seq_length is not None:
1492
- max_seq_length = model.max_seq_length
1493
- if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1494
- if model is not None and hasattr(model, 'for_training'):
1495
- model.for_training()
1496
- if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1497
- if 'processing_class' in locals():
1498
- if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1499
- if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1500
- __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
1501
- from unsloth_zoo.vision_utils import UnslothVisionDataCollator
1502
- if not isinstance(data_collator, UnslothVisionDataCollator):
1503
- if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
1504
- data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
1505
- elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
1506
- data_collator = DataCollatorForSeq2Seq(__tokenizer)
1507
- else:
1508
- if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
1509
- if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
1510
- if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
1511
- if not isinstance(data_collator, UnslothVisionDataCollator):
1512
- if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
1513
- if isinstance(data_collator, DataCollatorForSeq2Seq):
1514
- data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
1515
- else:
1516
- data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
1517
- other_metrics = []
1518
-
1519
- from unsloth_zoo.logging_utils import PatchRLStatistics
1520
- PatchRLStatistics('orpo_trainer', other_metrics)
1521
-
1522
- super().__init__(
1523
- model = model,
1524
- args = args,
1525
- data_collator = data_collator,
1526
- train_dataset = train_dataset,
1527
- eval_dataset = eval_dataset,
1528
- processing_class = processing_class,
1529
- model_init = model_init,
1530
- callbacks = callbacks,
1531
- preprocess_logits_for_metrics = preprocess_logits_for_metrics,
1532
- peft_config = peft_config,
1533
- compute_metrics = compute_metrics,**kwargs)
1534
- if hasattr(self, 'neftune_hook_handle'):
1535
- self.neftune_hook_handle.remove()
1536
- if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1537
- if getattr(args, 'neftune_noise_alpha', None) is not None:
1538
- model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1539
- pass
1540
-
1541
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
unsloth_compiled_cache/UnslothOnlineDPOTrainer.py DELETED
@@ -1,1267 +0,0 @@
1
- """
2
- 2025.3.13
3
- 2025.3.15
4
- 4.48.3
5
- 0.15.2
6
- __UNSLOTH_VERSIONING__
7
- """
8
- from torch import Tensor
9
- import torch
10
- import torch.nn as nn
11
- from torch.nn import functional as F
12
- from trl.trainer.online_dpo_trainer import (Any, BaseImageProcessor, BasePairwiseJudge, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalPrediction, F, FeatureExtractionMixin, GenerationConfig, IterableDataset, OnlineDPOConfig, OnlineDPOTrainer, OptimizerNames, Optional, PREFIX_CHECKPOINT_DIR, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SIMPLE_CHAT_TEMPLATE, Trainer, TrainerCallback, Union, apply_chat_template, create_reference_model, datasets, disable_dropout_in_model, empty_cache, generate_model_card, get_comet_experiment_url, get_reward, is_conversational, is_peft_available, is_wandb_available, jinja2, logging, maybe_apply_chat_template, nn, np, os, prepare_deepspeed, seed_worker, textwrap, torch, transformers, truncate_right, unwrap_model_for_generation, version, wandb, warnings, wraps, F, is_conversational, os, torch)
13
-
14
-
15
- import os
16
- from typing import *
17
- from dataclasses import dataclass, field
18
- from packaging.version import Version
19
- import torch
20
- import numpy as np
21
- from contextlib import nullcontext
22
- from torch.nn import functional as F
23
- from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
-
25
- torch_compile_options = {
26
- "epilogue_fusion" : True,
27
- "max_autotune" : False,
28
- "shape_padding" : True,
29
- "trace.enabled" : False,
30
- "triton.cudagraphs" : False,
31
- }
32
-
33
- @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
- def selective_log_softmax(logits, index):
35
- logits = logits.to(torch.float32)
36
- selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
- # loop to reduce peak mem consumption
38
- # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
- logsumexp_values = torch.logsumexp(logits, dim = -1)
40
- per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
- return per_token_logps
42
- def vLLMSamplingParams(**kwargs):
43
- from vllm import SamplingParams
44
- sampling_params = SamplingParams(**kwargs)
45
- sampling_params._set_kwargs = kwargs
46
- return sampling_params
47
- @dataclass
48
- class UnslothOnlineDPOConfig(OnlineDPOConfig):
49
- """
50
-
51
- Configuration class for the [`OnlineDPOTrainer`].
52
-
53
- Using [`~transformers.HfArgumentParser`] we can turn this class into
54
- [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
55
- command line.
56
-
57
- Parameters:
58
- learning_rate (`float`, *optional*, defaults to `5e-7`):
59
- Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
60
- [`~transformers.TrainingArguments`].
61
- reward_model_path (`str` or `None`, *optional*, defaults to `None`):
62
- Path to the reward model. Either `judge` or `reward_model_path` must be set, but not both.
63
- judge (`str` or `None`, *optional*, defaults to `None`):
64
- Name of the judge to use. Either `judge` or `reward_model_path` must be set, but not both.
65
- max_new_tokens (`int`, *optional*, defaults to `64`):
66
- Maximum number of tokens to generate per completion.
67
- max_length (`int`, *optional*, defaults to `256`):
68
- Maximum total length of the sequence (prompt + completion) used to compute log probabilities. If the
69
- sequence exceeds this limit, the leftmost tokens will be truncated to preserve as much of the completion as
70
- possible.
71
- temperature (`float`, *optional*, defaults to `0.9`):
72
- Temperature for sampling. The higher the temperature, the more random the completions.
73
- missing_eos_penalty (`float` or `None`, *optional*, defaults to `None`):
74
- Penalty applied to the score when the model fails to generate an EOS token. This is useful to encourage
75
- to generate completions shorter than the maximum length (`max_new_tokens`). The penalty must be a positive
76
- value.
77
- beta (`float` or `list[float]`, *optional*, defaults to `0.1`):
78
- Parameter controlling the deviation from the reference model. Higher β means less deviation from the
79
- reference model. For the IPO loss (`loss_type="ipo"`), β is the regularization parameter denoted by τ in
80
- the [paper](https://huggingface.co/papers/2310.12036). If a list of floats is provided then the β is
81
- selected for each new epoch and the last β is used for the rest of the epochs.
82
- loss_type (`str`, *optional*, defaults to `"sigmoid"`):
83
- Type of loss to use. Possible values are:
84
-
85
- - `"sigmoid"`: sigmoid loss from the original [DPO](https://huggingface.co/papers/2305.18290) paper.
86
- - `"ipo"`: IPO loss from the [IPO](https://huggingface.co/papers/2310.12036) paper.
87
-
88
- dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
89
- Number of processes to use for processing the dataset.
90
- disable_dropout (`bool`, *optional*, defaults to `True`):
91
- Whether to disable dropout in the model and reference model.
92
- use_vllm (`bool`, *optional*, defaults to `False`):
93
- Whether to use vLLM for generating completions. Requires vLLM to be installed (`pip install vllm`).
94
- ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
95
- This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
96
- improving generation speed. However, disabling this option allows training models that exceed the VRAM
97
- capacity of a single GPU, albeit at the cost of slower generation.
98
-
99
- """
100
- vllm_sampling_params: Optional[Any] = field(
101
- default = None,
102
- metadata = {'help': 'vLLM SamplingParams'},
103
- )
104
- unsloth_num_chunks : Optional[int] = field(
105
- default = -1,
106
- metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
107
- )
108
- def __init__(
109
- self,
110
- output_dir = None,
111
- overwrite_output_dir = None,
112
- do_train = False,
113
- do_eval = False,
114
- do_predict = False,
115
- eval_strategy = 'no',
116
- prediction_loss_only = False,
117
- per_device_train_batch_size = 4,
118
- per_device_eval_batch_size = 4,
119
- per_gpu_train_batch_size = None,
120
- per_gpu_eval_batch_size = None,
121
- gradient_accumulation_steps = 2,
122
- eval_accumulation_steps = 2,
123
- eval_delay = 0,
124
- torch_empty_cache_steps = 250,
125
- learning_rate = 5e-05,
126
- weight_decay = 0.01,
127
- adam_beta1 = 0.9,
128
- adam_beta2 = 0.999,
129
- adam_epsilon = 1e-08,
130
- max_grad_norm = 1.0,
131
- num_train_epochs = 3.0,
132
- max_steps = -1,
133
- lr_scheduler_type = 'linear',
134
- warmup_ratio = 0.1,
135
- warmup_steps = 0,
136
- log_level = 'passive',
137
- log_level_replica = 'warning',
138
- log_on_each_node = True,
139
- logging_dir = None,
140
- logging_strategy = 'steps',
141
- logging_first_step = False,
142
- logging_steps = 1,
143
- logging_nan_inf_filter = False,
144
- save_strategy = 'steps',
145
- save_steps = 500,
146
- save_total_limit = None,
147
- save_safetensors = True,
148
- save_on_each_node = False,
149
- save_only_model = False,
150
- restore_callback_states_from_checkpoint = False,
151
- no_cuda = False,
152
- use_cpu = False,
153
- use_mps_device = False,
154
- seed = 3407,
155
- data_seed = 3407,
156
- jit_mode_eval = False,
157
- use_ipex = False,
158
- bf16 = False,
159
- fp16 = False,
160
- fp16_opt_level = 'O1',
161
- half_precision_backend = 'auto',
162
- bf16_full_eval = False,
163
- fp16_full_eval = False,
164
- tf32 = None,
165
- local_rank = -1,
166
- ddp_backend = None,
167
- tpu_num_cores = None,
168
- tpu_metrics_debug = False,
169
- debug = '',
170
- dataloader_drop_last = False,
171
- eval_steps = None,
172
- dataloader_num_workers = 0,
173
- dataloader_prefetch_factor = None,
174
- past_index = -1,
175
- run_name = None,
176
- disable_tqdm = None,
177
- remove_unused_columns = True,
178
- label_names = None,
179
- load_best_model_at_end = False,
180
- metric_for_best_model = None,
181
- greater_is_better = None,
182
- ignore_data_skip = False,
183
- fsdp = '',
184
- fsdp_min_num_params = 0,
185
- fsdp_config = None,
186
- fsdp_transformer_layer_cls_to_wrap = None,
187
- accelerator_config = None,
188
- deepspeed = None,
189
- label_smoothing_factor = 0.0,
190
- optim = 'adamw_8bit',
191
- optim_args = None,
192
- adafactor = False,
193
- group_by_length = False,
194
- length_column_name = 'length',
195
- report_to = None,
196
- ddp_find_unused_parameters = None,
197
- ddp_bucket_cap_mb = None,
198
- ddp_broadcast_buffers = None,
199
- dataloader_pin_memory = True,
200
- dataloader_persistent_workers = False,
201
- skip_memory_metrics = True,
202
- use_legacy_prediction_loop = False,
203
- push_to_hub = False,
204
- resume_from_checkpoint = None,
205
- hub_model_id = None,
206
- hub_strategy = 'every_save',
207
- hub_token = None,
208
- hub_private_repo = None,
209
- hub_always_push = False,
210
- gradient_checkpointing = False,
211
- gradient_checkpointing_kwargs = None,
212
- include_inputs_for_metrics = False,
213
- eval_do_concat_batches = True,
214
- fp16_backend = 'auto',
215
- evaluation_strategy = None,
216
- push_to_hub_model_id = None,
217
- push_to_hub_organization = None,
218
- push_to_hub_token = None,
219
- mp_parameters = '',
220
- auto_find_batch_size = False,
221
- full_determinism = False,
222
- torchdynamo = None,
223
- ray_scope = 'last',
224
- ddp_timeout = 1800,
225
- torch_compile = False,
226
- torch_compile_backend = None,
227
- torch_compile_mode = None,
228
- dispatch_batches = None,
229
- split_batches = None,
230
- include_tokens_per_second = False,
231
- include_num_input_tokens_seen = False,
232
- neftune_noise_alpha = None,
233
- optim_target_modules = None,
234
- batch_eval_metrics = False,
235
- eval_on_start = False,
236
- use_liger_kernel = False,
237
- eval_use_gather_object = False,
238
- average_tokens_across_devices = False,
239
- reward_model_path = None,
240
- judge = None,
241
- max_new_tokens = 64,
242
- max_length = 512,
243
- temperature = 0.9,
244
- missing_eos_penalty = None,
245
- loss_type = 'sigmoid',
246
- dataset_num_proc = None,
247
- disable_dropout = True,
248
- use_vllm = False,
249
- ds3_gather_for_generation = True,
250
- vllm_sampling_params = None,
251
- unsloth_num_chunks = -1,
252
- **kwargs,
253
- ):
254
- if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
255
- if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
256
- if output_dir is None and save_strategy == 'steps' and save_steps == 500:
257
- output_dir = 'unsloth_training_checkpoints'
258
- save_strategy = 'no'
259
- if dataset_num_proc is None:
260
- from multiprocessing import cpu_count
261
- dataset_num_proc = cpu_count()
262
-
263
- super().__init__(
264
- output_dir = output_dir,
265
- overwrite_output_dir = overwrite_output_dir,
266
- do_train = do_train,
267
- do_eval = do_eval,
268
- do_predict = do_predict,
269
- eval_strategy = eval_strategy,
270
- prediction_loss_only = prediction_loss_only,
271
- per_device_train_batch_size = per_device_train_batch_size,
272
- per_device_eval_batch_size = per_device_eval_batch_size,
273
- per_gpu_train_batch_size = per_gpu_train_batch_size,
274
- per_gpu_eval_batch_size = per_gpu_eval_batch_size,
275
- gradient_accumulation_steps = gradient_accumulation_steps,
276
- eval_accumulation_steps = eval_accumulation_steps,
277
- eval_delay = eval_delay,
278
- torch_empty_cache_steps = torch_empty_cache_steps,
279
- learning_rate = learning_rate,
280
- weight_decay = weight_decay,
281
- adam_beta1 = adam_beta1,
282
- adam_beta2 = adam_beta2,
283
- adam_epsilon = adam_epsilon,
284
- max_grad_norm = max_grad_norm,
285
- num_train_epochs = num_train_epochs,
286
- max_steps = max_steps,
287
- lr_scheduler_type = lr_scheduler_type,
288
- warmup_ratio = warmup_ratio,
289
- warmup_steps = warmup_steps,
290
- log_level = log_level,
291
- log_level_replica = log_level_replica,
292
- log_on_each_node = log_on_each_node,
293
- logging_dir = logging_dir,
294
- logging_strategy = logging_strategy,
295
- logging_first_step = logging_first_step,
296
- logging_steps = logging_steps,
297
- logging_nan_inf_filter = logging_nan_inf_filter,
298
- save_strategy = save_strategy,
299
- save_steps = save_steps,
300
- save_total_limit = save_total_limit,
301
- save_safetensors = save_safetensors,
302
- save_on_each_node = save_on_each_node,
303
- save_only_model = save_only_model,
304
- restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
305
- no_cuda = no_cuda,
306
- use_cpu = use_cpu,
307
- use_mps_device = use_mps_device,
308
- seed = seed,
309
- data_seed = data_seed,
310
- jit_mode_eval = jit_mode_eval,
311
- use_ipex = use_ipex,
312
- bf16 = bf16,
313
- fp16 = fp16,
314
- fp16_opt_level = fp16_opt_level,
315
- half_precision_backend = half_precision_backend,
316
- bf16_full_eval = bf16_full_eval,
317
- fp16_full_eval = fp16_full_eval,
318
- tf32 = tf32,
319
- local_rank = local_rank,
320
- ddp_backend = ddp_backend,
321
- tpu_num_cores = tpu_num_cores,
322
- tpu_metrics_debug = tpu_metrics_debug,
323
- debug = debug,
324
- dataloader_drop_last = dataloader_drop_last,
325
- eval_steps = eval_steps,
326
- dataloader_num_workers = dataloader_num_workers,
327
- dataloader_prefetch_factor = dataloader_prefetch_factor,
328
- past_index = past_index,
329
- run_name = run_name,
330
- disable_tqdm = disable_tqdm,
331
- remove_unused_columns = remove_unused_columns,
332
- label_names = label_names,
333
- load_best_model_at_end = load_best_model_at_end,
334
- metric_for_best_model = metric_for_best_model,
335
- greater_is_better = greater_is_better,
336
- ignore_data_skip = ignore_data_skip,
337
- fsdp = fsdp,
338
- fsdp_min_num_params = fsdp_min_num_params,
339
- fsdp_config = fsdp_config,
340
- fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
341
- accelerator_config = accelerator_config,
342
- deepspeed = deepspeed,
343
- label_smoothing_factor = label_smoothing_factor,
344
- optim = optim,
345
- optim_args = optim_args,
346
- adafactor = adafactor,
347
- group_by_length = group_by_length,
348
- length_column_name = length_column_name,
349
- report_to = report_to,
350
- ddp_find_unused_parameters = ddp_find_unused_parameters,
351
- ddp_bucket_cap_mb = ddp_bucket_cap_mb,
352
- ddp_broadcast_buffers = ddp_broadcast_buffers,
353
- dataloader_pin_memory = dataloader_pin_memory,
354
- dataloader_persistent_workers = dataloader_persistent_workers,
355
- skip_memory_metrics = skip_memory_metrics,
356
- use_legacy_prediction_loop = use_legacy_prediction_loop,
357
- push_to_hub = push_to_hub,
358
- resume_from_checkpoint = resume_from_checkpoint,
359
- hub_model_id = hub_model_id,
360
- hub_strategy = hub_strategy,
361
- hub_token = hub_token,
362
- hub_private_repo = hub_private_repo,
363
- hub_always_push = hub_always_push,
364
- gradient_checkpointing = gradient_checkpointing,
365
- gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
366
- include_inputs_for_metrics = include_inputs_for_metrics,
367
- eval_do_concat_batches = eval_do_concat_batches,
368
- fp16_backend = fp16_backend,
369
- evaluation_strategy = evaluation_strategy,
370
- push_to_hub_model_id = push_to_hub_model_id,
371
- push_to_hub_organization = push_to_hub_organization,
372
- push_to_hub_token = push_to_hub_token,
373
- mp_parameters = mp_parameters,
374
- auto_find_batch_size = auto_find_batch_size,
375
- full_determinism = full_determinism,
376
- torchdynamo = torchdynamo,
377
- ray_scope = ray_scope,
378
- ddp_timeout = ddp_timeout,
379
- torch_compile = torch_compile,
380
- torch_compile_backend = torch_compile_backend,
381
- torch_compile_mode = torch_compile_mode,
382
- dispatch_batches = dispatch_batches,
383
- split_batches = split_batches,
384
- include_tokens_per_second = include_tokens_per_second,
385
- include_num_input_tokens_seen = include_num_input_tokens_seen,
386
- neftune_noise_alpha = neftune_noise_alpha,
387
- optim_target_modules = optim_target_modules,
388
- batch_eval_metrics = batch_eval_metrics,
389
- eval_on_start = eval_on_start,
390
- use_liger_kernel = use_liger_kernel,
391
- eval_use_gather_object = eval_use_gather_object,
392
- average_tokens_across_devices = average_tokens_across_devices,
393
- reward_model_path = reward_model_path,
394
- judge = judge,
395
- max_new_tokens = max_new_tokens,
396
- max_length = max_length,
397
- temperature = temperature,
398
- missing_eos_penalty = missing_eos_penalty,
399
- loss_type = loss_type,
400
- dataset_num_proc = dataset_num_proc,
401
- disable_dropout = disable_dropout,
402
- use_vllm = use_vllm,
403
- ds3_gather_for_generation = ds3_gather_for_generation,**kwargs)
404
- self.vllm_sampling_params = vllm_sampling_params
405
- self.unsloth_num_chunks = unsloth_num_chunks
406
- pass
407
-
408
- class _UnslothOnlineDPOTrainer(Trainer):
409
- r""""""
410
-
411
- _tag_names = ["trl", "online-dpo"]
412
-
413
- def __init__(
414
- self,
415
- model: Union[PreTrainedModel, nn.Module],
416
- ref_model: Union[PreTrainedModel, nn.Module, None] = None,
417
- reward_model: Union[PreTrainedModel, nn.Module, None] = None,
418
- judge: Optional[BasePairwiseJudge] = None,
419
- args: Optional[OnlineDPOConfig] = None,
420
- data_collator: Optional[DataCollator] = None,
421
- train_dataset: Optional[Union[Dataset, IterableDataset, "datasets.Dataset"]] = None,
422
- eval_dataset: Optional[Union[Dataset, dict[str, Dataset], "datasets.Dataset"]] = None,
423
- processing_class: Optional[
424
- Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
425
- ] = None,
426
- reward_processing_class: Optional[PreTrainedTokenizerBase] = None,
427
- peft_config: Optional[dict] = None,
428
- compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
429
- callbacks: Optional[list[TrainerCallback]] = None,
430
- optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
431
- preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
432
- ) -> None:
433
-
434
- if hasattr(model, 'vllm_engine') and hasattr(args, 'use_vllm') and (getattr(args, 'use_vllm', False) == False): args.use_vllm = True
435
- if ref_model is model:
436
- raise ValueError(
437
- "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
438
- "same as `model`, either omit the `ref_model` argument or pass `None`."
439
- )
440
-
441
- self.ref_model = ref_model
442
-
443
- if reward_model is not None and judge is not None:
444
- warnings.warn(
445
- "Both `reward_model` and `judge` are provided. Please choose provide only one of them. "
446
- "Ignoring `judge` and using `reward_model`.",
447
- UserWarning,
448
- )
449
- judge = None
450
- elif reward_model is None and judge is None:
451
- raise ValueError("Either `reward_model` or `judge` must be provided.")
452
-
453
- self.reward_model = reward_model
454
- self.reward_processing_class = reward_processing_class
455
- self.judge = judge
456
-
457
- if args.missing_eos_penalty is not None and judge is not None:
458
- raise ValueError("`missing_eos_penalty` is not supported when `judge` is provided.")
459
-
460
- if args is None:
461
- raise ValueError("`args` must be provided.")
462
-
463
- # Check that the processing_class is provided
464
- if processing_class is None:
465
- raise ValueError("`processing_class` must be provided.")
466
-
467
- # Convert to PEFT model if peft_config is provided
468
- if False:
469
- # Check if PEFT is available
470
- if not is_peft_available():
471
- raise ImportError(
472
- "PEFT is not available and passed `peft_config`. Please install PEFT with "
473
- "`pip install peft` to use it."
474
- )
475
-
476
- # If the model is already a PeftModel, we need to merge and unload it.
477
- # Further information here: https://huggingface.co/docs/trl/dpo_trainer#reference-model-considerations-with-peft
478
- if isinstance(model, PeftModel):
479
- model = model.merge_and_unload()
480
-
481
- # Get peft model with the given config
482
- model = model
483
-
484
- # Disable dropout in the model and reference model
485
- if args.disable_dropout:
486
- disable_dropout_in_model(model)
487
- if self.ref_model is not None:
488
- disable_dropout_in_model(self.ref_model)
489
-
490
- # Handle the ref_model
491
- # Usually, the user wants the ref model to be the initial version of the model. When using PEFT, it's easy to
492
- # get the ref model, as it's just the model with a disabled adapter. When not using PEFT, we need to create
493
- # the ref model from the model by copying it and disable the gradients and set it in evaluation mode.
494
- if ref_model is None: # No ref model provided, the most common case
495
- if False:
496
- self.ref_model = create_reference_model(model) # copy, disable gradients, set eval mode
497
- else:
498
- self.ref_model = None # we don't need a ref model here, we can just disable the adapter.
499
- else: # rare case, the user provided a ref model
500
- self.ref_model = ref_model
501
- self.ref_model.eval()
502
-
503
- # Disable the gradient and set the reward model in eval mode
504
- if self.reward_model is not None:
505
- self.reward_model.eval()
506
-
507
- # Define the collator is not provided
508
- if data_collator is None:
509
- data_collator = DPODataCollatorWithPadding(pad_token_id=processing_class.pad_token_id)
510
-
511
- self.max_length = args.max_length
512
-
513
- self.stats = {
514
- "objective/kl": [],
515
- "objective/entropy": [],
516
- "objective/non_score_reward": [],
517
- "rewards/chosen": [],
518
- "rewards/rejected": [],
519
- "rewards/accuracies": [],
520
- "rewards/margins": [],
521
- "logps/chosen": [],
522
- "logps/rejected": [],
523
- "val/contain_eos_token": [],
524
- "beta": [],
525
- }
526
- if self.reward_model is not None:
527
- self.stats["objective/rlhf_reward"] = []
528
- self.stats["objective/scores_margin"] = []
529
- self.stats["objective/scores"] = []
530
-
531
- if args.use_vllm:
532
- self.llm = model.vllm_engine; self._last_loaded_step = 0; self.generation_config = SamplingParams(
533
- n=2, max_tokens=args.max_new_tokens,
534
- temperature=args.temperature,
535
- top_k=50,
536
- top_p=1.0,
537
- detokenize=False,**getattr(getattr(args, 'vllm_sampling_params', vLLMSamplingParams()), '_set_kwargs', {}),)
538
- else:
539
- self.generation_config = GenerationConfig(
540
- max_new_tokens=args.max_new_tokens,
541
- temperature=args.temperature,
542
- top_k=50,
543
- top_p=1.0,
544
- do_sample=True,
545
- use_cache=False if args.gradient_checkpointing else True,
546
- )
547
-
548
- # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
549
- # input tensor associated with the key "input_ids". However, in Online DPO, the sampled data does not include
550
- # the "input_ids" key. As a result, the trainer issues the warning: "Could not estimate the number of tokens
551
- # of the input, floating-point operations will not be computed." To suppress this warning, we set the
552
- # "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate
553
- # that the warning has already been issued.
554
- model.warnings_issued["estimate_tokens"] = True
555
-
556
- super().__init__(
557
- model=model,
558
- args=args,
559
- data_collator=data_collator,
560
- train_dataset=train_dataset,
561
- eval_dataset=eval_dataset,
562
- processing_class=processing_class,
563
- compute_metrics=compute_metrics,
564
- callbacks=callbacks,
565
- optimizers=optimizers,
566
- preprocess_logits_for_metrics=preprocess_logits_for_metrics,
567
- )
568
-
569
- # Add tags for models that have been loaded with the correct transformers version
570
- if hasattr(self.model, "add_model_tags"):
571
- self.model.add_model_tags(self._tag_names)
572
-
573
- self._beta = args.beta
574
-
575
- # Placed after the super().__init__ because we need self.is_deepspeed_enabled and self.accelerator
576
- if self.is_deepspeed_enabled:
577
- if self.reward_model is not None:
578
- self.reward_model = prepare_deepspeed(
579
- self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16
580
- )
581
- if self.ref_model is not None:
582
- self.ref_model = prepare_deepspeed(
583
- self.ref_model, args.per_device_train_batch_size, args.fp16, args.bf16
584
- )
585
- else:
586
- if self.ref_model is not None:
587
- self.ref_model = self.ref_model.to(self.accelerator.device)
588
- if self.reward_model is not None:
589
- self.reward_model = self.reward_model.to(self.accelerator.device)
590
-
591
- @property
592
- def beta(self):
593
- if isinstance(self._beta, list):
594
- epoch = self.state.epoch
595
- return self._beta[epoch] if epoch < len(self._beta) else self._beta[-1]
596
- else:
597
- return self._beta
598
-
599
- @staticmethod
600
- def tokenize_row(feature, is_encoder_decoder: bool, tokenizer: PreTrainedTokenizerBase) -> dict[str, Any]:
601
- """Tokenize a single row from a DPO specific dataset."""
602
- if not is_encoder_decoder:
603
- batch = tokenizer(feature["prompt"], add_special_tokens=False)
604
- # Add BOS token to head of prompt. Avoid adding if it's already there
605
- if tokenizer.bos_token_id is not None:
606
- prompt_len_input_ids = len(batch["input_ids"])
607
- if prompt_len_input_ids == 0 or tokenizer.bos_token_id != batch["input_ids"][0]:
608
- batch["input_ids"] = [tokenizer.bos_token_id] + batch["input_ids"]
609
- batch["attention_mask"] = [1] + batch["attention_mask"]
610
- else:
611
- batch = tokenizer(feature["prompt"], add_special_tokens=True)
612
- batch = {f"prompt_{key}": value for key, value in batch.items()}
613
- return batch
614
-
615
- # Same as Trainer.get_train_dataloader but skip the "remove_unused_columns".
616
- @wraps(Trainer.get_train_dataloader)
617
- def get_train_dataloader(self) -> DataLoader:
618
- if self.train_dataset is None:
619
- raise ValueError("Trainer: training requires a train_dataset.")
620
-
621
- train_dataset = self.train_dataset
622
- data_collator = self.data_collator
623
- dataloader_params = {
624
- "batch_size": self._train_batch_size,
625
- "collate_fn": data_collator,
626
- "num_workers": self.args.dataloader_num_workers,
627
- "pin_memory": self.args.dataloader_pin_memory,
628
- "persistent_workers": self.args.dataloader_persistent_workers,
629
- }
630
-
631
- if not isinstance(train_dataset, torch.utils.data.IterableDataset):
632
- dataloader_params["sampler"] = self._get_train_sampler()
633
- dataloader_params["drop_last"] = self.args.dataloader_drop_last
634
- dataloader_params["worker_init_fn"] = seed_worker
635
- dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
636
-
637
- return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
638
-
639
- # Same as Trainer.get_eval_dataloader but skip the "remove_unused_columns".
640
- @wraps(Trainer.get_eval_dataloader)
641
- def get_eval_dataloader(self, eval_dataset: Optional[Union[str, Dataset]] = None) -> DataLoader:
642
- if eval_dataset is None and self.eval_dataset is None:
643
- raise ValueError("Trainer: evaluation requires an eval_dataset.")
644
-
645
- # If we have persistent workers, don't do a fork bomb especially as eval datasets
646
- # don't change during training
647
- dataloader_key = eval_dataset if isinstance(eval_dataset, str) else "eval"
648
- if (
649
- hasattr(self, "_eval_dataloaders")
650
- and dataloader_key in self._eval_dataloaders
651
- and self.args.dataloader_persistent_workers
652
- ):
653
- return self.accelerator.prepare(self._eval_dataloaders[dataloader_key])
654
-
655
- eval_dataset = (
656
- self.eval_dataset[eval_dataset]
657
- if isinstance(eval_dataset, str)
658
- else eval_dataset
659
- if eval_dataset is not None
660
- else self.eval_dataset
661
- )
662
- data_collator = self.data_collator
663
-
664
- dataloader_params = {
665
- "batch_size": self.args.eval_batch_size,
666
- "collate_fn": data_collator,
667
- "num_workers": self.args.dataloader_num_workers,
668
- "pin_memory": self.args.dataloader_pin_memory,
669
- "persistent_workers": self.args.dataloader_persistent_workers,
670
- }
671
-
672
- if not isinstance(eval_dataset, torch.utils.data.IterableDataset):
673
- dataloader_params["sampler"] = self._get_eval_sampler(eval_dataset)
674
- dataloader_params["drop_last"] = self.args.dataloader_drop_last
675
- dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
676
-
677
- # accelerator.free_memory() will destroy the references, so
678
- # we need to store the non-prepared version
679
- eval_dataloader = DataLoader(eval_dataset, **dataloader_params)
680
- if self.args.dataloader_persistent_workers:
681
- if hasattr(self, "_eval_dataloaders"):
682
- self._eval_dataloaders[dataloader_key] = eval_dataloader
683
- else:
684
- self._eval_dataloaders = {dataloader_key: eval_dataloader}
685
-
686
- return self.accelerator.prepare(eval_dataloader)
687
-
688
- def _generate_vllm(self, model, prompts):
689
- eos_token_id = self.processing_class.eos_token_id
690
- pad_token_id = self.processing_class.pad_token_id
691
-
692
- # Load the latest weights
693
-
694
- pass
695
-
696
- pass
697
-
698
- if is_conversational({"prompt": prompts[0]}):
699
- outputs = self.llm.chat(prompts, self.generation_config, use_tqdm=False, lora_request = self.model.load_lora('online_dpo_trainer_lora_model', load_tensors = True))
700
- else:
701
- outputs = self.llm.generate(prompts, self.generation_config, use_tqdm=False, lora_request = self.model.load_lora('online_dpo_trainer_lora_model', load_tensors = True))
702
-
703
- completion_ids = [list(output.outputs[i].token_ids) for i in range(2) for output in outputs]
704
- prompt_ids = [list(output.prompt_token_ids) for _ in range(2) for output in outputs]
705
-
706
- # Create mask and pad the prompt and completion
707
- max_prompt_length = max(len(ids) for ids in prompt_ids)
708
- prompt_mask = [[0] * (max_prompt_length - len(ids)) + [1] * len(ids) for ids in prompt_ids]
709
- prompt_ids = [[pad_token_id] * (max_prompt_length - len(ids)) + ids for ids in prompt_ids]
710
- max_tokens = self.generation_config.max_tokens
711
- completion_mask = [[1] * len(ids) + [0] * (max_tokens - len(ids)) for ids in completion_ids]
712
- completion_ids = [
713
- ids + [eos_token_id] if ids[-1] != eos_token_id and len(ids) < max_tokens else ids
714
- for ids in completion_ids
715
- ]
716
- completion_ids = [ids + [pad_token_id] * (max_tokens - len(ids)) for ids in completion_ids]
717
-
718
- # Convert to tensors
719
- prompt_ids = torch.tensor(prompt_ids, device=self.accelerator.device)
720
- prompt_mask = torch.tensor(prompt_mask, device=self.accelerator.device)
721
- completion_ids = torch.tensor(completion_ids, device=self.accelerator.device)
722
- completion_mask = torch.tensor(completion_mask, device=self.accelerator.device)
723
-
724
- return prompt_ids, prompt_mask, completion_ids, completion_mask
725
-
726
- def _generate(self, model, prompts):
727
- eos_token_id = self.processing_class.eos_token_id
728
- pad_token_id = self.processing_class.pad_token_id
729
-
730
- # Apply chat template and tokenize the input. We do this on-the-fly to enable the use of reward models and
731
- # policies with different tokenizers / chat templates.
732
- inputs = [{"prompt": prompt} for prompt in prompts]
733
- inputs = [maybe_apply_chat_template(x, self.processing_class) for x in inputs]
734
- inputs = [self.tokenize_row(x, model.config.is_encoder_decoder, self.processing_class) for x in inputs]
735
- inputs = self.data_collator(inputs)
736
-
737
- # Sample 2 completions per prompt of size `max_new_tokens` from the model
738
- inputs = self._prepare_inputs(inputs)
739
- prompt_ids = inputs["prompt_input_ids"].repeat(2, 1)
740
- prompt_mask = inputs["prompt_attention_mask"].repeat(2, 1)
741
- with unwrap_model_for_generation(
742
- model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
743
- ) as unwrapped_model:
744
- output = unwrapped_model.generate(
745
- input_ids=prompt_ids,
746
- attention_mask=prompt_mask,
747
- generation_config=self.generation_config,
748
- )
749
-
750
- completion_ids = output[:, prompt_ids.size(1) :]
751
- completion_ids, completion_mask = truncate_right(completion_ids, eos_token_id, pad_token_id)
752
-
753
- return prompt_ids, prompt_mask, completion_ids, completion_mask
754
-
755
- def _forward(self, model, prompt_ids, prompt_mask, completion_ids, completion_mask):
756
- # Get the number of tokens to truncate from prompt
757
- num_tokens_to_truncate = max(prompt_ids.size(1) + completion_ids.size(1) - self.max_length, 0)
758
-
759
- # Truncate left to avoid oom
760
- prompt_ids = prompt_ids[:, num_tokens_to_truncate:]
761
- prompt_mask = prompt_mask[:, num_tokens_to_truncate:]
762
-
763
- # Concat the prompt and completion
764
- prompt_completion_ids = torch.cat((prompt_ids, completion_ids), dim=1)
765
- prompt_completion_mask = torch.cat((prompt_mask, completion_mask), dim=1)
766
-
767
- # Get the logprobs of the completions from the model
768
- output = model(prompt_completion_ids, attention_mask=prompt_completion_mask)
769
-
770
- # There is 1 offset, because the model predict the next token
771
- logits = output.logits[:, prompt_ids.size(1) - 1 : -1]
772
-
773
- # Take the completion tokens logprob
774
- logprobs = torch.take_along_dim(logits.log_softmax(dim=-1), completion_ids.unsqueeze(-1), dim=2).squeeze(-1)
775
- return logprobs
776
-
777
- def training_step(
778
- self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
779
- ) -> torch.Tensor:
780
- model.train()
781
-
782
- prompts = inputs["prompt"]
783
- batch_size = len(prompts)
784
-
785
- if self.args.use_vllm:
786
- prompt_ids, prompt_mask, completion_ids, completion_mask = self._generate_vllm(model, prompts)
787
- else:
788
- prompt_ids, prompt_mask, completion_ids, completion_mask = self._generate(model, prompts)
789
-
790
- contain_eos_token = torch.any(completion_ids == self.processing_class.eos_token_id, dim=-1)
791
-
792
- logprobs = self._forward(model, prompt_ids, prompt_mask, completion_ids, completion_mask)
793
- with torch.no_grad():
794
- if self.ref_model is not None:
795
- ref_logprobs = self._forward(self.ref_model, prompt_ids, prompt_mask, completion_ids, completion_mask)
796
- else: # peft case: we just need to disable the adapter
797
- with self.model.disable_adapter():
798
- ref_logprobs = self._forward(self.model, prompt_ids, prompt_mask, completion_ids, completion_mask)
799
-
800
- # Decode the completions, and format them if the input is conversational
801
- device = logprobs.device
802
- completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
803
- if is_conversational({"prompt": prompts[0]}):
804
- completions = [[{"role": "assistant", "content": completion}] for completion in completions]
805
-
806
- # Get the reward from the reward model or judge
807
- if self.judge is not None:
808
- # Once formatted, conversational data may contain special tokens (such as <|im_start|>) that are not
809
- # directly understandable by the judge and could alter its judgment. To avoid this and make the judge
810
- # independent of the model's chat template, we use the raw conversation data, and apply our own chat
811
- # template to it.
812
- if is_conversational({"prompt": prompts[0]}):
813
- environment = jinja2.Environment()
814
- template = environment.from_string(SIMPLE_CHAT_TEMPLATE)
815
- prompts = [template.render(messages=prompt) for prompt in prompts]
816
- completions = [template.render(messages=completion) for completion in completions]
817
-
818
- ranks_of_first_completion = self.judge.judge(
819
- prompts, list(zip(completions[:batch_size], completions[batch_size:]))
820
- )
821
-
822
- # convert ranks to a True/False mask:
823
- # when rank == 0, it means the first completion is the best
824
- # when rank == 1, it means the second completion is the best
825
- mask = torch.tensor([rank == 0 for rank in ranks_of_first_completion], device=device)
826
- else:
827
- # The reward model may not have the same chat template or tokenizer as the model, so we need to use the
828
- # raw data (string), apply the chat template (if needed), and tokenize it with the reward processing class.
829
- prompts = 2 * prompts # repeat the prompt: [prompt0, prompt1] -> [prompt0, prompt1, prompt0, prompt1]
830
- if is_conversational({"prompt": prompts[0]}):
831
- examples = [{"prompt": p, "completion": c} for p, c in zip(prompts, completions)]
832
- examples = [apply_chat_template(example, self.reward_processing_class) for example in examples]
833
- prompts = [example["prompt"] for example in examples]
834
- completions = [example["completion"] for example in examples]
835
-
836
- # Tokenize the prompts
837
- prompts_ids = self.reward_processing_class(
838
- prompts, padding=True, return_tensors="pt", padding_side="left"
839
- )["input_ids"].to(device)
840
- context_length = prompts_ids.shape[1]
841
-
842
- # Tokenize the completions
843
- completions_ids = self.reward_processing_class(
844
- completions, padding=True, return_tensors="pt", padding_side="right"
845
- )["input_ids"].to(device)
846
-
847
- # Concatenate the prompts and completions and get the reward
848
- prompt_completion_ids = torch.cat((prompts_ids, completions_ids), dim=1)
849
- with torch.inference_mode():
850
- _, scores, _ = get_reward(
851
- self.reward_model, prompt_completion_ids, self.reward_processing_class.pad_token_id, context_length
852
- )
853
-
854
- # Filter completion. Ensure that the sample contains stop_token_id
855
- # Completions not passing that filter will receive a lower score.
856
- if self.args.missing_eos_penalty is not None:
857
- scores[~contain_eos_token] -= self.args.missing_eos_penalty
858
-
859
- # Split the scores in 2 (the prompts of the first half are the same as the second half)
860
- first_half, second_half = scores.split(batch_size)
861
-
862
- # Get the indices of the chosen and rejected examples
863
- mask = first_half >= second_half
864
-
865
- batch_range = torch.arange(batch_size, device=device)
866
- chosen_indices = batch_range + (~mask * batch_size)
867
- rejected_indices = batch_range + (mask * batch_size)
868
-
869
- # Build tensor so that the first half is the chosen examples and the second half the rejected examples
870
- cr_indices = torch.cat((chosen_indices, rejected_indices), dim=0) # cr = chosen and rejected
871
- cr_logprobs = logprobs[cr_indices]
872
- cr_ref_logprobs = ref_logprobs[cr_indices]
873
-
874
- # mask out the padding tokens
875
- padding_mask = ~completion_mask.bool()
876
- cr_padding_mask = padding_mask[cr_indices]
877
-
878
- cr_logprobs_sum = (cr_logprobs * ~cr_padding_mask).sum(1)
879
- cr_ref_logprobs_sum = (cr_ref_logprobs * ~cr_padding_mask).sum(1)
880
-
881
- # Split the chosen and rejected examples
882
- chosen_logprobs_sum, rejected_logprobs_sum = torch.split(cr_logprobs_sum, batch_size)
883
- chosen_ref_logprobs_sum, rejected_ref_logprobs_sum = torch.split(cr_ref_logprobs_sum, batch_size)
884
- pi_logratios = chosen_logprobs_sum - rejected_logprobs_sum
885
- ref_logratios = chosen_ref_logprobs_sum - rejected_ref_logprobs_sum
886
-
887
- logits = pi_logratios - ref_logratios
888
-
889
- if self.args.loss_type == "sigmoid":
890
- losses = -F.logsigmoid(self.beta * logits)
891
- elif self.args.loss_type == "ipo":
892
- losses = (logits - 1 / (2 * self.beta)) ** 2
893
- else:
894
- raise NotImplementedError(f"invalid loss type {self.loss_type}")
895
-
896
- loss = losses.mean()
897
-
898
- # Log everything
899
- if self.reward_model is not None:
900
- scores_margin = scores[chosen_indices] - scores[rejected_indices]
901
- self.stats["objective/scores_margin"].append(
902
- self.accelerator.gather_for_metrics(scores_margin.mean()).mean().item()
903
- )
904
- self.stats["objective/scores"].append(self.accelerator.gather_for_metrics(scores.mean()).mean().item())
905
- self.stats["val/contain_eos_token"].append(contain_eos_token.float().mean().item())
906
- self.stats["logps/chosen"].append(self.accelerator.gather_for_metrics(chosen_logprobs_sum).mean().item())
907
- self.stats["logps/rejected"].append(self.accelerator.gather_for_metrics(rejected_logprobs_sum).mean().item())
908
-
909
- kl = logprobs - ref_logprobs
910
- mean_kl = kl.sum(1).mean()
911
- self.stats["objective/kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
912
- non_score_reward = (-self.beta * kl).sum(1)
913
- mean_non_score_reward = non_score_reward.mean()
914
- self.stats["objective/non_score_reward"].append(
915
- self.accelerator.gather_for_metrics(mean_non_score_reward).mean().item()
916
- )
917
- if self.reward_model is not None:
918
- rlhf_reward = scores + non_score_reward
919
- self.stats["objective/rlhf_reward"].append(self.accelerator.gather_for_metrics(rlhf_reward).mean().item())
920
- mean_entropy = -logprobs.sum(1).mean()
921
- self.stats["objective/entropy"].append(self.accelerator.gather_for_metrics(mean_entropy).mean().item())
922
- chosen_rewards = self.beta * (chosen_logprobs_sum - chosen_ref_logprobs_sum)
923
- gathered_chosen_rewards = self.accelerator.gather_for_metrics(chosen_rewards)
924
- self.stats["rewards/chosen"].append(gathered_chosen_rewards.mean().item())
925
- rejected_rewards = self.beta * (rejected_logprobs_sum - rejected_ref_logprobs_sum)
926
- gathered_rejected_rewards = self.accelerator.gather_for_metrics(rejected_rewards)
927
- self.stats["rewards/rejected"].append(gathered_rejected_rewards.mean().item())
928
- margin = gathered_chosen_rewards - gathered_rejected_rewards
929
- self.stats["rewards/margins"].append(margin.mean().item())
930
- accuracy = margin > 0
931
- self.stats["rewards/accuracies"].append(accuracy.float().mean().item())
932
- self.stats["beta"].append(self.beta)
933
-
934
- if (
935
- self.args.torch_empty_cache_steps is not None
936
- and self.state.global_step % self.args.torch_empty_cache_steps == 0
937
- ):
938
- empty_cache()
939
-
940
- kwargs = {}
941
-
942
- # For LOMO optimizers you need to explicitly use the learnign rate
943
- if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
944
- kwargs["learning_rate"] = self._get_learning_rate()
945
-
946
- if self.args.n_gpu > 1:
947
- loss = loss.mean() # mean() to average on multi-gpu parallel training
948
-
949
- if self.use_apex:
950
- with amp.scale_loss(loss, self.optimizer) as scaled_loss:
951
- scaled_loss.backward()
952
- else:
953
- self.accelerator.backward(loss, **kwargs)
954
-
955
- return loss.detach() / self.args.gradient_accumulation_steps
956
-
957
- # Same as Trainer._maybe_log_save_evaluate but log our metrics
958
- # start_time defaults to None to allow compatibility with transformers<=4.46
959
- def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time=None):
960
- if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
961
- logs: dict[str, float] = {}
962
-
963
- # all_gather + mean() to get average loss over all processes
964
- tr_loss_scalar = self._nested_gather(tr_loss).mean().item()
965
-
966
- # reset tr_loss to zero
967
- tr_loss -= tr_loss
968
-
969
- logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
970
- if grad_norm is not None:
971
- logs["grad_norm"] = grad_norm.detach().item() if isinstance(grad_norm, torch.Tensor) else grad_norm
972
- logs["learning_rate"] = self._get_learning_rate()
973
-
974
- # Add our metrics
975
- for key, val in self.stats.items():
976
- logs[key] = sum(val) / len(val)
977
- self.stats = {key: [] for key in self.stats} # reset stats
978
-
979
- self._total_loss_scalar += tr_loss_scalar
980
- self._globalstep_last_logged = self.state.global_step
981
- self.store_flos()
982
-
983
- if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
984
- self.log(logs, start_time)
985
- else: # transformers<=4.46
986
- self.log(logs)
987
-
988
- metrics = None
989
- if self.control.should_evaluate:
990
- metrics = self._evaluate(trial, ignore_keys_for_eval)
991
- is_new_best_metric = self._determine_best_metric(metrics=metrics, trial=trial)
992
-
993
- if self.args.save_strategy == "best":
994
- self.control.should_save = is_new_best_metric
995
-
996
- if self.control.should_save:
997
- self._save_checkpoint(model, trial)
998
- self.control = self.callback_handler.on_save(self.args, self.state, self.control)
999
-
1000
- # Copy-pasted from transformers.Trainer to maintain compatibility with earlier versions.
1001
- # This can be removed once the minimum transformers version is updated to 4.47.
1002
- # Refer to https://github.com/huggingface/trl/pull/2288 for more details.
1003
- def _determine_best_metric(self, metrics, trial):
1004
- """
1005
- Determine if the model should be saved based on the evaluation metrics.
1006
- If args.metric_for_best_model is not set, the loss is used.
1007
- Returns:
1008
- bool: True if a new best metric was found, else False
1009
- """
1010
- is_new_best_metric = False
1011
-
1012
- if self.args.metric_for_best_model is not None:
1013
- metric_to_check = self.args.metric_for_best_model
1014
-
1015
- if not metric_to_check.startswith("eval_"):
1016
- metric_to_check = f"eval_{metric_to_check}"
1017
-
1018
- try:
1019
- metric_value = metrics[metric_to_check]
1020
- except KeyError as exc:
1021
- raise KeyError(
1022
- f"The `metric_for_best_model` training argument is set to '{metric_to_check}', which is not found in the evaluation metrics. "
1023
- f"The available evaluation metrics are: {list(metrics.keys())}. Consider changing the `metric_for_best_model` via the TrainingArguments."
1024
- ) from exc
1025
-
1026
- operator = np.greater if self.args.greater_is_better else np.less
1027
-
1028
- if self.state.best_metric is None:
1029
- self.state.best_metric = float("-inf") if self.args.greater_is_better else float("inf")
1030
-
1031
- if operator(metric_value, self.state.best_metric):
1032
- run_dir = self._get_output_dir(trial=trial)
1033
- checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
1034
- output_dir = os.path.join(run_dir, checkpoint_folder)
1035
- self.state.best_metric = metric_value
1036
- self.state.best_model_checkpoint = output_dir
1037
-
1038
- is_new_best_metric = True
1039
-
1040
- return is_new_best_metric
1041
-
1042
- def create_model_card(
1043
- self,
1044
- model_name: Optional[str] = None,
1045
- dataset_name: Optional[str] = None,
1046
- tags: Union[str, list[str], None] = None,
1047
- ):
1048
- """
1049
- Creates a draft of a model card using the information available to the `Trainer`.
1050
-
1051
- Args:
1052
- model_name (`str` or `None`, *optional*, defaults to `None`):
1053
- Name of the model.
1054
- dataset_name (`str` or `None`, *optional*, defaults to `None`):
1055
- Name of the dataset used for training.
1056
- tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
1057
- Tags to be associated with the model card.
1058
- """
1059
- if not self.is_world_process_zero():
1060
- return
1061
-
1062
- if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
1063
- base_model = self.model.config._name_or_path
1064
- else:
1065
- base_model = None
1066
-
1067
- tags = tags or []
1068
- if isinstance(tags, str):
1069
- tags = [tags]
1070
-
1071
- if hasattr(self.model.config, "unsloth_version"):
1072
- tags.append("unsloth")
1073
-
1074
- citation = textwrap.dedent("""\
1075
- @article{guo2024direct,
1076
- title = {{Direct Language Model Alignment from Online AI Feedback}},
1077
- author = {Shangmin Guo and Biao Zhang and Tianlin Liu and Tianqi Liu and Misha Khalman and Felipe Llinares and Alexandre Ram{\'{e}} and Thomas Mesnard and Yao Zhao and Bilal Piot and Johan Ferret and Mathieu Blondel},
1078
- year = 2024,
1079
- eprint = {arXiv:2402.04792}
1080
- }""")
1081
-
1082
- model_card = generate_model_card(
1083
- base_model=base_model,
1084
- model_name=model_name,
1085
- hub_model_id=self.hub_model_id,
1086
- dataset_name=dataset_name,
1087
- tags=tags,
1088
- wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
1089
- comet_url=get_comet_experiment_url(),
1090
- trainer_name="Online DPO",
1091
- trainer_citation=citation,
1092
- paper_title="Direct Language Model Alignment from Online AI Feedback",
1093
- paper_id="2402.04792",
1094
- )
1095
- model_card.save(os.path.join(self.args.output_dir, "README.md"))
1096
- class UnslothOnlineDPOTrainer(_UnslothOnlineDPOTrainer):
1097
- """
1098
-
1099
- Initialize OnlineDPOTrainer.
1100
-
1101
- Args:
1102
- model (`transformers.PreTrainedModel` or `torch.nn.Module`):
1103
- The model to train, preferably an `AutoModelForCausalLM`.
1104
- ref_model (`transformers.PreTrainedModel` or `torch.nn.Module` or `None`):
1105
- The reference model to use for training. If None is specified, the reference model will be created from
1106
- the model.
1107
- reward_model (`transformers.PreTrainedModel` or `torch.nn.Module` or `None`):
1108
- The reward model to score completions with, preferably an `AutoModelForSequenceClassification`.
1109
- judge (`BasePairwiseJudge`):
1110
- The judge to use for pairwise comparison of model completions.
1111
- args (`OnlineDPOConfig`):
1112
- The online DPO config arguments to use for training.
1113
- data_collator (`transformers.DataCollator`):
1114
- The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
1115
- which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
1116
- train_dataset (`datasets.Dataset`):
1117
- The dataset to use for training.
1118
- eval_dataset (`datasets.Dataset`):
1119
- The dataset to use for evaluation.
1120
- processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
1121
- Processing class used to process the data. If provided, will be used to automatically process the inputs
1122
- for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
1123
- reuse the fine-tuned model.
1124
- peft_config (`dict`):
1125
- The peft config to use for training.
1126
- compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
1127
- The function to use to compute the metrics. Must take a `EvalPrediction` and return
1128
- a dictionary string to metric values.
1129
- callbacks (`list[transformers.TrainerCallback]`):
1130
- The callbacks to use for training.
1131
- optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
1132
- The optimizer and scheduler to use for training.
1133
- preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
1134
- The function to use to preprocess the logits before computing the metrics.
1135
-
1136
- """
1137
- def __init__(
1138
- self,
1139
- model,
1140
- ref_model = None,
1141
- reward_model = None,
1142
- judge = None,
1143
- args = None,
1144
- data_collator = None,
1145
- train_dataset = None,
1146
- eval_dataset = None,
1147
- processing_class = None,
1148
- reward_processing_class = None,
1149
- peft_config = None,
1150
- compute_metrics = None,
1151
- callbacks = None,
1152
- preprocess_logits_for_metrics = None,
1153
- **kwargs
1154
- ):
1155
- if args is None: args = UnslothOnlineDPOConfig()
1156
- use_bf16 = getattr(args, 'bf16', False)
1157
- use_fp16 = getattr(args, 'fp16', False)
1158
- force_float32 = False
1159
- if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
1160
- print('Unsloth: Switching to float32 training since model cannot work with float16')
1161
- force_float32 = True
1162
- mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
1163
- dtype = getattr(model.config, 'torch_dtype', None)
1164
- if dtype is None: dtype = model.get_input_embeddings().dtype
1165
- from unsloth_zoo.utils import _get_dtype
1166
- dtype = _get_dtype(dtype)
1167
- float16 = dtype == torch.float16
1168
- if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
1169
- if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
1170
- if force_float32:
1171
- args.fp16 = False
1172
- args.bf16 = False
1173
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1174
- elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
1175
- args.fp16 = float16
1176
- args.bf16 = not float16
1177
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
1178
- if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
1179
- args.eval_strategy = 'steps'
1180
- if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
1181
- ga_steps = getattr(args, 'gradient_accumulation_steps', None)
1182
- if ga_steps is not None and ga_steps > 1:
1183
- from transformers import __version__ as transformers_version
1184
- if Version(transformers_version) <= Version('4.45.2'):
1185
- print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
1186
- '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
1187
- if getattr(args, 'eval_strategy', 'no') != 'no':
1188
- eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
1189
- if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
1190
- if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
1191
- fp16_full_eval = getattr(args, 'fp16_full_eval', False)
1192
- bf16_full_eval = getattr(args, 'bf16_full_eval', False)
1193
- if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
1194
- if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
1195
- if force_float32:
1196
- args.bf16_full_eval = False
1197
- args.fp16_full_eval = False
1198
- elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
1199
- args.bf16_full_eval = True
1200
- args.fp16_full_eval = False
1201
- elif not bf16_full_eval and not fp16_full_eval:
1202
- args.bf16_full_eval = args.bf16
1203
- args.fp16_full_eval = args.fp16
1204
- _output_logits = False
1205
- if locals().get('compute_metrics', None) is not None: _output_logits = True
1206
- if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
1207
- if _output_logits:
1208
- os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
1209
- if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1210
- pass
1211
- else:
1212
- model_max_seq_length = getattr(model, 'max_seq_length', None)
1213
- args_max_seq_length = getattr(args, 'max_seq_length', None)
1214
- if args_max_seq_length is None and model_max_seq_length is not None:
1215
- max_seq_length = model.max_seq_length
1216
- if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1217
- if model is not None and hasattr(model, 'for_training'):
1218
- model.for_training()
1219
- if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1220
- if 'processing_class' in locals():
1221
- if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1222
- if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1223
- __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
1224
- from unsloth_zoo.vision_utils import UnslothVisionDataCollator
1225
- if not isinstance(data_collator, UnslothVisionDataCollator):
1226
- if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
1227
- data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
1228
- elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
1229
- data_collator = DataCollatorForSeq2Seq(__tokenizer)
1230
- else:
1231
- if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
1232
- if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
1233
- if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
1234
- if not isinstance(data_collator, UnslothVisionDataCollator):
1235
- if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
1236
- if isinstance(data_collator, DataCollatorForSeq2Seq):
1237
- data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
1238
- else:
1239
- data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
1240
- other_metrics = []
1241
-
1242
- from unsloth_zoo.logging_utils import PatchRLStatistics
1243
- PatchRLStatistics('online_dpo_trainer', other_metrics)
1244
-
1245
- super().__init__(
1246
- model = model,
1247
- ref_model = ref_model,
1248
- reward_model = reward_model,
1249
- judge = judge,
1250
- args = args,
1251
- data_collator = data_collator,
1252
- train_dataset = train_dataset,
1253
- eval_dataset = eval_dataset,
1254
- processing_class = processing_class,
1255
- reward_processing_class = reward_processing_class,
1256
- peft_config = peft_config,
1257
- compute_metrics = compute_metrics,
1258
- callbacks = callbacks,
1259
- preprocess_logits_for_metrics = preprocess_logits_for_metrics,**kwargs)
1260
- if hasattr(self, 'neftune_hook_handle'):
1261
- self.neftune_hook_handle.remove()
1262
- if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1263
- if getattr(args, 'neftune_noise_alpha', None) is not None:
1264
- model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1265
- pass
1266
-
1267
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
unsloth_compiled_cache/UnslothPPOTrainer.py DELETED
@@ -1,1257 +0,0 @@
1
- """
2
- 2025.3.13
3
- 2025.3.15
4
- 4.48.3
5
- 0.15.2
6
- __UNSLOTH_VERSIONING__
7
- """
8
- from torch import Tensor
9
- import torch
10
- import torch.nn as nn
11
- from torch.nn import functional as F
12
- from trl.trainer.ppo_trainer import (Accelerator, BaseImageProcessor, CallbackHandler, DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK, DataCollatorWithPadding, DataLoader, Dataset, ExportableState, FeatureExtractionMixin, GenerationConfig, INVALID_LOGPROB, OnlineTrainerState, Optional, PPOConfig, PPOTrainer, PeftConfig, PeftModel, PolicyAndValueWrapper, PreTrainedTokenizerBase, PrinterCallback, ProcessorMixin, Trainer, TrainerCallback, TrainerControl, Union, batch_generation, broadcast, contextmanager, create_reference_model, defaultdict, disable_dropout_in_model, exact_div, first_true_indices, forward, gather_object, gc, generate_model_card, get_comet_experiment_url, get_peft_model, get_reporting_integration_callbacks, get_reward, is_peft_available, is_wandb_available, log_table_to_comet_experiment, masked_mean, masked_whiten, math, nn, np, nullcontext, os, pd, peft_module_casting_to_bf16, prepare_deepspeed, print_rich_table, textwrap, time, torch, truncate_response, unwrap_model_for_generation, wandb)
13
-
14
-
15
- import os
16
- from typing import *
17
- from dataclasses import dataclass, field
18
- from packaging.version import Version
19
- import torch
20
- import numpy as np
21
- from contextlib import nullcontext
22
- from torch.nn import functional as F
23
- from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
-
25
- torch_compile_options = {
26
- "epilogue_fusion" : True,
27
- "max_autotune" : False,
28
- "shape_padding" : True,
29
- "trace.enabled" : False,
30
- "triton.cudagraphs" : False,
31
- }
32
-
33
- @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
- def selective_log_softmax(logits, index):
35
- logits = logits.to(torch.float32)
36
- selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
- # loop to reduce peak mem consumption
38
- # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
- logsumexp_values = torch.logsumexp(logits, dim = -1)
40
- per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
- return per_token_logps
42
- @dataclass
43
- class UnslothPPOConfig(PPOConfig):
44
- """
45
-
46
- Configuration class for the [`PPOTrainer`].
47
-
48
- Using [`~transformers.HfArgumentParser`] we can turn this class into
49
- [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
50
- command line.
51
-
52
- Parameters:
53
- exp_name (`str`, *optional*, defaults to `os.path.basename(__file__)[:-3]`):
54
- Name of this experiment.
55
- reward_model_path (`str`, *optional*, defaults to `"EleutherAI/pythia-160m"`):
56
- Path to the reward model.
57
- model_adapter_name (`str` or `None`, *optional*, defaults to `None`):
58
- Name of the train target PEFT adapter, when using LoRA with multiple adapters.
59
- ref_adapter_name (`str` or `None`, *optional*, defaults to `None`):
60
- Name of the reference PEFT adapter, when using LoRA with multiple adapters.
61
- num_ppo_epochs (`int`, *optional*, defaults to `4`):
62
- Number of epochs to train.
63
- whiten_rewards (`bool`, *optional*, defaults to `False`):
64
- Whether to whiten the rewards.
65
- kl_coef (`float`, *optional*, defaults to `0.05`):
66
- KL coefficient.
67
- cliprange (`float`, *optional*, defaults to `0.2`):
68
- Clip range.
69
- vf_coef (`float`, *optional*, defaults to `0.1`):
70
- Value function coefficient.
71
- cliprange_value (`float`, *optional*, defaults to `0.2`):
72
- Clip range for the value function.
73
- gamma (`float`, *optional*, defaults to `1.0`):
74
- Discount factor.
75
- lam (`float`, *optional*, defaults to `0.95`):
76
- Lambda value for GAE.
77
- ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
78
- This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
79
- improving generation speed. However, disabling this option allows training models that exceed the VRAM
80
- capacity of a single GPU, albeit at the cost of slower generation.
81
-
82
- """
83
- vllm_sampling_params: Optional[Any] = field(
84
- default = None,
85
- metadata = {'help': 'vLLM SamplingParams'},
86
- )
87
- unsloth_num_chunks : Optional[int] = field(
88
- default = -1,
89
- metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
90
- )
91
- def __init__(
92
- self,
93
- output_dir = None,
94
- overwrite_output_dir = None,
95
- do_train = False,
96
- do_eval = False,
97
- do_predict = False,
98
- eval_strategy = 'no',
99
- prediction_loss_only = False,
100
- per_device_train_batch_size = 4,
101
- per_device_eval_batch_size = 4,
102
- per_gpu_train_batch_size = None,
103
- per_gpu_eval_batch_size = None,
104
- gradient_accumulation_steps = 2,
105
- eval_accumulation_steps = 2,
106
- eval_delay = 0,
107
- torch_empty_cache_steps = 250,
108
- learning_rate = 5e-05,
109
- weight_decay = 0.01,
110
- adam_beta1 = 0.9,
111
- adam_beta2 = 0.999,
112
- adam_epsilon = 1e-08,
113
- max_grad_norm = 1.0,
114
- num_train_epochs = 3.0,
115
- max_steps = -1,
116
- lr_scheduler_type = 'linear',
117
- warmup_ratio = 0.1,
118
- warmup_steps = 0,
119
- log_level = 'passive',
120
- log_level_replica = 'warning',
121
- log_on_each_node = True,
122
- logging_dir = None,
123
- logging_strategy = 'steps',
124
- logging_first_step = False,
125
- logging_steps = 1,
126
- logging_nan_inf_filter = False,
127
- save_strategy = 'steps',
128
- save_steps = 500,
129
- save_total_limit = None,
130
- save_safetensors = True,
131
- save_on_each_node = False,
132
- save_only_model = False,
133
- restore_callback_states_from_checkpoint = False,
134
- no_cuda = False,
135
- use_cpu = False,
136
- use_mps_device = False,
137
- seed = 3407,
138
- data_seed = 3407,
139
- jit_mode_eval = False,
140
- use_ipex = False,
141
- bf16 = False,
142
- fp16 = False,
143
- fp16_opt_level = 'O1',
144
- half_precision_backend = 'auto',
145
- bf16_full_eval = False,
146
- fp16_full_eval = False,
147
- tf32 = None,
148
- local_rank = -1,
149
- ddp_backend = None,
150
- tpu_num_cores = None,
151
- tpu_metrics_debug = False,
152
- debug = '',
153
- dataloader_drop_last = False,
154
- eval_steps = None,
155
- dataloader_num_workers = 0,
156
- dataloader_prefetch_factor = None,
157
- past_index = -1,
158
- run_name = None,
159
- disable_tqdm = None,
160
- remove_unused_columns = True,
161
- label_names = None,
162
- load_best_model_at_end = False,
163
- metric_for_best_model = None,
164
- greater_is_better = None,
165
- ignore_data_skip = False,
166
- fsdp = '',
167
- fsdp_min_num_params = 0,
168
- fsdp_config = None,
169
- fsdp_transformer_layer_cls_to_wrap = None,
170
- accelerator_config = None,
171
- deepspeed = None,
172
- label_smoothing_factor = 0.0,
173
- optim = 'adamw_8bit',
174
- optim_args = None,
175
- adafactor = False,
176
- group_by_length = False,
177
- length_column_name = 'length',
178
- report_to = None,
179
- ddp_find_unused_parameters = None,
180
- ddp_bucket_cap_mb = None,
181
- ddp_broadcast_buffers = None,
182
- dataloader_pin_memory = True,
183
- dataloader_persistent_workers = False,
184
- skip_memory_metrics = True,
185
- use_legacy_prediction_loop = False,
186
- push_to_hub = False,
187
- resume_from_checkpoint = None,
188
- hub_model_id = None,
189
- hub_strategy = 'every_save',
190
- hub_token = None,
191
- hub_private_repo = None,
192
- hub_always_push = False,
193
- gradient_checkpointing = False,
194
- gradient_checkpointing_kwargs = None,
195
- include_inputs_for_metrics = False,
196
- eval_do_concat_batches = True,
197
- fp16_backend = 'auto',
198
- evaluation_strategy = None,
199
- push_to_hub_model_id = None,
200
- push_to_hub_organization = None,
201
- push_to_hub_token = None,
202
- mp_parameters = '',
203
- auto_find_batch_size = False,
204
- full_determinism = False,
205
- torchdynamo = None,
206
- ray_scope = 'last',
207
- ddp_timeout = 1800,
208
- torch_compile = False,
209
- torch_compile_backend = None,
210
- torch_compile_mode = None,
211
- dispatch_batches = None,
212
- split_batches = None,
213
- include_tokens_per_second = False,
214
- include_num_input_tokens_seen = False,
215
- neftune_noise_alpha = None,
216
- optim_target_modules = None,
217
- batch_eval_metrics = False,
218
- eval_on_start = False,
219
- use_liger_kernel = False,
220
- eval_use_gather_object = False,
221
- average_tokens_across_devices = False,
222
- dataset_num_proc = None,
223
- num_mini_batches = 1,
224
- total_episodes = None,
225
- local_rollout_forward_batch_size = 64,
226
- num_sample_generations = 10,
227
- response_length = 53,
228
- stop_token = None,
229
- stop_token_id = None,
230
- temperature = 0.7,
231
- missing_eos_penalty = None,
232
- sft_model_path = 'EleutherAI/pythia-160m',
233
- world_size = None,
234
- num_total_batches = None,
235
- micro_batch_size = None,
236
- local_batch_size = None,
237
- batch_size = None,
238
- local_mini_batch_size = None,
239
- mini_batch_size = None,
240
- exp_name = 'ppo_config',
241
- reward_model_path = 'EleutherAI/pythia-160m',
242
- model_adapter_name = None,
243
- ref_adapter_name = None,
244
- num_ppo_epochs = 4,
245
- whiten_rewards = False,
246
- kl_coef = 0.05,
247
- cliprange = 0.2,
248
- vf_coef = 0.1,
249
- cliprange_value = 0.2,
250
- gamma = 1.0,
251
- lam = 0.95,
252
- ds3_gather_for_generation = True,
253
- vllm_sampling_params = None,
254
- unsloth_num_chunks = -1,
255
- **kwargs,
256
- ):
257
- if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
258
- if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
259
- if output_dir is None and save_strategy == 'steps' and save_steps == 500:
260
- output_dir = 'unsloth_training_checkpoints'
261
- save_strategy = 'no'
262
- if dataset_num_proc is None:
263
- from multiprocessing import cpu_count
264
- dataset_num_proc = cpu_count()
265
-
266
- super().__init__(
267
- output_dir = output_dir,
268
- overwrite_output_dir = overwrite_output_dir,
269
- do_train = do_train,
270
- do_eval = do_eval,
271
- do_predict = do_predict,
272
- eval_strategy = eval_strategy,
273
- prediction_loss_only = prediction_loss_only,
274
- per_device_train_batch_size = per_device_train_batch_size,
275
- per_device_eval_batch_size = per_device_eval_batch_size,
276
- per_gpu_train_batch_size = per_gpu_train_batch_size,
277
- per_gpu_eval_batch_size = per_gpu_eval_batch_size,
278
- gradient_accumulation_steps = gradient_accumulation_steps,
279
- eval_accumulation_steps = eval_accumulation_steps,
280
- eval_delay = eval_delay,
281
- torch_empty_cache_steps = torch_empty_cache_steps,
282
- learning_rate = learning_rate,
283
- weight_decay = weight_decay,
284
- adam_beta1 = adam_beta1,
285
- adam_beta2 = adam_beta2,
286
- adam_epsilon = adam_epsilon,
287
- max_grad_norm = max_grad_norm,
288
- num_train_epochs = num_train_epochs,
289
- max_steps = max_steps,
290
- lr_scheduler_type = lr_scheduler_type,
291
- warmup_ratio = warmup_ratio,
292
- warmup_steps = warmup_steps,
293
- log_level = log_level,
294
- log_level_replica = log_level_replica,
295
- log_on_each_node = log_on_each_node,
296
- logging_dir = logging_dir,
297
- logging_strategy = logging_strategy,
298
- logging_first_step = logging_first_step,
299
- logging_steps = logging_steps,
300
- logging_nan_inf_filter = logging_nan_inf_filter,
301
- save_strategy = save_strategy,
302
- save_steps = save_steps,
303
- save_total_limit = save_total_limit,
304
- save_safetensors = save_safetensors,
305
- save_on_each_node = save_on_each_node,
306
- save_only_model = save_only_model,
307
- restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
308
- no_cuda = no_cuda,
309
- use_cpu = use_cpu,
310
- use_mps_device = use_mps_device,
311
- seed = seed,
312
- data_seed = data_seed,
313
- jit_mode_eval = jit_mode_eval,
314
- use_ipex = use_ipex,
315
- bf16 = bf16,
316
- fp16 = fp16,
317
- fp16_opt_level = fp16_opt_level,
318
- half_precision_backend = half_precision_backend,
319
- bf16_full_eval = bf16_full_eval,
320
- fp16_full_eval = fp16_full_eval,
321
- tf32 = tf32,
322
- local_rank = local_rank,
323
- ddp_backend = ddp_backend,
324
- tpu_num_cores = tpu_num_cores,
325
- tpu_metrics_debug = tpu_metrics_debug,
326
- debug = debug,
327
- dataloader_drop_last = dataloader_drop_last,
328
- eval_steps = eval_steps,
329
- dataloader_num_workers = dataloader_num_workers,
330
- dataloader_prefetch_factor = dataloader_prefetch_factor,
331
- past_index = past_index,
332
- run_name = run_name,
333
- disable_tqdm = disable_tqdm,
334
- remove_unused_columns = remove_unused_columns,
335
- label_names = label_names,
336
- load_best_model_at_end = load_best_model_at_end,
337
- metric_for_best_model = metric_for_best_model,
338
- greater_is_better = greater_is_better,
339
- ignore_data_skip = ignore_data_skip,
340
- fsdp = fsdp,
341
- fsdp_min_num_params = fsdp_min_num_params,
342
- fsdp_config = fsdp_config,
343
- fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
344
- accelerator_config = accelerator_config,
345
- deepspeed = deepspeed,
346
- label_smoothing_factor = label_smoothing_factor,
347
- optim = optim,
348
- optim_args = optim_args,
349
- adafactor = adafactor,
350
- group_by_length = group_by_length,
351
- length_column_name = length_column_name,
352
- report_to = report_to,
353
- ddp_find_unused_parameters = ddp_find_unused_parameters,
354
- ddp_bucket_cap_mb = ddp_bucket_cap_mb,
355
- ddp_broadcast_buffers = ddp_broadcast_buffers,
356
- dataloader_pin_memory = dataloader_pin_memory,
357
- dataloader_persistent_workers = dataloader_persistent_workers,
358
- skip_memory_metrics = skip_memory_metrics,
359
- use_legacy_prediction_loop = use_legacy_prediction_loop,
360
- push_to_hub = push_to_hub,
361
- resume_from_checkpoint = resume_from_checkpoint,
362
- hub_model_id = hub_model_id,
363
- hub_strategy = hub_strategy,
364
- hub_token = hub_token,
365
- hub_private_repo = hub_private_repo,
366
- hub_always_push = hub_always_push,
367
- gradient_checkpointing = gradient_checkpointing,
368
- gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
369
- include_inputs_for_metrics = include_inputs_for_metrics,
370
- eval_do_concat_batches = eval_do_concat_batches,
371
- fp16_backend = fp16_backend,
372
- evaluation_strategy = evaluation_strategy,
373
- push_to_hub_model_id = push_to_hub_model_id,
374
- push_to_hub_organization = push_to_hub_organization,
375
- push_to_hub_token = push_to_hub_token,
376
- mp_parameters = mp_parameters,
377
- auto_find_batch_size = auto_find_batch_size,
378
- full_determinism = full_determinism,
379
- torchdynamo = torchdynamo,
380
- ray_scope = ray_scope,
381
- ddp_timeout = ddp_timeout,
382
- torch_compile = torch_compile,
383
- torch_compile_backend = torch_compile_backend,
384
- torch_compile_mode = torch_compile_mode,
385
- dispatch_batches = dispatch_batches,
386
- split_batches = split_batches,
387
- include_tokens_per_second = include_tokens_per_second,
388
- include_num_input_tokens_seen = include_num_input_tokens_seen,
389
- neftune_noise_alpha = neftune_noise_alpha,
390
- optim_target_modules = optim_target_modules,
391
- batch_eval_metrics = batch_eval_metrics,
392
- eval_on_start = eval_on_start,
393
- use_liger_kernel = use_liger_kernel,
394
- eval_use_gather_object = eval_use_gather_object,
395
- average_tokens_across_devices = average_tokens_across_devices,
396
- dataset_num_proc = dataset_num_proc,
397
- num_mini_batches = num_mini_batches,
398
- total_episodes = total_episodes,
399
- local_rollout_forward_batch_size = local_rollout_forward_batch_size,
400
- num_sample_generations = num_sample_generations,
401
- response_length = response_length,
402
- stop_token = stop_token,
403
- stop_token_id = stop_token_id,
404
- temperature = temperature,
405
- missing_eos_penalty = missing_eos_penalty,
406
- sft_model_path = sft_model_path,
407
- world_size = world_size,
408
- num_total_batches = num_total_batches,
409
- micro_batch_size = micro_batch_size,
410
- local_batch_size = local_batch_size,
411
- batch_size = batch_size,
412
- local_mini_batch_size = local_mini_batch_size,
413
- mini_batch_size = mini_batch_size,
414
- exp_name = exp_name,
415
- reward_model_path = reward_model_path,
416
- model_adapter_name = model_adapter_name,
417
- ref_adapter_name = ref_adapter_name,
418
- num_ppo_epochs = num_ppo_epochs,
419
- whiten_rewards = whiten_rewards,
420
- kl_coef = kl_coef,
421
- cliprange = cliprange,
422
- vf_coef = vf_coef,
423
- cliprange_value = cliprange_value,
424
- gamma = gamma,
425
- lam = lam,
426
- ds3_gather_for_generation = ds3_gather_for_generation,**kwargs)
427
- self.vllm_sampling_params = vllm_sampling_params
428
- self.unsloth_num_chunks = unsloth_num_chunks
429
- pass
430
-
431
- class _UnslothPPOTrainer(Trainer):
432
- _tag_names = ["trl", "ppo"]
433
-
434
- def __init__(
435
- self,
436
- args: PPOConfig,
437
- processing_class: Optional[
438
- Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
439
- ],
440
- model: nn.Module,
441
- ref_model: Optional[nn.Module],
442
- reward_model: nn.Module,
443
- train_dataset: Dataset,
444
- value_model: Optional[nn.Module] = None,
445
- data_collator: Optional[DataCollatorWithPadding] = None,
446
- eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
447
- # less commonly used
448
- optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
449
- callbacks: Optional[list[TrainerCallback]] = None,
450
- peft_config: Optional["PeftConfig"] = None,
451
- ) -> None:
452
- if ref_model is model:
453
- raise ValueError(
454
- "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
455
- "same as `model`, you must make a copy of it, or `None` if you use peft."
456
- )
457
-
458
- self.args = args
459
- self.processing_class = processing_class
460
- self.policy_model = model
461
-
462
- # Define the collator if not provided
463
- if data_collator is None:
464
- data_collator = DataCollatorWithPadding(self.processing_class)
465
-
466
- # Handle stop token settings: update policy model's generation_config to use provided stop token
467
- if args.stop_token and args.stop_token_id:
468
- raise ValueError("You cannot set both `stop_token` and `stop_token_id`.")
469
- elif args.stop_token:
470
- if args.stop_token == "eos":
471
- self.policy_model.generation_config.eos_token_id = self.stop_token_id = processing_class.eos_token_id
472
- else:
473
- raise ValueError(
474
- f"Unknown `stop_token` {args.stop_token}. Allowed values are: `'eos'` and `None` (no stop token)."
475
- )
476
- else:
477
- self.policy_model.generation_config.eos_token_id = self.stop_token_id = args.stop_token_id # None or int
478
-
479
- # peft support
480
- if not is_peft_available() and peft_config is not None:
481
- raise ImportError(
482
- "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
483
- )
484
- elif is_peft_available() and peft_config is not None:
485
- # if model is a peft model and we have a peft_confg, we merge and unload it first
486
- if isinstance(self.policy_model, PeftModel):
487
- self.policy_model = self.policy_model.merge_and_unload()
488
-
489
- # get peft model with the given config
490
- self.policy_model = get_peft_model(self.policy_model, peft_config)
491
- if args.bf16 and getattr(self.policy_model, "is_loaded_in_4bit", False):
492
- peft_module_casting_to_bf16(self.policy_model)
493
-
494
- self.is_peft_model = is_peft_available() and isinstance(self.policy_model, PeftModel)
495
- self.model_adapter_name = args.model_adapter_name
496
- self.ref_adapter_name = args.ref_adapter_name
497
-
498
- if ref_model:
499
- self.ref_model = ref_model
500
- elif self.is_peft_model:
501
- self.ref_model = None
502
- else:
503
- self.ref_model = create_reference_model(self.policy_model)
504
-
505
- self.reward_model = reward_model
506
- self.train_dataset = train_dataset
507
- self.train_dataset_len = len(train_dataset)
508
- self.value_model = value_model
509
- self.data_collator = data_collator
510
- self.eval_dataset = eval_dataset
511
- self.optimizer, self.lr_scheduler = optimizers
512
- self.optimizer_cls_and_kwargs = None # needed for transformers >= 4.47
513
-
514
- #########
515
- # calculate various batch sizes
516
- #########
517
- if args.total_episodes is None: # allow the users to define episodes in terms of epochs.
518
- args.total_episodes = int(args.num_train_epochs * self.train_dataset_len)
519
- accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps)
520
- self.accelerator = accelerator
521
- args.world_size = accelerator.num_processes
522
- args.local_batch_size = (
523
- args.per_device_train_batch_size * args.gradient_accumulation_steps * args.num_mini_batches
524
- )
525
- args.micro_batch_size = int(args.per_device_train_batch_size * args.world_size)
526
- args.batch_size = int(args.local_batch_size * args.world_size)
527
- args.mini_batch_size = exact_div(
528
- args.batch_size, args.num_mini_batches, "`batch_size` must be a multiple of `num_mini_batches`"
529
- )
530
- args.local_mini_batch_size = exact_div(
531
- args.local_batch_size, args.num_mini_batches, "`local_batch_size` must be a multiple of `num_mini_batches`"
532
- )
533
- if args.whiten_rewards:
534
- assert (
535
- args.local_mini_batch_size >= 8
536
- ), f"Per-rank minibatch size {args.local_mini_batch_size} is insufficient for whitening"
537
- # `per_rank_rollout_batch_size` is our `args.local_batch_size`
538
- # `per_rank_minibatch_size` is our `args.local_mini_batch_size`
539
- args.num_total_batches = math.ceil(
540
- args.total_episodes / args.batch_size
541
- ) # we may train for more than `total_episodes`
542
- time_tensor = torch.tensor(int(time.time()), device=accelerator.device)
543
- time_int = broadcast(time_tensor, 0).item() # avoid different timestamps across processes
544
- args.run_name = f"{args.exp_name}__{args.seed}__{time_int}"
545
- self.local_seed = args.seed + accelerator.process_index * 100003 # Prime
546
- if args.num_sample_generations > 0:
547
- self.sample_generations_freq = max(1, args.num_total_batches // args.num_sample_generations)
548
- self.local_dataloader_batch_size = args.local_batch_size
549
-
550
- #########
551
- # setup model, optimizer, and others
552
- #########
553
- for module in [self.policy_model, self.ref_model, self.value_model, self.reward_model]:
554
- if module is not None:
555
- disable_dropout_in_model(module)
556
- self.model = PolicyAndValueWrapper(self.policy_model, self.value_model)
557
- self.model.config = self.policy_model.config # needed for pushing to hub
558
- self.create_optimizer_and_scheduler(
559
- num_training_steps=args.num_total_batches
560
- ) # note that we are calling `self.lr_scheduler.step()` manually only at the batch level
561
-
562
- #########
563
- ### trainer specifics
564
- #########
565
- default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
566
- self.callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
567
- self.callback_handler = CallbackHandler(
568
- self.callbacks, self.model, self.processing_class, self.optimizer, self.lr_scheduler
569
- )
570
- self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
571
- self.control = TrainerControl()
572
- self.state = OnlineTrainerState(
573
- is_local_process_zero=self.is_local_process_zero(),
574
- is_world_process_zero=self.is_world_process_zero(),
575
- stateful_callbacks=[
576
- cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
577
- ],
578
- )
579
- self.current_flos = 0
580
- self.hp_search_backend = None
581
- self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
582
- self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
583
- # Create distant repo and output directory if needed
584
- self.hub_model_id = None
585
- if self.args.push_to_hub:
586
- self.init_hf_repo()
587
- if self.args.should_save:
588
- os.makedirs(self.args.output_dir, exist_ok=True)
589
-
590
- # Add tags for models that have been loaded with the correct transformers version
591
- if hasattr(self.model, "add_model_tags"):
592
- self.model.add_model_tags(self._tag_names)
593
-
594
- #########
595
- ### setup dataloader
596
- #########
597
- self.dataloader = DataLoader(
598
- self.train_dataset,
599
- batch_size=self.local_dataloader_batch_size,
600
- shuffle=True,
601
- collate_fn=self.data_collator,
602
- drop_last=True, # needed; otherwise the last batch will be of ragged shape
603
- )
604
- # sync random states for DataLoader(shuffle=True) before `accelerator.prepare`
605
- # see https://gist.github.com/vwxyzjn/2581bff1e48e185e0b85b6dfe1def79c
606
- torch.manual_seed(args.seed)
607
- self.model, self.optimizer, self.dataloader = accelerator.prepare(self.model, self.optimizer, self.dataloader)
608
- torch.manual_seed(self.local_seed) # reset the local seed again
609
-
610
- self.eval_dataloader = DataLoader(
611
- self.eval_dataset,
612
- batch_size=args.per_device_eval_batch_size,
613
- collate_fn=self.data_collator,
614
- drop_last=True,
615
- ) # no need to shuffle eval dataset
616
- self.eval_dataloader = accelerator.prepare(self.eval_dataloader)
617
-
618
- if self.is_deepspeed_enabled:
619
- self.reward_model = prepare_deepspeed(
620
- self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16
621
- )
622
-
623
- if self.ref_model is None:
624
- if not self.is_peft_model:
625
- raise ValueError("No reference model and model is not a Peft model.")
626
- else:
627
- self.ref_model = prepare_deepspeed(
628
- self.ref_model, args.per_device_train_batch_size, args.fp16, args.bf16
629
- )
630
- else:
631
- if self.ref_model is None:
632
- if not self.is_peft_model:
633
- raise ValueError("No reference model and model is not a Peft model.")
634
- else:
635
- self.ref_model = self.ref_model.to(self.accelerator.device)
636
- self.reward_model = self.reward_model.to(self.accelerator.device)
637
-
638
- def get_train_dataloader(self) -> DataLoader:
639
- return self.dataloader
640
-
641
- def get_eval_dataloader(self) -> DataLoader:
642
- return self.eval_dataloader
643
-
644
- @contextmanager
645
- def null_ref_context(self):
646
- """Context manager for handling null reference model (that is, peft adapter manipulation)."""
647
- with (
648
- self.accelerator.unwrap_model(self.model.policy).disable_adapter()
649
- if self.is_peft_model and not self.ref_adapter_name
650
- else nullcontext()
651
- ):
652
- if self.ref_adapter_name:
653
- self.model.policy.set_adapter(self.ref_adapter_name)
654
- yield
655
- if self.ref_adapter_name:
656
- self.model.policy.set_adapter(self.model_adapter_name or "default")
657
-
658
- def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
659
- backup_model = self.model
660
- self.model = self.model.policy # save only the policy
661
-
662
- if self.is_deepspeed_enabled:
663
- backup_deepspeed = self.deepspeed
664
- self.deepspeed = self.model
665
-
666
- super().save_model(output_dir, _internal_call)
667
-
668
- self.model = backup_model
669
-
670
- if self.is_deepspeed_enabled:
671
- self.deepspeed = backup_deepspeed
672
-
673
- def train(self):
674
- args = self.args
675
- accelerator = self.accelerator
676
- optimizer = self.optimizer
677
- model = self.model
678
- ref_policy = self.ref_model
679
- reward_model = self.reward_model
680
- processing_class = self.processing_class
681
- dataloader = self.dataloader
682
- device = accelerator.device
683
-
684
- def repeat_generator():
685
- while True:
686
- yield from dataloader
687
-
688
- iter_dataloader = iter(repeat_generator())
689
- generation_config = GenerationConfig(
690
- max_new_tokens=args.response_length,
691
- temperature=(args.temperature + 1e-7),
692
- top_k=0.0,
693
- top_p=1.0,
694
- do_sample=True,
695
- )
696
-
697
- accelerator.print("===training policy===")
698
- start_time = time.time()
699
- stats_shape = (args.num_ppo_epochs, args.num_mini_batches, args.gradient_accumulation_steps)
700
- approxkl_stats = torch.zeros(stats_shape, device=device)
701
- pg_clipfrac_stats = torch.zeros(stats_shape, device=device)
702
- pg_loss_stats = torch.zeros(stats_shape, device=device)
703
- vf_loss_stats = torch.zeros(stats_shape, device=device)
704
- vf_clipfrac_stats = torch.zeros(stats_shape, device=device)
705
- entropy_stats = torch.zeros(stats_shape, device=device)
706
- ratio_stats = torch.zeros(stats_shape, device=device)
707
- model.train()
708
-
709
- # trainer state initialization
710
- self.state.global_step = 0
711
- self.state.episode = 0
712
- self.state.max_steps = args.num_total_batches * args.num_mini_batches
713
- self.state.num_train_epochs = args.total_episodes / self.train_dataset_len
714
- # Compute absolute values for logging, eval, and save if given as ratio
715
- if args.logging_steps is not None:
716
- if args.logging_steps < 1:
717
- self.state.logging_steps = math.ceil(self.state.max_steps * args.logging_steps)
718
- else:
719
- self.state.logging_steps = args.logging_steps
720
- if args.eval_steps is not None:
721
- if args.eval_steps < 1:
722
- self.state.eval_steps = math.ceil(self.state.max_steps * args.eval_steps)
723
- else:
724
- self.state.eval_steps = args.eval_steps
725
- if args.save_steps is not None:
726
- if args.save_steps < 1:
727
- self.state.save_steps = math.ceil(self.state.max_steps * args.save_steps)
728
- else:
729
- self.state.save_steps = args.save_steps
730
- self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
731
-
732
- # backward compatibility
733
- if self.is_deepspeed_enabled:
734
- self.deepspeed = self.model
735
- self.model_wrapped = self.model
736
-
737
- for update in range(1, args.num_total_batches + 1):
738
- self.state.episode += 1 * args.batch_size
739
- data = next(iter_dataloader)
740
- with torch.no_grad():
741
- queries = data["input_ids"].to(device)
742
- context_length = queries.shape[1]
743
- responses = []
744
- postprocessed_responses = []
745
- logprobs = []
746
- ref_logprobs = []
747
- scores = []
748
- sequence_lengths = []
749
- values = []
750
- with unwrap_model_for_generation(
751
- self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
752
- ) as unwrapped_model:
753
- query_responses, logitss = batch_generation(
754
- unwrapped_model.policy,
755
- queries,
756
- args.local_rollout_forward_batch_size,
757
- processing_class.pad_token_id,
758
- generation_config,
759
- )
760
-
761
- for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size):
762
- query = queries[i : i + args.local_rollout_forward_batch_size]
763
- query_response = query_responses[i : i + args.local_rollout_forward_batch_size]
764
- response = query_response[:, context_length:]
765
- logits = logitss[i : i + args.local_rollout_forward_batch_size]
766
- logprob = selective_log_softmax(logits, response)
767
- del logits
768
- torch.cuda.empty_cache()
769
-
770
- if ref_policy is None:
771
- with self.null_ref_context():
772
- ref_output = forward(model.policy, query_response, processing_class.pad_token_id)
773
- else:
774
- ref_output = forward(ref_policy, query_response, processing_class.pad_token_id)
775
- ref_logits = ref_output.logits[:, context_length - 1 : -1]
776
- ref_logits /= args.temperature + 1e-7
777
- ref_logprob = selective_log_softmax(ref_logits, response)
778
- del ref_output, ref_logits
779
- torch.cuda.empty_cache()
780
-
781
- # Response Processing 1. truncate response after the first occurrence of `stop_token_id`
782
- postprocessed_response = response
783
- if self.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
784
- postprocessed_response = truncate_response(
785
- self.stop_token_id, processing_class.pad_token_id, response
786
- )
787
-
788
- # Response Processing 2. run reward model on the truncated responses
789
- postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
790
- sequence_length = first_true_indices(postprocessed_response == processing_class.pad_token_id) - 1
791
- unwrapped_value_model = accelerator.unwrap_model(model).value_model
792
- full_value, _, _ = get_reward(
793
- unwrapped_value_model, query_response, processing_class.pad_token_id, context_length
794
- )
795
- value = full_value[:, context_length - 1 : -1].squeeze(-1)
796
- _, score, _ = get_reward(
797
- reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length
798
- )
799
-
800
- responses.append(response)
801
- postprocessed_responses.append(postprocessed_response)
802
- logprobs.append(logprob)
803
- ref_logprobs.append(ref_logprob)
804
- sequence_lengths.append(sequence_length)
805
- scores.append(score)
806
- values.append(value)
807
- responses = torch.cat(responses, 0)
808
- postprocessed_responses = torch.cat(postprocessed_responses, 0)
809
- logprobs = torch.cat(logprobs, 0)
810
- ref_logprobs = torch.cat(ref_logprobs, 0)
811
- sequence_lengths = torch.cat(sequence_lengths, 0)
812
- scores = torch.cat(scores, 0)
813
- values = torch.cat(values, 0)
814
- del (logprob, ref_logprob, full_value, value, score, unwrapped_model)
815
- torch.cuda.empty_cache()
816
- gc.collect()
817
-
818
- # Response Processing 3. Filter completion. Ensure that the sample contains stop_token_id
819
- # Completions not passing that filter will receive a lower score.
820
- contain_eos_token = torch.any(postprocessed_responses == self.processing_class.eos_token_id, dim=-1)
821
- if self.args.missing_eos_penalty is not None:
822
- scores[~contain_eos_token] -= self.args.missing_eos_penalty
823
- # accelerator.print(f"{scores=}, {(contain_eos_token.sum() / len(contain_eos_token))=}")
824
-
825
- # be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw
826
- response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1)
827
- padding_mask = response_idxs > sequence_lengths.unsqueeze(1)
828
- logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB)
829
- ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB)
830
- sequence_lengths_p1 = sequence_lengths + 1
831
- padding_mask_p1 = response_idxs > (sequence_lengths_p1.unsqueeze(1))
832
- values = torch.masked_fill(values, padding_mask_p1, 0)
833
-
834
- # 4. compute rewards
835
- kl = logprobs - ref_logprobs
836
- non_score_reward = -args.kl_coef * kl
837
- rewards = non_score_reward.clone()
838
- actual_start = torch.arange(rewards.size(0), device=rewards.device)
839
- actual_end = torch.where(sequence_lengths_p1 < rewards.size(1), sequence_lengths_p1, sequence_lengths)
840
- rewards[[actual_start, actual_end]] += scores
841
-
842
- # 5. whiten rewards
843
- if args.whiten_rewards:
844
- rewards = masked_whiten(rewards, mask=~padding_mask_p1, shift_mean=False)
845
- rewards = torch.masked_fill(rewards, padding_mask_p1, 0)
846
-
847
- # 6. compute advantages and returns
848
- lastgaelam = 0
849
- advantages_reversed = []
850
- gen_length = responses.shape[1]
851
- for t in reversed(range(gen_length)):
852
- nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0
853
- delta = rewards[:, t] + args.gamma * nextvalues - values[:, t]
854
- lastgaelam = delta + args.gamma * args.lam * lastgaelam
855
- advantages_reversed.append(lastgaelam)
856
- advantages = torch.stack(advantages_reversed[::-1], axis=1)
857
- returns = advantages + values
858
- advantages = masked_whiten(advantages, ~padding_mask)
859
- advantages = torch.masked_fill(advantages, padding_mask, 0)
860
- torch.cuda.empty_cache()
861
-
862
- # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch
863
- for ppo_epoch_idx in range(args.num_ppo_epochs):
864
- b_inds = np.random.permutation(args.local_batch_size)
865
- minibatch_idx = 0
866
- for mini_batch_start in range(0, args.local_batch_size, args.local_mini_batch_size):
867
- mini_batch_end = mini_batch_start + args.local_mini_batch_size
868
- mini_batch_inds = b_inds[mini_batch_start:mini_batch_end]
869
- gradient_accumulation_idx = 0
870
- for micro_batch_start in range(0, args.local_mini_batch_size, args.per_device_train_batch_size):
871
- with accelerator.accumulate(model):
872
- micro_batch_end = micro_batch_start + args.per_device_train_batch_size
873
- micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end]
874
- mb_advantage = advantages[micro_batch_inds]
875
- mb_responses = responses[micro_batch_inds]
876
- mb_query_responses = query_responses[micro_batch_inds]
877
- mb_logprobs = logprobs[micro_batch_inds]
878
- mb_return = returns[micro_batch_inds]
879
- mb_values = values[micro_batch_inds]
880
-
881
- output, vpred_temp = forward(model, mb_query_responses, processing_class.pad_token_id)
882
- logits = output.logits[:, context_length - 1 : -1]
883
- logits /= args.temperature + 1e-7
884
- new_logprobs = selective_log_softmax(logits, mb_responses)
885
- new_logprobs = torch.masked_fill(
886
- new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB
887
- )
888
- vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1)
889
- vpred = torch.masked_fill(vpred, padding_mask_p1[micro_batch_inds], 0)
890
- vpredclipped = torch.clamp(
891
- vpred,
892
- mb_values - args.cliprange_value,
893
- mb_values + args.cliprange_value,
894
- )
895
- vf_losses1 = torch.square(vpred - mb_return)
896
- vf_losses2 = torch.square(vpredclipped - mb_return)
897
- vf_loss_max = torch.max(vf_losses1, vf_losses2)
898
- vf_loss = 0.5 * masked_mean(vf_loss_max, ~padding_mask_p1[micro_batch_inds])
899
- vf_clipfrac = masked_mean(
900
- (vf_losses2 > vf_losses1).float(), ~padding_mask_p1[micro_batch_inds]
901
- )
902
- logprobs_diff = new_logprobs - mb_logprobs
903
- ratio = torch.exp(logprobs_diff)
904
- pg_losses = -mb_advantage * ratio
905
- pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange)
906
- pg_loss_max = torch.max(pg_losses, pg_losses2)
907
- pg_loss = masked_mean(pg_loss_max, ~padding_mask[micro_batch_inds])
908
- loss = pg_loss + args.vf_coef * vf_loss
909
- accelerator.backward(loss)
910
- optimizer.step()
911
- optimizer.zero_grad()
912
- with torch.no_grad():
913
- pg_clipfrac = masked_mean(
914
- (pg_losses2 > pg_losses).float(), ~padding_mask[micro_batch_inds]
915
- )
916
- prob_dist = torch.nn.functional.softmax(logits, dim=-1)
917
- entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1)
918
- approxkl = 0.5 * (logprobs_diff**2).mean()
919
- approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl
920
- pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = (
921
- pg_clipfrac
922
- )
923
- pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss
924
- vf_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss
925
- vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = (
926
- vf_clipfrac
927
- )
928
- entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean()
929
- ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean()
930
- gradient_accumulation_idx += 1
931
- minibatch_idx += 1
932
- # del everything and empty cache
933
- # fmt: off
934
- del (
935
- output, vpred_temp, logits, new_logprobs, vpred, vpredclipped,
936
- vf_losses1, vf_losses2, vf_loss, vf_clipfrac, logprobs_diff, ratio, pg_losses, pg_losses2, pg_loss_max,
937
- pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl, mb_return,
938
- mb_advantage, mb_values, mb_responses, mb_query_responses, mb_logprobs,
939
- )
940
- # fmt: on
941
- torch.cuda.empty_cache()
942
- with torch.no_grad():
943
- mean_kl = kl.sum(1).mean()
944
- mean_entropy = (-logprobs).sum(1).mean()
945
- mean_non_score_reward = non_score_reward.sum(1).mean()
946
- rlhf_reward = mean_non_score_reward + scores.mean()
947
- eps = int(self.state.episode / (time.time() - start_time))
948
- metrics = {}
949
- metrics["eps"] = eps
950
- metrics["objective/kl"] = self.accelerator.gather_for_metrics(mean_kl).mean().item()
951
- metrics["objective/entropy"] = self.accelerator.gather_for_metrics(mean_entropy).mean().item()
952
- metrics["objective/non_score_reward"] = (
953
- self.accelerator.gather_for_metrics(mean_non_score_reward).mean().item()
954
- )
955
- metrics["objective/rlhf_reward"] = self.accelerator.gather_for_metrics(rlhf_reward).mean().item()
956
- metrics["objective/scores"] = self.accelerator.gather_for_metrics(scores.mean()).mean().item()
957
- metrics["policy/approxkl_avg"] = self.accelerator.gather_for_metrics(approxkl_stats).mean().item()
958
- metrics["policy/clipfrac_avg"] = self.accelerator.gather_for_metrics(pg_clipfrac_stats).mean().item()
959
- metrics["loss/policy_avg"] = self.accelerator.gather_for_metrics(pg_loss_stats).mean().item()
960
- metrics["loss/value_avg"] = self.accelerator.gather_for_metrics(vf_loss_stats).mean().item()
961
- metrics["val/clipfrac_avg"] = self.accelerator.gather_for_metrics(vf_clipfrac_stats).mean().item()
962
- metrics["policy/entropy_avg"] = self.accelerator.gather_for_metrics(entropy_stats).mean().item()
963
- metrics["val/ratio"] = self.accelerator.gather_for_metrics(ratio_stats).mean().item()
964
- metrics["val/ratio_var"] = self.accelerator.gather_for_metrics(ratio_stats).var().item()
965
- metrics["val/num_eos_tokens"] = (responses == processing_class.eos_token_id).sum().item()
966
- metrics["lr"] = self.lr_scheduler.get_last_lr()[0]
967
- metrics["episode"] = self.state.episode
968
- self.state.epoch = self.state.episode / self.train_dataset_len # used by self.log
969
- self.state.global_step += 1
970
- self.log(metrics)
971
-
972
- self.lr_scheduler.step()
973
- self.control = self.callback_handler.on_step_end(args, self.state, self.control)
974
- if self.control.should_save:
975
- self._save_checkpoint(model, trial=None)
976
- self.control = self.callback_handler.on_save(self.args, self.state, self.control)
977
- del kl, mean_kl, mean_entropy, mean_non_score_reward, scores, metrics, non_score_reward
978
- torch.cuda.empty_cache()
979
- gc.collect()
980
-
981
- if args.num_sample_generations > 0 and (update - 1) % self.sample_generations_freq == 0:
982
- self.generate_completions(sampling=True)
983
- torch.cuda.empty_cache()
984
- del (
985
- query_responses,
986
- responses,
987
- postprocessed_responses,
988
- logprobs,
989
- ref_logprobs,
990
- values,
991
- sequence_lengths,
992
- contain_eos_token,
993
- sequence_lengths_p1,
994
- response_idxs,
995
- padding_mask,
996
- padding_mask_p1,
997
- rewards,
998
- actual_start,
999
- actual_end,
1000
- advantages,
1001
- returns,
1002
- )
1003
- torch.cuda.empty_cache()
1004
-
1005
- # HF trainer specifics
1006
- self.control = self.callback_handler.on_train_end(args, self.state, self.control)
1007
- if self.control.should_save:
1008
- self._save_checkpoint(model, trial=None, metrics=None)
1009
- self.control = self.callback_handler.on_save(self.args, self.state, self.control)
1010
-
1011
- def generate_completions(self, sampling: bool = False):
1012
- args = self.args
1013
- processing_class = self.processing_class
1014
- generation_config = GenerationConfig(
1015
- max_new_tokens=self.args.response_length,
1016
- temperature=(0.01 + 1e-7),
1017
- top_k=0.0,
1018
- top_p=1.0,
1019
- do_sample=True,
1020
- )
1021
-
1022
- table = defaultdict(list)
1023
- with unwrap_model_for_generation(
1024
- self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
1025
- ) as unwrapped_model:
1026
- for batch in self.eval_dataloader:
1027
- query = batch["input_ids"]
1028
- with torch.no_grad():
1029
- context_length = query.shape[1]
1030
- query_response, _ = batch_generation(
1031
- unwrapped_model.policy,
1032
- query,
1033
- query.shape[0],
1034
- processing_class.pad_token_id,
1035
- generation_config,
1036
- )
1037
- response = query_response[:, context_length:]
1038
- postprocessed_response = response
1039
- if self.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
1040
- postprocessed_response = truncate_response(
1041
- self.stop_token_id, processing_class.pad_token_id, response
1042
- )
1043
- table["query"].extend(
1044
- gather_object(processing_class.batch_decode(query, skip_special_tokens=True))
1045
- )
1046
- table["model response"].extend(
1047
- gather_object(processing_class.batch_decode(postprocessed_response))
1048
- )
1049
-
1050
- postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
1051
- _, score, _ = get_reward(
1052
- self.reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length
1053
- )
1054
- table["score"].extend(self.accelerator.gather_for_metrics(score).float().cpu().numpy())
1055
-
1056
- if sampling:
1057
- break
1058
- df = pd.DataFrame(table)
1059
-
1060
- if self.accelerator.is_main_process:
1061
- print_rich_table(df.iloc[0 : 0 + 5])
1062
- if "wandb" in args.report_to:
1063
- import wandb
1064
-
1065
- if wandb.run is not None:
1066
- wandb.log({"completions": wandb.Table(dataframe=df)})
1067
-
1068
- if "comet_ml" in args.report_to:
1069
- log_table_to_comet_experiment(
1070
- name="completions.csv",
1071
- table=df,
1072
- )
1073
-
1074
- def create_model_card(
1075
- self,
1076
- model_name: Optional[str] = None,
1077
- dataset_name: Optional[str] = None,
1078
- tags: Union[str, list[str], None] = None,
1079
- ):
1080
- """
1081
- Creates a draft of a model card using the information available to the `Trainer`.
1082
-
1083
- Args:
1084
- model_name (`str` or `None`, *optional*, defaults to `None`):
1085
- Name of the model.
1086
- dataset_name (`str` or `None`, *optional*, defaults to `None`):
1087
- Name of the dataset used for training.
1088
- tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
1089
- Tags to be associated with the model card.
1090
- """
1091
- if not self.is_world_process_zero():
1092
- return
1093
-
1094
- if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
1095
- base_model = self.model.config._name_or_path
1096
- else:
1097
- base_model = None
1098
-
1099
- tags = tags or []
1100
- if isinstance(tags, str):
1101
- tags = [tags]
1102
-
1103
- if hasattr(self.model.config, "unsloth_version"):
1104
- tags.append("unsloth")
1105
-
1106
- citation = textwrap.dedent("""\
1107
- @article{mziegler2019fine-tuning,
1108
- title = {{Fine-Tuning Language Models from Human Preferences}},
1109
- author = {Daniel M. Ziegler and Nisan Stiennon and Jeffrey Wu and Tom B. Brown and Alec Radford and Dario Amodei and Paul F. Christiano and Geoffrey Irving},
1110
- year = 2019,
1111
- eprint = {arXiv:1909.08593}
1112
- }""")
1113
-
1114
- model_card = generate_model_card(
1115
- base_model=base_model,
1116
- model_name=model_name,
1117
- hub_model_id=self.hub_model_id,
1118
- dataset_name=dataset_name,
1119
- tags=tags,
1120
- wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
1121
- comet_url=get_comet_experiment_url(),
1122
- trainer_name="PPO",
1123
- trainer_citation=citation,
1124
- paper_title="Fine-Tuning Language Models from Human Preferences",
1125
- paper_id="1909.08593",
1126
- )
1127
-
1128
- model_card.save(os.path.join(self.args.output_dir, "README.md"))
1129
- class UnslothPPOTrainer(_UnslothPPOTrainer):
1130
- """
1131
-
1132
- """
1133
- def __init__(
1134
- self,
1135
- args,
1136
- processing_class,
1137
- model,
1138
- ref_model,
1139
- reward_model,
1140
- train_dataset,
1141
- value_model = None,
1142
- data_collator = None,
1143
- eval_dataset = None,
1144
- callbacks = None,
1145
- peft_config = None,
1146
- **kwargs
1147
- ):
1148
- if args is None: args = UnslothPPOConfig()
1149
- use_bf16 = getattr(args, 'bf16', False)
1150
- use_fp16 = getattr(args, 'fp16', False)
1151
- force_float32 = False
1152
- if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
1153
- print('Unsloth: Switching to float32 training since model cannot work with float16')
1154
- force_float32 = True
1155
- mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
1156
- dtype = getattr(model.config, 'torch_dtype', None)
1157
- if dtype is None: dtype = model.get_input_embeddings().dtype
1158
- from unsloth_zoo.utils import _get_dtype
1159
- dtype = _get_dtype(dtype)
1160
- float16 = dtype == torch.float16
1161
- if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
1162
- if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
1163
- if force_float32:
1164
- args.fp16 = False
1165
- args.bf16 = False
1166
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1167
- elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
1168
- args.fp16 = float16
1169
- args.bf16 = not float16
1170
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
1171
- if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
1172
- args.eval_strategy = 'steps'
1173
- if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
1174
- ga_steps = getattr(args, 'gradient_accumulation_steps', None)
1175
- if ga_steps is not None and ga_steps > 1:
1176
- from transformers import __version__ as transformers_version
1177
- if Version(transformers_version) <= Version('4.45.2'):
1178
- print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
1179
- '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
1180
- if getattr(args, 'eval_strategy', 'no') != 'no':
1181
- eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
1182
- if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
1183
- if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
1184
- fp16_full_eval = getattr(args, 'fp16_full_eval', False)
1185
- bf16_full_eval = getattr(args, 'bf16_full_eval', False)
1186
- if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
1187
- if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
1188
- if force_float32:
1189
- args.bf16_full_eval = False
1190
- args.fp16_full_eval = False
1191
- elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
1192
- args.bf16_full_eval = True
1193
- args.fp16_full_eval = False
1194
- elif not bf16_full_eval and not fp16_full_eval:
1195
- args.bf16_full_eval = args.bf16
1196
- args.fp16_full_eval = args.fp16
1197
- _output_logits = False
1198
- if locals().get('compute_metrics', None) is not None: _output_logits = True
1199
- if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
1200
- if _output_logits:
1201
- os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
1202
- if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1203
- pass
1204
- else:
1205
- model_max_seq_length = getattr(model, 'max_seq_length', None)
1206
- args_max_seq_length = getattr(args, 'max_seq_length', None)
1207
- if args_max_seq_length is None and model_max_seq_length is not None:
1208
- max_seq_length = model.max_seq_length
1209
- if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1210
- if model is not None and hasattr(model, 'for_training'):
1211
- model.for_training()
1212
- if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1213
- if 'processing_class' in locals():
1214
- if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1215
- if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1216
- __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
1217
- from unsloth_zoo.vision_utils import UnslothVisionDataCollator
1218
- if not isinstance(data_collator, UnslothVisionDataCollator):
1219
- if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
1220
- data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
1221
- elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
1222
- data_collator = DataCollatorForSeq2Seq(__tokenizer)
1223
- else:
1224
- if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
1225
- if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
1226
- if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
1227
- if not isinstance(data_collator, UnslothVisionDataCollator):
1228
- if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
1229
- if isinstance(data_collator, DataCollatorForSeq2Seq):
1230
- data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
1231
- else:
1232
- data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
1233
- other_metrics = []
1234
-
1235
- from unsloth_zoo.logging_utils import PatchRLStatistics
1236
- PatchRLStatistics('ppo_trainer', other_metrics)
1237
-
1238
- super().__init__(
1239
- args = args,
1240
- processing_class = processing_class,
1241
- model = model,
1242
- ref_model = ref_model,
1243
- reward_model = reward_model,
1244
- train_dataset = train_dataset,
1245
- value_model = value_model,
1246
- data_collator = data_collator,
1247
- eval_dataset = eval_dataset,
1248
- callbacks = callbacks,
1249
- peft_config = peft_config,**kwargs)
1250
- if hasattr(self, 'neftune_hook_handle'):
1251
- self.neftune_hook_handle.remove()
1252
- if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1253
- if getattr(args, 'neftune_noise_alpha', None) is not None:
1254
- model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1255
- pass
1256
-
1257
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
unsloth_compiled_cache/UnslothPRMTrainer.py DELETED
@@ -1,798 +0,0 @@
1
- """
2
- 2025.3.13
3
- 2025.3.15
4
- 4.48.3
5
- 0.15.2
6
- __UNSLOTH_VERSIONING__
7
- """
8
- from torch import Tensor
9
- import torch
10
- import torch.nn as nn
11
- from torch.nn import functional as F
12
- from trl.trainer.prm_trainer import (BaseImageProcessor, Callable, DataCollator, DataCollatorForTokenClassification, Dataset, EvalPrediction, FeatureExtractionMixin, Optional, PRMConfig, PRMTrainer, PartialState, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, Trainer, TrainerCallback, Union, chain, compute_accuracy, disable_dropout_in_model, features, generate_model_card, inspect, is_peft_available, is_wandb_available, nn, os, prepare_model_for_kbit_training, textwrap, torch, wandb, warnings)
13
-
14
-
15
- import os
16
- from typing import *
17
- from dataclasses import dataclass, field
18
- from packaging.version import Version
19
- import torch
20
- import numpy as np
21
- from contextlib import nullcontext
22
- from torch.nn import functional as F
23
- from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
-
25
- torch_compile_options = {
26
- "epilogue_fusion" : True,
27
- "max_autotune" : False,
28
- "shape_padding" : True,
29
- "trace.enabled" : False,
30
- "triton.cudagraphs" : False,
31
- }
32
-
33
- @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
- def selective_log_softmax(logits, index):
35
- logits = logits.to(torch.float32)
36
- selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
- # loop to reduce peak mem consumption
38
- # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
- logsumexp_values = torch.logsumexp(logits, dim = -1)
40
- per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
- return per_token_logps
42
- @dataclass
43
- class UnslothPRMConfig(PRMConfig):
44
- """
45
-
46
- Configuration class for the [`PRMTrainer`].
47
-
48
- Using [`~transformers.HfArgumentParser`] we can turn this class into
49
- [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
50
- command line.
51
-
52
- Parameters:
53
- learning_rate (`float`, *optional*, defaults to `1e-5`):
54
- Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
55
- [`~transformers.TrainingArguments`].
56
- max_length (`int` or `None`, *optional*, defaults to `1024`):
57
- Maximum length of the sequences (prompt + completion) used for truncation.
58
- max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
59
- Maximum length of the prompt used for truncation.
60
- max_completion_length (`int` or `None`, *optional*, defaults to `None`):
61
- Maximum length of the completion used for truncation. The completion is the concatenation of the steps.
62
- disable_dropout (`bool`, *optional*, defaults to `True`):
63
- Whether to disable dropout in the model.
64
- step_separator (`str`, *optional*, defaults to `"\n"`):
65
- Separator used to separate each step of the reasoning process.
66
- train_on_last_step_only (`bool`, *optional*, defaults to `False`):
67
- Whether to train only on the last step.
68
- dataset_num_proc (`int`, *optional*, defaults to `None`):
69
- Number of processes to use for processing the dataset.
70
-
71
- """
72
- vllm_sampling_params: Optional[Any] = field(
73
- default = None,
74
- metadata = {'help': 'vLLM SamplingParams'},
75
- )
76
- unsloth_num_chunks : Optional[int] = field(
77
- default = -1,
78
- metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
79
- )
80
- def __init__(
81
- self,
82
- output_dir = None,
83
- overwrite_output_dir = None,
84
- do_train = False,
85
- do_eval = False,
86
- do_predict = False,
87
- eval_strategy = 'no',
88
- prediction_loss_only = False,
89
- per_device_train_batch_size = 4,
90
- per_device_eval_batch_size = 4,
91
- per_gpu_train_batch_size = None,
92
- per_gpu_eval_batch_size = None,
93
- gradient_accumulation_steps = 2,
94
- eval_accumulation_steps = 2,
95
- eval_delay = 0,
96
- torch_empty_cache_steps = 250,
97
- learning_rate = 5e-05,
98
- weight_decay = 0.01,
99
- adam_beta1 = 0.9,
100
- adam_beta2 = 0.999,
101
- adam_epsilon = 1e-08,
102
- max_grad_norm = 1.0,
103
- num_train_epochs = 3.0,
104
- max_steps = -1,
105
- lr_scheduler_type = 'linear',
106
- warmup_ratio = 0.1,
107
- warmup_steps = 0,
108
- log_level = 'passive',
109
- log_level_replica = 'warning',
110
- log_on_each_node = True,
111
- logging_dir = None,
112
- logging_strategy = 'steps',
113
- logging_first_step = False,
114
- logging_steps = 1,
115
- logging_nan_inf_filter = False,
116
- save_strategy = 'steps',
117
- save_steps = 500,
118
- save_total_limit = None,
119
- save_safetensors = True,
120
- save_on_each_node = False,
121
- save_only_model = False,
122
- restore_callback_states_from_checkpoint = False,
123
- no_cuda = False,
124
- use_cpu = False,
125
- use_mps_device = False,
126
- seed = 3407,
127
- data_seed = 3407,
128
- jit_mode_eval = False,
129
- use_ipex = False,
130
- bf16 = False,
131
- fp16 = False,
132
- fp16_opt_level = 'O1',
133
- half_precision_backend = 'auto',
134
- bf16_full_eval = False,
135
- fp16_full_eval = False,
136
- tf32 = None,
137
- local_rank = -1,
138
- ddp_backend = None,
139
- tpu_num_cores = None,
140
- tpu_metrics_debug = False,
141
- debug = '',
142
- dataloader_drop_last = False,
143
- eval_steps = None,
144
- dataloader_num_workers = 0,
145
- dataloader_prefetch_factor = None,
146
- past_index = -1,
147
- run_name = None,
148
- disable_tqdm = None,
149
- remove_unused_columns = True,
150
- label_names = None,
151
- load_best_model_at_end = False,
152
- metric_for_best_model = None,
153
- greater_is_better = None,
154
- ignore_data_skip = False,
155
- fsdp = '',
156
- fsdp_min_num_params = 0,
157
- fsdp_config = None,
158
- fsdp_transformer_layer_cls_to_wrap = None,
159
- accelerator_config = None,
160
- deepspeed = None,
161
- label_smoothing_factor = 0.0,
162
- optim = 'adamw_8bit',
163
- optim_args = None,
164
- adafactor = False,
165
- group_by_length = False,
166
- length_column_name = 'length',
167
- report_to = None,
168
- ddp_find_unused_parameters = None,
169
- ddp_bucket_cap_mb = None,
170
- ddp_broadcast_buffers = None,
171
- dataloader_pin_memory = True,
172
- dataloader_persistent_workers = False,
173
- skip_memory_metrics = True,
174
- use_legacy_prediction_loop = False,
175
- push_to_hub = False,
176
- resume_from_checkpoint = None,
177
- hub_model_id = None,
178
- hub_strategy = 'every_save',
179
- hub_token = None,
180
- hub_private_repo = None,
181
- hub_always_push = False,
182
- gradient_checkpointing = False,
183
- gradient_checkpointing_kwargs = None,
184
- include_inputs_for_metrics = False,
185
- eval_do_concat_batches = True,
186
- fp16_backend = 'auto',
187
- evaluation_strategy = None,
188
- push_to_hub_model_id = None,
189
- push_to_hub_organization = None,
190
- push_to_hub_token = None,
191
- mp_parameters = '',
192
- auto_find_batch_size = False,
193
- full_determinism = False,
194
- torchdynamo = None,
195
- ray_scope = 'last',
196
- ddp_timeout = 1800,
197
- torch_compile = False,
198
- torch_compile_backend = None,
199
- torch_compile_mode = None,
200
- dispatch_batches = None,
201
- split_batches = None,
202
- include_tokens_per_second = False,
203
- include_num_input_tokens_seen = False,
204
- neftune_noise_alpha = None,
205
- optim_target_modules = None,
206
- batch_eval_metrics = False,
207
- eval_on_start = False,
208
- use_liger_kernel = False,
209
- eval_use_gather_object = False,
210
- average_tokens_across_devices = False,
211
- max_length = 1024,
212
- max_prompt_length = 512,
213
- max_completion_length = None,
214
- disable_dropout = True,
215
- step_separator = '\
216
- ',
217
- train_on_last_step_only = False,
218
- dataset_num_proc = None,
219
- vllm_sampling_params = None,
220
- unsloth_num_chunks = -1,
221
- **kwargs,
222
- ):
223
- if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
224
- if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
225
- if output_dir is None and save_strategy == 'steps' and save_steps == 500:
226
- output_dir = 'unsloth_training_checkpoints'
227
- save_strategy = 'no'
228
- if dataset_num_proc is None:
229
- from multiprocessing import cpu_count
230
- dataset_num_proc = cpu_count()
231
-
232
- super().__init__(
233
- output_dir = output_dir,
234
- overwrite_output_dir = overwrite_output_dir,
235
- do_train = do_train,
236
- do_eval = do_eval,
237
- do_predict = do_predict,
238
- eval_strategy = eval_strategy,
239
- prediction_loss_only = prediction_loss_only,
240
- per_device_train_batch_size = per_device_train_batch_size,
241
- per_device_eval_batch_size = per_device_eval_batch_size,
242
- per_gpu_train_batch_size = per_gpu_train_batch_size,
243
- per_gpu_eval_batch_size = per_gpu_eval_batch_size,
244
- gradient_accumulation_steps = gradient_accumulation_steps,
245
- eval_accumulation_steps = eval_accumulation_steps,
246
- eval_delay = eval_delay,
247
- torch_empty_cache_steps = torch_empty_cache_steps,
248
- learning_rate = learning_rate,
249
- weight_decay = weight_decay,
250
- adam_beta1 = adam_beta1,
251
- adam_beta2 = adam_beta2,
252
- adam_epsilon = adam_epsilon,
253
- max_grad_norm = max_grad_norm,
254
- num_train_epochs = num_train_epochs,
255
- max_steps = max_steps,
256
- lr_scheduler_type = lr_scheduler_type,
257
- warmup_ratio = warmup_ratio,
258
- warmup_steps = warmup_steps,
259
- log_level = log_level,
260
- log_level_replica = log_level_replica,
261
- log_on_each_node = log_on_each_node,
262
- logging_dir = logging_dir,
263
- logging_strategy = logging_strategy,
264
- logging_first_step = logging_first_step,
265
- logging_steps = logging_steps,
266
- logging_nan_inf_filter = logging_nan_inf_filter,
267
- save_strategy = save_strategy,
268
- save_steps = save_steps,
269
- save_total_limit = save_total_limit,
270
- save_safetensors = save_safetensors,
271
- save_on_each_node = save_on_each_node,
272
- save_only_model = save_only_model,
273
- restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
274
- no_cuda = no_cuda,
275
- use_cpu = use_cpu,
276
- use_mps_device = use_mps_device,
277
- seed = seed,
278
- data_seed = data_seed,
279
- jit_mode_eval = jit_mode_eval,
280
- use_ipex = use_ipex,
281
- bf16 = bf16,
282
- fp16 = fp16,
283
- fp16_opt_level = fp16_opt_level,
284
- half_precision_backend = half_precision_backend,
285
- bf16_full_eval = bf16_full_eval,
286
- fp16_full_eval = fp16_full_eval,
287
- tf32 = tf32,
288
- local_rank = local_rank,
289
- ddp_backend = ddp_backend,
290
- tpu_num_cores = tpu_num_cores,
291
- tpu_metrics_debug = tpu_metrics_debug,
292
- debug = debug,
293
- dataloader_drop_last = dataloader_drop_last,
294
- eval_steps = eval_steps,
295
- dataloader_num_workers = dataloader_num_workers,
296
- dataloader_prefetch_factor = dataloader_prefetch_factor,
297
- past_index = past_index,
298
- run_name = run_name,
299
- disable_tqdm = disable_tqdm,
300
- remove_unused_columns = remove_unused_columns,
301
- label_names = label_names,
302
- load_best_model_at_end = load_best_model_at_end,
303
- metric_for_best_model = metric_for_best_model,
304
- greater_is_better = greater_is_better,
305
- ignore_data_skip = ignore_data_skip,
306
- fsdp = fsdp,
307
- fsdp_min_num_params = fsdp_min_num_params,
308
- fsdp_config = fsdp_config,
309
- fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
310
- accelerator_config = accelerator_config,
311
- deepspeed = deepspeed,
312
- label_smoothing_factor = label_smoothing_factor,
313
- optim = optim,
314
- optim_args = optim_args,
315
- adafactor = adafactor,
316
- group_by_length = group_by_length,
317
- length_column_name = length_column_name,
318
- report_to = report_to,
319
- ddp_find_unused_parameters = ddp_find_unused_parameters,
320
- ddp_bucket_cap_mb = ddp_bucket_cap_mb,
321
- ddp_broadcast_buffers = ddp_broadcast_buffers,
322
- dataloader_pin_memory = dataloader_pin_memory,
323
- dataloader_persistent_workers = dataloader_persistent_workers,
324
- skip_memory_metrics = skip_memory_metrics,
325
- use_legacy_prediction_loop = use_legacy_prediction_loop,
326
- push_to_hub = push_to_hub,
327
- resume_from_checkpoint = resume_from_checkpoint,
328
- hub_model_id = hub_model_id,
329
- hub_strategy = hub_strategy,
330
- hub_token = hub_token,
331
- hub_private_repo = hub_private_repo,
332
- hub_always_push = hub_always_push,
333
- gradient_checkpointing = gradient_checkpointing,
334
- gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
335
- include_inputs_for_metrics = include_inputs_for_metrics,
336
- eval_do_concat_batches = eval_do_concat_batches,
337
- fp16_backend = fp16_backend,
338
- evaluation_strategy = evaluation_strategy,
339
- push_to_hub_model_id = push_to_hub_model_id,
340
- push_to_hub_organization = push_to_hub_organization,
341
- push_to_hub_token = push_to_hub_token,
342
- mp_parameters = mp_parameters,
343
- auto_find_batch_size = auto_find_batch_size,
344
- full_determinism = full_determinism,
345
- torchdynamo = torchdynamo,
346
- ray_scope = ray_scope,
347
- ddp_timeout = ddp_timeout,
348
- torch_compile = torch_compile,
349
- torch_compile_backend = torch_compile_backend,
350
- torch_compile_mode = torch_compile_mode,
351
- dispatch_batches = dispatch_batches,
352
- split_batches = split_batches,
353
- include_tokens_per_second = include_tokens_per_second,
354
- include_num_input_tokens_seen = include_num_input_tokens_seen,
355
- neftune_noise_alpha = neftune_noise_alpha,
356
- optim_target_modules = optim_target_modules,
357
- batch_eval_metrics = batch_eval_metrics,
358
- eval_on_start = eval_on_start,
359
- use_liger_kernel = use_liger_kernel,
360
- eval_use_gather_object = eval_use_gather_object,
361
- average_tokens_across_devices = average_tokens_across_devices,
362
- max_length = max_length,
363
- max_prompt_length = max_prompt_length,
364
- max_completion_length = max_completion_length,
365
- disable_dropout = disable_dropout,
366
- step_separator = step_separator,
367
- train_on_last_step_only = train_on_last_step_only,
368
- dataset_num_proc = dataset_num_proc,**kwargs)
369
- self.vllm_sampling_params = vllm_sampling_params
370
- self.unsloth_num_chunks = unsloth_num_chunks
371
- pass
372
-
373
- class _UnslothPRMTrainer(Trainer):
374
- """"""
375
-
376
- _tag_names = ["trl", "prm"]
377
-
378
- def __init__(
379
- self,
380
- model: Optional[Union[PreTrainedModel, nn.Module]] = None,
381
- args: Optional[PRMConfig] = None,
382
- data_collator: Optional[DataCollator] = None,
383
- train_dataset: Optional[Dataset] = None,
384
- eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
385
- processing_class: Optional[
386
- Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
387
- ] = None,
388
- model_init: Optional[Callable[[], PreTrainedModel]] = None,
389
- compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
390
- callbacks: Optional[list[TrainerCallback]] = None,
391
- optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (
392
- None,
393
- None,
394
- ),
395
- preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
396
- peft_config: Optional[dict] = None,
397
- ):
398
- if not is_peft_available() and peft_config is not None:
399
- raise ValueError(
400
- "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
401
- )
402
- elif is_peft_available() and peft_config is not None:
403
- if not isinstance(model, PeftModel):
404
- if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_quantized", False):
405
- _supports_gc_kwargs = "gradient_checkpointing_kwargs" in list(
406
- inspect.signature(prepare_model_for_kbit_training).parameters
407
- )
408
-
409
- prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
410
-
411
- if not _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
412
- warnings.warn(
413
- "You passed `gradient_checkpointing_kwargs` in the trainer's kwargs, but your peft version does not support it. "
414
- "please update to the latest version of peft to use `gradient_checkpointing_kwargs`."
415
- )
416
- elif _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
417
- prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
418
-
419
- model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
420
-
421
- model = model
422
-
423
- # Disable dropout in the model
424
- if args.disable_dropout:
425
- disable_dropout_in_model(model)
426
-
427
- if compute_metrics is None:
428
- compute_metrics = compute_accuracy
429
-
430
- if data_collator is None:
431
- if processing_class is None:
432
- raise ValueError(
433
- "A processing_class must be specified when using the default DataCollatorForTokenClassification"
434
- )
435
- data_collator = DataCollatorForTokenClassification(processing_class, max_length=args.max_length)
436
-
437
- if "input_ids" not in train_dataset.column_names:
438
- with PartialState().local_main_process_first():
439
- fn_kwargs = {
440
- "tokenizer": processing_class,
441
- "step_separator": args.step_separator,
442
- "max_length": args.max_length,
443
- "max_prompt_length": args.max_prompt_length,
444
- "max_completion_length": args.max_completion_length,
445
- "train_on_last_step_only": args.train_on_last_step_only,
446
- }
447
- train_fn_kwargs = {**fn_kwargs, "is_eval": False}
448
- train_dataset = train_dataset.map(
449
- self.tokenize_row,
450
- fn_kwargs=train_fn_kwargs,
451
- num_proc=args.dataset_num_proc,
452
- remove_columns=train_dataset.features,
453
- desc="Tokenizing train dataset",
454
- features=features.Features( # needed to avoid map to cast labels to bool
455
- {
456
- "labels": features.Sequence(features.Value("int64")),
457
- "input_ids": features.Sequence(features.Value("int64")),
458
- }
459
- ),
460
- )
461
-
462
- eval_fn_kwargs = {**fn_kwargs, "is_eval": True}
463
- if eval_dataset is not None:
464
- eval_dataset = eval_dataset.map(
465
- self.tokenize_row,
466
- fn_kwargs=eval_fn_kwargs,
467
- num_proc=args.dataset_num_proc,
468
- remove_columns=eval_dataset.features,
469
- desc="Tokenizing eval dataset",
470
- features=features.Features( # needed to avoid map to cast labels to bool
471
- {
472
- "labels": features.Sequence(features.Value("int64")),
473
- "input_ids": features.Sequence(features.Value("int64")),
474
- }
475
- ),
476
- )
477
-
478
- super().__init__(
479
- model=model,
480
- args=args,
481
- data_collator=data_collator,
482
- train_dataset=train_dataset,
483
- eval_dataset=eval_dataset,
484
- processing_class=processing_class,
485
- model_init=model_init,
486
- compute_metrics=compute_metrics,
487
- callbacks=callbacks,
488
- optimizers=optimizers,
489
- preprocess_logits_for_metrics=preprocess_logits_for_metrics,
490
- )
491
-
492
- # Add tags for models that have been loaded with the correct transformers version
493
- if hasattr(self.model, "add_model_tags"):
494
- self.model.add_model_tags(self._tag_names)
495
-
496
- @staticmethod
497
- def tokenize_row(
498
- features,
499
- tokenizer,
500
- step_separator,
501
- max_length,
502
- max_prompt_length,
503
- max_completion_length,
504
- train_on_last_step_only,
505
- is_eval,
506
- ):
507
- r"""
508
- Tokenize a row of the dataset.
509
-
510
- Args:
511
- features (`dict[str, str]`):
512
- Row of the dataset, should contain the keys `"prompt"`, `"completions"`, and `"labels"`.
513
- tokenizer (`PreTrainedTokenizerBase`):
514
- Tokenizer used to process the data.
515
- step_separator (`str`):
516
- Separator between steps in the completion.
517
- max_length (`int` or `None`):
518
- Maximum length of the sequences (prompt + completion). If `None`, the sequences are not truncated.
519
- max_prompt_length (`int` or `None`):
520
- Maximum length of the prompt. If `None`, the prompt is not truncated.
521
- max_completion_length (`int` or `None`):
522
- Maximum length of the completion sequences. If `None`, the completion sequences are not truncated.
523
- train_on_last_step_only (`bool`):
524
- Whether to train only on the last step. If `True`, the labels are `-100` for all tokens except the last
525
- token of the completion.
526
- is_eval (`bool`):
527
- Whether the function is used to tokenize samples from a training or an evaluation dataset. Used only if `train_on_last_step_only` is set to `True`.
528
-
529
- Returns:
530
- `dict[str, list[int]]`:
531
- Tokenized sequences with the keys `"input_ids"`, and `"labels".
532
-
533
- Example:
534
- ```python
535
- >>> from transformers import AutoTokenizer
536
- >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
537
- >>> features = {"prompt": "Which number is larger, 9.8 or 9.11?",
538
- ... "completions": ["11 is greater than 8.",
539
- ... "Hence, 9.11 > 9.8."],
540
- ... "labels": [True, False]}
541
- >>> PRMTrainer.tokenize_row(features, tokenizer, "\n", max_completion_length=None, train_on_last_step_only=False, is_eval=False)
542
- {'input_ids': [23085, 1372, 374, 8131, 11, 220, 24, 13, 23, 476, 220, 24, 13, 16, 16, 30, 16, 16, 374, 7046, 1091, 220, 23, 13, 198, 39, 763, 11, 220, 24, 13, 16, 16, 861, 220, 24, 13, 23, 13, 198],
543
- 'labels': [-100, -100, -100, -100, -100, -100, -100, -100, 1, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 0]}
544
- ```
545
- """
546
- # Tokenize the prompt and completions
547
- prompt_ids = tokenizer(features["prompt"], add_special_tokens=False)["input_ids"]
548
- completions_ids = [
549
- tokenizer(completion, add_special_tokens=False)["input_ids"] for completion in features["completions"]
550
- ]
551
- if train_on_last_step_only and not is_eval:
552
- labels = [-100] * (len(features["labels"]) - 1) + [int(features["labels"][-1])]
553
- else:
554
- labels = [int(label) for label in features["labels"]]
555
-
556
- # Get the ID of the separator token and add it to the completions
557
- separator_ids = tokenizer.encode(step_separator, add_special_tokens=False)
558
- completions_ids = [completion + separator_ids for completion in completions_ids]
559
-
560
- # Create the label
561
- labels = [[-100] * (len(completion) - 1) + [label] for completion, label in zip(completions_ids, labels)]
562
-
563
- # Join the completions and labels steps
564
- completion_ids = list(chain(*completions_ids))
565
- labels = list(chain(*labels))
566
-
567
- if tokenizer.bos_token_id is not None:
568
- prompt_ids = [tokenizer.bos_token_id] + prompt_ids
569
-
570
- # Truncate prompt and completion sequences
571
- if max_prompt_length is not None:
572
- prompt_ids = prompt_ids[-max_prompt_length:]
573
- if max_completion_length is not None:
574
- completion_ids = completion_ids[:max_completion_length]
575
- labels = labels[:max_completion_length]
576
-
577
- input_ids = prompt_ids + completion_ids
578
- labels = [-100] * len(prompt_ids) + labels
579
-
580
- if max_length is not None:
581
- input_ids = input_ids[:max_length]
582
- labels = labels[:max_length]
583
-
584
- return {"input_ids": input_ids, "labels": labels}
585
-
586
- def create_model_card(
587
- self,
588
- model_name: Optional[str] = None,
589
- dataset_name: Optional[str] = None,
590
- tags: Union[str, list[str], None] = None,
591
- ):
592
- """
593
- Creates a draft of a model card using the information available to the `Trainer`.
594
-
595
- Args:
596
- model_name (`str` or `None`, *optional*, defaults to `None`):
597
- Name of the model.
598
- dataset_name (`str` or `None`, *optional*, defaults to `None`):
599
- Name of the dataset used for training.
600
- tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
601
- Tags to be associated with the model card.
602
- """
603
- if not self.is_world_process_zero():
604
- return
605
-
606
- if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
607
- base_model = self.model.config._name_or_path
608
- else:
609
- base_model = None
610
-
611
- tags = tags or []
612
- if isinstance(tags, str):
613
- tags = [tags]
614
-
615
- if hasattr(self.model.config, "unsloth_version"):
616
- tags.append("unsloth")
617
-
618
- citation = textwrap.dedent("""\
619
- @article{uesato2022solving,
620
- title = {{Solving Math Word Problems With Process- and Outcome-Based Feedback}},
621
- author = {Uesato, Jonathan and Kushman, Nate and Kumar, Ramana and Song, Francis and Siegel, Noah and Wang, Lisa and Creswell, Antonia and Irving, Geoffrey and Higgins, Irina},
622
- year = 2022,
623
- journal = {arXiv preprint arXiv:2211.14275}
624
- }""")
625
-
626
- model_card = generate_model_card(
627
- base_model=base_model,
628
- model_name=model_name,
629
- hub_model_id=self.hub_model_id,
630
- dataset_name=dataset_name,
631
- tags=tags,
632
- wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
633
- trainer_name="PRM",
634
- trainer_citation=citation,
635
- paper_title="Solving math word problems with process-and outcome-based feedback",
636
- )
637
-
638
- model_card.save(os.path.join(self.args.output_dir, "README.md"))
639
- class UnslothPRMTrainer(_UnslothPRMTrainer):
640
- """
641
-
642
- Initialize PRMTrainer.
643
-
644
- Args:
645
- model (`transformers.PreTrainedModel`):
646
- The model to train, preferably an `AutoModelForTokenClassification`.
647
- args (`PRMConfig`):
648
- The arguments to use for training.
649
- data_collator (`transformers.DataCollator`):
650
- The data collator to use for training. If None is specified, the default data collator (`DataCollatorForTokenClassification`) will be used
651
- which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
652
- train_dataset (`datasets.Dataset`):
653
- The dataset to use for training.
654
- eval_dataset (`datasets.Dataset`):
655
- The dataset to use for evaluation.
656
- processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
657
- Processing class used to process the data. If provided, will be used to automatically process the inputs
658
- for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
659
- reuse the fine-tuned model.
660
- model_init (`Callable[[], transformers.PreTrainedModel]`):
661
- The model initializer to use for training. If None is specified, the default model initializer will be used.
662
- compute_metrics (`Callable[[transformers.EvalPrediction], dict]`, *optional* defaults to `compute_accuracy`):
663
- The metrics to use for evaluation. If no metrics are specified, the default metric (`compute_accuracy`) will be used.
664
- callbacks (`list[transformers.TrainerCallback]`):
665
- The callbacks to use for training.
666
- optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
667
- The optimizer and scheduler to use for training.
668
- preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
669
- The function to use to preprocess the logits before computing the metrics.
670
- peft_config (`dict`, defaults to `None`):
671
- The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
672
-
673
- """
674
- def __init__(
675
- self,
676
- model = None,
677
- args = None,
678
- data_collator = None,
679
- train_dataset = None,
680
- eval_dataset = None,
681
- processing_class = None,
682
- model_init = None,
683
- compute_metrics = None,
684
- callbacks = None,
685
- preprocess_logits_for_metrics = None,
686
- peft_config = None,
687
- **kwargs
688
- ):
689
- if args is None: args = UnslothPRMConfig()
690
- use_bf16 = getattr(args, 'bf16', False)
691
- use_fp16 = getattr(args, 'fp16', False)
692
- force_float32 = False
693
- if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
694
- print('Unsloth: Switching to float32 training since model cannot work with float16')
695
- force_float32 = True
696
- mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
697
- dtype = getattr(model.config, 'torch_dtype', None)
698
- if dtype is None: dtype = model.get_input_embeddings().dtype
699
- from unsloth_zoo.utils import _get_dtype
700
- dtype = _get_dtype(dtype)
701
- float16 = dtype == torch.float16
702
- if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
703
- if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
704
- if force_float32:
705
- args.fp16 = False
706
- args.bf16 = False
707
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
708
- elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
709
- args.fp16 = float16
710
- args.bf16 = not float16
711
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
712
- if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
713
- args.eval_strategy = 'steps'
714
- if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
715
- ga_steps = getattr(args, 'gradient_accumulation_steps', None)
716
- if ga_steps is not None and ga_steps > 1:
717
- from transformers import __version__ as transformers_version
718
- if Version(transformers_version) <= Version('4.45.2'):
719
- print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
720
- '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
721
- if getattr(args, 'eval_strategy', 'no') != 'no':
722
- eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
723
- if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
724
- if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
725
- fp16_full_eval = getattr(args, 'fp16_full_eval', False)
726
- bf16_full_eval = getattr(args, 'bf16_full_eval', False)
727
- if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
728
- if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
729
- if force_float32:
730
- args.bf16_full_eval = False
731
- args.fp16_full_eval = False
732
- elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
733
- args.bf16_full_eval = True
734
- args.fp16_full_eval = False
735
- elif not bf16_full_eval and not fp16_full_eval:
736
- args.bf16_full_eval = args.bf16
737
- args.fp16_full_eval = args.fp16
738
- _output_logits = False
739
- if locals().get('compute_metrics', None) is not None: _output_logits = True
740
- if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
741
- if _output_logits:
742
- os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
743
- if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
744
- pass
745
- else:
746
- model_max_seq_length = getattr(model, 'max_seq_length', None)
747
- args_max_seq_length = getattr(args, 'max_seq_length', None)
748
- if args_max_seq_length is None and model_max_seq_length is not None:
749
- max_seq_length = model.max_seq_length
750
- if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
751
- if model is not None and hasattr(model, 'for_training'):
752
- model.for_training()
753
- if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
754
- if 'processing_class' in locals():
755
- if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
756
- if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
757
- __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
758
- from unsloth_zoo.vision_utils import UnslothVisionDataCollator
759
- if not isinstance(data_collator, UnslothVisionDataCollator):
760
- if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
761
- data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
762
- elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
763
- data_collator = DataCollatorForSeq2Seq(__tokenizer)
764
- else:
765
- if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
766
- if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
767
- if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
768
- if not isinstance(data_collator, UnslothVisionDataCollator):
769
- if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
770
- if isinstance(data_collator, DataCollatorForSeq2Seq):
771
- data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
772
- else:
773
- data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
774
- other_metrics = []
775
-
776
- from unsloth_zoo.logging_utils import PatchRLStatistics
777
- PatchRLStatistics('prm_trainer', other_metrics)
778
-
779
- super().__init__(
780
- model = model,
781
- args = args,
782
- data_collator = data_collator,
783
- train_dataset = train_dataset,
784
- eval_dataset = eval_dataset,
785
- processing_class = processing_class,
786
- model_init = model_init,
787
- compute_metrics = compute_metrics,
788
- callbacks = callbacks,
789
- preprocess_logits_for_metrics = preprocess_logits_for_metrics,
790
- peft_config = peft_config,**kwargs)
791
- if hasattr(self, 'neftune_hook_handle'):
792
- self.neftune_hook_handle.remove()
793
- if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
794
- if getattr(args, 'neftune_noise_alpha', None) is not None:
795
- model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
796
- pass
797
-
798
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
unsloth_compiled_cache/UnslothRLOOTrainer.py DELETED
@@ -1,1131 +0,0 @@
1
- """
2
- 2025.3.13
3
- 2025.3.15
4
- 4.48.3
5
- 0.15.2
6
- __UNSLOTH_VERSIONING__
7
- """
8
- from torch import Tensor
9
- import torch
10
- import torch.nn as nn
11
- from torch.nn import functional as F
12
- from trl.trainer.rloo_trainer import (Accelerator, BaseImageProcessor, Callable, CallbackHandler, DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK, DataCollatorWithPadding, DataLoader, Dataset, ExportableState, FeatureExtractionMixin, GenerationConfig, INVALID_LOGPROB, OnlineTrainerState, Optional, PreTrainedTokenizerBase, PrinterCallback, ProcessorMixin, RLOOConfig, RLOOTrainer, Trainer, TrainerCallback, TrainerControl, Union, batch_generation, broadcast, defaultdict, disable_dropout_in_model, exact_div, first_true_indices, forward, gather_object, gc, generate_model_card, get_comet_experiment_url, get_reporting_integration_callbacks, get_reward, is_wandb_available, log_table_to_comet_experiment, math, nn, np, os, pd, prepare_deepspeed, print_rich_table, textwrap, time, torch, truncate_response, unwrap_model_for_generation, wandb)
13
-
14
-
15
- import os
16
- from typing import *
17
- from dataclasses import dataclass, field
18
- from packaging.version import Version
19
- import torch
20
- import numpy as np
21
- from contextlib import nullcontext
22
- from torch.nn import functional as F
23
- from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
-
25
- torch_compile_options = {
26
- "epilogue_fusion" : True,
27
- "max_autotune" : False,
28
- "shape_padding" : True,
29
- "trace.enabled" : False,
30
- "triton.cudagraphs" : False,
31
- }
32
-
33
- @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
- def selective_log_softmax(logits, index):
35
- logits = logits.to(torch.float32)
36
- selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
- # loop to reduce peak mem consumption
38
- # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
- logsumexp_values = torch.logsumexp(logits, dim = -1)
40
- per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
- return per_token_logps
42
- @dataclass
43
- class UnslothRLOOConfig(RLOOConfig):
44
- """
45
-
46
- Configuration class for the [`RLOOTrainer`].
47
-
48
- Using [`~transformers.HfArgumentParser`] we can turn this class into
49
- [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
50
- command line.
51
-
52
- Parameters:
53
- exp_name (`str`, *optional*, defaults to `os.path.basename(__file__)[: -len(".py")]`):
54
- Name of this experiment.
55
- reward_model_path (`str`, *optional*, defaults to `"EleutherAI/pythia-160m"`):
56
- Path to the reward model.
57
- num_ppo_epochs (`int`, *optional*, defaults to `4`):
58
- Number of epochs to train.
59
- whiten_rewards (`bool`, *optional*, defaults to `False`):
60
- Whether to whiten the rewards.
61
- kl_coef (`float`, *optional*, defaults to `0.05`):
62
- KL coefficient.
63
- cliprange (`float`, *optional*, defaults to `0.2`):
64
- Clip range.
65
- rloo_k (`int`, *optional*, defaults to `2`):
66
- REINFORCE Leave-One-Out (RLOO) number of online samples per prompt.
67
- normalize_reward (`bool`, *optional*, defaults to `False`):
68
- Whether to normalize rewards.
69
- reward_clip_range (`float`, *optional*, defaults to `10.0`):
70
- Clip range for rewards.
71
- normalize_advantage (`bool`, *optional*, defaults to `False`):
72
- Whether to normalize advantages.
73
- token_level_kl (`bool`, *optional*, defaults to `True`):
74
- Whether to use token-level KL penalty or sequence-level KL penalty.
75
- ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
76
- This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
77
- improving generation speed. However, disabling this option allows training models that exceed the VRAM
78
- capacity of a single GPU, albeit at the cost of slower generation.
79
-
80
- """
81
- vllm_sampling_params: Optional[Any] = field(
82
- default = None,
83
- metadata = {'help': 'vLLM SamplingParams'},
84
- )
85
- unsloth_num_chunks : Optional[int] = field(
86
- default = -1,
87
- metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
88
- )
89
- def __init__(
90
- self,
91
- output_dir = None,
92
- overwrite_output_dir = None,
93
- do_train = False,
94
- do_eval = False,
95
- do_predict = False,
96
- eval_strategy = 'no',
97
- prediction_loss_only = False,
98
- per_device_train_batch_size = 4,
99
- per_device_eval_batch_size = 4,
100
- per_gpu_train_batch_size = None,
101
- per_gpu_eval_batch_size = None,
102
- gradient_accumulation_steps = 2,
103
- eval_accumulation_steps = 2,
104
- eval_delay = 0,
105
- torch_empty_cache_steps = 250,
106
- learning_rate = 5e-05,
107
- weight_decay = 0.01,
108
- adam_beta1 = 0.9,
109
- adam_beta2 = 0.999,
110
- adam_epsilon = 1e-08,
111
- max_grad_norm = 1.0,
112
- num_train_epochs = 3.0,
113
- max_steps = -1,
114
- lr_scheduler_type = 'linear',
115
- warmup_ratio = 0.1,
116
- warmup_steps = 0,
117
- log_level = 'passive',
118
- log_level_replica = 'warning',
119
- log_on_each_node = True,
120
- logging_dir = None,
121
- logging_strategy = 'steps',
122
- logging_first_step = False,
123
- logging_steps = 1,
124
- logging_nan_inf_filter = False,
125
- save_strategy = 'steps',
126
- save_steps = 500,
127
- save_total_limit = None,
128
- save_safetensors = True,
129
- save_on_each_node = False,
130
- save_only_model = False,
131
- restore_callback_states_from_checkpoint = False,
132
- no_cuda = False,
133
- use_cpu = False,
134
- use_mps_device = False,
135
- seed = 3407,
136
- data_seed = 3407,
137
- jit_mode_eval = False,
138
- use_ipex = False,
139
- bf16 = False,
140
- fp16 = False,
141
- fp16_opt_level = 'O1',
142
- half_precision_backend = 'auto',
143
- bf16_full_eval = False,
144
- fp16_full_eval = False,
145
- tf32 = None,
146
- local_rank = -1,
147
- ddp_backend = None,
148
- tpu_num_cores = None,
149
- tpu_metrics_debug = False,
150
- debug = '',
151
- dataloader_drop_last = False,
152
- eval_steps = None,
153
- dataloader_num_workers = 0,
154
- dataloader_prefetch_factor = None,
155
- past_index = -1,
156
- run_name = None,
157
- disable_tqdm = None,
158
- remove_unused_columns = True,
159
- label_names = None,
160
- load_best_model_at_end = False,
161
- metric_for_best_model = None,
162
- greater_is_better = None,
163
- ignore_data_skip = False,
164
- fsdp = '',
165
- fsdp_min_num_params = 0,
166
- fsdp_config = None,
167
- fsdp_transformer_layer_cls_to_wrap = None,
168
- accelerator_config = None,
169
- deepspeed = None,
170
- label_smoothing_factor = 0.0,
171
- optim = 'adamw_8bit',
172
- optim_args = None,
173
- adafactor = False,
174
- group_by_length = False,
175
- length_column_name = 'length',
176
- report_to = None,
177
- ddp_find_unused_parameters = None,
178
- ddp_bucket_cap_mb = None,
179
- ddp_broadcast_buffers = None,
180
- dataloader_pin_memory = True,
181
- dataloader_persistent_workers = False,
182
- skip_memory_metrics = True,
183
- use_legacy_prediction_loop = False,
184
- push_to_hub = False,
185
- resume_from_checkpoint = None,
186
- hub_model_id = None,
187
- hub_strategy = 'every_save',
188
- hub_token = None,
189
- hub_private_repo = None,
190
- hub_always_push = False,
191
- gradient_checkpointing = False,
192
- gradient_checkpointing_kwargs = None,
193
- include_inputs_for_metrics = False,
194
- eval_do_concat_batches = True,
195
- fp16_backend = 'auto',
196
- evaluation_strategy = None,
197
- push_to_hub_model_id = None,
198
- push_to_hub_organization = None,
199
- push_to_hub_token = None,
200
- mp_parameters = '',
201
- auto_find_batch_size = False,
202
- full_determinism = False,
203
- torchdynamo = None,
204
- ray_scope = 'last',
205
- ddp_timeout = 1800,
206
- torch_compile = False,
207
- torch_compile_backend = None,
208
- torch_compile_mode = None,
209
- dispatch_batches = None,
210
- split_batches = None,
211
- include_tokens_per_second = False,
212
- include_num_input_tokens_seen = False,
213
- neftune_noise_alpha = None,
214
- optim_target_modules = None,
215
- batch_eval_metrics = False,
216
- eval_on_start = False,
217
- use_liger_kernel = False,
218
- eval_use_gather_object = False,
219
- average_tokens_across_devices = False,
220
- dataset_num_proc = None,
221
- num_mini_batches = 1,
222
- total_episodes = None,
223
- local_rollout_forward_batch_size = 64,
224
- num_sample_generations = 10,
225
- response_length = 53,
226
- stop_token = None,
227
- stop_token_id = None,
228
- temperature = 0.7,
229
- missing_eos_penalty = None,
230
- sft_model_path = 'EleutherAI/pythia-160m',
231
- world_size = None,
232
- num_total_batches = None,
233
- micro_batch_size = None,
234
- local_batch_size = None,
235
- batch_size = None,
236
- local_mini_batch_size = None,
237
- mini_batch_size = None,
238
- exp_name = 'rloo_config',
239
- reward_model_path = 'EleutherAI/pythia-160m',
240
- num_ppo_epochs = 4,
241
- whiten_rewards = False,
242
- kl_coef = 0.05,
243
- cliprange = 0.2,
244
- rloo_k = 2,
245
- normalize_reward = False,
246
- reward_clip_range = 10.0,
247
- normalize_advantage = False,
248
- token_level_kl = False,
249
- ds3_gather_for_generation = True,
250
- vllm_sampling_params = None,
251
- unsloth_num_chunks = -1,
252
- **kwargs,
253
- ):
254
- if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
255
- if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
256
- if output_dir is None and save_strategy == 'steps' and save_steps == 500:
257
- output_dir = 'unsloth_training_checkpoints'
258
- save_strategy = 'no'
259
- if dataset_num_proc is None:
260
- from multiprocessing import cpu_count
261
- dataset_num_proc = cpu_count()
262
-
263
- super().__init__(
264
- output_dir = output_dir,
265
- overwrite_output_dir = overwrite_output_dir,
266
- do_train = do_train,
267
- do_eval = do_eval,
268
- do_predict = do_predict,
269
- eval_strategy = eval_strategy,
270
- prediction_loss_only = prediction_loss_only,
271
- per_device_train_batch_size = per_device_train_batch_size,
272
- per_device_eval_batch_size = per_device_eval_batch_size,
273
- per_gpu_train_batch_size = per_gpu_train_batch_size,
274
- per_gpu_eval_batch_size = per_gpu_eval_batch_size,
275
- gradient_accumulation_steps = gradient_accumulation_steps,
276
- eval_accumulation_steps = eval_accumulation_steps,
277
- eval_delay = eval_delay,
278
- torch_empty_cache_steps = torch_empty_cache_steps,
279
- learning_rate = learning_rate,
280
- weight_decay = weight_decay,
281
- adam_beta1 = adam_beta1,
282
- adam_beta2 = adam_beta2,
283
- adam_epsilon = adam_epsilon,
284
- max_grad_norm = max_grad_norm,
285
- num_train_epochs = num_train_epochs,
286
- max_steps = max_steps,
287
- lr_scheduler_type = lr_scheduler_type,
288
- warmup_ratio = warmup_ratio,
289
- warmup_steps = warmup_steps,
290
- log_level = log_level,
291
- log_level_replica = log_level_replica,
292
- log_on_each_node = log_on_each_node,
293
- logging_dir = logging_dir,
294
- logging_strategy = logging_strategy,
295
- logging_first_step = logging_first_step,
296
- logging_steps = logging_steps,
297
- logging_nan_inf_filter = logging_nan_inf_filter,
298
- save_strategy = save_strategy,
299
- save_steps = save_steps,
300
- save_total_limit = save_total_limit,
301
- save_safetensors = save_safetensors,
302
- save_on_each_node = save_on_each_node,
303
- save_only_model = save_only_model,
304
- restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
305
- no_cuda = no_cuda,
306
- use_cpu = use_cpu,
307
- use_mps_device = use_mps_device,
308
- seed = seed,
309
- data_seed = data_seed,
310
- jit_mode_eval = jit_mode_eval,
311
- use_ipex = use_ipex,
312
- bf16 = bf16,
313
- fp16 = fp16,
314
- fp16_opt_level = fp16_opt_level,
315
- half_precision_backend = half_precision_backend,
316
- bf16_full_eval = bf16_full_eval,
317
- fp16_full_eval = fp16_full_eval,
318
- tf32 = tf32,
319
- local_rank = local_rank,
320
- ddp_backend = ddp_backend,
321
- tpu_num_cores = tpu_num_cores,
322
- tpu_metrics_debug = tpu_metrics_debug,
323
- debug = debug,
324
- dataloader_drop_last = dataloader_drop_last,
325
- eval_steps = eval_steps,
326
- dataloader_num_workers = dataloader_num_workers,
327
- dataloader_prefetch_factor = dataloader_prefetch_factor,
328
- past_index = past_index,
329
- run_name = run_name,
330
- disable_tqdm = disable_tqdm,
331
- remove_unused_columns = remove_unused_columns,
332
- label_names = label_names,
333
- load_best_model_at_end = load_best_model_at_end,
334
- metric_for_best_model = metric_for_best_model,
335
- greater_is_better = greater_is_better,
336
- ignore_data_skip = ignore_data_skip,
337
- fsdp = fsdp,
338
- fsdp_min_num_params = fsdp_min_num_params,
339
- fsdp_config = fsdp_config,
340
- fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
341
- accelerator_config = accelerator_config,
342
- deepspeed = deepspeed,
343
- label_smoothing_factor = label_smoothing_factor,
344
- optim = optim,
345
- optim_args = optim_args,
346
- adafactor = adafactor,
347
- group_by_length = group_by_length,
348
- length_column_name = length_column_name,
349
- report_to = report_to,
350
- ddp_find_unused_parameters = ddp_find_unused_parameters,
351
- ddp_bucket_cap_mb = ddp_bucket_cap_mb,
352
- ddp_broadcast_buffers = ddp_broadcast_buffers,
353
- dataloader_pin_memory = dataloader_pin_memory,
354
- dataloader_persistent_workers = dataloader_persistent_workers,
355
- skip_memory_metrics = skip_memory_metrics,
356
- use_legacy_prediction_loop = use_legacy_prediction_loop,
357
- push_to_hub = push_to_hub,
358
- resume_from_checkpoint = resume_from_checkpoint,
359
- hub_model_id = hub_model_id,
360
- hub_strategy = hub_strategy,
361
- hub_token = hub_token,
362
- hub_private_repo = hub_private_repo,
363
- hub_always_push = hub_always_push,
364
- gradient_checkpointing = gradient_checkpointing,
365
- gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
366
- include_inputs_for_metrics = include_inputs_for_metrics,
367
- eval_do_concat_batches = eval_do_concat_batches,
368
- fp16_backend = fp16_backend,
369
- evaluation_strategy = evaluation_strategy,
370
- push_to_hub_model_id = push_to_hub_model_id,
371
- push_to_hub_organization = push_to_hub_organization,
372
- push_to_hub_token = push_to_hub_token,
373
- mp_parameters = mp_parameters,
374
- auto_find_batch_size = auto_find_batch_size,
375
- full_determinism = full_determinism,
376
- torchdynamo = torchdynamo,
377
- ray_scope = ray_scope,
378
- ddp_timeout = ddp_timeout,
379
- torch_compile = torch_compile,
380
- torch_compile_backend = torch_compile_backend,
381
- torch_compile_mode = torch_compile_mode,
382
- dispatch_batches = dispatch_batches,
383
- split_batches = split_batches,
384
- include_tokens_per_second = include_tokens_per_second,
385
- include_num_input_tokens_seen = include_num_input_tokens_seen,
386
- neftune_noise_alpha = neftune_noise_alpha,
387
- optim_target_modules = optim_target_modules,
388
- batch_eval_metrics = batch_eval_metrics,
389
- eval_on_start = eval_on_start,
390
- use_liger_kernel = use_liger_kernel,
391
- eval_use_gather_object = eval_use_gather_object,
392
- average_tokens_across_devices = average_tokens_across_devices,
393
- dataset_num_proc = dataset_num_proc,
394
- num_mini_batches = num_mini_batches,
395
- total_episodes = total_episodes,
396
- local_rollout_forward_batch_size = local_rollout_forward_batch_size,
397
- num_sample_generations = num_sample_generations,
398
- response_length = response_length,
399
- stop_token = stop_token,
400
- stop_token_id = stop_token_id,
401
- temperature = temperature,
402
- missing_eos_penalty = missing_eos_penalty,
403
- sft_model_path = sft_model_path,
404
- world_size = world_size,
405
- num_total_batches = num_total_batches,
406
- micro_batch_size = micro_batch_size,
407
- local_batch_size = local_batch_size,
408
- batch_size = batch_size,
409
- local_mini_batch_size = local_mini_batch_size,
410
- mini_batch_size = mini_batch_size,
411
- exp_name = exp_name,
412
- reward_model_path = reward_model_path,
413
- num_ppo_epochs = num_ppo_epochs,
414
- whiten_rewards = whiten_rewards,
415
- kl_coef = kl_coef,
416
- cliprange = cliprange,
417
- rloo_k = rloo_k,
418
- normalize_reward = normalize_reward,
419
- reward_clip_range = reward_clip_range,
420
- normalize_advantage = normalize_advantage,
421
- token_level_kl = token_level_kl,
422
- ds3_gather_for_generation = ds3_gather_for_generation,**kwargs)
423
- self.vllm_sampling_params = vllm_sampling_params
424
- self.unsloth_num_chunks = unsloth_num_chunks
425
- pass
426
-
427
- class _UnslothRLOOTrainer(Trainer):
428
- _tag_names = ["trl", "rloo"]
429
-
430
- def __init__(
431
- self,
432
- config: RLOOConfig,
433
- processing_class: Optional[
434
- Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
435
- ],
436
- policy: nn.Module,
437
- ref_policy: nn.Module,
438
- reward_model: Union[nn.Module, Callable[[list[str]], list[float]]],
439
- train_dataset: Dataset,
440
- data_collator: Optional[DataCollatorWithPadding] = None,
441
- eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
442
- # less commonly used
443
- optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
444
- callbacks: Optional[list[TrainerCallback]] = None,
445
- ) -> None:
446
- if ref_policy is policy:
447
- raise ValueError(
448
- "`policy` and `ref_policy` cannot be the same object. If you want `ref_policy` to be the "
449
- "same as `policy`, you must mass a copy of it, or `None` if you use peft."
450
- )
451
-
452
- self.args = config
453
- args = config
454
- self.processing_class = processing_class
455
- self.policy = policy
456
-
457
- # Define the collator if not provided
458
- if data_collator is None:
459
- data_collator = DataCollatorWithPadding(self.processing_class)
460
-
461
- self.policy.generation_config.eos_token_id = (
462
- None # disable `pad_token_id` and `eos_token_id` because we just want to
463
- )
464
- self.policy.generation_config.pad_token_id = None # generate tokens without truncation / padding
465
-
466
- self.ref_policy = ref_policy
467
- self.reward_model = reward_model
468
- self.train_dataset = train_dataset
469
- self.train_dataset_len = len(train_dataset)
470
- self.data_collator = data_collator
471
- self.eval_dataset = eval_dataset
472
- self.optimizer, self.lr_scheduler = optimizers
473
- self.optimizer_cls_and_kwargs = None # needed for transformers >= 4.47
474
-
475
- #########
476
- # calculate various batch sizes
477
- #########
478
- if args.total_episodes is None: # allow the users to define episodes in terms of epochs.
479
- args.total_episodes = int(args.num_train_epochs * self.train_dataset_len)
480
- accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps)
481
- self.accelerator = accelerator
482
- args.world_size = accelerator.num_processes
483
- args.local_batch_size = (
484
- args.per_device_train_batch_size * args.gradient_accumulation_steps * args.num_mini_batches
485
- )
486
- args.micro_batch_size = int(args.per_device_train_batch_size * args.world_size)
487
- args.batch_size = int(args.local_batch_size * args.world_size)
488
- args.mini_batch_size = exact_div(
489
- args.batch_size, args.num_mini_batches, "`batch_size` must be a multiple of `num_mini_batches`"
490
- )
491
- args.local_mini_batch_size = exact_div(
492
- args.local_batch_size, args.num_mini_batches, "`local_batch_size` must be a multiple of `num_mini_batches`"
493
- )
494
- args.num_total_batches = math.ceil(
495
- args.total_episodes / args.batch_size
496
- ) # we may train for more than `total_episodes`
497
- time_tensor = torch.tensor(int(time.time()), device=accelerator.device)
498
- time_int = broadcast(time_tensor, 0).item() # avoid different timestamps across processes
499
- args.run_name = f"{args.exp_name}__{args.seed}__{time_int}"
500
- self.local_seed = args.seed + accelerator.process_index * 100003 # Prime
501
- if args.num_sample_generations > 0:
502
- self.sample_generations_freq = max(1, args.num_total_batches // args.num_sample_generations)
503
- self.local_dataloader_batch_size = exact_div(
504
- args.local_batch_size, args.rloo_k, "`local_batch_size` must be a multiple of rloo_k"
505
- ) # RLOO logic: needed because RLOO repeats the same prompt args.rloo_k times
506
-
507
- #########
508
- # setup model, optimizer, and others
509
- #########
510
- for module in [policy, ref_policy, reward_model]:
511
- if isinstance(module, nn.Module):
512
- disable_dropout_in_model(module)
513
- if args.stop_token and args.stop_token == "eos":
514
- args.stop_token_id = self.processing_class.eos_token_id
515
- self.model = policy
516
- self.create_optimizer_and_scheduler(
517
- num_training_steps=args.num_total_batches
518
- ) # note that we are calling `self.lr_scheduler.step()` manually only at the batch level
519
-
520
- #########
521
- ### trainer specifics
522
- #########
523
- default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
524
- self.callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
525
- self.callback_handler = CallbackHandler(
526
- self.callbacks, self.model, self.processing_class, self.optimizer, self.lr_scheduler
527
- )
528
- self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
529
- self.control = TrainerControl()
530
- self.state = OnlineTrainerState(
531
- is_local_process_zero=self.is_local_process_zero(),
532
- is_world_process_zero=self.is_world_process_zero(),
533
- stateful_callbacks=[
534
- cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
535
- ],
536
- )
537
-
538
- self.current_flos = 0
539
- self.hp_search_backend = None
540
- self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
541
- self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
542
- # Create distant repo and output directory if needed
543
- self.hub_model_id = None
544
- if self.args.push_to_hub:
545
- self.init_hf_repo()
546
- if self.args.should_save:
547
- os.makedirs(self.args.output_dir, exist_ok=True)
548
- self.backup_model = None
549
-
550
- # Add tags for models that have been loaded with the correct transformers version
551
- if hasattr(self.model, "add_model_tags"):
552
- self.model.add_model_tags(self._tag_names)
553
-
554
- #########
555
- ### setup dataloader
556
- #########
557
- self.dataloader = DataLoader(
558
- self.train_dataset,
559
- batch_size=self.local_dataloader_batch_size,
560
- shuffle=True,
561
- collate_fn=self.data_collator,
562
- drop_last=True, # needed; otherwise the last batch will be of ragged shape
563
- )
564
- # sync random states for DataLoader(shuffle=True) before `accelerator.prepare`
565
- # see https://gist.github.com/vwxyzjn/2581bff1e48e185e0b85b6dfe1def79c
566
- torch.manual_seed(args.seed)
567
- self.model, self.optimizer, self.dataloader = accelerator.prepare(self.model, self.optimizer, self.dataloader)
568
- torch.manual_seed(self.local_seed) # reset the local seed again
569
-
570
- self.eval_dataloader = DataLoader(
571
- self.eval_dataset,
572
- batch_size=args.per_device_eval_batch_size,
573
- collate_fn=self.data_collator,
574
- drop_last=True,
575
- ) # no need to shuffle eval dataset
576
- self.eval_dataloader = accelerator.prepare(self.eval_dataloader)
577
-
578
- if self.is_deepspeed_enabled:
579
- if isinstance(self.reward_model, nn.Module):
580
- self.reward_model = prepare_deepspeed(
581
- self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16
582
- )
583
- self.ref_policy = prepare_deepspeed(
584
- self.ref_policy, args.per_device_train_batch_size, args.fp16, args.bf16
585
- )
586
- self.deepspeed = self.model
587
- else:
588
- self.ref_policy = self.ref_policy.to(self.accelerator.device)
589
- if isinstance(self.reward_model, nn.Module):
590
- self.reward_model = self.reward_model.to(self.accelerator.device)
591
-
592
- def get_train_dataloader(self) -> DataLoader:
593
- return self.dataloader
594
-
595
- def get_eval_dataloader(self) -> DataLoader:
596
- return self.eval_dataloader
597
-
598
- def train(self):
599
- args = self.args
600
- accelerator = self.accelerator
601
- optimizer = self.optimizer
602
- model = self.model
603
- self.model_wrapped = self.model
604
- ref_policy = self.ref_policy
605
- reward_model = self.reward_model
606
- processing_class = self.processing_class
607
- dataloader = self.dataloader
608
- device = accelerator.device
609
-
610
- def repeat_generator():
611
- while True:
612
- yield from dataloader
613
-
614
- iter_dataloader = iter(repeat_generator())
615
- generation_config = GenerationConfig(
616
- max_new_tokens=args.response_length,
617
- temperature=(args.temperature + 1e-7),
618
- top_k=0.0,
619
- top_p=1.0,
620
- do_sample=True,
621
- )
622
-
623
- accelerator.print("===training policy===")
624
- start_time = time.time()
625
- stats_shape = (args.num_ppo_epochs, args.num_mini_batches, args.gradient_accumulation_steps)
626
- approxkl_stats = torch.zeros(stats_shape, device=device)
627
- pg_clipfrac_stats = torch.zeros(stats_shape, device=device)
628
- pg_loss_stats = torch.zeros(stats_shape, device=device)
629
- vf_clipfrac_stats = torch.zeros(stats_shape, device=device)
630
- entropy_stats = torch.zeros(stats_shape, device=device)
631
- ratio_stats = torch.zeros(stats_shape, device=device)
632
- model.train()
633
-
634
- # trainer state initialization
635
- self.state.global_step = 0
636
- self.state.episode = 0
637
- self.state.max_steps = (args.num_total_batches * args.num_mini_batches) // 2
638
- self.state.num_train_epochs = args.total_episodes / self.train_dataset_len
639
- # Compute absolute values for logging, eval, and save if given as ratio
640
- if args.logging_steps is not None:
641
- if args.logging_steps < 1:
642
- self.state.logging_steps = math.ceil(self.state.max_steps * args.logging_steps)
643
- else:
644
- self.state.logging_steps = args.logging_steps
645
- if args.eval_steps is not None:
646
- if args.eval_steps < 1:
647
- self.state.eval_steps = math.ceil(self.state.max_steps * args.eval_steps)
648
- else:
649
- self.state.eval_steps = args.eval_steps
650
- if args.save_steps is not None:
651
- if args.save_steps < 1:
652
- self.state.save_steps = math.ceil(self.state.max_steps * args.save_steps)
653
- else:
654
- self.state.save_steps = args.save_steps
655
- self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
656
-
657
- for update in range(1, args.num_total_batches + 1):
658
- self.state.episode += 1 * args.batch_size
659
- data = next(iter_dataloader)
660
- with torch.no_grad():
661
- queries = data["input_ids"].to(device)
662
- queries = queries.repeat(args.rloo_k, 1)
663
- context_length = queries.shape[1]
664
- responses = []
665
- postprocessed_responses = []
666
- logprobs = []
667
- ref_logprobs = []
668
- scores = []
669
- sequence_lengths = []
670
-
671
- # Generate responses and compute logprobs
672
- with unwrap_model_for_generation(
673
- self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
674
- ) as unwrapped_model:
675
- query_responses, logitss = batch_generation(
676
- unwrapped_model,
677
- queries,
678
- args.local_rollout_forward_batch_size,
679
- processing_class.pad_token_id,
680
- generation_config,
681
- )
682
-
683
- # Process responses in batches
684
- for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size):
685
- query = queries[i : i + args.local_rollout_forward_batch_size]
686
- query_response = query_responses[i : i + args.local_rollout_forward_batch_size]
687
- response = query_response[:, context_length:]
688
- logits = logitss[i : i + args.local_rollout_forward_batch_size]
689
- logprob = selective_log_softmax(logits, response)
690
- del logits
691
- torch.cuda.empty_cache()
692
-
693
- ref_output = forward(ref_policy, query_response, processing_class.pad_token_id)
694
- ref_logits = ref_output.logits[:, context_length - 1 : -1]
695
- ref_logits /= args.temperature + 1e-7
696
- ref_logprob = selective_log_softmax(ref_logits, response)
697
- del ref_output, ref_logits
698
- torch.cuda.empty_cache()
699
-
700
- # Response Processing 1. truncate response after the first occurrence of `stop_token_id`
701
- postprocessed_response = response
702
- if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
703
- postprocessed_response = truncate_response(
704
- args.stop_token_id, processing_class.pad_token_id, response
705
- )
706
-
707
- # Response Processing 2. run reward model on the truncated responses
708
- postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
709
- sequence_length = first_true_indices(postprocessed_response == processing_class.pad_token_id) - 1
710
-
711
- if isinstance(reward_model, nn.Module):
712
- _, score, _ = get_reward(
713
- reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length
714
- )
715
- else:
716
- score = torch.tensor(
717
- reward_model(
718
- processing_class.batch_decode(postprocessed_query_response, skip_special_tokens=True)
719
- ),
720
- dtype=torch.float,
721
- ).to(device)
722
-
723
- # Store batch results
724
- responses.append(response)
725
- postprocessed_responses.append(postprocessed_response)
726
- logprobs.append(logprob)
727
- ref_logprobs.append(ref_logprob)
728
- sequence_lengths.append(sequence_length)
729
- scores.append(score)
730
-
731
- # Concatenate all batched results
732
- responses = torch.cat(responses, 0)
733
- postprocessed_responses = torch.cat(postprocessed_responses, 0)
734
- logprobs = torch.cat(logprobs, 0)
735
- ref_logprobs = torch.cat(ref_logprobs, 0)
736
- sequence_lengths = torch.cat(sequence_lengths, 0)
737
- scores = torch.cat(scores, 0)
738
- del (logprob, ref_logprob, score)
739
- torch.cuda.empty_cache()
740
- gc.collect()
741
-
742
- # Response Processing 3. filter response. Ensure that the sample contains stop_token_id
743
- # responses not passing that filter will receive a low (fixed) score
744
- # only query humans on responses that pass that filter
745
- contain_eos_token = torch.any(postprocessed_responses == processing_class.eos_token_id, dim=-1)
746
- if args.missing_eos_penalty is not None:
747
- scores[~contain_eos_token] -= self.args.missing_eos_penalty
748
- # accelerator.print(f"{scores=}, {(contain_eos_token.sum() / len(contain_eos_token))=}")
749
-
750
- # be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw
751
- response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1)
752
- padding_mask = response_idxs > sequence_lengths.unsqueeze(1)
753
- logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB)
754
- ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB)
755
-
756
- # 4. compute rewards
757
- # Compute KL divergence
758
- kl = logprobs - ref_logprobs
759
-
760
- # Normalize rewards
761
- if args.normalize_reward:
762
- scores = (scores - scores.mean()) / (scores.std() + 1e-8)
763
- scores = torch.clamp(scores, -args.reward_clip_range, args.reward_clip_range)
764
-
765
- # Compute total reward with KL penalty
766
- if args.token_level_kl:
767
- # Token-level KL penalty: apply KL penalty per token
768
- kl_reward = -args.kl_coef * kl
769
-
770
- # Get the index of the last non-padded token for each sequence
771
- eos_indices = padding_mask.size(1) - 1 - padding_mask.long().fliplr().argmax(dim=1, keepdim=True)
772
- last_reward = torch.zeros_like(kl)
773
- # Ensure scores has correct shape and type
774
- scores_shaped = scores.reshape(-1, 1).to(kl.dtype)
775
- last_reward.scatter_(dim=1, index=eos_indices, src=scores_shaped)
776
-
777
- # Combine KL reward and last reward
778
- non_score_reward = kl_reward.sum(1) # Keep this for logging
779
- reward = last_reward + kl_reward
780
- rlhf_reward = reward.sum(1) # Sum across sequence length
781
- else:
782
- # Sequence-level KL penalty: sum KL across tokens first
783
- sequence_kl = kl.sum(1)
784
- non_score_reward = -args.kl_coef * sequence_kl
785
- rlhf_reward = non_score_reward + scores
786
-
787
- # vectorized RLOO advantages implementation
788
- rlhf_reward = rlhf_reward.reshape(args.rloo_k, -1)
789
- baseline = (rlhf_reward.sum(0) - rlhf_reward) / (args.rloo_k - 1)
790
- advantages = rlhf_reward - baseline
791
- advantages = advantages.flatten()
792
-
793
- # Normalize advantages
794
- if args.normalize_advantage:
795
- advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
796
-
797
- torch.cuda.empty_cache()
798
-
799
- # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch
800
- for ppo_epoch_idx in range(args.num_ppo_epochs):
801
- b_inds = np.random.permutation(args.local_batch_size)
802
- minibatch_idx = 0
803
- for mini_batch_start in range(0, args.local_batch_size, args.local_mini_batch_size):
804
- mini_batch_end = mini_batch_start + args.local_mini_batch_size
805
- mini_batch_inds = b_inds[mini_batch_start:mini_batch_end]
806
- gradient_accumulation_idx = 0
807
- for micro_batch_start in range(0, args.local_mini_batch_size, args.per_device_train_batch_size):
808
- with accelerator.accumulate(model):
809
- micro_batch_end = micro_batch_start + args.per_device_train_batch_size
810
- micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end]
811
-
812
- # Get batch data
813
- mb_advantage = advantages[micro_batch_inds]
814
- mb_responses = responses[micro_batch_inds]
815
- mb_query_responses = query_responses[micro_batch_inds]
816
- mb_logprobs = logprobs[micro_batch_inds]
817
-
818
- # Forward pass
819
- output = forward(model, mb_query_responses, processing_class.pad_token_id)
820
- logits = output.logits[:, context_length - 1 : -1]
821
- logits /= args.temperature + 1e-7
822
-
823
- # Compute new logprobs
824
- new_logprobs = selective_log_softmax(logits, mb_responses)
825
- new_logprobs = torch.masked_fill(
826
- new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB
827
- )
828
-
829
- # Compute probability ratios
830
- new_ratio = (new_logprobs - mb_logprobs).exp()
831
- new_logprobs = new_logprobs.sum(1)
832
- mb_logprobs = mb_logprobs.sum(1)
833
- logprobs_diff = new_logprobs - mb_logprobs
834
- ratio = torch.exp(logprobs_diff)
835
-
836
- # PPO clipped loss
837
- pg_losses = -mb_advantage * ratio
838
- pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange)
839
- pg_loss_max = torch.max(pg_losses, pg_losses2)
840
- pg_loss = pg_loss_max.mean()
841
-
842
- # Final loss
843
- loss = pg_loss
844
-
845
- # Optimization step
846
- accelerator.backward(loss)
847
- optimizer.step()
848
- optimizer.zero_grad()
849
-
850
- with torch.no_grad():
851
- pg_clipfrac = (pg_losses2 > pg_losses).float().mean()
852
- prob_dist = torch.nn.functional.softmax(logits, dim=-1)
853
- entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1)
854
- approxkl = 0.5 * (logprobs_diff**2).mean()
855
- approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl
856
- pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = (
857
- pg_clipfrac
858
- )
859
- pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss
860
- entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean()
861
- ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = new_ratio.mean()
862
- gradient_accumulation_idx += 1
863
- minibatch_idx += 1
864
-
865
- # del everything and empty cache
866
- # fmt: off
867
- del (
868
- output, logits, new_logprobs, logprobs_diff, ratio, pg_losses,
869
- pg_losses2, pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl,
870
- mb_advantage, mb_responses, mb_query_responses, mb_logprobs,
871
- )
872
- # fmt: on
873
- torch.cuda.empty_cache()
874
-
875
- # Compute metrics
876
- with torch.no_grad():
877
- mean_kl = kl.sum(1).mean()
878
- mean_entropy = (-logprobs).sum(1).mean()
879
- mean_non_score_reward = non_score_reward.mean()
880
- eps = int(self.state.episode / (time.time() - start_time))
881
- metrics = {}
882
- metrics["eps"] = eps
883
- metrics["objective/kl"] = self.accelerator.gather_for_metrics(mean_kl).mean().item()
884
- metrics["objective/entropy"] = self.accelerator.gather_for_metrics(mean_entropy).mean().item()
885
- metrics["objective/non_score_reward"] = (
886
- self.accelerator.gather_for_metrics(mean_non_score_reward).mean().item()
887
- )
888
- metrics["objective/rlhf_reward"] = self.accelerator.gather_for_metrics(rlhf_reward).mean().item()
889
- metrics["objective/scores"] = self.accelerator.gather_for_metrics(scores.mean()).mean().item()
890
- metrics["policy/approxkl_avg"] = self.accelerator.gather_for_metrics(approxkl_stats).mean().item()
891
- metrics["policy/clipfrac_avg"] = self.accelerator.gather_for_metrics(pg_clipfrac_stats).mean().item()
892
- metrics["loss/policy_avg"] = self.accelerator.gather_for_metrics(pg_loss_stats).mean().item()
893
- metrics["val/clipfrac_avg"] = self.accelerator.gather_for_metrics(vf_clipfrac_stats).mean().item()
894
- metrics["policy/entropy_avg"] = self.accelerator.gather_for_metrics(entropy_stats).mean().item()
895
- metrics["val/ratio"] = self.accelerator.gather_for_metrics(ratio_stats).mean().item()
896
- metrics["val/ratio_var"] = self.accelerator.gather_for_metrics(ratio_stats).var().item()
897
- metrics["val/num_eos_tokens"] = (responses == processing_class.eos_token_id).sum().item()
898
- metrics["lr"] = self.lr_scheduler.get_last_lr()[0]
899
- metrics["episode"] = self.state.episode
900
- self.state.epoch = self.state.episode / (args.rloo_k * self.train_dataset_len) # used by self.log
901
- self.log(metrics)
902
- del kl, mean_kl, mean_entropy, scores
903
-
904
- self.lr_scheduler.step()
905
- self.state.global_step += 1
906
- self.control = self.callback_handler.on_step_end(args, self.state, self.control)
907
- if self.control.should_save:
908
- self._save_checkpoint(model, trial=None)
909
- self.control = self.callback_handler.on_save(self.args, self.state, self.control)
910
- torch.cuda.empty_cache()
911
- gc.collect()
912
-
913
- if args.num_sample_generations > 0 and (update - 1) % self.sample_generations_freq == 0:
914
- self.generate_completions(sampling=True)
915
-
916
- # HF trainer specifics
917
- self.control = self.callback_handler.on_train_end(args, self.state, self.control)
918
- if self.control.should_save:
919
- self._save_checkpoint(model, trial=None, metrics=None)
920
- self.control = self.callback_handler.on_save(self.args, self.state, self.control)
921
-
922
- def generate_completions(self, sampling: bool = False):
923
- args = self.args
924
- processing_class = self.processing_class
925
- generation_config = GenerationConfig(
926
- max_new_tokens=self.args.response_length,
927
- temperature=(0.01 + 1e-7),
928
- top_k=0.0,
929
- top_p=1.0,
930
- do_sample=True,
931
- )
932
-
933
- table = defaultdict(list)
934
- with unwrap_model_for_generation(
935
- self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
936
- ) as unwrapped_model:
937
- for batch in self.eval_dataloader:
938
- query = batch["input_ids"]
939
- with torch.no_grad():
940
- context_length = query.shape[1]
941
- query_response, _ = batch_generation(
942
- unwrapped_model,
943
- query,
944
- query.shape[0],
945
- processing_class.pad_token_id,
946
- generation_config,
947
- )
948
- response = query_response[:, context_length:]
949
- postprocessed_response = response
950
- if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
951
- postprocessed_response = truncate_response(
952
- args.stop_token_id, processing_class.pad_token_id, response
953
- )
954
- table["query"].extend(
955
- gather_object(processing_class.batch_decode(query, skip_special_tokens=True))
956
- )
957
- table["model response"].extend(
958
- gather_object(processing_class.batch_decode(postprocessed_response))
959
- )
960
-
961
- postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
962
-
963
- if isinstance(self.reward_model, nn.Module):
964
- _, score, _ = get_reward(
965
- self.reward_model,
966
- postprocessed_query_response,
967
- processing_class.pad_token_id,
968
- context_length,
969
- )
970
- else:
971
- score = torch.tensor(
972
- self.reward_model(
973
- processing_class.batch_decode(postprocessed_query_response, skip_special_tokens=True)
974
- ),
975
- dtype=torch.float,
976
- ).to(postprocessed_query_response.device)
977
- table["score"].extend(self.accelerator.gather_for_metrics(score).float().cpu().numpy())
978
-
979
- if sampling:
980
- break
981
- df = pd.DataFrame(table)
982
-
983
- if self.accelerator.is_main_process:
984
- print_rich_table(df.iloc[0 : 0 + 5])
985
- if "wandb" in args.report_to:
986
- import wandb
987
-
988
- if wandb.run is not None:
989
- wandb.log({"completions": wandb.Table(dataframe=df)})
990
-
991
- if "comet_ml" in args.report_to:
992
- log_table_to_comet_experiment(
993
- name="completions.csv",
994
- table=df,
995
- )
996
-
997
- def create_model_card(
998
- self,
999
- model_name: Optional[str] = None,
1000
- dataset_name: Optional[str] = None,
1001
- tags: Union[str, list[str], None] = None,
1002
- ):
1003
- """
1004
- Creates a draft of a model card using the information available to the `Trainer`.
1005
-
1006
- Args:
1007
- model_name (`str` or `None`, *optional*, defaults to `None`):
1008
- Name of the model.
1009
- dataset_name (`str` or `None`, *optional*, defaults to `None`):
1010
- Name of the dataset used for training.
1011
- tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
1012
- Tags to be associated with the model card.
1013
- """
1014
- if not self.is_world_process_zero():
1015
- return
1016
-
1017
- if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
1018
- base_model = self.model.config._name_or_path
1019
- else:
1020
- base_model = None
1021
-
1022
- tags = tags or []
1023
- if isinstance(tags, str):
1024
- tags = [tags]
1025
-
1026
- if hasattr(self.model.config, "unsloth_version"):
1027
- tags.append("unsloth")
1028
-
1029
- citation = textwrap.dedent("""\
1030
- @inproceedings{ahmadian2024back,
1031
- title = {{Back to Basics: Revisiting REINFORCE-Style Optimization for Learning from Human Feedback in LLMs}},
1032
- author = {Arash Ahmadian and Chris Cremer and Matthias Gall{\'{e}} and Marzieh Fadaee and Julia Kreutzer and Olivier Pietquin and Ahmet {\"{U}}st{\"{u}}n and Sara Hooker},
1033
- year = 2024,
1034
- booktitle = {Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), {ACL} 2024, Bangkok, Thailand, August 11-16, 2024},
1035
- publisher = {Association for Computational Linguistics},
1036
- pages = {12248--12267},
1037
- editor = {Lun{-}Wei Ku and Andre Martins and Vivek Srikumar},
1038
- }""")
1039
-
1040
- model_card = generate_model_card(
1041
- base_model=base_model,
1042
- model_name=model_name,
1043
- hub_model_id=self.hub_model_id,
1044
- dataset_name=dataset_name,
1045
- tags=tags,
1046
- wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
1047
- comet_url=get_comet_experiment_url(),
1048
- trainer_name="RLOO",
1049
- trainer_citation=citation,
1050
- paper_title="Back to Basics: Revisiting REINFORCE-Style Optimization for Learning from Human Feedback in LLMs",
1051
- paper_id="2402.14740",
1052
- )
1053
-
1054
- model_card.save(os.path.join(self.args.output_dir, "README.md"))
1055
- class UnslothRLOOTrainer(_UnslothRLOOTrainer):
1056
- """
1057
-
1058
- """
1059
- def __init__(
1060
- self,
1061
- config,
1062
- processing_class,
1063
- policy,
1064
- ref_policy,
1065
- reward_model,
1066
- train_dataset,
1067
- data_collator = None,
1068
- eval_dataset = None,
1069
- callbacks = None,
1070
- **kwargs
1071
- ):
1072
- if args is None: args = UnslothRLOOConfig()
1073
- _output_logits = False
1074
- if locals().get('compute_metrics', None) is not None: _output_logits = True
1075
- if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
1076
- if _output_logits:
1077
- os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
1078
- if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1079
- pass
1080
- else:
1081
- model_max_seq_length = getattr(model, 'max_seq_length', None)
1082
- args_max_seq_length = getattr(args, 'max_seq_length', None)
1083
- if args_max_seq_length is None and model_max_seq_length is not None:
1084
- max_seq_length = model.max_seq_length
1085
- if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1086
- if model is not None and hasattr(model, 'for_training'):
1087
- model.for_training()
1088
- if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1089
- if 'processing_class' in locals():
1090
- if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1091
- if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1092
- __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
1093
- from unsloth_zoo.vision_utils import UnslothVisionDataCollator
1094
- if not isinstance(data_collator, UnslothVisionDataCollator):
1095
- if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
1096
- data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
1097
- elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
1098
- data_collator = DataCollatorForSeq2Seq(__tokenizer)
1099
- else:
1100
- if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
1101
- if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
1102
- if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
1103
- if not isinstance(data_collator, UnslothVisionDataCollator):
1104
- if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
1105
- if isinstance(data_collator, DataCollatorForSeq2Seq):
1106
- data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
1107
- else:
1108
- data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
1109
- other_metrics = []
1110
-
1111
- from unsloth_zoo.logging_utils import PatchRLStatistics
1112
- PatchRLStatistics('rloo_trainer', other_metrics)
1113
-
1114
- super().__init__(
1115
- config = config,
1116
- processing_class = processing_class,
1117
- policy = policy,
1118
- ref_policy = ref_policy,
1119
- reward_model = reward_model,
1120
- train_dataset = train_dataset,
1121
- data_collator = data_collator,
1122
- eval_dataset = eval_dataset,
1123
- callbacks = callbacks,**kwargs)
1124
- if hasattr(self, 'neftune_hook_handle'):
1125
- self.neftune_hook_handle.remove()
1126
- if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1127
- if getattr(args, 'neftune_noise_alpha', None) is not None:
1128
- model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1129
- pass
1130
-
1131
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
unsloth_compiled_cache/UnslothRewardTrainer.py DELETED
@@ -1,817 +0,0 @@
1
- """
2
- 2025.3.13
3
- 2025.3.15
4
- 4.48.3
5
- 0.15.2
6
- __UNSLOTH_VERSIONING__
7
- """
8
- from torch import Tensor
9
- import torch
10
- import torch.nn as nn
11
- from torch.nn import functional as F
12
- from trl.trainer.reward_trainer import (Any, BaseImageProcessor, Callable, DataCollator, Dataset, EvalPrediction, FeatureExtractionMixin, FrozenInstanceError, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, RewardConfig, RewardDataCollatorWithPadding, RewardTrainer, Trainer, TrainerCallback, Union, _tokenize, compute_accuracy, decode_and_strip_padding, defaultdict, disable_dropout_in_model, gather_object, generate_model_card, get_comet_experiment_url, inspect, is_peft_available, is_wandb_available, log_table_to_comet_experiment, maybe_apply_chat_template, nested_detach, nn, os, pd, prepare_model_for_kbit_training, print_rich_table, replace, torch, wandb, warnings)
13
-
14
-
15
- import os
16
- from typing import *
17
- from dataclasses import dataclass, field
18
- from packaging.version import Version
19
- import torch
20
- import numpy as np
21
- from contextlib import nullcontext
22
- from torch.nn import functional as F
23
- from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
-
25
- torch_compile_options = {
26
- "epilogue_fusion" : True,
27
- "max_autotune" : False,
28
- "shape_padding" : True,
29
- "trace.enabled" : False,
30
- "triton.cudagraphs" : False,
31
- }
32
-
33
- @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
- def selective_log_softmax(logits, index):
35
- logits = logits.to(torch.float32)
36
- selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
- # loop to reduce peak mem consumption
38
- # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
- logsumexp_values = torch.logsumexp(logits, dim = -1)
40
- per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
- return per_token_logps
42
- @dataclass
43
- class UnslothRewardConfig(RewardConfig):
44
- """
45
-
46
- Configuration class for the [`RewardTrainer`].
47
-
48
- Using [`~transformers.HfArgumentParser`] we can turn this class into
49
- [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
50
- command line.
51
-
52
- Parameters:
53
- max_length (`int` or `None`, *optional*, defaults to `1024`):
54
- Maximum length of the sequences (prompt + completion) in the batch, filters out entries that exceed the
55
- limit. This argument is required if you want to use the default data collator.
56
- disable_dropout (`bool`, *optional*, defaults to `True`):
57
- Whether to disable dropout in the model.
58
- dataset_num_proc (`int`, *optional*, defaults to `None`):
59
- Number of processes to use for processing the dataset.
60
- center_rewards_coefficient (`float`, *optional*, defaults to `None`):
61
- Coefficient to incentivize the reward model to output mean-zero rewards (proposed by
62
- https://huggingface.co/papers/2312.09244, Eq. 2). Recommended value: `0.01`.
63
- remove_unused_columns (`bool`, *optional*, defaults to `False`):
64
- Whether to remove the columns that are not used by the model's forward pass. Can be `True` only if
65
- the dataset is pretokenized.
66
-
67
- """
68
- vllm_sampling_params: Optional[Any] = field(
69
- default = None,
70
- metadata = {'help': 'vLLM SamplingParams'},
71
- )
72
- unsloth_num_chunks : Optional[int] = field(
73
- default = -1,
74
- metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
75
- )
76
- def __init__(
77
- self,
78
- output_dir = None,
79
- overwrite_output_dir = None,
80
- do_train = False,
81
- do_eval = False,
82
- do_predict = False,
83
- eval_strategy = 'no',
84
- prediction_loss_only = False,
85
- per_device_train_batch_size = 4,
86
- per_device_eval_batch_size = 4,
87
- per_gpu_train_batch_size = None,
88
- per_gpu_eval_batch_size = None,
89
- gradient_accumulation_steps = 2,
90
- eval_accumulation_steps = 2,
91
- eval_delay = 0,
92
- torch_empty_cache_steps = 250,
93
- learning_rate = 5e-05,
94
- weight_decay = 0.01,
95
- adam_beta1 = 0.9,
96
- adam_beta2 = 0.999,
97
- adam_epsilon = 1e-08,
98
- max_grad_norm = 1.0,
99
- num_train_epochs = 3.0,
100
- max_steps = -1,
101
- lr_scheduler_type = 'linear',
102
- warmup_ratio = 0.1,
103
- warmup_steps = 0,
104
- log_level = 'passive',
105
- log_level_replica = 'warning',
106
- log_on_each_node = True,
107
- logging_dir = None,
108
- logging_strategy = 'steps',
109
- logging_first_step = False,
110
- logging_steps = 1,
111
- logging_nan_inf_filter = False,
112
- save_strategy = 'steps',
113
- save_steps = 500,
114
- save_total_limit = None,
115
- save_safetensors = True,
116
- save_on_each_node = False,
117
- save_only_model = False,
118
- restore_callback_states_from_checkpoint = False,
119
- no_cuda = False,
120
- use_cpu = False,
121
- use_mps_device = False,
122
- seed = 3407,
123
- data_seed = 3407,
124
- jit_mode_eval = False,
125
- use_ipex = False,
126
- bf16 = False,
127
- fp16 = False,
128
- fp16_opt_level = 'O1',
129
- half_precision_backend = 'auto',
130
- bf16_full_eval = False,
131
- fp16_full_eval = False,
132
- tf32 = None,
133
- local_rank = -1,
134
- ddp_backend = None,
135
- tpu_num_cores = None,
136
- tpu_metrics_debug = False,
137
- debug = '',
138
- dataloader_drop_last = False,
139
- eval_steps = None,
140
- dataloader_num_workers = 0,
141
- dataloader_prefetch_factor = None,
142
- past_index = -1,
143
- run_name = None,
144
- disable_tqdm = None,
145
- remove_unused_columns = False,
146
- label_names = None,
147
- load_best_model_at_end = False,
148
- metric_for_best_model = None,
149
- greater_is_better = None,
150
- ignore_data_skip = False,
151
- fsdp = '',
152
- fsdp_min_num_params = 0,
153
- fsdp_config = None,
154
- fsdp_transformer_layer_cls_to_wrap = None,
155
- accelerator_config = None,
156
- deepspeed = None,
157
- label_smoothing_factor = 0.0,
158
- optim = 'adamw_8bit',
159
- optim_args = None,
160
- adafactor = False,
161
- group_by_length = False,
162
- length_column_name = 'length',
163
- report_to = None,
164
- ddp_find_unused_parameters = None,
165
- ddp_bucket_cap_mb = None,
166
- ddp_broadcast_buffers = None,
167
- dataloader_pin_memory = True,
168
- dataloader_persistent_workers = False,
169
- skip_memory_metrics = True,
170
- use_legacy_prediction_loop = False,
171
- push_to_hub = False,
172
- resume_from_checkpoint = None,
173
- hub_model_id = None,
174
- hub_strategy = 'every_save',
175
- hub_token = None,
176
- hub_private_repo = None,
177
- hub_always_push = False,
178
- gradient_checkpointing = False,
179
- gradient_checkpointing_kwargs = None,
180
- include_inputs_for_metrics = False,
181
- eval_do_concat_batches = True,
182
- fp16_backend = 'auto',
183
- evaluation_strategy = None,
184
- push_to_hub_model_id = None,
185
- push_to_hub_organization = None,
186
- push_to_hub_token = None,
187
- mp_parameters = '',
188
- auto_find_batch_size = False,
189
- full_determinism = False,
190
- torchdynamo = None,
191
- ray_scope = 'last',
192
- ddp_timeout = 1800,
193
- torch_compile = False,
194
- torch_compile_backend = None,
195
- torch_compile_mode = None,
196
- dispatch_batches = None,
197
- split_batches = None,
198
- include_tokens_per_second = False,
199
- include_num_input_tokens_seen = False,
200
- neftune_noise_alpha = None,
201
- optim_target_modules = None,
202
- batch_eval_metrics = False,
203
- eval_on_start = False,
204
- use_liger_kernel = False,
205
- eval_use_gather_object = False,
206
- average_tokens_across_devices = False,
207
- max_length = 1024,
208
- disable_dropout = True,
209
- dataset_num_proc = None,
210
- center_rewards_coefficient = None,
211
- vllm_sampling_params = None,
212
- unsloth_num_chunks = -1,
213
- **kwargs,
214
- ):
215
- if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
216
- if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
217
- if output_dir is None and save_strategy == 'steps' and save_steps == 500:
218
- output_dir = 'unsloth_training_checkpoints'
219
- save_strategy = 'no'
220
- if dataset_num_proc is None:
221
- from multiprocessing import cpu_count
222
- dataset_num_proc = cpu_count()
223
-
224
- super().__init__(
225
- output_dir = output_dir,
226
- overwrite_output_dir = overwrite_output_dir,
227
- do_train = do_train,
228
- do_eval = do_eval,
229
- do_predict = do_predict,
230
- eval_strategy = eval_strategy,
231
- prediction_loss_only = prediction_loss_only,
232
- per_device_train_batch_size = per_device_train_batch_size,
233
- per_device_eval_batch_size = per_device_eval_batch_size,
234
- per_gpu_train_batch_size = per_gpu_train_batch_size,
235
- per_gpu_eval_batch_size = per_gpu_eval_batch_size,
236
- gradient_accumulation_steps = gradient_accumulation_steps,
237
- eval_accumulation_steps = eval_accumulation_steps,
238
- eval_delay = eval_delay,
239
- torch_empty_cache_steps = torch_empty_cache_steps,
240
- learning_rate = learning_rate,
241
- weight_decay = weight_decay,
242
- adam_beta1 = adam_beta1,
243
- adam_beta2 = adam_beta2,
244
- adam_epsilon = adam_epsilon,
245
- max_grad_norm = max_grad_norm,
246
- num_train_epochs = num_train_epochs,
247
- max_steps = max_steps,
248
- lr_scheduler_type = lr_scheduler_type,
249
- warmup_ratio = warmup_ratio,
250
- warmup_steps = warmup_steps,
251
- log_level = log_level,
252
- log_level_replica = log_level_replica,
253
- log_on_each_node = log_on_each_node,
254
- logging_dir = logging_dir,
255
- logging_strategy = logging_strategy,
256
- logging_first_step = logging_first_step,
257
- logging_steps = logging_steps,
258
- logging_nan_inf_filter = logging_nan_inf_filter,
259
- save_strategy = save_strategy,
260
- save_steps = save_steps,
261
- save_total_limit = save_total_limit,
262
- save_safetensors = save_safetensors,
263
- save_on_each_node = save_on_each_node,
264
- save_only_model = save_only_model,
265
- restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
266
- no_cuda = no_cuda,
267
- use_cpu = use_cpu,
268
- use_mps_device = use_mps_device,
269
- seed = seed,
270
- data_seed = data_seed,
271
- jit_mode_eval = jit_mode_eval,
272
- use_ipex = use_ipex,
273
- bf16 = bf16,
274
- fp16 = fp16,
275
- fp16_opt_level = fp16_opt_level,
276
- half_precision_backend = half_precision_backend,
277
- bf16_full_eval = bf16_full_eval,
278
- fp16_full_eval = fp16_full_eval,
279
- tf32 = tf32,
280
- local_rank = local_rank,
281
- ddp_backend = ddp_backend,
282
- tpu_num_cores = tpu_num_cores,
283
- tpu_metrics_debug = tpu_metrics_debug,
284
- debug = debug,
285
- dataloader_drop_last = dataloader_drop_last,
286
- eval_steps = eval_steps,
287
- dataloader_num_workers = dataloader_num_workers,
288
- dataloader_prefetch_factor = dataloader_prefetch_factor,
289
- past_index = past_index,
290
- run_name = run_name,
291
- disable_tqdm = disable_tqdm,
292
- remove_unused_columns = remove_unused_columns,
293
- label_names = label_names,
294
- load_best_model_at_end = load_best_model_at_end,
295
- metric_for_best_model = metric_for_best_model,
296
- greater_is_better = greater_is_better,
297
- ignore_data_skip = ignore_data_skip,
298
- fsdp = fsdp,
299
- fsdp_min_num_params = fsdp_min_num_params,
300
- fsdp_config = fsdp_config,
301
- fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
302
- accelerator_config = accelerator_config,
303
- deepspeed = deepspeed,
304
- label_smoothing_factor = label_smoothing_factor,
305
- optim = optim,
306
- optim_args = optim_args,
307
- adafactor = adafactor,
308
- group_by_length = group_by_length,
309
- length_column_name = length_column_name,
310
- report_to = report_to,
311
- ddp_find_unused_parameters = ddp_find_unused_parameters,
312
- ddp_bucket_cap_mb = ddp_bucket_cap_mb,
313
- ddp_broadcast_buffers = ddp_broadcast_buffers,
314
- dataloader_pin_memory = dataloader_pin_memory,
315
- dataloader_persistent_workers = dataloader_persistent_workers,
316
- skip_memory_metrics = skip_memory_metrics,
317
- use_legacy_prediction_loop = use_legacy_prediction_loop,
318
- push_to_hub = push_to_hub,
319
- resume_from_checkpoint = resume_from_checkpoint,
320
- hub_model_id = hub_model_id,
321
- hub_strategy = hub_strategy,
322
- hub_token = hub_token,
323
- hub_private_repo = hub_private_repo,
324
- hub_always_push = hub_always_push,
325
- gradient_checkpointing = gradient_checkpointing,
326
- gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
327
- include_inputs_for_metrics = include_inputs_for_metrics,
328
- eval_do_concat_batches = eval_do_concat_batches,
329
- fp16_backend = fp16_backend,
330
- evaluation_strategy = evaluation_strategy,
331
- push_to_hub_model_id = push_to_hub_model_id,
332
- push_to_hub_organization = push_to_hub_organization,
333
- push_to_hub_token = push_to_hub_token,
334
- mp_parameters = mp_parameters,
335
- auto_find_batch_size = auto_find_batch_size,
336
- full_determinism = full_determinism,
337
- torchdynamo = torchdynamo,
338
- ray_scope = ray_scope,
339
- ddp_timeout = ddp_timeout,
340
- torch_compile = torch_compile,
341
- torch_compile_backend = torch_compile_backend,
342
- torch_compile_mode = torch_compile_mode,
343
- dispatch_batches = dispatch_batches,
344
- split_batches = split_batches,
345
- include_tokens_per_second = include_tokens_per_second,
346
- include_num_input_tokens_seen = include_num_input_tokens_seen,
347
- neftune_noise_alpha = neftune_noise_alpha,
348
- optim_target_modules = optim_target_modules,
349
- batch_eval_metrics = batch_eval_metrics,
350
- eval_on_start = eval_on_start,
351
- use_liger_kernel = use_liger_kernel,
352
- eval_use_gather_object = eval_use_gather_object,
353
- average_tokens_across_devices = average_tokens_across_devices,
354
- max_length = max_length,
355
- disable_dropout = disable_dropout,
356
- dataset_num_proc = dataset_num_proc,
357
- center_rewards_coefficient = center_rewards_coefficient,**kwargs)
358
- self.vllm_sampling_params = vllm_sampling_params
359
- self.unsloth_num_chunks = unsloth_num_chunks
360
- pass
361
-
362
- class _UnslothRewardTrainer(Trainer):
363
- _tag_names = ["trl", "reward-trainer"]
364
-
365
- def __init__(
366
- self,
367
- model: Optional[Union[PreTrainedModel, nn.Module]] = None,
368
- args: Optional[RewardConfig] = None,
369
- data_collator: Optional[DataCollator] = None,
370
- train_dataset: Optional[Dataset] = None,
371
- eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
372
- processing_class: Optional[
373
- Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
374
- ] = None,
375
- model_init: Optional[Callable[[], PreTrainedModel]] = None,
376
- compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
377
- callbacks: Optional[list[TrainerCallback]] = None,
378
- optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (
379
- None,
380
- None,
381
- ),
382
- preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
383
- peft_config: Optional[dict] = None,
384
- ):
385
- """
386
- Initialize RewardTrainer.
387
-
388
- Args:
389
- model (`transformers.PreTrainedModel`):
390
- The model to train, preferably an `AutoModelForSequenceClassification`.
391
- args (`RewardConfig`):
392
- The arguments to use for training.
393
- data_collator (`transformers.DataCollator`):
394
- The data collator to use for training. If None is specified, the default data collator (`RewardDataCollatorWithPadding`) will be used
395
- which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
396
- train_dataset (`datasets.Dataset`):
397
- The dataset to use for training.
398
- eval_dataset (`datasets.Dataset`):
399
- The dataset to use for evaluation.
400
- processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
401
- Processing class used to process the data. If provided, will be used to automatically process the inputs
402
- for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
403
- reuse the fine-tuned model.
404
- model_init (`Callable[[], transformers.PreTrainedModel]`):
405
- The model initializer to use for training. If None is specified, the default model initializer will be used.
406
- compute_metrics (`Callable[[transformers.EvalPrediction], dict]`, *optional* defaults to `compute_accuracy`):
407
- The metrics to use for evaluation. If no metrics are specified, the default metric (`compute_accuracy`) will be used.
408
- callbacks (`list[transformers.TrainerCallback]`):
409
- The callbacks to use for training.
410
- optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
411
- The optimizer and scheduler to use for training.
412
- preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
413
- The function to use to preprocess the logits before computing the metrics.
414
- peft_config (`dict`, defaults to `None`):
415
- The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
416
- """
417
- if not is_peft_available() and peft_config is not None:
418
- raise ValueError(
419
- "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
420
- )
421
- elif is_peft_available() and peft_config is not None:
422
- if not isinstance(model, PeftModel):
423
- if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_quantized", False):
424
- _supports_gc_kwargs = "gradient_checkpointing_kwargs" in list(
425
- inspect.signature(prepare_model_for_kbit_training).parameters
426
- )
427
-
428
- prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
429
-
430
- if not _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
431
- warnings.warn(
432
- "You passed `gradient_checkpointing_kwargs` in the trainer's kwargs, but your peft version does not support it. "
433
- "please update to the latest version of peft to use `gradient_checkpointing_kwargs`.",
434
- UserWarning,
435
- )
436
- elif _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
437
- prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
438
-
439
- model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
440
-
441
- model = model
442
-
443
- # Disable dropout in the model
444
- if args.disable_dropout:
445
- disable_dropout_in_model(model)
446
-
447
- if compute_metrics is None:
448
- compute_metrics = compute_accuracy
449
-
450
- if data_collator is None:
451
- if processing_class is None:
452
- raise ValueError(
453
- "A processing_class must be specified when using the default RewardDataCollatorWithPadding"
454
- )
455
-
456
- max_length = args.max_length
457
-
458
- data_collator = RewardDataCollatorWithPadding(processing_class)
459
-
460
- if args.remove_unused_columns:
461
- try: # for bc before https://github.com/huggingface/transformers/pull/25435
462
- args.remove_unused_columns = False
463
- except FrozenInstanceError:
464
- args = replace(args, remove_unused_columns=False)
465
- # warn users
466
- warnings.warn(
467
- "When using RewardDataCollatorWithPadding, you should set `remove_unused_columns=False` in your RewardConfig"
468
- " we have set it for you, but you should do it yourself in the future.",
469
- UserWarning,
470
- )
471
-
472
- self.use_reward_data_collator = True
473
- else:
474
- self.use_reward_data_collator = False
475
-
476
- # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
477
- # input tensor associated with the key "input_ids". However, in Reward, the sampled data does not include the
478
- # "input_ids" key. Instead, the available keys are "input_ids_chosen" and "input_ids_rejected". As a result,
479
- # the trainer issues the warning: "Could not estimate the number of tokens of the input, floating-point
480
- # operations will not be computed." To suppress this warning, we set the "estimate_tokens" key in the model's
481
- # "warnings_issued" dictionary to True. This acts as a flag to indicate that the warning has already been
482
- # issued.
483
- model.warnings_issued["estimate_tokens"] = True
484
-
485
- if "input_ids_chosen" not in train_dataset.column_names:
486
- with PartialState().local_main_process_first():
487
- fn_kwargs = {"tokenizer": processing_class}
488
- train_dataset = train_dataset.map(maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class})
489
- train_dataset = train_dataset.map(
490
- _tokenize,
491
- batched=True,
492
- fn_kwargs=fn_kwargs,
493
- num_proc=args.dataset_num_proc,
494
- )
495
- # This filter is important because otherwise you get samples that exceed the model's context length and
496
- # get truncated => noisy signal the chosen/rejected label gets lost. The downside is that the
497
- # user might get surprised if N samples are missing from training.
498
- train_dataset = train_dataset.filter(
499
- lambda x: len(x["input_ids_chosen"]) <= max_length and len(x["input_ids_rejected"]) <= max_length,
500
- num_proc=args.dataset_num_proc,
501
- )
502
- if eval_dataset is not None:
503
- eval_dataset = eval_dataset.map(
504
- maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}
505
- )
506
- eval_dataset = eval_dataset.map(
507
- _tokenize,
508
- fn_kwargs=fn_kwargs,
509
- batched=True,
510
- num_proc=args.dataset_num_proc,
511
- )
512
- # This filter is important because otherwise you get samples that exceed the model's context length and
513
- # get truncated => noisy signal the chosen/rejected label gets lost. The downside is that the
514
- # user might get surprised if N samples are missing from training.
515
- eval_dataset = eval_dataset.filter(
516
- lambda x: len(x["input_ids_chosen"]) <= max_length
517
- and len(x["input_ids_rejected"]) <= max_length,
518
- num_proc=args.dataset_num_proc,
519
- )
520
-
521
- super().__init__(
522
- model=model,
523
- args=args,
524
- data_collator=data_collator,
525
- train_dataset=train_dataset,
526
- eval_dataset=eval_dataset,
527
- processing_class=processing_class,
528
- model_init=model_init,
529
- compute_metrics=compute_metrics,
530
- callbacks=callbacks,
531
- optimizers=optimizers,
532
- preprocess_logits_for_metrics=preprocess_logits_for_metrics,
533
- )
534
-
535
- # Add tags for models that have been loaded with the correct transformers version
536
- if hasattr(self.model, "add_model_tags"):
537
- self.model.add_model_tags(self._tag_names)
538
-
539
- def compute_loss(
540
- self,
541
- model: Union[PreTrainedModel, nn.Module],
542
- inputs: dict[str, Union[torch.Tensor, Any]],
543
- return_outputs=False,
544
- num_items_in_batch=None,
545
- ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
546
- rewards_chosen = model(
547
- input_ids=inputs["input_ids_chosen"],
548
- attention_mask=inputs["attention_mask_chosen"],
549
- return_dict=True,
550
- )["logits"]
551
- rewards_rejected = model(
552
- input_ids=inputs["input_ids_rejected"],
553
- attention_mask=inputs["attention_mask_rejected"],
554
- return_dict=True,
555
- )["logits"]
556
- # calculate loss, optionally modulate with margin
557
- if "margin" in inputs:
558
- loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected - inputs["margin"]).mean()
559
- else:
560
- loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected).mean()
561
-
562
- if self.args.center_rewards_coefficient is not None:
563
- loss += self.args.center_rewards_coefficient * torch.mean((rewards_chosen + rewards_rejected) ** 2)
564
-
565
- if return_outputs:
566
- return loss, {
567
- "rewards_chosen": rewards_chosen,
568
- "rewards_rejected": rewards_rejected,
569
- }
570
- return loss
571
-
572
- def prediction_step(
573
- self,
574
- model: Union[PreTrainedModel, nn.Module],
575
- inputs: dict[str, Union[torch.Tensor, Any]],
576
- prediction_loss_only: bool,
577
- ignore_keys: Optional[list[str]] = None,
578
- ) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
579
- inputs = self._prepare_inputs(inputs)
580
- if ignore_keys is None:
581
- if hasattr(self.model, "config"):
582
- ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
583
- else:
584
- ignore_keys = []
585
-
586
- with torch.no_grad():
587
- loss, logits_dict = self.compute_loss(model, inputs, return_outputs=True)
588
-
589
- if prediction_loss_only:
590
- return (loss, None, None)
591
-
592
- loss = loss.detach()
593
- logits = tuple(v for k, v in logits_dict.items() if k not in ignore_keys)
594
- logits = nested_detach(logits)
595
- # Stack accepted against rejected, mean over logits
596
- # and softmax to get preferences between accepted and rejected to sum to 1
597
- logits = torch.stack(logits).mean(dim=2).softmax(dim=0).T
598
-
599
- labels = torch.zeros(logits.shape[0])
600
- labels = self._prepare_inputs(labels)
601
-
602
- return loss, logits, labels
603
-
604
- def evaluate(self, *args, **kwargs):
605
- num_print_samples = kwargs.pop("num_print_samples", 4)
606
- self.visualize_samples(num_print_samples)
607
- return super().evaluate(*args, **kwargs)
608
-
609
- def visualize_samples(self, num_print_samples: int):
610
- """
611
- Visualize the reward model logits prediction
612
-
613
- Args:
614
- num_print_samples (`int`, defaults to `4`):
615
- The number of samples to print. Set to `-1` to print all samples.
616
- """
617
- eval_dataloader = self.get_eval_dataloader()
618
- table = defaultdict(list)
619
- for _, inputs in enumerate(eval_dataloader):
620
- _, logits, _ = self.prediction_step(self.model, inputs, prediction_loss_only=False)
621
- chosen_text = decode_and_strip_padding(inputs["input_ids_chosen"], self.processing_class)
622
- rejected_text = decode_and_strip_padding(inputs["input_ids_rejected"], self.processing_class)
623
- table["chosen_text"].extend(gather_object(chosen_text))
624
- table["rejected_text"].extend(gather_object(rejected_text))
625
- table["logits"].extend(
626
- gather_object([[round(inner_item, 4) for inner_item in item] for item in logits.tolist()])
627
- )
628
- if num_print_samples >= 0 and len(table["chosen_text"]) >= num_print_samples:
629
- break
630
- df = pd.DataFrame(table)
631
- if self.accelerator.process_index == 0:
632
- print_rich_table(df[:num_print_samples])
633
- if "wandb" in self.args.report_to:
634
- import wandb
635
-
636
- if wandb.run is not None:
637
- wandb.log({"completions": wandb.Table(dataframe=df)})
638
-
639
- if "comet_ml" in self.args.report_to:
640
- log_table_to_comet_experiment(
641
- name="completions.csv",
642
- table=df,
643
- )
644
-
645
- def create_model_card(
646
- self,
647
- model_name: Optional[str] = None,
648
- dataset_name: Optional[str] = None,
649
- tags: Union[str, list[str], None] = None,
650
- ):
651
- """
652
- Creates a draft of a model card using the information available to the `Trainer`.
653
-
654
- Args:
655
- model_name (`str` or `None`, *optional*, defaults to `None`):
656
- Name of the model.
657
- dataset_name (`str` or `None`, *optional*, defaults to `None`):
658
- Name of the dataset used for training.
659
- tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
660
- Tags to be associated with the model card.
661
- """
662
- if not self.is_world_process_zero():
663
- return
664
-
665
- if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
666
- base_model = self.model.config._name_or_path
667
- else:
668
- base_model = None
669
-
670
- tags = tags or []
671
- if isinstance(tags, str):
672
- tags = [tags]
673
-
674
- if hasattr(self.model.config, "unsloth_version"):
675
- tags.append("unsloth")
676
-
677
- model_card = generate_model_card(
678
- base_model=base_model,
679
- model_name=model_name,
680
- hub_model_id=self.hub_model_id,
681
- dataset_name=dataset_name,
682
- tags=tags,
683
- wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
684
- comet_url=get_comet_experiment_url(),
685
- trainer_name="Reward",
686
- )
687
-
688
- model_card.save(os.path.join(self.args.output_dir, "README.md"))
689
- class UnslothRewardTrainer(_UnslothRewardTrainer):
690
- """
691
-
692
- """
693
- def __init__(
694
- self,
695
- model = None,
696
- args = None,
697
- data_collator = None,
698
- train_dataset = None,
699
- eval_dataset = None,
700
- processing_class = None,
701
- model_init = None,
702
- compute_metrics = None,
703
- callbacks = None,
704
- preprocess_logits_for_metrics = None,
705
- peft_config = None,
706
- **kwargs
707
- ):
708
- if args is None: args = UnslothRewardConfig()
709
- use_bf16 = getattr(args, 'bf16', False)
710
- use_fp16 = getattr(args, 'fp16', False)
711
- force_float32 = False
712
- if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
713
- print('Unsloth: Switching to float32 training since model cannot work with float16')
714
- force_float32 = True
715
- mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
716
- dtype = getattr(model.config, 'torch_dtype', None)
717
- if dtype is None: dtype = model.get_input_embeddings().dtype
718
- from unsloth_zoo.utils import _get_dtype
719
- dtype = _get_dtype(dtype)
720
- float16 = dtype == torch.float16
721
- if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
722
- if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
723
- if force_float32:
724
- args.fp16 = False
725
- args.bf16 = False
726
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
727
- elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
728
- args.fp16 = float16
729
- args.bf16 = not float16
730
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
731
- if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
732
- args.eval_strategy = 'steps'
733
- if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
734
- ga_steps = getattr(args, 'gradient_accumulation_steps', None)
735
- if ga_steps is not None and ga_steps > 1:
736
- from transformers import __version__ as transformers_version
737
- if Version(transformers_version) <= Version('4.45.2'):
738
- print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
739
- '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
740
- if getattr(args, 'eval_strategy', 'no') != 'no':
741
- eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
742
- if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
743
- if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
744
- fp16_full_eval = getattr(args, 'fp16_full_eval', False)
745
- bf16_full_eval = getattr(args, 'bf16_full_eval', False)
746
- if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
747
- if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
748
- if force_float32:
749
- args.bf16_full_eval = False
750
- args.fp16_full_eval = False
751
- elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
752
- args.bf16_full_eval = True
753
- args.fp16_full_eval = False
754
- elif not bf16_full_eval and not fp16_full_eval:
755
- args.bf16_full_eval = args.bf16
756
- args.fp16_full_eval = args.fp16
757
- _output_logits = False
758
- if locals().get('compute_metrics', None) is not None: _output_logits = True
759
- if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
760
- if _output_logits:
761
- os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
762
- if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
763
- pass
764
- else:
765
- model_max_seq_length = getattr(model, 'max_seq_length', None)
766
- args_max_seq_length = getattr(args, 'max_seq_length', None)
767
- if args_max_seq_length is None and model_max_seq_length is not None:
768
- max_seq_length = model.max_seq_length
769
- if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
770
- if model is not None and hasattr(model, 'for_training'):
771
- model.for_training()
772
- if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
773
- if 'processing_class' in locals():
774
- if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
775
- if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
776
- __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
777
- from unsloth_zoo.vision_utils import UnslothVisionDataCollator
778
- if not isinstance(data_collator, UnslothVisionDataCollator):
779
- if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
780
- data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
781
- elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
782
- data_collator = DataCollatorForSeq2Seq(__tokenizer)
783
- else:
784
- if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
785
- if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
786
- if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
787
- if not isinstance(data_collator, UnslothVisionDataCollator):
788
- if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
789
- if isinstance(data_collator, DataCollatorForSeq2Seq):
790
- data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
791
- else:
792
- data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
793
- other_metrics = []
794
-
795
- from unsloth_zoo.logging_utils import PatchRLStatistics
796
- PatchRLStatistics('reward_trainer', other_metrics)
797
-
798
- super().__init__(
799
- model = model,
800
- args = args,
801
- data_collator = data_collator,
802
- train_dataset = train_dataset,
803
- eval_dataset = eval_dataset,
804
- processing_class = processing_class,
805
- model_init = model_init,
806
- compute_metrics = compute_metrics,
807
- callbacks = callbacks,
808
- preprocess_logits_for_metrics = preprocess_logits_for_metrics,
809
- peft_config = peft_config,**kwargs)
810
- if hasattr(self, 'neftune_hook_handle'):
811
- self.neftune_hook_handle.remove()
812
- if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
813
- if getattr(args, 'neftune_noise_alpha', None) is not None:
814
- model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
815
- pass
816
-
817
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
unsloth_compiled_cache/UnslothSFTTrainer.py DELETED
@@ -1,1025 +0,0 @@
1
- """
2
- 2025.3.13
3
- 2025.3.15
4
- 4.48.3
5
- 0.15.2
6
- __UNSLOTH_VERSIONING__
7
- """
8
- from torch import Tensor
9
- import torch
10
- import torch.nn as nn
11
- from torch.nn import functional as F
12
- from trl.trainer.sft_trainer import (Any, AutoModelForCausalLM, AutoTokenizer, BaseImageProcessor, Callable, ConstantLengthDataset, DataCollator, DataCollatorForLanguageModeling, Dataset, EvalPrediction, FeatureExtractionMixin, IterableDataset, Optional, PeftConfig, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SFTConfig, SFTTrainer, Trainer, TrainerCallback, TrainingArguments, Type, Union, dataclasses, defaultdict, deprecate_kwarg, generate_model_card, get_comet_experiment_url, get_peft_model, is_liger_kernel_available, is_peft_available, is_wandb_available, nn, os, pack_examples, peft, peft_module_casting_to_bf16, prepare_model_for_kbit_training, torch, transformers, version, wandb, warnings, Callable, ConstantLengthDataset, DataCollator, DataCollatorForLanguageModeling, Dataset, IterableDataset, Optional, Union, os, pack_examples, transformers, os)
13
-
14
-
15
- import os
16
- from typing import *
17
- from dataclasses import dataclass, field
18
- from packaging.version import Version
19
- import torch
20
- import numpy as np
21
- from contextlib import nullcontext
22
- from torch.nn import functional as F
23
- from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
-
25
- torch_compile_options = {
26
- "epilogue_fusion" : True,
27
- "max_autotune" : False,
28
- "shape_padding" : True,
29
- "trace.enabled" : False,
30
- "triton.cudagraphs" : False,
31
- }
32
-
33
- @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
- def selective_log_softmax(logits, index):
35
- logits = logits.to(torch.float32)
36
- selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
- # loop to reduce peak mem consumption
38
- # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
- logsumexp_values = torch.logsumexp(logits, dim = -1)
40
- per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
- return per_token_logps
42
- @dataclass
43
- class UnslothSFTConfig(SFTConfig):
44
- """
45
-
46
- Configuration class for the [`SFTTrainer`].
47
-
48
- Only the parameters specific to SFT training are listed here. For details on other parameters, refer to the
49
- [`~transformers.TrainingArguments`] documentation.
50
-
51
- Using [`~transformers.HfArgumentParser`] we can turn this class into
52
- [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
53
- command line.
54
-
55
- Parameters:
56
- > Parameters that control the model
57
-
58
- model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
59
- Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model`
60
- argument of the [`SFTTrainer`] is provided as a string.
61
- use_liger (`bool`, *optional*, defaults to `False`):
62
- Monkey patch the model with Liger kernels to increase throughput and reduce memory usage.
63
-
64
- > Parameters that control the data preprocessing
65
-
66
- dataset_text_field (`str`, *optional*, defaults to `"text"`):
67
- Name of the column that contains text data in the dataset.
68
- dataset_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
69
- Dictionary of optional keyword arguments for the dataset preparation. The only supported key is
70
- `skip_prepare_dataset`.
71
- dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
72
- Number of processes to use for processing the dataset.
73
- max_seq_length (`int` or `None`, *optional*, defaults to `1024`):
74
- Maximum length of the tokenized sequence. Sequences longer than `max_seq_length` are truncated from the
75
- right.
76
- If `None`, no truncation is applied. When packing is enabled, this value sets the sequence length.
77
- packing (`bool`, *optional*, defaults to `False`):
78
- Whether to pack multiple sequences into a fixed-length format. Uses `max_seq_length` to define sequence
79
- length.
80
- eval_packing (`bool` or `None`, *optional*, defaults to `None`):
81
- Whether to pack the eval dataset. If `None`, uses the same value as `packing`.
82
-
83
- > Parameters that control the training
84
-
85
- learning_rate (`float`, *optional*, defaults to `2e-5`):
86
- Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
87
- [`~transformers.TrainingArguments`].
88
-
89
- """
90
- vllm_sampling_params: Optional[Any] = field(
91
- default = None,
92
- metadata = {'help': 'vLLM SamplingParams'},
93
- )
94
- unsloth_num_chunks : Optional[int] = field(
95
- default = -1,
96
- metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
97
- )
98
- def __init__(
99
- self,
100
- output_dir = None,
101
- overwrite_output_dir = None,
102
- do_train = False,
103
- do_eval = False,
104
- do_predict = False,
105
- eval_strategy = 'no',
106
- prediction_loss_only = False,
107
- per_device_train_batch_size = 4,
108
- per_device_eval_batch_size = 4,
109
- per_gpu_train_batch_size = None,
110
- per_gpu_eval_batch_size = None,
111
- gradient_accumulation_steps = 2,
112
- eval_accumulation_steps = 2,
113
- eval_delay = 0,
114
- torch_empty_cache_steps = 250,
115
- learning_rate = 5e-05,
116
- weight_decay = 0.01,
117
- adam_beta1 = 0.9,
118
- adam_beta2 = 0.999,
119
- adam_epsilon = 1e-08,
120
- max_grad_norm = 1.0,
121
- num_train_epochs = 3.0,
122
- max_steps = -1,
123
- lr_scheduler_type = 'linear',
124
- warmup_ratio = 0.1,
125
- warmup_steps = 0,
126
- log_level = 'passive',
127
- log_level_replica = 'warning',
128
- log_on_each_node = True,
129
- logging_dir = None,
130
- logging_strategy = 'steps',
131
- logging_first_step = False,
132
- logging_steps = 1,
133
- logging_nan_inf_filter = False,
134
- save_strategy = 'steps',
135
- save_steps = 500,
136
- save_total_limit = None,
137
- save_safetensors = True,
138
- save_on_each_node = False,
139
- save_only_model = False,
140
- restore_callback_states_from_checkpoint = False,
141
- no_cuda = False,
142
- use_cpu = False,
143
- use_mps_device = False,
144
- seed = 3407,
145
- data_seed = 3407,
146
- jit_mode_eval = False,
147
- use_ipex = False,
148
- bf16 = False,
149
- fp16 = False,
150
- fp16_opt_level = 'O1',
151
- half_precision_backend = 'auto',
152
- bf16_full_eval = False,
153
- fp16_full_eval = False,
154
- tf32 = None,
155
- local_rank = -1,
156
- ddp_backend = None,
157
- tpu_num_cores = None,
158
- tpu_metrics_debug = False,
159
- debug = '',
160
- dataloader_drop_last = False,
161
- eval_steps = None,
162
- dataloader_num_workers = 0,
163
- dataloader_prefetch_factor = None,
164
- past_index = -1,
165
- run_name = None,
166
- disable_tqdm = None,
167
- remove_unused_columns = True,
168
- label_names = None,
169
- load_best_model_at_end = False,
170
- metric_for_best_model = None,
171
- greater_is_better = None,
172
- ignore_data_skip = False,
173
- fsdp = '',
174
- fsdp_min_num_params = 0,
175
- fsdp_config = None,
176
- fsdp_transformer_layer_cls_to_wrap = None,
177
- accelerator_config = None,
178
- deepspeed = None,
179
- label_smoothing_factor = 0.0,
180
- optim = 'adamw_8bit',
181
- optim_args = None,
182
- adafactor = False,
183
- group_by_length = False,
184
- length_column_name = 'length',
185
- report_to = None,
186
- ddp_find_unused_parameters = None,
187
- ddp_bucket_cap_mb = None,
188
- ddp_broadcast_buffers = None,
189
- dataloader_pin_memory = True,
190
- dataloader_persistent_workers = False,
191
- skip_memory_metrics = True,
192
- use_legacy_prediction_loop = False,
193
- push_to_hub = False,
194
- resume_from_checkpoint = None,
195
- hub_model_id = None,
196
- hub_strategy = 'every_save',
197
- hub_token = None,
198
- hub_private_repo = None,
199
- hub_always_push = False,
200
- gradient_checkpointing = False,
201
- gradient_checkpointing_kwargs = None,
202
- include_inputs_for_metrics = False,
203
- eval_do_concat_batches = True,
204
- fp16_backend = 'auto',
205
- evaluation_strategy = None,
206
- push_to_hub_model_id = None,
207
- push_to_hub_organization = None,
208
- push_to_hub_token = None,
209
- mp_parameters = '',
210
- auto_find_batch_size = False,
211
- full_determinism = False,
212
- torchdynamo = None,
213
- ray_scope = 'last',
214
- ddp_timeout = 1800,
215
- torch_compile = False,
216
- torch_compile_backend = None,
217
- torch_compile_mode = None,
218
- dispatch_batches = None,
219
- split_batches = None,
220
- include_tokens_per_second = False,
221
- include_num_input_tokens_seen = False,
222
- neftune_noise_alpha = None,
223
- optim_target_modules = None,
224
- batch_eval_metrics = False,
225
- eval_on_start = False,
226
- use_liger_kernel = False,
227
- eval_use_gather_object = False,
228
- average_tokens_across_devices = False,
229
- model_init_kwargs = None,
230
- use_liger = False,
231
- dataset_text_field = 'text',
232
- dataset_kwargs = None,
233
- dataset_num_proc = None,
234
- max_seq_length = None,
235
- packing = False,
236
- eval_packing = None,
237
- dataset_batch_size = None,
238
- num_of_sequences = None,
239
- chars_per_token = None,
240
- vllm_sampling_params = None,
241
- unsloth_num_chunks = -1,
242
- **kwargs,
243
- ):
244
- if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
245
- if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
246
- if output_dir is None and save_strategy == 'steps' and save_steps == 500:
247
- output_dir = 'unsloth_training_checkpoints'
248
- save_strategy = 'no'
249
- if dataset_num_proc is None:
250
- from multiprocessing import cpu_count
251
- dataset_num_proc = cpu_count()
252
-
253
- super().__init__(
254
- output_dir = output_dir,
255
- overwrite_output_dir = overwrite_output_dir,
256
- do_train = do_train,
257
- do_eval = do_eval,
258
- do_predict = do_predict,
259
- eval_strategy = eval_strategy,
260
- prediction_loss_only = prediction_loss_only,
261
- per_device_train_batch_size = per_device_train_batch_size,
262
- per_device_eval_batch_size = per_device_eval_batch_size,
263
- per_gpu_train_batch_size = per_gpu_train_batch_size,
264
- per_gpu_eval_batch_size = per_gpu_eval_batch_size,
265
- gradient_accumulation_steps = gradient_accumulation_steps,
266
- eval_accumulation_steps = eval_accumulation_steps,
267
- eval_delay = eval_delay,
268
- torch_empty_cache_steps = torch_empty_cache_steps,
269
- learning_rate = learning_rate,
270
- weight_decay = weight_decay,
271
- adam_beta1 = adam_beta1,
272
- adam_beta2 = adam_beta2,
273
- adam_epsilon = adam_epsilon,
274
- max_grad_norm = max_grad_norm,
275
- num_train_epochs = num_train_epochs,
276
- max_steps = max_steps,
277
- lr_scheduler_type = lr_scheduler_type,
278
- warmup_ratio = warmup_ratio,
279
- warmup_steps = warmup_steps,
280
- log_level = log_level,
281
- log_level_replica = log_level_replica,
282
- log_on_each_node = log_on_each_node,
283
- logging_dir = logging_dir,
284
- logging_strategy = logging_strategy,
285
- logging_first_step = logging_first_step,
286
- logging_steps = logging_steps,
287
- logging_nan_inf_filter = logging_nan_inf_filter,
288
- save_strategy = save_strategy,
289
- save_steps = save_steps,
290
- save_total_limit = save_total_limit,
291
- save_safetensors = save_safetensors,
292
- save_on_each_node = save_on_each_node,
293
- save_only_model = save_only_model,
294
- restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
295
- no_cuda = no_cuda,
296
- use_cpu = use_cpu,
297
- use_mps_device = use_mps_device,
298
- seed = seed,
299
- data_seed = data_seed,
300
- jit_mode_eval = jit_mode_eval,
301
- use_ipex = use_ipex,
302
- bf16 = bf16,
303
- fp16 = fp16,
304
- fp16_opt_level = fp16_opt_level,
305
- half_precision_backend = half_precision_backend,
306
- bf16_full_eval = bf16_full_eval,
307
- fp16_full_eval = fp16_full_eval,
308
- tf32 = tf32,
309
- local_rank = local_rank,
310
- ddp_backend = ddp_backend,
311
- tpu_num_cores = tpu_num_cores,
312
- tpu_metrics_debug = tpu_metrics_debug,
313
- debug = debug,
314
- dataloader_drop_last = dataloader_drop_last,
315
- eval_steps = eval_steps,
316
- dataloader_num_workers = dataloader_num_workers,
317
- dataloader_prefetch_factor = dataloader_prefetch_factor,
318
- past_index = past_index,
319
- run_name = run_name,
320
- disable_tqdm = disable_tqdm,
321
- remove_unused_columns = remove_unused_columns,
322
- label_names = label_names,
323
- load_best_model_at_end = load_best_model_at_end,
324
- metric_for_best_model = metric_for_best_model,
325
- greater_is_better = greater_is_better,
326
- ignore_data_skip = ignore_data_skip,
327
- fsdp = fsdp,
328
- fsdp_min_num_params = fsdp_min_num_params,
329
- fsdp_config = fsdp_config,
330
- fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
331
- accelerator_config = accelerator_config,
332
- deepspeed = deepspeed,
333
- label_smoothing_factor = label_smoothing_factor,
334
- optim = optim,
335
- optim_args = optim_args,
336
- adafactor = adafactor,
337
- group_by_length = group_by_length,
338
- length_column_name = length_column_name,
339
- report_to = report_to,
340
- ddp_find_unused_parameters = ddp_find_unused_parameters,
341
- ddp_bucket_cap_mb = ddp_bucket_cap_mb,
342
- ddp_broadcast_buffers = ddp_broadcast_buffers,
343
- dataloader_pin_memory = dataloader_pin_memory,
344
- dataloader_persistent_workers = dataloader_persistent_workers,
345
- skip_memory_metrics = skip_memory_metrics,
346
- use_legacy_prediction_loop = use_legacy_prediction_loop,
347
- push_to_hub = push_to_hub,
348
- resume_from_checkpoint = resume_from_checkpoint,
349
- hub_model_id = hub_model_id,
350
- hub_strategy = hub_strategy,
351
- hub_token = hub_token,
352
- hub_private_repo = hub_private_repo,
353
- hub_always_push = hub_always_push,
354
- gradient_checkpointing = gradient_checkpointing,
355
- gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
356
- include_inputs_for_metrics = include_inputs_for_metrics,
357
- eval_do_concat_batches = eval_do_concat_batches,
358
- fp16_backend = fp16_backend,
359
- evaluation_strategy = evaluation_strategy,
360
- push_to_hub_model_id = push_to_hub_model_id,
361
- push_to_hub_organization = push_to_hub_organization,
362
- push_to_hub_token = push_to_hub_token,
363
- mp_parameters = mp_parameters,
364
- auto_find_batch_size = auto_find_batch_size,
365
- full_determinism = full_determinism,
366
- torchdynamo = torchdynamo,
367
- ray_scope = ray_scope,
368
- ddp_timeout = ddp_timeout,
369
- torch_compile = torch_compile,
370
- torch_compile_backend = torch_compile_backend,
371
- torch_compile_mode = torch_compile_mode,
372
- dispatch_batches = dispatch_batches,
373
- split_batches = split_batches,
374
- include_tokens_per_second = include_tokens_per_second,
375
- include_num_input_tokens_seen = include_num_input_tokens_seen,
376
- neftune_noise_alpha = neftune_noise_alpha,
377
- optim_target_modules = optim_target_modules,
378
- batch_eval_metrics = batch_eval_metrics,
379
- eval_on_start = eval_on_start,
380
- use_liger_kernel = use_liger_kernel,
381
- eval_use_gather_object = eval_use_gather_object,
382
- average_tokens_across_devices = average_tokens_across_devices,
383
- model_init_kwargs = model_init_kwargs,
384
- use_liger = use_liger,
385
- dataset_text_field = dataset_text_field,
386
- dataset_kwargs = dataset_kwargs,
387
- dataset_num_proc = dataset_num_proc,
388
- max_seq_length = max_seq_length,
389
- packing = packing,
390
- eval_packing = eval_packing,
391
- dataset_batch_size = dataset_batch_size,
392
- num_of_sequences = num_of_sequences,
393
- chars_per_token = chars_per_token,**kwargs)
394
- self.vllm_sampling_params = vllm_sampling_params
395
- self.unsloth_num_chunks = unsloth_num_chunks
396
- pass
397
-
398
- class _UnslothSFTTrainer(Trainer):
399
- """"""
400
-
401
- _tag_names = ["trl", "sft"]
402
-
403
- @deprecate_kwarg(
404
- "tokenizer", "0.16.0", "processing_class", warn_if_greater_or_equal_version=True, raise_if_both_names=True
405
- )
406
- def __init__(
407
- self,
408
- model: Union[str, nn.Module, PreTrainedModel],
409
- args: Optional[Union[SFTConfig, TrainingArguments]] = None,
410
- data_collator: Optional[DataCollator] = None, # type: ignore
411
- train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
412
- eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
413
- processing_class: Optional[
414
- Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
415
- ] = None,
416
- compute_loss_func: Optional[Callable] = None,
417
- compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
418
- callbacks: Optional[list[TrainerCallback]] = None,
419
- optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
420
- optimizer_cls_and_kwargs: Optional[tuple[Type[torch.optim.Optimizer], dict[str, Any]]] = None,
421
- preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
422
- peft_config: Optional["PeftConfig"] = None,
423
- formatting_func: Optional[Union[Callable[[dict], str], Callable[[dict], list[str]]]] = None,
424
- ):
425
- # Args
426
- if args is None:
427
- model_name = model if isinstance(model, str) else model.config._name_or_path
428
- model_name = model_name.split("/")[-1]
429
- args = SFTConfig(f"{model_name}-SFT")
430
- elif isinstance(args, TrainingArguments) and not isinstance(args, SFTConfig):
431
- dict_args = args.to_dict()
432
- dict_args["hub_token"] = args.hub_token # to_dict hides the hub_token
433
- dict_args.pop("push_to_hub_token")
434
- args = SFTConfig(**dict_args)
435
-
436
- # Model
437
- if args.model_init_kwargs is not None and not isinstance(model, str):
438
- warnings.warn(
439
- "You passed model_init_kwargs to the `SFTConfig`, but your model is already instantiated. "
440
- "The `model_init_kwargs` will be ignored."
441
- )
442
- if isinstance(model, str):
443
- model = self._create_model_from_path(model, args)
444
-
445
- # PEFT configuration and model wrapping
446
- if False:
447
- model = self._prepare_peft_model(model, peft_config, args)
448
-
449
- # Handle the tokenizer
450
- if processing_class is None:
451
- processing_class = AutoTokenizer.from_pretrained(model.config._name_or_path)
452
- if processing_class.pad_token is None:
453
- processing_class.pad_token = processing_class.eos_token # required for padding when collating data
454
-
455
- # Dataset
456
- preprocess_dataset = args.dataset_kwargs is None or not args.dataset_kwargs.get("skip_prepare_dataset", False)
457
- if preprocess_dataset:
458
- train_dataset = self._prepare_dataset(
459
- train_dataset, processing_class, args, args.packing, formatting_func, "train"
460
- )
461
- if eval_dataset is not None:
462
- packing = args.packing if args.eval_packing is None else args.eval_packing
463
- if isinstance(eval_dataset, dict):
464
- eval_dataset = {
465
- key: self._prepare_dataset(dataset, processing_class, args, packing, formatting_func, key)
466
- for key, dataset in eval_dataset.items()
467
- }
468
- else:
469
- eval_dataset = self._prepare_dataset(
470
- eval_dataset, processing_class, args, packing, formatting_func, "eval"
471
- )
472
-
473
- # Data collator
474
- if data_collator is None:
475
- data_collator = DataCollatorForLanguageModeling(tokenizer=processing_class, mlm=False)
476
-
477
- # Initialize the metrics
478
- self._metrics = defaultdict(list)
479
-
480
- # Initialize the Trainer. Parent class will handle:
481
- # - DeepSpeed configuration (through create_accelerator_and_postprocess)
482
- # - FSDP setup
483
- # - Distributed training setup
484
- # - Optimizer and scheduler creation
485
- # Some arguments are only available for transformers>=4.47.0. Can be removed when the min version is bumped.
486
- super_init_kwargs = {}
487
- if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
488
- super_init_kwargs["optimizer_cls_and_kwargs"] = optimizer_cls_and_kwargs
489
- else:
490
- if optimizer_cls_and_kwargs is not None:
491
- warnings.warn(
492
- "The `optimizer_cls_and_kwargs` argument is only available for `transformers>=4.47.0`. "
493
- "The default optimizer will be used. "
494
- "Remove the `optimizer_cls_and_kwargs` or upgrade to `transformers>=4.47.0`."
495
- )
496
- super().__init__(
497
- model=model,
498
- args=args,
499
- data_collator=data_collator,
500
- train_dataset=train_dataset,
501
- eval_dataset=eval_dataset,
502
- processing_class=processing_class,
503
- compute_loss_func=compute_loss_func,
504
- compute_metrics=compute_metrics,
505
- callbacks=callbacks,
506
- optimizers=optimizers,
507
- preprocess_logits_for_metrics=preprocess_logits_for_metrics,
508
- **super_init_kwargs,
509
- )
510
-
511
- # Add tags for models that have been loaded with the correct transformers version
512
- if hasattr(self.model, "add_model_tags"):
513
- self.model.add_model_tags(self._tag_names)
514
-
515
- def _create_model_from_path(self, model_path: str, args: SFTConfig) -> PreTrainedModel:
516
- """Creates a model from a path or model identifier."""
517
- model_init_kwargs = args.model_init_kwargs or {}
518
- # Handle torch dtype
519
- torch_dtype = model_init_kwargs.get("torch_dtype")
520
- if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None:
521
- pass # torch_dtype is already a torch.dtype or "auto" or None
522
- elif isinstance(torch_dtype, str): # it's a str, but not "auto"
523
- torch_dtype = getattr(torch, torch_dtype)
524
- model_init_kwargs["torch_dtype"] = torch_dtype
525
- else:
526
- raise ValueError(
527
- "Invalid `torch_dtype` passed to `SFTConfig`. Expected either 'auto' or a string representing "
528
- f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
529
- )
530
- # Disable caching if gradient checkpointing is enabled (not supported)
531
- if args.gradient_checkpointing:
532
- model_init_kwargs["use_cache"] = False
533
-
534
- # Create model
535
- if args.use_liger:
536
- if not is_liger_kernel_available():
537
- raise ImportError("Please install Liger-kernel for use_liger=True")
538
- model = AutoLigerKernelForCausalLM.from_pretrained(model_path, **model_init_kwargs)
539
- else:
540
- model = AutoModelForCausalLM.from_pretrained(model_path, **model_init_kwargs)
541
- return model
542
-
543
- def _prepare_peft_model(self, model: PreTrainedModel, peft_config: Any, args: SFTConfig) -> PreTrainedModel:
544
- """Prepares a model for PEFT training."""
545
- if not is_peft_available():
546
- raise ImportError("To use PeftModel, you need to install the `peft` library.")
547
-
548
- if not isinstance(peft_config, PeftConfig):
549
- raise ValueError(
550
- f"Expected PeftConfig object but got {type(peft_config)}. If you want to use the PeftModel, you need "
551
- "to pass a PeftConfig object to the SFTTrainer."
552
- )
553
-
554
- if isinstance(model, PeftModel):
555
- return model
556
-
557
- # Handle quantized models (QLoRA)
558
- is_qlora = getattr(model, "is_loaded_in_4bit", False) or getattr(model, "is_loaded_in_8bit", False)
559
-
560
- is_sharded_qlora = False
561
- if getattr(model, "is_loaded_in_4bit", False):
562
- # Check if model is sharded (FSDP/DS-Zero3)
563
- for _, param in model.named_parameters():
564
- if param.__class__.__name__ == "Params4bit":
565
- is_sharded_qlora = param.data.device.type in {"cpu", "meta"}
566
- break
567
-
568
- # Prepare model for kbit training if needed
569
- if is_qlora and not is_sharded_qlora:
570
- model = self._prepare_model_for_kbit_training(model, args)
571
- # Disable gradient checkpointing as it's handled by prepare_model_for_kbit_training
572
- args = dataclasses.replace(args, gradient_checkpointing=False)
573
- elif args.gradient_checkpointing:
574
- model = self._enable_gradient_checkpointing(model, args)
575
-
576
- # Create PEFT model
577
- if (
578
- version.parse(peft.__version__) >= version.parse("0.12") # autocast_adapter_dtype introduced in 0.12
579
- and getattr(model, "is_loaded_in_4bit", False)
580
- and is_sharded_qlora
581
- ):
582
- model = get_peft_model(model, peft_config, autocast_adapter_dtype=False)
583
- else:
584
- model = get_peft_model(model, peft_config)
585
-
586
- # Handle bf16 casting for 4-bit models
587
- if args.bf16 and getattr(model, "is_loaded_in_4bit", False) and not is_sharded_qlora:
588
- peft_module_casting_to_bf16(model)
589
-
590
- return model
591
-
592
- def _prepare_model_for_kbit_training(self, model: PreTrainedModel, args: SFTConfig) -> PreTrainedModel:
593
- """Prepares a quantized model for kbit training."""
594
- prepare_model_kwargs = {
595
- "use_gradient_checkpointing": args.gradient_checkpointing,
596
- "gradient_checkpointing_kwargs": args.gradient_checkpointing_kwargs or {},
597
- }
598
-
599
- return prepare_model_for_kbit_training(model, **prepare_model_kwargs)
600
-
601
- def _enable_gradient_checkpointing(self, model: PreTrainedModel, args: SFTConfig) -> PreTrainedModel:
602
- """Enables gradient checkpointing for the model."""
603
- gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
604
- use_reentrant = (
605
- "use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]
606
- )
607
-
608
- if use_reentrant:
609
- if hasattr(model, "enable_input_require_grads"):
610
- model.enable_input_require_grads()
611
- else:
612
-
613
- def make_inputs_require_grad(module, input, output):
614
- output.requires_grad_(True)
615
-
616
- model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
617
-
618
- return model
619
-
620
- def _prepare_dataset(
621
- self,
622
- dataset: Union[Dataset, IterableDataset],
623
- processing_class,
624
- args,
625
- packing: bool,
626
- formatting_func: Optional[Callable[[dict], str]],
627
- dataset_name: str,
628
- ) -> Union[Dataset, IterableDataset]:
629
- # All Unsloth Zoo code licensed under LGPLv3
630
- if isinstance(dataset, ConstantLengthDataset): return dataset
631
-
632
- map_kwargs = {}
633
- use_desc = isinstance(dataset, Dataset)
634
- is_vlm = hasattr(processing_class, "tokenizer")
635
- tokenizer = processing_class
636
- if is_vlm: tokenizer = processing_class.tokenizer
637
-
638
- # Get max length
639
- max_seq_length = getattr(args, "max_length", 0)
640
- if max_seq_length == 0: max_seq_length = getattr(args, "max_seq_length", 0)
641
- if max_seq_length == 0: max_seq_length = getattr(self, "max_seq_length", 0)
642
- if max_seq_length == 0: max_seq_length = getattr(self, "max_seq", 0)
643
- if max_seq_length == 0: raise RuntimeError("Unsloth: max_seq_length is 0! Please specify one!")
644
- dataset_text_field = getattr(args, "dataset_text_field", "text")
645
- do_truncation = max_seq_length != 0
646
- do_formatting_func = False
647
- do_tokenize = True
648
-
649
- # Get correct column names
650
- column_names = set(next(iter(dataset)).keys())
651
- used_column_names = ["input_ids"]
652
- if "attention_mask" in column_names:
653
- used_column_names.append("attention_mask")
654
-
655
- # Check if already tokenized so skip
656
- from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
657
- if "labels" in column_names:
658
- # Most likely forgot data collator!
659
- if is_vlm and not hasattr(tokenizer, "pad"):
660
- # Check if processing_class has a .pad, if not, use tokenizer.tokenizer
661
- raise RuntimeError(f"Unsloth: {processing_class.__class__} does not have .pad!")
662
- self.data_collator = DataCollatorForSeq2Seq(tokenizer)
663
- used_column_names.append("labels")
664
- do_tokenize = False
665
- elif "input_ids" in column_names:
666
- # Skip dataset prep, and set data collator
667
- if is_vlm and not hasattr(tokenizer, "pad"):
668
- # Check if processing_class has a .pad, if not, use tokenizer.tokenizer
669
- raise RuntimeError(f"Unsloth: {processing_class.__class__} does not have .pad!")
670
- self.data_collator = DataCollatorForLanguageModeling(tokenizer, mlm = False)
671
- do_tokenize = False
672
- elif dataset_text_field not in column_names:
673
- do_formatting_func = True
674
- if formatting_func is None:
675
- raise RuntimeError("Unsloth: You must specify a `formatting_func`")
676
- pass
677
-
678
- if do_tokenize:
679
- # Check double BOS tokens
680
- if do_formatting_func:
681
- test_text = formatting_func(dataset[0])
682
- if not isinstance(test_text, list):
683
- raise ValueError(
684
- "Unsloth: The `formatting_func` should return a list of processed strings."
685
- )
686
- test_text = test_text[0]
687
- else:
688
- test_text = dataset[0][dataset_text_field]
689
-
690
- # Get chat template
691
- chat_template = getattr(processing_class, 'chat_template', '')
692
- if chat_template == '' and is_vlm:
693
- chat_template = getattr(tokenizer, 'chat_template', '')
694
- if chat_template is None:
695
- chat_template = ''
696
-
697
- # Get bos_token
698
- add_special_tokens = True
699
- bos_token_1 = getattr(processing_class, 'bos_token', None)
700
- bos_token_2 = getattr(tokenizer, 'bos_token', None)
701
- bos_token = bos_token_1 or bos_token_2
702
-
703
- if bos_token is not None:
704
- if test_text.startswith(bos_token) or bos_token in chat_template:
705
- add_special_tokens = False
706
- print("Unsloth: We found double BOS tokens - we shall remove one automatically.")
707
- pass
708
-
709
- # Create tokenize function
710
- def _tokenize(example):
711
- return tokenizer(
712
- example[dataset_text_field] if not do_formatting_func else formatting_func(example),
713
- truncation = do_truncation,
714
- max_length = max_seq_length,
715
- return_token_type_ids = False,
716
- add_special_tokens = add_special_tokens,
717
- )
718
- pass
719
-
720
- map_kwargs["num_proc"] = getattr(args, "dataset_num_proc", 2)
721
- if use_desc: map_kwargs["desc"] = f'Unsloth: Tokenizing ["{dataset_text_field}"]'
722
- dataset = dataset.map(_tokenize, batched = True, **map_kwargs)
723
-
724
- # If VLM, switch data collator since .pad is needed!
725
- if is_vlm and not hasattr(processing_class, "pad"):
726
- data_collator = DataCollatorForLanguageModeling(tokenizer, mlm = False)
727
- self.data_collator = data_collator
728
- pass
729
- pass
730
- if packing:
731
- print("Unsloth: Hugging Face's packing is currently buggy - we're disabling it for now!")
732
- return dataset
733
-
734
- if max_seq_length == 0:
735
- raise ValueError("When packing is enabled, `max_seq_length` can't be `None`.")
736
-
737
- if use_desc: map_kwargs["desc"] = f"Unsloth: Packing {dataset_name} dataset"
738
- dataset = dataset.select_columns(used_column_names).map(
739
- pack_examples,
740
- batched = True,
741
- fn_kwargs = {"seq_length": max_seq_length,},
742
- **map_kwargs,
743
- )
744
- pass
745
- return dataset
746
-
747
- def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch = None):
748
- outputs = super().compute_loss(
749
- model,
750
- inputs,
751
- return_outputs = return_outputs,
752
- num_items_in_batch = num_items_in_batch,
753
- )
754
- return outputs
755
-
756
- def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
757
- metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics
758
-
759
- # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
760
- # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
761
- if next(iter(logs.keys())).startswith("eval_"):
762
- metrics = {f"eval_{key}": val for key, val in metrics.items()}
763
-
764
- logs = {**logs, **metrics}
765
- if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
766
- super().log(logs, start_time)
767
- else: # transformers<=4.46
768
- super().log(logs)
769
- self._metrics.clear()
770
-
771
- def create_model_card(
772
- self,
773
- model_name: Optional[str] = None,
774
- dataset_name: Optional[str] = None,
775
- tags: Union[str, list[str], None] = None,
776
- ):
777
- """
778
- Creates a draft of a model card using the information available to the `Trainer`.
779
-
780
- Args:
781
- model_name (`str` or `None`, *optional*, defaults to `None`):
782
- Name of the model.
783
- dataset_name (`str` or `None`, *optional*, defaults to `None`):
784
- Name of the dataset used for training.
785
- tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
786
- Tags to be associated with the model card.
787
- """
788
- if not self.is_world_process_zero():
789
- return
790
-
791
- if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
792
- base_model = self.model.config._name_or_path
793
- else:
794
- base_model = None
795
-
796
- tags = tags or []
797
- if isinstance(tags, str):
798
- tags = [tags]
799
-
800
- if hasattr(self.model.config, "unsloth_version"):
801
- tags.append("unsloth")
802
-
803
- model_card = generate_model_card(
804
- base_model=base_model,
805
- model_name=model_name,
806
- hub_model_id=self.hub_model_id,
807
- dataset_name=dataset_name,
808
- tags=tags,
809
- wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
810
- comet_url=get_comet_experiment_url(),
811
- trainer_name="SFT",
812
- )
813
-
814
- model_card.save(os.path.join(self.args.output_dir, "README.md"))
815
- class UnslothSFTTrainer(_UnslothSFTTrainer):
816
- """
817
-
818
- Trainer for Supervised Fine-Tuning (SFT) method.
819
-
820
- This class is a wrapper around the [`transformers.Trainer`] class and inherits all of its attributes and methods.
821
-
822
- Example:
823
-
824
- ```python
825
- from datasets import load_dataset
826
- from trl import SFTTrainer
827
-
828
- dataset = load_dataset("roneneldan/TinyStories", split="train[:1%]")
829
-
830
- trainer = SFTTrainer(model="Qwen/Qwen2-0.5B-Instruct", train_dataset=dataset)
831
- trainer.train()
832
- ```
833
-
834
- Args:
835
- model (`Union[str, PreTrainedModel]`):
836
- Model to be trained. Can be either:
837
-
838
- - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or
839
- a path to a *directory* containing model weights saved using
840
- [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is
841
- loaded using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keywork arguments
842
- in `args.model_init_kwargs`.
843
- - A [`~transformers.PreTrainedModel`] object. Only causal language models are supported.
844
- args ([`SFTConfig`], *optional*, defaults to `None`):
845
- Configuration for this trainer. If `None`, a default configuration is used.
846
- data_collator (`DataCollator`, *optional*):
847
- Function to use to form a batch from a list of elements of the prcessed `train_dataset` or `eval_dataset`.
848
- Will default to [`~transformers.default_data_collator`] if no `processing_class` is provided, an instance
849
- of [`~transformers.DataCollatorWithPadding`] otherwise if the processing_class is a feature extractor or
850
- tokenizer.
851
- train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
852
- Dataset to use for training. SFT supports both [language modeling](#language-modeling) type and
853
- [prompt-completion](#prompt-completion) type. The format of the samples can be either:
854
-
855
- - [Standard](dataset_formats#standard): Each sample contains plain text.
856
- - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
857
- and content).
858
-
859
- The trainer also supports processed datasets (tokenized) as long as they contain an `input_ids` field.
860
- eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
861
- Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
862
- processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*, defaults to `None`):
863
- Processing class used to process the data. If `None`, the processing class is loaded from the model's name
864
- with [`~transformers.AutoTokenizer.from_pretrained`].
865
- callbacks (list of [`~transformers.TrainerCallback`], *optional*, defaults to `None`):
866
- List of callbacks to customize the training loop. Will add those to the list of default callbacks
867
- detailed in [here](https://huggingface.co/docs/transformers/main_classes/callback).
868
-
869
- If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`]
870
- method.
871
- optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`):
872
- A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your
873
- model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
874
- optimizer_cls_and_kwargs (`Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]`, *optional*, defaults to `None`):
875
- A tuple containing the optimizer class and keyword arguments to use.
876
- Overrides `optim` and `optim_args` in `args`. Incompatible with the `optimizers` argument.
877
-
878
- Unlike `optimizers`, this argument avoids the need to place model parameters on the correct devices before initializing the Trainer.
879
- preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*, defaults to `None`):
880
- A function that preprocess the logits right before caching them at each evaluation step. Must take two
881
- tensors, the logits and the labels, and return the logits once processed as desired. The modifications made
882
- by this function will be reflected in the predictions received by `compute_metrics`.
883
-
884
- Note that the labels (second parameter) will be `None` if the dataset does not have them.
885
- peft_config ([`~peft.PeftConfig`], *optional*, defaults to `None`):
886
- PEFT configuration used to wrap the model. If `None`, the model is not wrapped.
887
- formatting_func (`Optional[Callable]`):
888
- Formatting function applied to the dataset before tokenization.
889
-
890
- """
891
- def __init__(
892
- self,
893
- model,
894
- args = None,
895
- data_collator = None,
896
- train_dataset = None,
897
- eval_dataset = None,
898
- processing_class = None,
899
- compute_loss_func = None,
900
- compute_metrics = None,
901
- callbacks = None,
902
- optimizer_cls_and_kwargs = None,
903
- preprocess_logits_for_metrics = None,
904
- peft_config = None,
905
- formatting_func = None,
906
- **kwargs
907
- ):
908
- if args is None: args = UnslothSFTConfig()
909
- use_bf16 = getattr(args, 'bf16', False)
910
- use_fp16 = getattr(args, 'fp16', False)
911
- force_float32 = False
912
- if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
913
- print('Unsloth: Switching to float32 training since model cannot work with float16')
914
- force_float32 = True
915
- mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
916
- dtype = getattr(model.config, 'torch_dtype', None)
917
- if dtype is None: dtype = model.get_input_embeddings().dtype
918
- from unsloth_zoo.utils import _get_dtype
919
- dtype = _get_dtype(dtype)
920
- float16 = dtype == torch.float16
921
- if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
922
- if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
923
- if force_float32:
924
- args.fp16 = False
925
- args.bf16 = False
926
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
927
- elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
928
- args.fp16 = float16
929
- args.bf16 = not float16
930
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
931
- if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
932
- args.eval_strategy = 'steps'
933
- if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
934
- ga_steps = getattr(args, 'gradient_accumulation_steps', None)
935
- if ga_steps is not None and ga_steps > 1:
936
- from transformers import __version__ as transformers_version
937
- if Version(transformers_version) <= Version('4.45.2'):
938
- print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
939
- '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
940
- if getattr(args, 'eval_strategy', 'no') != 'no':
941
- eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
942
- if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
943
- if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
944
- fp16_full_eval = getattr(args, 'fp16_full_eval', False)
945
- bf16_full_eval = getattr(args, 'bf16_full_eval', False)
946
- if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
947
- if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
948
- if force_float32:
949
- args.bf16_full_eval = False
950
- args.fp16_full_eval = False
951
- elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
952
- args.bf16_full_eval = True
953
- args.fp16_full_eval = False
954
- elif not bf16_full_eval and not fp16_full_eval:
955
- args.bf16_full_eval = args.bf16
956
- args.fp16_full_eval = args.fp16
957
- _output_logits = False
958
- if locals().get('compute_metrics', None) is not None: _output_logits = True
959
- if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
960
- if _output_logits:
961
- os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
962
- if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
963
- pass
964
- else:
965
- model_max_seq_length = getattr(model, 'max_seq_length', None)
966
- args_max_seq_length = getattr(args, 'max_seq_length', None)
967
- if args_max_seq_length is None and model_max_seq_length is not None:
968
- max_seq_length = model.max_seq_length
969
- if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
970
- if model is not None and hasattr(model, 'for_training'):
971
- model.for_training()
972
- if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
973
- if 'processing_class' in locals():
974
- if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
975
- if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
976
- __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
977
- from unsloth_zoo.vision_utils import UnslothVisionDataCollator
978
- if not isinstance(data_collator, UnslothVisionDataCollator):
979
- if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
980
- data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
981
- elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
982
- data_collator = DataCollatorForSeq2Seq(__tokenizer)
983
- else:
984
- if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
985
- if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
986
- if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
987
- if not isinstance(data_collator, UnslothVisionDataCollator):
988
- if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
989
- if isinstance(data_collator, DataCollatorForSeq2Seq):
990
- data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
991
- else:
992
- data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
993
- other_metrics = []
994
-
995
- from unsloth_zoo.logging_utils import PatchRLStatistics
996
- PatchRLStatistics('sft_trainer', other_metrics)
997
- IGNORED_TOKENIZER_NAMES = os.environ.get('UNSLOTH_IGNORED_TOKENIZER_NAMES', '').split('\n')
998
- from unsloth_zoo.tokenizer_utils import fix_untrained_tokens
999
- from unsloth_zoo.training_utils import fix_zero_training_loss
1000
- if 'tokenizer' not in locals(): tokenizer = processing_class
1001
- fix_untrained_tokens(model, tokenizer, train_dataset, IGNORED_TOKENIZER_NAMES, eps = 1e-16)
1002
- fix_zero_training_loss(model, tokenizer, train_dataset)
1003
-
1004
- super().__init__(
1005
- model = model,
1006
- args = args,
1007
- data_collator = data_collator,
1008
- train_dataset = train_dataset,
1009
- eval_dataset = eval_dataset,
1010
- processing_class = processing_class,
1011
- compute_loss_func = compute_loss_func,
1012
- compute_metrics = compute_metrics,
1013
- callbacks = callbacks,
1014
- optimizer_cls_and_kwargs = optimizer_cls_and_kwargs,
1015
- preprocess_logits_for_metrics = preprocess_logits_for_metrics,
1016
- peft_config = peft_config,
1017
- formatting_func = formatting_func,**kwargs)
1018
- if hasattr(self, 'neftune_hook_handle'):
1019
- self.neftune_hook_handle.remove()
1020
- if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1021
- if getattr(args, 'neftune_noise_alpha', None) is not None:
1022
- model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1023
- pass
1024
-
1025
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
unsloth_compiled_cache/UnslothXPOTrainer.py DELETED
@@ -1,1008 +0,0 @@
1
- """
2
- 2025.3.13
3
- 2025.3.15
4
- 4.48.3
5
- 0.15.2
6
- __UNSLOTH_VERSIONING__
7
- """
8
- from torch import Tensor
9
- import torch
10
- import torch.nn as nn
11
- from torch.nn import functional as F
12
- from trl.trainer.xpo_trainer import (Any, BaseImageProcessor, BasePairwiseJudge, Callable, Dataset, EvalPrediction, F, FeatureExtractionMixin, IterableDataset, OnlineDPOTrainer, OptimizerNames, Optional, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SIMPLE_CHAT_TEMPLATE, TrainerCallback, Union, XPOConfig, XPOTrainer, empty_cache, generate_model_card, get_comet_experiment_url, get_reward, is_conversational, is_wandb_available, jinja2, maybe_apply_chat_template, nn, os, textwrap, torch, truncate_right, unwrap_model_for_generation, wandb)
13
-
14
-
15
- import os
16
- from typing import *
17
- from dataclasses import dataclass, field
18
- from packaging.version import Version
19
- import torch
20
- import numpy as np
21
- from contextlib import nullcontext
22
- from torch.nn import functional as F
23
- from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
-
25
- torch_compile_options = {
26
- "epilogue_fusion" : True,
27
- "max_autotune" : False,
28
- "shape_padding" : True,
29
- "trace.enabled" : False,
30
- "triton.cudagraphs" : False,
31
- }
32
-
33
- @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
- def selective_log_softmax(logits, index):
35
- logits = logits.to(torch.float32)
36
- selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
- # loop to reduce peak mem consumption
38
- # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
- logsumexp_values = torch.logsumexp(logits, dim = -1)
40
- per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
- return per_token_logps
42
- @dataclass
43
- class UnslothXPOConfig(XPOConfig):
44
- """
45
-
46
- Configuration class for the [`XPOTrainer`].
47
-
48
- Subclass of [`OnlineDPOConfig`] we can use all its arguments and add the following:
49
-
50
- Parameters:
51
- alpha (`float` or `list[float]`, *optional*, defaults to `1e-5`):
52
- Weight of the XPO loss term. If a list of floats is provided then the alpha is selected for each new epoch
53
- and the last alpha is used for the rest of the epochs.
54
-
55
- """
56
- vllm_sampling_params: Optional[Any] = field(
57
- default = None,
58
- metadata = {'help': 'vLLM SamplingParams'},
59
- )
60
- unsloth_num_chunks : Optional[int] = field(
61
- default = -1,
62
- metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
63
- )
64
- def __init__(
65
- self,
66
- output_dir = None,
67
- overwrite_output_dir = None,
68
- do_train = False,
69
- do_eval = False,
70
- do_predict = False,
71
- eval_strategy = 'no',
72
- prediction_loss_only = False,
73
- per_device_train_batch_size = 4,
74
- per_device_eval_batch_size = 4,
75
- per_gpu_train_batch_size = None,
76
- per_gpu_eval_batch_size = None,
77
- gradient_accumulation_steps = 2,
78
- eval_accumulation_steps = 2,
79
- eval_delay = 0,
80
- torch_empty_cache_steps = 250,
81
- learning_rate = 5e-05,
82
- weight_decay = 0.01,
83
- adam_beta1 = 0.9,
84
- adam_beta2 = 0.999,
85
- adam_epsilon = 1e-08,
86
- max_grad_norm = 1.0,
87
- num_train_epochs = 3.0,
88
- max_steps = -1,
89
- lr_scheduler_type = 'linear',
90
- warmup_ratio = 0.1,
91
- warmup_steps = 0,
92
- log_level = 'passive',
93
- log_level_replica = 'warning',
94
- log_on_each_node = True,
95
- logging_dir = None,
96
- logging_strategy = 'steps',
97
- logging_first_step = False,
98
- logging_steps = 1,
99
- logging_nan_inf_filter = False,
100
- save_strategy = 'steps',
101
- save_steps = 500,
102
- save_total_limit = None,
103
- save_safetensors = True,
104
- save_on_each_node = False,
105
- save_only_model = False,
106
- restore_callback_states_from_checkpoint = False,
107
- no_cuda = False,
108
- use_cpu = False,
109
- use_mps_device = False,
110
- seed = 3407,
111
- data_seed = 3407,
112
- jit_mode_eval = False,
113
- use_ipex = False,
114
- bf16 = False,
115
- fp16 = False,
116
- fp16_opt_level = 'O1',
117
- half_precision_backend = 'auto',
118
- bf16_full_eval = False,
119
- fp16_full_eval = False,
120
- tf32 = None,
121
- local_rank = -1,
122
- ddp_backend = None,
123
- tpu_num_cores = None,
124
- tpu_metrics_debug = False,
125
- debug = '',
126
- dataloader_drop_last = False,
127
- eval_steps = None,
128
- dataloader_num_workers = 0,
129
- dataloader_prefetch_factor = None,
130
- past_index = -1,
131
- run_name = None,
132
- disable_tqdm = None,
133
- remove_unused_columns = True,
134
- label_names = None,
135
- load_best_model_at_end = False,
136
- metric_for_best_model = None,
137
- greater_is_better = None,
138
- ignore_data_skip = False,
139
- fsdp = '',
140
- fsdp_min_num_params = 0,
141
- fsdp_config = None,
142
- fsdp_transformer_layer_cls_to_wrap = None,
143
- accelerator_config = None,
144
- deepspeed = None,
145
- label_smoothing_factor = 0.0,
146
- optim = 'adamw_8bit',
147
- optim_args = None,
148
- adafactor = False,
149
- group_by_length = False,
150
- length_column_name = 'length',
151
- report_to = None,
152
- ddp_find_unused_parameters = None,
153
- ddp_bucket_cap_mb = None,
154
- ddp_broadcast_buffers = None,
155
- dataloader_pin_memory = True,
156
- dataloader_persistent_workers = False,
157
- skip_memory_metrics = True,
158
- use_legacy_prediction_loop = False,
159
- push_to_hub = False,
160
- resume_from_checkpoint = None,
161
- hub_model_id = None,
162
- hub_strategy = 'every_save',
163
- hub_token = None,
164
- hub_private_repo = None,
165
- hub_always_push = False,
166
- gradient_checkpointing = False,
167
- gradient_checkpointing_kwargs = None,
168
- include_inputs_for_metrics = False,
169
- eval_do_concat_batches = True,
170
- fp16_backend = 'auto',
171
- evaluation_strategy = None,
172
- push_to_hub_model_id = None,
173
- push_to_hub_organization = None,
174
- push_to_hub_token = None,
175
- mp_parameters = '',
176
- auto_find_batch_size = False,
177
- full_determinism = False,
178
- torchdynamo = None,
179
- ray_scope = 'last',
180
- ddp_timeout = 1800,
181
- torch_compile = False,
182
- torch_compile_backend = None,
183
- torch_compile_mode = None,
184
- dispatch_batches = None,
185
- split_batches = None,
186
- include_tokens_per_second = False,
187
- include_num_input_tokens_seen = False,
188
- neftune_noise_alpha = None,
189
- optim_target_modules = None,
190
- batch_eval_metrics = False,
191
- eval_on_start = False,
192
- use_liger_kernel = False,
193
- eval_use_gather_object = False,
194
- average_tokens_across_devices = False,
195
- reward_model_path = None,
196
- judge = None,
197
- max_new_tokens = 64,
198
- max_length = 512,
199
- temperature = 0.9,
200
- missing_eos_penalty = None,
201
- loss_type = 'sigmoid',
202
- dataset_num_proc = None,
203
- disable_dropout = True,
204
- use_vllm = False,
205
- ds3_gather_for_generation = True,
206
- vllm_sampling_params = None,
207
- unsloth_num_chunks = -1,
208
- **kwargs,
209
- ):
210
- if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
211
- if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
212
- if output_dir is None and save_strategy == 'steps' and save_steps == 500:
213
- output_dir = 'unsloth_training_checkpoints'
214
- save_strategy = 'no'
215
- if dataset_num_proc is None:
216
- from multiprocessing import cpu_count
217
- dataset_num_proc = cpu_count()
218
-
219
- super().__init__(
220
- output_dir = output_dir,
221
- overwrite_output_dir = overwrite_output_dir,
222
- do_train = do_train,
223
- do_eval = do_eval,
224
- do_predict = do_predict,
225
- eval_strategy = eval_strategy,
226
- prediction_loss_only = prediction_loss_only,
227
- per_device_train_batch_size = per_device_train_batch_size,
228
- per_device_eval_batch_size = per_device_eval_batch_size,
229
- per_gpu_train_batch_size = per_gpu_train_batch_size,
230
- per_gpu_eval_batch_size = per_gpu_eval_batch_size,
231
- gradient_accumulation_steps = gradient_accumulation_steps,
232
- eval_accumulation_steps = eval_accumulation_steps,
233
- eval_delay = eval_delay,
234
- torch_empty_cache_steps = torch_empty_cache_steps,
235
- learning_rate = learning_rate,
236
- weight_decay = weight_decay,
237
- adam_beta1 = adam_beta1,
238
- adam_beta2 = adam_beta2,
239
- adam_epsilon = adam_epsilon,
240
- max_grad_norm = max_grad_norm,
241
- num_train_epochs = num_train_epochs,
242
- max_steps = max_steps,
243
- lr_scheduler_type = lr_scheduler_type,
244
- warmup_ratio = warmup_ratio,
245
- warmup_steps = warmup_steps,
246
- log_level = log_level,
247
- log_level_replica = log_level_replica,
248
- log_on_each_node = log_on_each_node,
249
- logging_dir = logging_dir,
250
- logging_strategy = logging_strategy,
251
- logging_first_step = logging_first_step,
252
- logging_steps = logging_steps,
253
- logging_nan_inf_filter = logging_nan_inf_filter,
254
- save_strategy = save_strategy,
255
- save_steps = save_steps,
256
- save_total_limit = save_total_limit,
257
- save_safetensors = save_safetensors,
258
- save_on_each_node = save_on_each_node,
259
- save_only_model = save_only_model,
260
- restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
261
- no_cuda = no_cuda,
262
- use_cpu = use_cpu,
263
- use_mps_device = use_mps_device,
264
- seed = seed,
265
- data_seed = data_seed,
266
- jit_mode_eval = jit_mode_eval,
267
- use_ipex = use_ipex,
268
- bf16 = bf16,
269
- fp16 = fp16,
270
- fp16_opt_level = fp16_opt_level,
271
- half_precision_backend = half_precision_backend,
272
- bf16_full_eval = bf16_full_eval,
273
- fp16_full_eval = fp16_full_eval,
274
- tf32 = tf32,
275
- local_rank = local_rank,
276
- ddp_backend = ddp_backend,
277
- tpu_num_cores = tpu_num_cores,
278
- tpu_metrics_debug = tpu_metrics_debug,
279
- debug = debug,
280
- dataloader_drop_last = dataloader_drop_last,
281
- eval_steps = eval_steps,
282
- dataloader_num_workers = dataloader_num_workers,
283
- dataloader_prefetch_factor = dataloader_prefetch_factor,
284
- past_index = past_index,
285
- run_name = run_name,
286
- disable_tqdm = disable_tqdm,
287
- remove_unused_columns = remove_unused_columns,
288
- label_names = label_names,
289
- load_best_model_at_end = load_best_model_at_end,
290
- metric_for_best_model = metric_for_best_model,
291
- greater_is_better = greater_is_better,
292
- ignore_data_skip = ignore_data_skip,
293
- fsdp = fsdp,
294
- fsdp_min_num_params = fsdp_min_num_params,
295
- fsdp_config = fsdp_config,
296
- fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
297
- accelerator_config = accelerator_config,
298
- deepspeed = deepspeed,
299
- label_smoothing_factor = label_smoothing_factor,
300
- optim = optim,
301
- optim_args = optim_args,
302
- adafactor = adafactor,
303
- group_by_length = group_by_length,
304
- length_column_name = length_column_name,
305
- report_to = report_to,
306
- ddp_find_unused_parameters = ddp_find_unused_parameters,
307
- ddp_bucket_cap_mb = ddp_bucket_cap_mb,
308
- ddp_broadcast_buffers = ddp_broadcast_buffers,
309
- dataloader_pin_memory = dataloader_pin_memory,
310
- dataloader_persistent_workers = dataloader_persistent_workers,
311
- skip_memory_metrics = skip_memory_metrics,
312
- use_legacy_prediction_loop = use_legacy_prediction_loop,
313
- push_to_hub = push_to_hub,
314
- resume_from_checkpoint = resume_from_checkpoint,
315
- hub_model_id = hub_model_id,
316
- hub_strategy = hub_strategy,
317
- hub_token = hub_token,
318
- hub_private_repo = hub_private_repo,
319
- hub_always_push = hub_always_push,
320
- gradient_checkpointing = gradient_checkpointing,
321
- gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
322
- include_inputs_for_metrics = include_inputs_for_metrics,
323
- eval_do_concat_batches = eval_do_concat_batches,
324
- fp16_backend = fp16_backend,
325
- evaluation_strategy = evaluation_strategy,
326
- push_to_hub_model_id = push_to_hub_model_id,
327
- push_to_hub_organization = push_to_hub_organization,
328
- push_to_hub_token = push_to_hub_token,
329
- mp_parameters = mp_parameters,
330
- auto_find_batch_size = auto_find_batch_size,
331
- full_determinism = full_determinism,
332
- torchdynamo = torchdynamo,
333
- ray_scope = ray_scope,
334
- ddp_timeout = ddp_timeout,
335
- torch_compile = torch_compile,
336
- torch_compile_backend = torch_compile_backend,
337
- torch_compile_mode = torch_compile_mode,
338
- dispatch_batches = dispatch_batches,
339
- split_batches = split_batches,
340
- include_tokens_per_second = include_tokens_per_second,
341
- include_num_input_tokens_seen = include_num_input_tokens_seen,
342
- neftune_noise_alpha = neftune_noise_alpha,
343
- optim_target_modules = optim_target_modules,
344
- batch_eval_metrics = batch_eval_metrics,
345
- eval_on_start = eval_on_start,
346
- use_liger_kernel = use_liger_kernel,
347
- eval_use_gather_object = eval_use_gather_object,
348
- average_tokens_across_devices = average_tokens_across_devices,
349
- reward_model_path = reward_model_path,
350
- judge = judge,
351
- max_new_tokens = max_new_tokens,
352
- max_length = max_length,
353
- temperature = temperature,
354
- missing_eos_penalty = missing_eos_penalty,
355
- loss_type = loss_type,
356
- dataset_num_proc = dataset_num_proc,
357
- disable_dropout = disable_dropout,
358
- use_vllm = use_vllm,
359
- ds3_gather_for_generation = ds3_gather_for_generation,**kwargs)
360
- self.vllm_sampling_params = vllm_sampling_params
361
- self.unsloth_num_chunks = unsloth_num_chunks
362
- pass
363
-
364
- class _UnslothXPOTrainer(OnlineDPOTrainer):
365
- r""""""
366
-
367
- _tag_names = ["trl", "xpo"]
368
-
369
- def __init__(
370
- self,
371
- model: Union[PreTrainedModel, nn.Module] = None,
372
- ref_model: Union[PreTrainedModel, nn.Module] = None,
373
- reward_model: Optional[nn.Module] = None,
374
- judge: Optional[BasePairwiseJudge] = None,
375
- args: Optional[XPOConfig] = None,
376
- data_collator: Optional[Callable] = None,
377
- train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
378
- eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
379
- processing_class: Optional[
380
- Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
381
- ] = None,
382
- peft_config: Optional[dict] = None,
383
- compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
384
- callbacks: Optional[list[TrainerCallback]] = None,
385
- optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
386
- preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
387
- ) -> None:
388
- super().__init__(
389
- model=model,
390
- ref_model=ref_model,
391
- judge=judge,
392
- reward_model=reward_model,
393
- args=args,
394
- data_collator=data_collator,
395
- train_dataset=train_dataset,
396
- eval_dataset=eval_dataset,
397
- processing_class=processing_class,
398
- reward_processing_class=processing_class, # for now, XPOTrainer can't use any reward model
399
- peft_config=peft_config,
400
- compute_metrics=compute_metrics,
401
- callbacks=callbacks,
402
- optimizers=optimizers,
403
- preprocess_logits_for_metrics=preprocess_logits_for_metrics,
404
- )
405
-
406
- self._alpha = self.args.alpha
407
-
408
- # Overwrite the stats dictionary to include XPO specific statistics
409
- self.stats = {
410
- # Remove "non_score_reward", "rlhf_reward", "scores"
411
- # Add "loss/dpo", "loss/xpo"
412
- "loss/dpo": [],
413
- "loss/xpo": [],
414
- "objective/kl": [],
415
- "objective/entropy": [],
416
- "rewards/chosen": [],
417
- "rewards/rejected": [],
418
- "rewards/accuracies": [],
419
- "rewards/margins": [],
420
- "logps/chosen": [],
421
- "logps/rejected": [],
422
- # Replace "contain_eos_token" by "model_contain_eos_token" and "ref_contain_eos_token"
423
- "val/model_contain_eos_token": [],
424
- "val/ref_contain_eos_token": [],
425
- "alpha": [],
426
- "beta": [],
427
- }
428
- if self.reward_model is not None:
429
- # Replace "scores" by "model_scores" and "ref_scores"
430
- self.stats["objective/model_scores"] = []
431
- self.stats["objective/ref_scores"] = []
432
- self.stats["objective/scores_margin"] = []
433
-
434
- @property
435
- def alpha(self):
436
- if isinstance(self._alpha, list):
437
- epoch = self.state.epoch
438
- return self._alpha[epoch] if epoch < len(self._alpha) else self._alpha[-1]
439
- else:
440
- return self._alpha
441
-
442
- def _generate_completions(self, prompts, model):
443
- with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
444
- model_output = unwrapped_model.generate(
445
- input_ids=prompts["input_ids"],
446
- attention_mask=prompts["attention_mask"],
447
- generation_config=self.generation_config,
448
- )
449
-
450
- ref_model = model if self.ref_model is None else self.ref_model
451
- with torch.no_grad(), unwrap_model_for_generation(ref_model, self.accelerator) as unwrapped_ref_model:
452
- ref_output = unwrapped_ref_model.generate(
453
- input_ids=prompts["input_ids"],
454
- attention_mask=prompts["attention_mask"],
455
- generation_config=self.generation_config,
456
- )
457
-
458
- return model_output, ref_output
459
-
460
- def _process_completions(self, model_output, ref_output, prompts):
461
- context_length = prompts["input_ids"].shape[1]
462
-
463
- # Process model completions
464
- model_completion_ids = model_output[:, context_length:]
465
- model_completion_ids, model_completion_mask = truncate_right(
466
- model_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id
467
- )
468
- model_data = {
469
- "input_ids": torch.cat((prompts["input_ids"], model_completion_ids), dim=1),
470
- "attention_mask": torch.cat((prompts["attention_mask"], model_completion_mask), dim=1),
471
- "raw": prompts["raw"],
472
- }
473
-
474
- # Process reference model completions
475
- ref_completion_ids = ref_output[:, context_length:]
476
- ref_completion_ids, ref_completion_mask = truncate_right(
477
- ref_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id
478
- )
479
- ref_data = {
480
- "input_ids": torch.cat((prompts["input_ids"], ref_completion_ids), dim=1),
481
- "attention_mask": torch.cat((prompts["attention_mask"], ref_completion_mask), dim=1),
482
- "raw": prompts["raw"],
483
- }
484
-
485
- return model_data, ref_data
486
-
487
- def _compute_rewards(self, model_data, ref_data, context_length):
488
- with torch.no_grad():
489
- _, model_scores, _ = get_reward(
490
- self.reward_model, model_data["input_ids"], self.processing_class.pad_token_id, context_length
491
- )
492
- _, ref_scores, _ = get_reward(
493
- self.reward_model, ref_data["input_ids"], self.processing_class.pad_token_id, context_length
494
- )
495
-
496
- # Apply EOS penalty if needed
497
- if self.args.missing_eos_penalty is not None:
498
- model_contain_eos = torch.any(model_data["input_ids"] == self.processing_class.eos_token_id, dim=-1)
499
- ref_contain_eos = torch.any(ref_data["input_ids"] == self.processing_class.eos_token_id, dim=-1)
500
- model_scores[~model_contain_eos] -= self.args.missing_eos_penalty
501
- ref_scores[~ref_contain_eos] -= self.args.missing_eos_penalty
502
-
503
- return model_scores, ref_scores
504
-
505
- def _compute_judge(self, model_data, ref_data, context_length):
506
- prompts = model_data["raw"]
507
- model_data_completions = self.processing_class.batch_decode(
508
- model_data["input_ids"][:, context_length:], skip_special_tokens=True
509
- )
510
- model_data_completions = [completion.strip() for completion in model_data_completions]
511
-
512
- ref_data_completions = self.processing_class.batch_decode(
513
- ref_data["input_ids"][:, context_length:], skip_special_tokens=True
514
- )
515
- ref_data_completions = [completion.strip() for completion in ref_data_completions]
516
-
517
- if is_conversational({"prompt": prompts[0]}):
518
- model_data_completions = [
519
- [{"role": "assistant", "content": completion}] for completion in model_data_completions
520
- ]
521
- environment = jinja2.Environment()
522
- template = environment.from_string(SIMPLE_CHAT_TEMPLATE)
523
- prompts = [template.render(messages=message) for message in prompts]
524
- model_data_completions = [template.render(messages=completion) for completion in model_data_completions]
525
-
526
- ref_data_completions = [
527
- [{"role": "assistant", "content": completion}] for completion in ref_data_completions
528
- ]
529
- ref_data_completions = [template.render(messages=completion) for completion in ref_data_completions]
530
-
531
- ranks_of_first_completion = self.judge.judge(
532
- prompts,
533
- list(zip(model_data_completions, ref_data_completions)),
534
- )
535
- # convert ranks to a True/False mask:
536
- # when rank == 0, it means the first completion is the best
537
- # when rank == 1, it means the second completion is the best
538
- return torch.tensor([rank == 0 for rank in ranks_of_first_completion], device=model_data["input_ids"].device)
539
-
540
- def _compute_logprobs(self, model, model_data, ref_data, context_length):
541
- def compute_logprobs_for_data(m, data):
542
- output = m(data["input_ids"], attention_mask=data["attention_mask"])
543
- logits = output.logits[:, context_length - 1 : -1]
544
- token_logprobs = selective_log_softmax(logits, data["input_ids"][:, context_length:])
545
- return token_logprobs
546
-
547
- # Compute logprobs for model completions
548
- model_logprobs_model_data = compute_logprobs_for_data(model, model_data)
549
- # Compute logprobs for model on reference completions (for XPO loss)
550
- model_logprobs_ref_data = compute_logprobs_for_data(model, ref_data)
551
-
552
- # Compute logprobs for reference model completions
553
- with torch.no_grad():
554
- if self.ref_model is None:
555
- with model.disable_adapter():
556
- ref_logprobs_model_data = compute_logprobs_for_data(model, model_data)
557
- ref_logprobs_ref_data = compute_logprobs_for_data(model, ref_data)
558
- else:
559
- ref_logprobs_model_data = compute_logprobs_for_data(self.ref_model, model_data)
560
- ref_logprobs_ref_data = compute_logprobs_for_data(self.ref_model, ref_data)
561
-
562
- # Mask padding tokens
563
- model_padding_mask = model_data["attention_mask"][:, context_length:] == 0
564
- ref_padding_mask = ref_data["attention_mask"][:, context_length:] == 0
565
- model_logprobs_model_data = model_logprobs_model_data.masked_fill(model_padding_mask, 0.0)
566
- model_logprobs_ref_data = model_logprobs_ref_data.masked_fill(ref_padding_mask, 0.0)
567
- ref_logprobs_ref_data = ref_logprobs_ref_data.masked_fill(ref_padding_mask, 0.0)
568
- ref_logprobs_model_data = ref_logprobs_model_data.masked_fill(model_padding_mask, 0.0)
569
-
570
- return model_logprobs_model_data, model_logprobs_ref_data, ref_logprobs_ref_data, ref_logprobs_model_data
571
-
572
- def _compute_losses(
573
- self,
574
- model_logprobs_model_data,
575
- model_logprobs_ref_data,
576
- ref_logprobs_ref_data,
577
- ref_logprobs_model_data,
578
- chosen_mask,
579
- ):
580
- # Compute log probs
581
- model_logprobs_model_data_sum = model_logprobs_model_data.sum(1)
582
- model_logprobs_ref_data_sum = model_logprobs_ref_data.sum(1)
583
- ref_logprobs_ref_data_sum = ref_logprobs_ref_data.sum(1)
584
- ref_logprobs_model_data_sum = ref_logprobs_model_data.sum(1)
585
-
586
- chosen_model_logprobs = torch.where(chosen_mask, model_logprobs_model_data_sum, model_logprobs_ref_data_sum)
587
- chosen_ref_logprobs = torch.where(chosen_mask, ref_logprobs_model_data_sum, ref_logprobs_ref_data_sum)
588
- chosen_log_ratios = chosen_model_logprobs - chosen_ref_logprobs
589
-
590
- rejected_model_logprobs = torch.where(~chosen_mask, model_logprobs_model_data_sum, model_logprobs_ref_data_sum)
591
- rejected_ref_logprobs = torch.where(~chosen_mask, ref_logprobs_model_data_sum, ref_logprobs_ref_data_sum)
592
- rejected_log_ratios = rejected_model_logprobs - rejected_ref_logprobs
593
-
594
- # Compute logits as the difference between chosen and rejected log ratios
595
- logits = chosen_log_ratios - rejected_log_ratios
596
-
597
- if self.args.loss_type == "sigmoid":
598
- dpo_losses = -F.logsigmoid(self.beta * logits)
599
- elif self.args.loss_type == "ipo":
600
- dpo_losses = (logits - 1 / (2 * self.beta)) ** 2
601
- else:
602
- raise NotImplementedError(f"invalid loss type {self.args.loss_type}")
603
-
604
- # Compute XPO specific loss
605
- xpo_losses = self.alpha * model_logprobs_ref_data_sum
606
-
607
- # Total loss
608
- loss = (dpo_losses + xpo_losses).mean()
609
-
610
- return loss, dpo_losses, xpo_losses
611
-
612
- def _log_statistics(
613
- self,
614
- model_data,
615
- ref_data,
616
- model_logprobs_model_data,
617
- model_logprobs_ref_data,
618
- ref_logprobs_ref_data,
619
- ref_logprobs_model_data,
620
- chosen_mask,
621
- dpo_losses,
622
- xpo_losses,
623
- context_length,
624
- model_scores=None,
625
- ref_scores=None,
626
- ):
627
- # Helper function to gather and compute mean
628
- def gather_mean(tensor):
629
- return self.accelerator.gather_for_metrics(tensor).mean().item()
630
-
631
- # Log losses
632
- self.stats["loss/dpo"].append(gather_mean(dpo_losses))
633
- self.stats["loss/xpo"].append(gather_mean(xpo_losses))
634
-
635
- # Log scores
636
- if self.reward_model is not None:
637
- self.stats["objective/model_scores"].append(gather_mean(model_scores))
638
- self.stats["objective/ref_scores"].append(gather_mean(ref_scores))
639
- self.stats["objective/scores_margin"].append(gather_mean(model_scores - ref_scores))
640
-
641
- # Log logprobs
642
- model_logprobs_model_data_sum = model_logprobs_model_data.sum(1)
643
- model_logprobs_ref_data_sum = model_logprobs_ref_data.sum(1)
644
- ref_logprobs_ref_data_sum = ref_logprobs_ref_data.sum(1)
645
- ref_logprobs_model_data_sum = ref_logprobs_model_data.sum(1)
646
-
647
- chosen_model_logprobs = torch.where(chosen_mask, model_logprobs_model_data_sum, model_logprobs_ref_data_sum)
648
- chosen_ref_logprobs = torch.where(chosen_mask, ref_logprobs_model_data_sum, ref_logprobs_ref_data_sum)
649
- chosen_log_ratios = chosen_model_logprobs - chosen_ref_logprobs
650
-
651
- rejected_model_logprobs = torch.where(~chosen_mask, model_logprobs_model_data_sum, model_logprobs_ref_data_sum)
652
- rejected_ref_logprobs = torch.where(~chosen_mask, ref_logprobs_model_data_sum, ref_logprobs_ref_data_sum)
653
- rejected_log_ratios = rejected_model_logprobs - rejected_ref_logprobs
654
-
655
- self.stats["logps/chosen"].append(gather_mean(chosen_model_logprobs.mean() + chosen_ref_logprobs.mean()))
656
- self.stats["logps/rejected"].append(gather_mean(rejected_model_logprobs.mean() + rejected_ref_logprobs.mean()))
657
-
658
- # Log rewards
659
- # Compute various statistics
660
- chosen_rewards = chosen_log_ratios * self.beta
661
- rejected_rewards = rejected_log_ratios * self.beta
662
- self.stats["rewards/chosen"].append(gather_mean(chosen_rewards.mean()))
663
- self.stats["rewards/rejected"].append(gather_mean(rejected_rewards.mean()))
664
-
665
- # Calculate KL divergence for model and ref data
666
- kl_model_data = model_logprobs_model_data - ref_logprobs_model_data
667
- kl_ref_data = model_logprobs_ref_data - ref_logprobs_ref_data
668
- mean_kl = (kl_model_data.sum(1) + kl_ref_data.sum(1)).mean() / 2
669
- self.stats["objective/kl"].append(gather_mean(mean_kl))
670
-
671
- # Calculate entropy for model and ref data
672
- entropy_model_data = -model_logprobs_model_data.sum(1)
673
- entropy_ref_data = -model_logprobs_ref_data.sum(1)
674
- mean_entropy = (entropy_model_data.mean() + entropy_ref_data.mean()) / 2
675
- self.stats["objective/entropy"].append(gather_mean(mean_entropy))
676
-
677
- # Calculate margins
678
- margin = chosen_rewards - rejected_rewards
679
- self.stats["rewards/margins"].append(gather_mean(margin.mean()))
680
-
681
- # Calculate accuracy
682
- accuracy = (margin > 0).float()
683
- self.stats["rewards/accuracies"].append(gather_mean(accuracy.mean()))
684
-
685
- # Log EOS token statistics
686
- model_eos = (model_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1)
687
- ref_eos = (ref_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1)
688
- self.stats["val/model_contain_eos_token"].append(gather_mean(model_eos.float()))
689
- self.stats["val/ref_contain_eos_token"].append(gather_mean(ref_eos.float()))
690
-
691
- # Log alpha and beta
692
- self.stats["alpha"].append(self.alpha)
693
- self.stats["beta"].append(self.beta)
694
-
695
- def training_step(
696
- self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
697
- ) -> torch.Tensor:
698
- model.train()
699
-
700
- # Apply chat template and tokenize the input
701
- batch_size = len(next(iter(inputs.values())))
702
- prompts = inputs["prompt"]
703
- inputs = [{k: v[i] for k, v in inputs.items()} for i in range(batch_size)]
704
- inputs = [maybe_apply_chat_template(x, self.processing_class) for x in inputs]
705
- inputs = [self.tokenize_row(x, self.model.config.is_encoder_decoder, self.processing_class) for x in inputs]
706
- inputs = self.data_collator(inputs)
707
-
708
- # need the prompt_ only
709
- inputs = self._prepare_inputs(inputs)
710
- context_length = inputs["prompt_input_ids"].shape[1]
711
- prompts = {
712
- "input_ids": inputs["prompt_input_ids"],
713
- "attention_mask": inputs["prompt_attention_mask"],
714
- "raw": prompts,
715
- }
716
- del inputs
717
-
718
- # Sample completions from both the model and the reference model
719
- model_output, ref_output = self._generate_completions(prompts, model)
720
-
721
- # Process model completions
722
- model_data, ref_data = self._process_completions(model_output, ref_output, prompts)
723
-
724
- # Compute rewards
725
- if self.reward_model is not None:
726
- model_scores, ref_scores = self._compute_rewards(model_data, ref_data, context_length)
727
- chosen_mask = model_scores >= ref_scores
728
- else:
729
- model_scores, ref_scores = None, None
730
- chosen_mask = self._compute_judge(model_data, ref_data, context_length)
731
-
732
- # Compute logprobs
733
- model_logprobs_model_data, model_logprobs_ref_data, ref_logprobs_ref_data, ref_logprobs_model_data = (
734
- self._compute_logprobs(model, model_data, ref_data, context_length)
735
- )
736
-
737
- # Compute loss
738
- loss, dpo_losses, xpo_losses = self._compute_losses(
739
- model_logprobs_model_data,
740
- model_logprobs_ref_data,
741
- ref_logprobs_ref_data,
742
- ref_logprobs_model_data,
743
- chosen_mask,
744
- )
745
-
746
- # Log everything
747
- self._log_statistics(
748
- model_data,
749
- ref_data,
750
- model_logprobs_model_data.detach(),
751
- model_logprobs_ref_data.detach(),
752
- ref_logprobs_ref_data,
753
- ref_logprobs_model_data,
754
- chosen_mask,
755
- dpo_losses.detach(),
756
- xpo_losses.detach(),
757
- context_length,
758
- model_scores,
759
- ref_scores,
760
- )
761
-
762
- if (
763
- self.args.torch_empty_cache_steps is not None
764
- and self.state.global_step % self.args.torch_empty_cache_steps == 0
765
- ):
766
- empty_cache()
767
-
768
- kwargs = {}
769
- # For LOMO optimizers you need to explicitly use the learning rate
770
- if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
771
- kwargs["learning_rate"] = self._get_learning_rate()
772
-
773
- if self.args.n_gpu > 1:
774
- loss = loss.mean() # mean() to average on multi-gpu parallel training
775
-
776
- if self.use_apex:
777
- with amp.scale_loss(loss, self.optimizer) as scaled_loss:
778
- scaled_loss.backward()
779
- else:
780
- self.accelerator.backward(loss, **kwargs)
781
-
782
- return loss.detach() / self.args.gradient_accumulation_steps
783
-
784
- def create_model_card(
785
- self,
786
- model_name: Optional[str] = None,
787
- dataset_name: Optional[str] = None,
788
- tags: Union[str, list[str], None] = None,
789
- ):
790
- """
791
- Creates a draft of a model card using the information available to the `Trainer`.
792
-
793
- Args:
794
- model_name (`str` or `None`, *optional*, defaults to `None`):
795
- Name of the model.
796
- dataset_name (`str` or `None`, *optional*, defaults to `None`):
797
- Name of the dataset used for training.
798
- tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
799
- Tags to be associated with the model card.
800
- """
801
- if not self.is_world_process_zero():
802
- return
803
-
804
- if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
805
- base_model = self.model.config._name_or_path
806
- else:
807
- base_model = None
808
-
809
- tags = tags or []
810
- if isinstance(tags, str):
811
- tags = [tags]
812
-
813
- if hasattr(self.model.config, "unsloth_version"):
814
- tags.append("unsloth")
815
-
816
- citation = textwrap.dedent("""\
817
- @article{jung2024binary,
818
- title = {{Exploratory Preference Optimization: Harnessing Implicit Q*-Approximation for Sample-Efficient RLHF}},
819
- author = {Tengyang Xie and Dylan J. Foster and Akshay Krishnamurthy and Corby Rosset and Ahmed Awadallah and Alexander Rakhlin},
820
- year = 2024,
821
- eprint = {arXiv:2405.21046}
822
- }""")
823
-
824
- model_card = generate_model_card(
825
- base_model=base_model,
826
- model_name=model_name,
827
- hub_model_id=self.hub_model_id,
828
- dataset_name=dataset_name,
829
- tags=tags,
830
- wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
831
- comet_url=get_comet_experiment_url(),
832
- trainer_name="XPO",
833
- trainer_citation=citation,
834
- paper_title="Exploratory Preference Optimization: Harnessing Implicit Q*-Approximation for Sample-Efficient RLHF",
835
- paper_id="2405.21046",
836
- )
837
-
838
- model_card.save(os.path.join(self.args.output_dir, "README.md"))
839
- class UnslothXPOTrainer(_UnslothXPOTrainer):
840
- """
841
-
842
- Initialize XPOTrainer as a subclass of [`OnlineDPOConfig`].
843
-
844
- Args:
845
- model (`transformers.PreTrainedModel`):
846
- The model to train, preferably an `AutoModelForCausalLM`.
847
- ref_model (`PreTrainedModelWrapper`):
848
- Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no
849
- reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized.
850
- reward_model (`transformers.PreTrainedModel`):
851
- The reward model to score completions with, preferably an `AutoModelForSequenceClassification`.
852
- judge (`BasePairwiseJudge`):
853
- The judge to use for pairwise comparison of model completions.
854
- args (`XPOConfig`):
855
- The XPO config arguments to use for training.
856
- data_collator (`transformers.DataCollator`):
857
- The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
858
- which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
859
- train_dataset (`datasets.Dataset`):
860
- The dataset to use for training.
861
- eval_dataset (`datasets.Dataset`):
862
- The dataset to use for evaluation.
863
- processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
864
- Processing class used to process the data. If provided, will be used to automatically process the inputs
865
- for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
866
- reuse the fine-tuned model.
867
- peft_config (`dict`):
868
- The peft config to use for training.
869
- compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
870
- The function to use to compute the metrics. Must take a `EvalPrediction` and return
871
- a dictionary string to metric values.
872
- callbacks (`list[transformers.TrainerCallback]`):
873
- The callbacks to use for training.
874
- optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
875
- The optimizer and scheduler to use for training.
876
- preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
877
- The function to use to preprocess the logits before computing the metrics.
878
-
879
- """
880
- def __init__(
881
- self,
882
- model = None,
883
- ref_model = None,
884
- reward_model = None,
885
- judge = None,
886
- args = None,
887
- data_collator = None,
888
- train_dataset = None,
889
- eval_dataset = None,
890
- processing_class = None,
891
- peft_config = None,
892
- compute_metrics = None,
893
- callbacks = None,
894
- preprocess_logits_for_metrics = None,
895
- **kwargs
896
- ):
897
- if args is None: args = UnslothXPOConfig()
898
- use_bf16 = getattr(args, 'bf16', False)
899
- use_fp16 = getattr(args, 'fp16', False)
900
- force_float32 = False
901
- if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
902
- print('Unsloth: Switching to float32 training since model cannot work with float16')
903
- force_float32 = True
904
- mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
905
- dtype = getattr(model.config, 'torch_dtype', None)
906
- if dtype is None: dtype = model.get_input_embeddings().dtype
907
- from unsloth_zoo.utils import _get_dtype
908
- dtype = _get_dtype(dtype)
909
- float16 = dtype == torch.float16
910
- if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
911
- if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
912
- if force_float32:
913
- args.fp16 = False
914
- args.bf16 = False
915
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
916
- elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
917
- args.fp16 = float16
918
- args.bf16 = not float16
919
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
920
- if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
921
- args.eval_strategy = 'steps'
922
- if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
923
- ga_steps = getattr(args, 'gradient_accumulation_steps', None)
924
- if ga_steps is not None and ga_steps > 1:
925
- from transformers import __version__ as transformers_version
926
- if Version(transformers_version) <= Version('4.45.2'):
927
- print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
928
- '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
929
- if getattr(args, 'eval_strategy', 'no') != 'no':
930
- eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
931
- if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
932
- if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
933
- fp16_full_eval = getattr(args, 'fp16_full_eval', False)
934
- bf16_full_eval = getattr(args, 'bf16_full_eval', False)
935
- if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
936
- if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
937
- if force_float32:
938
- args.bf16_full_eval = False
939
- args.fp16_full_eval = False
940
- elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
941
- args.bf16_full_eval = True
942
- args.fp16_full_eval = False
943
- elif not bf16_full_eval and not fp16_full_eval:
944
- args.bf16_full_eval = args.bf16
945
- args.fp16_full_eval = args.fp16
946
- _output_logits = False
947
- if locals().get('compute_metrics', None) is not None: _output_logits = True
948
- if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
949
- if _output_logits:
950
- os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
951
- if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
952
- pass
953
- else:
954
- model_max_seq_length = getattr(model, 'max_seq_length', None)
955
- args_max_seq_length = getattr(args, 'max_seq_length', None)
956
- if args_max_seq_length is None and model_max_seq_length is not None:
957
- max_seq_length = model.max_seq_length
958
- if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
959
- if model is not None and hasattr(model, 'for_training'):
960
- model.for_training()
961
- if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
962
- if 'processing_class' in locals():
963
- if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
964
- if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
965
- __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
966
- from unsloth_zoo.vision_utils import UnslothVisionDataCollator
967
- if not isinstance(data_collator, UnslothVisionDataCollator):
968
- if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
969
- data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
970
- elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
971
- data_collator = DataCollatorForSeq2Seq(__tokenizer)
972
- else:
973
- if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
974
- if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
975
- if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
976
- if not isinstance(data_collator, UnslothVisionDataCollator):
977
- if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
978
- if isinstance(data_collator, DataCollatorForSeq2Seq):
979
- data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
980
- else:
981
- data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
982
- other_metrics = []
983
-
984
- from unsloth_zoo.logging_utils import PatchRLStatistics
985
- PatchRLStatistics('xpo_trainer', other_metrics)
986
-
987
- super().__init__(
988
- model = model,
989
- ref_model = ref_model,
990
- reward_model = reward_model,
991
- judge = judge,
992
- args = args,
993
- data_collator = data_collator,
994
- train_dataset = train_dataset,
995
- eval_dataset = eval_dataset,
996
- processing_class = processing_class,
997
- peft_config = peft_config,
998
- compute_metrics = compute_metrics,
999
- callbacks = callbacks,
1000
- preprocess_logits_for_metrics = preprocess_logits_for_metrics,**kwargs)
1001
- if hasattr(self, 'neftune_hook_handle'):
1002
- self.neftune_hook_handle.remove()
1003
- if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1004
- if getattr(args, 'neftune_noise_alpha', None) is not None:
1005
- model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1006
- pass
1007
-
1008
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
unsloth_compiled_cache/__pycache__/UnslothAlignPropTrainer.cpython-311.pyc DELETED
Binary file (32.9 kB)
 
unsloth_compiled_cache/__pycache__/UnslothBCOTrainer.cpython-311.pyc DELETED
Binary file (91.7 kB)
 
unsloth_compiled_cache/__pycache__/UnslothCPOTrainer.cpython-311.pyc DELETED
Binary file (75.6 kB)
 
unsloth_compiled_cache/__pycache__/UnslothDDPOTrainer.cpython-311.pyc DELETED
Binary file (45.5 kB)
 
unsloth_compiled_cache/__pycache__/UnslothDPOTrainer.cpython-311.pyc DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:866cfa82efe9b576ed95d6c36ad4eb1e47599692a7b674f9acbb3e671b4d5178
3
- size 103496
 
 
 
 
unsloth_compiled_cache/__pycache__/UnslothGKDTrainer.cpython-311.pyc DELETED
Binary file (37.7 kB)
 
unsloth_compiled_cache/__pycache__/UnslothGRPOTrainer.cpython-311.pyc DELETED
Binary file (78.4 kB)
 
unsloth_compiled_cache/__pycache__/UnslothKTOTrainer.cpython-311.pyc DELETED
Binary file (87.3 kB)
 
unsloth_compiled_cache/__pycache__/UnslothNashMDTrainer.cpython-311.pyc DELETED
Binary file (47.2 kB)
 
unsloth_compiled_cache/__pycache__/UnslothORPOTrainer.cpython-311.pyc DELETED
Binary file (75.5 kB)
 
unsloth_compiled_cache/__pycache__/UnslothOnlineDPOTrainer.cpython-311.pyc DELETED
Binary file (67.1 kB)
 
unsloth_compiled_cache/__pycache__/UnslothPPOTrainer.cpython-311.pyc DELETED
Binary file (62.6 kB)
 
unsloth_compiled_cache/__pycache__/UnslothPRMTrainer.cpython-311.pyc DELETED
Binary file (36.4 kB)
 
unsloth_compiled_cache/__pycache__/UnslothRLOOTrainer.cpython-311.pyc DELETED
Binary file (54.2 kB)
 
unsloth_compiled_cache/__pycache__/UnslothRewardTrainer.cpython-311.pyc DELETED
Binary file (38.9 kB)
 
unsloth_compiled_cache/__pycache__/UnslothSFTTrainer.cpython-311.pyc DELETED
Binary file (47.8 kB)
 
unsloth_compiled_cache/__pycache__/UnslothXPOTrainer.cpython-311.pyc DELETED
Binary file (49.9 kB)