nirbo commited on
Commit
dc2b56f
·
verified ·
1 Parent(s): a3580d9

Upload folder using huggingface_hub

Browse files
Files changed (49) hide show
  1. .gitattributes +1 -0
  2. .gitignore +3 -0
  3. .gradio/certificate.pem +31 -0
  4. README.md +224 -12
  5. dataset_utils.py +203 -0
  6. gguf_utils.py +18 -0
  7. inference_utils.py +25 -0
  8. main.py +5 -0
  9. model_utils.py +60 -0
  10. openai_sample_dataset.json +178 -0
  11. project_plan.md +104 -0
  12. requirements.txt +15 -0
  13. training_utils.py +75 -0
  14. ui.py +231 -0
  15. unsloth_compiled_cache/UnslothAlignPropTrainer.py +637 -0
  16. unsloth_compiled_cache/UnslothBCOTrainer.py +1824 -0
  17. unsloth_compiled_cache/UnslothCPOTrainer.py +1557 -0
  18. unsloth_compiled_cache/UnslothDDPOTrainer.py +872 -0
  19. unsloth_compiled_cache/UnslothDPOTrainer.py +0 -0
  20. unsloth_compiled_cache/UnslothGKDTrainer.py +863 -0
  21. unsloth_compiled_cache/UnslothGRPOTrainer.py +1438 -0
  22. unsloth_compiled_cache/UnslothKTOTrainer.py +1840 -0
  23. unsloth_compiled_cache/UnslothNashMDTrainer.py +955 -0
  24. unsloth_compiled_cache/UnslothORPOTrainer.py +1543 -0
  25. unsloth_compiled_cache/UnslothOnlineDPOTrainer.py +1269 -0
  26. unsloth_compiled_cache/UnslothPPOTrainer.py +1259 -0
  27. unsloth_compiled_cache/UnslothPRMTrainer.py +800 -0
  28. unsloth_compiled_cache/UnslothRLOOTrainer.py +1133 -0
  29. unsloth_compiled_cache/UnslothRewardTrainer.py +819 -0
  30. unsloth_compiled_cache/UnslothSFTTrainer.py +1031 -0
  31. unsloth_compiled_cache/UnslothXPOTrainer.py +1010 -0
  32. unsloth_compiled_cache/__pycache__/UnslothAlignPropTrainer.cpython-311.pyc +0 -0
  33. unsloth_compiled_cache/__pycache__/UnslothBCOTrainer.cpython-311.pyc +0 -0
  34. unsloth_compiled_cache/__pycache__/UnslothCPOTrainer.cpython-311.pyc +0 -0
  35. unsloth_compiled_cache/__pycache__/UnslothDDPOTrainer.cpython-311.pyc +0 -0
  36. unsloth_compiled_cache/__pycache__/UnslothDPOTrainer.cpython-311.pyc +3 -0
  37. unsloth_compiled_cache/__pycache__/UnslothGKDTrainer.cpython-311.pyc +0 -0
  38. unsloth_compiled_cache/__pycache__/UnslothGRPOTrainer.cpython-311.pyc +0 -0
  39. unsloth_compiled_cache/__pycache__/UnslothKTOTrainer.cpython-311.pyc +0 -0
  40. unsloth_compiled_cache/__pycache__/UnslothNashMDTrainer.cpython-311.pyc +0 -0
  41. unsloth_compiled_cache/__pycache__/UnslothORPOTrainer.cpython-311.pyc +0 -0
  42. unsloth_compiled_cache/__pycache__/UnslothOnlineDPOTrainer.cpython-311.pyc +0 -0
  43. unsloth_compiled_cache/__pycache__/UnslothPPOTrainer.cpython-311.pyc +0 -0
  44. unsloth_compiled_cache/__pycache__/UnslothPRMTrainer.cpython-311.pyc +0 -0
  45. unsloth_compiled_cache/__pycache__/UnslothRLOOTrainer.cpython-311.pyc +0 -0
  46. unsloth_compiled_cache/__pycache__/UnslothRewardTrainer.cpython-311.pyc +0 -0
  47. unsloth_compiled_cache/__pycache__/UnslothSFTTrainer.cpython-311.pyc +0 -0
  48. unsloth_compiled_cache/__pycache__/UnslothXPOTrainer.cpython-311.pyc +0 -0
  49. upload_utils.py +77 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ unsloth_compiled_cache/__pycache__/UnslothDPOTrainer.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ __pycache__/
2
+ outputs/
3
+ results/
.gradio/certificate.pem ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -----BEGIN CERTIFICATE-----
2
+ MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
3
+ TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
4
+ cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
5
+ WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
6
+ ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
7
+ MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
8
+ h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
9
+ 0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
10
+ A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
11
+ T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
12
+ B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
13
+ B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
14
+ KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
15
+ OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
16
+ jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
17
+ qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
18
+ rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
19
+ HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
20
+ hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
21
+ ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
22
+ 3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
23
+ NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
24
+ ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
25
+ TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
26
+ jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
27
+ oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
28
+ 4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
29
+ mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
30
+ emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
31
+ -----END CERTIFICATE-----
README.md CHANGED
@@ -1,12 +1,224 @@
1
- ---
2
- title: Finetune Test
3
- emoji: 🔥
4
- colorFrom: green
5
- colorTo: purple
6
- sdk: gradio
7
- sdk_version: 5.23.2
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Finetune-Test
3
+ app_file: ui.py
4
+ sdk: gradio
5
+ sdk_version: 5.23.2
6
+ ---
7
+ # LLM Finetuner
8
+
9
+ This project provides a user-friendly interface for fine-tuning Large Language Models (LLMs) using the Unsloth library. It includes features for dataset preparation, synthetic dataset creation, model training, testing, and GGUF conversion.
10
+
11
+ ## Features
12
+
13
+ - Load and fine-tune various pre-trained models
14
+ - Prepare existing datasets or create synthetic datasets
15
+ - Fine-tune models with customizable hyperparameters
16
+ - Test fine-tuned models
17
+ - Convert models to GGUF format for deployment
18
+
19
+ ## Prerequisites
20
+
21
+ - Python 3.8 or higher
22
+ - CUDA-capable GPU (for efficient training)
23
+
24
+ ## Installation
25
+
26
+ 1. Clone the repository:
27
+ ```
28
+ git clone https://github.com/yourusername/llm-finetuner.git
29
+ cd llm-finetuner
30
+ ```
31
+
32
+ 2. Create a virtual environment (optional but recommended):
33
+ ```
34
+ python -m venv venv
35
+ source venv/bin/activate # On Windows, use `venv\Scripts\activate`
36
+ ```
37
+
38
+ 3. Install the required packages:
39
+ ```
40
+ pip install -r requirements.txt
41
+ ```
42
+
43
+ ## Usage
44
+
45
+ 1. Run the application:
46
+ ```
47
+ python main.py
48
+ ```
49
+
50
+ 2. Open the provided URL in your web browser to access the Gradio interface.
51
+
52
+ 3. Follow these steps in the interface:
53
+ a. Settings: Enter your Hugging Face token and select a model.
54
+ b. Dataset: Prepare an existing dataset or create a synthetic one.
55
+ c. Training: Set hyperparameters and start the fine-tuning process.
56
+ d. Test: Test your fine-tuned model with custom inputs.
57
+ e. GGUF Conversion: Convert your model to GGUF format if needed.
58
+
59
+ ## Notes
60
+
61
+ - Ensure you have the necessary API keys for OpenAI or Anthropic if you plan to use them for synthetic dataset creation.
62
+ - If using Ollama for local generation, make sure it's installed and running on your machine.
63
+ - Fine-tuning can be computationally intensive. Ensure you have adequate GPU resources available.
64
+
65
+ ## Contributing
66
+
67
+ Contributions are welcome! Please feel free to submit a Pull Request.
68
+
69
+ ## License
70
+
71
+ This project is licensed under the MIT License.
72
+
73
+
74
+ # Comprehensive Python Setup Guide
75
+
76
+ This guide will walk you through setting up Python, creating a virtual environment, and running your LLM Finetuner project on a new system.
77
+
78
+ ## 1. Install Python
79
+
80
+ ### Windows:
81
+ 1. Go to https://www.python.org/downloads/windows/
82
+ 2. Download the latest Python 3.x installer (64-bit version recommended)
83
+ 3. Run the installer
84
+ 4. Check "Add Python to PATH" during installation
85
+ 5. Click "Install Now"
86
+
87
+ ### macOS:
88
+ 1. Install Homebrew if you haven't already:
89
+ ```
90
+ /bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"
91
+ ```
92
+ 2. Install Python using Homebrew:
93
+ ```
94
+ brew install python
95
+ ```
96
+
97
+ ### Linux (Ubuntu/Debian):
98
+ 1. Update package list:
99
+ ```
100
+ sudo apt update
101
+ ```
102
+ 2. Install Python:
103
+ ```
104
+ sudo apt install python3 python3-pip python3-venv
105
+ ```
106
+
107
+ ## 2. Verify Python Installation
108
+
109
+ Open a terminal (Command Prompt on Windows) and run:
110
+ ```
111
+ python --version
112
+ ```
113
+ You should see the Python version number. If not, try `python3 --version`.
114
+
115
+ ## 3. Install Git
116
+
117
+ ### Windows:
118
+ 1. Go to https://git-scm.com/download/win
119
+ 2. Download and run the installer
120
+ 3. Use the default settings during installation
121
+
122
+ ### macOS:
123
+ If you installed Homebrew earlier:
124
+ ```
125
+ brew install git
126
+ ```
127
+
128
+ ### Linux (Ubuntu/Debian):
129
+ ```
130
+ sudo apt install git
131
+ ```
132
+
133
+ ## 4. Clone the Repository
134
+
135
+ 1. Open a terminal
136
+ 2. Navigate to where you want to store the project
137
+ 3. Clone the repository:
138
+ ```
139
+ git clone https://github.com/yourusername/llm-finetuner.git
140
+ cd llm-finetuner
141
+ ```
142
+
143
+ ## 5. Create and Activate a Virtual Environment
144
+
145
+ ### Windows:
146
+ ```
147
+ python -m venv venv
148
+ venv\Scripts\activate
149
+ ```
150
+
151
+ ### macOS and Linux:
152
+ ```
153
+ python3 -m venv venv
154
+ source venv/bin/activate
155
+ ```
156
+
157
+ Your prompt should change to indicate that the virtual environment is active.
158
+
159
+ ## 6. Install Required Packages
160
+
161
+ With the virtual environment activated:
162
+ ```
163
+ pip install -r requirements.txt
164
+ ```
165
+
166
+ This may take a while as it installs all necessary dependencies.
167
+
168
+ ## 7. Set Up CUDA (for GPU support)
169
+
170
+ If you have an NVIDIA GPU and want to use it for training:
171
+
172
+ 1. Go to https://developer.nvidia.com/cuda-downloads
173
+ 2. Download and install the CUDA Toolkit appropriate for your system
174
+ 3. Install the cuDNN library:
175
+ - Go to https://developer.nvidia.com/cudnn
176
+ - Download cuDNN (you may need to create an NVIDIA account)
177
+ - Follow the installation instructions for your system
178
+
179
+ ## 8. Run the Application
180
+
181
+ With the virtual environment still activated:
182
+ ```
183
+ python main.py
184
+ ```
185
+
186
+ This will start the Gradio interface. Open the provided URL in your web browser.
187
+
188
+ ## 9. Using the LLM Finetuner
189
+
190
+ 1. In the "Settings" tab:
191
+ - Enter your Hugging Face token
192
+ - Select a model
193
+
194
+ 2. In the "Dataset" tab:
195
+ - Prepare an existing dataset or create a synthetic one
196
+
197
+ 3. In the "Training" tab:
198
+ - Set hyperparameters and start training
199
+
200
+ 4. In the "Test" tab:
201
+ - Test your fine-tuned model
202
+
203
+ 5. In the "GGUF Conversion" tab:
204
+ - Convert your model to GGUF format if needed
205
+
206
+ ## Troubleshooting
207
+
208
+ - If `python` doesn't work, try `python3`
209
+ - Ensure your GPU drivers are up to date for CUDA support
210
+ - If you encounter "command not found" errors, ensure the relevant programs are in your system's PATH
211
+
212
+ ## Closing Notes
213
+
214
+ - Always activate the virtual environment before running the project
215
+ - To deactivate the virtual environment, simply type `deactivate` in the terminal
216
+ - Keep your Python packages updated with `pip install --upgrade -r requirements.txt`
217
+
218
+ Remember to keep your API keys and tokens secure. Happy fine-tuning!
219
+
220
+ ## Alternative, installation
221
+
222
+ pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
223
+ pip install triton
224
+ pip install unsloth gradio transformers datasets tqdm
dataset_utils.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset, Dataset
2
+ import json
3
+ import csv
4
+ import openai
5
+ import anthropic
6
+ import requests
7
+ import os
8
+ import logging
9
+ from tqdm import tqdm
10
+ import time
11
+
12
+ logging.basicConfig(level=logging.INFO)
13
+ logger = logging.getLogger(__name__)
14
+
15
+ def prepare_dataset(dataset_source, dataset_path, tokenizer, hf_token=None):
16
+ """
17
+ Prepare a dataset for fine-tuning, either from Hugging Face or a local file.
18
+
19
+ Args:
20
+ dataset_source (str): 'huggingface' or 'local'
21
+ dataset_path (str): Path or identifier of the dataset
22
+ tokenizer: The tokenizer associated with the model
23
+ hf_token (str, optional): Hugging Face token for accessing datasets
24
+
25
+ Returns:
26
+ Dataset: Prepared dataset ready for fine-tuning
27
+ """
28
+ if dataset_source == 'huggingface':
29
+ try:
30
+ dataset = load_dataset(dataset_path, split="train", use_auth_token=hf_token)
31
+ except ValueError:
32
+ # If use_auth_token is not supported, try without it
33
+ dataset = load_dataset(dataset_path, split="train")
34
+ elif dataset_source == 'local':
35
+ if not os.path.exists(dataset_path):
36
+ raise FileNotFoundError(f"File not found: {dataset_path}")
37
+
38
+ if dataset_path.endswith('.json'):
39
+ with open(dataset_path, 'r') as f:
40
+ data = json.load(f)
41
+ if isinstance(data, list):
42
+ dataset = Dataset.from_list(data)
43
+ elif isinstance(data, dict):
44
+ dataset = Dataset.from_dict(data)
45
+ else:
46
+ raise ValueError("JSON file must contain either a list or a dictionary.")
47
+ elif dataset_path.endswith('.csv'):
48
+ with open(dataset_path, 'r') as f:
49
+ reader = csv.DictReader(f)
50
+ data = list(reader)
51
+ dataset = Dataset.from_list(data)
52
+ else:
53
+ raise ValueError("Unsupported file format. Please use JSON or CSV.")
54
+ else:
55
+ raise ValueError("Invalid dataset source. Use 'huggingface' or 'local'.")
56
+
57
+ # Check if 'conversations' column exists, if not, try to create it
58
+ if 'conversations' not in dataset.column_names:
59
+ if 'text' in dataset.column_names:
60
+ dataset = dataset.map(lambda example: {'conversations': [{'human': example['text'], 'assistant': ''}]})
61
+ else:
62
+ raise ValueError("Dataset does not contain 'conversations' or 'text' column. Please check your dataset structure.")
63
+
64
+ # Only apply standardize_sharegpt if 'conversations' column exists
65
+ if 'conversations' in dataset.column_names:
66
+ dataset = standardize_sharegpt(dataset)
67
+
68
+ def formatting_prompts_func(examples):
69
+ if tokenizer is None:
70
+ raise ValueError("Tokenizer is not properly initialized. Please load the model and tokenizer before preparing the dataset.")
71
+ convos = examples["conversations"]
72
+ texts = [tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False) for convo in convos]
73
+ return {"text": texts}
74
+
75
+ dataset = dataset.map(formatting_prompts_func, batched=True)
76
+
77
+ if 'text' not in dataset.column_names:
78
+ def format_conversation(example):
79
+ formatted_text = ""
80
+ for turn in example['conversations']:
81
+ formatted_text += f"{turn['role']}: {turn['content']}\n"
82
+ return {"text": formatted_text.strip()}
83
+
84
+ dataset = dataset.map(format_conversation)
85
+
86
+ return dataset
87
+
88
+ def standardize_sharegpt(dataset):
89
+ # This is a simplified version. You might need to adjust it based on your specific needs.
90
+ def process_conversation(conversation):
91
+ standardized = []
92
+ for turn in conversation:
93
+ if 'human' in turn:
94
+ standardized.append({'role': 'user', 'content': turn['human']})
95
+ if 'assistant' in turn:
96
+ standardized.append({'role': 'assistant', 'content': turn['assistant']})
97
+ return standardized
98
+
99
+ return dataset.map(lambda x: {'conversations': process_conversation(x['conversations'])})
100
+
101
+ def create_synthetic_dataset(examples, expected_structure, num_samples, ai_provider, api_key, model_name=None):
102
+ """
103
+ Create a synthetic dataset based on example conversations and expected structure.
104
+
105
+ Args:
106
+ examples (str): Example conversations to base the synthetic data on.
107
+ expected_structure (str): Description of the expected dataset structure.
108
+ num_samples (int): Number of synthetic samples to generate.
109
+ ai_provider (str): AI provider to use for generation ('OpenAI', 'Anthropic', or 'Ollama').
110
+ api_key (str): API key for the chosen AI provider.
111
+ model_name (str, optional): Model name for Ollama (if applicable).
112
+
113
+ Returns:
114
+ Dataset: Synthetic dataset ready for fine-tuning.
115
+ """
116
+ synthetic_data = []
117
+
118
+ prompt = f"""
119
+ You are an AI assistant creating training dataset for finetuning a model.
120
+ You are provided an one-shot or few-shot output example of output that application expects from the AI model. You are also provided the
121
+ expected structure that the to-be trained AI model expects during training process.
122
+
123
+
124
+ Examples:
125
+ {examples}
126
+
127
+ Expected structure:
128
+ {expected_structure}
129
+
130
+ Please help Generate a new dataset in the provided same style and expected structure. Do not produce any extra output except the dataset in the training needed structure:
131
+ """
132
+
133
+ if ai_provider == "OpenAI":
134
+ client = openai.OpenAI(api_key=api_key)
135
+ for _ in tqdm(range(num_samples), desc="Generating samples"):
136
+ try:
137
+ response = client.chat.completions.create(
138
+ model="gpt-4-0125-preview",
139
+ messages=[{"role": "user", "content": prompt}],
140
+ timeout=30 # 30 seconds timeout
141
+ )
142
+ conversation = response.choices[0].message.content
143
+ synthetic_data.append({"conversations": json.loads(conversation)})
144
+ except json.JSONDecodeError:
145
+ logger.warning(f"Failed to decode response as JSON: {response.choices[0].message.content}")
146
+ except openai.APITimeoutError:
147
+ logger.warning("OpenAI API request timed out")
148
+ except Exception as e:
149
+ logger.error(f"Unexpected error: {str(e)}")
150
+ time.sleep(1) # Rate limiting
151
+
152
+ elif ai_provider == "Anthropic":
153
+ client = anthropic.Anthropic(api_key=api_key)
154
+ for _ in tqdm(range(num_samples), desc="Generating samples"):
155
+ try:
156
+ response = client.completions.create(
157
+ model="claude-3-opus-20240229",
158
+ prompt=f"Human: {prompt}\n\nAssistant:",
159
+ max_tokens_to_sample=1000,
160
+ timeout=30 # 30 seconds timeout
161
+ )
162
+ synthetic_data.append({"conversations": json.loads(response.completion)})
163
+ except json.JSONDecodeError:
164
+ logger.warning(f"Failed to decode response as JSON: {response.completion}")
165
+ except anthropic.APITimeoutError:
166
+ logger.warning("Anthropic API request timed out")
167
+ except Exception as e:
168
+ logger.error(f"Unexpected error: {str(e)}")
169
+ time.sleep(1) # Rate limiting
170
+
171
+ elif ai_provider == "Ollama":
172
+ for _ in tqdm(range(num_samples), desc="Generating samples"):
173
+ try:
174
+ response = requests.post('http://localhost:11434/api/generate',
175
+ json={
176
+ "model": model_name,
177
+ "prompt": prompt,
178
+ "stream": False
179
+ },
180
+ timeout=30) # 30 seconds timeout
181
+ response.raise_for_status()
182
+ synthetic_data.append({"conversations": json.loads(response.json()["response"])})
183
+ except json.JSONDecodeError:
184
+ logger.warning(f"Failed to decode response as JSON: {response.json()['response']}")
185
+ except requests.Timeout:
186
+ logger.warning("Ollama API request timed out")
187
+ except Exception as e:
188
+ logger.error(f"Unexpected error: {str(e)}")
189
+ time.sleep(1) # Rate limiting
190
+
191
+ dataset = Dataset.from_list(synthetic_data)
192
+ dataset = standardize_sharegpt(dataset)
193
+
194
+ if 'text' not in dataset.column_names:
195
+ def format_conversation(example):
196
+ formatted_text = ""
197
+ for turn in example['conversations']:
198
+ formatted_text += f"{turn['role']}: {turn['content']}\n"
199
+ return {"text": formatted_text.strip()}
200
+
201
+ dataset = dataset.map(format_conversation)
202
+
203
+ return dataset
gguf_utils.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def convert_to_gguf(model, tokenizer, output_path, quantization_method="q8_0"):
2
+ """
3
+ Convert the fine-tuned model to GGUF format.
4
+
5
+ Args:
6
+ model: The fine-tuned model to convert.
7
+ tokenizer: The tokenizer associated with the model.
8
+ output_path (str): The path to save the converted model.
9
+ quantization_method (str): The quantization method to use (e.g., "q8_0", "q4_k_m", "q5_k_m", "f16").
10
+
11
+ Returns:
12
+ str: A message indicating the success or failure of the conversion.
13
+ """
14
+ try:
15
+ model.save_pretrained_gguf(output_path, tokenizer, quantization_method=quantization_method)
16
+ return f"Model successfully converted to GGUF format: {output_path}-unsloth-{quantization_method}.gguf"
17
+ except Exception as e:
18
+ return f"Error converting to GGUF: {str(e)}"
inference_utils.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def test_model(model, tokenizer, input_text):
2
+ """
3
+ Test the fine-tuned model with a given input.
4
+
5
+ Args:
6
+ model: The fine-tuned model to test.
7
+ tokenizer: The tokenizer associated with the model.
8
+ input_text (str): The input text to generate a response for.
9
+
10
+ Returns:
11
+ str: The generated response from the model.
12
+ """
13
+ messages = [
14
+ {"role": "user", "content": input_text},
15
+ ]
16
+ inputs = tokenizer.apply_chat_template(
17
+ messages,
18
+ tokenize=True,
19
+ add_generation_prompt=True,
20
+ return_tensors="pt"
21
+ ).to("cuda")
22
+
23
+ outputs = model.generate(input_ids=inputs, max_new_tokens=64, use_cache=True,
24
+ temperature=1.5, min_p=0.1)
25
+ return tokenizer.batch_decode(outputs)[0]
main.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from ui import create_gradio_interface
2
+
3
+ if __name__ == "__main__":
4
+ demo = create_gradio_interface()
5
+ demo.launch(share=True)
model_utils.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import importlib.util
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+
5
+ def load_model(model_path, hf_token):
6
+ """
7
+ Load a pre-trained model and tokenizer, using unsloth if available,
8
+ falling back to standard transformers if necessary.
9
+
10
+ Args:
11
+ model_path (str): Path or identifier of the pre-trained model.
12
+ hf_token (str): Hugging Face API token for accessing gated models.
13
+
14
+ Returns:
15
+ tuple: Loaded model and tokenizer.
16
+ """
17
+ tokenizer = AutoTokenizer.from_pretrained(model_path, token=hf_token)
18
+
19
+ # Check if CUDA is available
20
+ cuda_available = torch.cuda.is_available()
21
+ if cuda_available:
22
+ print(f"CUDA is available. Using GPU: {torch.cuda.get_device_name(0)}")
23
+ device = "cuda"
24
+ else:
25
+ print("CUDA is not available. Using CPU.")
26
+ device = "cpu"
27
+
28
+ # Try to use unsloth if it's available
29
+ if importlib.util.find_spec("unsloth") is not None:
30
+ try:
31
+ from unsloth import FastLanguageModel
32
+ print("Using unsloth for model loading.")
33
+ model, _ = FastLanguageModel.from_pretrained(
34
+ model_name=model_path,
35
+ max_seq_length=2048,
36
+ dtype=None, # Automatically choose between float16 and bfloat16
37
+ load_in_4bit=cuda_available, # Only use 4-bit quantization if CUDA is available
38
+ token=hf_token
39
+ )
40
+ except Exception as e:
41
+ print(f"Error loading model with unsloth: {e}")
42
+ print("Falling back to standard transformers.")
43
+ model = load_with_transformers(model_path, hf_token, device)
44
+ else:
45
+ print("unsloth not found. Using standard transformers.")
46
+ model = load_with_transformers(model_path, hf_token, device)
47
+
48
+ # Do not use .to(device) for quantized models
49
+ # The device placement is handled automatically by unsloth or transformers
50
+
51
+ return model, tokenizer
52
+
53
+ def load_with_transformers(model_path, hf_token, device):
54
+ """Helper function to load model with standard transformers library."""
55
+ return AutoModelForCausalLM.from_pretrained(
56
+ model_path,
57
+ device_map="auto", # This will handle device placement automatically
58
+ token=hf_token,
59
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32
60
+ )
openai_sample_dataset.json ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "conversations": [
4
+ {
5
+ "human": "I'd like to know the price of silver membership.",
6
+ "assistant": "{\"conversation_stage_id\": 3,\"tool_required\": \"yes\",\"tool_name\": \"PriceInquiry\", \"operation_id\": \"fetchMembershipPrice\", \"tool_headers\": {\"Authorization\": \"sensitive_value\"}, \"tool_parameters\": {}, \"tool_body_parameters\": {\"membership\": \"Silver-Gym-Membership\"}}"
7
+ }
8
+ ]
9
+ },
10
+ {
11
+ "conversations": [
12
+ {
13
+ "human": "Help me book an appointment.",
14
+ "assistant": "{\"conversation_stage_id\": 7,\"tool_required\": \"yes\",\"tool_name\": \"OnsiteAppointment\", \"operation_id\": \"bookAppointment\", \"tool_headers\": {\"Authorization\": \"sensitive_value\"}, \"tool_parameters\": {\"url\": \"book-appointment\"}, \"tool_body_parameters\": {\"date\": \"next Monday\", \"time\": \"10:00 AM\"}}"
15
+ }
16
+ ]
17
+ },
18
+ {
19
+ "conversations": [
20
+ {
21
+ "human": "I'm having some back pain that's why considering the gym.",
22
+ "assistant": "{\"conversation_stage_id\": 3,\"tool_required\": \"no\"}"
23
+ }
24
+ ]
25
+ },
26
+ {
27
+ "conversations": [
28
+ {
29
+ "human": "Yes, I'd like to know the yoga classes schedule.",
30
+ "assistant": "{\"conversation_stage_id\": 2,\"tool_required\": \"yes\",\"tool_name\": \"GetClassSchedule\", \"operation_id\": \"fetchClassSchedule\", \"tool_headers\": {\"Authorization\": \"sensitive_value\"}, \"tool_parameters\": {\"class_name\": \"yoga\"}, \"tool_body_parameters\": {}}"
31
+ }
32
+ ]
33
+ },
34
+ {
35
+ "conversations": [
36
+ {
37
+ "human": "Help me book a trial session tomorrow.",
38
+ "assistant": "{\"conversation_stage_id\": 6,\"tool_required\": \"yes\",\"tool_name\": \"ScheduleTrialSession\", \"operation_id\": \"bookTrialSession\", \"tool_headers\": {\"Authorization\": \"sensitive_value\"}, \"tool_parameters\": {\"url\": \"schedule-trial\"}, \"tool_body_parameters\": {\"date\": \"tomorrow\", \"time\": \"10:00 AM\"}}"
39
+ }
40
+ ]
41
+ },
42
+ {
43
+ "conversations": [
44
+ {
45
+ "human": "I want to lose weight and build muscle.",
46
+ "assistant": "{\"conversation_stage_id\": 4,\"tool_required\": \"no\"}"
47
+ }
48
+ ]
49
+ },
50
+ {
51
+ "conversations": [
52
+ {
53
+ "human": "What are the benefits of the Gold membership?",
54
+ "assistant": "{\"conversation_stage_id\": 2,\"tool_required\": \"no\"}"
55
+ }
56
+ ]
57
+ },
58
+ {
59
+ "conversations": [
60
+ {
61
+ "human": "Yes, I am available next Monday at 10 AM. Help me book a personal training session.",
62
+ "assistant": "{\"conversation_stage_id\": 7,\"tool_required\": \"yes\",\"tool_name\": \"PersonalTraining\", \"operation_id\": \"bookTrainerSession\", \"tool_headers\": {\"Authorization\": \"sensitive_value\"}, \"tool_parameters\": {\"trainer_id\": \"johndoe\"}, \"tool_body_parameters\": {\"date\": \"next Monday\", \"time\": \"10:00 AM\"}}"
63
+ }
64
+ ]
65
+ },
66
+ {
67
+ "conversations": [
68
+ {
69
+ "human": "I am free tomorrow afternoon.",
70
+ "assistant": "{\"conversation_stage_id\": 6,\"tool_required\": \"yes\",\"tool_name\": \"GymAppointmentAvailability\", \"operation_id\": \"checkAvailability\", \"tool_headers\": {\"Authorization\": \"sensitive_value\"}, \"tool_parameters\": {\"date\": \"tomorrow\", \"time\": \"afternoon\"}, \"tool_body_parameters\": {}}"
71
+ }
72
+ ]
73
+ },
74
+ {
75
+ "conversations": [
76
+ {
77
+ "human": "Yes, I'd like to join the spin fitness class.",
78
+ "assistant": "{\"conversation_stage_id\": 4,\"tool_required\": \"yes\",\"tool_name\": \"GetClassSchedule\", \"operation_id\": \"fetchClassSchedule\", \"tool_headers\": {\"Authorization\": \"sensitive_value\"}, \"tool_parameters\": {\"class_name\": \"spin\"}, \"tool_body_parameters\": {}}"
79
+ }
80
+ ]
81
+ },
82
+ {
83
+ "conversations": [
84
+ {
85
+ "human": "Sure, when do zumba classes start?",
86
+ "assistant": "{\"conversation_stage_id\": 2,\"tool_required\": \"yes\",\"tool_name\": \"GetClassSchedule\", \"operation_id\": \"fetchClassSchedule\", \"tool_headers\": {\"Authorization\": \"sensitive_value\"}, \"tool_parameters\": {\"class_name\": \"zumba\"}, \"tool_body_parameters\": {}}"
87
+ }
88
+ ]
89
+ },
90
+ {
91
+ "conversations": [
92
+ {
93
+ "human": "I'd like to know if there are any membership discounts available.",
94
+ "assistant": "{\"conversation_stage_id\": 3,\"tool_required\": \"yes\",\"tool_name\": \"DiscountInquiry\", \"operation_id\": \"checkDiscount\", \"tool_headers\": {\"Authorization\": \"sensitive_value\"}, \"tool_parameters\": {}, \"tool_body_parameters\": {\"membership_type\": \"any\"}}"
95
+ }
96
+ ]
97
+ },
98
+ {
99
+ "conversations": [
100
+ {
101
+ "human": "Can you provide information about the swimming pool facilities?",
102
+ "assistant": "{\"conversation_stage_id\": 2,\"tool_required\": \"no\"}"
103
+ }
104
+ ]
105
+ },
106
+ {
107
+ "conversations": [
108
+ {
109
+ "human": "Please help me cancel my gym appointment for tomorrow.",
110
+ "assistant": "{\"conversation_stage_id\": 7,\"tool_required\": \"yes\",\"tool_name\": \"AppointmentCancellation\", \"operation_id\": \"cancelAppointment\", \"tool_headers\": {\"Authorization\": \"sensitive_value\"}, \"tool_parameters\": {\"appointment_id\": \"12345\"}, \"tool_body_parameters\": {}}"
111
+ }
112
+ ]
113
+ },
114
+ {
115
+ "conversations": [
116
+ {
117
+ "human": "I need to reschedule my personal training session.",
118
+ "assistant": "{\"conversation_stage_id\": 7,\"tool_required\": \"yes\",\"tool_name\": \"RescheduleTrainingSession\", \"operation_id\": \"rescheduleSession\", \"tool_headers\": {\"Authorization\": \"sensitive_value\"}, \"tool_parameters\": {\"trainer_id\": \"johndoe\"}, \"tool_body_parameters\": {\"new_date\": \"next Tuesday\", \"new_time\": \"11:00 AM\"}}"
119
+ }
120
+ ]
121
+ },
122
+ {
123
+ "conversations": [
124
+ {
125
+ "human": "Do you have any yoga classes available in the evening?",
126
+ "assistant": "{\"conversation_stage_id\": 3,\"tool_required\": \"yes\",\"tool_name\": \"ClassScheduleInquiry\", \"operation_id\": \"fetchClassSchedule\", \"tool_headers\": {\"Authorization\": \"sensitive_value\"}, \"tool_parameters\": {\"class_name\": \"yoga\", \"time_of_day\": \"evening\"}, \"tool_body_parameters\": {}}"
127
+ }
128
+ ]
129
+ },
130
+ {
131
+ "conversations": [
132
+ {
133
+ "human": "I'm interested in the family membership plan. Can you provide more details?",
134
+ "assistant": "{\"conversation_stage_id\": 2,\"tool_required\": \"no\"}"
135
+ }
136
+ ]
137
+ },
138
+ {
139
+ "conversations": [
140
+ {
141
+ "human": "I'd like to upgrade my membership from silver to gold.",
142
+ "assistant": "{\"conversation_stage_id\": 4,\"tool_required\": \"yes\",\"tool_name\": \"MembershipUpgrade\", \"operation_id\": \"upgradeMembership\", \"tool_headers\": {\"Authorization\": \"sensitive_value\"}, \"tool_parameters\": {\"current_membership\": \"silver\", \"new_membership\": \"gold\"}, \"tool_body_parameters\": {}}"
143
+ }
144
+ ]
145
+ },
146
+ {
147
+ "conversations": [
148
+ {
149
+ "human": "Is there a way to pause my membership while I travel?",
150
+ "assistant": "{\"conversation_stage_id\": 5,\"tool_required\": \"yes\",\"tool_name\": \"MembershipHold\", \"operation_id\": \"pauseMembership\", \"tool_headers\": {\"Authorization\": \"sensitive_value\"}, \"tool_parameters\": {}, \"tool_body_parameters\": {\"hold_duration\": \"2 months\"}}"
151
+ }
152
+ ]
153
+ },
154
+ {
155
+ "conversations": [
156
+ {
157
+ "human": "Can you assist me with changing my contact information?",
158
+ "assistant": "{\"conversation_stage_id\": 5,\"tool_required\": \"yes\",\"tool_name\": \"UpdateContactInfo\", \"operation_id\": \"updateInfo\", \"tool_headers\": {\"Authorization\": \"sensitive_value\"}, \"tool_parameters\": {}, \"tool_body_parameters\": {\"new_phone\": \"1234567890\", \"new_email\": \"[email protected]\"}}"
159
+ }
160
+ ]
161
+ },
162
+ {
163
+ "conversations": [
164
+ {
165
+ "human": "What are the payment options available for membership renewal?",
166
+ "assistant": "{\"conversation_stage_id\": 3,\"tool_required\": \"no\"}"
167
+ }
168
+ ]
169
+ },
170
+ {
171
+ "conversations": [
172
+ {
173
+ "human": "I want to know the personal trainer's availability next week.",
174
+ "assistant": "{\"conversation_stage_id\": 3,\"tool_required\": \"yes\",\"tool_name\": \"TrainerAvailability\", \"operation_id\": \"checkTrainerAvailability\", \"tool_headers\": {\"Authorization\": \"sensitive_value\"}, \"tool_parameters\": {\"trainer_id\": \"johndoe\", \"date_range\": \"next week\"}, \"tool_body_parameters\": {}}"
175
+ }
176
+ ]
177
+ }
178
+ ]
project_plan.md ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LLM Finetuner Project Plan
2
+
3
+ ## 1. Project Overview
4
+
5
+ The LLM Finetuner is a user-friendly application designed to simplify the process of fine-tuning Large Language Models (LLMs) using the Unsloth library. The application provides a graphical user interface for dataset preparation, model selection, fine-tuning, testing, and GGUF conversion.
6
+
7
+ ## 2. Project Structure
8
+
9
+ ```
10
+ llm_finetuner/
11
+ ├── main.py
12
+ ├── ui.py
13
+ ├── model_utils.py
14
+ ├── dataset_utils.py
15
+ ├── training_utils.py
16
+ ├── inference_utils.py
17
+ ├── gguf_utils.py
18
+ ├── requirements.txt
19
+ └── README.md
20
+ ```
21
+
22
+ ## 3. Key Components
23
+
24
+ ### 3.1 User Interface (ui.py)
25
+ - Gradio-based interface with tabs for different functionalities
26
+ - Handles user inputs and interactions
27
+ - Coordinates between different modules
28
+
29
+ ### 3.2 Model Utilities (model_utils.py)
30
+ - Handles model loading and initialization
31
+ - Supports various pre-trained models from Unsloth
32
+
33
+ ### 3.3 Dataset Utilities (dataset_utils.py)
34
+ - Manages dataset preparation from Hugging Face and local files
35
+ - Implements synthetic dataset creation using AI providers (OpenAI, Anthropic, Ollama)
36
+
37
+ ### 3.4 Training Utilities (training_utils.py)
38
+ - Implements the fine-tuning process using Unsloth and TRL
39
+
40
+ ### 3.5 Inference Utilities (inference_utils.py)
41
+ - Handles model testing and inference
42
+
43
+ ### 3.6 GGUF Conversion Utilities (gguf_utils.py)
44
+ - Manages the conversion of fine-tuned models to GGUF format
45
+
46
+ ## 4. Implementation Plan
47
+
48
+ ### 4.1 Phase 1: Core Functionality
49
+ - [x] Implement basic UI structure
50
+ - [x] Develop model loading and initialization
51
+ - [x] Implement dataset preparation for Hugging Face and local files using the model transformers and chat template.
52
+ - [x] Develop basic fine-tuning functionality using the prepared dataset
53
+ - [x] Implement model testing
54
+ - [x] Add GGUF conversion capability
55
+
56
+ ### 4.2 Phase 2: Enhanced Features
57
+ - [x] Implement synthetic dataset creation
58
+ - [ ] Improve error handling and user feedback
59
+ - [ ] Implement progress tracking for long-running operations
60
+ - [ ] Add support for custom model configurations
61
+
62
+ ### 4.3 Phase 3: Optimization and Advanced Features
63
+ - [ ] Optimize performance for large datasets and models
64
+ - [ ] Implement advanced fine-tuning techniques (e.g., LoRA, QLoRA)
65
+ - [ ] Add support for distributed training
66
+ - [ ] Implement model comparison tools
67
+
68
+ ## 5. Testing Plan
69
+
70
+ ### 5.1 Unit Testing
71
+ - Develop unit tests for each utility module
72
+ - Ensure proper error handling and edge case coverage
73
+
74
+ ### 5.2 Integration Testing
75
+ - Test the interaction between different modules
76
+ - Verify data flow from UI to backend and vice versa
77
+
78
+ ### 5.3 User Acceptance Testing
79
+ - Conduct usability testing with potential users
80
+ - Gather feedback on UI intuitiveness and feature completeness
81
+
82
+ ## 6. Deployment Plan
83
+
84
+ ### 6.1 Local Deployment
85
+ - Provide clear instructions for local installation and setup
86
+ - Create a comprehensive README with usage guidelines
87
+
88
+ ### 6.2 Cloud Deployment (Future Consideration)
89
+ - Explore options for cloud deployment (e.g., Hugging Face Spaces, Google Cloud)
90
+ - Implement necessary security measures for cloud deployment
91
+
92
+ ## 7. Documentation
93
+
94
+ - Create user documentation explaining each feature and its usage
95
+ - Develop technical documentation for future maintainers
96
+ - Include examples and use cases in the documentation
97
+
98
+ ## 8. Maintenance and Updates
99
+
100
+ - Establish a process for regular updates to supported models and libraries
101
+ - Plan for ongoing bug fixes and feature enhancements based on user feedback
102
+
103
+
104
+ This project plan provides a roadmap for the development, testing, and deployment of the LLM Finetuner application. It should be reviewed and updated regularly as the project progresses and new requirements or challenges emerge.
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ unsloth
4
+ datasets
5
+ trl
6
+ transformers
7
+ openai
8
+ anthropic
9
+ requests
10
+ tqdm
11
+ accelerate
12
+ bitsandbytes
13
+ huggingface_hub
14
+ triton
15
+ peft
training_utils.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from unsloth import FastLanguageModel
2
+ from trl import SFTTrainer
3
+ from transformers import TrainingArguments, DataCollatorForSeq2Seq
4
+ from unsloth import is_bfloat16_supported
5
+ from unsloth.chat_templates import train_on_responses_only
6
+
7
+ def finetune_model(model, tokenizer, dataset, learning_rate, batch_size, num_epochs):
8
+ """
9
+ Fine-tune a model on a given dataset, using CUDA if available.
10
+ This version supports fine-tuning of quantized models using PEFT and Unsloth optimizations.
11
+
12
+ Args:
13
+ model: The pre-trained model to fine-tune.
14
+ tokenizer: The tokenizer associated with the model.
15
+ dataset: The dataset to use for fine-tuning.
16
+ learning_rate (float): Learning rate for optimization.
17
+ batch_size (int): Number of training examples used in one iteration.
18
+ num_epochs (int): Number of complete passes through the dataset.
19
+
20
+ Returns:
21
+ SFTTrainer: The trained model wrapped in an SFTTrainer object.
22
+ """
23
+ # Prepare the model for training
24
+ model = FastLanguageModel.get_peft_model(
25
+ model,
26
+ r=16,
27
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
28
+ "gate_proj", "up_proj", "down_proj",],
29
+ lora_alpha=16,
30
+ lora_dropout=0,
31
+ bias="none",
32
+ use_gradient_checkpointing="unsloth",
33
+ random_state=3407,
34
+ use_rslora=False,
35
+ loftq_config=None,
36
+ )
37
+
38
+ # Set up the trainer
39
+ trainer = SFTTrainer(
40
+ model=model,
41
+ tokenizer=tokenizer,
42
+ train_dataset=dataset,
43
+ dataset_text_field="text",
44
+ max_seq_length=model.config.max_position_embeddings,
45
+ data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
46
+ dataset_num_proc=2,
47
+ packing=False,
48
+ args=TrainingArguments(
49
+ per_device_train_batch_size=batch_size,
50
+ gradient_accumulation_steps=4,
51
+ warmup_steps=5,
52
+ num_train_epochs=num_epochs,
53
+ learning_rate=learning_rate,
54
+ fp16=not is_bfloat16_supported(),
55
+ bf16=is_bfloat16_supported(),
56
+ logging_steps=1,
57
+ optim="adamw_8bit",
58
+ weight_decay=0.01,
59
+ lr_scheduler_type="linear",
60
+ seed=3407,
61
+ output_dir="outputs",
62
+ ),
63
+ )
64
+
65
+ # Apply train_on_responses_only
66
+ trainer = train_on_responses_only(
67
+ trainer,
68
+ instruction_part="<|start_header_id|>user<|end_header_id|>\n\n",
69
+ response_part="<|start_header_id|>assistant<|end_header_id|>\n\n",
70
+ )
71
+
72
+ # Train the model
73
+ trainer.train()
74
+
75
+ return trainer
ui.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from model_utils import load_model
4
+ from dataset_utils import prepare_dataset, create_synthetic_dataset
5
+ from training_utils import finetune_model
6
+ from inference_utils import test_model
7
+ from gguf_utils import convert_to_gguf
8
+ from unsloth import FastLanguageModel
9
+ from unsloth.chat_templates import get_chat_template
10
+ from upload_utils import upload_to_huggingface, upload_gguf_to_huggingface
11
+
12
+ def create_gradio_interface():
13
+ models = [
14
+ "unsloth/Meta-Llama-3.1-8B-bnb-4bit",
15
+ "unsloth/Mistral-Small-Instruct-2409",
16
+ "unsloth/mistral-7b-instruct-v0.3-bnb-4bit",
17
+ "unsloth/Phi-3.5-mini-instruct",
18
+ "unsloth/Phi-3-medium-4k-instruct",
19
+ "unsloth/gemma-2-9b-bnb-4bit",
20
+ "unsloth/gemma-2-27b-bnb-4bit",
21
+ "unsloth/Llama-3.2-3B-Instruct",
22
+ ]
23
+
24
+ with gr.Blocks() as demo:
25
+ gr.Markdown("# LLM Finetuner")
26
+
27
+ model = gr.State(None)
28
+ tokenizer = gr.State(None)
29
+ dataset = gr.State(None)
30
+
31
+ with gr.Tab("Settings"):
32
+ hf_token = gr.Textbox(label="Hugging Face Token", type="password")
33
+ model_path = gr.Dropdown(label="Model", choices=models, value="unsloth/Llama-3.2-3B-Instruct")
34
+ load_model_btn = gr.Button("Load Model")
35
+ load_model_output = gr.Textbox(label="Load Model Output")
36
+
37
+ with gr.Tab("Dataset"):
38
+ with gr.Group():
39
+ gr.Markdown("## Use Existing Dataset")
40
+ dataset_source = gr.Radio(["Hugging Face", "Local File"], label="Dataset Source", value="Hugging Face")
41
+ hf_dataset_path = gr.Textbox(label="Hugging Face Dataset Path", value="mlabonne/FineTome-100k")
42
+ local_dataset_path = gr.File(label="Upload Local Dataset (JSON or CSV)", visible=False)
43
+ prepare_dataset_btn = gr.Button("Prepare Dataset")
44
+ prepare_dataset_output = gr.Textbox(label="Prepare Dataset Output")
45
+
46
+ with gr.Group():
47
+ gr.Markdown("## Create Synthetic Dataset")
48
+ examples = gr.Textbox(label="Example Conversations", lines=10, placeholder="Enter example conversations here...")
49
+ expected_structure = gr.Textbox(label="Expected Dataset Structure", lines=5, placeholder="Enter the expected structure for the dataset...")
50
+ num_samples = gr.Number(label="Number of Samples to Generate", value=100)
51
+ ai_provider = gr.Radio(["OpenAI", "Anthropic", "Ollama"], label="AI Provider")
52
+ api_key = gr.Textbox(label="API Key", type="password")
53
+ ollama_model = gr.Textbox(label="Ollama Model Name", visible=False)
54
+ create_dataset_btn = gr.Button("Create Synthetic Dataset")
55
+ create_dataset_output = gr.Textbox(label="Create Dataset Output")
56
+
57
+ with gr.Tab("Training"):
58
+ learning_rate = gr.Number(label="Learning Rate", value=2e-4)
59
+ batch_size = gr.Number(label="Batch Size", value=2)
60
+ num_epochs = gr.Number(label="Number of Epochs", value=1)
61
+ train_btn = gr.Button("Start Training")
62
+ train_output = gr.Textbox(label="Training Output")
63
+
64
+ with gr.Tab("Test"):
65
+ test_input = gr.Textbox(label="Test Input")
66
+ test_btn = gr.Button("Test Model")
67
+ test_output = gr.Textbox(label="Model Output")
68
+
69
+ with gr.Tab("GGUF Conversion"):
70
+ gguf_output_path = gr.Textbox(label="GGUF Output Path")
71
+ gguf_quant_method = gr.Dropdown(
72
+ label="Quantization Method",
73
+ choices=["q8_0", "q4_k_m", "q5_k_m", "f16"],
74
+ value="q8_0"
75
+ )
76
+ gguf_convert_btn = gr.Button("Convert to GGUF")
77
+ gguf_output = gr.Textbox(label="GGUF Conversion Output")
78
+
79
+ with gr.Tab("Upload to Hugging Face"):
80
+ repo_name = gr.Textbox(label="Hugging Face Repository Name")
81
+ model_type = gr.Radio(["Fine-tuned Model", "GGUF Converted Model"], label="Model Type to Upload", value="Fine-tuned Model")
82
+ gguf_file_path = gr.Textbox(label="GGUF File Path (if uploading GGUF model)", visible=False)
83
+ upload_btn = gr.Button("Upload to Hugging Face")
84
+ upload_output = gr.Textbox(label="Upload Output")
85
+
86
+ def load_model_and_tokenizer(model_path, hf_token):
87
+ model_val, tokenizer_val = load_model(model_path, hf_token)
88
+ tokenizer_val = get_chat_template(tokenizer_val, chat_template="llama-3.1")
89
+ return model_val, tokenizer_val, "Model and tokenizer loaded successfully!"
90
+
91
+ def update_ollama_visibility(choice):
92
+ return gr.update(visible=(choice == "Ollama"))
93
+
94
+ def update_dataset_input_visibility(choice):
95
+ return gr.update(visible=(choice == "Hugging Face")), gr.update(visible=(choice == "Local File"))
96
+
97
+ def update_gguf_path_visibility(choice):
98
+ return gr.update(visible=(choice == "GGUF Converted Model"))
99
+
100
+ load_model_btn.click(
101
+ load_model_and_tokenizer,
102
+ inputs=[model_path, hf_token],
103
+ outputs=[model, tokenizer, load_model_output]
104
+ )
105
+
106
+ dataset_source.change(
107
+ update_dataset_input_visibility,
108
+ inputs=[dataset_source],
109
+ outputs=[hf_dataset_path, local_dataset_path]
110
+ )
111
+
112
+ model_type.change(
113
+ update_gguf_path_visibility,
114
+ inputs=[model_type],
115
+ outputs=[gguf_file_path]
116
+ )
117
+
118
+ def prepare_dataset_wrapper(source, hf_path, local_file, hf_token, tokenizer_val):
119
+ if tokenizer_val is None:
120
+ return "Error: Model and tokenizer not loaded. Please load the model first."
121
+
122
+ if source == "Hugging Face":
123
+ dataset_val = prepare_dataset("huggingface", hf_path, tokenizer_val, hf_token)
124
+ elif source == "Local File":
125
+ if local_file is not None:
126
+ dataset_val = prepare_dataset("local", local_file.name, tokenizer_val)
127
+ else:
128
+ return "No file uploaded. Please upload a local dataset file."
129
+ else:
130
+ return "Invalid dataset source selected."
131
+
132
+ return dataset_val, "Dataset prepared successfully!"
133
+
134
+ prepare_dataset_btn.click(
135
+ prepare_dataset_wrapper,
136
+ inputs=[dataset_source, hf_dataset_path, local_dataset_path, hf_token, tokenizer],
137
+ outputs=[dataset, prepare_dataset_output]
138
+ )
139
+
140
+ def create_synthetic_dataset_wrapper(examples, expected_structure, num_samples, ai_provider, api_key, ollama_model, tokenizer_val):
141
+ if tokenizer_val is None:
142
+ return "Error: Model and tokenizer not loaded. Please load the model first."
143
+
144
+ dataset_val = create_synthetic_dataset(examples, expected_structure, num_samples, ai_provider, api_key, ollama_model)
145
+ return dataset_val, "Synthetic dataset created successfully!"
146
+
147
+ create_dataset_btn.click(
148
+ create_synthetic_dataset_wrapper,
149
+ inputs=[examples, expected_structure, num_samples, ai_provider, api_key, ollama_model, tokenizer],
150
+ outputs=[dataset, create_dataset_output]
151
+ )
152
+
153
+ ai_provider.change(update_ollama_visibility, inputs=[ai_provider], outputs=[ollama_model])
154
+
155
+ def train_model_wrapper(model_val, tokenizer_val, dataset_val, learning_rate, batch_size, num_epochs):
156
+ if model_val is None or tokenizer_val is None:
157
+ return "Error: Model and tokenizer not loaded. Please load the model first."
158
+ if dataset_val is None:
159
+ return "Error: Dataset not prepared. Please prepare or create a dataset first."
160
+
161
+ try:
162
+ trainer = finetune_model(model_val, tokenizer_val, dataset_val, learning_rate, batch_size, num_epochs)
163
+ return "Training completed successfully!"
164
+ except Exception as e:
165
+ return f"Error during training: {str(e)}"
166
+
167
+ train_btn.click(
168
+ train_model_wrapper,
169
+ inputs=[model, tokenizer, dataset, learning_rate, batch_size, num_epochs],
170
+ outputs=[train_output]
171
+ )
172
+
173
+ def test_model_wrapper(model_val, tokenizer_val, test_input):
174
+ if model_val is None or tokenizer_val is None:
175
+ return "Error: Model and tokenizer not loaded. Please load the model first."
176
+
177
+ FastLanguageModel.for_inference(model_val) # Enable native 2x faster inference
178
+ messages = [{"role": "user", "content": test_input}]
179
+ inputs = tokenizer_val.apply_chat_template(
180
+ messages,
181
+ tokenize=True,
182
+ add_generation_prompt=True,
183
+ return_tensors="pt"
184
+ ).to("cuda" if torch.cuda.is_available() else "cpu")
185
+
186
+ outputs = model_val.generate(input_ids=inputs, max_new_tokens=128, temperature=1.5, min_p=0.1)
187
+ return tokenizer_val.batch_decode(outputs)[0]
188
+
189
+ test_btn.click(
190
+ test_model_wrapper,
191
+ inputs=[model, tokenizer, test_input],
192
+ outputs=[test_output]
193
+ )
194
+
195
+ def convert_to_gguf_wrapper(model_val, tokenizer_val, gguf_output_path, gguf_quant_method):
196
+ if model_val is None or tokenizer_val is None:
197
+ return "Error: Model and tokenizer not loaded. Please load the model first."
198
+
199
+ output = convert_to_gguf(model_val, tokenizer_val, gguf_output_path, gguf_quant_method)
200
+ return output
201
+
202
+ gguf_convert_btn.click(
203
+ convert_to_gguf_wrapper,
204
+ inputs=[model, tokenizer, gguf_output_path, gguf_quant_method],
205
+ outputs=[gguf_output]
206
+ )
207
+
208
+ def upload_to_hf_wrapper(model_val, tokenizer_val, repo_name, hf_token, model_type, gguf_file_path):
209
+ if model_type == "Fine-tuned Model":
210
+ if model_val is None or tokenizer_val is None:
211
+ return "Error: Model and tokenizer not loaded. Please load the model first."
212
+ result = upload_to_huggingface(model_val, tokenizer_val, repo_name, hf_token)
213
+ elif model_type == "GGUF Converted Model":
214
+ if not gguf_file_path:
215
+ return "Error: GGUF file path not provided. Please enter the path to the GGUF file."
216
+ result = upload_gguf_to_huggingface(gguf_file_path, repo_name, hf_token)
217
+ else:
218
+ return "Error: Invalid model type selected."
219
+ return result
220
+
221
+ upload_btn.click(
222
+ upload_to_hf_wrapper,
223
+ inputs=[model, tokenizer, repo_name, hf_token, model_type, gguf_file_path],
224
+ outputs=[upload_output]
225
+ )
226
+
227
+ return demo
228
+
229
+ if __name__ == "__main__":
230
+ demo = create_gradio_interface()
231
+ demo.launch()
unsloth_compiled_cache/UnslothAlignPropTrainer.py ADDED
@@ -0,0 +1,637 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.3.17
3
+ 2025.3.19
4
+ 4.50.3
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from trl.trainer.alignprop_trainer import (Accelerator, AlignPropConfig, AlignPropTrainer, Any, Callable, DDPOStableDiffusionPipeline, Optional, ProjectConfiguration, PyTorchModelHubMixin, Union, defaultdict, generate_model_card, get_comet_experiment_url, is_wandb_available, logger, os, set_seed, textwrap, torch, wandb, warn)
13
+
14
+
15
+ import os
16
+ from typing import *
17
+ from dataclasses import dataclass, field
18
+ from packaging.version import Version
19
+ import torch
20
+ import numpy as np
21
+ from contextlib import nullcontext
22
+ from torch.nn import functional as F
23
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
+
25
+ torch_compile_options = {
26
+ "epilogue_fusion" : True,
27
+ "max_autotune" : False,
28
+ "shape_padding" : True,
29
+ "trace.enabled" : False,
30
+ "triton.cudagraphs" : False,
31
+ }
32
+
33
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
+ def selective_log_softmax(logits, index):
35
+ logits = logits.to(torch.float32)
36
+ selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
+ # loop to reduce peak mem consumption
38
+ # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
+ logsumexp_values = torch.logsumexp(logits, dim = -1)
40
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
+ return per_token_logps
42
+ @dataclass
43
+ class UnslothAlignPropConfig(AlignPropConfig):
44
+ """
45
+
46
+ Configuration class for the [`AlignPropTrainer`].
47
+
48
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
49
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
50
+ command line.
51
+
52
+ Parameters:
53
+ exp_name (`str`, *optional*, defaults to `os.path.basename(sys.argv[0])[: -len(".py")]`):
54
+ Name of this experiment (defaults to the file name without the extension).
55
+ run_name (`str`, *optional*, defaults to `""`):
56
+ Name of this run.
57
+ seed (`int`, *optional*, defaults to `0`):
58
+ Random seed for reproducibility.
59
+ log_with (`str` or `None`, *optional*, defaults to `None`):
60
+ Log with either `"wandb"` or `"tensorboard"`. Check
61
+ [tracking](https://huggingface.co/docs/accelerate/usage_guides/tracking) for more details.
62
+ log_image_freq (`int`, *optional*, defaults to `1`):
63
+ Frequency for logging images.
64
+ tracker_kwargs (`dict[str, Any]`, *optional*, defaults to `{}`):
65
+ Keyword arguments for the tracker (e.g., `wandb_project`).
66
+ accelerator_kwargs (`dict[str, Any]`, *optional*, defaults to `{}`):
67
+ Keyword arguments for the accelerator.
68
+ project_kwargs (`dict[str, Any]`, *optional*, defaults to `{}`):
69
+ Keyword arguments for the accelerator project config (e.g., `logging_dir`).
70
+ tracker_project_name (`str`, *optional*, defaults to `"trl"`):
71
+ Name of project to use for tracking.
72
+ logdir (`str`, *optional*, defaults to `"logs"`):
73
+ Top-level logging directory for checkpoint saving.
74
+ num_epochs (`int`, *optional*, defaults to `100`):
75
+ Number of epochs to train.
76
+ save_freq (`int`, *optional*, defaults to `1`):
77
+ Number of epochs between saving model checkpoints.
78
+ num_checkpoint_limit (`int`, *optional*, defaults to `5`):
79
+ Number of checkpoints to keep before overwriting old ones.
80
+ mixed_precision (`str`, *optional*, defaults to `"fp16"`):
81
+ Mixed precision training.
82
+ allow_tf32 (`bool`, *optional*, defaults to `True`):
83
+ Allow `tf32` on Ampere GPUs.
84
+ resume_from (`str`, *optional*, defaults to `""`):
85
+ Path to resume training from a checkpoint.
86
+ sample_num_steps (`int`, *optional*, defaults to `50`):
87
+ Number of sampler inference steps.
88
+ sample_eta (`float`, *optional*, defaults to `1.0`):
89
+ Eta parameter for the DDIM sampler.
90
+ sample_guidance_scale (`float`, *optional*, defaults to `5.0`):
91
+ Classifier-free guidance weight.
92
+ train_batch_size (`int`, *optional*, defaults to `1`):
93
+ Batch size for training.
94
+ train_use_8bit_adam (`bool`, *optional*, defaults to `False`):
95
+ Whether to use the 8bit Adam optimizer from `bitsandbytes`.
96
+ train_learning_rate (`float`, *optional*, defaults to `1e-3`):
97
+ Learning rate.
98
+ train_adam_beta1 (`float`, *optional*, defaults to `0.9`):
99
+ Beta1 for Adam optimizer.
100
+ train_adam_beta2 (`float`, *optional*, defaults to `0.999`):
101
+ Beta2 for Adam optimizer.
102
+ train_adam_weight_decay (`float`, *optional*, defaults to `1e-4`):
103
+ Weight decay for Adam optimizer.
104
+ train_adam_epsilon (`float`, *optional*, defaults to `1e-8`):
105
+ Epsilon value for Adam optimizer.
106
+ train_gradient_accumulation_steps (`int`, *optional*, defaults to `1`):
107
+ Number of gradient accumulation steps.
108
+ train_max_grad_norm (`float`, *optional*, defaults to `1.0`):
109
+ Maximum gradient norm for gradient clipping.
110
+ negative_prompts (`str` or `None`, *optional*, defaults to `None`):
111
+ Comma-separated list of prompts to use as negative examples.
112
+ truncated_backprop_rand (`bool`, *optional*, defaults to `True`):
113
+ If `True`, randomized truncation to different diffusion timesteps is used.
114
+ truncated_backprop_timestep (`int`, *optional*, defaults to `49`):
115
+ Absolute timestep to which the gradients are backpropagated. Used only if `truncated_backprop_rand=False`.
116
+ truncated_rand_backprop_minmax (`tuple[int, int]`, *optional*, defaults to `(0, 50)`):
117
+ Range of diffusion timesteps for randomized truncated backpropagation.
118
+ push_to_hub (`bool`, *optional*, defaults to `False`):
119
+ Whether to push the final model to the Hub.
120
+
121
+ """
122
+ vllm_sampling_params: Optional[Any] = field(
123
+ default = None,
124
+ metadata = {'help': 'vLLM SamplingParams'},
125
+ )
126
+ unsloth_num_chunks : Optional[int] = field(
127
+ default = -1,
128
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
129
+ )
130
+ def __init__(
131
+ self,
132
+ exp_name = 'main',
133
+ run_name = '',
134
+ seed = 3407,
135
+ log_with = None,
136
+ log_image_freq = 1,
137
+ tracker_project_name = 'trl',
138
+ logdir = 'logs',
139
+ num_epochs = 100,
140
+ save_freq = 1,
141
+ num_checkpoint_limit = 5,
142
+ mixed_precision = 'fp16',
143
+ allow_tf32 = True,
144
+ resume_from = '',
145
+ sample_num_steps = 50,
146
+ sample_eta = 1.0,
147
+ sample_guidance_scale = 5.0,
148
+ train_batch_size = 1,
149
+ train_use_8bit_adam = False,
150
+ train_learning_rate = 5e-05,
151
+ train_adam_beta1 = 0.9,
152
+ train_adam_beta2 = 0.999,
153
+ train_adam_weight_decay = 0.01,
154
+ train_adam_epsilon = 1e-08,
155
+ train_gradient_accumulation_steps = 2,
156
+ train_max_grad_norm = 1.0,
157
+ negative_prompts = None,
158
+ truncated_backprop_rand = True,
159
+ truncated_backprop_timestep = 49,
160
+ push_to_hub = False,
161
+ vllm_sampling_params = None,
162
+ unsloth_num_chunks = -1,
163
+ **kwargs,
164
+ ):
165
+
166
+ super().__init__(
167
+ exp_name = exp_name,
168
+ run_name = run_name,
169
+ seed = seed,
170
+ log_with = log_with,
171
+ log_image_freq = log_image_freq,
172
+ tracker_project_name = tracker_project_name,
173
+ logdir = logdir,
174
+ num_epochs = num_epochs,
175
+ save_freq = save_freq,
176
+ num_checkpoint_limit = num_checkpoint_limit,
177
+ mixed_precision = mixed_precision,
178
+ allow_tf32 = allow_tf32,
179
+ resume_from = resume_from,
180
+ sample_num_steps = sample_num_steps,
181
+ sample_eta = sample_eta,
182
+ sample_guidance_scale = sample_guidance_scale,
183
+ train_batch_size = train_batch_size,
184
+ train_use_8bit_adam = train_use_8bit_adam,
185
+ train_learning_rate = train_learning_rate,
186
+ train_adam_beta1 = train_adam_beta1,
187
+ train_adam_beta2 = train_adam_beta2,
188
+ train_adam_weight_decay = train_adam_weight_decay,
189
+ train_adam_epsilon = train_adam_epsilon,
190
+ train_gradient_accumulation_steps = train_gradient_accumulation_steps,
191
+ train_max_grad_norm = train_max_grad_norm,
192
+ negative_prompts = negative_prompts,
193
+ truncated_backprop_rand = truncated_backprop_rand,
194
+ truncated_backprop_timestep = truncated_backprop_timestep,
195
+ push_to_hub = push_to_hub,**kwargs)
196
+ self.vllm_sampling_params = vllm_sampling_params
197
+ self.unsloth_num_chunks = unsloth_num_chunks
198
+ pass
199
+
200
+ class _UnslothAlignPropTrainer(PyTorchModelHubMixin):
201
+ """"""
202
+
203
+ _tag_names = ["trl", "alignprop"]
204
+
205
+ def __init__(
206
+ self,
207
+ config: AlignPropConfig,
208
+ reward_function: Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor],
209
+ prompt_function: Callable[[], tuple[str, Any]],
210
+ sd_pipeline: DDPOStableDiffusionPipeline,
211
+ image_samples_hook: Optional[Callable[[Any, Any, Any], Any]] = None,
212
+ ):
213
+ if image_samples_hook is None:
214
+ warn("No image_samples_hook provided; no images will be logged")
215
+
216
+ self.prompt_fn = prompt_function
217
+ self.reward_fn = reward_function
218
+ self.config = config
219
+ self.image_samples_callback = image_samples_hook
220
+
221
+ accelerator_project_config = ProjectConfiguration(**self.config.project_kwargs)
222
+
223
+ if self.config.resume_from:
224
+ self.config.resume_from = os.path.normpath(os.path.expanduser(self.config.resume_from))
225
+ if "checkpoint_" not in os.path.basename(self.config.resume_from):
226
+ # get the most recent checkpoint in this directory
227
+ checkpoints = list(
228
+ filter(
229
+ lambda x: "checkpoint_" in x,
230
+ os.listdir(self.config.resume_from),
231
+ )
232
+ )
233
+ if len(checkpoints) == 0:
234
+ raise ValueError(f"No checkpoints found in {self.config.resume_from}")
235
+ checkpoint_numbers = sorted([int(x.split("_")[-1]) for x in checkpoints])
236
+ self.config.resume_from = os.path.join(
237
+ self.config.resume_from,
238
+ f"checkpoint_{checkpoint_numbers[-1]}",
239
+ )
240
+
241
+ accelerator_project_config.iteration = checkpoint_numbers[-1] + 1
242
+
243
+ self.accelerator = Accelerator(
244
+ log_with=self.config.log_with,
245
+ mixed_precision=self.config.mixed_precision,
246
+ project_config=accelerator_project_config,
247
+ # we always accumulate gradients across timesteps; we want config.train.gradient_accumulation_steps to be the
248
+ # number of *samples* we accumulate across, so we need to multiply by the number of training timesteps to get
249
+ # the total number of optimizer steps to accumulate across.
250
+ gradient_accumulation_steps=self.config.train_gradient_accumulation_steps,
251
+ **self.config.accelerator_kwargs,
252
+ )
253
+
254
+ is_using_tensorboard = config.log_with is not None and config.log_with == "tensorboard"
255
+
256
+ if self.accelerator.is_main_process:
257
+ self.accelerator.init_trackers(
258
+ self.config.tracker_project_name,
259
+ config=dict(alignprop_trainer_config=config.to_dict())
260
+ if not is_using_tensorboard
261
+ else config.to_dict(),
262
+ init_kwargs=self.config.tracker_kwargs,
263
+ )
264
+
265
+ logger.info(f"\n{config}")
266
+
267
+ set_seed(self.config.seed, device_specific=True)
268
+
269
+ self.sd_pipeline = sd_pipeline
270
+
271
+ self.sd_pipeline.set_progress_bar_config(
272
+ position=1,
273
+ disable=not self.accelerator.is_local_main_process,
274
+ leave=False,
275
+ desc="Timestep",
276
+ dynamic_ncols=True,
277
+ )
278
+
279
+ # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
280
+ # as these weights are only used for inference, keeping weights in full precision is not required.
281
+ if self.accelerator.mixed_precision == "fp16":
282
+ inference_dtype = torch.float16
283
+ elif self.accelerator.mixed_precision == "bf16":
284
+ inference_dtype = torch.bfloat16
285
+ else:
286
+ inference_dtype = torch.float32
287
+
288
+ self.sd_pipeline.vae.to(self.accelerator.device, dtype=inference_dtype)
289
+ self.sd_pipeline.text_encoder.to(self.accelerator.device, dtype=inference_dtype)
290
+ self.sd_pipeline.unet.to(self.accelerator.device, dtype=inference_dtype)
291
+
292
+ trainable_layers = self.sd_pipeline.get_trainable_layers()
293
+
294
+ self.accelerator.register_save_state_pre_hook(self._save_model_hook)
295
+ self.accelerator.register_load_state_pre_hook(self._load_model_hook)
296
+
297
+ # Enable TF32 for faster training on Ampere GPUs,
298
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
299
+ if self.config.allow_tf32:
300
+ torch.backends.cuda.matmul.allow_tf32 = True
301
+
302
+ self.optimizer = self._setup_optimizer(
303
+ trainable_layers.parameters() if not isinstance(trainable_layers, list) else trainable_layers
304
+ )
305
+
306
+ self.neg_prompt_embed = self.sd_pipeline.text_encoder(
307
+ self.sd_pipeline.tokenizer(
308
+ [""] if self.config.negative_prompts is None else self.config.negative_prompts,
309
+ return_tensors="pt",
310
+ padding="max_length",
311
+ truncation=True,
312
+ max_length=self.sd_pipeline.tokenizer.model_max_length,
313
+ ).input_ids.to(self.accelerator.device)
314
+ )[0]
315
+
316
+ # NOTE: for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses
317
+ # more memory
318
+ self.autocast = self.sd_pipeline.autocast or self.accelerator.autocast
319
+
320
+ if hasattr(self.sd_pipeline, "use_lora") and self.sd_pipeline.use_lora:
321
+ unet, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
322
+ self.trainable_layers = list(filter(lambda p: p.requires_grad, unet.parameters()))
323
+ else:
324
+ self.trainable_layers, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
325
+
326
+ if config.resume_from:
327
+ logger.info(f"Resuming from {config.resume_from}")
328
+ self.accelerator.load_state(config.resume_from)
329
+ self.first_epoch = int(config.resume_from.split("_")[-1]) + 1
330
+ else:
331
+ self.first_epoch = 0
332
+
333
+ def compute_rewards(self, prompt_image_pairs):
334
+ reward, reward_metadata = self.reward_fn(
335
+ prompt_image_pairs["images"], prompt_image_pairs["prompts"], prompt_image_pairs["prompt_metadata"]
336
+ )
337
+ return reward
338
+
339
+ def step(self, epoch: int, global_step: int):
340
+ """
341
+ Perform a single step of training.
342
+
343
+ Args:
344
+ epoch (int): The current epoch.
345
+ global_step (int): The current global step.
346
+
347
+ Side Effects:
348
+ - Model weights are updated
349
+ - Logs the statistics to the accelerator trackers.
350
+ - If `self.image_samples_callback` is not None, it will be called with the prompt_image_pairs, global_step, and the accelerator tracker.
351
+
352
+ Returns:
353
+ global_step (int): The updated global step.
354
+ """
355
+ info = defaultdict(list)
356
+
357
+ self.sd_pipeline.unet.train()
358
+
359
+ for _ in range(self.config.train_gradient_accumulation_steps):
360
+ with self.accelerator.accumulate(self.sd_pipeline.unet), self.autocast(), torch.enable_grad():
361
+ prompt_image_pairs = self._generate_samples(
362
+ batch_size=self.config.train_batch_size,
363
+ )
364
+
365
+ rewards = self.compute_rewards(prompt_image_pairs)
366
+
367
+ prompt_image_pairs["rewards"] = rewards
368
+
369
+ rewards_vis = self.accelerator.gather(rewards).detach().cpu().numpy()
370
+
371
+ loss = self.calculate_loss(rewards)
372
+
373
+ self.accelerator.backward(loss)
374
+
375
+ if self.accelerator.sync_gradients:
376
+ self.accelerator.clip_grad_norm_(
377
+ self.trainable_layers.parameters()
378
+ if not isinstance(self.trainable_layers, list)
379
+ else self.trainable_layers,
380
+ self.config.train_max_grad_norm,
381
+ )
382
+
383
+ self.optimizer.step()
384
+ self.optimizer.zero_grad()
385
+
386
+ info["reward_mean"].append(rewards_vis.mean())
387
+ info["reward_std"].append(rewards_vis.std())
388
+ info["loss"].append(loss.item())
389
+
390
+ # Checks if the accelerator has performed an optimization step behind the scenes
391
+ if self.accelerator.sync_gradients:
392
+ # log training-related stuff
393
+ info = {k: torch.mean(torch.tensor(v)) for k, v in info.items()}
394
+ info = self.accelerator.reduce(info, reduction="mean")
395
+ info.update({"epoch": epoch})
396
+ self.accelerator.log(info, step=global_step)
397
+ global_step += 1
398
+ info = defaultdict(list)
399
+ else:
400
+ raise ValueError(
401
+ "Optimization step should have been performed by this point. Please check calculated gradient accumulation settings."
402
+ )
403
+ # Logs generated images
404
+ if self.image_samples_callback is not None and global_step % self.config.log_image_freq == 0:
405
+ self.image_samples_callback(prompt_image_pairs, global_step, self.accelerator.trackers[0])
406
+
407
+ if epoch != 0 and epoch % self.config.save_freq == 0 and self.accelerator.is_main_process:
408
+ self.accelerator.save_state()
409
+
410
+ return global_step
411
+
412
+ def calculate_loss(self, rewards):
413
+ """
414
+ Calculate the loss for a batch of an unpacked sample
415
+
416
+ Args:
417
+ rewards (torch.Tensor):
418
+ Differentiable reward scalars for each generated image, shape: [batch_size]
419
+
420
+ Returns:
421
+ loss (torch.Tensor)
422
+ (all of these are of shape (1,))
423
+ """
424
+ # Loss is specific to Aesthetic Reward function used in AlignProp (https://huggingface.co/papers/2310.03739)
425
+ loss = 10.0 - (rewards).mean()
426
+ return loss
427
+
428
+ def loss(
429
+ self,
430
+ advantages: torch.Tensor,
431
+ clip_range: float,
432
+ ratio: torch.Tensor,
433
+ ):
434
+ unclipped_loss = -advantages * ratio
435
+ clipped_loss = -advantages * torch.clamp(
436
+ ratio,
437
+ 1.0 - clip_range,
438
+ 1.0 + clip_range,
439
+ )
440
+ return torch.mean(torch.maximum(unclipped_loss, clipped_loss))
441
+
442
+ def _setup_optimizer(self, trainable_layers_parameters):
443
+ if self.config.train_use_8bit_adam:
444
+ import bitsandbytes
445
+
446
+ optimizer_cls = bitsandbytes.optim.AdamW8bit
447
+ else:
448
+ optimizer_cls = torch.optim.AdamW
449
+
450
+ return optimizer_cls(
451
+ trainable_layers_parameters,
452
+ lr=self.config.train_learning_rate,
453
+ betas=(self.config.train_adam_beta1, self.config.train_adam_beta2),
454
+ weight_decay=self.config.train_adam_weight_decay,
455
+ eps=self.config.train_adam_epsilon,
456
+ )
457
+
458
+ def _save_model_hook(self, models, weights, output_dir):
459
+ self.sd_pipeline.save_checkpoint(models, weights, output_dir)
460
+ weights.pop() # ensures that accelerate doesn't try to handle saving of the model
461
+
462
+ def _load_model_hook(self, models, input_dir):
463
+ self.sd_pipeline.load_checkpoint(models, input_dir)
464
+ models.pop() # ensures that accelerate doesn't try to handle loading of the model
465
+
466
+ def _generate_samples(self, batch_size, with_grad=True, prompts=None):
467
+ """
468
+ Generate samples from the model
469
+
470
+ Args:
471
+ batch_size (int): Batch size to use for sampling
472
+ with_grad (bool): Whether the generated RGBs should have gradients attached to it.
473
+
474
+ Returns:
475
+ prompt_image_pairs (dict[Any])
476
+ """
477
+ prompt_image_pairs = {}
478
+
479
+ sample_neg_prompt_embeds = self.neg_prompt_embed.repeat(batch_size, 1, 1)
480
+
481
+ if prompts is None:
482
+ prompts, prompt_metadata = zip(*[self.prompt_fn() for _ in range(batch_size)])
483
+ else:
484
+ prompt_metadata = [{} for _ in range(batch_size)]
485
+
486
+ prompt_ids = self.sd_pipeline.tokenizer(
487
+ prompts,
488
+ return_tensors="pt",
489
+ padding="max_length",
490
+ truncation=True,
491
+ max_length=self.sd_pipeline.tokenizer.model_max_length,
492
+ ).input_ids.to(self.accelerator.device)
493
+
494
+ prompt_embeds = self.sd_pipeline.text_encoder(prompt_ids)[0]
495
+
496
+ if with_grad:
497
+ sd_output = self.sd_pipeline.rgb_with_grad(
498
+ prompt_embeds=prompt_embeds,
499
+ negative_prompt_embeds=sample_neg_prompt_embeds,
500
+ num_inference_steps=self.config.sample_num_steps,
501
+ guidance_scale=self.config.sample_guidance_scale,
502
+ eta=self.config.sample_eta,
503
+ truncated_backprop_rand=self.config.truncated_backprop_rand,
504
+ truncated_backprop_timestep=self.config.truncated_backprop_timestep,
505
+ truncated_rand_backprop_minmax=self.config.truncated_rand_backprop_minmax,
506
+ output_type="pt",
507
+ )
508
+ else:
509
+ sd_output = self.sd_pipeline(
510
+ prompt_embeds=prompt_embeds,
511
+ negative_prompt_embeds=sample_neg_prompt_embeds,
512
+ num_inference_steps=self.config.sample_num_steps,
513
+ guidance_scale=self.config.sample_guidance_scale,
514
+ eta=self.config.sample_eta,
515
+ output_type="pt",
516
+ )
517
+
518
+ images = sd_output.images
519
+
520
+ prompt_image_pairs["images"] = images
521
+ prompt_image_pairs["prompts"] = prompts
522
+ prompt_image_pairs["prompt_metadata"] = prompt_metadata
523
+
524
+ return prompt_image_pairs
525
+
526
+ def train(self, epochs: Optional[int] = None):
527
+ """
528
+ Train the model for a given number of epochs
529
+ """
530
+ global_step = 0
531
+ if epochs is None:
532
+ epochs = self.config.num_epochs
533
+ for epoch in range(self.first_epoch, epochs):
534
+ global_step = self.step(epoch, global_step)
535
+
536
+ def _save_pretrained(self, save_directory):
537
+ self.sd_pipeline.save_pretrained(save_directory)
538
+ self.create_model_card()
539
+
540
+ def create_model_card(
541
+ self,
542
+ model_name: Optional[str] = None,
543
+ dataset_name: Optional[str] = None,
544
+ tags: Union[str, list[str], None] = None,
545
+ ):
546
+ """
547
+ Creates a draft of a model card using the information available to the `Trainer`.
548
+
549
+ Args:
550
+ model_name (`str` or `None`, *optional*, defaults to `None`):
551
+ Name of the model.
552
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
553
+ Name of the dataset used for training.
554
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
555
+ Tags to be associated with the model card.
556
+ """
557
+ if not self.is_world_process_zero():
558
+ return
559
+
560
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
561
+ base_model = self.model.config._name_or_path
562
+ else:
563
+ base_model = None
564
+
565
+ tags = tags or []
566
+ if isinstance(tags, str):
567
+ tags = [tags]
568
+
569
+ if hasattr(self.model.config, "unsloth_version"):
570
+ tags.append("unsloth")
571
+
572
+ citation = textwrap.dedent("""\
573
+ @article{prabhudesai2024aligning,
574
+ title = {{Aligning Text-to-Image Diffusion Models with Reward Backpropagation}},
575
+ author = {Mihir Prabhudesai and Anirudh Goyal and Deepak Pathak and Katerina Fragkiadaki},
576
+ year = 2024,
577
+ eprint = {arXiv:2310.03739}
578
+ }""")
579
+
580
+ model_card = generate_model_card(
581
+ base_model=base_model,
582
+ model_name=model_name,
583
+ hub_model_id=self.hub_model_id,
584
+ dataset_name=dataset_name,
585
+ tags=tags,
586
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
587
+ comet_url=get_comet_experiment_url(),
588
+ trainer_name="AlignProp",
589
+ trainer_citation=citation,
590
+ paper_title="Aligning Text-to-Image Diffusion Models with Reward Backpropagation",
591
+ paper_id="2310.03739",
592
+ )
593
+
594
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
595
+ class UnslothAlignPropTrainer(_UnslothAlignPropTrainer):
596
+ """
597
+
598
+ The AlignPropTrainer uses Deep Diffusion Policy Optimization to optimise diffusion models.
599
+ Note, this trainer is heavily inspired by the work here: https://github.com/mihirp1998/AlignProp/
600
+ As of now only Stable Diffusion based pipelines are supported
601
+
602
+ Attributes:
603
+ config (`AlignPropConfig`):
604
+ Configuration object for AlignPropTrainer. Check the documentation of `PPOConfig` for more details.
605
+ reward_function (`Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor]`):
606
+ Reward function to be used
607
+ prompt_function (`Callable[[], tuple[str, Any]]`):
608
+ Function to generate prompts to guide model
609
+ sd_pipeline (`DDPOStableDiffusionPipeline`):
610
+ Stable Diffusion pipeline to be used for training.
611
+ image_samples_hook (`Optional[Callable[[Any, Any, Any], Any]]`):
612
+ Hook to be called to log images
613
+
614
+ """
615
+ def __init__(
616
+ self,
617
+ config,
618
+ reward_function,
619
+ prompt_function,
620
+ sd_pipeline,
621
+ image_samples_hook = None,
622
+ **kwargs
623
+ ):
624
+ if args is None: args = UnslothAlignPropConfig()
625
+ other_metrics = []
626
+
627
+ from unsloth_zoo.logging_utils import PatchRLStatistics
628
+ PatchRLStatistics('alignprop_trainer', other_metrics)
629
+
630
+ super().__init__(
631
+ config = config,
632
+ reward_function = reward_function,
633
+ prompt_function = prompt_function,
634
+ sd_pipeline = sd_pipeline,
635
+ image_samples_hook = image_samples_hook,**kwargs)
636
+
637
+ pass
unsloth_compiled_cache/UnslothBCOTrainer.py ADDED
@@ -0,0 +1,1824 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.3.17
3
+ 2025.3.19
4
+ 4.50.3
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from trl.trainer.bco_trainer import (Any, AutoModelForCausalLM, BCOConfig, BCOTrainer, BaseImageProcessor, CLF_NAME, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, Literal, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedModelWrapper, PreTrainedTokenizerBase, ProcessorMixin, RUNNING_NAME, RunningMoments, SequentialSampler, Trainer, TrainerCallback, TrainingArguments, Union, _process_tokens, _tokenize, amp, contextmanager, create_reference_model, deepcopy, deepspeed, defaultdict, disable_dropout_in_model, generate_model_card, get_comet_experiment_url, has_length, inspect, is_comet_available, is_peft_available, is_sklearn_available, is_wandb_available, itemgetter, log_table_to_comet_experiment, maybe_apply_chat_template, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_model_for_kbit_training, random, textwrap, torch, tqdm, transformers, version, wandb, warnings, F, Optional, PeftModel, PreTrainedModel, Trainer, is_peft_available, os, torch)
13
+
14
+
15
+ import os
16
+ from typing import *
17
+ from dataclasses import dataclass, field
18
+ from packaging.version import Version
19
+ import torch
20
+ import numpy as np
21
+ from contextlib import nullcontext
22
+ from torch.nn import functional as F
23
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
+
25
+ torch_compile_options = {
26
+ "epilogue_fusion" : True,
27
+ "max_autotune" : False,
28
+ "shape_padding" : True,
29
+ "trace.enabled" : False,
30
+ "triton.cudagraphs" : False,
31
+ }
32
+
33
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
+ def selective_log_softmax(logits, index):
35
+ logits = logits.to(torch.float32)
36
+ selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
+ # loop to reduce peak mem consumption
38
+ # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
+ logsumexp_values = torch.logsumexp(logits, dim = -1)
40
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
+ return per_token_logps
42
+ @dataclass
43
+ class UnslothBCOConfig(BCOConfig):
44
+ """
45
+
46
+ Configuration class for the [`BCOTrainer`].
47
+
48
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
49
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
50
+ command line.
51
+
52
+ Parameters:
53
+ max_length (`int` or `None`, *optional*, defaults to `1024`):
54
+ Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
55
+ to use the default data collator.
56
+ max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
57
+ Maximum length of the prompt. This argument is required if you want to use the default data collator.
58
+ max_completion_length (`int` or `None`, *optional*, defaults to `None`):
59
+ Maximum length of the completion. This argument is required if you want to use the default data collator
60
+ and your model is an encoder-decoder.
61
+ beta (`float`, *optional*, defaults to `0.1`):
62
+ Parameter controlling the deviation from the reference model. Higher β means less deviation from the
63
+ reference model.
64
+ label_pad_token_id (`int`, *optional*, defaults to `-100`):
65
+ Label pad token id. This argument is required if you want to use the default data collator.
66
+ padding_value (`int` or `None`, *optional*, defaults to `None`):
67
+ Padding value to use. If `None`, the padding value of the tokenizer is used.
68
+ truncation_mode (`str`, *optional*, defaults to `"keep_end"`):
69
+ Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`.
70
+ This argument is required if you want to use the default data collator.
71
+ disable_dropout (`bool`, *optional*, defaults to `True`):
72
+ Whether to disable dropout in the model and reference model.
73
+ generate_during_eval (`bool`, *optional*, defaults to `False`):
74
+ If `True`, generates and logs completions from both the model and the reference model to W&B or Comet during
75
+ evaluation.
76
+ is_encoder_decoder (`bool` or `None`, *optional*, defaults to `None`):
77
+ When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument,
78
+ you need to specify if the model returned by the callable is an encoder-decoder model.
79
+ precompute_ref_log_probs (`bool`, *optional*, defaults to `False`):
80
+ Whether to precompute reference model log probabilities for training and evaluation datasets. This is
81
+ useful when training without the reference model to reduce the total GPU memory needed.
82
+ model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
83
+ Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
84
+ string.
85
+ ref_model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
86
+ Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the reference model
87
+ from a string.
88
+ dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
89
+ Number of processes to use for processing the dataset.
90
+ prompt_sample_size (`int`, *optional*, defaults to `1024`):
91
+ Number of prompts that are fed to density ratio classifier.
92
+ min_density_ratio (`float`, *optional*, defaults to `0.5`):
93
+ Minimum value of the density ratio. The estimated density ratio is clamped to this value.
94
+ max_density_ratio (`float`, *optional*, defaults to `10.0`):
95
+ Maximum value of the density ratio. The estimated density ratio is clamped to this value.
96
+
97
+ """
98
+ vllm_sampling_params: Optional[Any] = field(
99
+ default = None,
100
+ metadata = {'help': 'vLLM SamplingParams'},
101
+ )
102
+ unsloth_num_chunks : Optional[int] = field(
103
+ default = -1,
104
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
105
+ )
106
+ def __init__(
107
+ self,
108
+ output_dir = None,
109
+ overwrite_output_dir = None,
110
+ do_train = False,
111
+ do_eval = False,
112
+ do_predict = False,
113
+ eval_strategy = 'no',
114
+ prediction_loss_only = False,
115
+ per_device_train_batch_size = 4,
116
+ per_device_eval_batch_size = 4,
117
+ per_gpu_train_batch_size = None,
118
+ per_gpu_eval_batch_size = None,
119
+ gradient_accumulation_steps = 2,
120
+ eval_accumulation_steps = 2,
121
+ eval_delay = 0,
122
+ torch_empty_cache_steps = 250,
123
+ learning_rate = 5e-05,
124
+ weight_decay = 0.01,
125
+ adam_beta1 = 0.9,
126
+ adam_beta2 = 0.999,
127
+ adam_epsilon = 1e-08,
128
+ max_grad_norm = 1.0,
129
+ num_train_epochs = 3.0,
130
+ max_steps = -1,
131
+ lr_scheduler_type = 'linear',
132
+ warmup_ratio = 0.1,
133
+ warmup_steps = 0,
134
+ log_level = 'passive',
135
+ log_level_replica = 'warning',
136
+ log_on_each_node = True,
137
+ logging_dir = None,
138
+ logging_strategy = 'steps',
139
+ logging_first_step = False,
140
+ logging_steps = 1,
141
+ logging_nan_inf_filter = False,
142
+ save_strategy = 'steps',
143
+ save_steps = 500,
144
+ save_total_limit = None,
145
+ save_safetensors = True,
146
+ save_on_each_node = False,
147
+ save_only_model = False,
148
+ restore_callback_states_from_checkpoint = False,
149
+ no_cuda = False,
150
+ use_cpu = False,
151
+ use_mps_device = False,
152
+ seed = 3407,
153
+ data_seed = 3407,
154
+ jit_mode_eval = False,
155
+ use_ipex = False,
156
+ bf16 = False,
157
+ fp16 = False,
158
+ fp16_opt_level = 'O1',
159
+ half_precision_backend = 'auto',
160
+ bf16_full_eval = False,
161
+ fp16_full_eval = False,
162
+ tf32 = None,
163
+ local_rank = -1,
164
+ ddp_backend = None,
165
+ tpu_num_cores = None,
166
+ tpu_metrics_debug = False,
167
+ debug = '',
168
+ dataloader_drop_last = False,
169
+ eval_steps = None,
170
+ dataloader_num_workers = 0,
171
+ dataloader_prefetch_factor = None,
172
+ past_index = -1,
173
+ run_name = None,
174
+ disable_tqdm = None,
175
+ remove_unused_columns = True,
176
+ label_names = None,
177
+ load_best_model_at_end = False,
178
+ metric_for_best_model = None,
179
+ greater_is_better = None,
180
+ ignore_data_skip = False,
181
+ fsdp = '',
182
+ fsdp_min_num_params = 0,
183
+ fsdp_config = None,
184
+ tp_size = 0,
185
+ fsdp_transformer_layer_cls_to_wrap = None,
186
+ accelerator_config = None,
187
+ deepspeed = None,
188
+ label_smoothing_factor = 0.0,
189
+ optim = 'adamw_8bit',
190
+ optim_args = None,
191
+ adafactor = False,
192
+ group_by_length = False,
193
+ length_column_name = 'length',
194
+ report_to = None,
195
+ ddp_find_unused_parameters = None,
196
+ ddp_bucket_cap_mb = None,
197
+ ddp_broadcast_buffers = None,
198
+ dataloader_pin_memory = True,
199
+ dataloader_persistent_workers = False,
200
+ skip_memory_metrics = True,
201
+ use_legacy_prediction_loop = False,
202
+ push_to_hub = False,
203
+ resume_from_checkpoint = None,
204
+ hub_model_id = None,
205
+ hub_strategy = 'every_save',
206
+ hub_token = None,
207
+ hub_private_repo = None,
208
+ hub_always_push = False,
209
+ gradient_checkpointing = False,
210
+ gradient_checkpointing_kwargs = None,
211
+ include_inputs_for_metrics = False,
212
+ eval_do_concat_batches = True,
213
+ fp16_backend = 'auto',
214
+ evaluation_strategy = None,
215
+ push_to_hub_model_id = None,
216
+ push_to_hub_organization = None,
217
+ push_to_hub_token = None,
218
+ mp_parameters = '',
219
+ auto_find_batch_size = False,
220
+ full_determinism = False,
221
+ torchdynamo = None,
222
+ ray_scope = 'last',
223
+ ddp_timeout = 1800,
224
+ torch_compile = False,
225
+ torch_compile_backend = None,
226
+ torch_compile_mode = None,
227
+ dispatch_batches = None,
228
+ split_batches = None,
229
+ include_tokens_per_second = False,
230
+ include_num_input_tokens_seen = False,
231
+ neftune_noise_alpha = None,
232
+ optim_target_modules = None,
233
+ batch_eval_metrics = False,
234
+ eval_on_start = False,
235
+ use_liger_kernel = False,
236
+ eval_use_gather_object = False,
237
+ average_tokens_across_devices = False,
238
+ max_length = 1024,
239
+ max_prompt_length = 512,
240
+ max_completion_length = None,
241
+ beta = 0.1,
242
+ label_pad_token_id = -100,
243
+ padding_value = None,
244
+ truncation_mode = 'keep_end',
245
+ disable_dropout = True,
246
+ generate_during_eval = False,
247
+ is_encoder_decoder = None,
248
+ precompute_ref_log_probs = False,
249
+ model_init_kwargs = None,
250
+ ref_model_init_kwargs = None,
251
+ dataset_num_proc = None,
252
+ prompt_sample_size = 1024,
253
+ min_density_ratio = 0.5,
254
+ max_density_ratio = 10.0,
255
+ vllm_sampling_params = None,
256
+ unsloth_num_chunks = -1,
257
+ **kwargs,
258
+ ):
259
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
260
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
261
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
262
+ output_dir = 'unsloth_training_checkpoints'
263
+ save_strategy = 'no'
264
+ if dataset_num_proc is None:
265
+ from multiprocessing import cpu_count
266
+ dataset_num_proc = cpu_count()
267
+
268
+ super().__init__(
269
+ output_dir = output_dir,
270
+ overwrite_output_dir = overwrite_output_dir,
271
+ do_train = do_train,
272
+ do_eval = do_eval,
273
+ do_predict = do_predict,
274
+ eval_strategy = eval_strategy,
275
+ prediction_loss_only = prediction_loss_only,
276
+ per_device_train_batch_size = per_device_train_batch_size,
277
+ per_device_eval_batch_size = per_device_eval_batch_size,
278
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
279
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
280
+ gradient_accumulation_steps = gradient_accumulation_steps,
281
+ eval_accumulation_steps = eval_accumulation_steps,
282
+ eval_delay = eval_delay,
283
+ torch_empty_cache_steps = torch_empty_cache_steps,
284
+ learning_rate = learning_rate,
285
+ weight_decay = weight_decay,
286
+ adam_beta1 = adam_beta1,
287
+ adam_beta2 = adam_beta2,
288
+ adam_epsilon = adam_epsilon,
289
+ max_grad_norm = max_grad_norm,
290
+ num_train_epochs = num_train_epochs,
291
+ max_steps = max_steps,
292
+ lr_scheduler_type = lr_scheduler_type,
293
+ warmup_ratio = warmup_ratio,
294
+ warmup_steps = warmup_steps,
295
+ log_level = log_level,
296
+ log_level_replica = log_level_replica,
297
+ log_on_each_node = log_on_each_node,
298
+ logging_dir = logging_dir,
299
+ logging_strategy = logging_strategy,
300
+ logging_first_step = logging_first_step,
301
+ logging_steps = logging_steps,
302
+ logging_nan_inf_filter = logging_nan_inf_filter,
303
+ save_strategy = save_strategy,
304
+ save_steps = save_steps,
305
+ save_total_limit = save_total_limit,
306
+ save_safetensors = save_safetensors,
307
+ save_on_each_node = save_on_each_node,
308
+ save_only_model = save_only_model,
309
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
310
+ no_cuda = no_cuda,
311
+ use_cpu = use_cpu,
312
+ use_mps_device = use_mps_device,
313
+ seed = seed,
314
+ data_seed = data_seed,
315
+ jit_mode_eval = jit_mode_eval,
316
+ use_ipex = use_ipex,
317
+ bf16 = bf16,
318
+ fp16 = fp16,
319
+ fp16_opt_level = fp16_opt_level,
320
+ half_precision_backend = half_precision_backend,
321
+ bf16_full_eval = bf16_full_eval,
322
+ fp16_full_eval = fp16_full_eval,
323
+ tf32 = tf32,
324
+ local_rank = local_rank,
325
+ ddp_backend = ddp_backend,
326
+ tpu_num_cores = tpu_num_cores,
327
+ tpu_metrics_debug = tpu_metrics_debug,
328
+ debug = debug,
329
+ dataloader_drop_last = dataloader_drop_last,
330
+ eval_steps = eval_steps,
331
+ dataloader_num_workers = dataloader_num_workers,
332
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
333
+ past_index = past_index,
334
+ run_name = run_name,
335
+ disable_tqdm = disable_tqdm,
336
+ remove_unused_columns = remove_unused_columns,
337
+ label_names = label_names,
338
+ load_best_model_at_end = load_best_model_at_end,
339
+ metric_for_best_model = metric_for_best_model,
340
+ greater_is_better = greater_is_better,
341
+ ignore_data_skip = ignore_data_skip,
342
+ fsdp = fsdp,
343
+ fsdp_min_num_params = fsdp_min_num_params,
344
+ fsdp_config = fsdp_config,
345
+ tp_size = tp_size,
346
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
347
+ accelerator_config = accelerator_config,
348
+ deepspeed = deepspeed,
349
+ label_smoothing_factor = label_smoothing_factor,
350
+ optim = optim,
351
+ optim_args = optim_args,
352
+ adafactor = adafactor,
353
+ group_by_length = group_by_length,
354
+ length_column_name = length_column_name,
355
+ report_to = report_to,
356
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
357
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
358
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
359
+ dataloader_pin_memory = dataloader_pin_memory,
360
+ dataloader_persistent_workers = dataloader_persistent_workers,
361
+ skip_memory_metrics = skip_memory_metrics,
362
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
363
+ push_to_hub = push_to_hub,
364
+ resume_from_checkpoint = resume_from_checkpoint,
365
+ hub_model_id = hub_model_id,
366
+ hub_strategy = hub_strategy,
367
+ hub_token = hub_token,
368
+ hub_private_repo = hub_private_repo,
369
+ hub_always_push = hub_always_push,
370
+ gradient_checkpointing = gradient_checkpointing,
371
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
372
+ include_inputs_for_metrics = include_inputs_for_metrics,
373
+ eval_do_concat_batches = eval_do_concat_batches,
374
+ fp16_backend = fp16_backend,
375
+ evaluation_strategy = evaluation_strategy,
376
+ push_to_hub_model_id = push_to_hub_model_id,
377
+ push_to_hub_organization = push_to_hub_organization,
378
+ push_to_hub_token = push_to_hub_token,
379
+ mp_parameters = mp_parameters,
380
+ auto_find_batch_size = auto_find_batch_size,
381
+ full_determinism = full_determinism,
382
+ torchdynamo = torchdynamo,
383
+ ray_scope = ray_scope,
384
+ ddp_timeout = ddp_timeout,
385
+ torch_compile = torch_compile,
386
+ torch_compile_backend = torch_compile_backend,
387
+ torch_compile_mode = torch_compile_mode,
388
+ dispatch_batches = dispatch_batches,
389
+ split_batches = split_batches,
390
+ include_tokens_per_second = include_tokens_per_second,
391
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
392
+ neftune_noise_alpha = neftune_noise_alpha,
393
+ optim_target_modules = optim_target_modules,
394
+ batch_eval_metrics = batch_eval_metrics,
395
+ eval_on_start = eval_on_start,
396
+ use_liger_kernel = use_liger_kernel,
397
+ eval_use_gather_object = eval_use_gather_object,
398
+ average_tokens_across_devices = average_tokens_across_devices,
399
+ max_length = max_length,
400
+ max_prompt_length = max_prompt_length,
401
+ max_completion_length = max_completion_length,
402
+ beta = beta,
403
+ label_pad_token_id = label_pad_token_id,
404
+ padding_value = padding_value,
405
+ truncation_mode = truncation_mode,
406
+ disable_dropout = disable_dropout,
407
+ generate_during_eval = generate_during_eval,
408
+ is_encoder_decoder = is_encoder_decoder,
409
+ precompute_ref_log_probs = precompute_ref_log_probs,
410
+ model_init_kwargs = model_init_kwargs,
411
+ ref_model_init_kwargs = ref_model_init_kwargs,
412
+ dataset_num_proc = dataset_num_proc,
413
+ prompt_sample_size = prompt_sample_size,
414
+ min_density_ratio = min_density_ratio,
415
+ max_density_ratio = max_density_ratio,**kwargs)
416
+ self.vllm_sampling_params = vllm_sampling_params
417
+ self.unsloth_num_chunks = unsloth_num_chunks
418
+ pass
419
+
420
+ class _UnslothBCOTrainer(Trainer):
421
+ r""""""
422
+
423
+ _tag_names = ["trl", "bco"]
424
+
425
+ def __init__(
426
+ self,
427
+ model: Union[PreTrainedModel, nn.Module, str] = None,
428
+ ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
429
+ args: BCOConfig = None,
430
+ train_dataset: Optional[Dataset] = None,
431
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
432
+ processing_class: Optional[
433
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
434
+ ] = None,
435
+ data_collator: Optional[DataCollator] = None,
436
+ model_init: Optional[Callable[[], PreTrainedModel]] = None,
437
+ callbacks: Optional[list[TrainerCallback]] = None,
438
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
439
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
440
+ peft_config: Optional[dict] = None,
441
+ compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
442
+ model_adapter_name: Optional[str] = None,
443
+ ref_adapter_name: Optional[str] = None,
444
+ embedding_func: Optional[Callable] = None,
445
+ embedding_tokenizer: Optional[PreTrainedTokenizerBase] = None,
446
+ ):
447
+ if not is_sklearn_available():
448
+ raise ImportError(
449
+ "BCOTrainer requires the scikit-learn library. Please install it with `pip install scikit-learn`."
450
+ )
451
+
452
+ if type(args) is TrainingArguments:
453
+ raise ValueError("Please use `BCOConfig` instead `TrainingArguments`.")
454
+
455
+ if not isinstance(model, str) and ref_model is model:
456
+ raise ValueError(
457
+ "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
458
+ "same as `model`, you must mass a copy of it, or `None` if you use peft."
459
+ )
460
+
461
+ if args.model_init_kwargs is None:
462
+ model_init_kwargs = {}
463
+ elif not isinstance(model, str):
464
+ raise ValueError("You passed model_kwargs to the BCOTrainer. But your model is already instantiated.")
465
+ else:
466
+ model_init_kwargs = args.model_init_kwargs
467
+ torch_dtype = model_init_kwargs.get("torch_dtype")
468
+ if torch_dtype is not None:
469
+ # Convert to `torch.dtype` if an str is passed
470
+ if isinstance(torch_dtype, str) and torch_dtype != "auto":
471
+ torch_dtype = getattr(torch, torch_dtype)
472
+ if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
473
+ raise ValueError(
474
+ f"Invalid `torch_dtype` passed to the BCOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
475
+ )
476
+ model_init_kwargs["torch_dtype"] = torch_dtype
477
+
478
+ if args.ref_model_init_kwargs is None:
479
+ ref_model_init_kwargs = {}
480
+ elif not isinstance(ref_model, str):
481
+ raise ValueError(
482
+ "You passed ref_model_kwargs to the BCOTrainer. But your ref_model is already instantiated."
483
+ )
484
+ else:
485
+ ref_model_init_kwargs = args.ref_model_init_kwargs
486
+ torch_dtype = ref_model_init_kwargs.get("torch_dtype")
487
+ if torch_dtype is not None:
488
+ # Convert to `torch.dtype` if an str is passed
489
+ if isinstance(torch_dtype, str) and torch_dtype != "auto":
490
+ torch_dtype = getattr(torch, torch_dtype)
491
+ if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
492
+ raise ValueError(
493
+ f"Invalid `torch_dtype` passed to the BCOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
494
+ )
495
+ ref_model_init_kwargs["torch_dtype"] = torch_dtype
496
+
497
+ if isinstance(model, str):
498
+ model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
499
+
500
+ if isinstance(ref_model, str):
501
+ ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **ref_model_init_kwargs)
502
+
503
+ # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
504
+ # has been called in order to properly call autocast if needed.
505
+ self._peft_has_been_casted_to_bf16 = False
506
+
507
+ if not is_peft_available() and peft_config is not None:
508
+ raise ValueError(
509
+ "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it with `pip install peft` to use the PEFT models"
510
+ )
511
+ elif is_peft_available() and peft_config is not None:
512
+ # if model is a peft model and we have a peft_config, we merge and unload it first
513
+ if isinstance(model, PeftModel):
514
+ model = model.merge_and_unload()
515
+
516
+ if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
517
+ _support_gc_kwargs = hasattr(
518
+ args, "gradient_checkpointing_kwargs"
519
+ ) and "gradient_checkpointing_kwargs" in list(
520
+ inspect.signature(prepare_model_for_kbit_training).parameters
521
+ )
522
+
523
+ prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
524
+
525
+ if _support_gc_kwargs:
526
+ prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
527
+
528
+ model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
529
+ elif getattr(args, "gradient_checkpointing", False):
530
+ # For backward compatibility with older versions of transformers
531
+ if hasattr(model, "enable_input_require_grads"):
532
+ model.enable_input_require_grads()
533
+ else:
534
+
535
+ def make_inputs_require_grad(module, input, output):
536
+ output.requires_grad_(True)
537
+
538
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
539
+
540
+ # get peft model with the given config
541
+ model = model
542
+ if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
543
+ peft_module_casting_to_bf16(model)
544
+ # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
545
+ self._peft_has_been_casted_to_bf16 = True
546
+
547
+ # For models that use gradient_checkpointing, we need to attach a hook that enables input
548
+ # to explicitly have `requires_grad=True`, otherwise training will either silently
549
+ # fail or completely fail.
550
+ elif getattr(args, "gradient_checkpointing", False):
551
+ # For backward compatibility with older versions of transformers
552
+ if hasattr(model, "enable_input_require_grads"):
553
+ model.enable_input_require_grads()
554
+ else:
555
+
556
+ def make_inputs_require_grad(module, input, output):
557
+ output.requires_grad_(True)
558
+
559
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
560
+
561
+ if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
562
+ raise ValueError(
563
+ "`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
564
+ " Please install `wandb` or `comet-ml` to resolve."
565
+ )
566
+
567
+ if model is not None:
568
+ self.is_encoder_decoder = model.config.is_encoder_decoder
569
+ elif args.is_encoder_decoder is None:
570
+ raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
571
+ else:
572
+ self.is_encoder_decoder = args.is_encoder_decoder
573
+
574
+ self.is_peft_model = is_peft_available() and isinstance(model, PeftModel)
575
+ self.model_adapter_name = model_adapter_name
576
+ self.ref_adapter_name = ref_adapter_name
577
+
578
+ if ref_model:
579
+ self.ref_model = ref_model
580
+ elif self.is_peft_model or args.precompute_ref_log_probs:
581
+ # The `model` with adapters turned off will be used as the reference model
582
+ self.ref_model = None
583
+ else:
584
+ self.ref_model = create_reference_model(model)
585
+
586
+ if processing_class is None:
587
+ raise ValueError(
588
+ "max_length or a processing_class must be specified when using the default DPODataCollatorWithPadding"
589
+ )
590
+ if args.max_length is None:
591
+ warnings.warn(
592
+ "When using DPODataCollatorWithPadding, you should set `max_length` in the `BCOConfig`. "
593
+ "It will be set to `512` by default, but you should do it yourself in the future.",
594
+ UserWarning,
595
+ )
596
+ max_length = 512
597
+ if args.max_length is not None:
598
+ max_length = args.max_length
599
+
600
+ if args.max_prompt_length is None:
601
+ warnings.warn(
602
+ "When using DPODataCollatorWithPadding, you should set `max_prompt_length` in the `BCOConfig`. "
603
+ "It will be set to `128` by default, but you should do it yourself in the future.",
604
+ UserWarning,
605
+ )
606
+ max_prompt_length = 128
607
+ if args.max_prompt_length is not None:
608
+ max_prompt_length = args.max_prompt_length
609
+
610
+ max_completion_length = None
611
+ if args.max_completion_length is None and self.is_encoder_decoder:
612
+ warnings.warn(
613
+ "When using DPODataCollatorWithPadding with an encoder decoder architecture, you should set `max_completion_length` in the BCOTrainer's init"
614
+ " it will be set to `128` by default, but you should do it yourself in the future.",
615
+ UserWarning,
616
+ )
617
+ max_completion_length = 128
618
+ if args.max_completion_length is not None and self.is_encoder_decoder:
619
+ max_completion_length = args.max_completion_length
620
+
621
+ if data_collator is None:
622
+ data_collator = DPODataCollatorWithPadding(
623
+ pad_token_id=processing_class.pad_token_id,
624
+ label_pad_token_id=args.label_pad_token_id,
625
+ is_encoder_decoder=self.is_encoder_decoder,
626
+ )
627
+
628
+ if args.remove_unused_columns:
629
+ args.remove_unused_columns = False
630
+ # warn users
631
+ warnings.warn(
632
+ "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your BCOConfig"
633
+ " we have set it for you, but you should do it yourself in the future.",
634
+ UserWarning,
635
+ )
636
+
637
+ self.use_dpo_data_collator = True
638
+ else:
639
+ self.use_dpo_data_collator = False
640
+
641
+ # Disable dropout in the model and reference model
642
+ if args.disable_dropout:
643
+ disable_dropout_in_model(model)
644
+ if self.ref_model is not None:
645
+ disable_dropout_in_model(self.ref_model)
646
+
647
+ self.max_length = max_length
648
+ self.generate_during_eval = args.generate_during_eval
649
+ self.label_pad_token_id = args.label_pad_token_id
650
+ self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id
651
+ self.max_prompt_length = max_prompt_length
652
+ self.truncation_mode = args.truncation_mode
653
+ self.max_completion_length = max_completion_length
654
+ self.precompute_ref_log_probs = args.precompute_ref_log_probs
655
+
656
+ # Since ref_logs are precomputed on the first call to get_train/eval_dataloader
657
+ # keep track of first called to avoid computation of future calls
658
+ self._precomputed_train_ref_log_probs = False
659
+ self._precomputed_eval_ref_log_probs = False
660
+
661
+ # metric
662
+ self._stored_metrics = defaultdict(lambda: defaultdict(list))
663
+
664
+ # BCO parameter
665
+ self.beta = args.beta
666
+ self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
667
+ self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
668
+ if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
669
+ warnings.warn(
670
+ "You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
671
+ "`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
672
+ "greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
673
+ "loss.",
674
+ UserWarning,
675
+ )
676
+
677
+ # Underlying Distribution Matching argument
678
+ self.embedding_func = embedding_func
679
+ self.embedding_tokenizer = embedding_tokenizer
680
+
681
+ # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
682
+ # input tensor associated with the key "input_ids". However, in BCO, the sampled data does not include the
683
+ # "input_ids" key. Instead, the available keys are "prompt_input_ids" and "completion_input_ids". As a result,
684
+ # the trainer issues the warning: "Could not estimate the number of tokens of the input, floating-point
685
+ # operations will not be computed." To suppress this warning, we set the "estimate_tokens" key in the model's
686
+ # "warnings_issued" dictionary to True. This acts as a flag to indicate that the warning has already been
687
+ # issued.
688
+ model.warnings_issued["estimate_tokens"] = True
689
+
690
+ with PartialState().local_main_process_first():
691
+ # Apply the chat template if needed
692
+ train_dataset = train_dataset.map(
693
+ maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, num_proc=args.dataset_num_proc
694
+ )
695
+ if eval_dataset is not None:
696
+ eval_dataset = eval_dataset.map(
697
+ maybe_apply_chat_template,
698
+ fn_kwargs={"tokenizer": processing_class},
699
+ num_proc=args.dataset_num_proc,
700
+ )
701
+ # Shuffle the datasets
702
+ train_dataset = train_dataset.shuffle(seed=args.data_seed)
703
+ if eval_dataset is not None:
704
+ eval_dataset = eval_dataset.shuffle(seed=args.data_seed)
705
+ # Tokenize and prepare the training datasets
706
+ train_dataset = train_dataset.map(
707
+ _tokenize,
708
+ batched=True,
709
+ fn_kwargs={"tokenizer": processing_class, "embedding_tokenizer": self.embedding_tokenizer},
710
+ num_proc=args.dataset_num_proc,
711
+ desc="Tokenizing train dataset",
712
+ )
713
+
714
+ # Prepare the datasets
715
+ fn_kwargs = {
716
+ "prefix": "",
717
+ "is_encoder_decoder": self.is_encoder_decoder,
718
+ "tokenizer": processing_class,
719
+ "max_length": self.max_length,
720
+ "truncation_mode": self.truncation_mode,
721
+ "label_pad_token_id": self.label_pad_token_id,
722
+ "max_prompt_length": self.max_prompt_length,
723
+ "max_completion_length": self.max_completion_length,
724
+ }
725
+ train_dataset = train_dataset.map(
726
+ _process_tokens,
727
+ fn_kwargs=fn_kwargs,
728
+ num_proc=args.dataset_num_proc,
729
+ desc="Processing tokenized train dataset",
730
+ )
731
+
732
+ if eval_dataset is not None:
733
+ # Tokenize
734
+ eval_dataset = eval_dataset.map(
735
+ _tokenize,
736
+ fn_kwargs={"tokenizer": processing_class, "embedding_tokenizer": self.embedding_tokenizer},
737
+ batched=True,
738
+ num_proc=args.dataset_num_proc,
739
+ desc="Tokenizing eval dataset",
740
+ )
741
+
742
+ # Process
743
+ fn_kwargs = {
744
+ "prefix": "",
745
+ "is_encoder_decoder": self.is_encoder_decoder,
746
+ "tokenizer": processing_class,
747
+ "max_length": self.max_length,
748
+ "truncation_mode": self.truncation_mode,
749
+ "label_pad_token_id": self.label_pad_token_id,
750
+ "max_prompt_length": self.max_prompt_length,
751
+ "max_completion_length": self.max_completion_length,
752
+ }
753
+ eval_dataset = eval_dataset.map(
754
+ _process_tokens,
755
+ fn_kwargs=fn_kwargs,
756
+ num_proc=args.dataset_num_proc,
757
+ desc="Processing tokenized eval dataset",
758
+ )
759
+
760
+ desirable = train_dataset.filter(
761
+ lambda x: x["label"], num_proc=args.dataset_num_proc, desc="Filtering desirable examples"
762
+ )
763
+ undesirable = train_dataset.filter(
764
+ lambda x: not x["label"], num_proc=args.dataset_num_proc, desc="Filtering undesirable examples"
765
+ )
766
+
767
+ desirable = desirable.shuffle(seed=args.data_seed)
768
+ undesirable = undesirable.shuffle(seed=args.data_seed)
769
+
770
+ super().__init__(
771
+ model=model,
772
+ args=args,
773
+ data_collator=data_collator,
774
+ train_dataset=train_dataset,
775
+ eval_dataset=eval_dataset,
776
+ processing_class=processing_class,
777
+ model_init=model_init,
778
+ compute_metrics=compute_metrics,
779
+ callbacks=callbacks,
780
+ optimizers=optimizers,
781
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
782
+ )
783
+
784
+ # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
785
+ # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
786
+ # self.model_accepts_loss_kwargs to False to enable scaling.
787
+ self.model_accepts_loss_kwargs = False
788
+
789
+ # Add tags for models that have been loaded with the correct transformers version
790
+ if hasattr(self.model, "add_model_tags"):
791
+ self.model.add_model_tags(self._tag_names)
792
+
793
+ if not hasattr(self, "accelerator"):
794
+ raise AttributeError(
795
+ "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
796
+ )
797
+
798
+ # Deepspeed Zero-3 does not support precompute_ref_log_probs
799
+ if self.is_deepspeed_enabled:
800
+ if self.accelerator.state.deepspeed_plugin.zero_stage == 3 and self.precompute_ref_log_probs:
801
+ raise ValueError(
802
+ "You cannot use `precompute_ref_log_probs=True` with Deepspeed ZeRO-3. Please set `precompute_ref_log_probs=False`."
803
+ )
804
+
805
+ if self.ref_model is None:
806
+ if not (self.is_peft_model or self.precompute_ref_log_probs):
807
+ raise ValueError(
808
+ "No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`"
809
+ )
810
+ else:
811
+ if self.is_deepspeed_enabled:
812
+ self.ref_model = self._prepare_deepspeed(self.ref_model)
813
+ else:
814
+ self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
815
+
816
+ self.running = RunningMoments(accelerator=self.accelerator)
817
+
818
+ if self.embedding_func is None:
819
+ return
820
+
821
+ chosen_embeddings = self._get_sample_prompt_embeddings(desirable, sample_size=self.args.prompt_sample_size)
822
+ rejected_embeddings = self._get_sample_prompt_embeddings(undesirable, sample_size=self.args.prompt_sample_size)
823
+
824
+ embeddings = torch.cat((chosen_embeddings, rejected_embeddings), dim=0)
825
+ labels = torch.cat(
826
+ (torch.ones_like(chosen_embeddings[:, 0]), torch.zeros_like(rejected_embeddings[:, 0])), dim=0
827
+ )
828
+
829
+ self.clf = LogisticRegression(class_weight="balanced").fit(
830
+ embeddings.cpu().float().numpy(), labels.cpu().numpy()
831
+ )
832
+
833
+ @property
834
+ def match_underlying_distribution(self):
835
+ return self.embedding_func is not None and self.embedding_tokenizer is not None
836
+
837
+ def _get_chosen_prob(self, prompt_embeddings: torch.FloatTensor) -> torch.FloatTensor:
838
+ """
839
+ Calculates the probability if the given prompt embedding is from desirable dataset.
840
+ This function calculates the probability in the process and ensemble across processes.
841
+ """
842
+ dtype = prompt_embeddings.dtype
843
+ device = prompt_embeddings.device
844
+ rank = self.accelerator.process_index
845
+
846
+ padded_prompt_embeddings = self.accelerator.pad_across_processes(
847
+ prompt_embeddings, pad_index=self.embedding_tokenizer.pad_token_id
848
+ )
849
+ sample_size = padded_prompt_embeddings.shape[0]
850
+ nonzero = padded_prompt_embeddings.mean(dim=1) != self.embedding_tokenizer.pad_token_id
851
+ prompt_embeddings = self.accelerator.gather(padded_prompt_embeddings)
852
+
853
+ # cannot predict for all empty values
854
+ if prompt_embeddings.shape[0] == 0:
855
+ return torch.tensor([], device=device, dtype=dtype)
856
+
857
+ prob = self.clf.predict_proba(prompt_embeddings.cpu().float().numpy())[:, 1]
858
+ prob = torch.as_tensor(prob, dtype=dtype, device=device)
859
+ prob = self.accelerator.reduce(prob, reduction="mean")
860
+
861
+ prob = prob[sample_size * rank : sample_size * (rank + 1)]
862
+ prob = prob[nonzero]
863
+
864
+ return prob
865
+
866
+ def _vectorize_prompt(self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor) -> torch.FloatTensor:
867
+ """
868
+ Replaces processing_class.pad_token_id to embedding_tokenizer.pad_token_id
869
+ and applies self.embedding_func
870
+ """
871
+ input_ids = torch.where(
872
+ input_ids == self.processing_class.pad_token_id,
873
+ self.embedding_tokenizer.pad_token_id,
874
+ input_ids,
875
+ )
876
+
877
+ with torch.no_grad():
878
+ embeddings = self.embedding_func(
879
+ input_ids=input_ids,
880
+ attention_mask=attention_mask,
881
+ )
882
+
883
+ return embeddings
884
+
885
+ def _get_prompt_embeddings(
886
+ self, batch: dict[str, Union[list, torch.LongTensor]]
887
+ ) -> tuple[torch.FloatTensor, torch.FloatTensor]:
888
+ """Extract embeddings from frozen embedding model"""
889
+
890
+ if not self.match_underlying_distribution:
891
+ return None, None
892
+
893
+ embeddings = self._vectorize_prompt(
894
+ input_ids=batch["embedding_input_ids"],
895
+ attention_mask=batch["embedding_attention_mask"],
896
+ )
897
+
898
+ chosen_idx = [i for i in range(len(batch["label"])) if batch["label"][i] is True]
899
+ rejected_idx = [i for i in range(len(batch["label"])) if batch["label"][i] is False]
900
+
901
+ chosen_embeddings = embeddings[chosen_idx, ...]
902
+ rejected_embeddings = embeddings[rejected_idx, ...]
903
+
904
+ return (chosen_embeddings, rejected_embeddings)
905
+
906
+ def _get_sample_prompt_embeddings(self, dataset: Dataset, sample_size: int = 512) -> torch.FloatTensor:
907
+ """
908
+ Sample instances from dataset and get prompt embeddings.
909
+ Used for density ratio classifier training.
910
+ """
911
+ n_samples = min(len(dataset), sample_size)
912
+ rand_indices = np.random.choice(len(dataset), size=(n_samples,))
913
+
914
+ embedding_dataset = dataset.select(rand_indices)
915
+
916
+ dataloader_params = {
917
+ "batch_size": self.args.per_device_train_batch_size,
918
+ "collate_fn": self.data_collator,
919
+ "num_workers": self.args.dataloader_num_workers,
920
+ "pin_memory": self.args.dataloader_pin_memory,
921
+ "shuffle": False,
922
+ }
923
+
924
+ # prepare dataloader
925
+ data_loader = self.accelerator.prepare(DataLoader(embedding_dataset, **dataloader_params))
926
+
927
+ with torch.no_grad():
928
+ all_embeddings = torch.empty(0)
929
+ for padded_batch in tqdm(iterable=data_loader, desc="Building sample prompt embeddings"):
930
+ embeddings = self._vectorize_prompt(
931
+ input_ids=padded_batch["embedding_input_ids"],
932
+ attention_mask=padded_batch["embedding_attention_mask"],
933
+ )
934
+ embeddings = self.accelerator.gather_for_metrics(embeddings)
935
+ all_embeddings = torch.cat((all_embeddings, embeddings.cpu()))
936
+
937
+ return all_embeddings
938
+
939
+ def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
940
+ # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
941
+ deepspeed_plugin = self.accelerator.state.deepspeed_plugin
942
+ config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
943
+
944
+ if model is not None:
945
+ if hasattr(model, "config"):
946
+ hidden_size = (
947
+ max(model.config.hidden_sizes)
948
+ if getattr(model.config, "hidden_sizes", None)
949
+ else getattr(model.config, "hidden_size", None)
950
+ )
951
+ if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
952
+ # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
953
+ # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
954
+ config_kwargs.update(
955
+ {
956
+ "zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
957
+ "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
958
+ "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
959
+ }
960
+ )
961
+
962
+ # If ZeRO-3 is used, we shard both the active and reference model.
963
+ # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
964
+ if config_kwargs["zero_optimization"]["stage"] != 3:
965
+ config_kwargs["zero_optimization"]["stage"] = 0
966
+ model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
967
+ model.eval()
968
+ return model
969
+
970
+ def _save_optimizer_and_scheduler(self, output_dir):
971
+ super()._save_optimizer_and_scheduler(output_dir)
972
+
973
+ # When saving optimizer and scheduler to checkpoint, save also the running delta object.
974
+ output_dir = output_dir if output_dir is not None else self.args.output_dir
975
+
976
+ self.running.save_to_json(os.path.join(output_dir, RUNNING_NAME))
977
+
978
+ if self.match_underlying_distribution:
979
+ torch.save(self.clf.get_params(), os.path.join(output_dir, CLF_NAME))
980
+
981
+ def _load_optimizer_and_scheduler(self, checkpoint):
982
+ super()._load_optimizer_and_scheduler(checkpoint)
983
+
984
+ if checkpoint is None:
985
+ return
986
+ # when loading optimizer and scheduler from checkpoint, also load the running delta object.
987
+ running_file = os.path.join(checkpoint, RUNNING_NAME)
988
+ if os.path.isfile(running_file):
989
+ self.running = RunningMoments.load_from_json(self.accelerator, running_file)
990
+
991
+ if self.match_underlying_distribution:
992
+ clf_file = os.path.join(checkpoint, CLF_NAME)
993
+ if os.path.isfile(running_file):
994
+ self.clf.set_params(**torch.load(clf_file, weights_only=True, map_location="cpu"))
995
+
996
+ @contextmanager
997
+ def null_ref_context(self):
998
+ """Context manager for handling null reference model (that is, peft adapter manipulation)."""
999
+ with (
1000
+ self.accelerator.unwrap_model(self.model).disable_adapter()
1001
+ if self.is_peft_model and not self.ref_adapter_name
1002
+ else nullcontext()
1003
+ ):
1004
+ if self.ref_adapter_name:
1005
+ self.model.set_adapter(self.ref_adapter_name)
1006
+ yield
1007
+ if self.ref_adapter_name:
1008
+ self.model.set_adapter(self.model_adapter_name or "default")
1009
+
1010
+ def get_train_dataloader(self) -> DataLoader:
1011
+ """
1012
+ Returns the training [`~torch.utils.data.DataLoader`].
1013
+
1014
+ Subclass of transformers.src.transformers.trainer.get_train_dataloader to precompute `ref_log_probs`.
1015
+ """
1016
+
1017
+ if self.precompute_ref_log_probs and not self._precomputed_train_ref_log_probs:
1018
+ dataloader_params = {
1019
+ "batch_size": self.args.per_device_train_batch_size,
1020
+ "collate_fn": self.data_collator,
1021
+ "num_workers": self.args.dataloader_num_workers,
1022
+ "pin_memory": self.args.dataloader_pin_memory,
1023
+ "shuffle": False,
1024
+ }
1025
+
1026
+ # prepare dataloader
1027
+ data_loader = self.accelerator.prepare(DataLoader(self.train_dataset, **dataloader_params))
1028
+ reference_completion_logps = []
1029
+
1030
+ for padded_batch in tqdm(iterable=data_loader, desc="Train dataset reference log probs"):
1031
+ reference_completion_logp = self.compute_reference_log_probs(padded_batch)
1032
+
1033
+ reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp)
1034
+ reference_completion_logps.append(reference_completion_logp.cpu())
1035
+
1036
+ self.train_dataset = self.train_dataset.add_column(
1037
+ name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy()
1038
+ )
1039
+
1040
+ self._precomputed_train_ref_log_probs = True
1041
+
1042
+ return super().get_train_dataloader()
1043
+
1044
+ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
1045
+ """
1046
+ Returns the evaluation [`~torch.utils.data.DataLoader`].
1047
+
1048
+ Subclass of transformers.src.transformers.trainer.get_eval_dataloader to precompute `ref_log_probs`.
1049
+
1050
+ Args:
1051
+ eval_dataset (`torch.utils.data.Dataset`, *optional*):
1052
+ If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted
1053
+ by the `model.forward()` method are automatically removed. It must implement `__len__`.
1054
+ """
1055
+ if eval_dataset is None and self.eval_dataset is None:
1056
+ raise ValueError("Trainer: evaluation requires an eval_dataset.")
1057
+ eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
1058
+
1059
+ if self.precompute_ref_log_probs and not self._precomputed_eval_ref_log_probs:
1060
+ dataloader_params = {
1061
+ "batch_size": self.args.per_device_eval_batch_size,
1062
+ "collate_fn": self.data_collator,
1063
+ "num_workers": self.args.dataloader_num_workers,
1064
+ "pin_memory": self.args.dataloader_pin_memory,
1065
+ "shuffle": False,
1066
+ }
1067
+
1068
+ # prepare dataloader
1069
+ data_loader = self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params))
1070
+
1071
+ reference_completion_logps = []
1072
+
1073
+ for padded_batch in tqdm(iterable=data_loader, desc="Eval dataset reference log probs"):
1074
+ reference_completion_logp = self.compute_reference_log_probs(padded_batch)
1075
+
1076
+ reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp)
1077
+ reference_completion_logps.append(reference_completion_logp.cpu())
1078
+
1079
+ eval_dataset = eval_dataset.add_column(
1080
+ name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy()
1081
+ )
1082
+
1083
+ # Save calculated reference_chosen_logps and reference_rejected_logps to the eval_dataset for subsequent runs
1084
+ if self.eval_dataset is not None:
1085
+ self.eval_dataset = eval_dataset
1086
+ self._precomputed_eval_ref_log_probs = True
1087
+
1088
+ return super().get_eval_dataloader(eval_dataset=eval_dataset)
1089
+
1090
+ def compute_reference_log_probs(self, padded_batch: dict) -> dict:
1091
+ """Computes log probabilities of the reference model for a single padded batch of a BCO specific dataset."""
1092
+ with torch.no_grad():
1093
+ if self.ref_model is None:
1094
+ with self.null_ref_context():
1095
+ if self.is_encoder_decoder:
1096
+ completion_logits = self.model(
1097
+ padded_batch["prompt_input_ids"],
1098
+ attention_mask=padded_batch["prompt_attention_mask"],
1099
+ decoder_input_ids=padded_batch.get("completion_decoder_input_ids"),
1100
+ labels=padded_batch["completion_labels"],
1101
+ ).logits
1102
+
1103
+ else:
1104
+ completion_logits = self.model(
1105
+ padded_batch["completion_input_ids"],
1106
+ attention_mask=padded_batch["completion_attention_mask"],
1107
+ ).logits
1108
+
1109
+ else:
1110
+ if self.is_encoder_decoder:
1111
+ completion_logits = self.ref_model(
1112
+ padded_batch["prompt_input_ids"],
1113
+ attention_mask=padded_batch["prompt_attention_mask"],
1114
+ decoder_input_ids=padded_batch.get("completion_decoder_input_ids"),
1115
+ labels=padded_batch["completion_labels"],
1116
+ ).logits
1117
+
1118
+ else:
1119
+ completion_logits = self.ref_model(
1120
+ padded_batch["completion_input_ids"], attention_mask=padded_batch["completion_attention_mask"]
1121
+ ).logits
1122
+
1123
+ completion_logps = self.get_batch_logps(
1124
+ completion_logits,
1125
+ padded_batch["completion_labels"],
1126
+ average_log_prob=False,
1127
+ is_encoder_decoder=self.is_encoder_decoder,
1128
+ label_pad_token_id=self.label_pad_token_id,
1129
+ )
1130
+
1131
+ return completion_logps
1132
+
1133
+ @staticmethod
1134
+ def get_batch_logps(
1135
+ logits: torch.FloatTensor,
1136
+ labels: torch.LongTensor,
1137
+ average_log_prob: bool = False,
1138
+ label_pad_token_id: int = -100,
1139
+ is_encoder_decoder: bool = False,
1140
+ ) -> torch.FloatTensor:
1141
+ """Compute the log probabilities of the given labels under the given logits.
1142
+
1143
+ Args:
1144
+ logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
1145
+ labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
1146
+ average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
1147
+
1148
+ Returns:
1149
+ A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
1150
+ """
1151
+ if logits.shape[:-1] != labels.shape:
1152
+ raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
1153
+
1154
+ if not is_encoder_decoder:
1155
+ labels = labels[:, 1:].clone()
1156
+ logits = logits[:, :-1, :]
1157
+ else:
1158
+ # Fixes end-dec RuntimeError
1159
+ labels = labels.clone()
1160
+
1161
+ loss_mask = labels != label_pad_token_id
1162
+
1163
+ # dummy token; we'll ignore the losses on these tokens later
1164
+ labels[labels == label_pad_token_id] = 0
1165
+
1166
+ per_token_logps = selective_log_softmax(logits, labels)
1167
+
1168
+ if average_log_prob:
1169
+ return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
1170
+ else:
1171
+ return (per_token_logps * loss_mask).sum(-1)
1172
+
1173
+ def forward(
1174
+ self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
1175
+ ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
1176
+ model_kwargs = (
1177
+ {
1178
+ "labels": batch["completion_labels"],
1179
+ "decoder_input_ids": batch.get("completion_decoder_input_ids"),
1180
+ }
1181
+ if self.is_encoder_decoder
1182
+ else {}
1183
+ )
1184
+ if self.aux_loss_enabled:
1185
+ model_kwargs["output_router_logits"] = True
1186
+
1187
+ outputs = model(
1188
+ batch["completion_input_ids"],
1189
+ attention_mask=batch["completion_attention_mask"],
1190
+ **model_kwargs,
1191
+ )
1192
+ completion_logits = outputs.logits
1193
+
1194
+ completion_logps = self.get_batch_logps(
1195
+ completion_logits,
1196
+ batch["completion_labels"],
1197
+ average_log_prob=False,
1198
+ is_encoder_decoder=self.is_encoder_decoder,
1199
+ label_pad_token_id=self.label_pad_token_id,
1200
+ )
1201
+
1202
+ if completion_logps.shape[0] != len(batch["label"]):
1203
+ raise ValueError(
1204
+ "There is a mismatch between the number of examples in this batch and the number of "
1205
+ "examples for which an output sequence was predicted."
1206
+ )
1207
+
1208
+ chosen_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is True]
1209
+ rejected_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is False]
1210
+
1211
+ chosen_logps = completion_logps[chosen_idx, ...]
1212
+ rejected_logps = completion_logps[rejected_idx, ...]
1213
+
1214
+ chosen_logits = completion_logits[chosen_idx, ...]
1215
+ rejected_logits = completion_logits[rejected_idx, ...]
1216
+
1217
+ if self.aux_loss_enabled:
1218
+ return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, outputs.aux_loss)
1219
+ else:
1220
+ return (chosen_logps, rejected_logps, chosen_logits, rejected_logits)
1221
+
1222
+ def _get_udm_weight(self, rejected_embeddings: torch.FloatTensor) -> torch.FloatTensor:
1223
+ prob_desirable = self._get_chosen_prob(rejected_embeddings)
1224
+ min_ratio = self.args.min_density_ratio
1225
+ max_ratio = self.args.max_density_ratio
1226
+
1227
+ weight = (prob_desirable / (1 - prob_desirable + 1e-8)).clamp(min=min_ratio, max=max_ratio)
1228
+
1229
+ return weight
1230
+
1231
+ def bco_loss(
1232
+ self,
1233
+ policy_chosen_logps: torch.FloatTensor,
1234
+ policy_rejected_logps: torch.FloatTensor,
1235
+ reference_chosen_logps: torch.FloatTensor,
1236
+ reference_rejected_logps: torch.FloatTensor,
1237
+ chosen_embeddings: Optional[torch.FloatTensor],
1238
+ rejected_embeddings: Optional[torch.FloatTensor],
1239
+ ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
1240
+ """Compute the BCO loss for a batch of policy and reference model log probabilities.
1241
+
1242
+ Args:
1243
+ policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (num(chosen) in batch_size,)
1244
+ policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (num(rejected) in batch_size,)
1245
+ reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (num(chosen) in batch_size,)
1246
+ reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (num(rejected) in batch_size,)
1247
+ chosen_embeddings: embeddings of desirable prompts
1248
+ rejected_embeddings: embeddings of undesirable prompts
1249
+
1250
+ Returns:
1251
+ A tuple of four tensors: (losses, chosen_rewards, rejected_rewards, delta).
1252
+ The losses tensor contains the BCO loss for each example in the batch.
1253
+ The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
1254
+ The delta value contains the moving average of all implicit rewards.
1255
+ """
1256
+
1257
+ if policy_chosen_logps.shape[0] != 0 or reference_chosen_logps.shape[0] != 0:
1258
+ chosen_logratios = policy_chosen_logps - reference_chosen_logps
1259
+ chosen_rewards = self.beta * chosen_logratios
1260
+ else:
1261
+ # lists can't be empty -- if they are, then accelerate.gather will hang
1262
+ chosen_losses = torch.Tensor([]).to(self.accelerator.device)
1263
+ chosen_rewards = torch.Tensor([]).to(self.accelerator.device)
1264
+
1265
+ if policy_rejected_logps.shape[0] != 0 or reference_rejected_logps.shape[0] != 0:
1266
+ rejected_logratios = policy_rejected_logps - reference_rejected_logps
1267
+ rejected_rewards = self.beta * rejected_logratios
1268
+ else:
1269
+ # lists can't be empty -- if they are, then accelerate.gather will hang
1270
+ rejected_losses = torch.Tensor([]).to(self.accelerator.device)
1271
+ rejected_rewards = torch.Tensor([]).to(self.accelerator.device)
1272
+
1273
+ rewards = torch.cat((chosen_rewards, rejected_rewards), 0).mean().detach()
1274
+ self.running.update(rewards)
1275
+ delta = self.running.mean
1276
+
1277
+ if policy_chosen_logps.shape[0] != 0 or reference_chosen_logps.shape[0] != 0:
1278
+ chosen_losses = -F.logsigmoid(chosen_rewards - delta)
1279
+
1280
+ if policy_rejected_logps.shape[0] != 0 or reference_rejected_logps.shape[0] != 0:
1281
+ rejected_losses = -F.logsigmoid(-(rejected_rewards - delta))
1282
+
1283
+ if self.match_underlying_distribution:
1284
+ chosen_weight = torch.ones_like(chosen_losses)
1285
+ rejected_weight = self._get_udm_weight(rejected_embeddings)
1286
+
1287
+ losses = torch.cat((chosen_weight * chosen_losses, rejected_weight * rejected_losses), dim=0)
1288
+ else:
1289
+ losses = torch.cat((chosen_losses, rejected_losses), dim=0)
1290
+
1291
+ return losses, chosen_rewards, rejected_rewards, torch.as_tensor(delta)
1292
+
1293
+ def get_batch_loss_metrics(
1294
+ self,
1295
+ model,
1296
+ batch: dict[str, Union[list, torch.LongTensor]],
1297
+ ):
1298
+ """Compute the BCO loss and other metrics for the given batch of inputs for train or test."""
1299
+ metrics = {}
1300
+ batch = {k: (v.to(self.accelerator.device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
1301
+
1302
+ forward_output = self.forward(model, batch)
1303
+ (
1304
+ policy_chosen_logps,
1305
+ policy_rejected_logps,
1306
+ policy_chosen_logits,
1307
+ policy_rejected_logits,
1308
+ ) = forward_output[:4]
1309
+ if self.aux_loss_enabled:
1310
+ aux_loss = forward_output[4]
1311
+
1312
+ # if reference_logps in batch use them, otherwise use the reference model
1313
+ if "reference_logps" in batch:
1314
+ chosen_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is True]
1315
+ rejected_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is False]
1316
+
1317
+ reference_chosen_logps = batch["reference_logps"][chosen_idx, ...]
1318
+ reference_rejected_logps = batch["reference_logps"][rejected_idx, ...]
1319
+ else:
1320
+ with torch.no_grad():
1321
+ if self.ref_model is None:
1322
+ with self.null_ref_context():
1323
+ (
1324
+ reference_chosen_logps,
1325
+ reference_rejected_logps,
1326
+ _,
1327
+ _,
1328
+ ) = self.forward(self.model, batch)[:4]
1329
+ else:
1330
+ (
1331
+ reference_chosen_logps,
1332
+ reference_rejected_logps,
1333
+ _,
1334
+ _,
1335
+ ) = self.forward(self.ref_model, batch)[:4]
1336
+
1337
+ chosen_embeddings, rejected_embeddings = self._get_prompt_embeddings(batch)
1338
+
1339
+ losses, chosen_rewards, rejected_rewards, delta = self.bco_loss(
1340
+ policy_chosen_logps,
1341
+ policy_rejected_logps,
1342
+ reference_chosen_logps,
1343
+ reference_rejected_logps,
1344
+ chosen_embeddings,
1345
+ rejected_embeddings,
1346
+ )
1347
+ metrics["delta"] = self.accelerator.gather_for_metrics(delta).mean().item()
1348
+
1349
+ num_chosen = torch.Tensor([len(chosen_rewards)]).to(self.accelerator.device)
1350
+ num_rejected = torch.Tensor([len(rejected_rewards)]).to(self.accelerator.device)
1351
+
1352
+ all_num_chosen = self.accelerator.gather_for_metrics(num_chosen).sum().item()
1353
+ all_num_rejected = self.accelerator.gather_for_metrics(num_rejected).sum().item()
1354
+
1355
+ if all_num_chosen > 0:
1356
+ metrics["rewards/chosen_sum"] = (
1357
+ self.accelerator.gather_for_metrics(chosen_rewards.nansum()).nansum().item()
1358
+ )
1359
+ metrics["logps/chosen_sum"] = (
1360
+ self.accelerator.gather_for_metrics(policy_chosen_logps.nansum()).nansum().item()
1361
+ )
1362
+ metrics["logits/chosen_sum"] = (
1363
+ self.accelerator.gather_for_metrics(policy_chosen_logits.nansum()).nansum().item()
1364
+ )
1365
+ metrics["count/chosen"] = all_num_chosen
1366
+
1367
+ if all_num_rejected > 0:
1368
+ metrics["rewards/rejected_sum"] = (
1369
+ self.accelerator.gather_for_metrics(rejected_rewards.nansum()).nansum().item()
1370
+ )
1371
+ metrics["logps/rejected_sum"] = (
1372
+ self.accelerator.gather_for_metrics(policy_rejected_logps.nansum()).nansum().item()
1373
+ )
1374
+ metrics["logits/rejected_sum"] = (
1375
+ self.accelerator.gather_for_metrics(policy_rejected_logits.nansum()).nansum().item()
1376
+ )
1377
+ metrics["count/rejected"] = all_num_rejected
1378
+
1379
+ loss = losses.nanmean()
1380
+ if self.aux_loss_enabled:
1381
+ loss += self.aux_loss_coef * aux_loss
1382
+
1383
+ return loss, metrics
1384
+
1385
+ def compute_loss(
1386
+ self,
1387
+ model: Union[PreTrainedModel, nn.Module],
1388
+ inputs: dict[str, Union[torch.Tensor, Any]],
1389
+ return_outputs=False,
1390
+ num_items_in_batch=None,
1391
+ ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
1392
+ compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1393
+
1394
+ with compute_loss_context_manager:
1395
+ loss, metrics = self.get_batch_loss_metrics(model, inputs)
1396
+
1397
+ # Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class:
1398
+ loss = loss.to(self.args.device)
1399
+ # force log the metrics
1400
+ if self.accelerator.is_main_process:
1401
+ self.store_metrics(metrics, train_eval="train")
1402
+
1403
+ if return_outputs:
1404
+ return (loss, metrics)
1405
+ return loss
1406
+
1407
+ def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
1408
+ for key, value in metrics.items():
1409
+ self._stored_metrics[train_eval][key].append(value)
1410
+
1411
+ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
1412
+ if self.train_dataset is None or not has_length(self.train_dataset):
1413
+ return None
1414
+ return SequentialSampler(self.train_dataset)
1415
+
1416
+ def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) -> tuple[str, str]:
1417
+ """Generate samples from the model and reference model for the given batch of inputs."""
1418
+
1419
+ # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
1420
+ # the torch cuda amp context manager as some hidden states are silently casted to full precision.
1421
+ generate_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1422
+ with generate_context_manager:
1423
+ policy_output = model.generate(
1424
+ input_ids=batch["prompt_input_ids"],
1425
+ attention_mask=batch["prompt_attention_mask"],
1426
+ max_length=self.max_length,
1427
+ do_sample=True,
1428
+ pad_token_id=self.processing_class.pad_token_id,
1429
+ )
1430
+
1431
+ # if reference_output in batch use that otherwise use the reference model
1432
+ if "reference_output" in batch:
1433
+ reference_output = batch["reference_output"]
1434
+ else:
1435
+ if self.ref_model is None:
1436
+ with self.null_ref_context():
1437
+ reference_output = self.model.generate(
1438
+ input_ids=batch["prompt_input_ids"],
1439
+ attention_mask=batch["prompt_attention_mask"],
1440
+ max_length=self.max_length,
1441
+ do_sample=True,
1442
+ pad_token_id=self.processing_class.pad_token_id,
1443
+ )
1444
+ else:
1445
+ reference_output = self.ref_model.generate(
1446
+ input_ids=batch["prompt_input_ids"],
1447
+ attention_mask=batch["prompt_attention_mask"],
1448
+ max_length=self.max_length,
1449
+ do_sample=True,
1450
+ pad_token_id=self.processing_class.pad_token_id,
1451
+ )
1452
+
1453
+ policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id)
1454
+ policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True)
1455
+
1456
+ reference_output = pad_to_length(reference_output, self.max_length, self.processing_class.pad_token_id)
1457
+ reference_output_decoded = self.processing_class.batch_decode(reference_output, skip_special_tokens=True)
1458
+
1459
+ return policy_output_decoded, reference_output_decoded
1460
+
1461
+ def prediction_step(
1462
+ self,
1463
+ model: Union[PreTrainedModel, nn.Module],
1464
+ inputs: dict[str, Union[torch.Tensor, Any]],
1465
+ prediction_loss_only: bool,
1466
+ ignore_keys: Optional[list[str]] = None,
1467
+ ):
1468
+ if ignore_keys is None:
1469
+ if hasattr(model, "config"):
1470
+ ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
1471
+ else:
1472
+ ignore_keys = []
1473
+
1474
+ prediction_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1475
+ with torch.no_grad(), prediction_context_manager:
1476
+ loss, metrics = self.get_batch_loss_metrics(model, inputs)
1477
+
1478
+ # force log the metrics
1479
+ if self.accelerator.is_main_process:
1480
+ self.store_metrics(metrics, train_eval="eval")
1481
+
1482
+ if prediction_loss_only:
1483
+ return (loss.detach(), None, None)
1484
+
1485
+ # logits for the chosen and rejected samples from model
1486
+ logits_dict = {
1487
+ "eval_logits/chosen": metrics["logits/chosen"],
1488
+ "eval_logits/rejected": metrics["logits/rejected"],
1489
+ }
1490
+ logits = tuple(v.unsqueeze(dim=0) for k, v in logits_dict.items() if k not in ignore_keys)
1491
+ logits = torch.stack(logits).mean(axis=1).to(self.accelerator.device)
1492
+ labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
1493
+
1494
+ return (loss.detach(), logits, labels)
1495
+
1496
+ def evaluation_loop(
1497
+ self,
1498
+ dataloader: DataLoader,
1499
+ description: str,
1500
+ prediction_loss_only: Optional[bool] = None,
1501
+ ignore_keys: Optional[list[str]] = None,
1502
+ metric_key_prefix: str = "eval",
1503
+ ) -> EvalLoopOutput:
1504
+ """
1505
+ Overriding built-in evaluation loop to store metrics for each batch.
1506
+ Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
1507
+
1508
+ Works both with or without labels.
1509
+ """
1510
+
1511
+ # Sample and save to game log if requested (for one batch to save time)
1512
+ if self.generate_during_eval:
1513
+ # Generate random indices within the range of the total number of samples
1514
+ num_samples = len(dataloader.dataset)
1515
+ random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
1516
+
1517
+ # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
1518
+ random_batch_dataset = dataloader.dataset.select(random_indices)
1519
+ random_batch = self.data_collator(random_batch_dataset)
1520
+ random_batch = self._prepare_inputs(random_batch)
1521
+
1522
+ target_indicies = [i for i in range(len(random_batch["label"])) if random_batch["label"][i] is False]
1523
+ target_batch = {
1524
+ "prompt_input_ids": random_batch["prompt_input_ids"][target_indicies],
1525
+ "prompt_attention_mask": random_batch["prompt_attention_mask"][target_indicies],
1526
+ "prompt": itemgetter(*target_indicies)(random_batch["prompt"]),
1527
+ }
1528
+ policy_output_decoded, ref_output_decoded = self.generate_from_model_and_ref(self.model, target_batch)
1529
+
1530
+ table = pd.DataFrame(
1531
+ columns=["Prompt", "Policy", "Ref Model"],
1532
+ data=[
1533
+ [prompt, pol[len(prompt) :], ref[len(prompt) :]]
1534
+ for prompt, pol, ref in zip(target_batch["prompt"], policy_output_decoded, ref_output_decoded)
1535
+ ],
1536
+ )
1537
+ if "wandb" in self.args.report_to:
1538
+ wandb.log({"game_log": wandb.Table(data=table)})
1539
+
1540
+ if "comet_ml" in self.args.report_to:
1541
+ log_table_to_comet_experiment(
1542
+ name="game_log.csv",
1543
+ table=table,
1544
+ )
1545
+
1546
+ # Base evaluation
1547
+ initial_output = super().evaluation_loop(
1548
+ dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
1549
+ )
1550
+
1551
+ return initial_output
1552
+
1553
+ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
1554
+ """
1555
+ Log `logs` on the various objects watching training, including stored metrics.
1556
+
1557
+ Args:
1558
+ logs (`dict[str, float]`):
1559
+ The values to log.
1560
+ start_time (`float` or `None`, *optional*, defaults to `None`):
1561
+ Start time of the training.
1562
+ """
1563
+ # logs either has 'loss' or 'eval_loss'
1564
+ train_eval = "train" if "loss" in logs else "eval"
1565
+ # train metrics should have no prefix, eval should have 'eval_'
1566
+ prefix = "eval_" if train_eval == "eval" else ""
1567
+ # accumulate average metrics from sums and lengths
1568
+ for split in ["chosen", "rejected"]:
1569
+ if f"count/{split}" in self._stored_metrics[train_eval]:
1570
+ count_sum = torch.Tensor(self._stored_metrics[train_eval][f"count/{split}"]).sum().item()
1571
+ for metric in ["rewards", "logps", "logits"]:
1572
+ logs[f"{prefix}{metric}/{split}"] = (
1573
+ torch.Tensor(self._stored_metrics[train_eval][f"{metric}/{split}_sum"]).sum().item()
1574
+ / count_sum
1575
+ )
1576
+ # delete obsolete metric
1577
+ del self._stored_metrics[train_eval][f"{metric}/{split}_sum"]
1578
+ del self._stored_metrics[train_eval][f"count/{split}"]
1579
+ # calculate reward margin
1580
+ if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs:
1581
+ logs[f"{prefix}rewards/margins"] = logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"]
1582
+ # Add averaged stored metrics to logs
1583
+ for key, metrics in self._stored_metrics[train_eval].items():
1584
+ logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item()
1585
+ del self._stored_metrics[train_eval]
1586
+
1587
+ if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
1588
+ return super().log(logs, start_time)
1589
+ else: # transformers<=4.46
1590
+ return super().log(logs)
1591
+
1592
+ def create_model_card(
1593
+ self,
1594
+ model_name: Optional[str] = None,
1595
+ dataset_name: Optional[str] = None,
1596
+ tags: Union[str, list[str], None] = None,
1597
+ ):
1598
+ """
1599
+ Creates a draft of a model card using the information available to the `Trainer`.
1600
+
1601
+ Args:
1602
+ model_name (`str` or `None`, *optional*, defaults to `None`):
1603
+ Name of the model.
1604
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
1605
+ Name of the dataset used for training.
1606
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
1607
+ Tags to be associated with the model card.
1608
+ """
1609
+ if not self.is_world_process_zero():
1610
+ return
1611
+
1612
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
1613
+ base_model = self.model.config._name_or_path
1614
+ else:
1615
+ base_model = None
1616
+
1617
+ tags = tags or []
1618
+ if isinstance(tags, str):
1619
+ tags = [tags]
1620
+
1621
+ if hasattr(self.model.config, "unsloth_version"):
1622
+ tags.append("unsloth")
1623
+
1624
+ citation = textwrap.dedent("""\
1625
+ @article{jung2024binary,
1626
+ title = {{Binary Classifier Optimization for Large Language Model Alignment}},
1627
+ author = {Seungjae Jung and Gunsoo Han and Daniel Wontae Nam and Kyoung{-}Woon On},
1628
+ year = 2024,
1629
+ eprint = {arXiv:2404.04656}
1630
+ }""")
1631
+
1632
+ model_card = generate_model_card(
1633
+ base_model=base_model,
1634
+ model_name=model_name,
1635
+ hub_model_id=self.hub_model_id,
1636
+ dataset_name=dataset_name,
1637
+ tags=tags,
1638
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
1639
+ comet_url=get_comet_experiment_url(),
1640
+ trainer_name="BCO",
1641
+ trainer_citation=citation,
1642
+ paper_title="Binary Classifier Optimization for Large Language Model Alignment",
1643
+ paper_id="2404.04656",
1644
+ )
1645
+
1646
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
1647
+ class UnslothBCOTrainer(_UnslothBCOTrainer):
1648
+ """
1649
+
1650
+ Initialize BCOTrainer from [BCO](https://huggingface.co/papers/2404.04656) paper.
1651
+
1652
+ Args:
1653
+ model (`transformers.PreTrainedModel`):
1654
+ The model to train, preferably an `AutoModelForSequenceClassification`.
1655
+ ref_model (`PreTrainedModelWrapper`):
1656
+ Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no
1657
+ reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized.
1658
+ args (`BCOConfig`):
1659
+ The arguments to use for training.
1660
+ train_dataset (`datasets.Dataset`):
1661
+ The dataset to use for training.
1662
+ eval_dataset (`datasets.Dataset`):
1663
+ The dataset to use for evaluation.
1664
+ processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
1665
+ Processing class used to process the data. If provided, will be used to automatically process the inputs
1666
+ for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
1667
+ reuse the fine-tuned model.
1668
+ data_collator (`transformers.DataCollator`, *optional*, defaults to `None`):
1669
+ The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
1670
+ which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
1671
+ model_init (`Callable[[], transformers.PreTrainedModel]`):
1672
+ The model initializer to use for training. If None is specified, the default model initializer will be used.
1673
+ callbacks (`list[transformers.TrainerCallback]`):
1674
+ The callbacks to use for training.
1675
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
1676
+ The optimizer and scheduler to use for training.
1677
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
1678
+ The function to use to preprocess the logits before computing the metrics.
1679
+ peft_config (`dict`, defaults to `None`):
1680
+ The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
1681
+ compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
1682
+ The function to use to compute the metrics. Must take a `EvalPrediction` and return
1683
+ a dictionary string to metric values.
1684
+ model_adapter_name (`str`, defaults to `None`):
1685
+ Name of the train target PEFT adapter, when using LoRA with multiple adapters.
1686
+ ref_adapter_name (`str`, defaults to `None`):
1687
+ Name of the reference PEFT adapter, when using LoRA with multiple adapters.
1688
+
1689
+ """
1690
+ def __init__(
1691
+ self,
1692
+ model = None,
1693
+ ref_model = None,
1694
+ args = None,
1695
+ train_dataset = None,
1696
+ eval_dataset = None,
1697
+ processing_class = None,
1698
+ data_collator = None,
1699
+ model_init = None,
1700
+ callbacks = None,
1701
+ preprocess_logits_for_metrics = None,
1702
+ peft_config = None,
1703
+ compute_metrics = None,
1704
+ model_adapter_name = None,
1705
+ ref_adapter_name = None,
1706
+ embedding_func = None,
1707
+ embedding_tokenizer = None,
1708
+ **kwargs
1709
+ ):
1710
+ if args is None: args = UnslothBCOConfig()
1711
+ use_bf16 = getattr(args, 'bf16', False)
1712
+ use_fp16 = getattr(args, 'fp16', False)
1713
+ force_float32 = False
1714
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
1715
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
1716
+ force_float32 = True
1717
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
1718
+ dtype = getattr(model.config, 'torch_dtype', None)
1719
+ if dtype is None: dtype = model.get_input_embeddings().dtype
1720
+ from unsloth_zoo.utils import _get_dtype
1721
+ dtype = _get_dtype(dtype)
1722
+ float16 = dtype == torch.float16
1723
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
1724
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
1725
+ if force_float32:
1726
+ args.fp16 = False
1727
+ args.bf16 = False
1728
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1729
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
1730
+ args.fp16 = float16
1731
+ args.bf16 = not float16
1732
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
1733
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
1734
+ args.eval_strategy = 'steps'
1735
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
1736
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
1737
+ if ga_steps is not None and ga_steps > 1:
1738
+ from transformers import __version__ as transformers_version
1739
+ if Version(transformers_version) <= Version('4.45.2'):
1740
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
1741
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
1742
+ if getattr(args, 'eval_strategy', 'no') != 'no':
1743
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
1744
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
1745
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
1746
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
1747
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
1748
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
1749
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
1750
+ if force_float32:
1751
+ args.bf16_full_eval = False
1752
+ args.fp16_full_eval = False
1753
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
1754
+ args.bf16_full_eval = True
1755
+ args.fp16_full_eval = False
1756
+ elif not bf16_full_eval and not fp16_full_eval:
1757
+ args.bf16_full_eval = args.bf16
1758
+ args.fp16_full_eval = args.fp16
1759
+ _output_logits = False
1760
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
1761
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
1762
+ if _output_logits:
1763
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
1764
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1765
+ pass
1766
+ else:
1767
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
1768
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
1769
+ if args_max_seq_length is None and model_max_seq_length is not None:
1770
+ max_seq_length = model.max_seq_length
1771
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1772
+ if model is not None and hasattr(model, 'for_training'):
1773
+ model.for_training()
1774
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1775
+ if 'processing_class' in locals():
1776
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1777
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1778
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
1779
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
1780
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1781
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
1782
+ data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
1783
+ elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
1784
+ data_collator = DataCollatorForSeq2Seq(__tokenizer)
1785
+ else:
1786
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
1787
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
1788
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
1789
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1790
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
1791
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
1792
+ data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
1793
+ else:
1794
+ data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
1795
+ other_metrics = []
1796
+
1797
+ from unsloth_zoo.logging_utils import PatchRLStatistics
1798
+ PatchRLStatistics('bco_trainer', other_metrics)
1799
+
1800
+ super().__init__(
1801
+ model = model,
1802
+ ref_model = ref_model,
1803
+ args = args,
1804
+ train_dataset = train_dataset,
1805
+ eval_dataset = eval_dataset,
1806
+ processing_class = processing_class,
1807
+ data_collator = data_collator,
1808
+ model_init = model_init,
1809
+ callbacks = callbacks,
1810
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics,
1811
+ peft_config = peft_config,
1812
+ compute_metrics = compute_metrics,
1813
+ model_adapter_name = model_adapter_name,
1814
+ ref_adapter_name = ref_adapter_name,
1815
+ embedding_func = embedding_func,
1816
+ embedding_tokenizer = embedding_tokenizer,**kwargs)
1817
+ if hasattr(self, 'neftune_hook_handle'):
1818
+ self.neftune_hook_handle.remove()
1819
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1820
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
1821
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1822
+ pass
1823
+
1824
+ pass
unsloth_compiled_cache/UnslothCPOTrainer.py ADDED
@@ -0,0 +1,1557 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.3.17
3
+ 2025.3.19
4
+ 4.50.3
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from trl.trainer.cpo_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, CPOConfig, CPOTrainer, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, Literal, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, Trainer, TrainerCallback, Union, add_bos_token_if_needed, add_eos_token_if_needed, amp, defaultdict, disable_dropout_in_model, generate_model_card, get_comet_experiment_url, inspect, is_comet_available, is_peft_available, is_torch_fx_proxy, is_wandb_available, log_table_to_comet_experiment, maybe_apply_chat_template, maybe_extract_prompt, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_model_for_kbit_training, random, textwrap, torch, transformers, version, wandb, warnings)
13
+
14
+
15
+ import os
16
+ from typing import *
17
+ from dataclasses import dataclass, field
18
+ from packaging.version import Version
19
+ import torch
20
+ import numpy as np
21
+ from contextlib import nullcontext
22
+ from torch.nn import functional as F
23
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
+
25
+ torch_compile_options = {
26
+ "epilogue_fusion" : True,
27
+ "max_autotune" : False,
28
+ "shape_padding" : True,
29
+ "trace.enabled" : False,
30
+ "triton.cudagraphs" : False,
31
+ }
32
+
33
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
+ def selective_log_softmax(logits, index):
35
+ logits = logits.to(torch.float32)
36
+ selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
+ # loop to reduce peak mem consumption
38
+ # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
+ logsumexp_values = torch.logsumexp(logits, dim = -1)
40
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
+ return per_token_logps
42
+ @dataclass
43
+ class UnslothCPOConfig(CPOConfig):
44
+ """
45
+
46
+ Configuration class for the [`CPOTrainer`].
47
+
48
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
49
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
50
+ command line.
51
+
52
+ Parameters:
53
+ learning_rate (`float`, *optional*, defaults to `1e-6`):
54
+ Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
55
+ [`~transformers.TrainingArguments`].
56
+ max_length (`int` or `None`, *optional*, defaults to `1024`):
57
+ Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
58
+ to use the default data collator.
59
+ max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
60
+ Maximum length of the prompt. This argument is required if you want to use the default data collator.
61
+ max_completion_length (`int` or `None`, *optional*, defaults to `None`):
62
+ Maximum length of the completion. This argument is required if you want to use the default data collator
63
+ and your model is an encoder-decoder.
64
+ beta (`float`, *optional*, defaults to `0.1`):
65
+ Parameter controlling the deviation from the reference model. Higher β means less deviation from the
66
+ reference model. For the IPO loss (`loss_type="ipo"`), β is the regularization parameter denoted by τ in
67
+ the [paper](https://huggingface.co/papers/2310.12036).
68
+ label_smoothing (`float`, *optional*, defaults to `0.0`):
69
+ Label smoothing factor. This argument is required if you want to use the default data collator.
70
+ loss_type (`str`, *optional*, defaults to `"sigmoid"`):
71
+ Type of loss to use. Possible values are:
72
+
73
+ - `"sigmoid"`: sigmoid loss from the original [DPO](https://huggingface.co/papers/2305.18290) paper.
74
+ - `"hinge"`: hinge loss on the normalized likelihood from the [SLiC](https://huggingface.co/papers/2305.10425) paper.
75
+ - `"ipo"`: IPO loss from the [IPO](https://huggingface.co/papers/2310.12036) paper.
76
+ - `"simpo"`: SimPO loss from the [SimPO](https://huggingface.co/papers/2405.14734) paper.
77
+
78
+ disable_dropout (`bool`, *optional*, defaults to `True`):
79
+ Whether to disable dropout in the model.
80
+ cpo_alpha (`float`, *optional*, defaults to `1.0`):
81
+ Weight of the BC regularizer in CPO training.
82
+ simpo_gamma (`float`, *optional*, defaults to `0.5`):
83
+ Target reward margin for the SimPO loss, used only when the `loss_type="simpo"`.
84
+ label_pad_token_id (`int`, *optional*, defaults to `-100`):
85
+ Label pad token id. This argument is required if you want to use the default data collator.
86
+ padding_value (`int` or `None`, *optional*, defaults to `None`):
87
+ Padding value to use. If `None`, the padding value of the tokenizer is used.
88
+ truncation_mode (`str`,*optional*, defaults to `"keep_end"`):
89
+ Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`.
90
+ This argument is required if you want to use the default data collator.
91
+ generate_during_eval (`bool`, *optional*, defaults to `False`):
92
+ If `True`, generates and logs completions from the model to W&B or Comet during evaluation.
93
+ is_encoder_decoder (`bool` or `None`, *optional*, defaults to `None`):
94
+ When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument,
95
+ you need to specify if the model returned by the callable is an encoder-decoder model.
96
+ model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
97
+ Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
98
+ string.
99
+ dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
100
+ Number of processes to use for processing the dataset.
101
+
102
+ """
103
+ vllm_sampling_params: Optional[Any] = field(
104
+ default = None,
105
+ metadata = {'help': 'vLLM SamplingParams'},
106
+ )
107
+ unsloth_num_chunks : Optional[int] = field(
108
+ default = -1,
109
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
110
+ )
111
+ def __init__(
112
+ self,
113
+ output_dir = None,
114
+ overwrite_output_dir = None,
115
+ do_train = False,
116
+ do_eval = False,
117
+ do_predict = False,
118
+ eval_strategy = 'no',
119
+ prediction_loss_only = False,
120
+ per_device_train_batch_size = 4,
121
+ per_device_eval_batch_size = 4,
122
+ per_gpu_train_batch_size = None,
123
+ per_gpu_eval_batch_size = None,
124
+ gradient_accumulation_steps = 2,
125
+ eval_accumulation_steps = 2,
126
+ eval_delay = 0,
127
+ torch_empty_cache_steps = 250,
128
+ learning_rate = 5e-05,
129
+ weight_decay = 0.01,
130
+ adam_beta1 = 0.9,
131
+ adam_beta2 = 0.999,
132
+ adam_epsilon = 1e-08,
133
+ max_grad_norm = 1.0,
134
+ num_train_epochs = 3.0,
135
+ max_steps = -1,
136
+ lr_scheduler_type = 'linear',
137
+ warmup_ratio = 0.1,
138
+ warmup_steps = 0,
139
+ log_level = 'passive',
140
+ log_level_replica = 'warning',
141
+ log_on_each_node = True,
142
+ logging_dir = None,
143
+ logging_strategy = 'steps',
144
+ logging_first_step = False,
145
+ logging_steps = 1,
146
+ logging_nan_inf_filter = False,
147
+ save_strategy = 'steps',
148
+ save_steps = 500,
149
+ save_total_limit = None,
150
+ save_safetensors = True,
151
+ save_on_each_node = False,
152
+ save_only_model = False,
153
+ restore_callback_states_from_checkpoint = False,
154
+ no_cuda = False,
155
+ use_cpu = False,
156
+ use_mps_device = False,
157
+ seed = 3407,
158
+ data_seed = 3407,
159
+ jit_mode_eval = False,
160
+ use_ipex = False,
161
+ bf16 = False,
162
+ fp16 = False,
163
+ fp16_opt_level = 'O1',
164
+ half_precision_backend = 'auto',
165
+ bf16_full_eval = False,
166
+ fp16_full_eval = False,
167
+ tf32 = None,
168
+ local_rank = -1,
169
+ ddp_backend = None,
170
+ tpu_num_cores = None,
171
+ tpu_metrics_debug = False,
172
+ debug = '',
173
+ dataloader_drop_last = False,
174
+ eval_steps = None,
175
+ dataloader_num_workers = 0,
176
+ dataloader_prefetch_factor = None,
177
+ past_index = -1,
178
+ run_name = None,
179
+ disable_tqdm = None,
180
+ remove_unused_columns = True,
181
+ label_names = None,
182
+ load_best_model_at_end = False,
183
+ metric_for_best_model = None,
184
+ greater_is_better = None,
185
+ ignore_data_skip = False,
186
+ fsdp = '',
187
+ fsdp_min_num_params = 0,
188
+ fsdp_config = None,
189
+ tp_size = 0,
190
+ fsdp_transformer_layer_cls_to_wrap = None,
191
+ accelerator_config = None,
192
+ deepspeed = None,
193
+ label_smoothing_factor = 0.0,
194
+ optim = 'adamw_8bit',
195
+ optim_args = None,
196
+ adafactor = False,
197
+ group_by_length = False,
198
+ length_column_name = 'length',
199
+ report_to = None,
200
+ ddp_find_unused_parameters = None,
201
+ ddp_bucket_cap_mb = None,
202
+ ddp_broadcast_buffers = None,
203
+ dataloader_pin_memory = True,
204
+ dataloader_persistent_workers = False,
205
+ skip_memory_metrics = True,
206
+ use_legacy_prediction_loop = False,
207
+ push_to_hub = False,
208
+ resume_from_checkpoint = None,
209
+ hub_model_id = None,
210
+ hub_strategy = 'every_save',
211
+ hub_token = None,
212
+ hub_private_repo = None,
213
+ hub_always_push = False,
214
+ gradient_checkpointing = False,
215
+ gradient_checkpointing_kwargs = None,
216
+ include_inputs_for_metrics = False,
217
+ eval_do_concat_batches = True,
218
+ fp16_backend = 'auto',
219
+ evaluation_strategy = None,
220
+ push_to_hub_model_id = None,
221
+ push_to_hub_organization = None,
222
+ push_to_hub_token = None,
223
+ mp_parameters = '',
224
+ auto_find_batch_size = False,
225
+ full_determinism = False,
226
+ torchdynamo = None,
227
+ ray_scope = 'last',
228
+ ddp_timeout = 1800,
229
+ torch_compile = False,
230
+ torch_compile_backend = None,
231
+ torch_compile_mode = None,
232
+ dispatch_batches = None,
233
+ split_batches = None,
234
+ include_tokens_per_second = False,
235
+ include_num_input_tokens_seen = False,
236
+ neftune_noise_alpha = None,
237
+ optim_target_modules = None,
238
+ batch_eval_metrics = False,
239
+ eval_on_start = False,
240
+ use_liger_kernel = False,
241
+ eval_use_gather_object = False,
242
+ average_tokens_across_devices = False,
243
+ max_length = 1024,
244
+ max_prompt_length = 512,
245
+ max_completion_length = None,
246
+ beta = 0.1,
247
+ label_smoothing = 0.0,
248
+ loss_type = 'sigmoid',
249
+ disable_dropout = True,
250
+ cpo_alpha = 1.0,
251
+ simpo_gamma = 0.5,
252
+ label_pad_token_id = -100,
253
+ padding_value = None,
254
+ truncation_mode = 'keep_end',
255
+ generate_during_eval = False,
256
+ is_encoder_decoder = None,
257
+ model_init_kwargs = None,
258
+ dataset_num_proc = None,
259
+ vllm_sampling_params = None,
260
+ unsloth_num_chunks = -1,
261
+ **kwargs,
262
+ ):
263
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
264
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
265
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
266
+ output_dir = 'unsloth_training_checkpoints'
267
+ save_strategy = 'no'
268
+ if dataset_num_proc is None:
269
+ from multiprocessing import cpu_count
270
+ dataset_num_proc = cpu_count()
271
+
272
+ super().__init__(
273
+ output_dir = output_dir,
274
+ overwrite_output_dir = overwrite_output_dir,
275
+ do_train = do_train,
276
+ do_eval = do_eval,
277
+ do_predict = do_predict,
278
+ eval_strategy = eval_strategy,
279
+ prediction_loss_only = prediction_loss_only,
280
+ per_device_train_batch_size = per_device_train_batch_size,
281
+ per_device_eval_batch_size = per_device_eval_batch_size,
282
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
283
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
284
+ gradient_accumulation_steps = gradient_accumulation_steps,
285
+ eval_accumulation_steps = eval_accumulation_steps,
286
+ eval_delay = eval_delay,
287
+ torch_empty_cache_steps = torch_empty_cache_steps,
288
+ learning_rate = learning_rate,
289
+ weight_decay = weight_decay,
290
+ adam_beta1 = adam_beta1,
291
+ adam_beta2 = adam_beta2,
292
+ adam_epsilon = adam_epsilon,
293
+ max_grad_norm = max_grad_norm,
294
+ num_train_epochs = num_train_epochs,
295
+ max_steps = max_steps,
296
+ lr_scheduler_type = lr_scheduler_type,
297
+ warmup_ratio = warmup_ratio,
298
+ warmup_steps = warmup_steps,
299
+ log_level = log_level,
300
+ log_level_replica = log_level_replica,
301
+ log_on_each_node = log_on_each_node,
302
+ logging_dir = logging_dir,
303
+ logging_strategy = logging_strategy,
304
+ logging_first_step = logging_first_step,
305
+ logging_steps = logging_steps,
306
+ logging_nan_inf_filter = logging_nan_inf_filter,
307
+ save_strategy = save_strategy,
308
+ save_steps = save_steps,
309
+ save_total_limit = save_total_limit,
310
+ save_safetensors = save_safetensors,
311
+ save_on_each_node = save_on_each_node,
312
+ save_only_model = save_only_model,
313
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
314
+ no_cuda = no_cuda,
315
+ use_cpu = use_cpu,
316
+ use_mps_device = use_mps_device,
317
+ seed = seed,
318
+ data_seed = data_seed,
319
+ jit_mode_eval = jit_mode_eval,
320
+ use_ipex = use_ipex,
321
+ bf16 = bf16,
322
+ fp16 = fp16,
323
+ fp16_opt_level = fp16_opt_level,
324
+ half_precision_backend = half_precision_backend,
325
+ bf16_full_eval = bf16_full_eval,
326
+ fp16_full_eval = fp16_full_eval,
327
+ tf32 = tf32,
328
+ local_rank = local_rank,
329
+ ddp_backend = ddp_backend,
330
+ tpu_num_cores = tpu_num_cores,
331
+ tpu_metrics_debug = tpu_metrics_debug,
332
+ debug = debug,
333
+ dataloader_drop_last = dataloader_drop_last,
334
+ eval_steps = eval_steps,
335
+ dataloader_num_workers = dataloader_num_workers,
336
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
337
+ past_index = past_index,
338
+ run_name = run_name,
339
+ disable_tqdm = disable_tqdm,
340
+ remove_unused_columns = remove_unused_columns,
341
+ label_names = label_names,
342
+ load_best_model_at_end = load_best_model_at_end,
343
+ metric_for_best_model = metric_for_best_model,
344
+ greater_is_better = greater_is_better,
345
+ ignore_data_skip = ignore_data_skip,
346
+ fsdp = fsdp,
347
+ fsdp_min_num_params = fsdp_min_num_params,
348
+ fsdp_config = fsdp_config,
349
+ tp_size = tp_size,
350
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
351
+ accelerator_config = accelerator_config,
352
+ deepspeed = deepspeed,
353
+ label_smoothing_factor = label_smoothing_factor,
354
+ optim = optim,
355
+ optim_args = optim_args,
356
+ adafactor = adafactor,
357
+ group_by_length = group_by_length,
358
+ length_column_name = length_column_name,
359
+ report_to = report_to,
360
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
361
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
362
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
363
+ dataloader_pin_memory = dataloader_pin_memory,
364
+ dataloader_persistent_workers = dataloader_persistent_workers,
365
+ skip_memory_metrics = skip_memory_metrics,
366
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
367
+ push_to_hub = push_to_hub,
368
+ resume_from_checkpoint = resume_from_checkpoint,
369
+ hub_model_id = hub_model_id,
370
+ hub_strategy = hub_strategy,
371
+ hub_token = hub_token,
372
+ hub_private_repo = hub_private_repo,
373
+ hub_always_push = hub_always_push,
374
+ gradient_checkpointing = gradient_checkpointing,
375
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
376
+ include_inputs_for_metrics = include_inputs_for_metrics,
377
+ eval_do_concat_batches = eval_do_concat_batches,
378
+ fp16_backend = fp16_backend,
379
+ evaluation_strategy = evaluation_strategy,
380
+ push_to_hub_model_id = push_to_hub_model_id,
381
+ push_to_hub_organization = push_to_hub_organization,
382
+ push_to_hub_token = push_to_hub_token,
383
+ mp_parameters = mp_parameters,
384
+ auto_find_batch_size = auto_find_batch_size,
385
+ full_determinism = full_determinism,
386
+ torchdynamo = torchdynamo,
387
+ ray_scope = ray_scope,
388
+ ddp_timeout = ddp_timeout,
389
+ torch_compile = torch_compile,
390
+ torch_compile_backend = torch_compile_backend,
391
+ torch_compile_mode = torch_compile_mode,
392
+ dispatch_batches = dispatch_batches,
393
+ split_batches = split_batches,
394
+ include_tokens_per_second = include_tokens_per_second,
395
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
396
+ neftune_noise_alpha = neftune_noise_alpha,
397
+ optim_target_modules = optim_target_modules,
398
+ batch_eval_metrics = batch_eval_metrics,
399
+ eval_on_start = eval_on_start,
400
+ use_liger_kernel = use_liger_kernel,
401
+ eval_use_gather_object = eval_use_gather_object,
402
+ average_tokens_across_devices = average_tokens_across_devices,
403
+ max_length = max_length,
404
+ max_prompt_length = max_prompt_length,
405
+ max_completion_length = max_completion_length,
406
+ beta = beta,
407
+ label_smoothing = label_smoothing,
408
+ loss_type = loss_type,
409
+ disable_dropout = disable_dropout,
410
+ cpo_alpha = cpo_alpha,
411
+ simpo_gamma = simpo_gamma,
412
+ label_pad_token_id = label_pad_token_id,
413
+ padding_value = padding_value,
414
+ truncation_mode = truncation_mode,
415
+ generate_during_eval = generate_during_eval,
416
+ is_encoder_decoder = is_encoder_decoder,
417
+ model_init_kwargs = model_init_kwargs,
418
+ dataset_num_proc = dataset_num_proc,**kwargs)
419
+ self.vllm_sampling_params = vllm_sampling_params
420
+ self.unsloth_num_chunks = unsloth_num_chunks
421
+ pass
422
+
423
+ class _UnslothCPOTrainer(Trainer):
424
+ r""""""
425
+
426
+ _tag_names = ["trl", "cpo"]
427
+
428
+ def __init__(
429
+ self,
430
+ model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
431
+ args: Optional[CPOConfig] = None,
432
+ data_collator: Optional[DataCollator] = None,
433
+ train_dataset: Optional[Dataset] = None,
434
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
435
+ processing_class: Optional[
436
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
437
+ ] = None,
438
+ model_init: Optional[Callable[[], PreTrainedModel]] = None,
439
+ callbacks: Optional[list[TrainerCallback]] = None,
440
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
441
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
442
+ peft_config: Optional[dict] = None,
443
+ compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
444
+ ):
445
+ if args.model_init_kwargs is None:
446
+ model_init_kwargs = {}
447
+ elif not isinstance(model, str):
448
+ raise ValueError("You passed model_kwargs to the CPOTrainer. But your model is already instantiated.")
449
+ else:
450
+ model_init_kwargs = args.model_init_kwargs
451
+ torch_dtype = model_init_kwargs.get("torch_dtype")
452
+ if torch_dtype is not None:
453
+ # Convert to `torch.dtype` if an str is passed
454
+ if isinstance(torch_dtype, str) and torch_dtype != "auto":
455
+ torch_dtype = getattr(torch, torch_dtype)
456
+ if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
457
+ raise ValueError(
458
+ f"Invalid `torch_dtype` passed to the CPOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
459
+ )
460
+ model_init_kwargs["torch_dtype"] = torch_dtype
461
+
462
+ if isinstance(model, str):
463
+ model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
464
+
465
+ # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
466
+ # has been called in order to properly call autocast if needed.
467
+ self._peft_has_been_casted_to_bf16 = False
468
+
469
+ if not is_peft_available() and peft_config is not None:
470
+ raise ValueError(
471
+ "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
472
+ )
473
+ elif is_peft_available() and peft_config is not None:
474
+ # if model is a peft model and we have a peft_config, we merge and unload it first
475
+ if isinstance(model, PeftModel):
476
+ model = model.merge_and_unload()
477
+
478
+ if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
479
+ _support_gc_kwargs = hasattr(
480
+ args, "gradient_checkpointing_kwargs"
481
+ ) and "gradient_checkpointing_kwargs" in list(
482
+ inspect.signature(prepare_model_for_kbit_training).parameters
483
+ )
484
+
485
+ prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
486
+
487
+ if _support_gc_kwargs:
488
+ prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
489
+
490
+ model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
491
+ elif getattr(args, "gradient_checkpointing", False):
492
+ # For backward compatibility with older versions of transformers
493
+ if hasattr(model, "enable_input_require_grads"):
494
+ model.enable_input_require_grads()
495
+ else:
496
+
497
+ def make_inputs_require_grad(module, input, output):
498
+ output.requires_grad_(True)
499
+
500
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
501
+
502
+ # get peft model with the given config
503
+ model = model
504
+ if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
505
+ peft_module_casting_to_bf16(model)
506
+ # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
507
+ self._peft_has_been_casted_to_bf16 = True
508
+
509
+ # For models that use gradient_checkpointing, we need to attach a hook that enables input
510
+ # to explicitly have `requires_grad=True`, otherwise training will either silently
511
+ # fail or completely fail.
512
+ elif getattr(args, "gradient_checkpointing", False):
513
+ # For backward compatibility with older versions of transformers
514
+ if hasattr(model, "enable_input_require_grads"):
515
+ model.enable_input_require_grads()
516
+ else:
517
+
518
+ def make_inputs_require_grad(module, input, output):
519
+ output.requires_grad_(True)
520
+
521
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
522
+
523
+ if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
524
+ raise ValueError(
525
+ "`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
526
+ " Please install `wandb` or `comet-ml` to resolve."
527
+ )
528
+
529
+ if model is not None:
530
+ self.is_encoder_decoder = model.config.is_encoder_decoder
531
+ elif args.is_encoder_decoder is None:
532
+ raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
533
+ else:
534
+ self.is_encoder_decoder = args.is_encoder_decoder
535
+
536
+ if self.is_encoder_decoder:
537
+ self.decoder_start_token_id = model.config.decoder_start_token_id
538
+ self.pad_token_id = model.config.pad_token_id
539
+
540
+ if processing_class is None:
541
+ raise ValueError("processing_class must be specified to tokenize a CPO dataset.")
542
+ if args.max_length is None:
543
+ warnings.warn(
544
+ "`max_length` is not set in the CPOConfig's init"
545
+ " it will default to `512` by default, but you should do it yourself in the future.",
546
+ UserWarning,
547
+ )
548
+ max_length = 512
549
+ else:
550
+ max_length = args.max_length
551
+ if args.max_prompt_length is None:
552
+ warnings.warn(
553
+ "`max_prompt_length` is not set in the CPOConfig's init"
554
+ " it will default to `128` by default, but you should do it yourself in the future.",
555
+ UserWarning,
556
+ )
557
+ max_prompt_length = 128
558
+ else:
559
+ max_prompt_length = args.max_prompt_length
560
+
561
+ if args.max_completion_length is None and self.is_encoder_decoder:
562
+ warnings.warn(
563
+ "When using an encoder decoder architecture, you should set `max_completion_length` in the CPOConfig's init"
564
+ " it will default to `128` by default, but you should do it yourself in the future.",
565
+ UserWarning,
566
+ )
567
+ max_completion_length = 128
568
+ else:
569
+ max_completion_length = args.max_completion_length
570
+
571
+ if data_collator is None:
572
+ data_collator = DPODataCollatorWithPadding(
573
+ pad_token_id=processing_class.pad_token_id,
574
+ label_pad_token_id=args.label_pad_token_id,
575
+ is_encoder_decoder=self.is_encoder_decoder,
576
+ )
577
+
578
+ if args.remove_unused_columns:
579
+ args.remove_unused_columns = False
580
+ # warn users
581
+ warnings.warn(
582
+ "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments"
583
+ " we have set it for you, but you should do it yourself in the future.",
584
+ UserWarning,
585
+ )
586
+
587
+ self.use_dpo_data_collator = True
588
+ else:
589
+ self.use_dpo_data_collator = False
590
+
591
+ # Disable dropout in the model
592
+ if args.disable_dropout:
593
+ disable_dropout_in_model(model)
594
+
595
+ self.max_length = max_length
596
+ self.generate_during_eval = args.generate_during_eval
597
+ self.label_pad_token_id = args.label_pad_token_id
598
+ self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id
599
+ self.max_prompt_length = max_prompt_length
600
+ self.truncation_mode = args.truncation_mode
601
+ self.max_completion_length = max_completion_length
602
+ self.processing_class = processing_class
603
+
604
+ if args.loss_type in ["hinge", "ipo"] and args.label_smoothing > 0:
605
+ warnings.warn(
606
+ f"You are using the {args.loss_type} loss type that does not support label smoothing. The "
607
+ "`label_smoothing` parameter will be ignored. Set `label_smoothing` to `0.0` to remove this warning.",
608
+ UserWarning,
609
+ )
610
+ if args.loss_type == "kto_pair":
611
+ raise ValueError("Support for kto_pair has been removed in CPOTrainer. Please use KTOTrainer.")
612
+
613
+ self.beta = args.beta
614
+ self.label_smoothing = args.label_smoothing
615
+ self.loss_type = args.loss_type
616
+ self.cpo_alpha = args.cpo_alpha
617
+ self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
618
+ self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
619
+ if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
620
+ warnings.warn(
621
+ "You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
622
+ "`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
623
+ "greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
624
+ "loss.",
625
+ UserWarning,
626
+ )
627
+
628
+ if args.loss_type == "simpo":
629
+ self.simpo_gamma = args.simpo_gamma
630
+
631
+ self._stored_metrics = defaultdict(lambda: defaultdict(list))
632
+
633
+ # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
634
+ # input tensor associated with the key "input_ids". However, in CPO, the sampled data does not include the
635
+ # "input_ids" key. Instead, the available keys are "prompt_input_ids", "chosen_input_ids", and
636
+ # "rejected_input_ids". As a result, the trainer issues the warning: "Could not estimate the number of tokens
637
+ # of the input, floating-point operations will not be computed." To suppress this warning, we set the
638
+ # "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate
639
+ # that the warning has already been issued.
640
+ model.warnings_issued["estimate_tokens"] = True
641
+
642
+ # Compute that only on the main process for faster data processing.
643
+ # see: https://github.com/huggingface/trl/pull/1255
644
+ with PartialState().local_main_process_first():
645
+ # Extract the prompt if needed, and apply the chat template if needed
646
+ train_dataset = train_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
647
+ train_dataset = train_dataset.map(
648
+ maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, num_proc=args.dataset_num_proc
649
+ )
650
+ if eval_dataset is not None:
651
+ eval_dataset = eval_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
652
+ eval_dataset = eval_dataset.map(
653
+ maybe_apply_chat_template,
654
+ fn_kwargs={"tokenizer": processing_class},
655
+ num_proc=args.dataset_num_proc,
656
+ )
657
+
658
+ # tokenize the dataset
659
+ train_dataset = train_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
660
+ if eval_dataset is not None:
661
+ eval_dataset = eval_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
662
+
663
+ super().__init__(
664
+ model=model,
665
+ args=args,
666
+ data_collator=data_collator,
667
+ train_dataset=train_dataset,
668
+ eval_dataset=eval_dataset,
669
+ processing_class=processing_class,
670
+ model_init=model_init,
671
+ compute_metrics=compute_metrics,
672
+ callbacks=callbacks,
673
+ optimizers=optimizers,
674
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
675
+ )
676
+
677
+ # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
678
+ # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
679
+ # self.model_accepts_loss_kwargs to False to enable scaling.
680
+ self.model_accepts_loss_kwargs = False
681
+
682
+ # Add tags for models that have been loaded with the correct transformers version
683
+ if hasattr(self.model, "add_model_tags"):
684
+ self.model.add_model_tags(self._tag_names)
685
+
686
+ if not hasattr(self, "accelerator"):
687
+ raise AttributeError(
688
+ "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
689
+ )
690
+
691
+ def build_tokenized_answer(self, prompt, answer):
692
+ """
693
+ Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`.
694
+ It does ensure `enc(a + b) = enc(a) + enc(a + b)[len(enc(a)):]`.
695
+ Reference:
696
+ https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
697
+ """
698
+
699
+ full_tokenized = self.processing_class(prompt + answer, add_special_tokens=False)
700
+ prompt_input_ids = self.processing_class(prompt, add_special_tokens=False)["input_ids"]
701
+
702
+ answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids) :]
703
+ answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids) :]
704
+
705
+ # Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]`
706
+ full_concat_input_ids = np.concatenate([prompt_input_ids, answer_input_ids])
707
+
708
+ # Prepare input tokens for token by token comparison
709
+ full_input_ids = np.array(full_tokenized["input_ids"])
710
+
711
+ if len(full_input_ids) != len(full_concat_input_ids):
712
+ raise ValueError("Prompt input ids and answer input ids should have the same length.")
713
+
714
+ # On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens
715
+ # can be merged together when tokenizing prompt+answer. This could result
716
+ # on the last token from the prompt being different when tokenized on its own
717
+ # vs when done as prompt+answer.
718
+ response_token_ids_start_idx = len(prompt_input_ids)
719
+
720
+ # If tokenized prompt is different than both prompt+answer, then it means the
721
+ # last token has changed due to merging.
722
+ if prompt_input_ids != full_tokenized["input_ids"][:response_token_ids_start_idx]:
723
+ response_token_ids_start_idx -= 1
724
+
725
+ prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx]
726
+ prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx]
727
+
728
+ if len(prompt_input_ids) != len(prompt_attention_mask):
729
+ raise ValueError("Prompt input ids and attention mask should have the same length.")
730
+
731
+ answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:]
732
+ answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:]
733
+
734
+ return dict(
735
+ prompt_input_ids=prompt_input_ids,
736
+ prompt_attention_mask=prompt_attention_mask,
737
+ input_ids=answer_input_ids,
738
+ attention_mask=answer_attention_mask,
739
+ )
740
+
741
+ def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module]] = None) -> dict:
742
+ """Tokenize a single row from a CPO specific dataset.
743
+
744
+ At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation
745
+ in case the prompt + chosen or prompt + rejected responses is/are too long. First
746
+ we truncate the prompt; if we're still too long, we truncate the chosen/rejected.
747
+
748
+ We also create the labels for the chosen/rejected responses, which are of length equal to
749
+ the sum of the length of the prompt and the chosen/rejected response, with
750
+ label_pad_token_id for the prompt tokens.
751
+ """
752
+ batch = {}
753
+ prompt = feature["prompt"]
754
+ chosen = feature["chosen"]
755
+ rejected = feature["rejected"]
756
+
757
+ if not self.is_encoder_decoder:
758
+ # Check issues below for more details
759
+ # 1. https://github.com/huggingface/trl/issues/907
760
+ # 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
761
+ # 3. https://github.com/LianjiaTech/BELLE/issues/337
762
+
763
+ if not isinstance(prompt, str):
764
+ raise ValueError(f"prompt should be an str but got {type(prompt)}")
765
+ prompt_tokens = self.processing_class(prompt, add_special_tokens=False)
766
+ prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()}
767
+
768
+ if not isinstance(chosen, str):
769
+ raise ValueError(f"chosen should be an str but got {type(chosen)}")
770
+ chosen_tokens = self.build_tokenized_answer(prompt, chosen)
771
+
772
+ if not isinstance(rejected, str):
773
+ raise ValueError(f"rejected should be an str but got {type(rejected)}")
774
+ rejected_tokens = self.build_tokenized_answer(prompt, rejected)
775
+
776
+ # Last prompt token might get merged by tokenizer and
777
+ # it should not be included for generation if that happens
778
+ prompt_len_input_ids = len(prompt_tokens["prompt_input_ids"])
779
+
780
+ chosen_prompt_len_input_ids = len(chosen_tokens["prompt_input_ids"])
781
+ rejected_prompt_len_input_ids = len(rejected_tokens["prompt_input_ids"])
782
+ prompt_len_input_ids = min(chosen_prompt_len_input_ids, rejected_prompt_len_input_ids)
783
+
784
+ for k, v in prompt_tokens.items():
785
+ prompt_tokens[k] = v[:prompt_len_input_ids]
786
+
787
+ # Make sure prompts only have one different token at most an
788
+ # and length only differs by 1 at most
789
+ num_diff_tokens = sum(
790
+ [a != b for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"])]
791
+ )
792
+ num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids)
793
+ if num_diff_tokens > 1 or num_diff_len > 1:
794
+ raise ValueError(
795
+ "Chosen and rejected prompt_input_ids might only differ on the "
796
+ "last token due to tokenizer merge ops."
797
+ )
798
+
799
+ # add BOS token to head of prompt. Avoid adding if it's already there
800
+ prompt_tokens, chosen_tokens, rejected_tokens = add_bos_token_if_needed(
801
+ self.processing_class.bos_token_id,
802
+ prompt_len_input_ids,
803
+ prompt_tokens,
804
+ chosen_prompt_len_input_ids,
805
+ chosen_tokens,
806
+ rejected_prompt_len_input_ids,
807
+ rejected_tokens,
808
+ )
809
+
810
+ # add EOS token to end of answer. Avoid adding if it's already there
811
+ chosen_tokens, rejected_tokens = add_eos_token_if_needed(
812
+ self.processing_class.eos_token_id, chosen_tokens, rejected_tokens
813
+ )
814
+
815
+ longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))
816
+
817
+ # if combined sequence is too long, truncate the prompt
818
+ for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]:
819
+ if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
820
+ if self.truncation_mode == "keep_start":
821
+ for k in ["prompt_input_ids", "prompt_attention_mask"]:
822
+ answer_tokens[k] = answer_tokens[k][: self.max_prompt_length]
823
+ elif self.truncation_mode == "keep_end":
824
+ for k in ["prompt_input_ids", "prompt_attention_mask"]:
825
+ answer_tokens[k] = answer_tokens[k][-self.max_prompt_length :]
826
+ else:
827
+ raise ValueError(f"Unknown truncation mode: {self.truncation_mode}")
828
+
829
+ # if that's still too long, truncate the response
830
+ for answer_tokens in [chosen_tokens, rejected_tokens]:
831
+ if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
832
+ for k in ["input_ids", "attention_mask"]:
833
+ answer_tokens[k] = answer_tokens[k][: self.max_length - self.max_prompt_length]
834
+
835
+ # Create labels
836
+ chosen_sequence_tokens = {
837
+ k: chosen_tokens[f"prompt_{k}"] + chosen_tokens[k] for k in ["input_ids", "attention_mask"]
838
+ }
839
+ rejected_sequence_tokens = {
840
+ k: rejected_tokens[f"prompt_{k}"] + rejected_tokens[k] for k in ["input_ids", "attention_mask"]
841
+ }
842
+ chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:]
843
+ chosen_sequence_tokens["labels"][: len(chosen_tokens["prompt_input_ids"])] = [
844
+ self.label_pad_token_id
845
+ ] * len(chosen_tokens["prompt_input_ids"])
846
+ rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:]
847
+ rejected_sequence_tokens["labels"][: len(rejected_tokens["prompt_input_ids"])] = [
848
+ self.label_pad_token_id
849
+ ] * len(rejected_tokens["prompt_input_ids"])
850
+
851
+ for k, toks in {
852
+ "chosen_": chosen_sequence_tokens,
853
+ "rejected_": rejected_sequence_tokens,
854
+ "": prompt_tokens,
855
+ }.items():
856
+ for type_key, tokens in toks.items():
857
+ if type_key == "token_type_ids":
858
+ continue
859
+ batch[f"{k}{type_key}"] = tokens
860
+
861
+ else:
862
+ chosen_tokens = self.processing_class(
863
+ chosen, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
864
+ )
865
+ rejected_tokens = self.processing_class(
866
+ rejected, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
867
+ )
868
+ prompt_tokens = self.processing_class(
869
+ prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True
870
+ )
871
+
872
+ batch["chosen_labels"] = chosen_tokens["input_ids"]
873
+ batch["rejected_labels"] = rejected_tokens["input_ids"]
874
+ batch["prompt_input_ids"] = prompt_tokens["input_ids"]
875
+ batch["prompt_attention_mask"] = prompt_tokens["attention_mask"]
876
+
877
+ if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"):
878
+ batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
879
+ labels=torch.tensor(batch["rejected_labels"])
880
+ )
881
+ batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
882
+ labels=torch.tensor(batch["chosen_labels"])
883
+ )
884
+
885
+ return batch
886
+
887
+ @staticmethod
888
+ def concatenated_inputs(
889
+ batch: dict[str, Union[list, torch.LongTensor]],
890
+ is_encoder_decoder: bool = False,
891
+ label_pad_token_id: int = -100,
892
+ padding_value: int = 0,
893
+ device: Optional[torch.device] = None,
894
+ ) -> dict[str, torch.LongTensor]:
895
+ """Concatenate the chosen and rejected inputs into a single tensor.
896
+
897
+ Args:
898
+ batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length).
899
+ is_encoder_decoder: Whether the model is an encoder-decoder model.
900
+ label_pad_token_id: The label pad token id.
901
+ padding_value: The padding value to use for the concatenated inputs_ids.
902
+ device: The device for the concatenated inputs.
903
+
904
+ Returns:
905
+ A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'.
906
+ """
907
+ concatenated_batch = {}
908
+
909
+ if is_encoder_decoder:
910
+ max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1])
911
+ else:
912
+ max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1])
913
+
914
+ for k in batch:
915
+ if k.startswith("chosen") and isinstance(batch[k], torch.Tensor):
916
+ if "labels" in k or is_encoder_decoder:
917
+ pad_value = label_pad_token_id
918
+ elif k.endswith("_input_ids"):
919
+ pad_value = padding_value
920
+ elif k.endswith("_attention_mask"):
921
+ pad_value = 0
922
+ concatenated_key = k.replace("chosen", "concatenated")
923
+ concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value)
924
+ for k in batch:
925
+ if k.startswith("rejected") and isinstance(batch[k], torch.Tensor):
926
+ if "labels" in k or is_encoder_decoder:
927
+ pad_value = label_pad_token_id
928
+ elif k.endswith("_input_ids"):
929
+ pad_value = padding_value
930
+ elif k.endswith("_attention_mask"):
931
+ pad_value = 0
932
+ concatenated_key = k.replace("rejected", "concatenated")
933
+ concatenated_batch[concatenated_key] = torch.cat(
934
+ (
935
+ concatenated_batch[concatenated_key],
936
+ pad_to_length(batch[k], max_length, pad_value=pad_value),
937
+ ),
938
+ dim=0,
939
+ ).to(device=device)
940
+
941
+ if is_encoder_decoder:
942
+ concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1).to(device=device)
943
+ concatenated_batch["concatenated_attention_mask"] = (
944
+ batch["prompt_attention_mask"].repeat(2, 1).to(device=device)
945
+ )
946
+
947
+ return concatenated_batch
948
+
949
+ def cpo_loss(
950
+ self,
951
+ policy_chosen_logps: torch.FloatTensor,
952
+ policy_rejected_logps: torch.FloatTensor,
953
+ ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
954
+ """Compute the CPO loss for a batch of policy and reference model log probabilities.
955
+
956
+ Args:
957
+ policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
958
+ policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
959
+
960
+ Returns:
961
+ A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
962
+ The losses tensor contains the CPO loss for each example in the batch.
963
+ The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
964
+ """
965
+ logits = (policy_chosen_logps - policy_rejected_logps).to(self.accelerator.device)
966
+
967
+ # The beta is a temperature parameter for the CPO loss, typically something in the range of 0.1 to 0.5.
968
+ # We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the labels and
969
+ # calculates a conservative CPO loss.
970
+
971
+ if self.loss_type == "simpo":
972
+ gamma_logratios = self.simpo_gamma / self.beta
973
+ logits = logits - gamma_logratios
974
+ # This reduces to Equation 3 from the CPO paper when label_smoothing -> 0.
975
+ losses = (
976
+ -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
977
+ - F.logsigmoid(-self.beta * logits) * self.label_smoothing
978
+ )
979
+ elif self.loss_type == "sigmoid":
980
+ # This reduces to Equation 3 from the CPO paper when label_smoothing -> 0.
981
+ losses = (
982
+ -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
983
+ - F.logsigmoid(-self.beta * logits) * self.label_smoothing
984
+ )
985
+ elif self.loss_type == "hinge":
986
+ losses = torch.relu(1 - self.beta * logits)
987
+ elif self.loss_type == "ipo":
988
+ # eqn (17) of the paper where beta is the regularization parameter for the IPO loss, denoted by tau in the paper.
989
+ losses = (logits - 1 / (2 * self.beta)) ** 2
990
+ else:
991
+ raise ValueError(
992
+ f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'simpo']"
993
+ )
994
+
995
+ chosen_rewards = self.beta * (policy_chosen_logps.to(self.accelerator.device)).detach()
996
+ rejected_rewards = self.beta * (policy_rejected_logps.to(self.accelerator.device)).detach()
997
+
998
+ return losses, chosen_rewards, rejected_rewards
999
+
1000
+ @staticmethod
1001
+ def get_batch_logps(
1002
+ logits: torch.FloatTensor,
1003
+ labels: torch.LongTensor,
1004
+ average_log_prob: bool = False,
1005
+ label_pad_token_id: int = -100,
1006
+ is_encoder_decoder: bool = False,
1007
+ ) -> torch.FloatTensor:
1008
+ """Compute the log probabilities of the given labels under the given logits.
1009
+
1010
+ Args:
1011
+ logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
1012
+ labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
1013
+ average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
1014
+ label_pad_token_id: The label pad token id.
1015
+ is_encoder_decoder: Whether the model is an encoder-decoder model.
1016
+
1017
+ Returns:
1018
+ A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
1019
+ """
1020
+ if logits.shape[:-1] != labels.shape:
1021
+ raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
1022
+
1023
+ if not is_encoder_decoder:
1024
+ labels = labels[:, 1:].clone()
1025
+ logits = logits[:, :-1, :]
1026
+ loss_mask = labels != label_pad_token_id
1027
+
1028
+ # dummy token; we'll ignore the losses on these tokens later
1029
+ labels[labels == label_pad_token_id] = 0
1030
+
1031
+ per_token_logps = selective_log_softmax(logits, labels)
1032
+
1033
+ if average_log_prob:
1034
+ return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
1035
+ else:
1036
+ return (per_token_logps * loss_mask).sum(-1)
1037
+
1038
+ def concatenated_forward(
1039
+ self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
1040
+ ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
1041
+ """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
1042
+
1043
+ We do this to avoid doing two forward passes, because it's faster for FSDP.
1044
+ """
1045
+ concatenated_batch = self.concatenated_inputs(
1046
+ batch,
1047
+ is_encoder_decoder=self.is_encoder_decoder,
1048
+ label_pad_token_id=self.label_pad_token_id,
1049
+ padding_value=self.padding_value,
1050
+ device=self.accelerator.device,
1051
+ )
1052
+ len_chosen = batch["chosen_labels"].shape[0]
1053
+
1054
+ model_kwargs = (
1055
+ {
1056
+ "decoder_input_ids": self._shift_right(concatenated_batch["concatenated_labels"]),
1057
+ }
1058
+ if self.is_encoder_decoder
1059
+ else {}
1060
+ )
1061
+
1062
+ if self.aux_loss_enabled:
1063
+ model_kwargs["output_router_logits"] = True
1064
+
1065
+ outputs = model(
1066
+ concatenated_batch["concatenated_input_ids"],
1067
+ attention_mask=concatenated_batch["concatenated_attention_mask"],
1068
+ use_cache=False,
1069
+ **model_kwargs,
1070
+ )
1071
+ all_logits = outputs.logits
1072
+
1073
+ def cross_entropy_loss(logits, labels):
1074
+ if not self.is_encoder_decoder:
1075
+ # Shift so that tokens < n predict n
1076
+ logits = logits[..., :-1, :].contiguous()
1077
+ labels = labels[..., 1:].contiguous()
1078
+ # Flatten the tokens
1079
+ loss_fct = nn.CrossEntropyLoss()
1080
+ logits = logits.view(-1, logits.shape[-1])
1081
+ labels = labels.view(-1)
1082
+ # Enable model parallelism
1083
+ labels = labels.to(logits.device)
1084
+ loss = loss_fct(logits, labels)
1085
+ return loss
1086
+
1087
+ labels = concatenated_batch["concatenated_labels"].clone()
1088
+
1089
+ if self.cpo_alpha == 0:
1090
+ nll_loss = torch.tensor(0.0).to(self.accelerator.device)
1091
+ else:
1092
+ nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])
1093
+
1094
+ all_logps = self.get_batch_logps(
1095
+ all_logits,
1096
+ concatenated_batch["concatenated_labels"],
1097
+ average_log_prob=self.loss_type in ["ipo", "simpo"],
1098
+ is_encoder_decoder=self.is_encoder_decoder,
1099
+ label_pad_token_id=self.label_pad_token_id,
1100
+ )
1101
+
1102
+ chosen_logps = all_logps[:len_chosen]
1103
+ rejected_logps = all_logps[len_chosen:]
1104
+
1105
+ chosen_logits = all_logits[:len_chosen]
1106
+ rejected_logits = all_logits[len_chosen:]
1107
+
1108
+ if self.aux_loss_enabled:
1109
+ return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss, outputs.aux_loss)
1110
+
1111
+ return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss)
1112
+
1113
+ def get_batch_loss_metrics(
1114
+ self,
1115
+ model,
1116
+ batch: dict[str, Union[list, torch.LongTensor]],
1117
+ train_eval: Literal["train", "eval"] = "train",
1118
+ ):
1119
+ """Compute the CPO loss and other metrics for the given batch of inputs for train or test."""
1120
+ metrics = {}
1121
+
1122
+ forward_output = self.concatenated_forward(model, batch)
1123
+ (
1124
+ policy_chosen_logps,
1125
+ policy_rejected_logps,
1126
+ policy_chosen_logits,
1127
+ policy_rejected_logits,
1128
+ policy_nll_loss,
1129
+ ) = forward_output[:5]
1130
+ if self.aux_loss_enabled:
1131
+ aux_loss = forward_output[5]
1132
+
1133
+ losses, chosen_rewards, rejected_rewards = self.cpo_loss(
1134
+ policy_chosen_logps,
1135
+ policy_rejected_logps,
1136
+ )
1137
+
1138
+ loss = losses.mean() + self.cpo_alpha * policy_nll_loss
1139
+ reward_accuracies = (chosen_rewards > rejected_rewards).float()
1140
+
1141
+ prefix = "eval_" if train_eval == "eval" else ""
1142
+ metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean().item()
1143
+ metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean().item()
1144
+ metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean().item()
1145
+ metrics[f"{prefix}rewards/margins"] = (
1146
+ self.accelerator.gather_for_metrics(chosen_rewards - rejected_rewards).mean().item()
1147
+ )
1148
+ metrics[f"{prefix}logps/rejected"] = (
1149
+ self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean().item()
1150
+ )
1151
+ metrics[f"{prefix}logps/chosen"] = (
1152
+ self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean().item()
1153
+ )
1154
+ metrics[f"{prefix}logits/rejected"] = (
1155
+ self.accelerator.gather_for_metrics(policy_rejected_logits).detach().mean().item()
1156
+ )
1157
+ metrics[f"{prefix}logits/chosen"] = (
1158
+ self.accelerator.gather_for_metrics(policy_chosen_logits).detach().mean().item()
1159
+ )
1160
+ metrics[f"{prefix}nll_loss"] = self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean().item()
1161
+
1162
+ if self.aux_loss_enabled:
1163
+ loss += self.aux_loss_coef * aux_loss
1164
+
1165
+ return loss, metrics
1166
+
1167
+ def compute_loss(
1168
+ self,
1169
+ model: Union[PreTrainedModel, nn.Module],
1170
+ inputs: dict[str, Union[torch.Tensor, Any]],
1171
+ return_outputs=False,
1172
+ num_items_in_batch=None,
1173
+ ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
1174
+ compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1175
+
1176
+ with compute_loss_context_manager:
1177
+ loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
1178
+
1179
+ # force log the metrics
1180
+ self.store_metrics(metrics, train_eval="train")
1181
+
1182
+ if return_outputs:
1183
+ return (loss, metrics)
1184
+ return loss
1185
+
1186
+ def generate_from_model(self, model, batch: dict[str, torch.LongTensor]) -> str:
1187
+ """Generate samples from the model and reference model for the given batch of inputs."""
1188
+
1189
+ # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
1190
+ # the torch cuda amp context manager as some hidden states are silently casted to full precision.
1191
+ generate_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1192
+
1193
+ with generate_context_manager:
1194
+ policy_output = model.generate(
1195
+ input_ids=batch["prompt_input_ids"],
1196
+ attention_mask=batch["prompt_attention_mask"],
1197
+ max_length=self.max_length,
1198
+ do_sample=True,
1199
+ pad_token_id=self.processing_class.pad_token_id,
1200
+ )
1201
+
1202
+ policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id)
1203
+ policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True)
1204
+
1205
+ return policy_output_decoded
1206
+
1207
+ def prediction_step(
1208
+ self,
1209
+ model: Union[PreTrainedModel, nn.Module],
1210
+ inputs: dict[str, Union[torch.Tensor, Any]],
1211
+ prediction_loss_only: bool,
1212
+ ignore_keys: Optional[list[str]] = None,
1213
+ ):
1214
+ if ignore_keys is None:
1215
+ if hasattr(model, "config"):
1216
+ ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
1217
+ else:
1218
+ ignore_keys = []
1219
+
1220
+ prediction_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1221
+
1222
+ with torch.no_grad(), prediction_context_manager:
1223
+ loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval")
1224
+
1225
+ # force log the metrics
1226
+ self.store_metrics(metrics, train_eval="eval")
1227
+
1228
+ if prediction_loss_only:
1229
+ return (loss.detach(), None, None)
1230
+
1231
+ # logits for the chosen and rejected samples from model
1232
+ logits_dict = {
1233
+ "eval_logits/chosen": metrics["eval_logits/chosen"],
1234
+ "eval_logits/rejected": metrics["eval_logits/rejected"],
1235
+ }
1236
+ logits = tuple(v.unsqueeze(dim=0) for k, v in logits_dict.items() if k not in ignore_keys)
1237
+ logits = torch.stack(logits).mean(axis=1).to(self.accelerator.device)
1238
+ labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
1239
+
1240
+ return (loss.detach(), logits, labels)
1241
+
1242
+ def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
1243
+ for key, value in metrics.items():
1244
+ self._stored_metrics[train_eval][key].append(value)
1245
+
1246
+ def evaluation_loop(
1247
+ self,
1248
+ dataloader: DataLoader,
1249
+ description: str,
1250
+ prediction_loss_only: Optional[bool] = None,
1251
+ ignore_keys: Optional[list[str]] = None,
1252
+ metric_key_prefix: str = "eval",
1253
+ ) -> EvalLoopOutput:
1254
+ """
1255
+ Overriding built-in evaluation loop to store metrics for each batch.
1256
+ Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
1257
+
1258
+ Works both with or without labels.
1259
+ """
1260
+
1261
+ # Sample and save to game log if requested (for one batch to save time)
1262
+ if self.generate_during_eval:
1263
+ # Generate random indices within the range of the total number of samples
1264
+ num_samples = len(dataloader.dataset)
1265
+ random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
1266
+
1267
+ # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
1268
+ random_batch_dataset = dataloader.dataset.select(random_indices)
1269
+ random_batch = self.data_collator(random_batch_dataset)
1270
+ random_batch = self._prepare_inputs(random_batch)
1271
+
1272
+ policy_output_decoded = self.generate_from_model(self.model, random_batch)
1273
+
1274
+ table = pd.DataFrame(
1275
+ columns=["Prompt", "Policy"],
1276
+ data=[
1277
+ [prompt, pol[len(prompt) :]] for prompt, pol in zip(random_batch["prompt"], policy_output_decoded)
1278
+ ],
1279
+ )
1280
+ if "wandb" in self.args.report_to:
1281
+ wandb.log({"game_log": wandb.Table(data=table)})
1282
+
1283
+ if "comet_ml" in self.args.report_to:
1284
+ log_table_to_comet_experiment(
1285
+ name="game_log.csv",
1286
+ table=table,
1287
+ )
1288
+
1289
+ # Base evaluation
1290
+ initial_output = super().evaluation_loop(
1291
+ dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
1292
+ )
1293
+
1294
+ return initial_output
1295
+
1296
+ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
1297
+ """
1298
+ Log `logs` on the various objects watching training, including stored metrics.
1299
+
1300
+ Args:
1301
+ logs (`dict[str, float]`):
1302
+ The values to log.
1303
+ start_time (`float` or `None`, *optional*, defaults to `None`):
1304
+ Start time of the training.
1305
+ """
1306
+ # logs either has 'loss' or 'eval_loss'
1307
+ train_eval = "train" if "loss" in logs else "eval"
1308
+ # Add averaged stored metrics to logs
1309
+ for key, metrics in self._stored_metrics[train_eval].items():
1310
+ logs[key] = torch.tensor(metrics).mean().item()
1311
+ del self._stored_metrics[train_eval]
1312
+
1313
+ if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
1314
+ return super().log(logs, start_time)
1315
+ else: # transformers<=4.46
1316
+ return super().log(logs)
1317
+
1318
+ def _shift_right(self, input_ids):
1319
+ if self.decoder_start_token_id is None:
1320
+ raise ValueError(
1321
+ "model.config.decoder_start_token_id has to be defined. It is usually set to the pad_token_id."
1322
+ )
1323
+
1324
+ # shift inputs to the right
1325
+ if is_torch_fx_proxy(input_ids):
1326
+ # Item assignment is not supported natively for proxies.
1327
+ shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), self.decoder_start_token_id)
1328
+ shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
1329
+ else:
1330
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
1331
+ shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
1332
+ shifted_input_ids[..., 0] = self.decoder_start_token_id
1333
+
1334
+ if self.pad_token_id is None:
1335
+ raise ValueError("model.config.pad_token_id has to be defined.")
1336
+ # replace possible -100 values in labels by `pad_token_id`
1337
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, self.pad_token_id)
1338
+
1339
+ return shifted_input_ids
1340
+
1341
+ def create_model_card(
1342
+ self,
1343
+ model_name: Optional[str] = None,
1344
+ dataset_name: Optional[str] = None,
1345
+ tags: Union[str, list[str], None] = None,
1346
+ ):
1347
+ """
1348
+ Creates a draft of a model card using the information available to the `Trainer`.
1349
+
1350
+ Args:
1351
+ model_name (`str` or `None`, *optional*, defaults to `None`):
1352
+ Name of the model.
1353
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
1354
+ Name of the dataset used for training.
1355
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
1356
+ Tags to be associated with the model card.
1357
+ """
1358
+ if not self.is_world_process_zero():
1359
+ return
1360
+
1361
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
1362
+ base_model = self.model.config._name_or_path
1363
+ else:
1364
+ base_model = None
1365
+
1366
+ tags = tags or []
1367
+ if isinstance(tags, str):
1368
+ tags = [tags]
1369
+
1370
+ if hasattr(self.model.config, "unsloth_version"):
1371
+ tags.append("unsloth")
1372
+
1373
+ citation = textwrap.dedent("""\
1374
+ @inproceedings{xu2024contrastive,
1375
+ title = {{Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation}},
1376
+ author = {Haoran Xu and Amr Sharaf and Yunmo Chen and Weiting Tan and Lingfeng Shen and Benjamin Van Durme and Kenton Murray and Young Jin Kim},
1377
+ year = 2024,
1378
+ booktitle = {Forty-first International Conference on Machine Learning, {ICML} 2024, Vienna, Austria, July 21-27, 2024},
1379
+ publisher = {OpenReview.net},
1380
+ url = {https://openreview.net/forum?id=51iwkioZpn}
1381
+ }""")
1382
+
1383
+ model_card = generate_model_card(
1384
+ base_model=base_model,
1385
+ model_name=model_name,
1386
+ hub_model_id=self.hub_model_id,
1387
+ dataset_name=dataset_name,
1388
+ tags=tags,
1389
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
1390
+ comet_url=get_comet_experiment_url(),
1391
+ trainer_name="CPO",
1392
+ trainer_citation=citation,
1393
+ paper_title="Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation",
1394
+ paper_id="2401.08417",
1395
+ )
1396
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
1397
+ class UnslothCPOTrainer(_UnslothCPOTrainer):
1398
+ """
1399
+
1400
+ Initialize CPOTrainer.
1401
+
1402
+ Args:
1403
+ model (`transformers.PreTrainedModel`):
1404
+ The model to train, preferably an `AutoModelForSequenceClassification`.
1405
+ args (`CPOConfig`):
1406
+ The CPO config arguments to use for training.
1407
+ data_collator (`transformers.DataCollator`):
1408
+ The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
1409
+ which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
1410
+ train_dataset (`datasets.Dataset`):
1411
+ The dataset to use for training.
1412
+ eval_dataset (`datasets.Dataset`):
1413
+ The dataset to use for evaluation.
1414
+ processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
1415
+ Processing class used to process the data. If provided, will be used to automatically process the inputs
1416
+ for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
1417
+ reuse the fine-tuned model.
1418
+ model_init (`Callable[[], transformers.PreTrainedModel]`):
1419
+ The model initializer to use for training. If None is specified, the default model initializer will be used.
1420
+ callbacks (`list[transformers.TrainerCallback]`):
1421
+ The callbacks to use for training.
1422
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
1423
+ The optimizer and scheduler to use for training.
1424
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
1425
+ The function to use to preprocess the logits before computing the metrics.
1426
+ peft_config (`dict`, defaults to `None`):
1427
+ The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
1428
+ compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
1429
+ The function to use to compute the metrics. Must take a `EvalPrediction` and return
1430
+ a dictionary string to metric values.
1431
+
1432
+ """
1433
+ def __init__(
1434
+ self,
1435
+ model = None,
1436
+ args = None,
1437
+ data_collator = None,
1438
+ train_dataset = None,
1439
+ eval_dataset = None,
1440
+ processing_class = None,
1441
+ model_init = None,
1442
+ callbacks = None,
1443
+ preprocess_logits_for_metrics = None,
1444
+ peft_config = None,
1445
+ compute_metrics = None,
1446
+ **kwargs
1447
+ ):
1448
+ if args is None: args = UnslothCPOConfig()
1449
+ use_bf16 = getattr(args, 'bf16', False)
1450
+ use_fp16 = getattr(args, 'fp16', False)
1451
+ force_float32 = False
1452
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
1453
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
1454
+ force_float32 = True
1455
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
1456
+ dtype = getattr(model.config, 'torch_dtype', None)
1457
+ if dtype is None: dtype = model.get_input_embeddings().dtype
1458
+ from unsloth_zoo.utils import _get_dtype
1459
+ dtype = _get_dtype(dtype)
1460
+ float16 = dtype == torch.float16
1461
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
1462
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
1463
+ if force_float32:
1464
+ args.fp16 = False
1465
+ args.bf16 = False
1466
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1467
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
1468
+ args.fp16 = float16
1469
+ args.bf16 = not float16
1470
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
1471
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
1472
+ args.eval_strategy = 'steps'
1473
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
1474
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
1475
+ if ga_steps is not None and ga_steps > 1:
1476
+ from transformers import __version__ as transformers_version
1477
+ if Version(transformers_version) <= Version('4.45.2'):
1478
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
1479
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
1480
+ if getattr(args, 'eval_strategy', 'no') != 'no':
1481
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
1482
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
1483
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
1484
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
1485
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
1486
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
1487
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
1488
+ if force_float32:
1489
+ args.bf16_full_eval = False
1490
+ args.fp16_full_eval = False
1491
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
1492
+ args.bf16_full_eval = True
1493
+ args.fp16_full_eval = False
1494
+ elif not bf16_full_eval and not fp16_full_eval:
1495
+ args.bf16_full_eval = args.bf16
1496
+ args.fp16_full_eval = args.fp16
1497
+ _output_logits = False
1498
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
1499
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
1500
+ if _output_logits:
1501
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
1502
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1503
+ pass
1504
+ else:
1505
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
1506
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
1507
+ if args_max_seq_length is None and model_max_seq_length is not None:
1508
+ max_seq_length = model.max_seq_length
1509
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1510
+ if model is not None and hasattr(model, 'for_training'):
1511
+ model.for_training()
1512
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1513
+ if 'processing_class' in locals():
1514
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1515
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1516
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
1517
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
1518
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1519
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
1520
+ data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
1521
+ elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
1522
+ data_collator = DataCollatorForSeq2Seq(__tokenizer)
1523
+ else:
1524
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
1525
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
1526
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
1527
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1528
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
1529
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
1530
+ data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
1531
+ else:
1532
+ data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
1533
+ other_metrics = []
1534
+
1535
+ from unsloth_zoo.logging_utils import PatchRLStatistics
1536
+ PatchRLStatistics('cpo_trainer', other_metrics)
1537
+
1538
+ super().__init__(
1539
+ model = model,
1540
+ args = args,
1541
+ data_collator = data_collator,
1542
+ train_dataset = train_dataset,
1543
+ eval_dataset = eval_dataset,
1544
+ processing_class = processing_class,
1545
+ model_init = model_init,
1546
+ callbacks = callbacks,
1547
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics,
1548
+ peft_config = peft_config,
1549
+ compute_metrics = compute_metrics,**kwargs)
1550
+ if hasattr(self, 'neftune_hook_handle'):
1551
+ self.neftune_hook_handle.remove()
1552
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1553
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
1554
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1555
+ pass
1556
+
1557
+ pass
unsloth_compiled_cache/UnslothDDPOTrainer.py ADDED
@@ -0,0 +1,872 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.3.17
3
+ 2025.3.19
4
+ 4.50.3
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from trl.trainer.ddpo_trainer import (Accelerator, Any, Callable, DDPOConfig, DDPOStableDiffusionPipeline, DDPOTrainer, Optional, PerPromptStatTracker, ProjectConfiguration, PyTorchModelHubMixin, Union, defaultdict, futures, generate_model_card, get_comet_experiment_url, is_wandb_available, logger, os, set_seed, textwrap, torch, wandb, warn)
13
+
14
+
15
+ import os
16
+ from typing import *
17
+ from dataclasses import dataclass, field
18
+ from packaging.version import Version
19
+ import torch
20
+ import numpy as np
21
+ from contextlib import nullcontext
22
+ from torch.nn import functional as F
23
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
+
25
+ torch_compile_options = {
26
+ "epilogue_fusion" : True,
27
+ "max_autotune" : False,
28
+ "shape_padding" : True,
29
+ "trace.enabled" : False,
30
+ "triton.cudagraphs" : False,
31
+ }
32
+
33
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
+ def selective_log_softmax(logits, index):
35
+ logits = logits.to(torch.float32)
36
+ selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
+ # loop to reduce peak mem consumption
38
+ # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
+ logsumexp_values = torch.logsumexp(logits, dim = -1)
40
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
+ return per_token_logps
42
+ @dataclass
43
+ class UnslothDDPOConfig(DDPOConfig):
44
+ """
45
+
46
+ Configuration class for the [`DDPOTrainer`].
47
+
48
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
49
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
50
+ command line.
51
+
52
+ Parameters:
53
+ exp_name (`str`, *optional*, defaults to `os.path.basename(sys.argv[0])[: -len(".py")]`):
54
+ Name of this experiment (by default is the file name without the extension name).
55
+ run_name (`str`, *optional*, defaults to `""`):
56
+ Name of this run.
57
+ seed (`int`, *optional*, defaults to `0`):
58
+ Random seed.
59
+ log_with (`Literal["wandb", "tensorboard"]]` or `None`, *optional*, defaults to `None`):
60
+ Log with either 'wandb' or 'tensorboard', check
61
+ https://huggingface.co/docs/accelerate/usage_guides/tracking for more details.
62
+ tracker_kwargs (`Dict`, *optional*, defaults to `{}`):
63
+ Keyword arguments for the tracker (e.g. wandb_project).
64
+ accelerator_kwargs (`Dict`, *optional*, defaults to `{}`):
65
+ Keyword arguments for the accelerator.
66
+ project_kwargs (`Dict`, *optional*, defaults to `{}`):
67
+ Keyword arguments for the accelerator project config (e.g. `logging_dir`).
68
+ tracker_project_name (`str`, *optional*, defaults to `"trl"`):
69
+ Name of project to use for tracking.
70
+ logdir (`str`, *optional*, defaults to `"logs"`):
71
+ Top-level logging directory for checkpoint saving.
72
+ num_epochs (`int`, *optional*, defaults to `100`):
73
+ Number of epochs to train.
74
+ save_freq (`int`, *optional*, defaults to `1`):
75
+ Number of epochs between saving model checkpoints.
76
+ num_checkpoint_limit (`int`, *optional*, defaults to `5`):
77
+ Number of checkpoints to keep before overwriting old ones.
78
+ mixed_precision (`str`, *optional*, defaults to `"fp16"`):
79
+ Mixed precision training.
80
+ allow_tf32 (`bool`, *optional*, defaults to `True`):
81
+ Allow `tf32` on Ampere GPUs.
82
+ resume_from (`str`, *optional*, defaults to `""`):
83
+ Resume training from a checkpoint.
84
+ sample_num_steps (`int`, *optional*, defaults to `50`):
85
+ Number of sampler inference steps.
86
+ sample_eta (`float`, *optional*, defaults to `1.0`):
87
+ Eta parameter for the DDIM sampler.
88
+ sample_guidance_scale (`float`, *optional*, defaults to `5.0`):
89
+ Classifier-free guidance weight.
90
+ sample_batch_size (`int`, *optional*, defaults to `1`):
91
+ Batch size (per GPU) to use for sampling.
92
+ sample_num_batches_per_epoch (`int`, *optional*, defaults to `2`):
93
+ Number of batches to sample per epoch.
94
+ train_batch_size (`int`, *optional*, defaults to `1`):
95
+ Batch size (per GPU) to use for training.
96
+ train_use_8bit_adam (`bool`, *optional*, defaults to `False`):
97
+ Use 8bit Adam optimizer from bitsandbytes.
98
+ train_learning_rate (`float`, *optional*, defaults to `3e-4`):
99
+ Learning rate.
100
+ train_adam_beta1 (`float`, *optional*, defaults to `0.9`):
101
+ Adam beta1.
102
+ train_adam_beta2 (`float`, *optional*, defaults to `0.999`):
103
+ Adam beta2.
104
+ train_adam_weight_decay (`float`, *optional*, defaults to `1e-4`):
105
+ Adam weight decay.
106
+ train_adam_epsilon (`float`, *optional*, defaults to `1e-8`):
107
+ Adam epsilon.
108
+ train_gradient_accumulation_steps (`int`, *optional*, defaults to `1`):
109
+ Number of gradient accumulation steps.
110
+ train_max_grad_norm (`float`, *optional*, defaults to `1.0`):
111
+ Maximum gradient norm for gradient clipping.
112
+ train_num_inner_epochs (`int`, *optional*, defaults to `1`):
113
+ Number of inner epochs per outer epoch.
114
+ train_cfg (`bool`, *optional*, defaults to `True`):
115
+ Whether to use classifier-free guidance during training.
116
+ train_adv_clip_max (`float`, *optional*, defaults to `5.0`):
117
+ Clip advantages to the range.
118
+ train_clip_range (`float`, *optional*, defaults to `1e-4`):
119
+ PPO clip range.
120
+ train_timestep_fraction (`float`, *optional*, defaults to `1.0`):
121
+ Fraction of timesteps to train on.
122
+ per_prompt_stat_tracking (`bool`, *optional*, defaults to `False`):
123
+ Whether to track statistics for each prompt separately.
124
+ per_prompt_stat_tracking_buffer_size (`int`, *optional*, defaults to `16`):
125
+ Number of reward values to store in the buffer for each prompt.
126
+ per_prompt_stat_tracking_min_count (`int`, *optional*, defaults to `16`):
127
+ Minimum number of reward values to store in the buffer.
128
+ async_reward_computation (`bool`, *optional*, defaults to `False`):
129
+ Whether to compute rewards asynchronously.
130
+ max_workers (`int`, *optional*, defaults to `2`):
131
+ Maximum number of workers to use for async reward computation.
132
+ negative_prompts (`str`, *optional*, defaults to `""`):
133
+ Comma-separated list of prompts to use as negative examples.
134
+ push_to_hub (`bool`, *optional*, defaults to `False`):
135
+ Whether to push the final model checkpoint to the Hub.
136
+
137
+ """
138
+ vllm_sampling_params: Optional[Any] = field(
139
+ default = None,
140
+ metadata = {'help': 'vLLM SamplingParams'},
141
+ )
142
+ unsloth_num_chunks : Optional[int] = field(
143
+ default = -1,
144
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
145
+ )
146
+ def __init__(
147
+ self,
148
+ exp_name = 'main',
149
+ run_name = '',
150
+ seed = 3407,
151
+ log_with = None,
152
+ tracker_project_name = 'trl',
153
+ logdir = 'logs',
154
+ num_epochs = 100,
155
+ save_freq = 1,
156
+ num_checkpoint_limit = 5,
157
+ mixed_precision = 'fp16',
158
+ allow_tf32 = True,
159
+ resume_from = '',
160
+ sample_num_steps = 50,
161
+ sample_eta = 1.0,
162
+ sample_guidance_scale = 5.0,
163
+ sample_batch_size = 1,
164
+ sample_num_batches_per_epoch = 2,
165
+ train_batch_size = 1,
166
+ train_use_8bit_adam = False,
167
+ train_learning_rate = 5e-05,
168
+ train_adam_beta1 = 0.9,
169
+ train_adam_beta2 = 0.999,
170
+ train_adam_weight_decay = 0.01,
171
+ train_adam_epsilon = 1e-08,
172
+ train_gradient_accumulation_steps = 2,
173
+ train_max_grad_norm = 1.0,
174
+ train_num_inner_epochs = 1,
175
+ train_cfg = True,
176
+ train_adv_clip_max = 5.0,
177
+ train_clip_range = 0.0001,
178
+ train_timestep_fraction = 1.0,
179
+ per_prompt_stat_tracking = False,
180
+ per_prompt_stat_tracking_buffer_size = 16,
181
+ per_prompt_stat_tracking_min_count = 16,
182
+ async_reward_computation = False,
183
+ max_workers = 2,
184
+ negative_prompts = '',
185
+ push_to_hub = False,
186
+ vllm_sampling_params = None,
187
+ unsloth_num_chunks = -1,
188
+ **kwargs,
189
+ ):
190
+
191
+ super().__init__(
192
+ exp_name = exp_name,
193
+ run_name = run_name,
194
+ seed = seed,
195
+ log_with = log_with,
196
+ tracker_project_name = tracker_project_name,
197
+ logdir = logdir,
198
+ num_epochs = num_epochs,
199
+ save_freq = save_freq,
200
+ num_checkpoint_limit = num_checkpoint_limit,
201
+ mixed_precision = mixed_precision,
202
+ allow_tf32 = allow_tf32,
203
+ resume_from = resume_from,
204
+ sample_num_steps = sample_num_steps,
205
+ sample_eta = sample_eta,
206
+ sample_guidance_scale = sample_guidance_scale,
207
+ sample_batch_size = sample_batch_size,
208
+ sample_num_batches_per_epoch = sample_num_batches_per_epoch,
209
+ train_batch_size = train_batch_size,
210
+ train_use_8bit_adam = train_use_8bit_adam,
211
+ train_learning_rate = train_learning_rate,
212
+ train_adam_beta1 = train_adam_beta1,
213
+ train_adam_beta2 = train_adam_beta2,
214
+ train_adam_weight_decay = train_adam_weight_decay,
215
+ train_adam_epsilon = train_adam_epsilon,
216
+ train_gradient_accumulation_steps = train_gradient_accumulation_steps,
217
+ train_max_grad_norm = train_max_grad_norm,
218
+ train_num_inner_epochs = train_num_inner_epochs,
219
+ train_cfg = train_cfg,
220
+ train_adv_clip_max = train_adv_clip_max,
221
+ train_clip_range = train_clip_range,
222
+ train_timestep_fraction = train_timestep_fraction,
223
+ per_prompt_stat_tracking = per_prompt_stat_tracking,
224
+ per_prompt_stat_tracking_buffer_size = per_prompt_stat_tracking_buffer_size,
225
+ per_prompt_stat_tracking_min_count = per_prompt_stat_tracking_min_count,
226
+ async_reward_computation = async_reward_computation,
227
+ max_workers = max_workers,
228
+ negative_prompts = negative_prompts,
229
+ push_to_hub = push_to_hub,**kwargs)
230
+ self.vllm_sampling_params = vllm_sampling_params
231
+ self.unsloth_num_chunks = unsloth_num_chunks
232
+ pass
233
+
234
+ class _UnslothDDPOTrainer(PyTorchModelHubMixin):
235
+ """"""
236
+
237
+ _tag_names = ["trl", "ddpo"]
238
+
239
+ def __init__(
240
+ self,
241
+ config: DDPOConfig,
242
+ reward_function: Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor],
243
+ prompt_function: Callable[[], tuple[str, Any]],
244
+ sd_pipeline: DDPOStableDiffusionPipeline,
245
+ image_samples_hook: Optional[Callable[[Any, Any, Any], Any]] = None,
246
+ ):
247
+ if image_samples_hook is None:
248
+ warn("No image_samples_hook provided; no images will be logged")
249
+
250
+ self.prompt_fn = prompt_function
251
+ self.reward_fn = reward_function
252
+ self.config = config
253
+ self.image_samples_callback = image_samples_hook
254
+
255
+ accelerator_project_config = ProjectConfiguration(**self.config.project_kwargs)
256
+
257
+ if self.config.resume_from:
258
+ self.config.resume_from = os.path.normpath(os.path.expanduser(self.config.resume_from))
259
+ if "checkpoint_" not in os.path.basename(self.config.resume_from):
260
+ # get the most recent checkpoint in this directory
261
+ checkpoints = list(
262
+ filter(
263
+ lambda x: "checkpoint_" in x,
264
+ os.listdir(self.config.resume_from),
265
+ )
266
+ )
267
+ if len(checkpoints) == 0:
268
+ raise ValueError(f"No checkpoints found in {self.config.resume_from}")
269
+ checkpoint_numbers = sorted([int(x.split("_")[-1]) for x in checkpoints])
270
+ self.config.resume_from = os.path.join(
271
+ self.config.resume_from,
272
+ f"checkpoint_{checkpoint_numbers[-1]}",
273
+ )
274
+
275
+ accelerator_project_config.iteration = checkpoint_numbers[-1] + 1
276
+
277
+ # number of timesteps within each trajectory to train on
278
+ self.num_train_timesteps = int(self.config.sample_num_steps * self.config.train_timestep_fraction)
279
+
280
+ self.accelerator = Accelerator(
281
+ log_with=self.config.log_with,
282
+ mixed_precision=self.config.mixed_precision,
283
+ project_config=accelerator_project_config,
284
+ # we always accumulate gradients across timesteps; we want config.train.gradient_accumulation_steps to be the
285
+ # number of *samples* we accumulate across, so we need to multiply by the number of training timesteps to get
286
+ # the total number of optimizer steps to accumulate across.
287
+ gradient_accumulation_steps=self.config.train_gradient_accumulation_steps * self.num_train_timesteps,
288
+ **self.config.accelerator_kwargs,
289
+ )
290
+
291
+ is_okay, message = self._config_check()
292
+ if not is_okay:
293
+ raise ValueError(message)
294
+
295
+ is_using_tensorboard = config.log_with is not None and config.log_with == "tensorboard"
296
+
297
+ if self.accelerator.is_main_process:
298
+ self.accelerator.init_trackers(
299
+ self.config.tracker_project_name,
300
+ config=dict(ddpo_trainer_config=config.to_dict()) if not is_using_tensorboard else config.to_dict(),
301
+ init_kwargs=self.config.tracker_kwargs,
302
+ )
303
+
304
+ logger.info(f"\n{config}")
305
+
306
+ set_seed(self.config.seed, device_specific=True)
307
+
308
+ self.sd_pipeline = sd_pipeline
309
+
310
+ self.sd_pipeline.set_progress_bar_config(
311
+ position=1,
312
+ disable=not self.accelerator.is_local_main_process,
313
+ leave=False,
314
+ desc="Timestep",
315
+ dynamic_ncols=True,
316
+ )
317
+
318
+ # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
319
+ # as these weights are only used for inference, keeping weights in full precision is not required.
320
+ if self.accelerator.mixed_precision == "fp16":
321
+ inference_dtype = torch.float16
322
+ elif self.accelerator.mixed_precision == "bf16":
323
+ inference_dtype = torch.bfloat16
324
+ else:
325
+ inference_dtype = torch.float32
326
+
327
+ self.sd_pipeline.vae.to(self.accelerator.device, dtype=inference_dtype)
328
+ self.sd_pipeline.text_encoder.to(self.accelerator.device, dtype=inference_dtype)
329
+ self.sd_pipeline.unet.to(self.accelerator.device, dtype=inference_dtype)
330
+
331
+ trainable_layers = self.sd_pipeline.get_trainable_layers()
332
+
333
+ self.accelerator.register_save_state_pre_hook(self._save_model_hook)
334
+ self.accelerator.register_load_state_pre_hook(self._load_model_hook)
335
+
336
+ # Enable TF32 for faster training on Ampere GPUs,
337
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
338
+ if self.config.allow_tf32:
339
+ torch.backends.cuda.matmul.allow_tf32 = True
340
+
341
+ self.optimizer = self._setup_optimizer(
342
+ trainable_layers.parameters() if not isinstance(trainable_layers, list) else trainable_layers
343
+ )
344
+
345
+ self.neg_prompt_embed = self.sd_pipeline.text_encoder(
346
+ self.sd_pipeline.tokenizer(
347
+ [""] if self.config.negative_prompts is None else self.config.negative_prompts,
348
+ return_tensors="pt",
349
+ padding="max_length",
350
+ truncation=True,
351
+ max_length=self.sd_pipeline.tokenizer.model_max_length,
352
+ ).input_ids.to(self.accelerator.device)
353
+ )[0]
354
+
355
+ if config.per_prompt_stat_tracking:
356
+ self.stat_tracker = PerPromptStatTracker(
357
+ config.per_prompt_stat_tracking_buffer_size,
358
+ config.per_prompt_stat_tracking_min_count,
359
+ )
360
+
361
+ # NOTE: for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses
362
+ # more memory
363
+ self.autocast = self.sd_pipeline.autocast or self.accelerator.autocast
364
+
365
+ if hasattr(self.sd_pipeline, "use_lora") and self.sd_pipeline.use_lora:
366
+ unet, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
367
+ self.trainable_layers = list(filter(lambda p: p.requires_grad, unet.parameters()))
368
+ else:
369
+ self.trainable_layers, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
370
+
371
+ if self.config.async_reward_computation:
372
+ self.executor = futures.ThreadPoolExecutor(max_workers=config.max_workers)
373
+
374
+ if config.resume_from:
375
+ logger.info(f"Resuming from {config.resume_from}")
376
+ self.accelerator.load_state(config.resume_from)
377
+ self.first_epoch = int(config.resume_from.split("_")[-1]) + 1
378
+ else:
379
+ self.first_epoch = 0
380
+
381
+ def compute_rewards(self, prompt_image_pairs, is_async=False):
382
+ if not is_async:
383
+ rewards = []
384
+ for images, prompts, prompt_metadata in prompt_image_pairs:
385
+ reward, reward_metadata = self.reward_fn(images, prompts, prompt_metadata)
386
+ rewards.append(
387
+ (
388
+ torch.as_tensor(reward, device=self.accelerator.device),
389
+ reward_metadata,
390
+ )
391
+ )
392
+ else:
393
+ rewards = self.executor.map(lambda x: self.reward_fn(*x), prompt_image_pairs)
394
+ rewards = [
395
+ (torch.as_tensor(reward.result(), device=self.accelerator.device), reward_metadata.result())
396
+ for reward, reward_metadata in rewards
397
+ ]
398
+
399
+ return zip(*rewards)
400
+
401
+ def step(self, epoch: int, global_step: int):
402
+ """
403
+ Perform a single step of training.
404
+
405
+ Args:
406
+ epoch (int): The current epoch.
407
+ global_step (int): The current global step.
408
+
409
+ Side Effects:
410
+ - Model weights are updated
411
+ - Logs the statistics to the accelerator trackers.
412
+ - If `self.image_samples_callback` is not None, it will be called with the prompt_image_pairs, global_step, and the accelerator tracker.
413
+
414
+ Returns:
415
+ global_step (int): The updated global step.
416
+
417
+ """
418
+ samples, prompt_image_data = self._generate_samples(
419
+ iterations=self.config.sample_num_batches_per_epoch,
420
+ batch_size=self.config.sample_batch_size,
421
+ )
422
+
423
+ # collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...)
424
+ samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()}
425
+ rewards, rewards_metadata = self.compute_rewards(
426
+ prompt_image_data, is_async=self.config.async_reward_computation
427
+ )
428
+
429
+ for i, image_data in enumerate(prompt_image_data):
430
+ image_data.extend([rewards[i], rewards_metadata[i]])
431
+
432
+ if self.image_samples_callback is not None:
433
+ self.image_samples_callback(prompt_image_data, global_step, self.accelerator.trackers[0])
434
+
435
+ rewards = torch.cat(rewards)
436
+ rewards = self.accelerator.gather(rewards).cpu().numpy()
437
+
438
+ self.accelerator.log(
439
+ {
440
+ "reward": rewards,
441
+ "epoch": epoch,
442
+ "reward_mean": rewards.mean(),
443
+ "reward_std": rewards.std(),
444
+ },
445
+ step=global_step,
446
+ )
447
+
448
+ if self.config.per_prompt_stat_tracking:
449
+ # gather the prompts across processes
450
+ prompt_ids = self.accelerator.gather(samples["prompt_ids"]).cpu().numpy()
451
+ prompts = self.sd_pipeline.tokenizer.batch_decode(prompt_ids, skip_special_tokens=True)
452
+ advantages = self.stat_tracker.update(prompts, rewards)
453
+ else:
454
+ advantages = (rewards - rewards.mean()) / (rewards.std() + 1e-8)
455
+
456
+ # ungather advantages; keep the entries corresponding to the samples on this process
457
+ samples["advantages"] = (
458
+ torch.as_tensor(advantages)
459
+ .reshape(self.accelerator.num_processes, -1)[self.accelerator.process_index]
460
+ .to(self.accelerator.device)
461
+ )
462
+
463
+ del samples["prompt_ids"]
464
+
465
+ total_batch_size, num_timesteps = samples["timesteps"].shape
466
+
467
+ for inner_epoch in range(self.config.train_num_inner_epochs):
468
+ # shuffle samples along batch dimension
469
+ perm = torch.randperm(total_batch_size, device=self.accelerator.device)
470
+ samples = {k: v[perm] for k, v in samples.items()}
471
+
472
+ # shuffle along time dimension independently for each sample
473
+ # still trying to understand the code below
474
+ perms = torch.stack(
475
+ [torch.randperm(num_timesteps, device=self.accelerator.device) for _ in range(total_batch_size)]
476
+ )
477
+
478
+ for key in ["timesteps", "latents", "next_latents", "log_probs"]:
479
+ samples[key] = samples[key][
480
+ torch.arange(total_batch_size, device=self.accelerator.device)[:, None],
481
+ perms,
482
+ ]
483
+
484
+ original_keys = samples.keys()
485
+ original_values = samples.values()
486
+ # rebatch them as user defined train_batch_size is different from sample_batch_size
487
+ reshaped_values = [v.reshape(-1, self.config.train_batch_size, *v.shape[1:]) for v in original_values]
488
+
489
+ # Transpose the list of original values
490
+ transposed_values = zip(*reshaped_values)
491
+ # Create new dictionaries for each row of transposed values
492
+ samples_batched = [dict(zip(original_keys, row_values)) for row_values in transposed_values]
493
+
494
+ self.sd_pipeline.unet.train()
495
+ global_step = self._train_batched_samples(inner_epoch, epoch, global_step, samples_batched)
496
+ # ensure optimization step at the end of the inner epoch
497
+ if not self.accelerator.sync_gradients:
498
+ raise ValueError(
499
+ "Optimization step should have been performed by this point. Please check calculated gradient accumulation settings."
500
+ )
501
+
502
+ if epoch != 0 and epoch % self.config.save_freq == 0 and self.accelerator.is_main_process:
503
+ self.accelerator.save_state()
504
+
505
+ return global_step
506
+
507
+ def calculate_loss(self, latents, timesteps, next_latents, log_probs, advantages, embeds):
508
+ """
509
+ Calculate the loss for a batch of an unpacked sample
510
+
511
+ Args:
512
+ latents (torch.Tensor):
513
+ The latents sampled from the diffusion model, shape: [batch_size, num_channels_latents, height, width]
514
+ timesteps (torch.Tensor):
515
+ The timesteps sampled from the diffusion model, shape: [batch_size]
516
+ next_latents (torch.Tensor):
517
+ The next latents sampled from the diffusion model, shape: [batch_size, num_channels_latents, height, width]
518
+ log_probs (torch.Tensor):
519
+ The log probabilities of the latents, shape: [batch_size]
520
+ advantages (torch.Tensor):
521
+ The advantages of the latents, shape: [batch_size]
522
+ embeds (torch.Tensor):
523
+ The embeddings of the prompts, shape: [2*batch_size or batch_size, ...]
524
+ Note: the "or" is because if train_cfg is True, the expectation is that negative prompts are concatenated to the embeds
525
+
526
+ Returns:
527
+ loss (torch.Tensor), approx_kl (torch.Tensor), clipfrac (torch.Tensor)
528
+ (all of these are of shape (1,))
529
+ """
530
+ with self.autocast():
531
+ if self.config.train_cfg:
532
+ noise_pred = self.sd_pipeline.unet(
533
+ torch.cat([latents] * 2),
534
+ torch.cat([timesteps] * 2),
535
+ embeds,
536
+ ).sample
537
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
538
+ noise_pred = noise_pred_uncond + self.config.sample_guidance_scale * (
539
+ noise_pred_text - noise_pred_uncond
540
+ )
541
+ else:
542
+ noise_pred = self.sd_pipeline.unet(
543
+ latents,
544
+ timesteps,
545
+ embeds,
546
+ ).sample
547
+ # compute the log prob of next_latents given latents under the current model
548
+
549
+ scheduler_step_output = self.sd_pipeline.scheduler_step(
550
+ noise_pred,
551
+ timesteps,
552
+ latents,
553
+ eta=self.config.sample_eta,
554
+ prev_sample=next_latents,
555
+ )
556
+
557
+ log_prob = scheduler_step_output.log_probs
558
+
559
+ advantages = torch.clamp(
560
+ advantages,
561
+ -self.config.train_adv_clip_max,
562
+ self.config.train_adv_clip_max,
563
+ )
564
+
565
+ ratio = torch.exp(log_prob - log_probs)
566
+
567
+ loss = self.loss(advantages, self.config.train_clip_range, ratio)
568
+
569
+ approx_kl = 0.5 * torch.mean((log_prob - log_probs) ** 2)
570
+
571
+ clipfrac = torch.mean((torch.abs(ratio - 1.0) > self.config.train_clip_range).float())
572
+
573
+ return loss, approx_kl, clipfrac
574
+
575
+ def loss(
576
+ self,
577
+ advantages: torch.Tensor,
578
+ clip_range: float,
579
+ ratio: torch.Tensor,
580
+ ):
581
+ unclipped_loss = -advantages * ratio
582
+ clipped_loss = -advantages * torch.clamp(
583
+ ratio,
584
+ 1.0 - clip_range,
585
+ 1.0 + clip_range,
586
+ )
587
+ return torch.mean(torch.maximum(unclipped_loss, clipped_loss))
588
+
589
+ def _setup_optimizer(self, trainable_layers_parameters):
590
+ if self.config.train_use_8bit_adam:
591
+ import bitsandbytes
592
+
593
+ optimizer_cls = bitsandbytes.optim.AdamW8bit
594
+ else:
595
+ optimizer_cls = torch.optim.AdamW
596
+
597
+ return optimizer_cls(
598
+ trainable_layers_parameters,
599
+ lr=self.config.train_learning_rate,
600
+ betas=(self.config.train_adam_beta1, self.config.train_adam_beta2),
601
+ weight_decay=self.config.train_adam_weight_decay,
602
+ eps=self.config.train_adam_epsilon,
603
+ )
604
+
605
+ def _save_model_hook(self, models, weights, output_dir):
606
+ self.sd_pipeline.save_checkpoint(models, weights, output_dir)
607
+ weights.pop() # ensures that accelerate doesn't try to handle saving of the model
608
+
609
+ def _load_model_hook(self, models, input_dir):
610
+ self.sd_pipeline.load_checkpoint(models, input_dir)
611
+ models.pop() # ensures that accelerate doesn't try to handle loading of the model
612
+
613
+ def _generate_samples(self, iterations, batch_size):
614
+ """
615
+ Generate samples from the model
616
+
617
+ Args:
618
+ iterations (int): Number of iterations to generate samples for
619
+ batch_size (int): Batch size to use for sampling
620
+
621
+ Returns:
622
+ samples (list[dict[str, torch.Tensor]]), prompt_image_pairs (list[list[Any]])
623
+ """
624
+ samples = []
625
+ prompt_image_pairs = []
626
+ self.sd_pipeline.unet.eval()
627
+
628
+ sample_neg_prompt_embeds = self.neg_prompt_embed.repeat(batch_size, 1, 1)
629
+
630
+ for _ in range(iterations):
631
+ prompts, prompt_metadata = zip(*[self.prompt_fn() for _ in range(batch_size)])
632
+
633
+ prompt_ids = self.sd_pipeline.tokenizer(
634
+ prompts,
635
+ return_tensors="pt",
636
+ padding="max_length",
637
+ truncation=True,
638
+ max_length=self.sd_pipeline.tokenizer.model_max_length,
639
+ ).input_ids.to(self.accelerator.device)
640
+ prompt_embeds = self.sd_pipeline.text_encoder(prompt_ids)[0]
641
+
642
+ with self.autocast():
643
+ sd_output = self.sd_pipeline(
644
+ prompt_embeds=prompt_embeds,
645
+ negative_prompt_embeds=sample_neg_prompt_embeds,
646
+ num_inference_steps=self.config.sample_num_steps,
647
+ guidance_scale=self.config.sample_guidance_scale,
648
+ eta=self.config.sample_eta,
649
+ output_type="pt",
650
+ )
651
+
652
+ images = sd_output.images
653
+ latents = sd_output.latents
654
+ log_probs = sd_output.log_probs
655
+
656
+ latents = torch.stack(latents, dim=1) # (batch_size, num_steps + 1, ...)
657
+ log_probs = torch.stack(log_probs, dim=1) # (batch_size, num_steps, 1)
658
+ timesteps = self.sd_pipeline.scheduler.timesteps.repeat(batch_size, 1) # (batch_size, num_steps)
659
+
660
+ samples.append(
661
+ {
662
+ "prompt_ids": prompt_ids,
663
+ "prompt_embeds": prompt_embeds,
664
+ "timesteps": timesteps,
665
+ "latents": latents[:, :-1], # each entry is the latent before timestep t
666
+ "next_latents": latents[:, 1:], # each entry is the latent after timestep t
667
+ "log_probs": log_probs,
668
+ "negative_prompt_embeds": sample_neg_prompt_embeds,
669
+ }
670
+ )
671
+ prompt_image_pairs.append([images, prompts, prompt_metadata])
672
+
673
+ return samples, prompt_image_pairs
674
+
675
+ def _train_batched_samples(self, inner_epoch, epoch, global_step, batched_samples):
676
+ """
677
+ Train on a batch of samples. Main training segment
678
+
679
+ Args:
680
+ inner_epoch (int): The current inner epoch
681
+ epoch (int): The current epoch
682
+ global_step (int): The current global step
683
+ batched_samples (list[dict[str, torch.Tensor]]): The batched samples to train on
684
+
685
+ Side Effects:
686
+ - Model weights are updated
687
+ - Logs the statistics to the accelerator trackers.
688
+
689
+ Returns:
690
+ global_step (int): The updated global step
691
+ """
692
+ info = defaultdict(list)
693
+ for _i, sample in enumerate(batched_samples):
694
+ if self.config.train_cfg:
695
+ # concat negative prompts to sample prompts to avoid two forward passes
696
+ embeds = torch.cat([sample["negative_prompt_embeds"], sample["prompt_embeds"]])
697
+ else:
698
+ embeds = sample["prompt_embeds"]
699
+
700
+ for j in range(self.num_train_timesteps):
701
+ with self.accelerator.accumulate(self.sd_pipeline.unet):
702
+ loss, approx_kl, clipfrac = self.calculate_loss(
703
+ sample["latents"][:, j],
704
+ sample["timesteps"][:, j],
705
+ sample["next_latents"][:, j],
706
+ sample["log_probs"][:, j],
707
+ sample["advantages"],
708
+ embeds,
709
+ )
710
+ info["approx_kl"].append(approx_kl)
711
+ info["clipfrac"].append(clipfrac)
712
+ info["loss"].append(loss)
713
+
714
+ self.accelerator.backward(loss)
715
+ if self.accelerator.sync_gradients:
716
+ self.accelerator.clip_grad_norm_(
717
+ self.trainable_layers.parameters()
718
+ if not isinstance(self.trainable_layers, list)
719
+ else self.trainable_layers,
720
+ self.config.train_max_grad_norm,
721
+ )
722
+ self.optimizer.step()
723
+ self.optimizer.zero_grad()
724
+
725
+ # Checks if the accelerator has performed an optimization step behind the scenes
726
+ if self.accelerator.sync_gradients:
727
+ # log training-related stuff
728
+ info = {k: torch.mean(torch.stack(v)) for k, v in info.items()}
729
+ info = self.accelerator.reduce(info, reduction="mean")
730
+ info.update({"epoch": epoch, "inner_epoch": inner_epoch})
731
+ self.accelerator.log(info, step=global_step)
732
+ global_step += 1
733
+ info = defaultdict(list)
734
+ return global_step
735
+
736
+ def _config_check(self) -> tuple[bool, str]:
737
+ samples_per_epoch = (
738
+ self.config.sample_batch_size * self.accelerator.num_processes * self.config.sample_num_batches_per_epoch
739
+ )
740
+ total_train_batch_size = (
741
+ self.config.train_batch_size
742
+ * self.accelerator.num_processes
743
+ * self.config.train_gradient_accumulation_steps
744
+ )
745
+
746
+ if not self.config.sample_batch_size >= self.config.train_batch_size:
747
+ return (
748
+ False,
749
+ f"Sample batch size ({self.config.sample_batch_size}) must be greater than or equal to the train batch size ({self.config.train_batch_size})",
750
+ )
751
+ if not self.config.sample_batch_size % self.config.train_batch_size == 0:
752
+ return (
753
+ False,
754
+ f"Sample batch size ({self.config.sample_batch_size}) must be divisible by the train batch size ({self.config.train_batch_size})",
755
+ )
756
+ if not samples_per_epoch % total_train_batch_size == 0:
757
+ return (
758
+ False,
759
+ f"Number of samples per epoch ({samples_per_epoch}) must be divisible by the total train batch size ({total_train_batch_size})",
760
+ )
761
+ return True, ""
762
+
763
+ def train(self, epochs: Optional[int] = None):
764
+ """
765
+ Train the model for a given number of epochs
766
+ """
767
+ global_step = 0
768
+ if epochs is None:
769
+ epochs = self.config.num_epochs
770
+ for epoch in range(self.first_epoch, epochs):
771
+ global_step = self.step(epoch, global_step)
772
+
773
+ def _save_pretrained(self, save_directory):
774
+ self.sd_pipeline.save_pretrained(save_directory)
775
+ self.create_model_card()
776
+
777
+ def create_model_card(
778
+ self,
779
+ model_name: Optional[str] = None,
780
+ dataset_name: Optional[str] = None,
781
+ tags: Union[str, list[str], None] = None,
782
+ ):
783
+ """
784
+ Creates a draft of a model card using the information available to the `Trainer`.
785
+
786
+ Args:
787
+ model_name (`str` or `None`, *optional*, defaults to `None`):
788
+ Name of the model.
789
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
790
+ Name of the dataset used for training.
791
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
792
+ Tags to be associated with the model card.
793
+ """
794
+ if not self.is_world_process_zero():
795
+ return
796
+
797
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
798
+ base_model = self.model.config._name_or_path
799
+ else:
800
+ base_model = None
801
+
802
+ tags = tags or []
803
+ if isinstance(tags, str):
804
+ tags = [tags]
805
+
806
+ if hasattr(self.model.config, "unsloth_version"):
807
+ tags.append("unsloth")
808
+
809
+ citation = textwrap.dedent("""\
810
+ @inproceedings{black2024training,
811
+ title = {{Training Diffusion Models with Reinforcement Learning}},
812
+ author = {Kevin Black and Michael Janner and Yilun Du and Ilya Kostrikov and Sergey Levine},
813
+ year = 2024,
814
+ booktitle = {The Twelfth International Conference on Learning Representations, {ICLR} 2024, Vienna, Austria, May 7-11, 2024},
815
+ publisher = {OpenReview.net},
816
+ url = {https://openreview.net/forum?id=YCWjhGrJFD},
817
+ }""")
818
+
819
+ model_card = generate_model_card(
820
+ base_model=base_model,
821
+ model_name=model_name,
822
+ hub_model_id=self.hub_model_id,
823
+ dataset_name=dataset_name,
824
+ tags=tags,
825
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
826
+ comet_url=get_comet_experiment_url(),
827
+ trainer_name="DDPO",
828
+ trainer_citation=citation,
829
+ paper_title="Training Diffusion Models with Reinforcement Learning",
830
+ paper_id="2305.13301",
831
+ )
832
+
833
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
834
+ class UnslothDDPOTrainer(_UnslothDDPOTrainer):
835
+ """
836
+
837
+ The DDPOTrainer uses Deep Diffusion Policy Optimization to optimise diffusion models.
838
+ Note, this trainer is heavily inspired by the work here: https://github.com/kvablack/ddpo-pytorch
839
+ As of now only Stable Diffusion based pipelines are supported
840
+
841
+ Attributes:
842
+ **config** (`DDPOConfig`) -- Configuration object for DDPOTrainer. Check the documentation of `PPOConfig` for more
843
+ details.
844
+ **reward_function** (Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor]) -- Reward function to be used
845
+ **prompt_function** (Callable[[], tuple[str, Any]]) -- Function to generate prompts to guide model
846
+ **sd_pipeline** (`DDPOStableDiffusionPipeline`) -- Stable Diffusion pipeline to be used for training.
847
+ **image_samples_hook** (Optional[Callable[[Any, Any, Any], Any]]) -- Hook to be called to log images
848
+
849
+ """
850
+ def __init__(
851
+ self,
852
+ config,
853
+ reward_function,
854
+ prompt_function,
855
+ sd_pipeline,
856
+ image_samples_hook = None,
857
+ **kwargs
858
+ ):
859
+ if args is None: args = UnslothDDPOConfig()
860
+ other_metrics = []
861
+
862
+ from unsloth_zoo.logging_utils import PatchRLStatistics
863
+ PatchRLStatistics('ddpo_trainer', other_metrics)
864
+
865
+ super().__init__(
866
+ config = config,
867
+ reward_function = reward_function,
868
+ prompt_function = prompt_function,
869
+ sd_pipeline = sd_pipeline,
870
+ image_samples_hook = image_samples_hook,**kwargs)
871
+
872
+ pass
unsloth_compiled_cache/UnslothDPOTrainer.py ADDED
The diff for this file is too large to render. See raw diff
 
unsloth_compiled_cache/UnslothGKDTrainer.py ADDED
@@ -0,0 +1,863 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.3.17
3
+ 2025.3.19
4
+ 4.50.3
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from trl.trainer.gkd_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, Callable, DataCollator, DataCollatorForChatML, Dataset, EvalPrediction, F, FeatureExtractionMixin, GKDConfig, GKDTrainer, GenerationConfig, Optional, PeftConfig, PreTrainedModel, PreTrainedModelWrapper, PreTrainedTokenizerBase, ProcessorMixin, SFTTrainer, TrainerCallback, Union, deepcopy, deepspeed, disable_dropout_in_model, empty_cache, generate_model_card, get_comet_experiment_url, is_wandb_available, nn, os, random, textwrap, torch, unwrap_model_for_generation, wandb)
13
+
14
+
15
+ import os
16
+ from typing import *
17
+ from dataclasses import dataclass, field
18
+ from packaging.version import Version
19
+ import torch
20
+ import numpy as np
21
+ from contextlib import nullcontext
22
+ from torch.nn import functional as F
23
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
+
25
+ torch_compile_options = {
26
+ "epilogue_fusion" : True,
27
+ "max_autotune" : False,
28
+ "shape_padding" : True,
29
+ "trace.enabled" : False,
30
+ "triton.cudagraphs" : False,
31
+ }
32
+
33
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
+ def selective_log_softmax(logits, index):
35
+ logits = logits.to(torch.float32)
36
+ selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
+ # loop to reduce peak mem consumption
38
+ # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
+ logsumexp_values = torch.logsumexp(logits, dim = -1)
40
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
+ return per_token_logps
42
+ @dataclass
43
+ class UnslothGKDConfig(GKDConfig):
44
+ """
45
+
46
+ Configuration class for [`GKDTrainer`].
47
+
48
+ Args:
49
+ temperature (`float`, *optional*, defaults to `0.9`):
50
+ Temperature for sampling. The higher the temperature, the more random the completions.
51
+ lmbda (`float`, *optional*, defaults to `0.5`):
52
+ Lambda parameter that controls the student data fraction (i.e., the proportion of on-policy
53
+ student-generated outputs).
54
+ beta (`float`, *optional*, defaults to `0.5`):
55
+ Interpolation coefficient between `0.0` and `1.0` of the Generalized Jensen-Shannon Divergence loss. When
56
+ beta is `0.0`, the loss is the KL divergence. When beta is `1.0`, the loss is the Inverse KL Divergence.
57
+ max_new_tokens (`int`, *optional*, defaults to `128`):
58
+ Maximum number of tokens to generate per completion.
59
+ teacher_model_name_or_path (`str` or `None`, *optional*, defaults to `None`):
60
+ Model name or path of the teacher model. If `None`, the teacher model will be the same as the model
61
+ being trained.
62
+ teacher_model_init_kwargs (`dict[str, Any]]` or `None`, *optional*, defaults to `None`):
63
+ Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the teacher model
64
+ from a string.
65
+ disable_dropout (`bool`, *optional*, defaults to `True`):
66
+ Whether to disable dropout in the model.
67
+ seq_kd (`bool`, *optional*, defaults to `False`):
68
+ Seq_kd parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised FT
69
+ on teacher-generated output).
70
+
71
+ """
72
+ vllm_sampling_params: Optional[Any] = field(
73
+ default = None,
74
+ metadata = {'help': 'vLLM SamplingParams'},
75
+ )
76
+ unsloth_num_chunks : Optional[int] = field(
77
+ default = -1,
78
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
79
+ )
80
+ def __init__(
81
+ self,
82
+ output_dir = None,
83
+ overwrite_output_dir = None,
84
+ do_train = False,
85
+ do_eval = False,
86
+ do_predict = False,
87
+ eval_strategy = 'no',
88
+ prediction_loss_only = False,
89
+ per_device_train_batch_size = 4,
90
+ per_device_eval_batch_size = 4,
91
+ per_gpu_train_batch_size = None,
92
+ per_gpu_eval_batch_size = None,
93
+ gradient_accumulation_steps = 2,
94
+ eval_accumulation_steps = 2,
95
+ eval_delay = 0,
96
+ torch_empty_cache_steps = 250,
97
+ learning_rate = 5e-05,
98
+ weight_decay = 0.01,
99
+ adam_beta1 = 0.9,
100
+ adam_beta2 = 0.999,
101
+ adam_epsilon = 1e-08,
102
+ max_grad_norm = 1.0,
103
+ num_train_epochs = 3.0,
104
+ max_steps = -1,
105
+ lr_scheduler_type = 'linear',
106
+ warmup_ratio = 0.1,
107
+ warmup_steps = 0,
108
+ log_level = 'passive',
109
+ log_level_replica = 'warning',
110
+ log_on_each_node = True,
111
+ logging_dir = None,
112
+ logging_strategy = 'steps',
113
+ logging_first_step = False,
114
+ logging_steps = 1,
115
+ logging_nan_inf_filter = False,
116
+ save_strategy = 'steps',
117
+ save_steps = 500,
118
+ save_total_limit = None,
119
+ save_safetensors = True,
120
+ save_on_each_node = False,
121
+ save_only_model = False,
122
+ restore_callback_states_from_checkpoint = False,
123
+ no_cuda = False,
124
+ use_cpu = False,
125
+ use_mps_device = False,
126
+ seed = 3407,
127
+ data_seed = 3407,
128
+ jit_mode_eval = False,
129
+ use_ipex = False,
130
+ bf16 = False,
131
+ fp16 = False,
132
+ fp16_opt_level = 'O1',
133
+ half_precision_backend = 'auto',
134
+ bf16_full_eval = False,
135
+ fp16_full_eval = False,
136
+ tf32 = None,
137
+ local_rank = -1,
138
+ ddp_backend = None,
139
+ tpu_num_cores = None,
140
+ tpu_metrics_debug = False,
141
+ debug = '',
142
+ dataloader_drop_last = False,
143
+ eval_steps = None,
144
+ dataloader_num_workers = 0,
145
+ dataloader_prefetch_factor = None,
146
+ past_index = -1,
147
+ run_name = None,
148
+ disable_tqdm = None,
149
+ remove_unused_columns = True,
150
+ label_names = None,
151
+ load_best_model_at_end = False,
152
+ metric_for_best_model = None,
153
+ greater_is_better = None,
154
+ ignore_data_skip = False,
155
+ fsdp = '',
156
+ fsdp_min_num_params = 0,
157
+ fsdp_config = None,
158
+ tp_size = 0,
159
+ fsdp_transformer_layer_cls_to_wrap = None,
160
+ accelerator_config = None,
161
+ deepspeed = None,
162
+ label_smoothing_factor = 0.0,
163
+ optim = 'adamw_8bit',
164
+ optim_args = None,
165
+ adafactor = False,
166
+ group_by_length = False,
167
+ length_column_name = 'length',
168
+ report_to = None,
169
+ ddp_find_unused_parameters = None,
170
+ ddp_bucket_cap_mb = None,
171
+ ddp_broadcast_buffers = None,
172
+ dataloader_pin_memory = True,
173
+ dataloader_persistent_workers = False,
174
+ skip_memory_metrics = True,
175
+ use_legacy_prediction_loop = False,
176
+ push_to_hub = False,
177
+ resume_from_checkpoint = None,
178
+ hub_model_id = None,
179
+ hub_strategy = 'every_save',
180
+ hub_token = None,
181
+ hub_private_repo = None,
182
+ hub_always_push = False,
183
+ gradient_checkpointing = False,
184
+ gradient_checkpointing_kwargs = None,
185
+ include_inputs_for_metrics = False,
186
+ eval_do_concat_batches = True,
187
+ fp16_backend = 'auto',
188
+ evaluation_strategy = None,
189
+ push_to_hub_model_id = None,
190
+ push_to_hub_organization = None,
191
+ push_to_hub_token = None,
192
+ mp_parameters = '',
193
+ auto_find_batch_size = False,
194
+ full_determinism = False,
195
+ torchdynamo = None,
196
+ ray_scope = 'last',
197
+ ddp_timeout = 1800,
198
+ torch_compile = False,
199
+ torch_compile_backend = None,
200
+ torch_compile_mode = None,
201
+ dispatch_batches = None,
202
+ split_batches = None,
203
+ include_tokens_per_second = False,
204
+ include_num_input_tokens_seen = False,
205
+ neftune_noise_alpha = None,
206
+ optim_target_modules = None,
207
+ batch_eval_metrics = False,
208
+ eval_on_start = False,
209
+ use_liger_kernel = False,
210
+ eval_use_gather_object = False,
211
+ average_tokens_across_devices = False,
212
+ model_init_kwargs = None,
213
+ use_liger = False,
214
+ dataset_text_field = 'text',
215
+ dataset_kwargs = None,
216
+ dataset_num_proc = None,
217
+ max_seq_length = None,
218
+ packing = False,
219
+ eval_packing = None,
220
+ dataset_batch_size = None,
221
+ num_of_sequences = None,
222
+ chars_per_token = None,
223
+ temperature = 0.9,
224
+ lmbda = 0.5,
225
+ beta = 0.5,
226
+ max_new_tokens = 128,
227
+ teacher_model_name_or_path = None,
228
+ teacher_model_init_kwargs = None,
229
+ disable_dropout = True,
230
+ seq_kd = False,
231
+ vllm_sampling_params = None,
232
+ unsloth_num_chunks = -1,
233
+ **kwargs,
234
+ ):
235
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
236
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
237
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
238
+ output_dir = 'unsloth_training_checkpoints'
239
+ save_strategy = 'no'
240
+ if dataset_num_proc is None:
241
+ from multiprocessing import cpu_count
242
+ dataset_num_proc = cpu_count()
243
+
244
+ super().__init__(
245
+ output_dir = output_dir,
246
+ overwrite_output_dir = overwrite_output_dir,
247
+ do_train = do_train,
248
+ do_eval = do_eval,
249
+ do_predict = do_predict,
250
+ eval_strategy = eval_strategy,
251
+ prediction_loss_only = prediction_loss_only,
252
+ per_device_train_batch_size = per_device_train_batch_size,
253
+ per_device_eval_batch_size = per_device_eval_batch_size,
254
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
255
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
256
+ gradient_accumulation_steps = gradient_accumulation_steps,
257
+ eval_accumulation_steps = eval_accumulation_steps,
258
+ eval_delay = eval_delay,
259
+ torch_empty_cache_steps = torch_empty_cache_steps,
260
+ learning_rate = learning_rate,
261
+ weight_decay = weight_decay,
262
+ adam_beta1 = adam_beta1,
263
+ adam_beta2 = adam_beta2,
264
+ adam_epsilon = adam_epsilon,
265
+ max_grad_norm = max_grad_norm,
266
+ num_train_epochs = num_train_epochs,
267
+ max_steps = max_steps,
268
+ lr_scheduler_type = lr_scheduler_type,
269
+ warmup_ratio = warmup_ratio,
270
+ warmup_steps = warmup_steps,
271
+ log_level = log_level,
272
+ log_level_replica = log_level_replica,
273
+ log_on_each_node = log_on_each_node,
274
+ logging_dir = logging_dir,
275
+ logging_strategy = logging_strategy,
276
+ logging_first_step = logging_first_step,
277
+ logging_steps = logging_steps,
278
+ logging_nan_inf_filter = logging_nan_inf_filter,
279
+ save_strategy = save_strategy,
280
+ save_steps = save_steps,
281
+ save_total_limit = save_total_limit,
282
+ save_safetensors = save_safetensors,
283
+ save_on_each_node = save_on_each_node,
284
+ save_only_model = save_only_model,
285
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
286
+ no_cuda = no_cuda,
287
+ use_cpu = use_cpu,
288
+ use_mps_device = use_mps_device,
289
+ seed = seed,
290
+ data_seed = data_seed,
291
+ jit_mode_eval = jit_mode_eval,
292
+ use_ipex = use_ipex,
293
+ bf16 = bf16,
294
+ fp16 = fp16,
295
+ fp16_opt_level = fp16_opt_level,
296
+ half_precision_backend = half_precision_backend,
297
+ bf16_full_eval = bf16_full_eval,
298
+ fp16_full_eval = fp16_full_eval,
299
+ tf32 = tf32,
300
+ local_rank = local_rank,
301
+ ddp_backend = ddp_backend,
302
+ tpu_num_cores = tpu_num_cores,
303
+ tpu_metrics_debug = tpu_metrics_debug,
304
+ debug = debug,
305
+ dataloader_drop_last = dataloader_drop_last,
306
+ eval_steps = eval_steps,
307
+ dataloader_num_workers = dataloader_num_workers,
308
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
309
+ past_index = past_index,
310
+ run_name = run_name,
311
+ disable_tqdm = disable_tqdm,
312
+ remove_unused_columns = remove_unused_columns,
313
+ label_names = label_names,
314
+ load_best_model_at_end = load_best_model_at_end,
315
+ metric_for_best_model = metric_for_best_model,
316
+ greater_is_better = greater_is_better,
317
+ ignore_data_skip = ignore_data_skip,
318
+ fsdp = fsdp,
319
+ fsdp_min_num_params = fsdp_min_num_params,
320
+ fsdp_config = fsdp_config,
321
+ tp_size = tp_size,
322
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
323
+ accelerator_config = accelerator_config,
324
+ deepspeed = deepspeed,
325
+ label_smoothing_factor = label_smoothing_factor,
326
+ optim = optim,
327
+ optim_args = optim_args,
328
+ adafactor = adafactor,
329
+ group_by_length = group_by_length,
330
+ length_column_name = length_column_name,
331
+ report_to = report_to,
332
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
333
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
334
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
335
+ dataloader_pin_memory = dataloader_pin_memory,
336
+ dataloader_persistent_workers = dataloader_persistent_workers,
337
+ skip_memory_metrics = skip_memory_metrics,
338
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
339
+ push_to_hub = push_to_hub,
340
+ resume_from_checkpoint = resume_from_checkpoint,
341
+ hub_model_id = hub_model_id,
342
+ hub_strategy = hub_strategy,
343
+ hub_token = hub_token,
344
+ hub_private_repo = hub_private_repo,
345
+ hub_always_push = hub_always_push,
346
+ gradient_checkpointing = gradient_checkpointing,
347
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
348
+ include_inputs_for_metrics = include_inputs_for_metrics,
349
+ eval_do_concat_batches = eval_do_concat_batches,
350
+ fp16_backend = fp16_backend,
351
+ evaluation_strategy = evaluation_strategy,
352
+ push_to_hub_model_id = push_to_hub_model_id,
353
+ push_to_hub_organization = push_to_hub_organization,
354
+ push_to_hub_token = push_to_hub_token,
355
+ mp_parameters = mp_parameters,
356
+ auto_find_batch_size = auto_find_batch_size,
357
+ full_determinism = full_determinism,
358
+ torchdynamo = torchdynamo,
359
+ ray_scope = ray_scope,
360
+ ddp_timeout = ddp_timeout,
361
+ torch_compile = torch_compile,
362
+ torch_compile_backend = torch_compile_backend,
363
+ torch_compile_mode = torch_compile_mode,
364
+ dispatch_batches = dispatch_batches,
365
+ split_batches = split_batches,
366
+ include_tokens_per_second = include_tokens_per_second,
367
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
368
+ neftune_noise_alpha = neftune_noise_alpha,
369
+ optim_target_modules = optim_target_modules,
370
+ batch_eval_metrics = batch_eval_metrics,
371
+ eval_on_start = eval_on_start,
372
+ use_liger_kernel = use_liger_kernel,
373
+ eval_use_gather_object = eval_use_gather_object,
374
+ average_tokens_across_devices = average_tokens_across_devices,
375
+ model_init_kwargs = model_init_kwargs,
376
+ use_liger = use_liger,
377
+ dataset_text_field = dataset_text_field,
378
+ dataset_kwargs = dataset_kwargs,
379
+ dataset_num_proc = dataset_num_proc,
380
+ max_seq_length = max_seq_length,
381
+ packing = packing,
382
+ eval_packing = eval_packing,
383
+ dataset_batch_size = dataset_batch_size,
384
+ num_of_sequences = num_of_sequences,
385
+ chars_per_token = chars_per_token,
386
+ temperature = temperature,
387
+ lmbda = lmbda,
388
+ beta = beta,
389
+ max_new_tokens = max_new_tokens,
390
+ teacher_model_name_or_path = teacher_model_name_or_path,
391
+ teacher_model_init_kwargs = teacher_model_init_kwargs,
392
+ disable_dropout = disable_dropout,
393
+ seq_kd = seq_kd,**kwargs)
394
+ self.vllm_sampling_params = vllm_sampling_params
395
+ self.unsloth_num_chunks = unsloth_num_chunks
396
+ pass
397
+
398
+ class _UnslothGKDTrainer(SFTTrainer):
399
+ _tag_names = ["trl", "gkd"]
400
+
401
+ def __init__(
402
+ self,
403
+ model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
404
+ teacher_model: Union[PreTrainedModel, nn.Module, str] = None,
405
+ args: Optional[GKDConfig] = None,
406
+ data_collator: Optional[DataCollator] = None, # type: ignore
407
+ train_dataset: Optional[Dataset] = None,
408
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
409
+ processing_class: Optional[
410
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
411
+ ] = None,
412
+ compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
413
+ callbacks: Optional[list[TrainerCallback]] = None,
414
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
415
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
416
+ peft_config: Optional["PeftConfig"] = None,
417
+ formatting_func: Optional[Callable] = None,
418
+ ):
419
+ # add remove_unused_columns=False to the dataclass args
420
+ args.remove_unused_columns = False
421
+ data_collator = DataCollatorForChatML(tokenizer=processing_class, max_length=args.max_seq_length)
422
+
423
+ super().__init__(
424
+ model,
425
+ args=args,
426
+ data_collator=data_collator,
427
+ train_dataset=train_dataset,
428
+ eval_dataset=eval_dataset,
429
+ processing_class=processing_class,
430
+ compute_metrics=compute_metrics,
431
+ callbacks=callbacks,
432
+ optimizers=optimizers,
433
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
434
+ peft_config=peft_config,
435
+ formatting_func=formatting_func,
436
+ )
437
+
438
+ if args.teacher_model_init_kwargs is None:
439
+ teacher_model_init_kwargs = {}
440
+ elif not isinstance(teacher_model, str):
441
+ raise ValueError(
442
+ "You passed teacher_model_init_kwargs to the GKDConfig, but your teacher_model is already instantiated."
443
+ )
444
+ else:
445
+ teacher_model_init_kwargs = args.teacher_model_init_kwargs
446
+ teacher_model_init_kwargs["torch_dtype"] = (
447
+ teacher_model_init_kwargs["torch_dtype"]
448
+ if teacher_model_init_kwargs["torch_dtype"] in ["auto", None]
449
+ else getattr(torch, teacher_model_init_kwargs["torch_dtype"])
450
+ )
451
+
452
+ if isinstance(teacher_model, str):
453
+ if args.use_liger:
454
+ teacher_model = AutoLigerKernelForCausalLM.from_pretrained(teacher_model, **teacher_model_init_kwargs)
455
+ else:
456
+ teacher_model = AutoModelForCausalLM.from_pretrained(teacher_model, **teacher_model_init_kwargs)
457
+
458
+ # Disable dropout in the model
459
+ if args.disable_dropout:
460
+ disable_dropout_in_model(self.model)
461
+
462
+ if self.is_deepspeed_enabled:
463
+ self.teacher_model = self._prepare_deepspeed(teacher_model)
464
+ else:
465
+ self.teacher_model = self.accelerator.prepare_model(teacher_model, evaluation_mode=True)
466
+
467
+ self.lmbda = args.lmbda
468
+ self.beta = args.beta
469
+ self.temperature = args.temperature
470
+ self.seq_kd = args.seq_kd
471
+
472
+ self.generation_config = GenerationConfig(
473
+ max_new_tokens=args.max_new_tokens,
474
+ temperature=args.temperature,
475
+ do_sample=True,
476
+ top_k=0,
477
+ use_cache=False if args.gradient_checkpointing else True,
478
+ pad_token_id=self.processing_class.pad_token_id,
479
+ )
480
+ # Set custom EOS tokens if they are specified by the model's generation
481
+ # config. This is important for models with the Llama 3 chat template,
482
+ # which use special tokens <|eot_id|> and <|eom_id|> to mark the end of
483
+ # turns or messages.
484
+ if (
485
+ hasattr(self.model.generation_config, "eos_token_id")
486
+ and self.model.generation_config.eos_token_id is not None
487
+ ):
488
+ self.generation_config.eos_token_id = self.model.generation_config.eos_token_id
489
+
490
+ def _prepare_dataset(self, dataset, *args):
491
+ # SFTTrainer._prepare_dataset() applies the chat template and rename the messages column to text. However, we
492
+ # need to keep the messages column as it is. We use the following workaround to keep the messages column.
493
+ dataset = dataset.add_column("_messages", dataset["messages"])
494
+ dataset = super()._prepare_dataset(dataset, *args)
495
+ dataset = dataset.rename_column("_messages", "messages")
496
+ return dataset
497
+
498
+ @staticmethod
499
+ def generalized_jsd_loss(
500
+ student_logits, teacher_logits, labels=None, beta=0.5, temperature=1.0, reduction="batchmean"
501
+ ):
502
+ """
503
+ Compute the generalized Jensen-Shannon Divergence loss for knowledge distillation using F.kl_div. See Eq. (1)
504
+ of https://huggingface.co/papers/2306.13649 for the definition.
505
+
506
+ Args:
507
+ student_logits: Tensor of shape (batch_size, sequence_length, vocab_size)
508
+ teacher_logits: Tensor of shape (batch_size, sequence_length, vocab_size)
509
+ labels: Tensor of shape (batch_size, sequence_length) with -100 for padding tokens to ignore when computing loss
510
+ beta: Interpolation coefficient between 0 and 1 (default: 0.5)
511
+ temperature: Softmax temperature (default: 1.0)
512
+ reduction: Specifies the reduction to apply to the output (default: 'batchmean')
513
+
514
+ Returns:
515
+ loss: Scalar tensor with the generalized JSD loss
516
+ """
517
+
518
+ # Apply temperature scaling
519
+ student_logits = student_logits / temperature
520
+ teacher_logits = teacher_logits / temperature
521
+
522
+ # Compute log probabilities for student and probabilities for teacher
523
+ student_log_probs = F.log_softmax(student_logits, dim=-1)
524
+ teacher_log_probs = F.log_softmax(teacher_logits, dim=-1)
525
+
526
+ # Compute the log of the mixture distribution
527
+ # log(a + b) = log(exp(log(a)) + exp(log(b))) -> for mixture
528
+ beta = torch.tensor(beta, dtype=student_log_probs.dtype)
529
+ mixture_log_probs = torch.logsumexp(
530
+ torch.stack([student_log_probs + torch.log(beta), teacher_log_probs + torch.log(1 - beta)]),
531
+ dim=0,
532
+ )
533
+
534
+ # Compute KL divergences using F.kl_div
535
+ # PyTorch differs from the standard mathematical definition, so the order of the probability distributions is swapped compared to that defined in the paper.
536
+ kl_teacher = F.kl_div(mixture_log_probs, teacher_log_probs, reduction="none", log_target=True)
537
+ kl_student = F.kl_div(mixture_log_probs, student_log_probs, reduction="none", log_target=True)
538
+
539
+ # Compute the Generalized Jensen-Shannon Divergence
540
+ jsd = beta * kl_teacher + (1 - beta) * kl_student
541
+
542
+ # Masking
543
+ if labels is not None:
544
+ mask = labels != -100
545
+ jsd = jsd[mask]
546
+
547
+ # Apply reduction
548
+ if reduction == "batchmean":
549
+ return jsd.sum() / mask.sum() if labels is not None else jsd.sum() / (jsd.size(0) * jsd.size(1))
550
+ elif reduction == "sum":
551
+ return jsd.sum()
552
+ elif reduction == "mean":
553
+ return jsd.mean()
554
+ else:
555
+ return jsd
556
+
557
+ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
558
+ # compute student output
559
+ outputs_student = model(
560
+ input_ids=inputs["input_ids"],
561
+ attention_mask=inputs["attention_mask"],
562
+ )
563
+
564
+ # compute teacher output in eval mode
565
+ self.teacher_model.eval()
566
+ with torch.no_grad():
567
+ outputs_teacher = self.teacher_model(
568
+ input_ids=inputs["input_ids"],
569
+ attention_mask=inputs["attention_mask"],
570
+ )
571
+
572
+ # slice the logits for the generated tokens using the inputs["prompts"] lengths
573
+ prompt_lengths = inputs["prompts"].shape[1]
574
+ shifted_student_logits = outputs_student.logits[:, prompt_lengths - 1 : -1, :]
575
+ shifted_teacher_logits = outputs_teacher.logits[:, prompt_lengths - 1 : -1, :]
576
+ shifted_labels = inputs["labels"][:, prompt_lengths:]
577
+
578
+ # compute loss
579
+ loss = self.generalized_jsd_loss(
580
+ student_logits=shifted_student_logits,
581
+ teacher_logits=shifted_teacher_logits,
582
+ labels=shifted_labels,
583
+ beta=self.beta,
584
+ )
585
+
586
+ # empty cache
587
+ empty_cache()
588
+
589
+ # Return loss
590
+ return (loss, outputs_student) if return_outputs else loss
591
+
592
+ @staticmethod
593
+ def generate_on_policy_outputs(model, inputs, generation_config, pad_token_id=None):
594
+ # Generate output with respect to the prompt only
595
+ generated_outputs = model.generate(
596
+ input_ids=inputs["prompts"],
597
+ attention_mask=inputs.get("prompt_attention_mask", None),
598
+ generation_config=generation_config,
599
+ return_dict_in_generate=True,
600
+ )
601
+
602
+ # Get the generated token IDs
603
+ generated_tokens = generated_outputs.sequences
604
+ # Calculate new attention mask
605
+ new_attention_mask = torch.ones_like(generated_tokens)
606
+ new_labels = generated_tokens.clone()
607
+
608
+ # If there's pad_token_id, set attention mask to 0 for padding tokens
609
+ if pad_token_id is not None:
610
+ new_labels[new_labels == pad_token_id] = -100
611
+ new_attention_mask[generated_tokens == pad_token_id] = 0
612
+
613
+ return generated_tokens, new_attention_mask, new_labels
614
+
615
+ def training_step(
616
+ self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
617
+ ) -> torch.Tensor:
618
+ """
619
+ Perform a training step for the Generalized Knowledge Distillation (GKD) model.
620
+
621
+ This method implements the on-policy learning approach described in the GKD paper.
622
+ With probability `self.lmbda`, it generates new responses using the student model,
623
+ which are then used for training instead of the original inputs.
624
+ """
625
+ if self.seq_kd:
626
+ with unwrap_model_for_generation(self.teacher_model, self.accelerator) as unwrapped_model:
627
+ new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs(
628
+ unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id
629
+ )
630
+ inputs["input_ids"] = new_input_ids
631
+ inputs["attention_mask"] = new_attention_mask
632
+ inputs["labels"] = new_labels
633
+ if random.random() <= self.lmbda:
634
+ with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
635
+ new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs(
636
+ unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id
637
+ )
638
+ inputs["input_ids"] = new_input_ids
639
+ inputs["attention_mask"] = new_attention_mask
640
+ inputs["labels"] = new_labels
641
+
642
+ loss = super().training_step(model, inputs, num_items_in_batch)
643
+ return loss
644
+
645
+ def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
646
+ # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
647
+ deepspeed_plugin = self.accelerator.state.deepspeed_plugin
648
+ config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
649
+
650
+ if model is not None:
651
+ if hasattr(model, "config"):
652
+ hidden_size = (
653
+ max(model.config.hidden_sizes)
654
+ if getattr(model.config, "hidden_sizes", None)
655
+ else getattr(model.config, "hidden_size", None)
656
+ )
657
+ if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
658
+ # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
659
+ # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
660
+ config_kwargs.update(
661
+ {
662
+ "zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
663
+ "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
664
+ "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
665
+ }
666
+ )
667
+
668
+ # If ZeRO-3 is used, we shard both the active and reference model.
669
+ # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
670
+ if config_kwargs["zero_optimization"]["stage"] != 3:
671
+ config_kwargs["zero_optimization"]["stage"] = 0
672
+ model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
673
+ model.eval()
674
+ return model
675
+
676
+ def create_model_card(
677
+ self,
678
+ model_name: Optional[str] = None,
679
+ dataset_name: Optional[str] = None,
680
+ tags: Union[str, list[str], None] = None,
681
+ ):
682
+ """
683
+ Creates a draft of a model card using the information available to the `Trainer`.
684
+
685
+ Args:
686
+ model_name (`str` or `None`, *optional*, defaults to `None`):
687
+ Name of the model.
688
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
689
+ Name of the dataset used for training.
690
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
691
+ Tags to be associated with the model card.
692
+ """
693
+ if not self.is_world_process_zero():
694
+ return
695
+
696
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
697
+ base_model = self.model.config._name_or_path
698
+ else:
699
+ base_model = None
700
+
701
+ tags = tags or []
702
+ if isinstance(tags, str):
703
+ tags = [tags]
704
+
705
+ if hasattr(self.model.config, "unsloth_version"):
706
+ tags.append("unsloth")
707
+
708
+ citation = textwrap.dedent("""\
709
+ @inproceedings{agarwal2024on-policy,
710
+ title = {{On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes}},
711
+ author = {Rishabh Agarwal and Nino Vieillard and Yongchao Zhou and Piotr Stanczyk and Sabela Ramos Garea and Matthieu Geist and Olivier Bachem},
712
+ year = 2024,
713
+ booktitle = {The Twelfth International Conference on Learning Representations, {ICLR} 2024, Vienna, Austria, May 7-11, 2024},
714
+ publisher = {OpenReview.net},
715
+ url = {https://openreview.net/forum?id=3zKtaqxLhW},
716
+ }""")
717
+
718
+ model_card = generate_model_card(
719
+ base_model=base_model,
720
+ model_name=model_name,
721
+ hub_model_id=self.hub_model_id,
722
+ dataset_name=dataset_name,
723
+ tags=tags,
724
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
725
+ comet_url=get_comet_experiment_url(),
726
+ trainer_name="GKD",
727
+ trainer_citation=citation,
728
+ paper_title="On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes",
729
+ paper_id="2306.13649",
730
+ )
731
+
732
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
733
+ class UnslothGKDTrainer(_UnslothGKDTrainer):
734
+ """
735
+
736
+ """
737
+ def __init__(
738
+ self,
739
+ model = None,
740
+ teacher_model = None,
741
+ args = None,
742
+ data_collator = None,
743
+ train_dataset = None,
744
+ eval_dataset = None,
745
+ processing_class = None,
746
+ compute_metrics = None,
747
+ callbacks = None,
748
+ preprocess_logits_for_metrics = None,
749
+ peft_config = None,
750
+ formatting_func = None,
751
+ **kwargs
752
+ ):
753
+ if args is None: args = UnslothGKDConfig()
754
+ use_bf16 = getattr(args, 'bf16', False)
755
+ use_fp16 = getattr(args, 'fp16', False)
756
+ force_float32 = False
757
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
758
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
759
+ force_float32 = True
760
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
761
+ dtype = getattr(model.config, 'torch_dtype', None)
762
+ if dtype is None: dtype = model.get_input_embeddings().dtype
763
+ from unsloth_zoo.utils import _get_dtype
764
+ dtype = _get_dtype(dtype)
765
+ float16 = dtype == torch.float16
766
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
767
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
768
+ if force_float32:
769
+ args.fp16 = False
770
+ args.bf16 = False
771
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
772
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
773
+ args.fp16 = float16
774
+ args.bf16 = not float16
775
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
776
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
777
+ args.eval_strategy = 'steps'
778
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
779
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
780
+ if ga_steps is not None and ga_steps > 1:
781
+ from transformers import __version__ as transformers_version
782
+ if Version(transformers_version) <= Version('4.45.2'):
783
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
784
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
785
+ if getattr(args, 'eval_strategy', 'no') != 'no':
786
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
787
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
788
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
789
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
790
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
791
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
792
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
793
+ if force_float32:
794
+ args.bf16_full_eval = False
795
+ args.fp16_full_eval = False
796
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
797
+ args.bf16_full_eval = True
798
+ args.fp16_full_eval = False
799
+ elif not bf16_full_eval and not fp16_full_eval:
800
+ args.bf16_full_eval = args.bf16
801
+ args.fp16_full_eval = args.fp16
802
+ _output_logits = False
803
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
804
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
805
+ if _output_logits:
806
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
807
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
808
+ pass
809
+ else:
810
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
811
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
812
+ if args_max_seq_length is None and model_max_seq_length is not None:
813
+ max_seq_length = model.max_seq_length
814
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
815
+ if model is not None and hasattr(model, 'for_training'):
816
+ model.for_training()
817
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
818
+ if 'processing_class' in locals():
819
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
820
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
821
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
822
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
823
+ if not isinstance(data_collator, UnslothVisionDataCollator):
824
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
825
+ data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
826
+ elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
827
+ data_collator = DataCollatorForSeq2Seq(__tokenizer)
828
+ else:
829
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
830
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
831
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
832
+ if not isinstance(data_collator, UnslothVisionDataCollator):
833
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
834
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
835
+ data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
836
+ else:
837
+ data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
838
+ other_metrics = []
839
+
840
+ from unsloth_zoo.logging_utils import PatchRLStatistics
841
+ PatchRLStatistics('gkd_trainer', other_metrics)
842
+
843
+ super().__init__(
844
+ model = model,
845
+ teacher_model = teacher_model,
846
+ args = args,
847
+ data_collator = data_collator,
848
+ train_dataset = train_dataset,
849
+ eval_dataset = eval_dataset,
850
+ processing_class = processing_class,
851
+ compute_metrics = compute_metrics,
852
+ callbacks = callbacks,
853
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics,
854
+ peft_config = peft_config,
855
+ formatting_func = formatting_func,**kwargs)
856
+ if hasattr(self, 'neftune_hook_handle'):
857
+ self.neftune_hook_handle.remove()
858
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
859
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
860
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
861
+ pass
862
+
863
+ pass
unsloth_compiled_cache/UnslothGRPOTrainer.py ADDED
@@ -0,0 +1,1438 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.3.17
3
+ 2025.3.19
4
+ 4.50.3
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from trl.trainer.grpo_trainer import (Any, AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, Dataset, GRPOConfig, GRPOTrainer, GenerationConfig, IterableDataset, Optional, PeftConfig, PreTrainedModel, PreTrainedTokenizerBase, RepeatRandomSampler, RewardFunc, Sampler, SyncRefModelCallback, Trainer, TrainerCallback, Union, apply_chat_template, broadcast_object_list, create_reference_model, defaultdict, gather, gather_object, generate_model_card, get_comet_experiment_url, is_conversational, is_deepspeed_zero3_enabled, is_peft_model, is_wandb_available, maybe_apply_chat_template, nn, os, pad, patch, prepare_deepspeed, set_seed, textwrap, torch, transformers, unwrap_model_for_generation, version, wandb, warnings, os, torch, transformers, Any, Union, apply_chat_template, broadcast_object_list, gather, gather_object, is_conversational, maybe_apply_chat_template, nn, os, pad, torch, unwrap_model_for_generation, wandb, GRPOTrainer, Trainer, gather, os, torch)
13
+
14
+
15
+ import os
16
+ from typing import *
17
+ from dataclasses import dataclass, field
18
+ from packaging.version import Version
19
+ import torch
20
+ import numpy as np
21
+ from contextlib import nullcontext
22
+ from torch.nn import functional as F
23
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
+
25
+ torch_compile_options = {
26
+ "epilogue_fusion" : True,
27
+ "max_autotune" : False,
28
+ "shape_padding" : True,
29
+ "trace.enabled" : False,
30
+ "triton.cudagraphs" : False,
31
+ }
32
+
33
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
+ def selective_log_softmax(logits, index):
35
+ logits = logits.to(torch.float32)
36
+ selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
+ # loop to reduce peak mem consumption
38
+ # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
+ logsumexp_values = torch.logsumexp(logits, dim = -1)
40
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
+ return per_token_logps
42
+
43
+ def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta, advantages):
44
+ # All Unsloth Zoo code licensed under LGPLv3
45
+ old_logits = old_logits.to(torch.float32)
46
+ new_logits = new_logits.to(torch.float32)
47
+ input_ids = input_ids.unsqueeze(-1)
48
+
49
+ # x_i - logsumexp(x_i)
50
+ old_x = torch.gather(old_logits, dim = -1, index = input_ids).squeeze(-1)
51
+ new_x = torch.gather(new_logits, dim = -1, index = input_ids).squeeze(-1)
52
+ old = old_x - torch.logsumexp(old_logits, dim = -1)
53
+ new = new_x - torch.logsumexp(new_logits, dim = -1)
54
+
55
+ # Reverse KL
56
+ kl_i = torch.exp(old - new) - (old - new) - 1.0
57
+ # Full correct reverse KL divergence?? Missing term maybe?
58
+ # kl_i = torch.exp(new) * kl_i
59
+
60
+ # Below is forward KL (normal KL)
61
+ # kl_i = torch.exp(old) * (old - new)
62
+
63
+ # Must detach - otherwise gradients are not propagated correctly!
64
+ # exp(x - x) == 1
65
+ loss_i = torch.exp(new - new.detach()) * advantages.unsqueeze(1)
66
+ loss_i = -(loss_i - beta * kl_i)
67
+
68
+ mask = mask.to(torch.float32)
69
+ n_mask_per_reward = mask.sum(1)
70
+
71
+ # See https://github.com/huggingface/trl/pull/2881
72
+ loss_per_reward = (loss_i * mask).sum(1) / n_mask_per_reward
73
+ loss = loss_per_reward.mean()
74
+ # loss = (loss_i * mask).sum() / mask.sum()
75
+
76
+ # Get metrics as well which are folded
77
+ with torch.inference_mode():
78
+ completion_length = n_mask_per_reward.mean()
79
+ mean_kl_per_reward = (kl_i * mask).sum(1) / n_mask_per_reward
80
+ mean_kl = mean_kl_per_reward.mean()
81
+ pass
82
+ return loss, completion_length, mean_kl
83
+
84
+ class UnslothEfficientGRPO(torch.autograd.Function):
85
+ # All Unsloth Zoo code licensed under LGPLv3
86
+ @staticmethod
87
+ def forward(ctx, _new_hidden_states, _old_hidden_states, lm_head, _input_ids, _mask, _advantages, beta, scaler = None, n_chunks = 1):
88
+ def compute_loss(new_hidden_states, old_hidden_states, input_ids, mask, advantages, scaling):
89
+ new_logits = torch.matmul(new_hidden_states, lm_head.t())
90
+ new_logits = new_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred
91
+ old_logits = torch.matmul(old_hidden_states, lm_head.t())
92
+ old_logits = old_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred
93
+ loss, completion_length, mean_kl = grpo_compute_loss(
94
+ old_logits, new_logits, input_ids, mask, beta, advantages,
95
+ )
96
+ # Scale loss if needed for mixed precision training
97
+ scaled_loss = loss * scaling
98
+ # Must add .loss.detach otherwise autograd uses 2x VRAM
99
+ return scaled_loss, (loss.detach(), completion_length, mean_kl,)
100
+ pass
101
+
102
+ device =_new_hidden_states.device
103
+ grad_inputs = torch.empty_like(_new_hidden_states)
104
+ accumulated_loss = torch.zeros(1, device = device)
105
+ accumulated_completion_length = torch.zeros(1, device = device)
106
+ accumulated_mean_kl = torch.zeros(1, device = device)
107
+
108
+ def accumulate_chunk(new_hidden_states_j, old_hidden_states_j, input_ids_j, mask_j, advantages_j, scaling):
109
+ (chunk_grad_input,), (chunk_loss, (unscaled_loss, chunk_completion_length, chunk_mean_kl,)) = torch.func.grad_and_value(
110
+ compute_loss,
111
+ argnums = (0,),
112
+ has_aux = True,
113
+ )(new_hidden_states_j, old_hidden_states_j, input_ids_j, mask_j, advantages_j, scaling)
114
+ accumulated_loss .add_(unscaled_loss)
115
+ accumulated_completion_length.add_(chunk_completion_length)
116
+ accumulated_mean_kl .add_(chunk_mean_kl)
117
+ return chunk_grad_input
118
+ pass
119
+
120
+ accumulate_chunk = torch.compile(
121
+ accumulate_chunk,
122
+ fullgraph = True,
123
+ options = torch_compile_options,
124
+ )
125
+
126
+ grad_inputs_chunks = torch.chunk(grad_inputs, chunks = n_chunks, dim = 0)
127
+ new_hidden_states = torch.chunk(_new_hidden_states, chunks = n_chunks, dim = 0)
128
+ old_hidden_states = torch.chunk(_old_hidden_states, chunks = n_chunks, dim = 0)
129
+ input_ids = torch.chunk(_input_ids, chunks = n_chunks, dim = 0)
130
+ mask = torch.chunk(_mask, chunks = n_chunks, dim = 0)
131
+ advantages = torch.chunk(_advantages, chunks = n_chunks, dim = 0)
132
+
133
+ # Get mixed precision scaling if seen
134
+ scaling = scaler.get_scale() if scaler is not None else 1.0
135
+
136
+ # Force torch.compile to use dynamic shapes for seqlen dim
137
+ mark_dynamic = lambda x: torch._dynamo.mark_dynamic(x, 1)
138
+
139
+ for (grad_inputs_j, new_hidden_states_j, old_hidden_states_j, input_ids_j, mask_j, advantages_j,) in \
140
+ zip(grad_inputs_chunks, new_hidden_states, old_hidden_states, input_ids, mask, advantages):
141
+
142
+ mark_dynamic(new_hidden_states_j)
143
+ mark_dynamic(old_hidden_states_j)
144
+ mark_dynamic(input_ids_j)
145
+ mark_dynamic(mask_j)
146
+
147
+ grad_inputs_j.copy_(
148
+ accumulate_chunk(new_hidden_states_j, old_hidden_states_j, input_ids_j, mask_j, advantages_j, scaling)
149
+ )
150
+ pass
151
+
152
+ grad_inputs .div_(n_chunks)
153
+ accumulated_loss .div_(n_chunks)
154
+ accumulated_completion_length.div_(n_chunks)
155
+ accumulated_mean_kl .div_(n_chunks)
156
+ ctx.save_for_backward(grad_inputs)
157
+
158
+ return (
159
+ accumulated_loss,
160
+ accumulated_completion_length,
161
+ accumulated_mean_kl,
162
+ )
163
+ pass
164
+
165
+ @staticmethod
166
+ def backward(ctx, grad_output, dcompletion_length, dmean_kl):
167
+ (grad_input,) = ctx.saved_tensors
168
+ return (grad_input, None, None, None, None, None, None, None, None,)
169
+ pass
170
+
171
+ def grpo_accumulated_loss(
172
+ trainer,
173
+ input_ids,
174
+ logits_to_keep,
175
+ completion_mask,
176
+ advantages,
177
+ n_chunks = -1,
178
+ ):
179
+ # All Unsloth Zoo code licensed under LGPLv3
180
+ bsz, qlen = input_ids.shape
181
+ # Find closest multiple
182
+ factors = [i for i in range(1, bsz + 1) if bsz % i == 0]
183
+ if n_chunks == -1: n_chunks = bsz
184
+ n_chunks = factors[min(np.searchsorted(factors, n_chunks), len(factors)-1)]
185
+
186
+ mixed_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16
187
+ os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1"
188
+
189
+ completion_input_ids = input_ids[:, -logits_to_keep:]
190
+ lm_head = trainer.model.get_output_embeddings().weight
191
+
192
+ with torch.amp.autocast(device_type = "cuda", dtype = mixed_dtype):
193
+ with torch.inference_mode(), trainer.accelerator.unwrap_model(trainer.model, keep_fp32_wrapper = False).disable_adapter():
194
+ old_hidden_states = trainer.model(input_ids = input_ids, logits_to_keep = logits_to_keep + 1).logits
195
+ pass
196
+
197
+ new_hidden_states = trainer.model(input_ids = input_ids, logits_to_keep = logits_to_keep + 1).logits
198
+
199
+ loss, completion_length, mean_kl = UnslothEfficientGRPO.apply(
200
+ new_hidden_states, old_hidden_states, lm_head,
201
+ completion_input_ids, completion_mask, advantages, trainer.beta,
202
+ trainer.accelerator.scaler,
203
+ n_chunks,
204
+ )
205
+ return loss, completion_length, mean_kl
206
+
207
+ # Old non efficient code path
208
+ new_logits = torch.matmul(new_hidden_states, lm_head.t())
209
+ new_logits = new_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred
210
+ old_logits = torch.matmul(old_hidden_states, lm_head.t())
211
+ old_logits = old_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred
212
+ loss, completion_length, mean_kl = grpo_compute_loss(
213
+ old_logits, new_logits, completion_input_ids, completion_mask, trainer.beta, advantages,
214
+ )
215
+ return loss, completion_length, mean_kl
216
+ pass
217
+
218
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options)
219
+ def grpo_compute_loss_slow(old_logits, new_logits, input_ids, mask, beta, advantages):
220
+ # All Unsloth Zoo code licensed under LGPLv3
221
+ old_logits = old_logits.to(torch.float32)
222
+ new_logits = new_logits.to(torch.float32)
223
+ input_ids = input_ids.unsqueeze(-1)
224
+
225
+ # x_i - logsumexp(x_i)
226
+ old_x = torch.gather(old_logits, dim = -1, index = input_ids).squeeze(-1)
227
+ new_x = torch.gather(new_logits, dim = -1, index = input_ids).squeeze(-1)
228
+ old = old_x - torch.logsumexp(old_logits, dim = -1)
229
+ new = new_x - torch.logsumexp(new_logits, dim = -1)
230
+
231
+ # Reverse KL
232
+ kl_i = torch.exp(old - new) - (old - new) - 1.0
233
+ # Full correct reverse KL divergence?? Missing term maybe?
234
+ # kl_i = torch.exp(new) * kl_i
235
+
236
+ # Below is forward KL (normal KL)
237
+ # kl_i = torch.exp(old) * (old - new)
238
+
239
+ # Must detach - otherwise gradients are not propagated correctly!
240
+ # exp(x - x) == 1
241
+ loss_i = torch.exp(new - new.detach()) * advantages.unsqueeze(1)
242
+ loss_i = -(loss_i - beta * kl_i)
243
+
244
+ mask = mask.to(torch.float32)
245
+ n_mask_per_reward = mask.sum(1)
246
+
247
+ # See https://github.com/huggingface/trl/pull/2881
248
+ loss_per_reward = (loss_i * mask).sum(1) / n_mask_per_reward
249
+ loss = loss_per_reward.mean()
250
+ # loss = (loss_i * mask).sum() / mask.sum()
251
+
252
+ # Get metrics as well which are folded
253
+ with torch.inference_mode():
254
+ completion_length = n_mask_per_reward.mean()
255
+ mean_kl_per_reward = (kl_i * mask).sum(1) / n_mask_per_reward
256
+ mean_kl = mean_kl_per_reward.mean()
257
+ pass
258
+ return loss, completion_length, mean_kl
259
+
260
+ def vLLMSamplingParams(**kwargs):
261
+ from vllm import SamplingParams
262
+ sampling_params = SamplingParams(**kwargs)
263
+ sampling_params._set_kwargs = kwargs
264
+ return sampling_params
265
+ @dataclass
266
+ class UnslothGRPOConfig(GRPOConfig):
267
+ """
268
+
269
+ Configuration class for the [`GRPOTrainer`].
270
+
271
+ Only the parameters specific to GRPO training are listed here. For details on other parameters, refer to the
272
+ [`~transformers.TrainingArguments`] documentation.
273
+
274
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
275
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
276
+ command line.
277
+
278
+ Parameters:
279
+ > Parameters that control the model and reference model
280
+
281
+ model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
282
+ Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model`
283
+ argument of the [`GRPOTrainer`] is provided as a string.
284
+
285
+ > Parameters that control the data preprocessing
286
+
287
+ remove_unused_columns (`bool`, *optional*, defaults to `False`):
288
+ Whether to only keep the column `"prompt"` in the dataset. If you use a custom reward function that
289
+ requires any column other than `"prompts"` and `"completions"`, you should keep this to `False`.
290
+ max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
291
+ Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left.
292
+ num_generations (`int` or `None`, *optional*, defaults to `8`):
293
+ Number of generations per prompt to sample. The global batch size (num_processes * per_device_batch_size)
294
+ must be divisible by this value.
295
+ temperature (`float`, *optional*, defaults to `0.9`):
296
+ Temperature for sampling. The higher the temperature, the more random the completions.
297
+ max_completion_length (`int` or `None`, *optional*, defaults to `256`):
298
+ Maximum length of the generated completion.
299
+ ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
300
+ This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
301
+ improving generation speed. However, disabling this option allows training models that exceed the VRAM
302
+ capacity of a single GPU, albeit at the cost of slower generation. Disabling this option is not compatible
303
+ with vLLM generation.
304
+
305
+ > Parameters that control generation acceleration powered by vLLM
306
+
307
+ use_vllm (`bool`, *optional*, defaults to `False`):
308
+ Whether to use vLLM for generating completions. If set to `True`, ensure that a GPU is kept unused for
309
+ training, as vLLM will require one for generation. vLLM must be installed (`pip install vllm`).
310
+ vllm_device (`str`, *optional*, defaults to `"auto"`):
311
+ Device where vLLM generation will run, e.g. `"cuda:1"`. If set to `"auto"` (default), the system will
312
+ automatically select the next available GPU after the last one used for training. This assumes that
313
+ training has not already occupied all available GPUs. If only one device is available, the device will be
314
+ shared between both training and vLLM.
315
+ vllm_gpu_memory_utilization (`float`, *optional*, defaults to `0.9`):
316
+ Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache on the
317
+ device dedicated to generation powered by vLLM. Higher values will increase the KV cache size and thus
318
+ improve the model's throughput. However, if the value is too high, it may cause out-of-memory (OOM) errors
319
+ during initialization.
320
+ vllm_dtype (`str`, *optional*, defaults to `"auto"`):
321
+ Data type to use for vLLM generation. If set to `"auto"`, the data type will be automatically determined
322
+ based on the model configuration. Find the supported values in the vLLM documentation.
323
+ vllm_max_model_len (`int` or `None`, *optional*, defaults to `None`):
324
+ If set, the `max_model_len` to use for vLLM. This could be useful when running with reduced
325
+ `vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model
326
+ context size, which might be much larger than the KV cache, leading to inefficiencies.
327
+
328
+ > Parameters that control the training
329
+
330
+ learning_rate (`float`, *optional*, defaults to `1e-6`):
331
+ Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
332
+ [`~transformers.TrainingArguments`].
333
+ beta (`float`, *optional*, defaults to `0.04`):
334
+ KL coefficient.
335
+ reward_weights (`list[float]` or `None`, *optional*, defaults to `None`):
336
+ Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are
337
+ weighted equally with weight `1.0`.
338
+ sync_ref_model (`bool`, *optional*, defaults to `False`):
339
+ Whether to synchronize the reference model with the active model every `ref_model_sync_steps` steps, using
340
+ the `ref_model_mixup_alpha` parameter. This synchronization originites from the
341
+ [TR-DPO](https://huggingface.co/papers/2404.09656) paper.
342
+ ref_model_mixup_alpha (`float`, *optional*, defaults to `0.9`):
343
+ α parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which controls the mix
344
+ between the current policy and the previous reference policy during updates. The reference policy is
345
+ updated according to the equation: `π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you
346
+ must set `sync_ref_model=True`.
347
+ ref_model_sync_steps (`int`, *optional*, defaults to `64`):
348
+ τ parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which determines how
349
+ frequently the current policy is synchronized with the reference policy. To use this parameter, you must
350
+ set `sync_ref_model=True`.
351
+
352
+ > Parameters that control the logging
353
+
354
+ log_completions (`bool`, *optional*, defaults to `False`):
355
+ Whether to log the completions during training.
356
+
357
+ """
358
+ vllm_sampling_params: Optional[Any] = field(
359
+ default = None,
360
+ metadata = {'help': 'vLLM SamplingParams'},
361
+ )
362
+ unsloth_num_chunks : Optional[int] = field(
363
+ default = -1,
364
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
365
+ )
366
+ def __init__(
367
+ self,
368
+ output_dir = None,
369
+ overwrite_output_dir = None,
370
+ do_train = False,
371
+ do_eval = False,
372
+ do_predict = False,
373
+ eval_strategy = 'no',
374
+ prediction_loss_only = False,
375
+ per_device_train_batch_size = 4,
376
+ per_device_eval_batch_size = 4,
377
+ per_gpu_train_batch_size = None,
378
+ per_gpu_eval_batch_size = None,
379
+ gradient_accumulation_steps = 2,
380
+ eval_accumulation_steps = 2,
381
+ eval_delay = 0,
382
+ torch_empty_cache_steps = 250,
383
+ learning_rate = 5e-05,
384
+ weight_decay = 0.01,
385
+ adam_beta1 = 0.9,
386
+ adam_beta2 = 0.999,
387
+ adam_epsilon = 1e-08,
388
+ max_grad_norm = 1.0,
389
+ num_train_epochs = 3.0,
390
+ max_steps = -1,
391
+ lr_scheduler_type = 'linear',
392
+ warmup_ratio = 0.1,
393
+ warmup_steps = 0,
394
+ log_level = 'passive',
395
+ log_level_replica = 'warning',
396
+ log_on_each_node = True,
397
+ logging_dir = None,
398
+ logging_strategy = 'steps',
399
+ logging_first_step = False,
400
+ logging_steps = 1,
401
+ logging_nan_inf_filter = False,
402
+ save_strategy = 'steps',
403
+ save_steps = 500,
404
+ save_total_limit = None,
405
+ save_safetensors = True,
406
+ save_on_each_node = False,
407
+ save_only_model = False,
408
+ restore_callback_states_from_checkpoint = False,
409
+ no_cuda = False,
410
+ use_cpu = False,
411
+ use_mps_device = False,
412
+ seed = 3407,
413
+ data_seed = 3407,
414
+ jit_mode_eval = False,
415
+ use_ipex = False,
416
+ bf16 = False,
417
+ fp16 = False,
418
+ fp16_opt_level = 'O1',
419
+ half_precision_backend = 'auto',
420
+ bf16_full_eval = False,
421
+ fp16_full_eval = False,
422
+ tf32 = None,
423
+ local_rank = -1,
424
+ ddp_backend = None,
425
+ tpu_num_cores = None,
426
+ tpu_metrics_debug = False,
427
+ debug = '',
428
+ dataloader_drop_last = False,
429
+ eval_steps = None,
430
+ dataloader_num_workers = 0,
431
+ dataloader_prefetch_factor = None,
432
+ past_index = -1,
433
+ run_name = None,
434
+ disable_tqdm = None,
435
+ remove_unused_columns = False,
436
+ label_names = None,
437
+ load_best_model_at_end = False,
438
+ metric_for_best_model = None,
439
+ greater_is_better = None,
440
+ ignore_data_skip = False,
441
+ fsdp = '',
442
+ fsdp_min_num_params = 0,
443
+ fsdp_config = None,
444
+ tp_size = 0,
445
+ fsdp_transformer_layer_cls_to_wrap = None,
446
+ accelerator_config = None,
447
+ deepspeed = None,
448
+ label_smoothing_factor = 0.0,
449
+ optim = 'adamw_8bit',
450
+ optim_args = None,
451
+ adafactor = False,
452
+ group_by_length = False,
453
+ length_column_name = 'length',
454
+ report_to = None,
455
+ ddp_find_unused_parameters = None,
456
+ ddp_bucket_cap_mb = None,
457
+ ddp_broadcast_buffers = None,
458
+ dataloader_pin_memory = True,
459
+ dataloader_persistent_workers = False,
460
+ skip_memory_metrics = True,
461
+ use_legacy_prediction_loop = False,
462
+ push_to_hub = False,
463
+ resume_from_checkpoint = None,
464
+ hub_model_id = None,
465
+ hub_strategy = 'every_save',
466
+ hub_token = None,
467
+ hub_private_repo = None,
468
+ hub_always_push = False,
469
+ gradient_checkpointing = False,
470
+ gradient_checkpointing_kwargs = None,
471
+ include_inputs_for_metrics = False,
472
+ eval_do_concat_batches = True,
473
+ fp16_backend = 'auto',
474
+ evaluation_strategy = None,
475
+ push_to_hub_model_id = None,
476
+ push_to_hub_organization = None,
477
+ push_to_hub_token = None,
478
+ mp_parameters = '',
479
+ auto_find_batch_size = False,
480
+ full_determinism = False,
481
+ torchdynamo = None,
482
+ ray_scope = 'last',
483
+ ddp_timeout = 1800,
484
+ torch_compile = False,
485
+ torch_compile_backend = None,
486
+ torch_compile_mode = None,
487
+ dispatch_batches = None,
488
+ split_batches = None,
489
+ include_tokens_per_second = False,
490
+ include_num_input_tokens_seen = False,
491
+ neftune_noise_alpha = None,
492
+ optim_target_modules = None,
493
+ batch_eval_metrics = False,
494
+ eval_on_start = False,
495
+ use_liger_kernel = False,
496
+ eval_use_gather_object = False,
497
+ average_tokens_across_devices = False,
498
+ model_init_kwargs = None,
499
+ max_prompt_length = 512,
500
+ num_generations = 8,
501
+ temperature = 0.9,
502
+ max_completion_length = 256,
503
+ ds3_gather_for_generation = True,
504
+ use_vllm = False,
505
+ vllm_device = 'auto',
506
+ vllm_gpu_memory_utilization = 0.9,
507
+ vllm_dtype = 'auto',
508
+ vllm_max_model_len = None,
509
+ beta = 0.04,
510
+ reward_weights = None,
511
+ sync_ref_model = False,
512
+ ref_model_mixup_alpha = 0.9,
513
+ ref_model_sync_steps = 64,
514
+ log_completions = False,
515
+ vllm_sampling_params = None,
516
+ unsloth_num_chunks = -1,
517
+ **kwargs,
518
+ ):
519
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
520
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
521
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
522
+ output_dir = 'unsloth_training_checkpoints'
523
+ save_strategy = 'no'
524
+ div = per_device_train_batch_size // num_generations
525
+ if div * num_generations != per_device_train_batch_size:
526
+ print('Unsloth: We now expect `per_device_train_batch_size` to be a multiple of `num_generations`.\nWe will change the batch size of ' + str(per_device_train_batch_size) + ' to the `num_generations` of ' + str(num_generations))
527
+ per_device_train_batch_size = num_generations
528
+
529
+ super().__init__(
530
+ output_dir = output_dir,
531
+ overwrite_output_dir = overwrite_output_dir,
532
+ do_train = do_train,
533
+ do_eval = do_eval,
534
+ do_predict = do_predict,
535
+ eval_strategy = eval_strategy,
536
+ prediction_loss_only = prediction_loss_only,
537
+ per_device_train_batch_size = per_device_train_batch_size,
538
+ per_device_eval_batch_size = per_device_eval_batch_size,
539
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
540
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
541
+ gradient_accumulation_steps = gradient_accumulation_steps,
542
+ eval_accumulation_steps = eval_accumulation_steps,
543
+ eval_delay = eval_delay,
544
+ torch_empty_cache_steps = torch_empty_cache_steps,
545
+ learning_rate = learning_rate,
546
+ weight_decay = weight_decay,
547
+ adam_beta1 = adam_beta1,
548
+ adam_beta2 = adam_beta2,
549
+ adam_epsilon = adam_epsilon,
550
+ max_grad_norm = max_grad_norm,
551
+ num_train_epochs = num_train_epochs,
552
+ max_steps = max_steps,
553
+ lr_scheduler_type = lr_scheduler_type,
554
+ warmup_ratio = warmup_ratio,
555
+ warmup_steps = warmup_steps,
556
+ log_level = log_level,
557
+ log_level_replica = log_level_replica,
558
+ log_on_each_node = log_on_each_node,
559
+ logging_dir = logging_dir,
560
+ logging_strategy = logging_strategy,
561
+ logging_first_step = logging_first_step,
562
+ logging_steps = logging_steps,
563
+ logging_nan_inf_filter = logging_nan_inf_filter,
564
+ save_strategy = save_strategy,
565
+ save_steps = save_steps,
566
+ save_total_limit = save_total_limit,
567
+ save_safetensors = save_safetensors,
568
+ save_on_each_node = save_on_each_node,
569
+ save_only_model = save_only_model,
570
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
571
+ no_cuda = no_cuda,
572
+ use_cpu = use_cpu,
573
+ use_mps_device = use_mps_device,
574
+ seed = seed,
575
+ data_seed = data_seed,
576
+ jit_mode_eval = jit_mode_eval,
577
+ use_ipex = use_ipex,
578
+ bf16 = bf16,
579
+ fp16 = fp16,
580
+ fp16_opt_level = fp16_opt_level,
581
+ half_precision_backend = half_precision_backend,
582
+ bf16_full_eval = bf16_full_eval,
583
+ fp16_full_eval = fp16_full_eval,
584
+ tf32 = tf32,
585
+ local_rank = local_rank,
586
+ ddp_backend = ddp_backend,
587
+ tpu_num_cores = tpu_num_cores,
588
+ tpu_metrics_debug = tpu_metrics_debug,
589
+ debug = debug,
590
+ dataloader_drop_last = dataloader_drop_last,
591
+ eval_steps = eval_steps,
592
+ dataloader_num_workers = dataloader_num_workers,
593
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
594
+ past_index = past_index,
595
+ run_name = run_name,
596
+ disable_tqdm = disable_tqdm,
597
+ remove_unused_columns = remove_unused_columns,
598
+ label_names = label_names,
599
+ load_best_model_at_end = load_best_model_at_end,
600
+ metric_for_best_model = metric_for_best_model,
601
+ greater_is_better = greater_is_better,
602
+ ignore_data_skip = ignore_data_skip,
603
+ fsdp = fsdp,
604
+ fsdp_min_num_params = fsdp_min_num_params,
605
+ fsdp_config = fsdp_config,
606
+ tp_size = tp_size,
607
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
608
+ accelerator_config = accelerator_config,
609
+ deepspeed = deepspeed,
610
+ label_smoothing_factor = label_smoothing_factor,
611
+ optim = optim,
612
+ optim_args = optim_args,
613
+ adafactor = adafactor,
614
+ group_by_length = group_by_length,
615
+ length_column_name = length_column_name,
616
+ report_to = report_to,
617
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
618
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
619
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
620
+ dataloader_pin_memory = dataloader_pin_memory,
621
+ dataloader_persistent_workers = dataloader_persistent_workers,
622
+ skip_memory_metrics = skip_memory_metrics,
623
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
624
+ push_to_hub = push_to_hub,
625
+ resume_from_checkpoint = resume_from_checkpoint,
626
+ hub_model_id = hub_model_id,
627
+ hub_strategy = hub_strategy,
628
+ hub_token = hub_token,
629
+ hub_private_repo = hub_private_repo,
630
+ hub_always_push = hub_always_push,
631
+ gradient_checkpointing = gradient_checkpointing,
632
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
633
+ include_inputs_for_metrics = include_inputs_for_metrics,
634
+ eval_do_concat_batches = eval_do_concat_batches,
635
+ fp16_backend = fp16_backend,
636
+ evaluation_strategy = evaluation_strategy,
637
+ push_to_hub_model_id = push_to_hub_model_id,
638
+ push_to_hub_organization = push_to_hub_organization,
639
+ push_to_hub_token = push_to_hub_token,
640
+ mp_parameters = mp_parameters,
641
+ auto_find_batch_size = auto_find_batch_size,
642
+ full_determinism = full_determinism,
643
+ torchdynamo = torchdynamo,
644
+ ray_scope = ray_scope,
645
+ ddp_timeout = ddp_timeout,
646
+ torch_compile = torch_compile,
647
+ torch_compile_backend = torch_compile_backend,
648
+ torch_compile_mode = torch_compile_mode,
649
+ dispatch_batches = dispatch_batches,
650
+ split_batches = split_batches,
651
+ include_tokens_per_second = include_tokens_per_second,
652
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
653
+ neftune_noise_alpha = neftune_noise_alpha,
654
+ optim_target_modules = optim_target_modules,
655
+ batch_eval_metrics = batch_eval_metrics,
656
+ eval_on_start = eval_on_start,
657
+ use_liger_kernel = use_liger_kernel,
658
+ eval_use_gather_object = eval_use_gather_object,
659
+ average_tokens_across_devices = average_tokens_across_devices,
660
+ model_init_kwargs = model_init_kwargs,
661
+ max_prompt_length = max_prompt_length,
662
+ num_generations = num_generations,
663
+ temperature = temperature,
664
+ max_completion_length = max_completion_length,
665
+ ds3_gather_for_generation = ds3_gather_for_generation,
666
+ use_vllm = use_vllm,
667
+ vllm_device = vllm_device,
668
+ vllm_gpu_memory_utilization = vllm_gpu_memory_utilization,
669
+ vllm_dtype = vllm_dtype,
670
+ vllm_max_model_len = vllm_max_model_len,
671
+ beta = beta,
672
+ reward_weights = reward_weights,
673
+ sync_ref_model = sync_ref_model,
674
+ ref_model_mixup_alpha = ref_model_mixup_alpha,
675
+ ref_model_sync_steps = ref_model_sync_steps,
676
+ log_completions = log_completions,**kwargs)
677
+ self.vllm_sampling_params = vllm_sampling_params
678
+ self.unsloth_num_chunks = unsloth_num_chunks
679
+ pass
680
+
681
+ class _UnslothGRPOTrainer(Trainer):
682
+ """"""
683
+
684
+ _tag_names = ["trl", "grpo"]
685
+
686
+ def __init__(
687
+ self,
688
+ model: Union[str, PreTrainedModel],
689
+ reward_funcs: Union[RewardFunc, list[RewardFunc]],
690
+ args: GRPOConfig = None,
691
+ train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
692
+ eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None,
693
+ processing_class: Optional[PreTrainedTokenizerBase] = None,
694
+ reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None,
695
+ callbacks: Optional[list[TrainerCallback]] = None,
696
+ optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
697
+ peft_config: Optional["PeftConfig"] = None,
698
+ ):
699
+
700
+ if hasattr(model, 'vllm_engine') and hasattr(args, 'use_vllm') and (getattr(args, 'use_vllm', False) == False): args.use_vllm = True
701
+ # Args
702
+ if args is None:
703
+ model_name = model if isinstance(model, str) else model.config._name_or_path
704
+ model_name = model_name.split("/")[-1]
705
+ args = GRPOConfig(f"{model_name}-GRPO")
706
+
707
+ # Models
708
+ # Trained model
709
+ model_init_kwargs = args.model_init_kwargs or {}
710
+ if isinstance(model, str):
711
+ model_id = model
712
+ torch_dtype = model_init_kwargs.get("torch_dtype")
713
+ if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None:
714
+ pass # torch_dtype is already a torch.dtype or "auto" or None
715
+ elif isinstance(torch_dtype, str): # it's a str, but not "auto"
716
+ torch_dtype = getattr(torch, torch_dtype)
717
+ model_init_kwargs["torch_dtype"] = torch_dtype
718
+ else:
719
+ raise ValueError(
720
+ "Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing "
721
+ f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
722
+ )
723
+ # Disable caching if gradient checkpointing is enabled (not supported)
724
+ model_init_kwargs["use_cache"] = (
725
+ False if args.gradient_checkpointing else model_init_kwargs.get("use_cache")
726
+ )
727
+ model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
728
+ else:
729
+ model_id = model.config._name_or_path
730
+ if args.model_init_kwargs is not None:
731
+ raise ValueError(
732
+ "You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. "
733
+ "This argument can only be used when the `model` argument is a string."
734
+ )
735
+
736
+ if False:
737
+ model = model
738
+
739
+ # Reference model
740
+ if is_deepspeed_zero3_enabled():
741
+ self.ref_model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs)
742
+ elif not is_peft_model(model):
743
+ # If PEFT configuration is not provided, create a reference model based on the initial model.
744
+ self.ref_model = create_reference_model(model)
745
+ else:
746
+ # If PEFT is used, the reference model is not needed since the adapter can be disabled
747
+ # to revert to the initial model.
748
+ self.ref_model = None
749
+
750
+ # Processing class
751
+ if processing_class is None:
752
+ processing_class = AutoTokenizer.from_pretrained(model.config._name_or_path, padding_side="left")
753
+
754
+ # Reward functions
755
+ if not isinstance(reward_funcs, list):
756
+ reward_funcs = [reward_funcs]
757
+ for i, reward_func in enumerate(reward_funcs):
758
+ if isinstance(reward_func, str):
759
+ reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
760
+ reward_func, num_labels=1, **model_init_kwargs
761
+ )
762
+ self.reward_funcs = reward_funcs
763
+
764
+ # Reward weights
765
+ if args.reward_weights is not None:
766
+ if len(args.reward_weights) != len(reward_funcs):
767
+ raise ValueError(
768
+ f"Number of reward weights ({len(args.reward_weights)}) must match number of reward "
769
+ f"functions ({len(reward_funcs)})"
770
+ )
771
+ self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32)
772
+ else:
773
+ self.reward_weights = torch.ones(len(reward_funcs), dtype=torch.float32)
774
+
775
+ # Reward processing class
776
+ if reward_processing_classes is None:
777
+ reward_processing_classes = [None] * len(reward_funcs)
778
+ elif not isinstance(reward_processing_classes, list):
779
+ reward_processing_classes = [reward_processing_classes]
780
+ else:
781
+ if len(reward_processing_classes) != len(reward_funcs):
782
+ raise ValueError("The number of reward processing classes must match the number of reward functions.")
783
+
784
+ for i, (reward_processing_class, reward_func) in enumerate(zip(reward_processing_classes, reward_funcs)):
785
+ if isinstance(reward_func, PreTrainedModel):
786
+ if reward_processing_class is None:
787
+ reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path)
788
+ if reward_processing_class.pad_token_id is None:
789
+ reward_processing_class.pad_token = reward_processing_class.eos_token
790
+ # The reward model computes the reward for the latest non-padded token in the input sequence.
791
+ # So it's important to set the pad token ID to the padding token ID of the processing class.
792
+ reward_func.config.pad_token_id = reward_processing_class.pad_token_id
793
+ reward_processing_classes[i] = reward_processing_class
794
+ self.reward_processing_classes = reward_processing_classes
795
+
796
+ # Data collator
797
+ def data_collator(features): # No data collation is needed in GRPO
798
+ return features
799
+
800
+ # Training arguments
801
+ self.max_prompt_length = args.max_prompt_length
802
+ self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper
803
+ self.num_generations = args.num_generations # = G in the GRPO paper
804
+ self.use_vllm = args.use_vllm
805
+
806
+ self.beta = args.beta
807
+
808
+ # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
809
+ # input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the
810
+ # "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning:
811
+ # "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To
812
+ # suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True.
813
+ # This acts as a flag to indicate that the warning has already been issued.
814
+ model.warnings_issued["estimate_tokens"] = True
815
+
816
+ # Initialize the metrics
817
+ self._metrics = defaultdict(list)
818
+ self.log_completions = args.log_completions
819
+
820
+ super().__init__(
821
+ model=model,
822
+ args=args,
823
+ data_collator=data_collator,
824
+ train_dataset=train_dataset,
825
+ eval_dataset=eval_dataset,
826
+ processing_class=processing_class,
827
+ callbacks=callbacks,
828
+ optimizers=optimizers,
829
+ )
830
+
831
+ # Check if the per_device_train/eval_batch_size * num processes can be divided by the number of generations
832
+ num_processes = self.accelerator.num_processes
833
+ global_batch_size = args.per_device_train_batch_size * num_processes
834
+ possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0]
835
+ if self.num_generations not in possible_values:
836
+ raise ValueError(
837
+ f"The global train batch size ({num_processes} x {args.per_device_train_batch_size}) must be evenly "
838
+ f"divisible by the number of generations per prompt ({self.num_generations}). Given the current train "
839
+ f"batch size, the valid values for the number of generations are: {possible_values}."
840
+ )
841
+ if self.args.eval_strategy != "no":
842
+ global_batch_size = args.per_device_eval_batch_size * num_processes
843
+ possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0]
844
+ if self.num_generations not in possible_values:
845
+ raise ValueError(
846
+ f"The global eval batch size ({num_processes} x {args.per_device_eval_batch_size}) must be evenly "
847
+ f"divisible by the number of generations per prompt ({self.num_generations}). Given the current "
848
+ f"eval batch size, the valid values for the number of generations are: {possible_values}."
849
+ )
850
+
851
+ # Ensure each process receives a unique seed to prevent duplicate completions when generating with
852
+ # transformers if num_generations exceeds per_device_train_batch_size. We could skip it if we use vLLM, but
853
+ # it's safer to set it in all cases.
854
+ set_seed(args.seed, device_specific=True)
855
+
856
+ if self.use_vllm:
857
+ self.llm = model.vllm_engine; self._last_loaded_step = 0; self.sampling_params = SamplingParams(
858
+ temperature=args.temperature,
859
+ max_tokens=self.max_completion_length,**getattr(getattr(args, 'vllm_sampling_params', vLLMSamplingParams()), '_set_kwargs', {}),)
860
+ else:
861
+ self.generation_config = GenerationConfig(
862
+ max_new_tokens=self.max_completion_length,
863
+ do_sample=True,
864
+ temperature=args.temperature,
865
+ pad_token_id=processing_class.pad_token_id,
866
+ )
867
+
868
+ # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
869
+ # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
870
+ # self.model_accepts_loss_kwargs to False to enable scaling.
871
+ self.model_accepts_loss_kwargs = False
872
+
873
+ # Add tags to the model
874
+ self.model.add_model_tags(self._tag_names)
875
+
876
+ if self.ref_model is not None:
877
+ if self.is_deepspeed_enabled:
878
+ self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
879
+ else:
880
+ self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
881
+
882
+ if args.sync_ref_model:
883
+ self.add_callback(SyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator))
884
+
885
+ for i, reward_func in enumerate(self.reward_funcs):
886
+ if isinstance(reward_func, PreTrainedModel):
887
+ self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True)
888
+
889
+ def _set_signature_columns_if_needed(self):
890
+ # If `self.args.remove_unused_columns` is True, non-signature columns are removed.
891
+ # By default, this method sets `self._signature_columns` to the model's expected inputs.
892
+ # In GRPOTrainer, we preprocess data, so using the model's signature columns doesn't work.
893
+ # Instead, we set them to the columns expected by the `training_step` method, hence the override.
894
+ if self._signature_columns is None:
895
+ self._signature_columns = ["prompt"]
896
+
897
+ def _get_train_sampler(self) -> Sampler:
898
+ # Returns a sampler that ensures each prompt is repeated across multiple processes. This guarantees that
899
+ # identical prompts are distributed to different GPUs, allowing rewards to be computed and normalized correctly
900
+ # within each prompt group. Using the same seed across processes ensures consistent prompt assignment,
901
+ # preventing discrepancies in group formation.
902
+ return RepeatRandomSampler(self.train_dataset, self.num_generations, seed=self.args.seed)
903
+
904
+ def _get_eval_sampler(self, eval_dataset) -> Sampler:
905
+ # Returns a sampler that ensures each prompt is repeated across multiple processes. This guarantees that
906
+ # identical prompts are distributed to different GPUs, allowing rewards to be computed and normalized correctly
907
+ # within each prompt group. Using the same seed across processes ensures consistent prompt assignment,
908
+ # preventing discrepancies in group formation.
909
+ return RepeatRandomSampler(eval_dataset, self.num_generations, seed=self.args.seed)
910
+
911
+ # Get the per-token log probabilities for the completions for the model and the reference model
912
+ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
913
+ if os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0':
914
+ return None # Unsloth efficient GRPO
915
+ # Otherwise, calculate normally:
916
+ if not hasattr(self, '_autocast_dtype'):
917
+ self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16
918
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1': self._autocast_dtype = torch.float16
919
+ with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype):
920
+ # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
921
+ logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
922
+ logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
923
+
924
+ input_ids = input_ids[:, -logits_to_keep:]
925
+ # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
926
+ # See https://github.com/huggingface/trl/issues/2770
927
+ logits = logits[:, -logits_to_keep:]
928
+ return logits
929
+ # return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens
930
+ pass
931
+
932
+ def _move_model_to_vllm(self, *args, **kwargs): return None
933
+
934
+ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
935
+ device = self.accelerator.device
936
+ prompts = [x["prompt"] for x in inputs]
937
+ prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
938
+ prompt_inputs = self.processing_class(
939
+ prompts_text, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False
940
+ )
941
+ prompt_inputs = super()._prepare_inputs(prompt_inputs)
942
+ prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
943
+
944
+ if self.max_prompt_length is not None:
945
+ prompt_ids = prompt_ids[:, -self.max_prompt_length :]
946
+ prompt_mask = prompt_mask[:, -self.max_prompt_length :]
947
+
948
+ # Generate completions using either vLLM or regular generation
949
+ if self.args.use_vllm:
950
+ # First, have main process load weights if needed
951
+ if self.state.global_step != self._last_loaded_step:
952
+ self._move_model_to_vllm()
953
+ self._last_loaded_step = self.state.global_step
954
+
955
+ # Generate completions using vLLM: gather all prompts and use them in a single call in the main process
956
+ all_prompts_text = gather_object(prompts_text)
957
+ if self.accelerator.is_main_process:
958
+ outputs = self.llm.generate(all_prompts_text, sampling_params=self.sampling_params, use_tqdm=False, lora_request = self.model.load_lora('grpo_trainer_lora_model', load_tensors = True))
959
+ completion_ids = [out.token_ids for completions in outputs for out in completions.outputs]
960
+ else:
961
+ completion_ids = [None] * len(all_prompts_text)
962
+ # Broadcast the completions from the main process to all processes, ensuring each process receives its
963
+ # corresponding slice.
964
+ completion_ids = broadcast_object_list(completion_ids, from_process=0)
965
+ process_slice = slice(
966
+ self.accelerator.process_index * len(prompts),
967
+ (self.accelerator.process_index + 1) * len(prompts),
968
+ )
969
+ completion_ids = completion_ids[process_slice]
970
+
971
+ # Pad the completions, and concatenate them with the prompts
972
+ completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
973
+ completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id)
974
+ prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
975
+ else:
976
+ # Regular generation path
977
+ with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
978
+ prompt_completion_ids = unwrapped_model.generate(
979
+ prompt_ids, attention_mask=prompt_mask, generation_config=self.generation_config
980
+ )
981
+
982
+ # Compute prompt length and extract completion ids
983
+ prompt_length = prompt_ids.size(1)
984
+ prompt_ids = prompt_completion_ids[:, :prompt_length]
985
+ completion_ids = prompt_completion_ids[:, prompt_length:]
986
+
987
+ # Mask everything after the first EOS token
988
+ is_eos = completion_ids == self.processing_class.eos_token_id
989
+ eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
990
+ eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
991
+ sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
992
+ completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
993
+
994
+ # Concatenate prompt_mask with completion_mask for logit computation
995
+ attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B*G, P+C)
996
+
997
+ logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
998
+
999
+ with torch.inference_mode(), torch.amp.autocast(device_type = 'cuda', dtype = ((torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16) if not torch.is_autocast_enabled('cuda') else nullcontext())if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '0' else torch.float16):
1000
+ if self.ref_model is not None:
1001
+ ref_per_token_logps = self._get_per_token_logps(
1002
+ self.ref_model, prompt_completion_ids, attention_mask, logits_to_keep
1003
+ )
1004
+ else:
1005
+ with self.accelerator.unwrap_model(self.model, keep_fp32_wrapper = False).disable_adapter():
1006
+ ref_per_token_logps = self._get_per_token_logps(
1007
+ self.model, prompt_completion_ids, attention_mask, logits_to_keep
1008
+ )
1009
+
1010
+ # Decode the generated completions
1011
+ completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
1012
+ if is_conversational(inputs[0]):
1013
+ completions = []
1014
+ for prompt, completion in zip(prompts, completions_text):
1015
+ bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else ""
1016
+ completions.append([{"role": "assistant", "content": bootstrap + completion}])
1017
+ else:
1018
+ completions = completions_text
1019
+
1020
+ rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
1021
+ for i, (reward_func, reward_processing_class) in enumerate(
1022
+ zip(self.reward_funcs, self.reward_processing_classes)
1023
+ ):
1024
+ if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models
1025
+ if is_conversational(inputs[0]):
1026
+ messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
1027
+ texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
1028
+ else:
1029
+ texts = [p + c for p, c in zip(prompts, completions)]
1030
+ reward_inputs = reward_processing_class(
1031
+ texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
1032
+ )
1033
+ reward_inputs = super()._prepare_inputs(reward_inputs)
1034
+ with torch.inference_mode(), torch.amp.autocast(device_type = 'cuda', dtype = ((torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16) if not torch.is_autocast_enabled('cuda') else nullcontext())if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '0' else torch.float16):
1035
+ rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,)
1036
+ else:
1037
+ # Repeat all input columns (but "prompt" and "completion") to match the number of generations
1038
+ keys = [key for key in inputs[0] if key not in ["prompt", "completion"]]
1039
+ reward_kwargs = {key: [example[key] for example in inputs] for key in keys}
1040
+ output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs)
1041
+ rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
1042
+
1043
+ # Gather the reward per function: this part is crucial, because the rewards are normalized per group and the
1044
+ # completions may be distributed across processes
1045
+ rewards_per_func = gather(rewards_per_func)
1046
+
1047
+ # Apply weights to each reward function's output and sum
1048
+ rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).sum(dim=1)
1049
+
1050
+ # Compute grouped-wise rewards
1051
+ mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
1052
+ std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
1053
+
1054
+ # Normalize the rewards to compute the advantages
1055
+ mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
1056
+ std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
1057
+ advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
1058
+
1059
+ # Slice to keep only the local part of the data
1060
+ process_slice = slice(
1061
+ self.accelerator.process_index * len(prompts),
1062
+ (self.accelerator.process_index + 1) * len(prompts),
1063
+ )
1064
+ advantages = advantages[process_slice]
1065
+
1066
+ # Log the metrics
1067
+ reward_per_func = rewards_per_func.mean(0)
1068
+ for i, reward_func in enumerate(self.reward_funcs):
1069
+ if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models
1070
+ reward_func_name = reward_func.config._name_or_path.split("/")[-1]
1071
+ else:
1072
+ reward_func_name = reward_func.__name__
1073
+ self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item())
1074
+
1075
+ self._metrics["reward"].append(rewards.mean().item())
1076
+ self._metrics["reward_std"].append(std_grouped_rewards.mean().item())
1077
+
1078
+ if (
1079
+ self.log_completions
1080
+ and self.state.global_step % self.args.logging_steps == 0
1081
+ and "wandb" in self.args.report_to
1082
+ ):
1083
+ import pandas as pd
1084
+
1085
+ # For logging
1086
+ table = {
1087
+ "step": [str(self.state.global_step)] * len(rewards),
1088
+ "prompt": gather_object(prompts_text),
1089
+ "completion": gather_object(completions_text),
1090
+ "reward": rewards.tolist(),
1091
+ }
1092
+ df = pd.DataFrame(table)
1093
+
1094
+ if wandb.run is not None and self.accelerator.is_main_process:
1095
+ wandb.log({"completions": wandb.Table(dataframe=df)})
1096
+
1097
+ return {
1098
+ "prompt_ids": prompt_ids,
1099
+ "prompt_mask": prompt_mask,
1100
+ "completion_ids": completion_ids,
1101
+ "completion_mask": completion_mask,
1102
+ "ref_per_token_logps": ref_per_token_logps,
1103
+ "advantages": advantages,
1104
+ }
1105
+
1106
+ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch = None):
1107
+ if return_outputs:
1108
+ raise ValueError("The GRPOTrainer does not support returning outputs")
1109
+ # Compute the per-token log probabilities for the model
1110
+
1111
+ prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
1112
+ completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
1113
+ input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
1114
+ bsz, qlen = input_ids.shape
1115
+ attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
1116
+ # attention_mask = None
1117
+ logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
1118
+ _input_ids = input_ids
1119
+ _logits_to_keep = logits_to_keep
1120
+
1121
+ per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)
1122
+
1123
+ # Compute the KL divergence between the model and the reference model
1124
+ ref_per_token_logps = inputs["ref_per_token_logps"]
1125
+ # per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
1126
+
1127
+ # x - x.detach() allows for preserving gradients from x
1128
+ advantages = inputs["advantages"]
1129
+ # per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
1130
+ # per_token_loss = -(per_token_loss - self.beta * per_token_kl)
1131
+ # loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
1132
+ input_ids = input_ids[:, -logits_to_keep:]
1133
+ if per_token_logps is not None:
1134
+ loss, completion_length, mean_kl = grpo_compute_loss_slow(
1135
+ ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages,
1136
+ )
1137
+ else:
1138
+ loss, completion_length, mean_kl = grpo_accumulated_loss(
1139
+ self, _input_ids, logits_to_keep, completion_mask, advantages,
1140
+ n_chunks = self.args.unsloth_num_chunks,
1141
+ )
1142
+
1143
+ # Log the metrics
1144
+ # completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
1145
+
1146
+ # mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
1147
+ # self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
1148
+
1149
+ if "train" in self._metrics:
1150
+ mode = "eval" if self.control.should_evaluate else "train"
1151
+ self._metrics[mode]["completion_length"].append(completion_length.item())
1152
+ self._metrics[mode]["kl"].append(mean_kl.item())
1153
+ else:
1154
+ self._metrics["completion_length"].append(completion_length.item())
1155
+ self._metrics["kl"].append(mean_kl.item())
1156
+ return loss
1157
+
1158
+ def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys: Optional[list[str]] = None):
1159
+ inputs = self._prepare_inputs(inputs)
1160
+ with torch.no_grad():
1161
+ with self.compute_loss_context_manager():
1162
+ loss = self.compute_loss(model, inputs)
1163
+ loss = loss.mean().detach()
1164
+ return loss, None, None
1165
+
1166
+ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
1167
+ metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics
1168
+
1169
+ # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
1170
+ # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
1171
+ if next(iter(logs.keys())).startswith("eval_"):
1172
+ metrics = {f"eval_{key}": val for key, val in metrics.items()}
1173
+
1174
+ logs = {**logs, **metrics}
1175
+ if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
1176
+ super().log(logs, start_time)
1177
+ else: # transformers<=4.46
1178
+ super().log(logs)
1179
+ self._metrics.clear()
1180
+
1181
+ def create_model_card(
1182
+ self,
1183
+ model_name: Optional[str] = None,
1184
+ dataset_name: Optional[str] = None,
1185
+ tags: Union[str, list[str], None] = None,
1186
+ ):
1187
+ """
1188
+ Creates a draft of a model card using the information available to the `Trainer`.
1189
+
1190
+ Args:
1191
+ model_name (`str` or `None`, *optional*, defaults to `None`):
1192
+ Name of the model.
1193
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
1194
+ Name of the dataset used for training.
1195
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
1196
+ Tags to be associated with the model card.
1197
+ """
1198
+ if not self.is_world_process_zero():
1199
+ return
1200
+
1201
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
1202
+ base_model = self.model.config._name_or_path
1203
+ else:
1204
+ base_model = None
1205
+
1206
+ tags = tags or []
1207
+ if isinstance(tags, str):
1208
+ tags = [tags]
1209
+
1210
+ if hasattr(self.model.config, "unsloth_version"):
1211
+ tags.append("unsloth")
1212
+
1213
+ citation = textwrap.dedent(
1214
+ """\
1215
+ @article{zhihong2024deepseekmath,
1216
+ title = {{DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models}},
1217
+ author = {Zhihong Shao and Peiyi Wang and Qihao Zhu and Runxin Xu and Junxiao Song and Mingchuan Zhang and Y. K. Li and Y. Wu and Daya Guo},
1218
+ year = 2024,
1219
+ eprint = {arXiv:2402.03300},
1220
+ }
1221
+ """
1222
+ )
1223
+
1224
+ model_card = generate_model_card(
1225
+ base_model=base_model,
1226
+ model_name=model_name,
1227
+ hub_model_id=self.hub_model_id,
1228
+ dataset_name=dataset_name,
1229
+ tags=tags,
1230
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
1231
+ comet_url=get_comet_experiment_url(),
1232
+ trainer_name="GRPO",
1233
+ trainer_citation=citation,
1234
+ paper_title="DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models",
1235
+ paper_id="2402.03300",
1236
+ )
1237
+
1238
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
1239
+ class UnslothGRPOTrainer(_UnslothGRPOTrainer):
1240
+ """
1241
+
1242
+ Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the
1243
+ paper [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300).
1244
+
1245
+ Example:
1246
+
1247
+ ```python
1248
+ from datasets import load_dataset
1249
+ from trl import GRPOTrainer
1250
+
1251
+ dataset = load_dataset("trl-lib/tldr", split="train")
1252
+
1253
+ def reward_func(completions, **kwargs):
1254
+ # Dummy reward function that rewards completions with more unique letters.
1255
+ return [float(len(set(completion))) for completion in completions]
1256
+
1257
+ trainer = GRPOTrainer(
1258
+ model="Qwen/Qwen2-0.5B-Instruct",
1259
+ reward_funcs=reward_func,
1260
+ train_dataset=dataset,
1261
+ )
1262
+
1263
+ trainer.train()
1264
+ ```
1265
+
1266
+ Args:
1267
+ model (`Union[str, PreTrainedModel]`):
1268
+ Model to be trained. Can be either:
1269
+
1270
+ - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or
1271
+ a path to a *directory* containing model weights saved using
1272
+ [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is
1273
+ loaded using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keywork arguments
1274
+ in `args.model_init_kwargs`.
1275
+ - A [`~transformers.PreTrainedModel`] object. Only causal language models are supported.
1276
+ reward_funcs (`Union[RewardFunc, list[RewardFunc]]`):
1277
+ Reward functions to be used for computing the rewards. To compute the rewards, we call all the reward
1278
+ functions with the prompts and completions and sum the rewards. Can be either:
1279
+
1280
+ - A single reward function, such as:
1281
+ - A string: The *model ID* of a pretrained model hosted inside a model repo on huggingface.co, or a
1282
+ path to a *directory* containing model weights saved using
1283
+ [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded
1284
+ using [`~transformers.AutoModelForSequenceClassification.from_pretrained`] with `num_labels=1` and the
1285
+ keyword arguments in `args.model_init_kwargs`.
1286
+ - A [`~transformers.PreTrainedModel`] object: Only sequence classification models are supported.
1287
+ - A custom reward function: The function is provided with the prompts and the generated completions,
1288
+ plus any additional columns in the dataset. It should return a list of rewards. For more details, see
1289
+ [Using a custom reward function](#using-a-custom-reward-function).
1290
+ - A list of reward functions, where each item can independently be any of the above types. Mixing different
1291
+ types within the list (e.g., a string model ID and a custom reward function) is allowed.
1292
+ args ([`GRPOConfig`], *optional*, defaults to `None`):
1293
+ Configuration for this trainer. If `None`, a default configuration is used.
1294
+ train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
1295
+ Dataset to use for training. It must include a column `"prompt"`. Any additional columns in the dataset is
1296
+ ignored. The format of the samples can be either:
1297
+
1298
+ - [Standard](dataset_formats#standard): Each sample contains plain text.
1299
+ - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
1300
+ and content).
1301
+ eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
1302
+ Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
1303
+ processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*, defaults to `None`):
1304
+ Processing class used to process the data. The padding side must be set to "left". If `None`, the
1305
+ processing class is loaded from the model's name with [`~transformers.AutoTokenizer.from_pretrained`].
1306
+ reward_processing_classes (`Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]`, *optional*, defaults to `None`):
1307
+ Processing classes corresponding to the reward functions specified in `reward_funcs`. Can be either:
1308
+
1309
+ - A single processing class: Used when `reward_funcs` contains only one reward function.
1310
+ - A list of processing classes: Must match the order and length of the reward functions in `reward_funcs`.
1311
+ If set to `None`, or if an element of the list corresponding to a [`~transformers.PreTrainedModel`] is
1312
+ `None`, the tokenizer for the model is automatically loaded using [`~transformers.AutoTokenizer.from_pretrained`].
1313
+ For elements in `reward_funcs` that are custom reward functions (not [`~transformers.PreTrainedModel`]),
1314
+ the corresponding entries in `reward_processing_classes` are ignored.
1315
+ callbacks (list of [`~transformers.TrainerCallback`], *optional*, defaults to `None`):
1316
+ List of callbacks to customize the training loop. Will add those to the list of default callbacks
1317
+ detailed in [here](https://huggingface.co/docs/transformers/main_classes/callback).
1318
+
1319
+ If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`]
1320
+ method.
1321
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`):
1322
+ A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your
1323
+ model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
1324
+ peft_config ([`~peft.PeftConfig`], *optional*, defaults to `None`):
1325
+ PEFT configuration used to wrap the model. If `None`, the model is not wrapped.
1326
+
1327
+ """
1328
+ def __init__(
1329
+ self,
1330
+ model,
1331
+ reward_funcs,
1332
+ args = None,
1333
+ train_dataset = None,
1334
+ eval_dataset = None,
1335
+ processing_class = None,
1336
+ reward_processing_classes = None,
1337
+ callbacks = None,
1338
+ peft_config = None,
1339
+ **kwargs
1340
+ ):
1341
+ if args is None: args = UnslothGRPOConfig()
1342
+ use_bf16 = getattr(args, 'bf16', False)
1343
+ use_fp16 = getattr(args, 'fp16', False)
1344
+ force_float32 = False
1345
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
1346
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
1347
+ force_float32 = True
1348
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
1349
+ dtype = getattr(model.config, 'torch_dtype', None)
1350
+ if dtype is None: dtype = model.get_input_embeddings().dtype
1351
+ from unsloth_zoo.utils import _get_dtype
1352
+ dtype = _get_dtype(dtype)
1353
+ float16 = dtype == torch.float16
1354
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
1355
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
1356
+ if force_float32:
1357
+ args.fp16 = False
1358
+ args.bf16 = False
1359
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1360
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
1361
+ args.fp16 = float16
1362
+ args.bf16 = not float16
1363
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
1364
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
1365
+ args.eval_strategy = 'steps'
1366
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
1367
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
1368
+ if ga_steps is not None and ga_steps > 1:
1369
+ from transformers import __version__ as transformers_version
1370
+ if Version(transformers_version) <= Version('4.45.2'):
1371
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
1372
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
1373
+ if getattr(args, 'eval_strategy', 'no') != 'no':
1374
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
1375
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
1376
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
1377
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
1378
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
1379
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
1380
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
1381
+ if force_float32:
1382
+ args.bf16_full_eval = False
1383
+ args.fp16_full_eval = False
1384
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
1385
+ args.bf16_full_eval = True
1386
+ args.fp16_full_eval = False
1387
+ elif not bf16_full_eval and not fp16_full_eval:
1388
+ args.bf16_full_eval = args.bf16
1389
+ args.fp16_full_eval = args.fp16
1390
+ _output_logits = False
1391
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
1392
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
1393
+ if _output_logits:
1394
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
1395
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1396
+ pass
1397
+ else:
1398
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
1399
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
1400
+ if args_max_seq_length is None and model_max_seq_length is not None:
1401
+ max_seq_length = model.max_seq_length
1402
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1403
+ if model is not None and hasattr(model, 'for_training'):
1404
+ model.for_training()
1405
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1406
+ if 'processing_class' in locals():
1407
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1408
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1409
+ other_metrics = []
1410
+ if not isinstance(reward_funcs, list): _reward_funcs = [reward_funcs]
1411
+ else: _reward_funcs = reward_funcs
1412
+ for reward_func in _reward_funcs:
1413
+ try:
1414
+ reward_func_name = reward_func.__name__
1415
+ other_metrics.append(f'rewards/{reward_func_name}')
1416
+ except: pass
1417
+
1418
+ from unsloth_zoo.logging_utils import PatchRLStatistics
1419
+ PatchRLStatistics('grpo_trainer', other_metrics)
1420
+
1421
+ super().__init__(
1422
+ model = model,
1423
+ reward_funcs = reward_funcs,
1424
+ args = args,
1425
+ train_dataset = train_dataset,
1426
+ eval_dataset = eval_dataset,
1427
+ processing_class = processing_class,
1428
+ reward_processing_classes = reward_processing_classes,
1429
+ callbacks = callbacks,
1430
+ peft_config = peft_config,**kwargs)
1431
+ if hasattr(self, 'neftune_hook_handle'):
1432
+ self.neftune_hook_handle.remove()
1433
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1434
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
1435
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1436
+ pass
1437
+
1438
+ pass
unsloth_compiled_cache/UnslothKTOTrainer.py ADDED
@@ -0,0 +1,1840 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.3.17
3
+ 2025.3.19
4
+ 4.50.3
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from trl.trainer.kto_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, KTOConfig, KTOTrainer, Literal, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedModelWrapper, PreTrainedTokenizerBase, ProcessorMixin, SequentialSampler, Trainer, TrainerCallback, TrainingArguments, Union, _get_kl_dataset, _process_tokens, _tokenize, amp, concatenate_datasets, contextmanager, create_reference_model, deepcopy, deepspeed, defaultdict, disable_dropout_in_model, generate_model_card, get_comet_experiment_url, has_length, inspect, is_comet_available, is_peft_available, is_wandb_available, itemgetter, log_table_to_comet_experiment, maybe_apply_chat_template, maybe_extract_prompt, maybe_unpair_preference_dataset, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_model_for_kbit_training, random, textwrap, torch, tqdm, transformers, version, wandb, warnings)
13
+
14
+
15
+ import os
16
+ from typing import *
17
+ from dataclasses import dataclass, field
18
+ from packaging.version import Version
19
+ import torch
20
+ import numpy as np
21
+ from contextlib import nullcontext
22
+ from torch.nn import functional as F
23
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
+
25
+ torch_compile_options = {
26
+ "epilogue_fusion" : True,
27
+ "max_autotune" : False,
28
+ "shape_padding" : True,
29
+ "trace.enabled" : False,
30
+ "triton.cudagraphs" : False,
31
+ }
32
+
33
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
+ def selective_log_softmax(logits, index):
35
+ logits = logits.to(torch.float32)
36
+ selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
+ # loop to reduce peak mem consumption
38
+ # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
+ logsumexp_values = torch.logsumexp(logits, dim = -1)
40
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
+ return per_token_logps
42
+ @dataclass
43
+ class UnslothKTOConfig(KTOConfig):
44
+ """
45
+
46
+ Configuration class for the [`KTOTrainer`].
47
+
48
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
49
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
50
+ command line.
51
+
52
+ Parameters:
53
+ learning_rate (`float`, *optional*, defaults to `5e-7`):
54
+ Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
55
+ [`~transformers.TrainingArguments`].
56
+ max_length (`int` or `None`, *optional*, defaults to `1024`):
57
+ Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
58
+ to use the default data collator.
59
+ max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
60
+ Maximum length of the prompt. This argument is required if you want to use the default data collator.
61
+ max_completion_length (`int` or `None`, *optional*, defaults to `None`):
62
+ Maximum length of the completion. This argument is required if you want to use the default data collator
63
+ and your model is an encoder-decoder.
64
+ beta (`float`, *optional*, defaults to `0.1`):
65
+ Parameter controlling the deviation from the reference model. Higher β means less deviation from the
66
+ reference model.
67
+ loss_type (`str`, *optional*, defaults to `"kto"`):
68
+ Type of loss to use. Possible values are:
69
+
70
+ - `"kto"`: KTO loss from the [KTO](https://huggingface.co/papers/2402.01306) paper.
71
+ - `"apo_zero_unpaired"`: Unpaired variant of APO-zero loss from the [APO](https://huggingface.co/papers/2408.06266) paper.
72
+
73
+ desirable_weight (`float`, *optional*, defaults to `1.0`):
74
+ Desirable losses are weighed by this factor to counter unequal number of desirable and undesirable paris.
75
+ undesirable_weight (`float`, *optional*, defaults to `1.0`):
76
+ Undesirable losses are weighed by this factor to counter unequal number of desirable and undesirable pairs.
77
+ label_pad_token_id (`int`, *optional*, defaults to `-100`):
78
+ Label pad token id. This argument is required if you want to use the default data collator.
79
+ padding_value (`int` or `None`, *optional*, defaults to `None`):
80
+ Padding value to use. If `None`, the padding value of the tokenizer is used.
81
+ truncation_mode (`str`, *optional*, defaults to `"keep_end"`):
82
+ Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`.
83
+ This argument is required if you want to use the default data collator.
84
+ generate_during_eval (`bool`, *optional*, defaults to `False`):
85
+ If `True`, generates and logs completions from both the model and the reference model to W&B or Comet during
86
+ evaluation.
87
+ is_encoder_decoder (`bool` or `None`, *optional*, defaults to `None`):
88
+ When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument,
89
+ you need to specify if the model returned by the callable is an encoder-decoder model.
90
+ precompute_ref_log_probs (`bool`, *optional*, defaults to `False`):
91
+ Whether to precompute reference model log probabilities for training and evaluation datasets. This is
92
+ useful when training without the reference model to reduce the total GPU memory needed.
93
+ model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
94
+ Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
95
+ string.
96
+ ref_model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
97
+ Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the reference model
98
+ from a string.
99
+ dataset_num_proc: (`int` or `None`, *optional*, defaults to `None`):
100
+ Number of processes to use for processing the dataset.
101
+ disable_dropout (`bool`, *optional*, defaults to `True`):
102
+ Whether to disable dropout in the model and reference model.
103
+
104
+ """
105
+ vllm_sampling_params: Optional[Any] = field(
106
+ default = None,
107
+ metadata = {'help': 'vLLM SamplingParams'},
108
+ )
109
+ unsloth_num_chunks : Optional[int] = field(
110
+ default = -1,
111
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
112
+ )
113
+ def __init__(
114
+ self,
115
+ output_dir = None,
116
+ overwrite_output_dir = None,
117
+ do_train = False,
118
+ do_eval = False,
119
+ do_predict = False,
120
+ eval_strategy = 'no',
121
+ prediction_loss_only = False,
122
+ per_device_train_batch_size = 4,
123
+ per_device_eval_batch_size = 4,
124
+ per_gpu_train_batch_size = None,
125
+ per_gpu_eval_batch_size = None,
126
+ gradient_accumulation_steps = 2,
127
+ eval_accumulation_steps = 2,
128
+ eval_delay = 0,
129
+ torch_empty_cache_steps = 250,
130
+ learning_rate = 5e-05,
131
+ weight_decay = 0.01,
132
+ adam_beta1 = 0.9,
133
+ adam_beta2 = 0.999,
134
+ adam_epsilon = 1e-08,
135
+ max_grad_norm = 1.0,
136
+ num_train_epochs = 3.0,
137
+ max_steps = -1,
138
+ lr_scheduler_type = 'linear',
139
+ warmup_ratio = 0.1,
140
+ warmup_steps = 0,
141
+ log_level = 'passive',
142
+ log_level_replica = 'warning',
143
+ log_on_each_node = True,
144
+ logging_dir = None,
145
+ logging_strategy = 'steps',
146
+ logging_first_step = False,
147
+ logging_steps = 1,
148
+ logging_nan_inf_filter = False,
149
+ save_strategy = 'steps',
150
+ save_steps = 500,
151
+ save_total_limit = None,
152
+ save_safetensors = True,
153
+ save_on_each_node = False,
154
+ save_only_model = False,
155
+ restore_callback_states_from_checkpoint = False,
156
+ no_cuda = False,
157
+ use_cpu = False,
158
+ use_mps_device = False,
159
+ seed = 3407,
160
+ data_seed = 3407,
161
+ jit_mode_eval = False,
162
+ use_ipex = False,
163
+ bf16 = False,
164
+ fp16 = False,
165
+ fp16_opt_level = 'O1',
166
+ half_precision_backend = 'auto',
167
+ bf16_full_eval = False,
168
+ fp16_full_eval = False,
169
+ tf32 = None,
170
+ local_rank = -1,
171
+ ddp_backend = None,
172
+ tpu_num_cores = None,
173
+ tpu_metrics_debug = False,
174
+ debug = '',
175
+ dataloader_drop_last = False,
176
+ eval_steps = None,
177
+ dataloader_num_workers = 0,
178
+ dataloader_prefetch_factor = None,
179
+ past_index = -1,
180
+ run_name = None,
181
+ disable_tqdm = None,
182
+ remove_unused_columns = True,
183
+ label_names = None,
184
+ load_best_model_at_end = False,
185
+ metric_for_best_model = None,
186
+ greater_is_better = None,
187
+ ignore_data_skip = False,
188
+ fsdp = '',
189
+ fsdp_min_num_params = 0,
190
+ fsdp_config = None,
191
+ tp_size = 0,
192
+ fsdp_transformer_layer_cls_to_wrap = None,
193
+ accelerator_config = None,
194
+ deepspeed = None,
195
+ label_smoothing_factor = 0.0,
196
+ optim = 'adamw_8bit',
197
+ optim_args = None,
198
+ adafactor = False,
199
+ group_by_length = False,
200
+ length_column_name = 'length',
201
+ report_to = None,
202
+ ddp_find_unused_parameters = None,
203
+ ddp_bucket_cap_mb = None,
204
+ ddp_broadcast_buffers = None,
205
+ dataloader_pin_memory = True,
206
+ dataloader_persistent_workers = False,
207
+ skip_memory_metrics = True,
208
+ use_legacy_prediction_loop = False,
209
+ push_to_hub = False,
210
+ resume_from_checkpoint = None,
211
+ hub_model_id = None,
212
+ hub_strategy = 'every_save',
213
+ hub_token = None,
214
+ hub_private_repo = None,
215
+ hub_always_push = False,
216
+ gradient_checkpointing = False,
217
+ gradient_checkpointing_kwargs = None,
218
+ include_inputs_for_metrics = False,
219
+ eval_do_concat_batches = True,
220
+ fp16_backend = 'auto',
221
+ evaluation_strategy = None,
222
+ push_to_hub_model_id = None,
223
+ push_to_hub_organization = None,
224
+ push_to_hub_token = None,
225
+ mp_parameters = '',
226
+ auto_find_batch_size = False,
227
+ full_determinism = False,
228
+ torchdynamo = None,
229
+ ray_scope = 'last',
230
+ ddp_timeout = 1800,
231
+ torch_compile = False,
232
+ torch_compile_backend = None,
233
+ torch_compile_mode = None,
234
+ dispatch_batches = None,
235
+ split_batches = None,
236
+ include_tokens_per_second = False,
237
+ include_num_input_tokens_seen = False,
238
+ neftune_noise_alpha = None,
239
+ optim_target_modules = None,
240
+ batch_eval_metrics = False,
241
+ eval_on_start = False,
242
+ use_liger_kernel = False,
243
+ eval_use_gather_object = False,
244
+ average_tokens_across_devices = False,
245
+ max_length = 1024,
246
+ max_prompt_length = 512,
247
+ max_completion_length = None,
248
+ beta = 0.1,
249
+ loss_type = 'kto',
250
+ desirable_weight = 1.0,
251
+ undesirable_weight = 1.0,
252
+ label_pad_token_id = -100,
253
+ padding_value = None,
254
+ truncation_mode = 'keep_end',
255
+ generate_during_eval = False,
256
+ is_encoder_decoder = None,
257
+ disable_dropout = True,
258
+ precompute_ref_log_probs = False,
259
+ model_init_kwargs = None,
260
+ ref_model_init_kwargs = None,
261
+ dataset_num_proc = None,
262
+ vllm_sampling_params = None,
263
+ unsloth_num_chunks = -1,
264
+ **kwargs,
265
+ ):
266
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
267
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
268
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
269
+ output_dir = 'unsloth_training_checkpoints'
270
+ save_strategy = 'no'
271
+ if dataset_num_proc is None:
272
+ from multiprocessing import cpu_count
273
+ dataset_num_proc = cpu_count()
274
+
275
+ super().__init__(
276
+ output_dir = output_dir,
277
+ overwrite_output_dir = overwrite_output_dir,
278
+ do_train = do_train,
279
+ do_eval = do_eval,
280
+ do_predict = do_predict,
281
+ eval_strategy = eval_strategy,
282
+ prediction_loss_only = prediction_loss_only,
283
+ per_device_train_batch_size = per_device_train_batch_size,
284
+ per_device_eval_batch_size = per_device_eval_batch_size,
285
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
286
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
287
+ gradient_accumulation_steps = gradient_accumulation_steps,
288
+ eval_accumulation_steps = eval_accumulation_steps,
289
+ eval_delay = eval_delay,
290
+ torch_empty_cache_steps = torch_empty_cache_steps,
291
+ learning_rate = learning_rate,
292
+ weight_decay = weight_decay,
293
+ adam_beta1 = adam_beta1,
294
+ adam_beta2 = adam_beta2,
295
+ adam_epsilon = adam_epsilon,
296
+ max_grad_norm = max_grad_norm,
297
+ num_train_epochs = num_train_epochs,
298
+ max_steps = max_steps,
299
+ lr_scheduler_type = lr_scheduler_type,
300
+ warmup_ratio = warmup_ratio,
301
+ warmup_steps = warmup_steps,
302
+ log_level = log_level,
303
+ log_level_replica = log_level_replica,
304
+ log_on_each_node = log_on_each_node,
305
+ logging_dir = logging_dir,
306
+ logging_strategy = logging_strategy,
307
+ logging_first_step = logging_first_step,
308
+ logging_steps = logging_steps,
309
+ logging_nan_inf_filter = logging_nan_inf_filter,
310
+ save_strategy = save_strategy,
311
+ save_steps = save_steps,
312
+ save_total_limit = save_total_limit,
313
+ save_safetensors = save_safetensors,
314
+ save_on_each_node = save_on_each_node,
315
+ save_only_model = save_only_model,
316
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
317
+ no_cuda = no_cuda,
318
+ use_cpu = use_cpu,
319
+ use_mps_device = use_mps_device,
320
+ seed = seed,
321
+ data_seed = data_seed,
322
+ jit_mode_eval = jit_mode_eval,
323
+ use_ipex = use_ipex,
324
+ bf16 = bf16,
325
+ fp16 = fp16,
326
+ fp16_opt_level = fp16_opt_level,
327
+ half_precision_backend = half_precision_backend,
328
+ bf16_full_eval = bf16_full_eval,
329
+ fp16_full_eval = fp16_full_eval,
330
+ tf32 = tf32,
331
+ local_rank = local_rank,
332
+ ddp_backend = ddp_backend,
333
+ tpu_num_cores = tpu_num_cores,
334
+ tpu_metrics_debug = tpu_metrics_debug,
335
+ debug = debug,
336
+ dataloader_drop_last = dataloader_drop_last,
337
+ eval_steps = eval_steps,
338
+ dataloader_num_workers = dataloader_num_workers,
339
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
340
+ past_index = past_index,
341
+ run_name = run_name,
342
+ disable_tqdm = disable_tqdm,
343
+ remove_unused_columns = remove_unused_columns,
344
+ label_names = label_names,
345
+ load_best_model_at_end = load_best_model_at_end,
346
+ metric_for_best_model = metric_for_best_model,
347
+ greater_is_better = greater_is_better,
348
+ ignore_data_skip = ignore_data_skip,
349
+ fsdp = fsdp,
350
+ fsdp_min_num_params = fsdp_min_num_params,
351
+ fsdp_config = fsdp_config,
352
+ tp_size = tp_size,
353
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
354
+ accelerator_config = accelerator_config,
355
+ deepspeed = deepspeed,
356
+ label_smoothing_factor = label_smoothing_factor,
357
+ optim = optim,
358
+ optim_args = optim_args,
359
+ adafactor = adafactor,
360
+ group_by_length = group_by_length,
361
+ length_column_name = length_column_name,
362
+ report_to = report_to,
363
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
364
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
365
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
366
+ dataloader_pin_memory = dataloader_pin_memory,
367
+ dataloader_persistent_workers = dataloader_persistent_workers,
368
+ skip_memory_metrics = skip_memory_metrics,
369
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
370
+ push_to_hub = push_to_hub,
371
+ resume_from_checkpoint = resume_from_checkpoint,
372
+ hub_model_id = hub_model_id,
373
+ hub_strategy = hub_strategy,
374
+ hub_token = hub_token,
375
+ hub_private_repo = hub_private_repo,
376
+ hub_always_push = hub_always_push,
377
+ gradient_checkpointing = gradient_checkpointing,
378
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
379
+ include_inputs_for_metrics = include_inputs_for_metrics,
380
+ eval_do_concat_batches = eval_do_concat_batches,
381
+ fp16_backend = fp16_backend,
382
+ evaluation_strategy = evaluation_strategy,
383
+ push_to_hub_model_id = push_to_hub_model_id,
384
+ push_to_hub_organization = push_to_hub_organization,
385
+ push_to_hub_token = push_to_hub_token,
386
+ mp_parameters = mp_parameters,
387
+ auto_find_batch_size = auto_find_batch_size,
388
+ full_determinism = full_determinism,
389
+ torchdynamo = torchdynamo,
390
+ ray_scope = ray_scope,
391
+ ddp_timeout = ddp_timeout,
392
+ torch_compile = torch_compile,
393
+ torch_compile_backend = torch_compile_backend,
394
+ torch_compile_mode = torch_compile_mode,
395
+ dispatch_batches = dispatch_batches,
396
+ split_batches = split_batches,
397
+ include_tokens_per_second = include_tokens_per_second,
398
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
399
+ neftune_noise_alpha = neftune_noise_alpha,
400
+ optim_target_modules = optim_target_modules,
401
+ batch_eval_metrics = batch_eval_metrics,
402
+ eval_on_start = eval_on_start,
403
+ use_liger_kernel = use_liger_kernel,
404
+ eval_use_gather_object = eval_use_gather_object,
405
+ average_tokens_across_devices = average_tokens_across_devices,
406
+ max_length = max_length,
407
+ max_prompt_length = max_prompt_length,
408
+ max_completion_length = max_completion_length,
409
+ beta = beta,
410
+ loss_type = loss_type,
411
+ desirable_weight = desirable_weight,
412
+ undesirable_weight = undesirable_weight,
413
+ label_pad_token_id = label_pad_token_id,
414
+ padding_value = padding_value,
415
+ truncation_mode = truncation_mode,
416
+ generate_during_eval = generate_during_eval,
417
+ is_encoder_decoder = is_encoder_decoder,
418
+ disable_dropout = disable_dropout,
419
+ precompute_ref_log_probs = precompute_ref_log_probs,
420
+ model_init_kwargs = model_init_kwargs,
421
+ ref_model_init_kwargs = ref_model_init_kwargs,
422
+ dataset_num_proc = dataset_num_proc,**kwargs)
423
+ self.vllm_sampling_params = vllm_sampling_params
424
+ self.unsloth_num_chunks = unsloth_num_chunks
425
+ pass
426
+
427
+ class _UnslothKTOTrainer(Trainer):
428
+ r""""""
429
+
430
+ _tag_names = ["trl", "kto"]
431
+
432
+ def __init__(
433
+ self,
434
+ model: Union[PreTrainedModel, nn.Module, str] = None,
435
+ ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
436
+ args: KTOConfig = None,
437
+ train_dataset: Optional[Dataset] = None,
438
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
439
+ processing_class: Optional[
440
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
441
+ ] = None,
442
+ data_collator: Optional[DataCollator] = None,
443
+ model_init: Optional[Callable[[], PreTrainedModel]] = None,
444
+ callbacks: Optional[list[TrainerCallback]] = None,
445
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
446
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
447
+ peft_config: Optional[dict] = None,
448
+ compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
449
+ model_adapter_name: Optional[str] = None,
450
+ ref_adapter_name: Optional[str] = None,
451
+ ):
452
+ if type(args) is TrainingArguments:
453
+ raise ValueError("Please use `KTOConfig` instead TrainingArguments.")
454
+
455
+ if not isinstance(model, str) and ref_model is model:
456
+ raise ValueError(
457
+ "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
458
+ "same as `model`, you must mass a copy of it, or `None` if you use peft."
459
+ )
460
+
461
+ if args.model_init_kwargs is None:
462
+ model_init_kwargs = {}
463
+ elif not isinstance(model, str):
464
+ raise ValueError("You passed model_kwargs to the KTOTrainer. But your model is already instantiated.")
465
+ else:
466
+ model_init_kwargs = args.model_init_kwargs
467
+ torch_dtype = model_init_kwargs.get("torch_dtype")
468
+ if torch_dtype is not None:
469
+ # Convert to `torch.dtype` if an str is passed
470
+ if isinstance(torch_dtype, str) and torch_dtype != "auto":
471
+ torch_dtype = getattr(torch, torch_dtype)
472
+ if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
473
+ raise ValueError(
474
+ f"Invalid `torch_dtype` passed to the KTOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
475
+ )
476
+ model_init_kwargs["torch_dtype"] = torch_dtype
477
+
478
+ if args.ref_model_init_kwargs is None:
479
+ ref_model_init_kwargs = {}
480
+ elif not isinstance(ref_model, str):
481
+ raise ValueError(
482
+ "You passed ref_model_kwargs to the KTOTrainer. But your ref_model is already instantiated."
483
+ )
484
+ else:
485
+ ref_model_init_kwargs = args.ref_model_init_kwargs
486
+ torch_dtype = ref_model_init_kwargs.get("torch_dtype")
487
+ if torch_dtype is not None:
488
+ # Convert to `torch.dtype` if an str is passed
489
+ if isinstance(torch_dtype, str) and torch_dtype != "auto":
490
+ torch_dtype = getattr(torch, torch_dtype)
491
+ if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
492
+ raise ValueError(
493
+ f"Invalid `torch_dtype` passed to the KTOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
494
+ )
495
+ ref_model_init_kwargs["torch_dtype"] = torch_dtype
496
+
497
+ if isinstance(model, str):
498
+ model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
499
+
500
+ if isinstance(ref_model, str):
501
+ ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **ref_model_init_kwargs)
502
+
503
+ # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
504
+ # has been called in order to properly call autocast if needed.
505
+ self._peft_has_been_casted_to_bf16 = False
506
+
507
+ if not is_peft_available() and peft_config is not None:
508
+ raise ValueError(
509
+ "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it with `pip install peft` to use the PEFT models"
510
+ )
511
+ elif is_peft_available() and peft_config is not None:
512
+ # if model is a peft model and we have a peft_config, we merge and unload it first
513
+ if isinstance(model, PeftModel):
514
+ model = model.merge_and_unload()
515
+
516
+ if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
517
+ _support_gc_kwargs = hasattr(
518
+ args, "gradient_checkpointing_kwargs"
519
+ ) and "gradient_checkpointing_kwargs" in list(
520
+ inspect.signature(prepare_model_for_kbit_training).parameters
521
+ )
522
+
523
+ prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
524
+
525
+ if _support_gc_kwargs:
526
+ prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
527
+
528
+ model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
529
+ elif getattr(args, "gradient_checkpointing", False):
530
+ # For backward compatibility with older versions of transformers
531
+ if hasattr(model, "enable_input_require_grads"):
532
+ model.enable_input_require_grads()
533
+ else:
534
+
535
+ def make_inputs_require_grad(module, input, output):
536
+ output.requires_grad_(True)
537
+
538
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
539
+
540
+ # get peft model with the given config
541
+ model = model
542
+ if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
543
+ peft_module_casting_to_bf16(model)
544
+ # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
545
+ self._peft_has_been_casted_to_bf16 = True
546
+
547
+ # For models that use gradient_checkpointing, we need to attach a hook that enables input
548
+ # to explicitly have `requires_grad=True`, otherwise training will either silently
549
+ # fail or completely fail.
550
+ elif getattr(args, "gradient_checkpointing", False):
551
+ # For backward compatibility with older versions of transformers
552
+ if hasattr(model, "enable_input_require_grads"):
553
+ model.enable_input_require_grads()
554
+ else:
555
+
556
+ def make_inputs_require_grad(module, input, output):
557
+ output.requires_grad_(True)
558
+
559
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
560
+
561
+ if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
562
+ raise ValueError(
563
+ "`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
564
+ " Please install `wandb` or `comet-ml` to resolve."
565
+ )
566
+
567
+ if model is not None:
568
+ self.is_encoder_decoder = model.config.is_encoder_decoder
569
+ elif args.is_encoder_decoder is None:
570
+ raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
571
+ else:
572
+ self.is_encoder_decoder = args.is_encoder_decoder
573
+
574
+ self.is_peft_model = is_peft_available() and isinstance(model, PeftModel)
575
+ self.model_adapter_name = model_adapter_name
576
+ self.ref_adapter_name = ref_adapter_name
577
+
578
+ if ref_model:
579
+ self.ref_model = ref_model
580
+ elif self.is_peft_model or args.precompute_ref_log_probs:
581
+ # The `model` with adapters turned off will be used as the reference model
582
+ self.ref_model = None
583
+ else:
584
+ self.ref_model = create_reference_model(model)
585
+
586
+ if processing_class is None:
587
+ raise ValueError(
588
+ "max_length or a processing_class must be specified when using the default DPODataCollatorWithPadding"
589
+ )
590
+ if args.max_length is None:
591
+ warnings.warn(
592
+ "When using DPODataCollatorWithPadding, you should set `max_length` in the KTOTrainer's init"
593
+ " it will be set to `512` by default, but you should do it yourself in the future.",
594
+ UserWarning,
595
+ )
596
+ max_length = 512
597
+ if args.max_length is not None:
598
+ max_length = args.max_length
599
+
600
+ if args.max_prompt_length is None:
601
+ warnings.warn(
602
+ "When using DPODataCollatorWithPadding, you should set `max_prompt_length` in the KTOTrainer's init"
603
+ " it will be set to `128` by default, but you should do it yourself in the future.",
604
+ UserWarning,
605
+ )
606
+ max_prompt_length = 128
607
+ if args.max_prompt_length is not None:
608
+ max_prompt_length = args.max_prompt_length
609
+
610
+ max_completion_length = None
611
+ if args.max_completion_length is None and self.is_encoder_decoder:
612
+ warnings.warn(
613
+ "When using DPODataCollatorWithPadding with an encoder decoder architecture, you should set `max_completion_length` in the KTOTrainer's init"
614
+ " it will be set to `128` by default, but you should do it yourself in the future.",
615
+ UserWarning,
616
+ )
617
+ max_completion_length = 128
618
+ if args.max_completion_length is not None and self.is_encoder_decoder:
619
+ max_completion_length = args.max_completion_length
620
+
621
+ if data_collator is None:
622
+ data_collator = DPODataCollatorWithPadding(
623
+ pad_token_id=processing_class.pad_token_id,
624
+ label_pad_token_id=args.label_pad_token_id,
625
+ is_encoder_decoder=self.is_encoder_decoder,
626
+ )
627
+
628
+ if args.remove_unused_columns:
629
+ args.remove_unused_columns = False
630
+ # warn users
631
+ warnings.warn(
632
+ "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your KTOConfig"
633
+ " we have set it for you, but you should do it yourself in the future.",
634
+ UserWarning,
635
+ )
636
+
637
+ self.use_dpo_data_collator = True
638
+ else:
639
+ self.use_dpo_data_collator = False
640
+
641
+ # Disable dropout in the model and reference model
642
+ if args.disable_dropout:
643
+ disable_dropout_in_model(model)
644
+ if self.ref_model is not None:
645
+ disable_dropout_in_model(self.ref_model)
646
+
647
+ self.loss_type = args.loss_type
648
+ self.max_length = max_length
649
+ self.generate_during_eval = args.generate_during_eval
650
+ self.label_pad_token_id = args.label_pad_token_id
651
+ self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id
652
+ self.max_prompt_length = max_prompt_length
653
+ self.truncation_mode = args.truncation_mode
654
+ self.max_completion_length = max_completion_length
655
+ self.processing_class = processing_class
656
+ self.precompute_ref_log_probs = args.precompute_ref_log_probs
657
+
658
+ # Not all losses require a KL calculation
659
+ self.calculate_KL = True
660
+ if self.loss_type in ["apo_zero_unpaired"]:
661
+ self.calculate_KL = False
662
+
663
+ # Since ref_logs are precomputed on the first call to get_train/eval_dataloader
664
+ # keep track of first called to avoid computation of future calls
665
+ self._precomputed_train_ref_log_probs = False
666
+ self._precomputed_eval_ref_log_probs = False
667
+
668
+ # metric
669
+ self._stored_metrics = defaultdict(lambda: defaultdict(list))
670
+
671
+ # KTO parameter
672
+ self.beta = args.beta
673
+ self.desirable_weight = args.desirable_weight
674
+ self.undesirable_weight = args.undesirable_weight
675
+ self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
676
+ self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
677
+ if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
678
+ warnings.warn(
679
+ "You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
680
+ "`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
681
+ "greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
682
+ "loss.",
683
+ UserWarning,
684
+ )
685
+
686
+ # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
687
+ # input tensor associated with the key "input_ids". However, in KTO, the sampled data does not include the
688
+ # "input_ids" key. Instead, the available keys are "prompt_input_ids" and "completion_input_ids". As a result,
689
+ # the trainer issues the warning: "Could not estimate the number of tokens of the input, floating-point
690
+ # operations will not be computed." To suppress this warning, we set the "estimate_tokens" key in the model's
691
+ # "warnings_issued" dictionary to True. This acts as a flag to indicate that the warning has already been
692
+ # issued.
693
+ model.warnings_issued["estimate_tokens"] = True
694
+
695
+ # Compute that only on the main process for faster data processing.
696
+ # see: https://github.com/huggingface/trl/pull/1255
697
+ with PartialState().local_main_process_first():
698
+ # Extract the prompt if needed
699
+ train_dataset = train_dataset.map(
700
+ maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from train dataset"
701
+ )
702
+ # Unpair the dataset if needed
703
+ train_dataset = maybe_unpair_preference_dataset(
704
+ train_dataset, args.dataset_num_proc, desc="Unpairing train dataset"
705
+ )
706
+ # Apply the chat template if needed
707
+ train_dataset = train_dataset.map(
708
+ maybe_apply_chat_template,
709
+ fn_kwargs={"tokenizer": processing_class},
710
+ num_proc=args.dataset_num_proc,
711
+ desc="Applying chat template to train dataset",
712
+ )
713
+ if eval_dataset is not None:
714
+ eval_dataset = eval_dataset.map(
715
+ maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from eval dataset"
716
+ )
717
+ eval_dataset = maybe_unpair_preference_dataset(
718
+ eval_dataset, args.dataset_num_proc, desc="Unpairing eval dataset"
719
+ )
720
+ eval_dataset = eval_dataset.map(
721
+ maybe_apply_chat_template,
722
+ fn_kwargs={"tokenizer": processing_class},
723
+ num_proc=args.dataset_num_proc,
724
+ desc="Applying chat template to eval dataset",
725
+ )
726
+
727
+ # Tokenize and prepare the training datasets
728
+ train_dataset = train_dataset.map(
729
+ _tokenize,
730
+ batched=True,
731
+ fn_kwargs={"tokenizer": self.processing_class},
732
+ num_proc=args.dataset_num_proc,
733
+ desc="Tokenizing train dataset",
734
+ )
735
+
736
+ fn_kwargs = {
737
+ "prefix": "",
738
+ "is_encoder_decoder": self.is_encoder_decoder,
739
+ "tokenizer": self.processing_class,
740
+ "max_length": self.max_length,
741
+ "truncation_mode": self.truncation_mode,
742
+ "label_pad_token_id": self.label_pad_token_id,
743
+ "max_prompt_length": self.max_prompt_length,
744
+ "max_completion_length": self.max_completion_length,
745
+ }
746
+
747
+ train_dataset = train_dataset.map(
748
+ _process_tokens,
749
+ fn_kwargs=fn_kwargs,
750
+ num_proc=args.dataset_num_proc,
751
+ desc="Processing tokenized train dataset",
752
+ )
753
+
754
+ # Tokenize and prepare the eval datasets
755
+ if eval_dataset is not None:
756
+ eval_dataset = eval_dataset.map(
757
+ _tokenize,
758
+ fn_kwargs={"tokenizer": self.processing_class},
759
+ batched=True,
760
+ num_proc=args.dataset_num_proc,
761
+ desc="Tokenizing eval dataset",
762
+ )
763
+
764
+ eval_dataset = eval_dataset.map(
765
+ _process_tokens,
766
+ fn_kwargs=fn_kwargs,
767
+ num_proc=args.dataset_num_proc,
768
+ desc="Processing tokenized eval dataset",
769
+ )
770
+
771
+ # Get KL datasets if needed
772
+ if self.calculate_KL:
773
+ if args.per_device_train_batch_size <= 1:
774
+ raise ValueError(
775
+ "Actual (not effective) batch size must be > 1. KTO will not work properly because the KL term will be equivalent to the implied reward."
776
+ )
777
+
778
+ # create pairs for estimating the KL term by flipping the matched pairs in each batch of size total_batch_size
779
+ # i.e., (x_1, y_1), ..., (x_n, y_n) --> (x_1, y_n), ..., (x_n, y_1) = (x'_1, y'_1), ..., (x'_n, y'_n)
780
+ train_kl_dataset = train_dataset.map(
781
+ _get_kl_dataset,
782
+ batched=True,
783
+ batch_size=args.per_device_train_batch_size,
784
+ num_proc=args.dataset_num_proc,
785
+ desc="Extracting KL train dataset",
786
+ )
787
+
788
+ fn_kwargs["prefix"] = "KL_"
789
+ train_kl_dataset = train_kl_dataset.map(
790
+ _process_tokens,
791
+ fn_kwargs=fn_kwargs,
792
+ num_proc=args.dataset_num_proc,
793
+ remove_columns=[c for c in train_kl_dataset.column_names if c in train_dataset.column_names],
794
+ desc="Processing tokenized train KL dataset",
795
+ )
796
+
797
+ # merge the datasets
798
+ train_dataset = concatenate_datasets([train_dataset, train_kl_dataset], axis=1)
799
+
800
+ if eval_dataset is not None:
801
+ # Get KL dataset
802
+ eval_kl_dataset = eval_dataset.map(
803
+ _get_kl_dataset,
804
+ batched=True,
805
+ batch_size=args.per_device_train_batch_size,
806
+ num_proc=args.dataset_num_proc,
807
+ desc="Extracting eval KL dataset",
808
+ )
809
+
810
+ eval_kl_dataset = eval_kl_dataset.map(
811
+ _process_tokens,
812
+ fn_kwargs=fn_kwargs,
813
+ num_proc=args.dataset_num_proc,
814
+ remove_columns=[c for c in eval_kl_dataset.column_names if c in eval_dataset.column_names],
815
+ desc="Processing tokenized eval KL dataset",
816
+ )
817
+
818
+ # merge the datasets
819
+ eval_dataset = concatenate_datasets([eval_dataset, eval_kl_dataset], axis=1)
820
+
821
+ # calculate dataset desirability balance
822
+ num_desirable = max(sum(train_dataset["label"]), 1)
823
+ num_undesirable = max(len(train_dataset["label"]) - num_desirable, 1) # "label" is binary
824
+
825
+ if num_desirable != num_undesirable:
826
+ # The lower and upper bounds come from Eq. (8) of https://huggingface.co/papers/2402.01306
827
+ des_weight_lower_bound = round((num_undesirable * self.undesirable_weight / num_desirable) * 1, 2)
828
+ des_weight_upper_bound = round((num_undesirable * self.undesirable_weight / num_desirable) * 1.33, 2)
829
+ und_weight_lower_bound = round((num_desirable * self.desirable_weight / num_undesirable) / 1.33, 2)
830
+ und_weight_upper_bound = round((num_desirable * self.desirable_weight / num_undesirable) / 1, 2)
831
+
832
+ des_weight_in_range = des_weight_lower_bound <= self.desirable_weight <= des_weight_upper_bound
833
+ und_weight_in_range = und_weight_lower_bound <= self.undesirable_weight <= und_weight_upper_bound
834
+
835
+ if not (des_weight_in_range or und_weight_in_range):
836
+ warnings.warn(
837
+ "You have different amounts of desirable/positive and undesirable/negative examples but the "
838
+ "weights on the desirable and undesirable losses don't seem to be in an ideal range. Based "
839
+ f"on your data, we recommend EITHER "
840
+ f"desirable_weight in [{des_weight_lower_bound}, {des_weight_upper_bound}] or "
841
+ f"undesirable_weight in [{und_weight_lower_bound}, {und_weight_upper_bound}] (but NOT BOTH). "
842
+ "See the documentation on how to optimally set these weights.",
843
+ UserWarning,
844
+ )
845
+
846
+ super().__init__(
847
+ model=model,
848
+ args=args,
849
+ data_collator=data_collator,
850
+ train_dataset=train_dataset,
851
+ eval_dataset=eval_dataset,
852
+ processing_class=processing_class,
853
+ model_init=model_init,
854
+ compute_metrics=compute_metrics,
855
+ callbacks=callbacks,
856
+ optimizers=optimizers,
857
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
858
+ )
859
+
860
+ # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
861
+ # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
862
+ # self.model_accepts_loss_kwargs to False to enable scaling.
863
+ self.model_accepts_loss_kwargs = False
864
+
865
+ # Add tags for models that have been loaded with the correct transformers version
866
+ if hasattr(self.model, "add_model_tags"):
867
+ self.model.add_model_tags(self._tag_names)
868
+
869
+ if not hasattr(self, "accelerator"):
870
+ raise AttributeError(
871
+ "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
872
+ )
873
+
874
+ # Deepspeed Zero-3 does not support precompute_ref_log_probs
875
+ if self.is_deepspeed_enabled:
876
+ if self.accelerator.state.deepspeed_plugin.zero_stage == 3 and self.precompute_ref_log_probs:
877
+ raise ValueError(
878
+ "You cannot use `precompute_ref_log_probs=True` with Deepspeed ZeRO-3. Please set `precompute_ref_log_probs=False`."
879
+ )
880
+
881
+ if self.ref_model is None:
882
+ if not (self.is_peft_model or self.precompute_ref_log_probs):
883
+ raise ValueError(
884
+ "No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`"
885
+ )
886
+ else:
887
+ if self.is_deepspeed_enabled:
888
+ self.ref_model = self._prepare_deepspeed(self.ref_model)
889
+ else:
890
+ self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
891
+
892
+ def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
893
+ # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
894
+ deepspeed_plugin = self.accelerator.state.deepspeed_plugin
895
+ config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
896
+
897
+ if model is not None:
898
+ if hasattr(model, "config"):
899
+ hidden_size = (
900
+ max(model.config.hidden_sizes)
901
+ if getattr(model.config, "hidden_sizes", None)
902
+ else getattr(model.config, "hidden_size", None)
903
+ )
904
+ if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
905
+ # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
906
+ # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
907
+ config_kwargs.update(
908
+ {
909
+ "zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
910
+ "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
911
+ "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
912
+ }
913
+ )
914
+
915
+ # If ZeRO-3 is used, we shard both the active and reference model.
916
+ # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
917
+ if config_kwargs["zero_optimization"]["stage"] != 3:
918
+ config_kwargs["zero_optimization"]["stage"] = 0
919
+ model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
920
+ model.eval()
921
+ return model
922
+
923
+ @contextmanager
924
+ def null_ref_context(self):
925
+ """Context manager for handling null reference model (that is, peft adapter manipulation)."""
926
+ with (
927
+ self.accelerator.unwrap_model(self.model).disable_adapter()
928
+ if self.is_peft_model and not self.ref_adapter_name
929
+ else nullcontext()
930
+ ):
931
+ if self.ref_adapter_name:
932
+ self.model.set_adapter(self.ref_adapter_name)
933
+ yield
934
+ if self.ref_adapter_name:
935
+ self.model.set_adapter(self.model_adapter_name or "default")
936
+
937
+ def get_train_dataloader(self) -> DataLoader:
938
+ """
939
+ Returns the training [`~torch.utils.data.DataLoader`].
940
+
941
+ Subclass of transformers.src.transformers.trainer.get_train_dataloader to precompute `ref_log_probs`.
942
+ """
943
+
944
+ if self.precompute_ref_log_probs and not self._precomputed_train_ref_log_probs:
945
+ dataloader_params = {
946
+ "batch_size": self.args.per_device_train_batch_size,
947
+ "collate_fn": self.data_collator,
948
+ "num_workers": self.args.dataloader_num_workers,
949
+ "pin_memory": self.args.dataloader_pin_memory,
950
+ "shuffle": False,
951
+ }
952
+
953
+ # prepare dataloader
954
+ data_loader = self.accelerator.prepare(DataLoader(self.train_dataset, **dataloader_params))
955
+ reference_completion_logps = []
956
+ reference_KL_logps = []
957
+
958
+ for padded_batch in tqdm(iterable=data_loader, desc="Train dataset reference log probs"):
959
+ reference_completion_logp, reference_KL_logp = self.compute_reference_log_probs(padded_batch)
960
+
961
+ reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp)
962
+ reference_completion_logps.append(reference_completion_logp.cpu())
963
+
964
+ if self.calculate_KL:
965
+ reference_KL_logp = self.accelerator.gather_for_metrics(reference_KL_logp)
966
+ reference_KL_logps.append(reference_KL_logp.cpu())
967
+
968
+ self.train_dataset = self.train_dataset.add_column(
969
+ name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy()
970
+ )
971
+
972
+ if self.calculate_KL:
973
+ self.train_dataset = self.train_dataset.add_column(
974
+ name="reference_KL_logps", column=torch.cat(reference_KL_logps).float().numpy()
975
+ )
976
+
977
+ self._precomputed_train_ref_log_probs = True
978
+
979
+ return super().get_train_dataloader()
980
+
981
+ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
982
+ """
983
+ Returns the evaluation [`~torch.utils.data.DataLoader`].
984
+
985
+ Subclass of transformers.src.transformers.trainer.get_eval_dataloader to precompute `ref_log_probs`.
986
+
987
+ Args:
988
+ eval_dataset (`torch.utils.data.Dataset`, *optional*):
989
+ If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted
990
+ by the `model.forward()` method are automatically removed. It must implement `__len__`.
991
+ """
992
+ if eval_dataset is None and self.eval_dataset is None:
993
+ raise ValueError("Trainer: evaluation requires an eval_dataset.")
994
+ eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
995
+
996
+ if self.precompute_ref_log_probs and not self._precomputed_eval_ref_log_probs:
997
+ dataloader_params = {
998
+ "batch_size": self.args.per_device_eval_batch_size,
999
+ "collate_fn": self.data_collator,
1000
+ "num_workers": self.args.dataloader_num_workers,
1001
+ "pin_memory": self.args.dataloader_pin_memory,
1002
+ "shuffle": False,
1003
+ }
1004
+
1005
+ # prepare dataloader
1006
+ data_loader = self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params))
1007
+
1008
+ reference_completion_logps = []
1009
+ reference_KL_logps = []
1010
+
1011
+ for padded_batch in tqdm(iterable=data_loader, desc="Eval dataset reference log probs"):
1012
+ reference_completion_logp, reference_KL_logp = self.compute_reference_log_probs(padded_batch)
1013
+
1014
+ reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp)
1015
+ reference_completion_logps.append(reference_completion_logp.cpu())
1016
+
1017
+ if self.calculate_KL:
1018
+ reference_KL_logp = self.accelerator.gather_for_metrics(reference_KL_logp)
1019
+ reference_KL_logps.append(reference_KL_logp.cpu())
1020
+
1021
+ eval_dataset = eval_dataset.add_column(
1022
+ name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy()
1023
+ )
1024
+ if self.calculate_KL:
1025
+ eval_dataset = eval_dataset.add_column(
1026
+ name="reference_KL_logps", column=torch.cat(reference_KL_logps).float().numpy()
1027
+ )
1028
+
1029
+ # Save calculated reference_chosen_logps and reference_rejected_logps to the eval_dataset for subsequent runs
1030
+ if self.eval_dataset is not None:
1031
+ self.eval_dataset = eval_dataset
1032
+ self._precomputed_eval_ref_log_probs = True
1033
+
1034
+ return super().get_eval_dataloader(eval_dataset=eval_dataset)
1035
+
1036
+ def compute_reference_log_probs(self, padded_batch: dict) -> dict:
1037
+ """Computes log probabilities of the reference model for a single padded batch of a KTO specific dataset."""
1038
+ with torch.no_grad():
1039
+ if self.ref_model is None:
1040
+ with self.null_ref_context():
1041
+ if self.is_encoder_decoder:
1042
+ completion_logits = self.model(
1043
+ padded_batch["prompt_input_ids"],
1044
+ attention_mask=padded_batch["prompt_attention_mask"],
1045
+ decoder_input_ids=padded_batch.get("completion_decoder_input_ids"),
1046
+ labels=padded_batch["completion_labels"],
1047
+ ).logits
1048
+
1049
+ if self.calculate_KL:
1050
+ KL_logits = self.model(
1051
+ padded_batch["KL_prompt_input_ids"],
1052
+ attention_mask=padded_batch["KL_prompt_attention_mask"],
1053
+ decoder_input_ids=padded_batch.get("KL_completion_decoder_input_ids"),
1054
+ labels=padded_batch["KL_completion_labels"],
1055
+ ).logits
1056
+ else:
1057
+ completion_logits = self.model(
1058
+ padded_batch["completion_input_ids"],
1059
+ attention_mask=padded_batch["completion_attention_mask"],
1060
+ ).logits
1061
+
1062
+ if self.calculate_KL:
1063
+ KL_logits = self.model(
1064
+ padded_batch["KL_completion_input_ids"],
1065
+ attention_mask=padded_batch["KL_completion_attention_mask"],
1066
+ ).logits
1067
+ else:
1068
+ if self.is_encoder_decoder:
1069
+ completion_logits = self.ref_model(
1070
+ padded_batch["prompt_input_ids"],
1071
+ attention_mask=padded_batch["prompt_attention_mask"],
1072
+ decoder_input_ids=padded_batch.get("completion_decoder_input_ids"),
1073
+ labels=padded_batch["completion_labels"],
1074
+ ).logits
1075
+
1076
+ if self.calculate_KL:
1077
+ KL_logits = self.ref_model(
1078
+ padded_batch["KL_prompt_input_ids"],
1079
+ attention_mask=padded_batch["KL_prompt_attention_mask"],
1080
+ decoder_input_ids=padded_batch.get("KL_completion_decoder_input_ids"),
1081
+ labels=padded_batch["KL_completion_labels"],
1082
+ ).logits
1083
+ else:
1084
+ completion_logits = self.ref_model(
1085
+ padded_batch["completion_input_ids"], attention_mask=padded_batch["completion_attention_mask"]
1086
+ ).logits
1087
+
1088
+ if self.calculate_KL:
1089
+ KL_logits = self.ref_model(
1090
+ padded_batch["KL_completion_input_ids"],
1091
+ attention_mask=padded_batch["KL_completion_attention_mask"],
1092
+ ).logits
1093
+
1094
+ completion_logps = self.get_batch_logps(
1095
+ completion_logits,
1096
+ padded_batch["completion_labels"],
1097
+ average_log_prob=False,
1098
+ is_encoder_decoder=self.is_encoder_decoder,
1099
+ label_pad_token_id=self.label_pad_token_id,
1100
+ )
1101
+
1102
+ if self.calculate_KL:
1103
+ KL_logps = self.get_batch_logps(
1104
+ KL_logits,
1105
+ padded_batch["KL_completion_labels"],
1106
+ average_log_prob=False,
1107
+ is_encoder_decoder=self.is_encoder_decoder,
1108
+ label_pad_token_id=self.label_pad_token_id,
1109
+ )
1110
+ else:
1111
+ KL_logps = None
1112
+
1113
+ return completion_logps, KL_logps
1114
+
1115
+ @staticmethod
1116
+ def get_batch_logps(
1117
+ logits: torch.FloatTensor,
1118
+ labels: torch.LongTensor,
1119
+ average_log_prob: bool = False,
1120
+ label_pad_token_id: int = -100,
1121
+ is_encoder_decoder: bool = False,
1122
+ ) -> torch.FloatTensor:
1123
+ """Compute the log probabilities of the given labels under the given logits.
1124
+
1125
+ Args:
1126
+ logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
1127
+ labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
1128
+ average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
1129
+
1130
+ Returns:
1131
+ A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
1132
+ """
1133
+ if logits.shape[:-1] != labels.shape:
1134
+ raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
1135
+
1136
+ if not is_encoder_decoder:
1137
+ labels = labels[:, 1:].clone()
1138
+ logits = logits[:, :-1, :]
1139
+ else:
1140
+ # Fixes end-dec RuntimeError
1141
+ labels = labels.clone()
1142
+
1143
+ loss_mask = labels != label_pad_token_id
1144
+
1145
+ # dummy token; we'll ignore the losses on these tokens later
1146
+ labels[labels == label_pad_token_id] = 0
1147
+
1148
+ per_token_logps = selective_log_softmax(logits, labels)
1149
+
1150
+ if average_log_prob:
1151
+ return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
1152
+ else:
1153
+ return (per_token_logps * loss_mask).sum(-1)
1154
+
1155
+ def forward(
1156
+ self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
1157
+ ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
1158
+ if self.calculate_KL:
1159
+ KL_logps = None
1160
+ KL_model_kwargs = (
1161
+ {
1162
+ "input_ids": batch["KL_prompt_input_ids"],
1163
+ "attention_mask": batch["KL_prompt_attention_mask"],
1164
+ "labels": batch["KL_completion_labels"],
1165
+ "decoder_input_ids": batch.get("KL_completion_decoder_input_ids"),
1166
+ }
1167
+ if self.is_encoder_decoder
1168
+ else {
1169
+ "input_ids": batch["KL_completion_input_ids"],
1170
+ "attention_mask": batch["KL_completion_attention_mask"],
1171
+ }
1172
+ )
1173
+ with torch.no_grad():
1174
+ KL_logits = model(
1175
+ **KL_model_kwargs,
1176
+ ).logits
1177
+
1178
+ KL_logps = self.get_batch_logps(
1179
+ KL_logits,
1180
+ batch["KL_completion_labels"],
1181
+ average_log_prob=False,
1182
+ is_encoder_decoder=self.is_encoder_decoder,
1183
+ label_pad_token_id=self.label_pad_token_id,
1184
+ )
1185
+ else:
1186
+ KL_logps = None
1187
+
1188
+ model_kwargs = (
1189
+ {
1190
+ "labels": batch["completion_labels"],
1191
+ "decoder_input_ids": batch.get("completion_decoder_input_ids"),
1192
+ }
1193
+ if self.is_encoder_decoder
1194
+ else {}
1195
+ )
1196
+ if self.aux_loss_enabled:
1197
+ model_kwargs["output_router_logits"] = True
1198
+
1199
+ outputs = model(
1200
+ batch["completion_input_ids"],
1201
+ attention_mask=batch["completion_attention_mask"],
1202
+ **model_kwargs,
1203
+ )
1204
+ completion_logits = outputs.logits
1205
+
1206
+ completion_logps = self.get_batch_logps(
1207
+ completion_logits,
1208
+ batch["completion_labels"],
1209
+ average_log_prob=False,
1210
+ is_encoder_decoder=self.is_encoder_decoder,
1211
+ label_pad_token_id=self.label_pad_token_id,
1212
+ )
1213
+
1214
+ if completion_logps.shape[0] != len(batch["label"]):
1215
+ raise ValueError(
1216
+ "There is a mismatch between the number of examples in this batch and the number of "
1217
+ "examples for which an output sequence was predicted."
1218
+ )
1219
+
1220
+ chosen_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is True]
1221
+ rejected_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is False]
1222
+
1223
+ chosen_logps = completion_logps[chosen_idx, ...]
1224
+ rejected_logps = completion_logps[rejected_idx, ...]
1225
+
1226
+ chosen_logits = completion_logits[chosen_idx, ...]
1227
+ rejected_logits = completion_logits[rejected_idx, ...]
1228
+
1229
+ if self.aux_loss_enabled:
1230
+ return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, KL_logps, outputs.aux_loss)
1231
+ else:
1232
+ return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, KL_logps)
1233
+
1234
+ def kto_loss(
1235
+ self,
1236
+ policy_chosen_logps: torch.FloatTensor,
1237
+ policy_rejected_logps: torch.FloatTensor,
1238
+ policy_KL_logps: torch.FloatTensor,
1239
+ reference_chosen_logps: torch.FloatTensor,
1240
+ reference_rejected_logps: torch.FloatTensor,
1241
+ reference_KL_logps: torch.FloatTensor,
1242
+ ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
1243
+ """Compute the KTO loss for a batch of policy and reference model log probabilities.
1244
+
1245
+ Args:
1246
+ policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (num(chosen) in batch_size,)
1247
+ policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (num(rejected) in batch_size,)
1248
+ policy_KL_logps: Log probabilities of the policy model for the KL responses. Shape: (batch_size,)
1249
+ reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (num(chosen) in batch_size,)
1250
+ reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (num(rejected) in batch_size,)
1251
+ reference_KL_logps: Log probabilities of the reference model for the KL responses. Shape: (batch_size,)
1252
+
1253
+ Returns:
1254
+ A tuple of four tensors: (losses, chosen_rewards, rejected_rewards, KL).
1255
+ The losses tensor contains the KTO loss for each example in the batch.
1256
+ The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
1257
+ The KL tensor contains the detached KL divergence estimate between the policy and reference models.
1258
+ """
1259
+ if self.calculate_KL:
1260
+ kl = (policy_KL_logps - reference_KL_logps).mean().detach()
1261
+ kl = self.accelerator.gather_for_metrics(kl).mean().clamp(min=0)
1262
+ else:
1263
+ kl = torch.zeros(1).to(policy_chosen_logps.device)
1264
+
1265
+ # Chosen losses
1266
+ if policy_chosen_logps.shape[0] != 0 or reference_chosen_logps.shape[0] != 0:
1267
+ chosen_logratios = policy_chosen_logps - reference_chosen_logps
1268
+
1269
+ if self.loss_type == "kto":
1270
+ # Eqn (7) of the KTO paper (https://huggingface.co/papers/2402.01306)
1271
+ chosen_losses = 1 - F.sigmoid(self.beta * (chosen_logratios - kl))
1272
+ elif self.loss_type == "apo_zero_unpaired":
1273
+ # Unpaired variant of Eqn (7) of the APO paper (https://huggingface.co/papers/2408.06266)
1274
+ # Use this loss when you believe the chosen outputs are better than your model's default output
1275
+ chosen_losses = 1 - F.sigmoid(self.beta * chosen_logratios)
1276
+
1277
+ chosen_rewards = self.beta * chosen_logratios.detach()
1278
+
1279
+ else:
1280
+ # lists can't be empty -- if they are, then accelerate.gather will hang
1281
+ chosen_losses = torch.Tensor([]).to(self.accelerator.device)
1282
+ chosen_rewards = torch.Tensor([]).to(self.accelerator.device)
1283
+
1284
+ # Rejected losses
1285
+ if policy_rejected_logps.shape[0] != 0 or reference_rejected_logps.shape[0] != 0:
1286
+ rejected_logratios = policy_rejected_logps - reference_rejected_logps
1287
+
1288
+ if self.loss_type == "kto":
1289
+ rejected_losses = 1 - F.sigmoid(self.beta * (kl - rejected_logratios))
1290
+ elif self.loss_type == "apo_zero_unpaired":
1291
+ rejected_losses = F.sigmoid(self.beta * rejected_logratios)
1292
+
1293
+ rejected_rewards = self.beta * rejected_logratios.detach()
1294
+ else:
1295
+ # lists can't be empty -- if they are, then accelerate.gather will hang
1296
+ rejected_losses = torch.Tensor([]).to(self.accelerator.device)
1297
+ rejected_rewards = torch.Tensor([]).to(self.accelerator.device)
1298
+
1299
+ losses = torch.cat(
1300
+ (self.desirable_weight * chosen_losses, self.undesirable_weight * rejected_losses),
1301
+ 0,
1302
+ )
1303
+
1304
+ return losses, chosen_rewards, rejected_rewards, kl
1305
+
1306
+ def get_batch_loss_metrics(
1307
+ self,
1308
+ model,
1309
+ batch: dict[str, Union[list, torch.LongTensor]],
1310
+ ):
1311
+ """Compute the KTO loss and other metrics for the given batch of inputs for train or test."""
1312
+ metrics = {}
1313
+ batch = {k: (v.to(self.accelerator.device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
1314
+
1315
+ forward_output = self.forward(model, batch)
1316
+ (
1317
+ policy_chosen_logps,
1318
+ policy_rejected_logps,
1319
+ policy_chosen_logits,
1320
+ policy_rejected_logits,
1321
+ policy_KL_logps,
1322
+ ) = forward_output[:5]
1323
+ if self.aux_loss_enabled:
1324
+ aux_loss = forward_output[5]
1325
+
1326
+ # if reference_logps in batch use them, otherwise use the reference model
1327
+ if "reference_logps" in batch:
1328
+ chosen_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is True]
1329
+ rejected_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is False]
1330
+
1331
+ reference_chosen_logps = batch["reference_logps"][chosen_idx, ...]
1332
+ reference_rejected_logps = batch["reference_logps"][rejected_idx, ...]
1333
+ if self.calculate_KL:
1334
+ reference_KL_logps = batch["reference_KL_logps"]
1335
+ else:
1336
+ reference_KL_logps = None
1337
+ else:
1338
+ with torch.no_grad():
1339
+ if self.ref_model is None:
1340
+ with self.null_ref_context():
1341
+ (
1342
+ reference_chosen_logps,
1343
+ reference_rejected_logps,
1344
+ _,
1345
+ _,
1346
+ reference_KL_logps,
1347
+ ) = self.forward(self.model, batch)[:5]
1348
+ else:
1349
+ (
1350
+ reference_chosen_logps,
1351
+ reference_rejected_logps,
1352
+ _,
1353
+ _,
1354
+ reference_KL_logps,
1355
+ ) = self.forward(self.ref_model, batch)[:5]
1356
+
1357
+ losses, chosen_rewards, rejected_rewards, kl = self.kto_loss(
1358
+ policy_chosen_logps,
1359
+ policy_rejected_logps,
1360
+ policy_KL_logps,
1361
+ reference_chosen_logps,
1362
+ reference_rejected_logps,
1363
+ reference_KL_logps,
1364
+ )
1365
+ metrics["kl"] = kl.item()
1366
+
1367
+ num_chosen = torch.Tensor([len(chosen_rewards)]).to(self.accelerator.device)
1368
+ num_rejected = torch.Tensor([len(rejected_rewards)]).to(self.accelerator.device)
1369
+
1370
+ all_num_chosen = self.accelerator.gather_for_metrics(num_chosen).sum().item()
1371
+ all_num_rejected = self.accelerator.gather_for_metrics(num_rejected).sum().item()
1372
+
1373
+ if all_num_chosen > 0:
1374
+ metrics["rewards/chosen_sum"] = (
1375
+ self.accelerator.gather_for_metrics(chosen_rewards.nansum()).nansum().item()
1376
+ )
1377
+ metrics["logps/chosen_sum"] = (
1378
+ self.accelerator.gather_for_metrics(policy_chosen_logps.nansum()).nansum().item()
1379
+ )
1380
+ metrics["logits/chosen_sum"] = (
1381
+ self.accelerator.gather_for_metrics(policy_chosen_logits.nansum()).nansum().item()
1382
+ )
1383
+ metrics["count/chosen"] = all_num_chosen
1384
+
1385
+ if all_num_rejected > 0:
1386
+ metrics["rewards/rejected_sum"] = (
1387
+ self.accelerator.gather_for_metrics(rejected_rewards.nansum()).nansum().item()
1388
+ )
1389
+ metrics["logps/rejected_sum"] = (
1390
+ self.accelerator.gather_for_metrics(policy_rejected_logps.nansum()).nansum().item()
1391
+ )
1392
+ metrics["logits/rejected_sum"] = (
1393
+ self.accelerator.gather_for_metrics(policy_rejected_logits.nansum()).nansum().item()
1394
+ )
1395
+ metrics["count/rejected"] = all_num_rejected
1396
+
1397
+ loss = losses.nanmean()
1398
+ if self.aux_loss_enabled:
1399
+ loss += self.aux_loss_coef * aux_loss
1400
+
1401
+ return loss, metrics
1402
+
1403
+ def compute_loss(
1404
+ self,
1405
+ model: Union[PreTrainedModel, nn.Module],
1406
+ inputs: dict[str, Union[torch.Tensor, Any]],
1407
+ return_outputs=False,
1408
+ num_items_in_batch=None,
1409
+ ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
1410
+ compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1411
+
1412
+ with compute_loss_context_manager:
1413
+ loss, metrics = self.get_batch_loss_metrics(model, inputs)
1414
+
1415
+ # Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class:
1416
+ loss = loss.to(self.args.device)
1417
+ # force log the metrics
1418
+ if self.accelerator.is_main_process:
1419
+ self.store_metrics(metrics, train_eval="train")
1420
+
1421
+ if return_outputs:
1422
+ return (loss, metrics)
1423
+ return loss
1424
+
1425
+ def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
1426
+ for key, value in metrics.items():
1427
+ self._stored_metrics[train_eval][key].append(value)
1428
+
1429
+ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
1430
+ if self.train_dataset is None or not has_length(self.train_dataset):
1431
+ return None
1432
+ return SequentialSampler(self.train_dataset)
1433
+
1434
+ def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) -> tuple[str, str]:
1435
+ """Generate samples from the model and reference model for the given batch of inputs."""
1436
+
1437
+ # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
1438
+ # the torch cuda amp context manager as some hidden states are silently casted to full precision.
1439
+ generate_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1440
+
1441
+ with generate_context_manager:
1442
+ policy_output = model.generate(
1443
+ input_ids=batch["prompt_input_ids"],
1444
+ attention_mask=batch["prompt_attention_mask"],
1445
+ max_length=self.max_length,
1446
+ do_sample=True,
1447
+ pad_token_id=self.processing_class.pad_token_id,
1448
+ )
1449
+
1450
+ # if reference_output in batch use that otherwise use the reference model
1451
+ if "reference_output" in batch:
1452
+ reference_output = batch["reference_output"]
1453
+ else:
1454
+ if self.ref_model is None:
1455
+ with self.null_ref_context():
1456
+ reference_output = self.model.generate(
1457
+ input_ids=batch["prompt_input_ids"],
1458
+ attention_mask=batch["prompt_attention_mask"],
1459
+ max_length=self.max_length,
1460
+ do_sample=True,
1461
+ pad_token_id=self.processing_class.pad_token_id,
1462
+ )
1463
+ else:
1464
+ reference_output = self.ref_model.generate(
1465
+ input_ids=batch["prompt_input_ids"],
1466
+ attention_mask=batch["prompt_attention_mask"],
1467
+ max_length=self.max_length,
1468
+ do_sample=True,
1469
+ pad_token_id=self.processing_class.pad_token_id,
1470
+ )
1471
+
1472
+ policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id)
1473
+ policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True)
1474
+
1475
+ reference_output = pad_to_length(reference_output, self.max_length, self.processing_class.pad_token_id)
1476
+ reference_output_decoded = self.processing_class.batch_decode(reference_output, skip_special_tokens=True)
1477
+
1478
+ return policy_output_decoded, reference_output_decoded
1479
+
1480
+ def prediction_step(
1481
+ self,
1482
+ model: Union[PreTrainedModel, nn.Module],
1483
+ inputs: dict[str, Union[torch.Tensor, Any]],
1484
+ prediction_loss_only: bool,
1485
+ ignore_keys: Optional[list[str]] = None,
1486
+ ):
1487
+ if ignore_keys is None:
1488
+ if hasattr(model, "config"):
1489
+ ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
1490
+ else:
1491
+ ignore_keys = []
1492
+
1493
+ prediction_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1494
+ with torch.no_grad(), prediction_context_manager:
1495
+ loss, metrics = self.get_batch_loss_metrics(model, inputs)
1496
+
1497
+ # force log the metrics
1498
+ if self.accelerator.is_main_process:
1499
+ self.store_metrics(metrics, train_eval="eval")
1500
+
1501
+ if prediction_loss_only:
1502
+ return (loss.detach(), None, None)
1503
+
1504
+ # logits for the chosen and rejected samples from model
1505
+ logits_dict = {
1506
+ "eval_logits/chosen": metrics["logits/chosen"],
1507
+ "eval_logits/rejected": metrics["logits/rejected"],
1508
+ }
1509
+ logits = torch.tensor(
1510
+ [v for k, v in logits_dict.items() if k not in ignore_keys], device=self.accelerator.device
1511
+ )
1512
+ labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
1513
+
1514
+ return (loss.detach(), logits, labels)
1515
+
1516
+ def evaluation_loop(
1517
+ self,
1518
+ dataloader: DataLoader,
1519
+ description: str,
1520
+ prediction_loss_only: Optional[bool] = None,
1521
+ ignore_keys: Optional[list[str]] = None,
1522
+ metric_key_prefix: str = "eval",
1523
+ ) -> EvalLoopOutput:
1524
+ """
1525
+ Overriding built-in evaluation loop to store metrics for each batch.
1526
+ Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
1527
+
1528
+ Works both with or without labels.
1529
+ """
1530
+
1531
+ # Sample and save to game log if requested (for one batch to save time)
1532
+ if self.generate_during_eval:
1533
+ # Generate random indices within the range of the total number of samples
1534
+ num_samples = len(dataloader.dataset)
1535
+ random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
1536
+
1537
+ # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
1538
+ random_batch_dataset = dataloader.dataset.select(random_indices)
1539
+ random_batch = self.data_collator(random_batch_dataset)
1540
+ random_batch = self._prepare_inputs(random_batch)
1541
+
1542
+ target_indicies = [i for i in range(len(random_batch["label"])) if random_batch["label"][i] is False]
1543
+ target_batch = {
1544
+ "prompt_input_ids": random_batch["prompt_input_ids"][target_indicies],
1545
+ "prompt_attention_mask": random_batch["prompt_attention_mask"][target_indicies],
1546
+ "prompt": itemgetter(*target_indicies)(random_batch["prompt"]),
1547
+ }
1548
+ policy_output_decoded, ref_output_decoded = self.generate_from_model_and_ref(self.model, target_batch)
1549
+
1550
+ table = pd.DataFrame(
1551
+ columns=["Prompt", "Policy", "Ref Model"],
1552
+ data=[
1553
+ [prompt, pol[len(prompt) :], ref[len(prompt) :]]
1554
+ for prompt, pol, ref in zip(target_batch["prompt"], policy_output_decoded, ref_output_decoded)
1555
+ ],
1556
+ )
1557
+ if "wandb" in self.args.report_to:
1558
+ wandb.log({"game_log": wandb.Table(data=table)})
1559
+
1560
+ if "comet_ml" in self.args.report_to:
1561
+ log_table_to_comet_experiment(
1562
+ name="game_log.csv",
1563
+ table=table,
1564
+ )
1565
+
1566
+ # Base evaluation
1567
+ initial_output = super().evaluation_loop(
1568
+ dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
1569
+ )
1570
+
1571
+ return initial_output
1572
+
1573
+ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
1574
+ """
1575
+ Log `logs` on the various objects watching training, including stored metrics.
1576
+
1577
+ Args:
1578
+ logs (`dict[str, float]`):
1579
+ The values to log.
1580
+ start_time (`float` or `None`, *optional*, defaults to `None`):
1581
+ Start time of the training.
1582
+ """
1583
+ # logs either has 'loss' or 'eval_loss'
1584
+ train_eval = "train" if "loss" in logs else "eval"
1585
+ # train metrics should have no prefix, eval should have 'eval_'
1586
+ prefix = "eval_" if train_eval == "eval" else ""
1587
+ # accumulate average metrics from sums and lengths
1588
+ for split in ["chosen", "rejected"]:
1589
+ if f"count/{split}" in self._stored_metrics[train_eval]:
1590
+ count_sum = torch.Tensor(self._stored_metrics[train_eval][f"count/{split}"]).sum().item()
1591
+ for metric in ["rewards", "logps", "logits"]:
1592
+ logs[f"{prefix}{metric}/{split}"] = (
1593
+ torch.Tensor(self._stored_metrics[train_eval][f"{metric}/{split}_sum"]).sum().item()
1594
+ / count_sum
1595
+ )
1596
+ # delete obsolete metric
1597
+ del self._stored_metrics[train_eval][f"{metric}/{split}_sum"]
1598
+ del self._stored_metrics[train_eval][f"count/{split}"]
1599
+ # calculate reward margin
1600
+ if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs:
1601
+ logs[f"{prefix}rewards/margins"] = logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"]
1602
+ # Add averaged stored metrics to logs
1603
+ for key, metrics in self._stored_metrics[train_eval].items():
1604
+ logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item()
1605
+ del self._stored_metrics[train_eval]
1606
+
1607
+ if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
1608
+ return super().log(logs, start_time)
1609
+ else: # transformers<=4.46
1610
+ return super().log(logs)
1611
+
1612
+ def create_model_card(
1613
+ self,
1614
+ model_name: Optional[str] = None,
1615
+ dataset_name: Optional[str] = None,
1616
+ tags: Union[str, list[str], None] = None,
1617
+ ):
1618
+ """
1619
+ Creates a draft of a model card using the information available to the `Trainer`.
1620
+
1621
+ Args:
1622
+ model_name (`str` or `None`, *optional*, defaults to `None`):
1623
+ Name of the model.
1624
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
1625
+ Name of the dataset used for training.
1626
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
1627
+ Tags to be associated with the model card.
1628
+ """
1629
+ if not self.is_world_process_zero():
1630
+ return
1631
+
1632
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
1633
+ base_model = self.model.config._name_or_path
1634
+ else:
1635
+ base_model = None
1636
+
1637
+ tags = tags or []
1638
+ if isinstance(tags, str):
1639
+ tags = [tags]
1640
+
1641
+ if hasattr(self.model.config, "unsloth_version"):
1642
+ tags.append("unsloth")
1643
+
1644
+ citation = textwrap.dedent("""\
1645
+ @article{ethayarajh2024kto,
1646
+ title = {{KTO: Model Alignment as Prospect Theoretic Optimization}},
1647
+ author = {Kawin Ethayarajh and Winnie Xu and Niklas Muennighoff and Dan Jurafsky and Douwe Kiela},
1648
+ year = 2024,
1649
+ eprint = {arXiv:2402.01306},
1650
+ }""")
1651
+
1652
+ model_card = generate_model_card(
1653
+ base_model=base_model,
1654
+ model_name=model_name,
1655
+ hub_model_id=self.hub_model_id,
1656
+ dataset_name=dataset_name,
1657
+ tags=tags,
1658
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
1659
+ comet_url=get_comet_experiment_url(),
1660
+ trainer_name="KTO",
1661
+ trainer_citation=citation,
1662
+ paper_title="KTO: Model Alignment as Prospect Theoretic Optimization",
1663
+ paper_id="2402.01306",
1664
+ )
1665
+
1666
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
1667
+ class UnslothKTOTrainer(_UnslothKTOTrainer):
1668
+ """
1669
+
1670
+ Initialize KTOTrainer.
1671
+
1672
+ Args:
1673
+ model (`transformers.PreTrainedModel`):
1674
+ The model to train, preferably an `AutoModelForSequenceClassification`.
1675
+ ref_model (`PreTrainedModelWrapper`):
1676
+ Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no
1677
+ reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized.
1678
+ args (`KTOConfig`):
1679
+ The arguments to use for training.
1680
+ train_dataset (`datasets.Dataset`):
1681
+ The dataset to use for training.
1682
+ eval_dataset (`datasets.Dataset`):
1683
+ The dataset to use for evaluation.
1684
+ processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
1685
+ Processing class used to process the data. If provided, will be used to automatically process the inputs
1686
+ for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
1687
+ reuse the fine-tuned model.
1688
+ data_collator (`transformers.DataCollator`, *optional*, defaults to `None`):
1689
+ The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
1690
+ which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
1691
+ model_init (`Callable[[], transformers.PreTrainedModel]`):
1692
+ The model initializer to use for training. If None is specified, the default model initializer will be used.
1693
+ callbacks (`list[transformers.TrainerCallback]`):
1694
+ The callbacks to use for training.
1695
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
1696
+ The optimizer and scheduler to use for training.
1697
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
1698
+ The function to use to preprocess the logits before computing the metrics.
1699
+ peft_config (`dict`, defaults to `None`):
1700
+ The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
1701
+ compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
1702
+ The function to use to compute the metrics. Must take a `EvalPrediction` and return
1703
+ a dictionary string to metric values.
1704
+ model_adapter_name (`str`, defaults to `None`):
1705
+ Name of the train target PEFT adapter, when using LoRA with multiple adapters.
1706
+ ref_adapter_name (`str`, defaults to `None`):
1707
+ Name of the reference PEFT adapter, when using LoRA with multiple adapters.
1708
+
1709
+ """
1710
+ def __init__(
1711
+ self,
1712
+ model = None,
1713
+ ref_model = None,
1714
+ args = None,
1715
+ train_dataset = None,
1716
+ eval_dataset = None,
1717
+ processing_class = None,
1718
+ data_collator = None,
1719
+ model_init = None,
1720
+ callbacks = None,
1721
+ preprocess_logits_for_metrics = None,
1722
+ peft_config = None,
1723
+ compute_metrics = None,
1724
+ model_adapter_name = None,
1725
+ ref_adapter_name = None,
1726
+ **kwargs
1727
+ ):
1728
+ if args is None: args = UnslothKTOConfig()
1729
+ use_bf16 = getattr(args, 'bf16', False)
1730
+ use_fp16 = getattr(args, 'fp16', False)
1731
+ force_float32 = False
1732
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
1733
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
1734
+ force_float32 = True
1735
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
1736
+ dtype = getattr(model.config, 'torch_dtype', None)
1737
+ if dtype is None: dtype = model.get_input_embeddings().dtype
1738
+ from unsloth_zoo.utils import _get_dtype
1739
+ dtype = _get_dtype(dtype)
1740
+ float16 = dtype == torch.float16
1741
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
1742
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
1743
+ if force_float32:
1744
+ args.fp16 = False
1745
+ args.bf16 = False
1746
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1747
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
1748
+ args.fp16 = float16
1749
+ args.bf16 = not float16
1750
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
1751
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
1752
+ args.eval_strategy = 'steps'
1753
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
1754
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
1755
+ if ga_steps is not None and ga_steps > 1:
1756
+ from transformers import __version__ as transformers_version
1757
+ if Version(transformers_version) <= Version('4.45.2'):
1758
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
1759
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
1760
+ if getattr(args, 'eval_strategy', 'no') != 'no':
1761
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
1762
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
1763
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
1764
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
1765
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
1766
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
1767
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
1768
+ if force_float32:
1769
+ args.bf16_full_eval = False
1770
+ args.fp16_full_eval = False
1771
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
1772
+ args.bf16_full_eval = True
1773
+ args.fp16_full_eval = False
1774
+ elif not bf16_full_eval and not fp16_full_eval:
1775
+ args.bf16_full_eval = args.bf16
1776
+ args.fp16_full_eval = args.fp16
1777
+ _output_logits = False
1778
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
1779
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
1780
+ if _output_logits:
1781
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
1782
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1783
+ pass
1784
+ else:
1785
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
1786
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
1787
+ if args_max_seq_length is None and model_max_seq_length is not None:
1788
+ max_seq_length = model.max_seq_length
1789
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1790
+ if model is not None and hasattr(model, 'for_training'):
1791
+ model.for_training()
1792
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1793
+ if 'processing_class' in locals():
1794
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1795
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1796
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
1797
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
1798
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1799
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
1800
+ data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
1801
+ elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
1802
+ data_collator = DataCollatorForSeq2Seq(__tokenizer)
1803
+ else:
1804
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
1805
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
1806
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
1807
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1808
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
1809
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
1810
+ data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
1811
+ else:
1812
+ data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
1813
+ other_metrics = []
1814
+
1815
+ from unsloth_zoo.logging_utils import PatchRLStatistics
1816
+ PatchRLStatistics('kto_trainer', other_metrics)
1817
+
1818
+ super().__init__(
1819
+ model = model,
1820
+ ref_model = ref_model,
1821
+ args = args,
1822
+ train_dataset = train_dataset,
1823
+ eval_dataset = eval_dataset,
1824
+ processing_class = processing_class,
1825
+ data_collator = data_collator,
1826
+ model_init = model_init,
1827
+ callbacks = callbacks,
1828
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics,
1829
+ peft_config = peft_config,
1830
+ compute_metrics = compute_metrics,
1831
+ model_adapter_name = model_adapter_name,
1832
+ ref_adapter_name = ref_adapter_name,**kwargs)
1833
+ if hasattr(self, 'neftune_hook_handle'):
1834
+ self.neftune_hook_handle.remove()
1835
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1836
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
1837
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1838
+ pass
1839
+
1840
+ pass
unsloth_compiled_cache/UnslothNashMDTrainer.py ADDED
@@ -0,0 +1,955 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.3.17
3
+ 2025.3.19
4
+ 4.50.3
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from trl.trainer.nash_md_trainer import (Any, BaseImageProcessor, BasePairwiseJudge, Callable, Dataset, EvalPrediction, F, FeatureExtractionMixin, GeometricMixtureWrapper, IterableDataset, NashMDConfig, NashMDTrainer, OnlineDPOTrainer, OptimizerNames, Optional, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SIMPLE_CHAT_TEMPLATE, TrainerCallback, Union, empty_cache, generate_model_card, get_comet_experiment_url, get_reward, is_conversational, is_wandb_available, jinja2, maybe_apply_chat_template, nn, os, textwrap, torch, truncate_right, unwrap_model_for_generation, wandb)
13
+
14
+
15
+ import os
16
+ from typing import *
17
+ from dataclasses import dataclass, field
18
+ from packaging.version import Version
19
+ import torch
20
+ import numpy as np
21
+ from contextlib import nullcontext
22
+ from torch.nn import functional as F
23
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
+
25
+ torch_compile_options = {
26
+ "epilogue_fusion" : True,
27
+ "max_autotune" : False,
28
+ "shape_padding" : True,
29
+ "trace.enabled" : False,
30
+ "triton.cudagraphs" : False,
31
+ }
32
+
33
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
+ def selective_log_softmax(logits, index):
35
+ logits = logits.to(torch.float32)
36
+ selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
+ # loop to reduce peak mem consumption
38
+ # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
+ logsumexp_values = torch.logsumexp(logits, dim = -1)
40
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
+ return per_token_logps
42
+ @dataclass
43
+ class UnslothNashMDConfig(NashMDConfig):
44
+ """
45
+
46
+ Configuration class for the [`NashMDTrainer`].
47
+
48
+ Subclass of [`OnlineDPOConfig`] we can use all its arguments and add the following:
49
+
50
+ Parameters:
51
+ mixture_coef (`float` or `list[float]`, *optional*, defaults to `0.5`):
52
+ Logit mixture coefficient for the model and reference model. If a list of floats is provided then the
53
+ mixture coefficient is selected for each new epoch and the last coefficient is used for the rest of the
54
+ epochs.
55
+
56
+ """
57
+ vllm_sampling_params: Optional[Any] = field(
58
+ default = None,
59
+ metadata = {'help': 'vLLM SamplingParams'},
60
+ )
61
+ unsloth_num_chunks : Optional[int] = field(
62
+ default = -1,
63
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
64
+ )
65
+ def __init__(
66
+ self,
67
+ output_dir = None,
68
+ overwrite_output_dir = None,
69
+ do_train = False,
70
+ do_eval = False,
71
+ do_predict = False,
72
+ eval_strategy = 'no',
73
+ prediction_loss_only = False,
74
+ per_device_train_batch_size = 4,
75
+ per_device_eval_batch_size = 4,
76
+ per_gpu_train_batch_size = None,
77
+ per_gpu_eval_batch_size = None,
78
+ gradient_accumulation_steps = 2,
79
+ eval_accumulation_steps = 2,
80
+ eval_delay = 0,
81
+ torch_empty_cache_steps = 250,
82
+ learning_rate = 5e-05,
83
+ weight_decay = 0.01,
84
+ adam_beta1 = 0.9,
85
+ adam_beta2 = 0.999,
86
+ adam_epsilon = 1e-08,
87
+ max_grad_norm = 1.0,
88
+ num_train_epochs = 3.0,
89
+ max_steps = -1,
90
+ lr_scheduler_type = 'linear',
91
+ warmup_ratio = 0.1,
92
+ warmup_steps = 0,
93
+ log_level = 'passive',
94
+ log_level_replica = 'warning',
95
+ log_on_each_node = True,
96
+ logging_dir = None,
97
+ logging_strategy = 'steps',
98
+ logging_first_step = False,
99
+ logging_steps = 1,
100
+ logging_nan_inf_filter = False,
101
+ save_strategy = 'steps',
102
+ save_steps = 500,
103
+ save_total_limit = None,
104
+ save_safetensors = True,
105
+ save_on_each_node = False,
106
+ save_only_model = False,
107
+ restore_callback_states_from_checkpoint = False,
108
+ no_cuda = False,
109
+ use_cpu = False,
110
+ use_mps_device = False,
111
+ seed = 3407,
112
+ data_seed = 3407,
113
+ jit_mode_eval = False,
114
+ use_ipex = False,
115
+ bf16 = False,
116
+ fp16 = False,
117
+ fp16_opt_level = 'O1',
118
+ half_precision_backend = 'auto',
119
+ bf16_full_eval = False,
120
+ fp16_full_eval = False,
121
+ tf32 = None,
122
+ local_rank = -1,
123
+ ddp_backend = None,
124
+ tpu_num_cores = None,
125
+ tpu_metrics_debug = False,
126
+ debug = '',
127
+ dataloader_drop_last = False,
128
+ eval_steps = None,
129
+ dataloader_num_workers = 0,
130
+ dataloader_prefetch_factor = None,
131
+ past_index = -1,
132
+ run_name = None,
133
+ disable_tqdm = None,
134
+ remove_unused_columns = True,
135
+ label_names = None,
136
+ load_best_model_at_end = False,
137
+ metric_for_best_model = None,
138
+ greater_is_better = None,
139
+ ignore_data_skip = False,
140
+ fsdp = '',
141
+ fsdp_min_num_params = 0,
142
+ fsdp_config = None,
143
+ tp_size = 0,
144
+ fsdp_transformer_layer_cls_to_wrap = None,
145
+ accelerator_config = None,
146
+ deepspeed = None,
147
+ label_smoothing_factor = 0.0,
148
+ optim = 'adamw_8bit',
149
+ optim_args = None,
150
+ adafactor = False,
151
+ group_by_length = False,
152
+ length_column_name = 'length',
153
+ report_to = None,
154
+ ddp_find_unused_parameters = None,
155
+ ddp_bucket_cap_mb = None,
156
+ ddp_broadcast_buffers = None,
157
+ dataloader_pin_memory = True,
158
+ dataloader_persistent_workers = False,
159
+ skip_memory_metrics = True,
160
+ use_legacy_prediction_loop = False,
161
+ push_to_hub = False,
162
+ resume_from_checkpoint = None,
163
+ hub_model_id = None,
164
+ hub_strategy = 'every_save',
165
+ hub_token = None,
166
+ hub_private_repo = None,
167
+ hub_always_push = False,
168
+ gradient_checkpointing = False,
169
+ gradient_checkpointing_kwargs = None,
170
+ include_inputs_for_metrics = False,
171
+ eval_do_concat_batches = True,
172
+ fp16_backend = 'auto',
173
+ evaluation_strategy = None,
174
+ push_to_hub_model_id = None,
175
+ push_to_hub_organization = None,
176
+ push_to_hub_token = None,
177
+ mp_parameters = '',
178
+ auto_find_batch_size = False,
179
+ full_determinism = False,
180
+ torchdynamo = None,
181
+ ray_scope = 'last',
182
+ ddp_timeout = 1800,
183
+ torch_compile = False,
184
+ torch_compile_backend = None,
185
+ torch_compile_mode = None,
186
+ dispatch_batches = None,
187
+ split_batches = None,
188
+ include_tokens_per_second = False,
189
+ include_num_input_tokens_seen = False,
190
+ neftune_noise_alpha = None,
191
+ optim_target_modules = None,
192
+ batch_eval_metrics = False,
193
+ eval_on_start = False,
194
+ use_liger_kernel = False,
195
+ eval_use_gather_object = False,
196
+ average_tokens_across_devices = False,
197
+ reward_model_path = None,
198
+ judge = None,
199
+ max_new_tokens = 64,
200
+ max_length = 512,
201
+ temperature = 0.9,
202
+ missing_eos_penalty = None,
203
+ loss_type = 'sigmoid',
204
+ dataset_num_proc = None,
205
+ disable_dropout = True,
206
+ use_vllm = False,
207
+ ds3_gather_for_generation = True,
208
+ vllm_sampling_params = None,
209
+ unsloth_num_chunks = -1,
210
+ **kwargs,
211
+ ):
212
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
213
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
214
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
215
+ output_dir = 'unsloth_training_checkpoints'
216
+ save_strategy = 'no'
217
+ if dataset_num_proc is None:
218
+ from multiprocessing import cpu_count
219
+ dataset_num_proc = cpu_count()
220
+
221
+ super().__init__(
222
+ output_dir = output_dir,
223
+ overwrite_output_dir = overwrite_output_dir,
224
+ do_train = do_train,
225
+ do_eval = do_eval,
226
+ do_predict = do_predict,
227
+ eval_strategy = eval_strategy,
228
+ prediction_loss_only = prediction_loss_only,
229
+ per_device_train_batch_size = per_device_train_batch_size,
230
+ per_device_eval_batch_size = per_device_eval_batch_size,
231
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
232
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
233
+ gradient_accumulation_steps = gradient_accumulation_steps,
234
+ eval_accumulation_steps = eval_accumulation_steps,
235
+ eval_delay = eval_delay,
236
+ torch_empty_cache_steps = torch_empty_cache_steps,
237
+ learning_rate = learning_rate,
238
+ weight_decay = weight_decay,
239
+ adam_beta1 = adam_beta1,
240
+ adam_beta2 = adam_beta2,
241
+ adam_epsilon = adam_epsilon,
242
+ max_grad_norm = max_grad_norm,
243
+ num_train_epochs = num_train_epochs,
244
+ max_steps = max_steps,
245
+ lr_scheduler_type = lr_scheduler_type,
246
+ warmup_ratio = warmup_ratio,
247
+ warmup_steps = warmup_steps,
248
+ log_level = log_level,
249
+ log_level_replica = log_level_replica,
250
+ log_on_each_node = log_on_each_node,
251
+ logging_dir = logging_dir,
252
+ logging_strategy = logging_strategy,
253
+ logging_first_step = logging_first_step,
254
+ logging_steps = logging_steps,
255
+ logging_nan_inf_filter = logging_nan_inf_filter,
256
+ save_strategy = save_strategy,
257
+ save_steps = save_steps,
258
+ save_total_limit = save_total_limit,
259
+ save_safetensors = save_safetensors,
260
+ save_on_each_node = save_on_each_node,
261
+ save_only_model = save_only_model,
262
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
263
+ no_cuda = no_cuda,
264
+ use_cpu = use_cpu,
265
+ use_mps_device = use_mps_device,
266
+ seed = seed,
267
+ data_seed = data_seed,
268
+ jit_mode_eval = jit_mode_eval,
269
+ use_ipex = use_ipex,
270
+ bf16 = bf16,
271
+ fp16 = fp16,
272
+ fp16_opt_level = fp16_opt_level,
273
+ half_precision_backend = half_precision_backend,
274
+ bf16_full_eval = bf16_full_eval,
275
+ fp16_full_eval = fp16_full_eval,
276
+ tf32 = tf32,
277
+ local_rank = local_rank,
278
+ ddp_backend = ddp_backend,
279
+ tpu_num_cores = tpu_num_cores,
280
+ tpu_metrics_debug = tpu_metrics_debug,
281
+ debug = debug,
282
+ dataloader_drop_last = dataloader_drop_last,
283
+ eval_steps = eval_steps,
284
+ dataloader_num_workers = dataloader_num_workers,
285
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
286
+ past_index = past_index,
287
+ run_name = run_name,
288
+ disable_tqdm = disable_tqdm,
289
+ remove_unused_columns = remove_unused_columns,
290
+ label_names = label_names,
291
+ load_best_model_at_end = load_best_model_at_end,
292
+ metric_for_best_model = metric_for_best_model,
293
+ greater_is_better = greater_is_better,
294
+ ignore_data_skip = ignore_data_skip,
295
+ fsdp = fsdp,
296
+ fsdp_min_num_params = fsdp_min_num_params,
297
+ fsdp_config = fsdp_config,
298
+ tp_size = tp_size,
299
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
300
+ accelerator_config = accelerator_config,
301
+ deepspeed = deepspeed,
302
+ label_smoothing_factor = label_smoothing_factor,
303
+ optim = optim,
304
+ optim_args = optim_args,
305
+ adafactor = adafactor,
306
+ group_by_length = group_by_length,
307
+ length_column_name = length_column_name,
308
+ report_to = report_to,
309
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
310
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
311
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
312
+ dataloader_pin_memory = dataloader_pin_memory,
313
+ dataloader_persistent_workers = dataloader_persistent_workers,
314
+ skip_memory_metrics = skip_memory_metrics,
315
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
316
+ push_to_hub = push_to_hub,
317
+ resume_from_checkpoint = resume_from_checkpoint,
318
+ hub_model_id = hub_model_id,
319
+ hub_strategy = hub_strategy,
320
+ hub_token = hub_token,
321
+ hub_private_repo = hub_private_repo,
322
+ hub_always_push = hub_always_push,
323
+ gradient_checkpointing = gradient_checkpointing,
324
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
325
+ include_inputs_for_metrics = include_inputs_for_metrics,
326
+ eval_do_concat_batches = eval_do_concat_batches,
327
+ fp16_backend = fp16_backend,
328
+ evaluation_strategy = evaluation_strategy,
329
+ push_to_hub_model_id = push_to_hub_model_id,
330
+ push_to_hub_organization = push_to_hub_organization,
331
+ push_to_hub_token = push_to_hub_token,
332
+ mp_parameters = mp_parameters,
333
+ auto_find_batch_size = auto_find_batch_size,
334
+ full_determinism = full_determinism,
335
+ torchdynamo = torchdynamo,
336
+ ray_scope = ray_scope,
337
+ ddp_timeout = ddp_timeout,
338
+ torch_compile = torch_compile,
339
+ torch_compile_backend = torch_compile_backend,
340
+ torch_compile_mode = torch_compile_mode,
341
+ dispatch_batches = dispatch_batches,
342
+ split_batches = split_batches,
343
+ include_tokens_per_second = include_tokens_per_second,
344
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
345
+ neftune_noise_alpha = neftune_noise_alpha,
346
+ optim_target_modules = optim_target_modules,
347
+ batch_eval_metrics = batch_eval_metrics,
348
+ eval_on_start = eval_on_start,
349
+ use_liger_kernel = use_liger_kernel,
350
+ eval_use_gather_object = eval_use_gather_object,
351
+ average_tokens_across_devices = average_tokens_across_devices,
352
+ reward_model_path = reward_model_path,
353
+ judge = judge,
354
+ max_new_tokens = max_new_tokens,
355
+ max_length = max_length,
356
+ temperature = temperature,
357
+ missing_eos_penalty = missing_eos_penalty,
358
+ loss_type = loss_type,
359
+ dataset_num_proc = dataset_num_proc,
360
+ disable_dropout = disable_dropout,
361
+ use_vllm = use_vllm,
362
+ ds3_gather_for_generation = ds3_gather_for_generation,**kwargs)
363
+ self.vllm_sampling_params = vllm_sampling_params
364
+ self.unsloth_num_chunks = unsloth_num_chunks
365
+ pass
366
+
367
+ class _UnslothNashMDTrainer(OnlineDPOTrainer):
368
+ r""""""
369
+
370
+ _tag_names = ["trl", "nash-md"]
371
+
372
+ def __init__(
373
+ self,
374
+ model: Union[PreTrainedModel, nn.Module] = None,
375
+ ref_model: Union[PreTrainedModel, nn.Module] = None,
376
+ reward_model: Union[PreTrainedModel, nn.Module, None] = None,
377
+ judge: Optional[BasePairwiseJudge] = None,
378
+ args: Optional[NashMDConfig] = None,
379
+ data_collator: Optional[Callable] = None,
380
+ train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
381
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
382
+ processing_class: Optional[
383
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
384
+ ] = None,
385
+ peft_config: Optional[dict] = None,
386
+ compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
387
+ callbacks: Optional[list[TrainerCallback]] = None,
388
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
389
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
390
+ ) -> None:
391
+ super().__init__(
392
+ model=model,
393
+ ref_model=ref_model,
394
+ reward_model=reward_model,
395
+ judge=judge,
396
+ args=args,
397
+ data_collator=data_collator,
398
+ train_dataset=train_dataset,
399
+ eval_dataset=eval_dataset,
400
+ processing_class=processing_class,
401
+ reward_processing_class=processing_class, # for now, NashMDTrainer can't use any reward model
402
+ peft_config=peft_config,
403
+ compute_metrics=compute_metrics,
404
+ callbacks=callbacks,
405
+ optimizers=optimizers,
406
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
407
+ )
408
+
409
+ self._mixture_coef = self.args.mixture_coef
410
+
411
+ # Overwrite the stats dictionary to include NashMD specific statistics
412
+ self.stats = {
413
+ # Remove "non_score_reward", "rlhf_reward", "scores_margin"
414
+ # Add "mixture_coef"
415
+ "loss/kl": [],
416
+ "objective/entropy": [],
417
+ "loss/score": [],
418
+ "rewards/probabilities": [],
419
+ "rewards/accuracies": [],
420
+ "rewards/margins": [],
421
+ "logps/chosen": [],
422
+ "logps/rejected": [],
423
+ "val/model_contain_eos_token": [],
424
+ "val/ref_contain_eos_token": [],
425
+ "beta": [],
426
+ "mixture_coef": [],
427
+ }
428
+ if self.reward_model is not None:
429
+ self.stats["rewards/chosen"] = []
430
+ self.stats["rewards/rejected"] = []
431
+
432
+ @property
433
+ def mixture_coef(self):
434
+ if isinstance(self._mixture_coef, list):
435
+ epoch = self.state.epoch
436
+ return self._mixture_coef[epoch] if epoch < len(self._mixture_coef) else self._mixture_coef[-1]
437
+ else:
438
+ return self._mixture_coef
439
+
440
+ def _generate_completions(self, model, prompts):
441
+ with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
442
+ model_output = unwrapped_model.generate(
443
+ input_ids=prompts["input_ids"],
444
+ attention_mask=prompts["attention_mask"],
445
+ generation_config=self.generation_config,
446
+ )
447
+
448
+ ref_model = model if self.ref_model is None else self.ref_model
449
+ with torch.no_grad(), unwrap_model_for_generation(ref_model, self.accelerator) as unwrapped_ref_model:
450
+ mixture_model = GeometricMixtureWrapper(
451
+ model=unwrapped_model,
452
+ ref_model=unwrapped_ref_model,
453
+ generation_config=self.generation_config,
454
+ mixture_coef=self.mixture_coef,
455
+ device=self.accelerator.device,
456
+ )
457
+
458
+ mixture_output = mixture_model.generate(
459
+ input_ids=prompts["input_ids"],
460
+ attention_mask=prompts["attention_mask"],
461
+ generation_config=self.generation_config,
462
+ )
463
+
464
+ return model_output, mixture_output
465
+
466
+ def _process_completions(self, model_output, mixture_output, prompts):
467
+ context_length = prompts["input_ids"].shape[1]
468
+
469
+ # Process model completions
470
+ model_completion_ids = model_output[:, context_length:]
471
+ model_completion_ids, model_completion_mask = truncate_right(
472
+ model_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id
473
+ )
474
+ model_data = {
475
+ "input_ids": torch.cat((prompts["input_ids"], model_completion_ids), dim=1),
476
+ "attention_mask": torch.cat((prompts["attention_mask"], model_completion_mask), dim=1),
477
+ "raw": prompts["raw"],
478
+ }
479
+
480
+ # Process reference model completions
481
+ mixture_completion_ids = mixture_output[:, context_length:]
482
+ mixture_completion_ids, mixture_completion_mask = truncate_right(
483
+ mixture_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id
484
+ )
485
+ mixture_data = {
486
+ "input_ids": torch.cat((prompts["input_ids"], mixture_completion_ids), dim=1),
487
+ "attention_mask": torch.cat((prompts["attention_mask"], mixture_completion_mask), dim=1),
488
+ "raw": prompts["raw"],
489
+ }
490
+
491
+ return model_data, mixture_data
492
+
493
+ def _compute_rewards(self, model_data, mixture_data, context_length):
494
+ with torch.no_grad():
495
+ _, model_scores, _ = get_reward(
496
+ self.reward_model, model_data["input_ids"], self.processing_class.pad_token_id, context_length
497
+ )
498
+ _, mixture_scores, _ = get_reward(
499
+ self.reward_model, mixture_data["input_ids"], self.processing_class.pad_token_id, context_length
500
+ )
501
+
502
+ # Apply EOS penalty if needed
503
+ if self.args.missing_eos_penalty is not None:
504
+ model_contain_eos = torch.any(model_data["input_ids"] == self.processing_class.eos_token_id, dim=-1)
505
+ mixture_contain_eos = torch.any(mixture_data["input_ids"] == self.processing_class.eos_token_id, dim=-1)
506
+ model_scores[~model_contain_eos] -= self.args.missing_eos_penalty
507
+ mixture_scores[~mixture_contain_eos] -= self.args.missing_eos_penalty
508
+
509
+ return model_scores, mixture_scores
510
+
511
+ def _compute_judge(self, model_data, mixture_data, context_length):
512
+ prompts = model_data["raw"]
513
+ model_data_completions = self.processing_class.batch_decode(
514
+ model_data["input_ids"][:, context_length:], skip_special_tokens=True
515
+ )
516
+ model_data_completions = [completion.strip() for completion in model_data_completions]
517
+
518
+ mixture_data_completions = self.processing_class.batch_decode(
519
+ mixture_data["input_ids"][:, context_length:], skip_special_tokens=True
520
+ )
521
+ mixture_data_completions = [completion.strip() for completion in mixture_data_completions]
522
+ if is_conversational({"prompt": prompts[0]}):
523
+ model_data_completions = [
524
+ [{"role": "assistant", "content": completion}] for completion in model_data_completions
525
+ ]
526
+ environment = jinja2.Environment()
527
+ template = environment.from_string(SIMPLE_CHAT_TEMPLATE)
528
+ prompts = [template.render(messages=message) for message in prompts]
529
+ model_data_completions = [template.render(messages=completion) for completion in model_data_completions]
530
+
531
+ mixture_data_completions = [
532
+ [{"role": "assistant", "content": completion}] for completion in mixture_data_completions
533
+ ]
534
+ mixture_data_completions = [
535
+ template.render(messages=completion) for completion in mixture_data_completions
536
+ ]
537
+
538
+ probability = self.judge.judge(
539
+ prompts,
540
+ list(zip(model_data_completions, mixture_data_completions)),
541
+ return_scores=True,
542
+ )
543
+ return torch.tensor(probability, device=model_data["input_ids"].device)
544
+
545
+ def _compute_logprobs(self, model, model_data, context_length):
546
+ def compute_logprobs_for_data(m, data):
547
+ output = m(data["input_ids"], attention_mask=data["attention_mask"])
548
+ logits = output.logits[:, context_length - 1 : -1]
549
+ token_logprobs = selective_log_softmax(logits, data["input_ids"][:, context_length:])
550
+ return token_logprobs
551
+
552
+ # Compute logprobs for model completions under the model
553
+ model_logprobs_model_data = compute_logprobs_for_data(model, model_data)
554
+
555
+ # Compute logprobs of model completions under the reference model
556
+ with torch.no_grad():
557
+ if self.ref_model is None:
558
+ with model.disable_adapter():
559
+ ref_logprobs_model_data = compute_logprobs_for_data(model, model_data)
560
+ else:
561
+ ref_logprobs_model_data = compute_logprobs_for_data(self.ref_model, model_data)
562
+
563
+ # Mask padding tokens
564
+ model_padding_mask = model_data["attention_mask"][:, context_length:] == 0
565
+ model_logprobs_model_data = model_logprobs_model_data.masked_fill(model_padding_mask, 0.0)
566
+ ref_logprobs_model_data = ref_logprobs_model_data.masked_fill(model_padding_mask, 0.0)
567
+
568
+ return (model_logprobs_model_data, ref_logprobs_model_data)
569
+
570
+ def _compute_losses(
571
+ self,
572
+ model_logprobs_model_data,
573
+ ref_logprobs_model_data,
574
+ probability,
575
+ ):
576
+ # reinforce score where 0.5 is a control variate
577
+ score = (probability - 0.5) * model_logprobs_model_data.sum(1)
578
+
579
+ # kl divergence via reinforce
580
+ with torch.no_grad():
581
+ log_ratio = model_logprobs_model_data - ref_logprobs_model_data
582
+ kl_div_log = log_ratio.sum(1)
583
+ kl_div_loss = (log_ratio * model_logprobs_model_data).sum(1)
584
+
585
+ # final loss
586
+ loss = self.beta * kl_div_loss - score
587
+
588
+ return loss.mean(), score, kl_div_log
589
+
590
+ def _log_statistics(
591
+ self,
592
+ model_data,
593
+ mixture_data,
594
+ model_logprobs_model_data,
595
+ ref_logprobs_model_data,
596
+ probability,
597
+ score,
598
+ kl_div,
599
+ context_length,
600
+ model_scores=None,
601
+ mixture_scores=None,
602
+ ):
603
+ # Helper function to gather and compute mean
604
+ def gather_mean(tensor):
605
+ return self.accelerator.gather_for_metrics(tensor).mean().item()
606
+
607
+ # Log score
608
+ self.stats["loss/score"].append(gather_mean(score))
609
+ # Log KL divergence
610
+ self.stats["loss/kl"].append(gather_mean(kl_div))
611
+
612
+ # Log logprobs
613
+ model_logprobs_model_data_sum = model_logprobs_model_data.sum(1)
614
+ ref_logprobs_model_data_sum = ref_logprobs_model_data.sum(1)
615
+
616
+ self.stats["logps/chosen"].append(gather_mean(model_logprobs_model_data_sum))
617
+ self.stats["logps/rejected"].append(gather_mean(ref_logprobs_model_data_sum))
618
+
619
+ # Log rewards
620
+ if self.reward_model is not None:
621
+ self.stats["rewards/chosen"].append(gather_mean(model_scores))
622
+ self.stats["rewards/rejected"].append(gather_mean(mixture_scores))
623
+
624
+ # Log probabilities
625
+ self.stats["rewards/probabilities"].append(gather_mean(probability))
626
+
627
+ # Calculate entropy for model data
628
+ entropy_model_data = -model_logprobs_model_data.sum(1)
629
+ self.stats["objective/entropy"].append(gather_mean(entropy_model_data))
630
+
631
+ # Calculate margins
632
+ margin = model_logprobs_model_data_sum - ref_logprobs_model_data_sum
633
+ self.stats["rewards/margins"].append(gather_mean(margin))
634
+
635
+ # Calculate accuracy
636
+ accuracy = (margin > 0).float()
637
+ self.stats["rewards/accuracies"].append(gather_mean(accuracy))
638
+
639
+ # Log EOS token statistics
640
+ model_eos = (model_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1)
641
+ mixture_eos = (mixture_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1)
642
+ self.stats["val/model_contain_eos_token"].append(gather_mean(model_eos.float()))
643
+ self.stats["val/ref_contain_eos_token"].append(gather_mean(mixture_eos.float()))
644
+
645
+ # Log beta and mixture coef
646
+ self.stats["beta"].append(self.beta)
647
+ self.stats["mixture_coef"].append(self.mixture_coef)
648
+
649
+ def training_step(
650
+ self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
651
+ ) -> torch.Tensor:
652
+ model.train()
653
+
654
+ # Apply chat template and tokenize the input
655
+ batch_size = len(next(iter(inputs.values())))
656
+ prompts = inputs["prompt"]
657
+ inputs = [{k: v[i] for k, v in inputs.items()} for i in range(batch_size)]
658
+ inputs = [maybe_apply_chat_template(x, self.processing_class) for x in inputs]
659
+ inputs = [self.tokenize_row(x, self.model.config.is_encoder_decoder, self.processing_class) for x in inputs]
660
+ inputs = self.data_collator(inputs)
661
+
662
+ # need the prompt_ only
663
+ inputs = self._prepare_inputs(inputs)
664
+ context_length = inputs["prompt_input_ids"].shape[1]
665
+ prompts = {
666
+ "input_ids": inputs["prompt_input_ids"],
667
+ "attention_mask": inputs["prompt_attention_mask"],
668
+ "raw": prompts,
669
+ }
670
+ del inputs
671
+
672
+ # Sample completions from both the model and the reference model
673
+ model_output, mixture_output = self._generate_completions(model, prompts)
674
+
675
+ # Process model completions
676
+ model_data, mixture_data = self._process_completions(model_output, mixture_output, prompts)
677
+
678
+ # Compute rewards
679
+ if self.reward_model is not None:
680
+ model_scores, mixture_scores = self._compute_rewards(model_data, mixture_data, context_length)
681
+ # probability of the model data vs the mixture data
682
+ probability = F.sigmoid(model_scores - mixture_scores)
683
+ else:
684
+ model_scores, mixture_scores = None, None
685
+ probability = self._compute_judge(model_data, mixture_data, context_length)
686
+
687
+ # Compute logprobs
688
+ model_logprobs_model_data, ref_logprobs_model_data = self._compute_logprobs(model, model_data, context_length)
689
+
690
+ # Compute loss
691
+ loss, score, kl_div = self._compute_losses(model_logprobs_model_data, ref_logprobs_model_data, probability)
692
+
693
+ # Log everything
694
+ self._log_statistics(
695
+ model_data,
696
+ mixture_data,
697
+ model_logprobs_model_data.detach(),
698
+ ref_logprobs_model_data,
699
+ probability,
700
+ score.detach(),
701
+ kl_div.detach(),
702
+ context_length,
703
+ model_scores,
704
+ mixture_scores,
705
+ )
706
+
707
+ if (
708
+ self.args.torch_empty_cache_steps is not None
709
+ and self.state.global_step % self.args.torch_empty_cache_steps == 0
710
+ ):
711
+ empty_cache()
712
+
713
+ kwargs = {}
714
+ # For LOMO optimizers you need to explicitly use the learning rate
715
+ if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
716
+ kwargs["learning_rate"] = self._get_learning_rate()
717
+
718
+ if self.args.n_gpu > 1:
719
+ loss = loss.mean() # mean() to average on multi-gpu parallel training
720
+
721
+ if self.use_apex:
722
+ with amp.scale_loss(loss, self.optimizer) as scaled_loss:
723
+ scaled_loss.backward()
724
+ else:
725
+ self.accelerator.backward(loss, **kwargs)
726
+
727
+ return loss.detach() / self.args.gradient_accumulation_steps
728
+
729
+ def create_model_card(
730
+ self,
731
+ model_name: Optional[str] = None,
732
+ dataset_name: Optional[str] = None,
733
+ tags: Union[str, list[str], None] = None,
734
+ ):
735
+ """
736
+ Creates a draft of a model card using the information available to the `Trainer`.
737
+
738
+ Args:
739
+ model_name (`str` or `None`, *optional*, defaults to `None`):
740
+ Name of the model.
741
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
742
+ Name of the dataset used for training.
743
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
744
+ Tags to be associated with the model card.
745
+ """
746
+ if not self.is_world_process_zero():
747
+ return
748
+
749
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
750
+ base_model = self.model.config._name_or_path
751
+ else:
752
+ base_model = None
753
+
754
+ tags = tags or []
755
+ if isinstance(tags, str):
756
+ tags = [tags]
757
+
758
+ if hasattr(self.model.config, "unsloth_version"):
759
+ tags.append("unsloth")
760
+
761
+ citation = textwrap.dedent("""\
762
+ @inproceedings{munos2024nash,
763
+ title = {{Nash Learning from Human Feedback}},
764
+ author = {R{\'{e}}mi Munos and Michal Valko and Daniele Calandriello and Mohammad Gheshlaghi Azar and Mark Rowland and Zhaohan Daniel Guo and Yunhao Tang and Matthieu Geist and Thomas Mesnard and C{\\^{o}}me Fiegel and Andrea Michi and Marco Selvi and Sertan Girgin and Nikola Momchev and Olivier Bachem and Daniel J. Mankowitz and Doina Precup and Bilal Piot},
765
+ year = 2024,
766
+ booktitle = {Forty-first International Conference on Machine Learning, {ICML} 2024, Vienna, Austria, July 21-27, 2024},
767
+ publisher = {OpenReview.net},
768
+ url = {https://openreview.net/forum?id=Y5AmNYiyCQ}
769
+ }""")
770
+
771
+ model_card = generate_model_card(
772
+ base_model=base_model,
773
+ model_name=model_name,
774
+ hub_model_id=self.hub_model_id,
775
+ dataset_name=dataset_name,
776
+ tags=tags,
777
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
778
+ comet_url=get_comet_experiment_url(),
779
+ trainer_name="Nash-MD",
780
+ trainer_citation=citation,
781
+ paper_title="Nash Learning from Human Feedback",
782
+ paper_id="2312.00886",
783
+ )
784
+
785
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
786
+ class UnslothNashMDTrainer(_UnslothNashMDTrainer):
787
+ """
788
+
789
+ Initialize NashMDTrainer as a subclass of [`OnlineDPOConfig`].
790
+
791
+ Args:
792
+ model (`transformers.PreTrainedModel`):
793
+ The model to train, preferably an `AutoModelForCausalLM`.
794
+ ref_model (`PreTrainedModelWrapper`):
795
+ Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no
796
+ reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized.
797
+ reward_model (`transformers.PreTrainedModel`):
798
+ The reward model to score completions with, preferably an `AutoModelForSequenceClassification`.
799
+ judge (`BasePairwiseJudge`):
800
+ The judge to use for pairwise comparison of model completions.
801
+ args (`NashMDConfig`):
802
+ The NashMD config arguments to use for training.
803
+ data_collator (`transformers.DataCollator`):
804
+ The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
805
+ which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
806
+ train_dataset (`datasets.Dataset`):
807
+ The dataset to use for training.
808
+ eval_dataset (`datasets.Dataset`):
809
+ The dataset to use for evaluation.
810
+ processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
811
+ Processing class used to process the data. If provided, will be used to automatically process the inputs
812
+ for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
813
+ reuse the fine-tuned model.
814
+ peft_config (`dict`):
815
+ The peft config to use for training.
816
+ compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
817
+ The function to use to compute the metrics. Must take a `EvalPrediction` and return
818
+ a dictionary string to metric values.
819
+ callbacks (`list[transformers.TrainerCallback]`):
820
+ The callbacks to use for training.
821
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
822
+ The optimizer and scheduler to use for training.
823
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
824
+ The function to use to preprocess the logits before computing the metrics.
825
+
826
+ """
827
+ def __init__(
828
+ self,
829
+ model = None,
830
+ ref_model = None,
831
+ reward_model = None,
832
+ judge = None,
833
+ args = None,
834
+ data_collator = None,
835
+ train_dataset = None,
836
+ eval_dataset = None,
837
+ processing_class = None,
838
+ peft_config = None,
839
+ compute_metrics = None,
840
+ callbacks = None,
841
+ preprocess_logits_for_metrics = None,
842
+ **kwargs
843
+ ):
844
+ if args is None: args = UnslothNashMDConfig()
845
+ use_bf16 = getattr(args, 'bf16', False)
846
+ use_fp16 = getattr(args, 'fp16', False)
847
+ force_float32 = False
848
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
849
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
850
+ force_float32 = True
851
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
852
+ dtype = getattr(model.config, 'torch_dtype', None)
853
+ if dtype is None: dtype = model.get_input_embeddings().dtype
854
+ from unsloth_zoo.utils import _get_dtype
855
+ dtype = _get_dtype(dtype)
856
+ float16 = dtype == torch.float16
857
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
858
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
859
+ if force_float32:
860
+ args.fp16 = False
861
+ args.bf16 = False
862
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
863
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
864
+ args.fp16 = float16
865
+ args.bf16 = not float16
866
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
867
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
868
+ args.eval_strategy = 'steps'
869
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
870
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
871
+ if ga_steps is not None and ga_steps > 1:
872
+ from transformers import __version__ as transformers_version
873
+ if Version(transformers_version) <= Version('4.45.2'):
874
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
875
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
876
+ if getattr(args, 'eval_strategy', 'no') != 'no':
877
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
878
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
879
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
880
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
881
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
882
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
883
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
884
+ if force_float32:
885
+ args.bf16_full_eval = False
886
+ args.fp16_full_eval = False
887
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
888
+ args.bf16_full_eval = True
889
+ args.fp16_full_eval = False
890
+ elif not bf16_full_eval and not fp16_full_eval:
891
+ args.bf16_full_eval = args.bf16
892
+ args.fp16_full_eval = args.fp16
893
+ _output_logits = False
894
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
895
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
896
+ if _output_logits:
897
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
898
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
899
+ pass
900
+ else:
901
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
902
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
903
+ if args_max_seq_length is None and model_max_seq_length is not None:
904
+ max_seq_length = model.max_seq_length
905
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
906
+ if model is not None and hasattr(model, 'for_training'):
907
+ model.for_training()
908
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
909
+ if 'processing_class' in locals():
910
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
911
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
912
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
913
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
914
+ if not isinstance(data_collator, UnslothVisionDataCollator):
915
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
916
+ data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
917
+ elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
918
+ data_collator = DataCollatorForSeq2Seq(__tokenizer)
919
+ else:
920
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
921
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
922
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
923
+ if not isinstance(data_collator, UnslothVisionDataCollator):
924
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
925
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
926
+ data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
927
+ else:
928
+ data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
929
+ other_metrics = []
930
+
931
+ from unsloth_zoo.logging_utils import PatchRLStatistics
932
+ PatchRLStatistics('nash_md_trainer', other_metrics)
933
+
934
+ super().__init__(
935
+ model = model,
936
+ ref_model = ref_model,
937
+ reward_model = reward_model,
938
+ judge = judge,
939
+ args = args,
940
+ data_collator = data_collator,
941
+ train_dataset = train_dataset,
942
+ eval_dataset = eval_dataset,
943
+ processing_class = processing_class,
944
+ peft_config = peft_config,
945
+ compute_metrics = compute_metrics,
946
+ callbacks = callbacks,
947
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics,**kwargs)
948
+ if hasattr(self, 'neftune_hook_handle'):
949
+ self.neftune_hook_handle.remove()
950
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
951
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
952
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
953
+ pass
954
+
955
+ pass
unsloth_compiled_cache/UnslothORPOTrainer.py ADDED
@@ -0,0 +1,1543 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.3.17
3
+ 2025.3.19
4
+ 4.50.3
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from trl.trainer.orpo_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, Literal, ORPOConfig, ORPOTrainer, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedModelWrapper, PreTrainedTokenizerBase, ProcessorMixin, Trainer, TrainerCallback, Union, add_bos_token_if_needed, add_eos_token_if_needed, amp, deepcopy, deepspeed, defaultdict, disable_dropout_in_model, generate_model_card, get_comet_experiment_url, inspect, is_comet_available, is_peft_available, is_torch_fx_proxy, is_torch_xla_available, is_wandb_available, log_table_to_comet_experiment, maybe_apply_chat_template, maybe_extract_prompt, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_model_for_kbit_training, random, textwrap, torch, transformers, version, wandb, warnings)
13
+
14
+
15
+ import os
16
+ from typing import *
17
+ from dataclasses import dataclass, field
18
+ from packaging.version import Version
19
+ import torch
20
+ import numpy as np
21
+ from contextlib import nullcontext
22
+ from torch.nn import functional as F
23
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
+
25
+ torch_compile_options = {
26
+ "epilogue_fusion" : True,
27
+ "max_autotune" : False,
28
+ "shape_padding" : True,
29
+ "trace.enabled" : False,
30
+ "triton.cudagraphs" : False,
31
+ }
32
+
33
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
+ def selective_log_softmax(logits, index):
35
+ logits = logits.to(torch.float32)
36
+ selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
+ # loop to reduce peak mem consumption
38
+ # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
+ logsumexp_values = torch.logsumexp(logits, dim = -1)
40
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
+ return per_token_logps
42
+ @dataclass
43
+ class UnslothORPOConfig(ORPOConfig):
44
+ """
45
+
46
+ Configuration class for the [`ORPOTrainer`].
47
+
48
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
49
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
50
+ command line.
51
+
52
+ Parameters:
53
+ learning_rate (`float`, *optional*, defaults to `1e-6`):
54
+ Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
55
+ [`~transformers.TrainingArguments`].
56
+ max_length (`int` or `None`, *optional*, defaults to `1024`):
57
+ Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
58
+ to use the default data collator.
59
+ max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
60
+ Maximum length of the prompt. This argument is required if you want to use the default data collator.
61
+ max_completion_length (`int` or `None`, *optional*, defaults to `None`):
62
+ Maximum length of the completion. This argument is required if you want to use the default data collator
63
+ and your model is an encoder-decoder.
64
+ beta (`float`, *optional*, defaults to `0.1`):
65
+ Parameter controlling the relative ratio loss weight in the ORPO loss. In the [paper](https://huggingface.co/papers/2403.07691),
66
+ it is denoted by λ. In the [code](https://github.com/xfactlab/orpo), it is denoted by `alpha`.
67
+ disable_dropout (`bool`, *optional*, defaults to `True`):
68
+ Whether to disable dropout in the model.
69
+ label_pad_token_id (`int`, *optional*, defaults to `-100`):
70
+ Label pad token id. This argument is required if you want to use the default data collator.
71
+ padding_value (`int` or `None`, *optional*, defaults to `None`):
72
+ Padding value to use. If `None`, the padding value of the tokenizer is used.
73
+ truncation_mode (`str`, *optional*, defaults to `"keep_end"`):
74
+ Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`.
75
+ This argument is required if you want to use the default data collator.
76
+ generate_during_eval (`bool`, *optional*, defaults to `False`):
77
+ If `True`, generates and logs completions from the model to W&B or Comet during evaluation.
78
+ is_encoder_decoder (`bool` or `None`, *optional*, defaults to `None`):
79
+ When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument,
80
+ you need to specify if the model returned by the callable is an encoder-decoder model.
81
+ model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
82
+ Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
83
+ string.
84
+ dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
85
+ Number of processes to use for processing the dataset.
86
+
87
+ """
88
+ vllm_sampling_params: Optional[Any] = field(
89
+ default = None,
90
+ metadata = {'help': 'vLLM SamplingParams'},
91
+ )
92
+ unsloth_num_chunks : Optional[int] = field(
93
+ default = -1,
94
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
95
+ )
96
+ def __init__(
97
+ self,
98
+ output_dir = None,
99
+ overwrite_output_dir = None,
100
+ do_train = False,
101
+ do_eval = False,
102
+ do_predict = False,
103
+ eval_strategy = 'no',
104
+ prediction_loss_only = False,
105
+ per_device_train_batch_size = 4,
106
+ per_device_eval_batch_size = 4,
107
+ per_gpu_train_batch_size = None,
108
+ per_gpu_eval_batch_size = None,
109
+ gradient_accumulation_steps = 2,
110
+ eval_accumulation_steps = 2,
111
+ eval_delay = 0,
112
+ torch_empty_cache_steps = 250,
113
+ learning_rate = 5e-05,
114
+ weight_decay = 0.01,
115
+ adam_beta1 = 0.9,
116
+ adam_beta2 = 0.999,
117
+ adam_epsilon = 1e-08,
118
+ max_grad_norm = 1.0,
119
+ num_train_epochs = 3.0,
120
+ max_steps = -1,
121
+ lr_scheduler_type = 'linear',
122
+ warmup_ratio = 0.1,
123
+ warmup_steps = 0,
124
+ log_level = 'passive',
125
+ log_level_replica = 'warning',
126
+ log_on_each_node = True,
127
+ logging_dir = None,
128
+ logging_strategy = 'steps',
129
+ logging_first_step = False,
130
+ logging_steps = 1,
131
+ logging_nan_inf_filter = False,
132
+ save_strategy = 'steps',
133
+ save_steps = 500,
134
+ save_total_limit = None,
135
+ save_safetensors = True,
136
+ save_on_each_node = False,
137
+ save_only_model = False,
138
+ restore_callback_states_from_checkpoint = False,
139
+ no_cuda = False,
140
+ use_cpu = False,
141
+ use_mps_device = False,
142
+ seed = 3407,
143
+ data_seed = 3407,
144
+ jit_mode_eval = False,
145
+ use_ipex = False,
146
+ bf16 = False,
147
+ fp16 = False,
148
+ fp16_opt_level = 'O1',
149
+ half_precision_backend = 'auto',
150
+ bf16_full_eval = False,
151
+ fp16_full_eval = False,
152
+ tf32 = None,
153
+ local_rank = -1,
154
+ ddp_backend = None,
155
+ tpu_num_cores = None,
156
+ tpu_metrics_debug = False,
157
+ debug = '',
158
+ dataloader_drop_last = False,
159
+ eval_steps = None,
160
+ dataloader_num_workers = 0,
161
+ dataloader_prefetch_factor = None,
162
+ past_index = -1,
163
+ run_name = None,
164
+ disable_tqdm = None,
165
+ remove_unused_columns = True,
166
+ label_names = None,
167
+ load_best_model_at_end = False,
168
+ metric_for_best_model = None,
169
+ greater_is_better = None,
170
+ ignore_data_skip = False,
171
+ fsdp = '',
172
+ fsdp_min_num_params = 0,
173
+ fsdp_config = None,
174
+ tp_size = 0,
175
+ fsdp_transformer_layer_cls_to_wrap = None,
176
+ accelerator_config = None,
177
+ deepspeed = None,
178
+ label_smoothing_factor = 0.0,
179
+ optim = 'adamw_8bit',
180
+ optim_args = None,
181
+ adafactor = False,
182
+ group_by_length = False,
183
+ length_column_name = 'length',
184
+ report_to = None,
185
+ ddp_find_unused_parameters = None,
186
+ ddp_bucket_cap_mb = None,
187
+ ddp_broadcast_buffers = None,
188
+ dataloader_pin_memory = True,
189
+ dataloader_persistent_workers = False,
190
+ skip_memory_metrics = True,
191
+ use_legacy_prediction_loop = False,
192
+ push_to_hub = False,
193
+ resume_from_checkpoint = None,
194
+ hub_model_id = None,
195
+ hub_strategy = 'every_save',
196
+ hub_token = None,
197
+ hub_private_repo = None,
198
+ hub_always_push = False,
199
+ gradient_checkpointing = False,
200
+ gradient_checkpointing_kwargs = None,
201
+ include_inputs_for_metrics = False,
202
+ eval_do_concat_batches = True,
203
+ fp16_backend = 'auto',
204
+ evaluation_strategy = None,
205
+ push_to_hub_model_id = None,
206
+ push_to_hub_organization = None,
207
+ push_to_hub_token = None,
208
+ mp_parameters = '',
209
+ auto_find_batch_size = False,
210
+ full_determinism = False,
211
+ torchdynamo = None,
212
+ ray_scope = 'last',
213
+ ddp_timeout = 1800,
214
+ torch_compile = False,
215
+ torch_compile_backend = None,
216
+ torch_compile_mode = None,
217
+ dispatch_batches = None,
218
+ split_batches = None,
219
+ include_tokens_per_second = False,
220
+ include_num_input_tokens_seen = False,
221
+ neftune_noise_alpha = None,
222
+ optim_target_modules = None,
223
+ batch_eval_metrics = False,
224
+ eval_on_start = False,
225
+ use_liger_kernel = False,
226
+ eval_use_gather_object = False,
227
+ average_tokens_across_devices = False,
228
+ max_length = 1024,
229
+ max_prompt_length = 512,
230
+ max_completion_length = None,
231
+ beta = 0.1,
232
+ disable_dropout = True,
233
+ label_pad_token_id = -100,
234
+ padding_value = None,
235
+ truncation_mode = 'keep_end',
236
+ generate_during_eval = False,
237
+ is_encoder_decoder = None,
238
+ model_init_kwargs = None,
239
+ dataset_num_proc = None,
240
+ vllm_sampling_params = None,
241
+ unsloth_num_chunks = -1,
242
+ **kwargs,
243
+ ):
244
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
245
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
246
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
247
+ output_dir = 'unsloth_training_checkpoints'
248
+ save_strategy = 'no'
249
+ if dataset_num_proc is None:
250
+ from multiprocessing import cpu_count
251
+ dataset_num_proc = cpu_count()
252
+
253
+ super().__init__(
254
+ output_dir = output_dir,
255
+ overwrite_output_dir = overwrite_output_dir,
256
+ do_train = do_train,
257
+ do_eval = do_eval,
258
+ do_predict = do_predict,
259
+ eval_strategy = eval_strategy,
260
+ prediction_loss_only = prediction_loss_only,
261
+ per_device_train_batch_size = per_device_train_batch_size,
262
+ per_device_eval_batch_size = per_device_eval_batch_size,
263
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
264
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
265
+ gradient_accumulation_steps = gradient_accumulation_steps,
266
+ eval_accumulation_steps = eval_accumulation_steps,
267
+ eval_delay = eval_delay,
268
+ torch_empty_cache_steps = torch_empty_cache_steps,
269
+ learning_rate = learning_rate,
270
+ weight_decay = weight_decay,
271
+ adam_beta1 = adam_beta1,
272
+ adam_beta2 = adam_beta2,
273
+ adam_epsilon = adam_epsilon,
274
+ max_grad_norm = max_grad_norm,
275
+ num_train_epochs = num_train_epochs,
276
+ max_steps = max_steps,
277
+ lr_scheduler_type = lr_scheduler_type,
278
+ warmup_ratio = warmup_ratio,
279
+ warmup_steps = warmup_steps,
280
+ log_level = log_level,
281
+ log_level_replica = log_level_replica,
282
+ log_on_each_node = log_on_each_node,
283
+ logging_dir = logging_dir,
284
+ logging_strategy = logging_strategy,
285
+ logging_first_step = logging_first_step,
286
+ logging_steps = logging_steps,
287
+ logging_nan_inf_filter = logging_nan_inf_filter,
288
+ save_strategy = save_strategy,
289
+ save_steps = save_steps,
290
+ save_total_limit = save_total_limit,
291
+ save_safetensors = save_safetensors,
292
+ save_on_each_node = save_on_each_node,
293
+ save_only_model = save_only_model,
294
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
295
+ no_cuda = no_cuda,
296
+ use_cpu = use_cpu,
297
+ use_mps_device = use_mps_device,
298
+ seed = seed,
299
+ data_seed = data_seed,
300
+ jit_mode_eval = jit_mode_eval,
301
+ use_ipex = use_ipex,
302
+ bf16 = bf16,
303
+ fp16 = fp16,
304
+ fp16_opt_level = fp16_opt_level,
305
+ half_precision_backend = half_precision_backend,
306
+ bf16_full_eval = bf16_full_eval,
307
+ fp16_full_eval = fp16_full_eval,
308
+ tf32 = tf32,
309
+ local_rank = local_rank,
310
+ ddp_backend = ddp_backend,
311
+ tpu_num_cores = tpu_num_cores,
312
+ tpu_metrics_debug = tpu_metrics_debug,
313
+ debug = debug,
314
+ dataloader_drop_last = dataloader_drop_last,
315
+ eval_steps = eval_steps,
316
+ dataloader_num_workers = dataloader_num_workers,
317
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
318
+ past_index = past_index,
319
+ run_name = run_name,
320
+ disable_tqdm = disable_tqdm,
321
+ remove_unused_columns = remove_unused_columns,
322
+ label_names = label_names,
323
+ load_best_model_at_end = load_best_model_at_end,
324
+ metric_for_best_model = metric_for_best_model,
325
+ greater_is_better = greater_is_better,
326
+ ignore_data_skip = ignore_data_skip,
327
+ fsdp = fsdp,
328
+ fsdp_min_num_params = fsdp_min_num_params,
329
+ fsdp_config = fsdp_config,
330
+ tp_size = tp_size,
331
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
332
+ accelerator_config = accelerator_config,
333
+ deepspeed = deepspeed,
334
+ label_smoothing_factor = label_smoothing_factor,
335
+ optim = optim,
336
+ optim_args = optim_args,
337
+ adafactor = adafactor,
338
+ group_by_length = group_by_length,
339
+ length_column_name = length_column_name,
340
+ report_to = report_to,
341
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
342
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
343
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
344
+ dataloader_pin_memory = dataloader_pin_memory,
345
+ dataloader_persistent_workers = dataloader_persistent_workers,
346
+ skip_memory_metrics = skip_memory_metrics,
347
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
348
+ push_to_hub = push_to_hub,
349
+ resume_from_checkpoint = resume_from_checkpoint,
350
+ hub_model_id = hub_model_id,
351
+ hub_strategy = hub_strategy,
352
+ hub_token = hub_token,
353
+ hub_private_repo = hub_private_repo,
354
+ hub_always_push = hub_always_push,
355
+ gradient_checkpointing = gradient_checkpointing,
356
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
357
+ include_inputs_for_metrics = include_inputs_for_metrics,
358
+ eval_do_concat_batches = eval_do_concat_batches,
359
+ fp16_backend = fp16_backend,
360
+ evaluation_strategy = evaluation_strategy,
361
+ push_to_hub_model_id = push_to_hub_model_id,
362
+ push_to_hub_organization = push_to_hub_organization,
363
+ push_to_hub_token = push_to_hub_token,
364
+ mp_parameters = mp_parameters,
365
+ auto_find_batch_size = auto_find_batch_size,
366
+ full_determinism = full_determinism,
367
+ torchdynamo = torchdynamo,
368
+ ray_scope = ray_scope,
369
+ ddp_timeout = ddp_timeout,
370
+ torch_compile = torch_compile,
371
+ torch_compile_backend = torch_compile_backend,
372
+ torch_compile_mode = torch_compile_mode,
373
+ dispatch_batches = dispatch_batches,
374
+ split_batches = split_batches,
375
+ include_tokens_per_second = include_tokens_per_second,
376
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
377
+ neftune_noise_alpha = neftune_noise_alpha,
378
+ optim_target_modules = optim_target_modules,
379
+ batch_eval_metrics = batch_eval_metrics,
380
+ eval_on_start = eval_on_start,
381
+ use_liger_kernel = use_liger_kernel,
382
+ eval_use_gather_object = eval_use_gather_object,
383
+ average_tokens_across_devices = average_tokens_across_devices,
384
+ max_length = max_length,
385
+ max_prompt_length = max_prompt_length,
386
+ max_completion_length = max_completion_length,
387
+ beta = beta,
388
+ disable_dropout = disable_dropout,
389
+ label_pad_token_id = label_pad_token_id,
390
+ padding_value = padding_value,
391
+ truncation_mode = truncation_mode,
392
+ generate_during_eval = generate_during_eval,
393
+ is_encoder_decoder = is_encoder_decoder,
394
+ model_init_kwargs = model_init_kwargs,
395
+ dataset_num_proc = dataset_num_proc,**kwargs)
396
+ self.vllm_sampling_params = vllm_sampling_params
397
+ self.unsloth_num_chunks = unsloth_num_chunks
398
+ pass
399
+
400
+ class _UnslothORPOTrainer(Trainer):
401
+ r""""""
402
+
403
+ _tag_names = ["trl", "orpo"]
404
+
405
+ def __init__(
406
+ self,
407
+ model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
408
+ args: Optional[ORPOConfig] = None,
409
+ data_collator: Optional[DataCollator] = None,
410
+ train_dataset: Optional[Dataset] = None,
411
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
412
+ processing_class: Optional[
413
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
414
+ ] = None,
415
+ model_init: Optional[Callable[[], PreTrainedModel]] = None,
416
+ callbacks: Optional[list[TrainerCallback]] = None,
417
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
418
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
419
+ peft_config: Optional[dict] = None,
420
+ compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
421
+ ):
422
+ if args.model_init_kwargs is None:
423
+ model_init_kwargs = {}
424
+ elif not isinstance(model, str):
425
+ raise ValueError("You passed model_kwargs to the ORPOTrainer. But your model is already instantiated.")
426
+ else:
427
+ model_init_kwargs = args.model_init_kwargs
428
+ torch_dtype = model_init_kwargs.get("torch_dtype")
429
+ if torch_dtype is not None:
430
+ # Convert to `torch.dtype` if an str is passed
431
+ if isinstance(torch_dtype, str) and torch_dtype != "auto":
432
+ torch_dtype = getattr(torch, torch_dtype)
433
+ if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
434
+ raise ValueError(
435
+ f"Invalid `torch_dtype` passed to the ORPOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
436
+ )
437
+ model_init_kwargs["torch_dtype"] = torch_dtype
438
+
439
+ if isinstance(model, str):
440
+ model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
441
+
442
+ # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
443
+ # has been called in order to properly call autocast if needed.
444
+ self._peft_has_been_casted_to_bf16 = False
445
+
446
+ if not is_peft_available() and peft_config is not None:
447
+ raise ValueError(
448
+ "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
449
+ )
450
+ elif is_peft_available() and peft_config is not None:
451
+ # if model is a peft model and we have a peft_config, we merge and unload it first
452
+ if isinstance(model, PeftModel):
453
+ model = model.merge_and_unload()
454
+
455
+ if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
456
+ _support_gc_kwargs = hasattr(
457
+ args, "gradient_checkpointing_kwargs"
458
+ ) and "gradient_checkpointing_kwargs" in list(
459
+ inspect.signature(prepare_model_for_kbit_training).parameters
460
+ )
461
+
462
+ prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
463
+
464
+ if _support_gc_kwargs:
465
+ prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
466
+
467
+ model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
468
+ elif getattr(args, "gradient_checkpointing", False):
469
+ # For backward compatibility with older versions of transformers
470
+ if hasattr(model, "enable_input_require_grads"):
471
+ model.enable_input_require_grads()
472
+ else:
473
+
474
+ def make_inputs_require_grad(module, input, output):
475
+ output.requires_grad_(True)
476
+
477
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
478
+
479
+ # get peft model with the given config
480
+ model = model
481
+ if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
482
+ peft_module_casting_to_bf16(model)
483
+ # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
484
+ self._peft_has_been_casted_to_bf16 = True
485
+
486
+ # For models that use gradient_checkpointing, we need to attach a hook that enables input
487
+ # to explicitly have `requires_grad=True`, otherwise training will either silently
488
+ # fail or completely fail.
489
+ elif getattr(args, "gradient_checkpointing", False):
490
+ # For backward compatibility with older versions of transformers
491
+ if hasattr(model, "enable_input_require_grads"):
492
+ model.enable_input_require_grads()
493
+ else:
494
+
495
+ def make_inputs_require_grad(module, input, output):
496
+ output.requires_grad_(True)
497
+
498
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
499
+
500
+ if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
501
+ raise ValueError(
502
+ "`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
503
+ " Please install `wandb` or `comet-ml` to resolve."
504
+ )
505
+
506
+ if model is not None:
507
+ self.is_encoder_decoder = model.config.is_encoder_decoder
508
+ elif args.is_encoder_decoder is None:
509
+ raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
510
+ else:
511
+ self.is_encoder_decoder = args.is_encoder_decoder
512
+
513
+ if self.is_encoder_decoder:
514
+ self.decoder_start_token_id = model.config.decoder_start_token_id
515
+ self.pad_token_id = model.config.pad_token_id
516
+
517
+ if processing_class is None:
518
+ raise ValueError("processing_class must be specified to tokenize a ORPO dataset.")
519
+ if args.max_length is None:
520
+ warnings.warn(
521
+ "`max_length` is not set in the ORPOConfig's init"
522
+ " it will default to `512` by default, but you should do it yourself in the future.",
523
+ UserWarning,
524
+ )
525
+ max_length = 512
526
+ else:
527
+ max_length = args.max_length
528
+ if args.max_prompt_length is None:
529
+ warnings.warn(
530
+ "`max_prompt_length` is not set in the ORPOConfig's init"
531
+ " it will default to `128` by default, but you should do it yourself in the future.",
532
+ UserWarning,
533
+ )
534
+ max_prompt_length = 128
535
+ else:
536
+ max_prompt_length = args.max_prompt_length
537
+
538
+ if args.max_completion_length is None and self.is_encoder_decoder:
539
+ warnings.warn(
540
+ "When using an encoder decoder architecture, you should set `max_completion_length` in the ORPOConfig's init"
541
+ " it will default to `128` by default, but you should do it yourself in the future.",
542
+ UserWarning,
543
+ )
544
+ self.max_completion_length = 128
545
+ else:
546
+ self.max_completion_length = args.max_completion_length
547
+
548
+ if data_collator is None:
549
+ data_collator = DPODataCollatorWithPadding(
550
+ pad_token_id=processing_class.pad_token_id,
551
+ label_pad_token_id=args.label_pad_token_id,
552
+ is_encoder_decoder=self.is_encoder_decoder,
553
+ )
554
+
555
+ if args.remove_unused_columns:
556
+ args.remove_unused_columns = False
557
+ # warn users
558
+ warnings.warn(
559
+ "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments"
560
+ " we have set it for you, but you should do it yourself in the future.",
561
+ UserWarning,
562
+ )
563
+
564
+ self.use_dpo_data_collator = True
565
+ else:
566
+ self.use_dpo_data_collator = False
567
+
568
+ # Disable dropout in the model and reference model
569
+ if args.disable_dropout:
570
+ disable_dropout_in_model(model)
571
+
572
+ self.max_length = max_length
573
+ self.generate_during_eval = args.generate_during_eval
574
+ self.label_pad_token_id = args.label_pad_token_id
575
+ self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id
576
+ self.max_prompt_length = max_prompt_length
577
+ self.truncation_mode = args.truncation_mode
578
+ self.processing_class = processing_class
579
+
580
+ self.beta = args.beta
581
+ self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
582
+ self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
583
+ if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
584
+ warnings.warn(
585
+ "You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
586
+ "`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
587
+ "greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
588
+ "loss.",
589
+ UserWarning,
590
+ )
591
+
592
+ self._stored_metrics = defaultdict(lambda: defaultdict(list))
593
+
594
+ # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
595
+ # input tensor associated with the key "input_ids". However, in ORPO, the sampled data does not include the
596
+ # "input_ids" key. Instead, the available keys are "prompt_input_ids", "chosen_input_ids", and
597
+ # "rejected_input_ids". As a result, the trainer issues the warning: "Could not estimate the number of tokens
598
+ # of the input, floating-point operations will not be computed." To suppress this warning, we set the
599
+ # "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate
600
+ # that the warning has already been issued.
601
+ model.warnings_issued["estimate_tokens"] = True
602
+
603
+ # Compute that only on the main process for faster data processing.
604
+ # see: https://github.com/huggingface/trl/pull/1255
605
+ with PartialState().local_main_process_first():
606
+ # Extract the prompt if needed, and apply the chat template if needed
607
+ train_dataset = train_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
608
+ train_dataset = train_dataset.map(
609
+ maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, num_proc=args.dataset_num_proc
610
+ )
611
+ train_dataset = train_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
612
+ if eval_dataset is not None:
613
+ eval_dataset = eval_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
614
+ eval_dataset = eval_dataset.map(
615
+ maybe_apply_chat_template,
616
+ fn_kwargs={"tokenizer": processing_class},
617
+ num_proc=args.dataset_num_proc,
618
+ )
619
+ eval_dataset = eval_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
620
+
621
+ super().__init__(
622
+ model=model,
623
+ args=args,
624
+ data_collator=data_collator,
625
+ train_dataset=train_dataset,
626
+ eval_dataset=eval_dataset,
627
+ processing_class=processing_class,
628
+ model_init=model_init,
629
+ compute_metrics=compute_metrics,
630
+ callbacks=callbacks,
631
+ optimizers=optimizers,
632
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
633
+ )
634
+
635
+ # Add tags for models that have been loaded with the correct transformers version
636
+ if hasattr(self.model, "add_model_tags"):
637
+ self.model.add_model_tags(self._tag_names)
638
+
639
+ if not hasattr(self, "accelerator"):
640
+ raise AttributeError(
641
+ "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
642
+ )
643
+
644
+ def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
645
+ # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
646
+ deepspeed_plugin = self.accelerator.state.deepspeed_plugin
647
+ config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
648
+
649
+ if model is not None:
650
+ if hasattr(model, "config"):
651
+ hidden_size = (
652
+ max(model.config.hidden_sizes)
653
+ if getattr(model.config, "hidden_sizes", None)
654
+ else getattr(model.config, "hidden_size", None)
655
+ )
656
+ if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
657
+ # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
658
+ # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
659
+ config_kwargs.update(
660
+ {
661
+ "zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
662
+ "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
663
+ "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
664
+ }
665
+ )
666
+
667
+ # If ZeRO-3 is used, we shard both the active and reference model.
668
+ # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
669
+ if config_kwargs["zero_optimization"]["stage"] != 3:
670
+ config_kwargs["zero_optimization"]["stage"] = 0
671
+ model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
672
+ model.eval()
673
+ return model
674
+
675
+ def build_tokenized_answer(self, prompt, answer):
676
+ """
677
+ Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`.
678
+ It does ensure `enc(a + b) = enc(a) + enc(a + b)[len(enc(a)):]`.
679
+ Reference:
680
+ https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
681
+ """
682
+
683
+ full_tokenized = self.processing_class(prompt + answer, add_special_tokens=False)
684
+ prompt_input_ids = self.processing_class(prompt, add_special_tokens=False)["input_ids"]
685
+
686
+ answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids) :]
687
+ answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids) :]
688
+
689
+ # Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]`
690
+ full_concat_input_ids = np.concatenate([prompt_input_ids, answer_input_ids])
691
+
692
+ # Prepare input tokens for token by token comparison
693
+ full_input_ids = np.array(full_tokenized["input_ids"])
694
+
695
+ if len(full_input_ids) != len(full_concat_input_ids):
696
+ raise ValueError("Prompt input ids and answer input ids should have the same length.")
697
+
698
+ # On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens
699
+ # can be merged together when tokenizing prompt+answer. This could result
700
+ # on the last token from the prompt being different when tokenized on its own
701
+ # vs when done as prompt+answer.
702
+ response_token_ids_start_idx = len(prompt_input_ids)
703
+
704
+ # If tokenized prompt is different than both prompt+answer, then it means the
705
+ # last token has changed due to merging.
706
+ if prompt_input_ids != full_tokenized["input_ids"][:response_token_ids_start_idx]:
707
+ response_token_ids_start_idx -= 1
708
+
709
+ prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx]
710
+ prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx]
711
+
712
+ if len(prompt_input_ids) != len(prompt_attention_mask):
713
+ raise ValueError("Prompt input ids and attention mask should have the same length.")
714
+
715
+ answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:]
716
+ answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:]
717
+
718
+ return dict(
719
+ prompt_input_ids=prompt_input_ids,
720
+ prompt_attention_mask=prompt_attention_mask,
721
+ input_ids=answer_input_ids,
722
+ attention_mask=answer_attention_mask,
723
+ )
724
+
725
+ def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module]] = None) -> dict:
726
+ """Tokenize a single row from a ORPO specific dataset.
727
+
728
+ At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation
729
+ in case the prompt + chosen or prompt + rejected responses is/are too long. First
730
+ we truncate the prompt; if we're still too long, we truncate the chosen/rejected.
731
+
732
+ We also create the labels for the chosen/rejected responses, which are of length equal to
733
+ the sum of the length of the prompt and the chosen/rejected response, with
734
+ label_pad_token_id for the prompt tokens.
735
+ """
736
+ batch = {}
737
+ prompt = feature["prompt"]
738
+ chosen = feature["chosen"]
739
+ rejected = feature["rejected"]
740
+
741
+ if not self.is_encoder_decoder:
742
+ # Check issues below for more details
743
+ # 1. https://github.com/huggingface/trl/issues/907
744
+ # 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
745
+ # 3. https://github.com/LianjiaTech/BELLE/issues/337
746
+
747
+ if not isinstance(prompt, str):
748
+ raise ValueError(f"prompt should be an str but got {type(prompt)}")
749
+ prompt_tokens = self.processing_class(prompt, add_special_tokens=False)
750
+ prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()}
751
+
752
+ if not isinstance(chosen, str):
753
+ raise ValueError(f"chosen should be an str but got {type(chosen)}")
754
+ chosen_tokens = self.build_tokenized_answer(prompt, chosen)
755
+
756
+ if not isinstance(rejected, str):
757
+ raise ValueError(f"rejected should be an str but got {type(rejected)}")
758
+ rejected_tokens = self.build_tokenized_answer(prompt, rejected)
759
+
760
+ # Last prompt token might get merged by tokenizer and
761
+ # it should not be included for generation if that happens
762
+ prompt_len_input_ids = len(prompt_tokens["prompt_input_ids"])
763
+
764
+ chosen_prompt_len_input_ids = len(chosen_tokens["prompt_input_ids"])
765
+ rejected_prompt_len_input_ids = len(rejected_tokens["prompt_input_ids"])
766
+ prompt_len_input_ids = min(chosen_prompt_len_input_ids, rejected_prompt_len_input_ids)
767
+
768
+ for k, v in prompt_tokens.items():
769
+ prompt_tokens[k] = v[:prompt_len_input_ids]
770
+
771
+ # Make sure prompts only have one different token at most an
772
+ # and length only differs by 1 at most
773
+ num_diff_tokens = sum(
774
+ [a != b for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"])]
775
+ )
776
+ num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids)
777
+ if num_diff_tokens > 1 or num_diff_len > 1:
778
+ raise ValueError(
779
+ "Chosen and rejected prompt_input_ids might only differ on the "
780
+ "last token due to tokenizer merge ops."
781
+ )
782
+
783
+ # add BOS token to head of prompt. Avoid adding if it's already there
784
+ prompt_tokens, chosen_tokens, rejected_tokens = add_bos_token_if_needed(
785
+ self.processing_class.bos_token_id,
786
+ prompt_len_input_ids,
787
+ prompt_tokens,
788
+ chosen_prompt_len_input_ids,
789
+ chosen_tokens,
790
+ rejected_prompt_len_input_ids,
791
+ rejected_tokens,
792
+ )
793
+
794
+ # add EOS token to end of answer. Avoid adding if it's already there
795
+ chosen_tokens, rejected_tokens = add_eos_token_if_needed(
796
+ self.processing_class.eos_token_id, chosen_tokens, rejected_tokens
797
+ )
798
+
799
+ longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))
800
+
801
+ # if combined sequence is too long, truncate the prompt
802
+ for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]:
803
+ if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
804
+ if self.truncation_mode == "keep_start":
805
+ for k in ["prompt_input_ids", "prompt_attention_mask"]:
806
+ answer_tokens[k] = answer_tokens[k][: self.max_prompt_length]
807
+ elif self.truncation_mode == "keep_end":
808
+ for k in ["prompt_input_ids", "prompt_attention_mask"]:
809
+ answer_tokens[k] = answer_tokens[k][-self.max_prompt_length :]
810
+ else:
811
+ raise ValueError(f"Unknown truncation mode: {self.truncation_mode}")
812
+
813
+ # if that's still too long, truncate the response
814
+ for answer_tokens in [chosen_tokens, rejected_tokens]:
815
+ if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
816
+ for k in ["input_ids", "attention_mask"]:
817
+ answer_tokens[k] = answer_tokens[k][: self.max_length - self.max_prompt_length]
818
+
819
+ # Create labels
820
+ chosen_sequence_tokens = {
821
+ k: chosen_tokens[f"prompt_{k}"] + chosen_tokens[k] for k in ["input_ids", "attention_mask"]
822
+ }
823
+ rejected_sequence_tokens = {
824
+ k: rejected_tokens[f"prompt_{k}"] + rejected_tokens[k] for k in ["input_ids", "attention_mask"]
825
+ }
826
+ chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:]
827
+ chosen_sequence_tokens["labels"][: len(chosen_tokens["prompt_input_ids"])] = [
828
+ self.label_pad_token_id
829
+ ] * len(chosen_tokens["prompt_input_ids"])
830
+ rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:]
831
+ rejected_sequence_tokens["labels"][: len(rejected_tokens["prompt_input_ids"])] = [
832
+ self.label_pad_token_id
833
+ ] * len(rejected_tokens["prompt_input_ids"])
834
+
835
+ for k, toks in {
836
+ "chosen_": chosen_sequence_tokens,
837
+ "rejected_": rejected_sequence_tokens,
838
+ "": prompt_tokens,
839
+ }.items():
840
+ for type_key, tokens in toks.items():
841
+ if type_key == "token_type_ids":
842
+ continue
843
+ batch[f"{k}{type_key}"] = tokens
844
+
845
+ else:
846
+ chosen_tokens = self.processing_class(
847
+ chosen, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
848
+ )
849
+ rejected_tokens = self.processing_class(
850
+ rejected, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
851
+ )
852
+ prompt_tokens = self.processing_class(
853
+ prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True
854
+ )
855
+
856
+ batch["chosen_labels"] = chosen_tokens["input_ids"]
857
+ batch["rejected_labels"] = rejected_tokens["input_ids"]
858
+ batch["prompt_input_ids"] = prompt_tokens["input_ids"]
859
+ batch["prompt_attention_mask"] = prompt_tokens["attention_mask"]
860
+
861
+ if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"):
862
+ batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
863
+ labels=torch.tensor(batch["rejected_labels"])
864
+ )
865
+ batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
866
+ labels=torch.tensor(batch["chosen_labels"])
867
+ )
868
+
869
+ if is_torch_xla_available():
870
+ # Pad the sequences to global max_length to avoid TorchXLA recompilation
871
+ for k in batch:
872
+ if "labels" in k or self.is_encoder_decoder:
873
+ pad_value = self.label_pad_token_id
874
+ elif k.endswith("_input_ids"):
875
+ pad_value = self.padding_value
876
+ elif k.endswith("_attention_mask"):
877
+ pad_value = 0
878
+ batch[k] = batch[k] + [pad_value] * (self.max_length - len(batch[k]))
879
+ return batch
880
+
881
+ @staticmethod
882
+ def concatenated_inputs(
883
+ batch: dict[str, Union[list, torch.LongTensor]],
884
+ is_encoder_decoder: bool = False,
885
+ label_pad_token_id: int = -100,
886
+ padding_value: int = 0,
887
+ device: Optional[torch.device] = None,
888
+ ) -> dict[str, torch.LongTensor]:
889
+ """Concatenate the chosen and rejected inputs into a single tensor.
890
+
891
+ Args:
892
+ batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length).
893
+ is_encoder_decoder: Whether the model is an encoder-decoder model.
894
+ label_pad_token_id: The label pad token id.
895
+ padding_value: The padding value to use for the concatenated inputs_ids.
896
+ device: The device for the concatenated inputs.
897
+
898
+ Returns:
899
+ A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'.
900
+ """
901
+ concatenated_batch = {}
902
+
903
+ if is_encoder_decoder:
904
+ max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1])
905
+ else:
906
+ max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1])
907
+
908
+ for k in batch:
909
+ if k.startswith("chosen") and isinstance(batch[k], torch.Tensor):
910
+ if "labels" in k or is_encoder_decoder:
911
+ pad_value = label_pad_token_id
912
+ elif k.endswith("_input_ids"):
913
+ pad_value = padding_value
914
+ elif k.endswith("_attention_mask"):
915
+ pad_value = 0
916
+ concatenated_key = k.replace("chosen", "concatenated")
917
+ concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value)
918
+ for k in batch:
919
+ if k.startswith("rejected") and isinstance(batch[k], torch.Tensor):
920
+ if "labels" in k or is_encoder_decoder:
921
+ pad_value = label_pad_token_id
922
+ elif k.endswith("_input_ids"):
923
+ pad_value = padding_value
924
+ elif k.endswith("_attention_mask"):
925
+ pad_value = 0
926
+ concatenated_key = k.replace("rejected", "concatenated")
927
+ concatenated_batch[concatenated_key] = torch.cat(
928
+ (
929
+ concatenated_batch[concatenated_key],
930
+ pad_to_length(batch[k], max_length, pad_value=pad_value),
931
+ ),
932
+ dim=0,
933
+ ).to(device=device)
934
+
935
+ if is_encoder_decoder:
936
+ concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1).to(device=device)
937
+ concatenated_batch["concatenated_attention_mask"] = (
938
+ batch["prompt_attention_mask"].repeat(2, 1).to(device=device)
939
+ )
940
+
941
+ return concatenated_batch
942
+
943
+ def odds_ratio_loss(
944
+ self,
945
+ policy_chosen_logps: torch.FloatTensor,
946
+ policy_rejected_logps: torch.FloatTensor,
947
+ ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
948
+ """Compute ORPO's odds ratio (OR) loss for a batch of policy and reference model log probabilities.
949
+
950
+ Args:
951
+ policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
952
+ policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
953
+
954
+ Returns:
955
+ A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
956
+ The losses tensor contains the ORPO loss for each example in the batch.
957
+ The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
958
+ The log odds ratio of the chosen responses over the rejected responses ratio for logging purposes.
959
+ The `log(sigmoid(log_odds_chosen))` for logging purposes.
960
+ """
961
+
962
+ # Derived from Eqs. (4) and (7) from https://huggingface.co/papers/2403.07691 by using log identities and exp(log(P(y|x)) = P(y|x)
963
+ log_odds = (policy_chosen_logps - policy_rejected_logps) - (
964
+ torch.log1p(-torch.exp(policy_chosen_logps)) - torch.log1p(-torch.exp(policy_rejected_logps))
965
+ )
966
+ ratio = F.logsigmoid(log_odds)
967
+ losses = self.beta * ratio
968
+
969
+ chosen_rewards = self.beta * (policy_chosen_logps.to(self.accelerator.device)).detach()
970
+ rejected_rewards = self.beta * (policy_rejected_logps.to(self.accelerator.device)).detach()
971
+
972
+ return losses, chosen_rewards, rejected_rewards, torch.mean(ratio), torch.mean(log_odds)
973
+
974
+ @staticmethod
975
+ def get_batch_logps(
976
+ logits: torch.FloatTensor,
977
+ labels: torch.LongTensor,
978
+ average_log_prob: bool = False,
979
+ label_pad_token_id: int = -100,
980
+ is_encoder_decoder: bool = False,
981
+ ) -> torch.FloatTensor:
982
+ """Compute the log probabilities of the given labels under the given logits.
983
+
984
+ Args:
985
+ logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
986
+ labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
987
+ average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
988
+ label_pad_token_id: The label pad token id.
989
+ is_encoder_decoder: Whether the model is an encoder-decoder model.
990
+
991
+ Returns:
992
+ A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
993
+ """
994
+ if logits.shape[:-1] != labels.shape:
995
+ raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
996
+
997
+ if not is_encoder_decoder:
998
+ labels = labels[:, 1:].clone()
999
+ logits = logits[:, :-1, :]
1000
+ loss_mask = labels != label_pad_token_id
1001
+
1002
+ # dummy token; we'll ignore the losses on these tokens later
1003
+ labels = torch.where(labels == label_pad_token_id, 0, labels)
1004
+
1005
+ per_token_logps = selective_log_softmax(logits, labels)
1006
+
1007
+ if average_log_prob:
1008
+ return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
1009
+ else:
1010
+ return (per_token_logps * loss_mask).sum(-1)
1011
+
1012
+ def concatenated_forward(
1013
+ self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
1014
+ ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
1015
+ """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
1016
+
1017
+ We do this to avoid doing two forward passes, because it's faster for FSDP.
1018
+ """
1019
+ concatenated_batch = self.concatenated_inputs(
1020
+ batch,
1021
+ is_encoder_decoder=self.is_encoder_decoder,
1022
+ label_pad_token_id=self.label_pad_token_id,
1023
+ padding_value=self.padding_value,
1024
+ device=self.accelerator.device,
1025
+ )
1026
+ len_chosen = batch["chosen_labels"].shape[0]
1027
+
1028
+ model_kwargs = (
1029
+ {
1030
+ "decoder_input_ids": self._shift_right(concatenated_batch["concatenated_labels"]),
1031
+ }
1032
+ if self.is_encoder_decoder
1033
+ else {}
1034
+ )
1035
+
1036
+ if self.aux_loss_enabled:
1037
+ model_kwargs["output_router_logits"] = True
1038
+
1039
+ outputs = model(
1040
+ concatenated_batch["concatenated_input_ids"],
1041
+ attention_mask=concatenated_batch["concatenated_attention_mask"],
1042
+ use_cache=False,
1043
+ **model_kwargs,
1044
+ )
1045
+ all_logits = outputs.logits
1046
+
1047
+ def cross_entropy_loss(logits, labels):
1048
+ if not self.is_encoder_decoder:
1049
+ # Shift so that tokens < n predict n
1050
+ logits = logits[..., :-1, :].contiguous()
1051
+ labels = labels[..., 1:].contiguous()
1052
+ # Flatten the tokens
1053
+ loss_fct = nn.CrossEntropyLoss()
1054
+ logits = logits.view(-1, logits.shape[-1])
1055
+ labels = labels.view(-1)
1056
+ # Enable model parallelism
1057
+ labels = labels.to(logits.device)
1058
+ loss = loss_fct(logits, labels)
1059
+ return loss
1060
+
1061
+ if self.is_encoder_decoder:
1062
+ labels = concatenated_batch["concatenated_labels"].clone()
1063
+ else:
1064
+ labels = concatenated_batch["concatenated_input_ids"].clone()
1065
+ attention_mask = concatenated_batch["concatenated_attention_mask"]
1066
+ labels = torch.where(attention_mask == 1, labels, self.label_pad_token_id)
1067
+ # orpo chosen nll loss is computed over the full prompt and response
1068
+ chosen_nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])
1069
+
1070
+ all_logps = self.get_batch_logps(
1071
+ all_logits,
1072
+ concatenated_batch["concatenated_labels"],
1073
+ average_log_prob=True,
1074
+ is_encoder_decoder=self.is_encoder_decoder,
1075
+ label_pad_token_id=self.label_pad_token_id,
1076
+ )
1077
+
1078
+ chosen_logps = all_logps[:len_chosen]
1079
+ rejected_logps = all_logps[len_chosen:]
1080
+
1081
+ if not self.is_encoder_decoder:
1082
+ chosen_logits = all_logits[:len_chosen, :-1, :]
1083
+ rejected_logits = all_logits[len_chosen:, :-1, :]
1084
+ else:
1085
+ chosen_logits = all_logits[:len_chosen]
1086
+ rejected_logits = all_logits[len_chosen:]
1087
+
1088
+ if self.aux_loss_enabled:
1089
+ return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss, outputs.aux_loss)
1090
+
1091
+ return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss)
1092
+
1093
+ def get_batch_loss_metrics(
1094
+ self,
1095
+ model,
1096
+ batch: dict[str, Union[list, torch.LongTensor]],
1097
+ train_eval: Literal["train", "eval"] = "train",
1098
+ ):
1099
+ """Compute the ORPO loss and other metrics for the given batch of inputs for train or test."""
1100
+ metrics = {}
1101
+
1102
+ forward_output = self.concatenated_forward(model, batch)
1103
+ (
1104
+ policy_chosen_logps,
1105
+ policy_rejected_logps,
1106
+ policy_chosen_logits,
1107
+ policy_rejected_logits,
1108
+ policy_nll_loss,
1109
+ ) = forward_output[:5]
1110
+ if self.aux_loss_enabled:
1111
+ aux_loss = forward_output[5]
1112
+
1113
+ losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = self.odds_ratio_loss(
1114
+ policy_chosen_logps, policy_rejected_logps
1115
+ )
1116
+ # full ORPO loss
1117
+ loss = policy_nll_loss - losses.mean()
1118
+
1119
+ reward_accuracies = (chosen_rewards > rejected_rewards).float()
1120
+
1121
+ prefix = "eval_" if train_eval == "eval" else ""
1122
+ metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean()
1123
+ metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean()
1124
+ metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean()
1125
+ metrics[f"{prefix}rewards/margins"] = self.accelerator.gather_for_metrics(
1126
+ chosen_rewards - rejected_rewards
1127
+ ).mean()
1128
+ metrics[f"{prefix}logps/rejected"] = self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean()
1129
+ metrics[f"{prefix}logps/chosen"] = self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean()
1130
+ metrics[f"{prefix}logits/rejected"] = (
1131
+ self.accelerator.gather_for_metrics(policy_rejected_logits).detach().mean()
1132
+ )
1133
+ metrics[f"{prefix}logits/chosen"] = self.accelerator.gather_for_metrics(policy_chosen_logits).detach().mean()
1134
+ metrics[f"{prefix}nll_loss"] = self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean()
1135
+ metrics[f"{prefix}log_odds_ratio"] = self.accelerator.gather_for_metrics(log_odds_ratio).mean()
1136
+ metrics[f"{prefix}log_odds_chosen"] = self.accelerator.gather_for_metrics(log_odds_chosen).mean()
1137
+ if is_torch_xla_available():
1138
+ xm.mark_step() # needed because .item() calls
1139
+ for k, v in metrics.items():
1140
+ metrics[k] = v.item()
1141
+ if self.aux_loss_enabled:
1142
+ loss += self.aux_loss_coef * aux_loss
1143
+
1144
+ return loss, metrics
1145
+
1146
+ def compute_loss(
1147
+ self,
1148
+ model: Union[PreTrainedModel, nn.Module],
1149
+ inputs: dict[str, Union[torch.Tensor, Any]],
1150
+ return_outputs=False,
1151
+ num_items_in_batch=None,
1152
+ ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
1153
+ compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1154
+
1155
+ with compute_loss_context_manager:
1156
+ loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
1157
+
1158
+ # Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class:
1159
+ loss = loss.to(self.args.device)
1160
+
1161
+ # force log the metrics
1162
+ self.store_metrics(metrics, train_eval="train")
1163
+
1164
+ if return_outputs:
1165
+ return (loss, metrics)
1166
+ return loss
1167
+
1168
+ def generate_from_model(self, model, batch: dict[str, torch.LongTensor]) -> str:
1169
+ """Generate samples from the model and reference model for the given batch of inputs."""
1170
+
1171
+ # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
1172
+ # the torch cuda amp context manager as some hidden states are silently casted to full precision.
1173
+ generate_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1174
+
1175
+ with generate_context_manager:
1176
+ policy_output = model.generate(
1177
+ input_ids=batch["prompt_input_ids"],
1178
+ attention_mask=batch["prompt_attention_mask"],
1179
+ max_length=self.max_length,
1180
+ do_sample=True,
1181
+ pad_token_id=self.processing_class.pad_token_id,
1182
+ )
1183
+
1184
+ policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id)
1185
+ policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True)
1186
+
1187
+ return policy_output_decoded
1188
+
1189
+ def prediction_step(
1190
+ self,
1191
+ model: Union[PreTrainedModel, nn.Module],
1192
+ inputs: dict[str, Union[torch.Tensor, Any]],
1193
+ prediction_loss_only: bool,
1194
+ ignore_keys: Optional[list[str]] = None,
1195
+ ):
1196
+ if not self.use_dpo_data_collator:
1197
+ warnings.warn(
1198
+ "prediction_step is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
1199
+ "DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
1200
+ )
1201
+ if ignore_keys is None:
1202
+ if hasattr(model, "config"):
1203
+ ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
1204
+ else:
1205
+ ignore_keys = []
1206
+
1207
+ prediction_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1208
+
1209
+ with torch.no_grad(), prediction_context_manager:
1210
+ loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval")
1211
+
1212
+ # force log the metrics
1213
+ self.store_metrics(metrics, train_eval="eval")
1214
+
1215
+ if prediction_loss_only:
1216
+ return (loss.detach(), None, None)
1217
+
1218
+ # logits for the chosen and rejected samples from model
1219
+ logits_dict = {
1220
+ "eval_logits/chosen": metrics["eval_logits/chosen"],
1221
+ "eval_logits/rejected": metrics["eval_logits/rejected"],
1222
+ }
1223
+ logits = tuple(v.unsqueeze(dim=0) for k, v in logits_dict.items() if k not in ignore_keys)
1224
+ logits = torch.stack(logits).mean(axis=1).to(self.accelerator.device)
1225
+ labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
1226
+
1227
+ return (loss.detach(), logits, labels)
1228
+
1229
+ def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
1230
+ for key, value in metrics.items():
1231
+ self._stored_metrics[train_eval][key].append(value)
1232
+
1233
+ def evaluation_loop(
1234
+ self,
1235
+ dataloader: DataLoader,
1236
+ description: str,
1237
+ prediction_loss_only: Optional[bool] = None,
1238
+ ignore_keys: Optional[list[str]] = None,
1239
+ metric_key_prefix: str = "eval",
1240
+ ) -> EvalLoopOutput:
1241
+ """
1242
+ Overriding built-in evaluation loop to store metrics for each batch.
1243
+ Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
1244
+
1245
+ Works both with or without labels.
1246
+ """
1247
+
1248
+ # Sample and save to game log if requested (for one batch to save time)
1249
+ if self.generate_during_eval:
1250
+ # Generate random indices within the range of the total number of samples
1251
+ num_samples = len(dataloader.dataset)
1252
+ random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
1253
+
1254
+ # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
1255
+ random_batch_dataset = dataloader.dataset.select(random_indices)
1256
+ random_batch = self.data_collator(random_batch_dataset)
1257
+ random_batch = self._prepare_inputs(random_batch)
1258
+
1259
+ policy_output_decoded = self.generate_from_model(self.model, random_batch)
1260
+
1261
+ table = pd.DataFrame(
1262
+ columns=["Prompt", "Policy"],
1263
+ data=[
1264
+ [prompt, pol[len(prompt) :]] for prompt, pol in zip(random_batch["prompt"], policy_output_decoded)
1265
+ ],
1266
+ )
1267
+ if "wandb" in self.args.report_to:
1268
+ wandb.log({"game_log": wandb.Table(data=table)})
1269
+
1270
+ if "comet_ml" in self.args.report_to:
1271
+ log_table_to_comet_experiment(
1272
+ name="game_log.csv",
1273
+ table=table,
1274
+ )
1275
+
1276
+ # Base evaluation
1277
+ initial_output = super().evaluation_loop(
1278
+ dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
1279
+ )
1280
+
1281
+ return initial_output
1282
+
1283
+ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
1284
+ """
1285
+ Log `logs` on the various objects watching training, including stored metrics.
1286
+
1287
+ Args:
1288
+ logs (`dict[str, float]`):
1289
+ The values to log.
1290
+ start_time (`float` or `None`, *optional*, defaults to `None`):
1291
+ Start time of the training.
1292
+ """
1293
+ # logs either has 'loss' or 'eval_loss'
1294
+ train_eval = "train" if "loss" in logs else "eval"
1295
+ # Add averaged stored metrics to logs
1296
+ for key, metrics in self._stored_metrics[train_eval].items():
1297
+ logs[key] = torch.tensor(metrics).mean().item()
1298
+ del self._stored_metrics[train_eval]
1299
+
1300
+ if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
1301
+ return super().log(logs, start_time)
1302
+ else: # transformers<=4.46
1303
+ return super().log(logs)
1304
+
1305
+ def _shift_right(self, input_ids):
1306
+ if self.decoder_start_token_id is None:
1307
+ raise ValueError(
1308
+ "model.config.decoder_start_token_id has to be defined. It is usually set to the pad_token_id."
1309
+ )
1310
+
1311
+ # shift inputs to the right
1312
+ if is_torch_fx_proxy(input_ids):
1313
+ # Item assignment is not supported natively for proxies.
1314
+ shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), self.decoder_start_token_id)
1315
+ shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
1316
+ else:
1317
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
1318
+ shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
1319
+ shifted_input_ids[..., 0] = self.decoder_start_token_id
1320
+
1321
+ if self.pad_token_id is None:
1322
+ raise ValueError("model.config.pad_token_id has to be defined.")
1323
+ # replace possible -100 values in labels by `pad_token_id`
1324
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, self.pad_token_id)
1325
+
1326
+ return shifted_input_ids
1327
+
1328
+ def create_model_card(
1329
+ self,
1330
+ model_name: Optional[str] = None,
1331
+ dataset_name: Optional[str] = None,
1332
+ tags: Union[str, list[str], None] = None,
1333
+ ):
1334
+ """
1335
+ Creates a draft of a model card using the information available to the `Trainer`.
1336
+
1337
+ Args:
1338
+ model_name (`str` or `None`, *optional*, defaults to `None`):
1339
+ Name of the model.
1340
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
1341
+ Name of the dataset used for training.
1342
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
1343
+ Tags to be associated with the model card.
1344
+ """
1345
+ if not self.is_world_process_zero():
1346
+ return
1347
+
1348
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
1349
+ base_model = self.model.config._name_or_path
1350
+ else:
1351
+ base_model = None
1352
+
1353
+ tags = tags or []
1354
+ if isinstance(tags, str):
1355
+ tags = [tags]
1356
+
1357
+ if hasattr(self.model.config, "unsloth_version"):
1358
+ tags.append("unsloth")
1359
+
1360
+ citation = textwrap.dedent("""\
1361
+ @article{hong2024orpo,
1362
+ title = {{ORPO: Monolithic Preference Optimization without Reference Model}},
1363
+ author = {Jiwoo Hong and Noah Lee and James Thorne},
1364
+ year = 2024,
1365
+ eprint = {arXiv:2403.07691}
1366
+ }""")
1367
+
1368
+ model_card = generate_model_card(
1369
+ base_model=base_model,
1370
+ model_name=model_name,
1371
+ hub_model_id=self.hub_model_id,
1372
+ dataset_name=dataset_name,
1373
+ tags=tags,
1374
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
1375
+ comet_url=get_comet_experiment_url(),
1376
+ trainer_name="ORPO",
1377
+ trainer_citation=citation,
1378
+ paper_title="ORPO: Monolithic Preference Optimization without Reference Model",
1379
+ paper_id="2403.07691",
1380
+ )
1381
+
1382
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
1383
+ class UnslothORPOTrainer(_UnslothORPOTrainer):
1384
+ """
1385
+
1386
+ Initialize ORPOTrainer.
1387
+
1388
+ Args:
1389
+ model (`transformers.PreTrainedModel`):
1390
+ The model to train, preferably an `AutoModelForSequenceClassification`.
1391
+ args (`ORPOConfig`):
1392
+ The ORPO config arguments to use for training.
1393
+ data_collator (`transformers.DataCollator`):
1394
+ The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
1395
+ which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
1396
+ train_dataset (`datasets.Dataset`):
1397
+ The dataset to use for training.
1398
+ eval_dataset (`datasets.Dataset`):
1399
+ The dataset to use for evaluation.
1400
+ processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
1401
+ Processing class used to process the data. If provided, will be used to automatically process the inputs
1402
+ for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
1403
+ reuse the fine-tuned model.
1404
+ model_init (`Callable[[], transformers.PreTrainedModel]`):
1405
+ The model initializer to use for training. If None is specified, the default model initializer will be used.
1406
+ callbacks (`list[transformers.TrainerCallback]`):
1407
+ The callbacks to use for training.
1408
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
1409
+ The optimizer and scheduler to use for training.
1410
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
1411
+ The function to use to preprocess the logits before computing the metrics.
1412
+ peft_config (`dict`, defaults to `None`):
1413
+ The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
1414
+ compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
1415
+ The function to use to compute the metrics. Must take a `EvalPrediction` and return
1416
+ a dictionary string to metric values.
1417
+
1418
+ """
1419
+ def __init__(
1420
+ self,
1421
+ model = None,
1422
+ args = None,
1423
+ data_collator = None,
1424
+ train_dataset = None,
1425
+ eval_dataset = None,
1426
+ processing_class = None,
1427
+ model_init = None,
1428
+ callbacks = None,
1429
+ preprocess_logits_for_metrics = None,
1430
+ peft_config = None,
1431
+ compute_metrics = None,
1432
+ **kwargs
1433
+ ):
1434
+ if args is None: args = UnslothORPOConfig()
1435
+ use_bf16 = getattr(args, 'bf16', False)
1436
+ use_fp16 = getattr(args, 'fp16', False)
1437
+ force_float32 = False
1438
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
1439
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
1440
+ force_float32 = True
1441
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
1442
+ dtype = getattr(model.config, 'torch_dtype', None)
1443
+ if dtype is None: dtype = model.get_input_embeddings().dtype
1444
+ from unsloth_zoo.utils import _get_dtype
1445
+ dtype = _get_dtype(dtype)
1446
+ float16 = dtype == torch.float16
1447
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
1448
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
1449
+ if force_float32:
1450
+ args.fp16 = False
1451
+ args.bf16 = False
1452
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1453
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
1454
+ args.fp16 = float16
1455
+ args.bf16 = not float16
1456
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
1457
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
1458
+ args.eval_strategy = 'steps'
1459
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
1460
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
1461
+ if ga_steps is not None and ga_steps > 1:
1462
+ from transformers import __version__ as transformers_version
1463
+ if Version(transformers_version) <= Version('4.45.2'):
1464
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
1465
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
1466
+ if getattr(args, 'eval_strategy', 'no') != 'no':
1467
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
1468
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
1469
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
1470
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
1471
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
1472
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
1473
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
1474
+ if force_float32:
1475
+ args.bf16_full_eval = False
1476
+ args.fp16_full_eval = False
1477
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
1478
+ args.bf16_full_eval = True
1479
+ args.fp16_full_eval = False
1480
+ elif not bf16_full_eval and not fp16_full_eval:
1481
+ args.bf16_full_eval = args.bf16
1482
+ args.fp16_full_eval = args.fp16
1483
+ _output_logits = False
1484
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
1485
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
1486
+ if _output_logits:
1487
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
1488
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1489
+ pass
1490
+ else:
1491
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
1492
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
1493
+ if args_max_seq_length is None and model_max_seq_length is not None:
1494
+ max_seq_length = model.max_seq_length
1495
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1496
+ if model is not None and hasattr(model, 'for_training'):
1497
+ model.for_training()
1498
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1499
+ if 'processing_class' in locals():
1500
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1501
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1502
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
1503
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
1504
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1505
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
1506
+ data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
1507
+ elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
1508
+ data_collator = DataCollatorForSeq2Seq(__tokenizer)
1509
+ else:
1510
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
1511
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
1512
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
1513
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1514
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
1515
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
1516
+ data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
1517
+ else:
1518
+ data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
1519
+ other_metrics = []
1520
+
1521
+ from unsloth_zoo.logging_utils import PatchRLStatistics
1522
+ PatchRLStatistics('orpo_trainer', other_metrics)
1523
+
1524
+ super().__init__(
1525
+ model = model,
1526
+ args = args,
1527
+ data_collator = data_collator,
1528
+ train_dataset = train_dataset,
1529
+ eval_dataset = eval_dataset,
1530
+ processing_class = processing_class,
1531
+ model_init = model_init,
1532
+ callbacks = callbacks,
1533
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics,
1534
+ peft_config = peft_config,
1535
+ compute_metrics = compute_metrics,**kwargs)
1536
+ if hasattr(self, 'neftune_hook_handle'):
1537
+ self.neftune_hook_handle.remove()
1538
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1539
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
1540
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1541
+ pass
1542
+
1543
+ pass
unsloth_compiled_cache/UnslothOnlineDPOTrainer.py ADDED
@@ -0,0 +1,1269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.3.17
3
+ 2025.3.19
4
+ 4.50.3
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from trl.trainer.online_dpo_trainer import (Any, BaseImageProcessor, BasePairwiseJudge, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalPrediction, F, FeatureExtractionMixin, GenerationConfig, IterableDataset, OnlineDPOConfig, OnlineDPOTrainer, OptimizerNames, Optional, PREFIX_CHECKPOINT_DIR, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SIMPLE_CHAT_TEMPLATE, Trainer, TrainerCallback, Union, apply_chat_template, create_reference_model, datasets, disable_dropout_in_model, empty_cache, generate_model_card, get_comet_experiment_url, get_reward, is_conversational, is_peft_available, is_wandb_available, jinja2, logging, maybe_apply_chat_template, nn, np, os, prepare_deepspeed, seed_worker, textwrap, torch, transformers, truncate_right, unwrap_model_for_generation, version, wandb, warnings, wraps, F, is_conversational, os, torch)
13
+
14
+
15
+ import os
16
+ from typing import *
17
+ from dataclasses import dataclass, field
18
+ from packaging.version import Version
19
+ import torch
20
+ import numpy as np
21
+ from contextlib import nullcontext
22
+ from torch.nn import functional as F
23
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
+
25
+ torch_compile_options = {
26
+ "epilogue_fusion" : True,
27
+ "max_autotune" : False,
28
+ "shape_padding" : True,
29
+ "trace.enabled" : False,
30
+ "triton.cudagraphs" : False,
31
+ }
32
+
33
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
+ def selective_log_softmax(logits, index):
35
+ logits = logits.to(torch.float32)
36
+ selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
+ # loop to reduce peak mem consumption
38
+ # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
+ logsumexp_values = torch.logsumexp(logits, dim = -1)
40
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
+ return per_token_logps
42
+ def vLLMSamplingParams(**kwargs):
43
+ from vllm import SamplingParams
44
+ sampling_params = SamplingParams(**kwargs)
45
+ sampling_params._set_kwargs = kwargs
46
+ return sampling_params
47
+ @dataclass
48
+ class UnslothOnlineDPOConfig(OnlineDPOConfig):
49
+ """
50
+
51
+ Configuration class for the [`OnlineDPOTrainer`].
52
+
53
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
54
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
55
+ command line.
56
+
57
+ Parameters:
58
+ learning_rate (`float`, *optional*, defaults to `5e-7`):
59
+ Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
60
+ [`~transformers.TrainingArguments`].
61
+ reward_model_path (`str` or `None`, *optional*, defaults to `None`):
62
+ Path to the reward model. Either `judge` or `reward_model_path` must be set, but not both.
63
+ judge (`str` or `None`, *optional*, defaults to `None`):
64
+ Name of the judge to use. Either `judge` or `reward_model_path` must be set, but not both.
65
+ max_new_tokens (`int`, *optional*, defaults to `64`):
66
+ Maximum number of tokens to generate per completion.
67
+ max_length (`int`, *optional*, defaults to `256`):
68
+ Maximum total length of the sequence (prompt + completion) used to compute log probabilities. If the
69
+ sequence exceeds this limit, the leftmost tokens will be truncated to preserve as much of the completion as
70
+ possible.
71
+ temperature (`float`, *optional*, defaults to `0.9`):
72
+ Temperature for sampling. The higher the temperature, the more random the completions.
73
+ missing_eos_penalty (`float` or `None`, *optional*, defaults to `None`):
74
+ Penalty applied to the score when the model fails to generate an EOS token. This is useful to encourage
75
+ to generate completions shorter than the maximum length (`max_new_tokens`). The penalty must be a positive
76
+ value.
77
+ beta (`float` or `list[float]`, *optional*, defaults to `0.1`):
78
+ Parameter controlling the deviation from the reference model. Higher β means less deviation from the
79
+ reference model. For the IPO loss (`loss_type="ipo"`), β is the regularization parameter denoted by τ in
80
+ the [paper](https://huggingface.co/papers/2310.12036). If a list of floats is provided then the β is
81
+ selected for each new epoch and the last β is used for the rest of the epochs.
82
+ loss_type (`str`, *optional*, defaults to `"sigmoid"`):
83
+ Type of loss to use. Possible values are:
84
+
85
+ - `"sigmoid"`: sigmoid loss from the original [DPO](https://huggingface.co/papers/2305.18290) paper.
86
+ - `"ipo"`: IPO loss from the [IPO](https://huggingface.co/papers/2310.12036) paper.
87
+
88
+ dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
89
+ Number of processes to use for processing the dataset.
90
+ disable_dropout (`bool`, *optional*, defaults to `True`):
91
+ Whether to disable dropout in the model and reference model.
92
+ use_vllm (`bool`, *optional*, defaults to `False`):
93
+ Whether to use vLLM for generating completions. Requires vLLM to be installed (`pip install vllm`).
94
+ ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
95
+ This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
96
+ improving generation speed. However, disabling this option allows training models that exceed the VRAM
97
+ capacity of a single GPU, albeit at the cost of slower generation.
98
+
99
+ """
100
+ vllm_sampling_params: Optional[Any] = field(
101
+ default = None,
102
+ metadata = {'help': 'vLLM SamplingParams'},
103
+ )
104
+ unsloth_num_chunks : Optional[int] = field(
105
+ default = -1,
106
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
107
+ )
108
+ def __init__(
109
+ self,
110
+ output_dir = None,
111
+ overwrite_output_dir = None,
112
+ do_train = False,
113
+ do_eval = False,
114
+ do_predict = False,
115
+ eval_strategy = 'no',
116
+ prediction_loss_only = False,
117
+ per_device_train_batch_size = 4,
118
+ per_device_eval_batch_size = 4,
119
+ per_gpu_train_batch_size = None,
120
+ per_gpu_eval_batch_size = None,
121
+ gradient_accumulation_steps = 2,
122
+ eval_accumulation_steps = 2,
123
+ eval_delay = 0,
124
+ torch_empty_cache_steps = 250,
125
+ learning_rate = 5e-05,
126
+ weight_decay = 0.01,
127
+ adam_beta1 = 0.9,
128
+ adam_beta2 = 0.999,
129
+ adam_epsilon = 1e-08,
130
+ max_grad_norm = 1.0,
131
+ num_train_epochs = 3.0,
132
+ max_steps = -1,
133
+ lr_scheduler_type = 'linear',
134
+ warmup_ratio = 0.1,
135
+ warmup_steps = 0,
136
+ log_level = 'passive',
137
+ log_level_replica = 'warning',
138
+ log_on_each_node = True,
139
+ logging_dir = None,
140
+ logging_strategy = 'steps',
141
+ logging_first_step = False,
142
+ logging_steps = 1,
143
+ logging_nan_inf_filter = False,
144
+ save_strategy = 'steps',
145
+ save_steps = 500,
146
+ save_total_limit = None,
147
+ save_safetensors = True,
148
+ save_on_each_node = False,
149
+ save_only_model = False,
150
+ restore_callback_states_from_checkpoint = False,
151
+ no_cuda = False,
152
+ use_cpu = False,
153
+ use_mps_device = False,
154
+ seed = 3407,
155
+ data_seed = 3407,
156
+ jit_mode_eval = False,
157
+ use_ipex = False,
158
+ bf16 = False,
159
+ fp16 = False,
160
+ fp16_opt_level = 'O1',
161
+ half_precision_backend = 'auto',
162
+ bf16_full_eval = False,
163
+ fp16_full_eval = False,
164
+ tf32 = None,
165
+ local_rank = -1,
166
+ ddp_backend = None,
167
+ tpu_num_cores = None,
168
+ tpu_metrics_debug = False,
169
+ debug = '',
170
+ dataloader_drop_last = False,
171
+ eval_steps = None,
172
+ dataloader_num_workers = 0,
173
+ dataloader_prefetch_factor = None,
174
+ past_index = -1,
175
+ run_name = None,
176
+ disable_tqdm = None,
177
+ remove_unused_columns = True,
178
+ label_names = None,
179
+ load_best_model_at_end = False,
180
+ metric_for_best_model = None,
181
+ greater_is_better = None,
182
+ ignore_data_skip = False,
183
+ fsdp = '',
184
+ fsdp_min_num_params = 0,
185
+ fsdp_config = None,
186
+ tp_size = 0,
187
+ fsdp_transformer_layer_cls_to_wrap = None,
188
+ accelerator_config = None,
189
+ deepspeed = None,
190
+ label_smoothing_factor = 0.0,
191
+ optim = 'adamw_8bit',
192
+ optim_args = None,
193
+ adafactor = False,
194
+ group_by_length = False,
195
+ length_column_name = 'length',
196
+ report_to = None,
197
+ ddp_find_unused_parameters = None,
198
+ ddp_bucket_cap_mb = None,
199
+ ddp_broadcast_buffers = None,
200
+ dataloader_pin_memory = True,
201
+ dataloader_persistent_workers = False,
202
+ skip_memory_metrics = True,
203
+ use_legacy_prediction_loop = False,
204
+ push_to_hub = False,
205
+ resume_from_checkpoint = None,
206
+ hub_model_id = None,
207
+ hub_strategy = 'every_save',
208
+ hub_token = None,
209
+ hub_private_repo = None,
210
+ hub_always_push = False,
211
+ gradient_checkpointing = False,
212
+ gradient_checkpointing_kwargs = None,
213
+ include_inputs_for_metrics = False,
214
+ eval_do_concat_batches = True,
215
+ fp16_backend = 'auto',
216
+ evaluation_strategy = None,
217
+ push_to_hub_model_id = None,
218
+ push_to_hub_organization = None,
219
+ push_to_hub_token = None,
220
+ mp_parameters = '',
221
+ auto_find_batch_size = False,
222
+ full_determinism = False,
223
+ torchdynamo = None,
224
+ ray_scope = 'last',
225
+ ddp_timeout = 1800,
226
+ torch_compile = False,
227
+ torch_compile_backend = None,
228
+ torch_compile_mode = None,
229
+ dispatch_batches = None,
230
+ split_batches = None,
231
+ include_tokens_per_second = False,
232
+ include_num_input_tokens_seen = False,
233
+ neftune_noise_alpha = None,
234
+ optim_target_modules = None,
235
+ batch_eval_metrics = False,
236
+ eval_on_start = False,
237
+ use_liger_kernel = False,
238
+ eval_use_gather_object = False,
239
+ average_tokens_across_devices = False,
240
+ reward_model_path = None,
241
+ judge = None,
242
+ max_new_tokens = 64,
243
+ max_length = 512,
244
+ temperature = 0.9,
245
+ missing_eos_penalty = None,
246
+ loss_type = 'sigmoid',
247
+ dataset_num_proc = None,
248
+ disable_dropout = True,
249
+ use_vllm = False,
250
+ ds3_gather_for_generation = True,
251
+ vllm_sampling_params = None,
252
+ unsloth_num_chunks = -1,
253
+ **kwargs,
254
+ ):
255
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
256
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
257
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
258
+ output_dir = 'unsloth_training_checkpoints'
259
+ save_strategy = 'no'
260
+ if dataset_num_proc is None:
261
+ from multiprocessing import cpu_count
262
+ dataset_num_proc = cpu_count()
263
+
264
+ super().__init__(
265
+ output_dir = output_dir,
266
+ overwrite_output_dir = overwrite_output_dir,
267
+ do_train = do_train,
268
+ do_eval = do_eval,
269
+ do_predict = do_predict,
270
+ eval_strategy = eval_strategy,
271
+ prediction_loss_only = prediction_loss_only,
272
+ per_device_train_batch_size = per_device_train_batch_size,
273
+ per_device_eval_batch_size = per_device_eval_batch_size,
274
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
275
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
276
+ gradient_accumulation_steps = gradient_accumulation_steps,
277
+ eval_accumulation_steps = eval_accumulation_steps,
278
+ eval_delay = eval_delay,
279
+ torch_empty_cache_steps = torch_empty_cache_steps,
280
+ learning_rate = learning_rate,
281
+ weight_decay = weight_decay,
282
+ adam_beta1 = adam_beta1,
283
+ adam_beta2 = adam_beta2,
284
+ adam_epsilon = adam_epsilon,
285
+ max_grad_norm = max_grad_norm,
286
+ num_train_epochs = num_train_epochs,
287
+ max_steps = max_steps,
288
+ lr_scheduler_type = lr_scheduler_type,
289
+ warmup_ratio = warmup_ratio,
290
+ warmup_steps = warmup_steps,
291
+ log_level = log_level,
292
+ log_level_replica = log_level_replica,
293
+ log_on_each_node = log_on_each_node,
294
+ logging_dir = logging_dir,
295
+ logging_strategy = logging_strategy,
296
+ logging_first_step = logging_first_step,
297
+ logging_steps = logging_steps,
298
+ logging_nan_inf_filter = logging_nan_inf_filter,
299
+ save_strategy = save_strategy,
300
+ save_steps = save_steps,
301
+ save_total_limit = save_total_limit,
302
+ save_safetensors = save_safetensors,
303
+ save_on_each_node = save_on_each_node,
304
+ save_only_model = save_only_model,
305
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
306
+ no_cuda = no_cuda,
307
+ use_cpu = use_cpu,
308
+ use_mps_device = use_mps_device,
309
+ seed = seed,
310
+ data_seed = data_seed,
311
+ jit_mode_eval = jit_mode_eval,
312
+ use_ipex = use_ipex,
313
+ bf16 = bf16,
314
+ fp16 = fp16,
315
+ fp16_opt_level = fp16_opt_level,
316
+ half_precision_backend = half_precision_backend,
317
+ bf16_full_eval = bf16_full_eval,
318
+ fp16_full_eval = fp16_full_eval,
319
+ tf32 = tf32,
320
+ local_rank = local_rank,
321
+ ddp_backend = ddp_backend,
322
+ tpu_num_cores = tpu_num_cores,
323
+ tpu_metrics_debug = tpu_metrics_debug,
324
+ debug = debug,
325
+ dataloader_drop_last = dataloader_drop_last,
326
+ eval_steps = eval_steps,
327
+ dataloader_num_workers = dataloader_num_workers,
328
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
329
+ past_index = past_index,
330
+ run_name = run_name,
331
+ disable_tqdm = disable_tqdm,
332
+ remove_unused_columns = remove_unused_columns,
333
+ label_names = label_names,
334
+ load_best_model_at_end = load_best_model_at_end,
335
+ metric_for_best_model = metric_for_best_model,
336
+ greater_is_better = greater_is_better,
337
+ ignore_data_skip = ignore_data_skip,
338
+ fsdp = fsdp,
339
+ fsdp_min_num_params = fsdp_min_num_params,
340
+ fsdp_config = fsdp_config,
341
+ tp_size = tp_size,
342
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
343
+ accelerator_config = accelerator_config,
344
+ deepspeed = deepspeed,
345
+ label_smoothing_factor = label_smoothing_factor,
346
+ optim = optim,
347
+ optim_args = optim_args,
348
+ adafactor = adafactor,
349
+ group_by_length = group_by_length,
350
+ length_column_name = length_column_name,
351
+ report_to = report_to,
352
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
353
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
354
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
355
+ dataloader_pin_memory = dataloader_pin_memory,
356
+ dataloader_persistent_workers = dataloader_persistent_workers,
357
+ skip_memory_metrics = skip_memory_metrics,
358
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
359
+ push_to_hub = push_to_hub,
360
+ resume_from_checkpoint = resume_from_checkpoint,
361
+ hub_model_id = hub_model_id,
362
+ hub_strategy = hub_strategy,
363
+ hub_token = hub_token,
364
+ hub_private_repo = hub_private_repo,
365
+ hub_always_push = hub_always_push,
366
+ gradient_checkpointing = gradient_checkpointing,
367
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
368
+ include_inputs_for_metrics = include_inputs_for_metrics,
369
+ eval_do_concat_batches = eval_do_concat_batches,
370
+ fp16_backend = fp16_backend,
371
+ evaluation_strategy = evaluation_strategy,
372
+ push_to_hub_model_id = push_to_hub_model_id,
373
+ push_to_hub_organization = push_to_hub_organization,
374
+ push_to_hub_token = push_to_hub_token,
375
+ mp_parameters = mp_parameters,
376
+ auto_find_batch_size = auto_find_batch_size,
377
+ full_determinism = full_determinism,
378
+ torchdynamo = torchdynamo,
379
+ ray_scope = ray_scope,
380
+ ddp_timeout = ddp_timeout,
381
+ torch_compile = torch_compile,
382
+ torch_compile_backend = torch_compile_backend,
383
+ torch_compile_mode = torch_compile_mode,
384
+ dispatch_batches = dispatch_batches,
385
+ split_batches = split_batches,
386
+ include_tokens_per_second = include_tokens_per_second,
387
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
388
+ neftune_noise_alpha = neftune_noise_alpha,
389
+ optim_target_modules = optim_target_modules,
390
+ batch_eval_metrics = batch_eval_metrics,
391
+ eval_on_start = eval_on_start,
392
+ use_liger_kernel = use_liger_kernel,
393
+ eval_use_gather_object = eval_use_gather_object,
394
+ average_tokens_across_devices = average_tokens_across_devices,
395
+ reward_model_path = reward_model_path,
396
+ judge = judge,
397
+ max_new_tokens = max_new_tokens,
398
+ max_length = max_length,
399
+ temperature = temperature,
400
+ missing_eos_penalty = missing_eos_penalty,
401
+ loss_type = loss_type,
402
+ dataset_num_proc = dataset_num_proc,
403
+ disable_dropout = disable_dropout,
404
+ use_vllm = use_vllm,
405
+ ds3_gather_for_generation = ds3_gather_for_generation,**kwargs)
406
+ self.vllm_sampling_params = vllm_sampling_params
407
+ self.unsloth_num_chunks = unsloth_num_chunks
408
+ pass
409
+
410
+ class _UnslothOnlineDPOTrainer(Trainer):
411
+ r""""""
412
+
413
+ _tag_names = ["trl", "online-dpo"]
414
+
415
+ def __init__(
416
+ self,
417
+ model: Union[PreTrainedModel, nn.Module],
418
+ ref_model: Union[PreTrainedModel, nn.Module, None] = None,
419
+ reward_model: Union[PreTrainedModel, nn.Module, None] = None,
420
+ judge: Optional[BasePairwiseJudge] = None,
421
+ args: Optional[OnlineDPOConfig] = None,
422
+ data_collator: Optional[DataCollator] = None,
423
+ train_dataset: Optional[Union[Dataset, IterableDataset, "datasets.Dataset"]] = None,
424
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset], "datasets.Dataset"]] = None,
425
+ processing_class: Optional[
426
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
427
+ ] = None,
428
+ reward_processing_class: Optional[PreTrainedTokenizerBase] = None,
429
+ peft_config: Optional[dict] = None,
430
+ compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
431
+ callbacks: Optional[list[TrainerCallback]] = None,
432
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
433
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
434
+ ) -> None:
435
+
436
+ if hasattr(model, 'vllm_engine') and hasattr(args, 'use_vllm') and (getattr(args, 'use_vllm', False) == False): args.use_vllm = True
437
+ if ref_model is model:
438
+ raise ValueError(
439
+ "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
440
+ "same as `model`, either omit the `ref_model` argument or pass `None`."
441
+ )
442
+
443
+ self.ref_model = ref_model
444
+
445
+ if reward_model is not None and judge is not None:
446
+ warnings.warn(
447
+ "Both `reward_model` and `judge` are provided. Please choose provide only one of them. "
448
+ "Ignoring `judge` and using `reward_model`.",
449
+ UserWarning,
450
+ )
451
+ judge = None
452
+ elif reward_model is None and judge is None:
453
+ raise ValueError("Either `reward_model` or `judge` must be provided.")
454
+
455
+ self.reward_model = reward_model
456
+ self.reward_processing_class = reward_processing_class
457
+ self.judge = judge
458
+
459
+ if args.missing_eos_penalty is not None and judge is not None:
460
+ raise ValueError("`missing_eos_penalty` is not supported when `judge` is provided.")
461
+
462
+ if args is None:
463
+ raise ValueError("`args` must be provided.")
464
+
465
+ # Check that the processing_class is provided
466
+ if processing_class is None:
467
+ raise ValueError("`processing_class` must be provided.")
468
+
469
+ # Convert to PEFT model if peft_config is provided
470
+ if False:
471
+ # Check if PEFT is available
472
+ if not is_peft_available():
473
+ raise ImportError(
474
+ "PEFT is not available and passed `peft_config`. Please install PEFT with "
475
+ "`pip install peft` to use it."
476
+ )
477
+
478
+ # If the model is already a PeftModel, we need to merge and unload it.
479
+ # Further information here: https://huggingface.co/docs/trl/dpo_trainer#reference-model-considerations-with-peft
480
+ if isinstance(model, PeftModel):
481
+ model = model.merge_and_unload()
482
+
483
+ # Get peft model with the given config
484
+ model = model
485
+
486
+ # Disable dropout in the model and reference model
487
+ if args.disable_dropout:
488
+ disable_dropout_in_model(model)
489
+ if self.ref_model is not None:
490
+ disable_dropout_in_model(self.ref_model)
491
+
492
+ # Handle the ref_model
493
+ # Usually, the user wants the ref model to be the initial version of the model. When using PEFT, it's easy to
494
+ # get the ref model, as it's just the model with a disabled adapter. When not using PEFT, we need to create
495
+ # the ref model from the model by copying it and disable the gradients and set it in evaluation mode.
496
+ if ref_model is None: # No ref model provided, the most common case
497
+ if False:
498
+ self.ref_model = create_reference_model(model) # copy, disable gradients, set eval mode
499
+ else:
500
+ self.ref_model = None # we don't need a ref model here, we can just disable the adapter.
501
+ else: # rare case, the user provided a ref model
502
+ self.ref_model = ref_model
503
+ self.ref_model.eval()
504
+
505
+ # Disable the gradient and set the reward model in eval mode
506
+ if self.reward_model is not None:
507
+ self.reward_model.eval()
508
+
509
+ # Define the collator is not provided
510
+ if data_collator is None:
511
+ data_collator = DPODataCollatorWithPadding(pad_token_id=processing_class.pad_token_id)
512
+
513
+ self.max_length = args.max_length
514
+
515
+ self.stats = {
516
+ "objective/kl": [],
517
+ "objective/entropy": [],
518
+ "objective/non_score_reward": [],
519
+ "rewards/chosen": [],
520
+ "rewards/rejected": [],
521
+ "rewards/accuracies": [],
522
+ "rewards/margins": [],
523
+ "logps/chosen": [],
524
+ "logps/rejected": [],
525
+ "val/contain_eos_token": [],
526
+ "beta": [],
527
+ }
528
+ if self.reward_model is not None:
529
+ self.stats["objective/rlhf_reward"] = []
530
+ self.stats["objective/scores_margin"] = []
531
+ self.stats["objective/scores"] = []
532
+
533
+ if args.use_vllm:
534
+ self.llm = model.vllm_engine; self._last_loaded_step = 0; self.generation_config = SamplingParams(
535
+ n=2, max_tokens=args.max_new_tokens,
536
+ temperature=args.temperature,
537
+ top_k=50,
538
+ top_p=1.0,
539
+ detokenize=False,**getattr(getattr(args, 'vllm_sampling_params', vLLMSamplingParams()), '_set_kwargs', {}),)
540
+ else:
541
+ self.generation_config = GenerationConfig(
542
+ max_new_tokens=args.max_new_tokens,
543
+ temperature=args.temperature,
544
+ top_k=50,
545
+ top_p=1.0,
546
+ do_sample=True,
547
+ use_cache=False if args.gradient_checkpointing else True,
548
+ )
549
+
550
+ # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
551
+ # input tensor associated with the key "input_ids". However, in Online DPO, the sampled data does not include
552
+ # the "input_ids" key. As a result, the trainer issues the warning: "Could not estimate the number of tokens
553
+ # of the input, floating-point operations will not be computed." To suppress this warning, we set the
554
+ # "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate
555
+ # that the warning has already been issued.
556
+ model.warnings_issued["estimate_tokens"] = True
557
+
558
+ super().__init__(
559
+ model=model,
560
+ args=args,
561
+ data_collator=data_collator,
562
+ train_dataset=train_dataset,
563
+ eval_dataset=eval_dataset,
564
+ processing_class=processing_class,
565
+ compute_metrics=compute_metrics,
566
+ callbacks=callbacks,
567
+ optimizers=optimizers,
568
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
569
+ )
570
+
571
+ # Add tags for models that have been loaded with the correct transformers version
572
+ if hasattr(self.model, "add_model_tags"):
573
+ self.model.add_model_tags(self._tag_names)
574
+
575
+ self._beta = args.beta
576
+
577
+ # Placed after the super().__init__ because we need self.is_deepspeed_enabled and self.accelerator
578
+ if self.is_deepspeed_enabled:
579
+ if self.reward_model is not None:
580
+ self.reward_model = prepare_deepspeed(
581
+ self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16
582
+ )
583
+ if self.ref_model is not None:
584
+ self.ref_model = prepare_deepspeed(
585
+ self.ref_model, args.per_device_train_batch_size, args.fp16, args.bf16
586
+ )
587
+ else:
588
+ if self.ref_model is not None:
589
+ self.ref_model = self.ref_model.to(self.accelerator.device)
590
+ if self.reward_model is not None:
591
+ self.reward_model = self.reward_model.to(self.accelerator.device)
592
+
593
+ @property
594
+ def beta(self):
595
+ if isinstance(self._beta, list):
596
+ epoch = self.state.epoch
597
+ return self._beta[epoch] if epoch < len(self._beta) else self._beta[-1]
598
+ else:
599
+ return self._beta
600
+
601
+ @staticmethod
602
+ def tokenize_row(feature, is_encoder_decoder: bool, tokenizer: PreTrainedTokenizerBase) -> dict[str, Any]:
603
+ """Tokenize a single row from a DPO specific dataset."""
604
+ if not is_encoder_decoder:
605
+ batch = tokenizer(feature["prompt"], add_special_tokens=False)
606
+ # Add BOS token to head of prompt. Avoid adding if it's already there
607
+ if tokenizer.bos_token_id is not None:
608
+ prompt_len_input_ids = len(batch["input_ids"])
609
+ if prompt_len_input_ids == 0 or tokenizer.bos_token_id != batch["input_ids"][0]:
610
+ batch["input_ids"] = [tokenizer.bos_token_id] + batch["input_ids"]
611
+ batch["attention_mask"] = [1] + batch["attention_mask"]
612
+ else:
613
+ batch = tokenizer(feature["prompt"], add_special_tokens=True)
614
+ batch = {f"prompt_{key}": value for key, value in batch.items()}
615
+ return batch
616
+
617
+ # Same as Trainer.get_train_dataloader but skip the "remove_unused_columns".
618
+ @wraps(Trainer.get_train_dataloader)
619
+ def get_train_dataloader(self) -> DataLoader:
620
+ if self.train_dataset is None:
621
+ raise ValueError("Trainer: training requires a train_dataset.")
622
+
623
+ train_dataset = self.train_dataset
624
+ data_collator = self.data_collator
625
+ dataloader_params = {
626
+ "batch_size": self._train_batch_size,
627
+ "collate_fn": data_collator,
628
+ "num_workers": self.args.dataloader_num_workers,
629
+ "pin_memory": self.args.dataloader_pin_memory,
630
+ "persistent_workers": self.args.dataloader_persistent_workers,
631
+ }
632
+
633
+ if not isinstance(train_dataset, torch.utils.data.IterableDataset):
634
+ dataloader_params["sampler"] = self._get_train_sampler()
635
+ dataloader_params["drop_last"] = self.args.dataloader_drop_last
636
+ dataloader_params["worker_init_fn"] = seed_worker
637
+ dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
638
+
639
+ return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
640
+
641
+ # Same as Trainer.get_eval_dataloader but skip the "remove_unused_columns".
642
+ @wraps(Trainer.get_eval_dataloader)
643
+ def get_eval_dataloader(self, eval_dataset: Optional[Union[str, Dataset]] = None) -> DataLoader:
644
+ if eval_dataset is None and self.eval_dataset is None:
645
+ raise ValueError("Trainer: evaluation requires an eval_dataset.")
646
+
647
+ # If we have persistent workers, don't do a fork bomb especially as eval datasets
648
+ # don't change during training
649
+ dataloader_key = eval_dataset if isinstance(eval_dataset, str) else "eval"
650
+ if (
651
+ hasattr(self, "_eval_dataloaders")
652
+ and dataloader_key in self._eval_dataloaders
653
+ and self.args.dataloader_persistent_workers
654
+ ):
655
+ return self.accelerator.prepare(self._eval_dataloaders[dataloader_key])
656
+
657
+ eval_dataset = (
658
+ self.eval_dataset[eval_dataset]
659
+ if isinstance(eval_dataset, str)
660
+ else eval_dataset
661
+ if eval_dataset is not None
662
+ else self.eval_dataset
663
+ )
664
+ data_collator = self.data_collator
665
+
666
+ dataloader_params = {
667
+ "batch_size": self.args.eval_batch_size,
668
+ "collate_fn": data_collator,
669
+ "num_workers": self.args.dataloader_num_workers,
670
+ "pin_memory": self.args.dataloader_pin_memory,
671
+ "persistent_workers": self.args.dataloader_persistent_workers,
672
+ }
673
+
674
+ if not isinstance(eval_dataset, torch.utils.data.IterableDataset):
675
+ dataloader_params["sampler"] = self._get_eval_sampler(eval_dataset)
676
+ dataloader_params["drop_last"] = self.args.dataloader_drop_last
677
+ dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
678
+
679
+ # accelerator.free_memory() will destroy the references, so
680
+ # we need to store the non-prepared version
681
+ eval_dataloader = DataLoader(eval_dataset, **dataloader_params)
682
+ if self.args.dataloader_persistent_workers:
683
+ if hasattr(self, "_eval_dataloaders"):
684
+ self._eval_dataloaders[dataloader_key] = eval_dataloader
685
+ else:
686
+ self._eval_dataloaders = {dataloader_key: eval_dataloader}
687
+
688
+ return self.accelerator.prepare(eval_dataloader)
689
+
690
+ def _generate_vllm(self, model, prompts):
691
+ eos_token_id = self.processing_class.eos_token_id
692
+ pad_token_id = self.processing_class.pad_token_id
693
+
694
+ # Load the latest weights
695
+
696
+ pass
697
+
698
+ pass
699
+
700
+ if is_conversational({"prompt": prompts[0]}):
701
+ outputs = self.llm.chat(prompts, self.generation_config, use_tqdm=False, lora_request = self.model.load_lora('online_dpo_trainer_lora_model', load_tensors = True))
702
+ else:
703
+ outputs = self.llm.generate(prompts, self.generation_config, use_tqdm=False, lora_request = self.model.load_lora('online_dpo_trainer_lora_model', load_tensors = True))
704
+
705
+ completion_ids = [list(output.outputs[i].token_ids) for i in range(2) for output in outputs]
706
+ prompt_ids = [list(output.prompt_token_ids) for _ in range(2) for output in outputs]
707
+
708
+ # Create mask and pad the prompt and completion
709
+ max_prompt_length = max(len(ids) for ids in prompt_ids)
710
+ prompt_mask = [[0] * (max_prompt_length - len(ids)) + [1] * len(ids) for ids in prompt_ids]
711
+ prompt_ids = [[pad_token_id] * (max_prompt_length - len(ids)) + ids for ids in prompt_ids]
712
+ max_tokens = self.generation_config.max_tokens
713
+ completion_mask = [[1] * len(ids) + [0] * (max_tokens - len(ids)) for ids in completion_ids]
714
+ completion_ids = [
715
+ ids + [eos_token_id] if ids[-1] != eos_token_id and len(ids) < max_tokens else ids
716
+ for ids in completion_ids
717
+ ]
718
+ completion_ids = [ids + [pad_token_id] * (max_tokens - len(ids)) for ids in completion_ids]
719
+
720
+ # Convert to tensors
721
+ prompt_ids = torch.tensor(prompt_ids, device=self.accelerator.device)
722
+ prompt_mask = torch.tensor(prompt_mask, device=self.accelerator.device)
723
+ completion_ids = torch.tensor(completion_ids, device=self.accelerator.device)
724
+ completion_mask = torch.tensor(completion_mask, device=self.accelerator.device)
725
+
726
+ return prompt_ids, prompt_mask, completion_ids, completion_mask
727
+
728
+ def _generate(self, model, prompts):
729
+ eos_token_id = self.processing_class.eos_token_id
730
+ pad_token_id = self.processing_class.pad_token_id
731
+
732
+ # Apply chat template and tokenize the input. We do this on-the-fly to enable the use of reward models and
733
+ # policies with different tokenizers / chat templates.
734
+ inputs = [{"prompt": prompt} for prompt in prompts]
735
+ inputs = [maybe_apply_chat_template(x, self.processing_class) for x in inputs]
736
+ inputs = [self.tokenize_row(x, model.config.is_encoder_decoder, self.processing_class) for x in inputs]
737
+ inputs = self.data_collator(inputs)
738
+
739
+ # Sample 2 completions per prompt of size `max_new_tokens` from the model
740
+ inputs = self._prepare_inputs(inputs)
741
+ prompt_ids = inputs["prompt_input_ids"].repeat(2, 1)
742
+ prompt_mask = inputs["prompt_attention_mask"].repeat(2, 1)
743
+ with unwrap_model_for_generation(
744
+ model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
745
+ ) as unwrapped_model:
746
+ output = unwrapped_model.generate(
747
+ input_ids=prompt_ids,
748
+ attention_mask=prompt_mask,
749
+ generation_config=self.generation_config,
750
+ )
751
+
752
+ completion_ids = output[:, prompt_ids.size(1) :]
753
+ completion_ids, completion_mask = truncate_right(completion_ids, eos_token_id, pad_token_id)
754
+
755
+ return prompt_ids, prompt_mask, completion_ids, completion_mask
756
+
757
+ def _forward(self, model, prompt_ids, prompt_mask, completion_ids, completion_mask):
758
+ # Get the number of tokens to truncate from prompt
759
+ num_tokens_to_truncate = max(prompt_ids.size(1) + completion_ids.size(1) - self.max_length, 0)
760
+
761
+ # Truncate left to avoid oom
762
+ prompt_ids = prompt_ids[:, num_tokens_to_truncate:]
763
+ prompt_mask = prompt_mask[:, num_tokens_to_truncate:]
764
+
765
+ # Concat the prompt and completion
766
+ prompt_completion_ids = torch.cat((prompt_ids, completion_ids), dim=1)
767
+ prompt_completion_mask = torch.cat((prompt_mask, completion_mask), dim=1)
768
+
769
+ # Get the logprobs of the completions from the model
770
+ output = model(prompt_completion_ids, attention_mask=prompt_completion_mask)
771
+
772
+ # There is 1 offset, because the model predict the next token
773
+ logits = output.logits[:, prompt_ids.size(1) - 1 : -1]
774
+
775
+ # Take the completion tokens logprob
776
+ logprobs = torch.take_along_dim(logits.log_softmax(dim=-1), completion_ids.unsqueeze(-1), dim=2).squeeze(-1)
777
+ return logprobs
778
+
779
+ def training_step(
780
+ self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
781
+ ) -> torch.Tensor:
782
+ model.train()
783
+
784
+ prompts = inputs["prompt"]
785
+ batch_size = len(prompts)
786
+
787
+ if self.args.use_vllm:
788
+ prompt_ids, prompt_mask, completion_ids, completion_mask = self._generate_vllm(model, prompts)
789
+ else:
790
+ prompt_ids, prompt_mask, completion_ids, completion_mask = self._generate(model, prompts)
791
+
792
+ contain_eos_token = torch.any(completion_ids == self.processing_class.eos_token_id, dim=-1)
793
+
794
+ logprobs = self._forward(model, prompt_ids, prompt_mask, completion_ids, completion_mask)
795
+ with torch.no_grad():
796
+ if self.ref_model is not None:
797
+ ref_logprobs = self._forward(self.ref_model, prompt_ids, prompt_mask, completion_ids, completion_mask)
798
+ else: # peft case: we just need to disable the adapter
799
+ with self.model.disable_adapter():
800
+ ref_logprobs = self._forward(self.model, prompt_ids, prompt_mask, completion_ids, completion_mask)
801
+
802
+ # Decode the completions, and format them if the input is conversational
803
+ device = logprobs.device
804
+ completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
805
+ if is_conversational({"prompt": prompts[0]}):
806
+ completions = [[{"role": "assistant", "content": completion}] for completion in completions]
807
+
808
+ # Get the reward from the reward model or judge
809
+ if self.judge is not None:
810
+ # Once formatted, conversational data may contain special tokens (such as <|im_start|>) that are not
811
+ # directly understandable by the judge and could alter its judgment. To avoid this and make the judge
812
+ # independent of the model's chat template, we use the raw conversation data, and apply our own chat
813
+ # template to it.
814
+ if is_conversational({"prompt": prompts[0]}):
815
+ environment = jinja2.Environment()
816
+ template = environment.from_string(SIMPLE_CHAT_TEMPLATE)
817
+ prompts = [template.render(messages=prompt) for prompt in prompts]
818
+ completions = [template.render(messages=completion) for completion in completions]
819
+
820
+ ranks_of_first_completion = self.judge.judge(
821
+ prompts, list(zip(completions[:batch_size], completions[batch_size:]))
822
+ )
823
+
824
+ # convert ranks to a True/False mask:
825
+ # when rank == 0, it means the first completion is the best
826
+ # when rank == 1, it means the second completion is the best
827
+ mask = torch.tensor([rank == 0 for rank in ranks_of_first_completion], device=device)
828
+ else:
829
+ # The reward model may not have the same chat template or tokenizer as the model, so we need to use the
830
+ # raw data (string), apply the chat template (if needed), and tokenize it with the reward processing class.
831
+ prompts = 2 * prompts # repeat the prompt: [prompt0, prompt1] -> [prompt0, prompt1, prompt0, prompt1]
832
+ if is_conversational({"prompt": prompts[0]}):
833
+ examples = [{"prompt": p, "completion": c} for p, c in zip(prompts, completions)]
834
+ examples = [apply_chat_template(example, self.reward_processing_class) for example in examples]
835
+ prompts = [example["prompt"] for example in examples]
836
+ completions = [example["completion"] for example in examples]
837
+
838
+ # Tokenize the prompts
839
+ prompts_ids = self.reward_processing_class(
840
+ prompts, padding=True, return_tensors="pt", padding_side="left"
841
+ )["input_ids"].to(device)
842
+ context_length = prompts_ids.shape[1]
843
+
844
+ # Tokenize the completions
845
+ completions_ids = self.reward_processing_class(
846
+ completions, padding=True, return_tensors="pt", padding_side="right"
847
+ )["input_ids"].to(device)
848
+
849
+ # Concatenate the prompts and completions and get the reward
850
+ prompt_completion_ids = torch.cat((prompts_ids, completions_ids), dim=1)
851
+ with torch.inference_mode():
852
+ _, scores, _ = get_reward(
853
+ self.reward_model, prompt_completion_ids, self.reward_processing_class.pad_token_id, context_length
854
+ )
855
+
856
+ # Filter completion. Ensure that the sample contains stop_token_id
857
+ # Completions not passing that filter will receive a lower score.
858
+ if self.args.missing_eos_penalty is not None:
859
+ scores[~contain_eos_token] -= self.args.missing_eos_penalty
860
+
861
+ # Split the scores in 2 (the prompts of the first half are the same as the second half)
862
+ first_half, second_half = scores.split(batch_size)
863
+
864
+ # Get the indices of the chosen and rejected examples
865
+ mask = first_half >= second_half
866
+
867
+ batch_range = torch.arange(batch_size, device=device)
868
+ chosen_indices = batch_range + (~mask * batch_size)
869
+ rejected_indices = batch_range + (mask * batch_size)
870
+
871
+ # Build tensor so that the first half is the chosen examples and the second half the rejected examples
872
+ cr_indices = torch.cat((chosen_indices, rejected_indices), dim=0) # cr = chosen and rejected
873
+ cr_logprobs = logprobs[cr_indices]
874
+ cr_ref_logprobs = ref_logprobs[cr_indices]
875
+
876
+ # mask out the padding tokens
877
+ padding_mask = ~completion_mask.bool()
878
+ cr_padding_mask = padding_mask[cr_indices]
879
+
880
+ cr_logprobs_sum = (cr_logprobs * ~cr_padding_mask).sum(1)
881
+ cr_ref_logprobs_sum = (cr_ref_logprobs * ~cr_padding_mask).sum(1)
882
+
883
+ # Split the chosen and rejected examples
884
+ chosen_logprobs_sum, rejected_logprobs_sum = torch.split(cr_logprobs_sum, batch_size)
885
+ chosen_ref_logprobs_sum, rejected_ref_logprobs_sum = torch.split(cr_ref_logprobs_sum, batch_size)
886
+ pi_logratios = chosen_logprobs_sum - rejected_logprobs_sum
887
+ ref_logratios = chosen_ref_logprobs_sum - rejected_ref_logprobs_sum
888
+
889
+ logits = pi_logratios - ref_logratios
890
+
891
+ if self.args.loss_type == "sigmoid":
892
+ losses = -F.logsigmoid(self.beta * logits)
893
+ elif self.args.loss_type == "ipo":
894
+ losses = (logits - 1 / (2 * self.beta)) ** 2
895
+ else:
896
+ raise NotImplementedError(f"invalid loss type {self.loss_type}")
897
+
898
+ loss = losses.mean()
899
+
900
+ # Log everything
901
+ if self.reward_model is not None:
902
+ scores_margin = scores[chosen_indices] - scores[rejected_indices]
903
+ self.stats["objective/scores_margin"].append(
904
+ self.accelerator.gather_for_metrics(scores_margin.mean()).mean().item()
905
+ )
906
+ self.stats["objective/scores"].append(self.accelerator.gather_for_metrics(scores.mean()).mean().item())
907
+ self.stats["val/contain_eos_token"].append(contain_eos_token.float().mean().item())
908
+ self.stats["logps/chosen"].append(self.accelerator.gather_for_metrics(chosen_logprobs_sum).mean().item())
909
+ self.stats["logps/rejected"].append(self.accelerator.gather_for_metrics(rejected_logprobs_sum).mean().item())
910
+
911
+ kl = logprobs - ref_logprobs
912
+ mean_kl = kl.sum(1).mean()
913
+ self.stats["objective/kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
914
+ non_score_reward = (-self.beta * kl).sum(1)
915
+ mean_non_score_reward = non_score_reward.mean()
916
+ self.stats["objective/non_score_reward"].append(
917
+ self.accelerator.gather_for_metrics(mean_non_score_reward).mean().item()
918
+ )
919
+ if self.reward_model is not None:
920
+ rlhf_reward = scores + non_score_reward
921
+ self.stats["objective/rlhf_reward"].append(self.accelerator.gather_for_metrics(rlhf_reward).mean().item())
922
+ mean_entropy = -logprobs.sum(1).mean()
923
+ self.stats["objective/entropy"].append(self.accelerator.gather_for_metrics(mean_entropy).mean().item())
924
+ chosen_rewards = self.beta * (chosen_logprobs_sum - chosen_ref_logprobs_sum)
925
+ gathered_chosen_rewards = self.accelerator.gather_for_metrics(chosen_rewards)
926
+ self.stats["rewards/chosen"].append(gathered_chosen_rewards.mean().item())
927
+ rejected_rewards = self.beta * (rejected_logprobs_sum - rejected_ref_logprobs_sum)
928
+ gathered_rejected_rewards = self.accelerator.gather_for_metrics(rejected_rewards)
929
+ self.stats["rewards/rejected"].append(gathered_rejected_rewards.mean().item())
930
+ margin = gathered_chosen_rewards - gathered_rejected_rewards
931
+ self.stats["rewards/margins"].append(margin.mean().item())
932
+ accuracy = margin > 0
933
+ self.stats["rewards/accuracies"].append(accuracy.float().mean().item())
934
+ self.stats["beta"].append(self.beta)
935
+
936
+ if (
937
+ self.args.torch_empty_cache_steps is not None
938
+ and self.state.global_step % self.args.torch_empty_cache_steps == 0
939
+ ):
940
+ empty_cache()
941
+
942
+ kwargs = {}
943
+
944
+ # For LOMO optimizers you need to explicitly use the learnign rate
945
+ if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
946
+ kwargs["learning_rate"] = self._get_learning_rate()
947
+
948
+ if self.args.n_gpu > 1:
949
+ loss = loss.mean() # mean() to average on multi-gpu parallel training
950
+
951
+ if self.use_apex:
952
+ with amp.scale_loss(loss, self.optimizer) as scaled_loss:
953
+ scaled_loss.backward()
954
+ else:
955
+ self.accelerator.backward(loss, **kwargs)
956
+
957
+ return loss.detach() / self.args.gradient_accumulation_steps
958
+
959
+ # Same as Trainer._maybe_log_save_evaluate but log our metrics
960
+ # start_time defaults to None to allow compatibility with transformers<=4.46
961
+ def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time=None):
962
+ if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
963
+ logs: dict[str, float] = {}
964
+
965
+ # all_gather + mean() to get average loss over all processes
966
+ tr_loss_scalar = self._nested_gather(tr_loss).mean().item()
967
+
968
+ # reset tr_loss to zero
969
+ tr_loss -= tr_loss
970
+
971
+ logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
972
+ if grad_norm is not None:
973
+ logs["grad_norm"] = grad_norm.detach().item() if isinstance(grad_norm, torch.Tensor) else grad_norm
974
+ logs["learning_rate"] = self._get_learning_rate()
975
+
976
+ # Add our metrics
977
+ for key, val in self.stats.items():
978
+ logs[key] = sum(val) / len(val)
979
+ self.stats = {key: [] for key in self.stats} # reset stats
980
+
981
+ self._total_loss_scalar += tr_loss_scalar
982
+ self._globalstep_last_logged = self.state.global_step
983
+ self.store_flos()
984
+
985
+ if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
986
+ self.log(logs, start_time)
987
+ else: # transformers<=4.46
988
+ self.log(logs)
989
+
990
+ metrics = None
991
+ if self.control.should_evaluate:
992
+ metrics = self._evaluate(trial, ignore_keys_for_eval)
993
+ is_new_best_metric = self._determine_best_metric(metrics=metrics, trial=trial)
994
+
995
+ if self.args.save_strategy == "best":
996
+ self.control.should_save = is_new_best_metric
997
+
998
+ if self.control.should_save:
999
+ self._save_checkpoint(model, trial)
1000
+ self.control = self.callback_handler.on_save(self.args, self.state, self.control)
1001
+
1002
+ # Copy-pasted from transformers.Trainer to maintain compatibility with earlier versions.
1003
+ # This can be removed once the minimum transformers version is updated to 4.47.
1004
+ # Refer to https://github.com/huggingface/trl/pull/2288 for more details.
1005
+ def _determine_best_metric(self, metrics, trial):
1006
+ """
1007
+ Determine if the model should be saved based on the evaluation metrics.
1008
+ If args.metric_for_best_model is not set, the loss is used.
1009
+ Returns:
1010
+ bool: True if a new best metric was found, else False
1011
+ """
1012
+ is_new_best_metric = False
1013
+
1014
+ if self.args.metric_for_best_model is not None:
1015
+ metric_to_check = self.args.metric_for_best_model
1016
+
1017
+ if not metric_to_check.startswith("eval_"):
1018
+ metric_to_check = f"eval_{metric_to_check}"
1019
+
1020
+ try:
1021
+ metric_value = metrics[metric_to_check]
1022
+ except KeyError as exc:
1023
+ raise KeyError(
1024
+ f"The `metric_for_best_model` training argument is set to '{metric_to_check}', which is not found in the evaluation metrics. "
1025
+ f"The available evaluation metrics are: {list(metrics.keys())}. Consider changing the `metric_for_best_model` via the TrainingArguments."
1026
+ ) from exc
1027
+
1028
+ operator = np.greater if self.args.greater_is_better else np.less
1029
+
1030
+ if self.state.best_metric is None:
1031
+ self.state.best_metric = float("-inf") if self.args.greater_is_better else float("inf")
1032
+
1033
+ if operator(metric_value, self.state.best_metric):
1034
+ run_dir = self._get_output_dir(trial=trial)
1035
+ checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
1036
+ output_dir = os.path.join(run_dir, checkpoint_folder)
1037
+ self.state.best_metric = metric_value
1038
+ self.state.best_model_checkpoint = output_dir
1039
+
1040
+ is_new_best_metric = True
1041
+
1042
+ return is_new_best_metric
1043
+
1044
+ def create_model_card(
1045
+ self,
1046
+ model_name: Optional[str] = None,
1047
+ dataset_name: Optional[str] = None,
1048
+ tags: Union[str, list[str], None] = None,
1049
+ ):
1050
+ """
1051
+ Creates a draft of a model card using the information available to the `Trainer`.
1052
+
1053
+ Args:
1054
+ model_name (`str` or `None`, *optional*, defaults to `None`):
1055
+ Name of the model.
1056
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
1057
+ Name of the dataset used for training.
1058
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
1059
+ Tags to be associated with the model card.
1060
+ """
1061
+ if not self.is_world_process_zero():
1062
+ return
1063
+
1064
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
1065
+ base_model = self.model.config._name_or_path
1066
+ else:
1067
+ base_model = None
1068
+
1069
+ tags = tags or []
1070
+ if isinstance(tags, str):
1071
+ tags = [tags]
1072
+
1073
+ if hasattr(self.model.config, "unsloth_version"):
1074
+ tags.append("unsloth")
1075
+
1076
+ citation = textwrap.dedent("""\
1077
+ @article{guo2024direct,
1078
+ title = {{Direct Language Model Alignment from Online AI Feedback}},
1079
+ author = {Shangmin Guo and Biao Zhang and Tianlin Liu and Tianqi Liu and Misha Khalman and Felipe Llinares and Alexandre Ram{\'{e}} and Thomas Mesnard and Yao Zhao and Bilal Piot and Johan Ferret and Mathieu Blondel},
1080
+ year = 2024,
1081
+ eprint = {arXiv:2402.04792}
1082
+ }""")
1083
+
1084
+ model_card = generate_model_card(
1085
+ base_model=base_model,
1086
+ model_name=model_name,
1087
+ hub_model_id=self.hub_model_id,
1088
+ dataset_name=dataset_name,
1089
+ tags=tags,
1090
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
1091
+ comet_url=get_comet_experiment_url(),
1092
+ trainer_name="Online DPO",
1093
+ trainer_citation=citation,
1094
+ paper_title="Direct Language Model Alignment from Online AI Feedback",
1095
+ paper_id="2402.04792",
1096
+ )
1097
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
1098
+ class UnslothOnlineDPOTrainer(_UnslothOnlineDPOTrainer):
1099
+ """
1100
+
1101
+ Initialize OnlineDPOTrainer.
1102
+
1103
+ Args:
1104
+ model (`transformers.PreTrainedModel` or `torch.nn.Module`):
1105
+ The model to train, preferably an `AutoModelForCausalLM`.
1106
+ ref_model (`transformers.PreTrainedModel` or `torch.nn.Module` or `None`):
1107
+ The reference model to use for training. If None is specified, the reference model will be created from
1108
+ the model.
1109
+ reward_model (`transformers.PreTrainedModel` or `torch.nn.Module` or `None`):
1110
+ The reward model to score completions with, preferably an `AutoModelForSequenceClassification`.
1111
+ judge (`BasePairwiseJudge`):
1112
+ The judge to use for pairwise comparison of model completions.
1113
+ args (`OnlineDPOConfig`):
1114
+ The online DPO config arguments to use for training.
1115
+ data_collator (`transformers.DataCollator`):
1116
+ The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
1117
+ which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
1118
+ train_dataset (`datasets.Dataset`):
1119
+ The dataset to use for training.
1120
+ eval_dataset (`datasets.Dataset`):
1121
+ The dataset to use for evaluation.
1122
+ processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
1123
+ Processing class used to process the data. If provided, will be used to automatically process the inputs
1124
+ for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
1125
+ reuse the fine-tuned model.
1126
+ peft_config (`dict`):
1127
+ The peft config to use for training.
1128
+ compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
1129
+ The function to use to compute the metrics. Must take a `EvalPrediction` and return
1130
+ a dictionary string to metric values.
1131
+ callbacks (`list[transformers.TrainerCallback]`):
1132
+ The callbacks to use for training.
1133
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
1134
+ The optimizer and scheduler to use for training.
1135
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
1136
+ The function to use to preprocess the logits before computing the metrics.
1137
+
1138
+ """
1139
+ def __init__(
1140
+ self,
1141
+ model,
1142
+ ref_model = None,
1143
+ reward_model = None,
1144
+ judge = None,
1145
+ args = None,
1146
+ data_collator = None,
1147
+ train_dataset = None,
1148
+ eval_dataset = None,
1149
+ processing_class = None,
1150
+ reward_processing_class = None,
1151
+ peft_config = None,
1152
+ compute_metrics = None,
1153
+ callbacks = None,
1154
+ preprocess_logits_for_metrics = None,
1155
+ **kwargs
1156
+ ):
1157
+ if args is None: args = UnslothOnlineDPOConfig()
1158
+ use_bf16 = getattr(args, 'bf16', False)
1159
+ use_fp16 = getattr(args, 'fp16', False)
1160
+ force_float32 = False
1161
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
1162
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
1163
+ force_float32 = True
1164
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
1165
+ dtype = getattr(model.config, 'torch_dtype', None)
1166
+ if dtype is None: dtype = model.get_input_embeddings().dtype
1167
+ from unsloth_zoo.utils import _get_dtype
1168
+ dtype = _get_dtype(dtype)
1169
+ float16 = dtype == torch.float16
1170
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
1171
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
1172
+ if force_float32:
1173
+ args.fp16 = False
1174
+ args.bf16 = False
1175
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1176
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
1177
+ args.fp16 = float16
1178
+ args.bf16 = not float16
1179
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
1180
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
1181
+ args.eval_strategy = 'steps'
1182
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
1183
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
1184
+ if ga_steps is not None and ga_steps > 1:
1185
+ from transformers import __version__ as transformers_version
1186
+ if Version(transformers_version) <= Version('4.45.2'):
1187
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
1188
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
1189
+ if getattr(args, 'eval_strategy', 'no') != 'no':
1190
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
1191
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
1192
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
1193
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
1194
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
1195
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
1196
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
1197
+ if force_float32:
1198
+ args.bf16_full_eval = False
1199
+ args.fp16_full_eval = False
1200
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
1201
+ args.bf16_full_eval = True
1202
+ args.fp16_full_eval = False
1203
+ elif not bf16_full_eval and not fp16_full_eval:
1204
+ args.bf16_full_eval = args.bf16
1205
+ args.fp16_full_eval = args.fp16
1206
+ _output_logits = False
1207
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
1208
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
1209
+ if _output_logits:
1210
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
1211
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1212
+ pass
1213
+ else:
1214
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
1215
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
1216
+ if args_max_seq_length is None and model_max_seq_length is not None:
1217
+ max_seq_length = model.max_seq_length
1218
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1219
+ if model is not None and hasattr(model, 'for_training'):
1220
+ model.for_training()
1221
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1222
+ if 'processing_class' in locals():
1223
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1224
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1225
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
1226
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
1227
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1228
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
1229
+ data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
1230
+ elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
1231
+ data_collator = DataCollatorForSeq2Seq(__tokenizer)
1232
+ else:
1233
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
1234
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
1235
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
1236
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1237
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
1238
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
1239
+ data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
1240
+ else:
1241
+ data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
1242
+ other_metrics = []
1243
+
1244
+ from unsloth_zoo.logging_utils import PatchRLStatistics
1245
+ PatchRLStatistics('online_dpo_trainer', other_metrics)
1246
+
1247
+ super().__init__(
1248
+ model = model,
1249
+ ref_model = ref_model,
1250
+ reward_model = reward_model,
1251
+ judge = judge,
1252
+ args = args,
1253
+ data_collator = data_collator,
1254
+ train_dataset = train_dataset,
1255
+ eval_dataset = eval_dataset,
1256
+ processing_class = processing_class,
1257
+ reward_processing_class = reward_processing_class,
1258
+ peft_config = peft_config,
1259
+ compute_metrics = compute_metrics,
1260
+ callbacks = callbacks,
1261
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics,**kwargs)
1262
+ if hasattr(self, 'neftune_hook_handle'):
1263
+ self.neftune_hook_handle.remove()
1264
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1265
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
1266
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1267
+ pass
1268
+
1269
+ pass
unsloth_compiled_cache/UnslothPPOTrainer.py ADDED
@@ -0,0 +1,1259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.3.17
3
+ 2025.3.19
4
+ 4.50.3
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from trl.trainer.ppo_trainer import (Accelerator, BaseImageProcessor, CallbackHandler, DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK, DataCollatorWithPadding, DataLoader, Dataset, ExportableState, FeatureExtractionMixin, GenerationConfig, INVALID_LOGPROB, OnlineTrainerState, Optional, PPOConfig, PPOTrainer, PeftConfig, PeftModel, PolicyAndValueWrapper, PreTrainedTokenizerBase, PrinterCallback, ProcessorMixin, Trainer, TrainerCallback, TrainerControl, Union, batch_generation, broadcast, contextmanager, create_reference_model, defaultdict, disable_dropout_in_model, exact_div, first_true_indices, forward, gather_object, gc, generate_model_card, get_comet_experiment_url, get_peft_model, get_reporting_integration_callbacks, get_reward, is_peft_available, is_wandb_available, log_table_to_comet_experiment, masked_mean, masked_whiten, math, nn, np, nullcontext, os, pd, peft_module_casting_to_bf16, prepare_deepspeed, print_rich_table, textwrap, time, torch, truncate_response, unwrap_model_for_generation, wandb)
13
+
14
+
15
+ import os
16
+ from typing import *
17
+ from dataclasses import dataclass, field
18
+ from packaging.version import Version
19
+ import torch
20
+ import numpy as np
21
+ from contextlib import nullcontext
22
+ from torch.nn import functional as F
23
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
+
25
+ torch_compile_options = {
26
+ "epilogue_fusion" : True,
27
+ "max_autotune" : False,
28
+ "shape_padding" : True,
29
+ "trace.enabled" : False,
30
+ "triton.cudagraphs" : False,
31
+ }
32
+
33
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
+ def selective_log_softmax(logits, index):
35
+ logits = logits.to(torch.float32)
36
+ selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
+ # loop to reduce peak mem consumption
38
+ # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
+ logsumexp_values = torch.logsumexp(logits, dim = -1)
40
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
+ return per_token_logps
42
+ @dataclass
43
+ class UnslothPPOConfig(PPOConfig):
44
+ """
45
+
46
+ Configuration class for the [`PPOTrainer`].
47
+
48
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
49
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
50
+ command line.
51
+
52
+ Parameters:
53
+ exp_name (`str`, *optional*, defaults to `os.path.basename(__file__)[:-3]`):
54
+ Name of this experiment.
55
+ reward_model_path (`str`, *optional*, defaults to `"EleutherAI/pythia-160m"`):
56
+ Path to the reward model.
57
+ model_adapter_name (`str` or `None`, *optional*, defaults to `None`):
58
+ Name of the train target PEFT adapter, when using LoRA with multiple adapters.
59
+ ref_adapter_name (`str` or `None`, *optional*, defaults to `None`):
60
+ Name of the reference PEFT adapter, when using LoRA with multiple adapters.
61
+ num_ppo_epochs (`int`, *optional*, defaults to `4`):
62
+ Number of epochs to train.
63
+ whiten_rewards (`bool`, *optional*, defaults to `False`):
64
+ Whether to whiten the rewards.
65
+ kl_coef (`float`, *optional*, defaults to `0.05`):
66
+ KL coefficient.
67
+ cliprange (`float`, *optional*, defaults to `0.2`):
68
+ Clip range.
69
+ vf_coef (`float`, *optional*, defaults to `0.1`):
70
+ Value function coefficient.
71
+ cliprange_value (`float`, *optional*, defaults to `0.2`):
72
+ Clip range for the value function.
73
+ gamma (`float`, *optional*, defaults to `1.0`):
74
+ Discount factor.
75
+ lam (`float`, *optional*, defaults to `0.95`):
76
+ Lambda value for GAE.
77
+ ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
78
+ This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
79
+ improving generation speed. However, disabling this option allows training models that exceed the VRAM
80
+ capacity of a single GPU, albeit at the cost of slower generation.
81
+
82
+ """
83
+ vllm_sampling_params: Optional[Any] = field(
84
+ default = None,
85
+ metadata = {'help': 'vLLM SamplingParams'},
86
+ )
87
+ unsloth_num_chunks : Optional[int] = field(
88
+ default = -1,
89
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
90
+ )
91
+ def __init__(
92
+ self,
93
+ output_dir = None,
94
+ overwrite_output_dir = None,
95
+ do_train = False,
96
+ do_eval = False,
97
+ do_predict = False,
98
+ eval_strategy = 'no',
99
+ prediction_loss_only = False,
100
+ per_device_train_batch_size = 4,
101
+ per_device_eval_batch_size = 4,
102
+ per_gpu_train_batch_size = None,
103
+ per_gpu_eval_batch_size = None,
104
+ gradient_accumulation_steps = 2,
105
+ eval_accumulation_steps = 2,
106
+ eval_delay = 0,
107
+ torch_empty_cache_steps = 250,
108
+ learning_rate = 5e-05,
109
+ weight_decay = 0.01,
110
+ adam_beta1 = 0.9,
111
+ adam_beta2 = 0.999,
112
+ adam_epsilon = 1e-08,
113
+ max_grad_norm = 1.0,
114
+ num_train_epochs = 3.0,
115
+ max_steps = -1,
116
+ lr_scheduler_type = 'linear',
117
+ warmup_ratio = 0.1,
118
+ warmup_steps = 0,
119
+ log_level = 'passive',
120
+ log_level_replica = 'warning',
121
+ log_on_each_node = True,
122
+ logging_dir = None,
123
+ logging_strategy = 'steps',
124
+ logging_first_step = False,
125
+ logging_steps = 1,
126
+ logging_nan_inf_filter = False,
127
+ save_strategy = 'steps',
128
+ save_steps = 500,
129
+ save_total_limit = None,
130
+ save_safetensors = True,
131
+ save_on_each_node = False,
132
+ save_only_model = False,
133
+ restore_callback_states_from_checkpoint = False,
134
+ no_cuda = False,
135
+ use_cpu = False,
136
+ use_mps_device = False,
137
+ seed = 3407,
138
+ data_seed = 3407,
139
+ jit_mode_eval = False,
140
+ use_ipex = False,
141
+ bf16 = False,
142
+ fp16 = False,
143
+ fp16_opt_level = 'O1',
144
+ half_precision_backend = 'auto',
145
+ bf16_full_eval = False,
146
+ fp16_full_eval = False,
147
+ tf32 = None,
148
+ local_rank = -1,
149
+ ddp_backend = None,
150
+ tpu_num_cores = None,
151
+ tpu_metrics_debug = False,
152
+ debug = '',
153
+ dataloader_drop_last = False,
154
+ eval_steps = None,
155
+ dataloader_num_workers = 0,
156
+ dataloader_prefetch_factor = None,
157
+ past_index = -1,
158
+ run_name = None,
159
+ disable_tqdm = None,
160
+ remove_unused_columns = True,
161
+ label_names = None,
162
+ load_best_model_at_end = False,
163
+ metric_for_best_model = None,
164
+ greater_is_better = None,
165
+ ignore_data_skip = False,
166
+ fsdp = '',
167
+ fsdp_min_num_params = 0,
168
+ fsdp_config = None,
169
+ tp_size = 0,
170
+ fsdp_transformer_layer_cls_to_wrap = None,
171
+ accelerator_config = None,
172
+ deepspeed = None,
173
+ label_smoothing_factor = 0.0,
174
+ optim = 'adamw_8bit',
175
+ optim_args = None,
176
+ adafactor = False,
177
+ group_by_length = False,
178
+ length_column_name = 'length',
179
+ report_to = None,
180
+ ddp_find_unused_parameters = None,
181
+ ddp_bucket_cap_mb = None,
182
+ ddp_broadcast_buffers = None,
183
+ dataloader_pin_memory = True,
184
+ dataloader_persistent_workers = False,
185
+ skip_memory_metrics = True,
186
+ use_legacy_prediction_loop = False,
187
+ push_to_hub = False,
188
+ resume_from_checkpoint = None,
189
+ hub_model_id = None,
190
+ hub_strategy = 'every_save',
191
+ hub_token = None,
192
+ hub_private_repo = None,
193
+ hub_always_push = False,
194
+ gradient_checkpointing = False,
195
+ gradient_checkpointing_kwargs = None,
196
+ include_inputs_for_metrics = False,
197
+ eval_do_concat_batches = True,
198
+ fp16_backend = 'auto',
199
+ evaluation_strategy = None,
200
+ push_to_hub_model_id = None,
201
+ push_to_hub_organization = None,
202
+ push_to_hub_token = None,
203
+ mp_parameters = '',
204
+ auto_find_batch_size = False,
205
+ full_determinism = False,
206
+ torchdynamo = None,
207
+ ray_scope = 'last',
208
+ ddp_timeout = 1800,
209
+ torch_compile = False,
210
+ torch_compile_backend = None,
211
+ torch_compile_mode = None,
212
+ dispatch_batches = None,
213
+ split_batches = None,
214
+ include_tokens_per_second = False,
215
+ include_num_input_tokens_seen = False,
216
+ neftune_noise_alpha = None,
217
+ optim_target_modules = None,
218
+ batch_eval_metrics = False,
219
+ eval_on_start = False,
220
+ use_liger_kernel = False,
221
+ eval_use_gather_object = False,
222
+ average_tokens_across_devices = False,
223
+ dataset_num_proc = None,
224
+ num_mini_batches = 1,
225
+ total_episodes = None,
226
+ local_rollout_forward_batch_size = 64,
227
+ num_sample_generations = 10,
228
+ response_length = 53,
229
+ stop_token = None,
230
+ stop_token_id = None,
231
+ temperature = 0.7,
232
+ missing_eos_penalty = None,
233
+ sft_model_path = 'EleutherAI/pythia-160m',
234
+ world_size = None,
235
+ num_total_batches = None,
236
+ micro_batch_size = None,
237
+ local_batch_size = None,
238
+ batch_size = None,
239
+ local_mini_batch_size = None,
240
+ mini_batch_size = None,
241
+ exp_name = 'ppo_config',
242
+ reward_model_path = 'EleutherAI/pythia-160m',
243
+ model_adapter_name = None,
244
+ ref_adapter_name = None,
245
+ num_ppo_epochs = 4,
246
+ whiten_rewards = False,
247
+ kl_coef = 0.05,
248
+ cliprange = 0.2,
249
+ vf_coef = 0.1,
250
+ cliprange_value = 0.2,
251
+ gamma = 1.0,
252
+ lam = 0.95,
253
+ ds3_gather_for_generation = True,
254
+ vllm_sampling_params = None,
255
+ unsloth_num_chunks = -1,
256
+ **kwargs,
257
+ ):
258
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
259
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
260
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
261
+ output_dir = 'unsloth_training_checkpoints'
262
+ save_strategy = 'no'
263
+ if dataset_num_proc is None:
264
+ from multiprocessing import cpu_count
265
+ dataset_num_proc = cpu_count()
266
+
267
+ super().__init__(
268
+ output_dir = output_dir,
269
+ overwrite_output_dir = overwrite_output_dir,
270
+ do_train = do_train,
271
+ do_eval = do_eval,
272
+ do_predict = do_predict,
273
+ eval_strategy = eval_strategy,
274
+ prediction_loss_only = prediction_loss_only,
275
+ per_device_train_batch_size = per_device_train_batch_size,
276
+ per_device_eval_batch_size = per_device_eval_batch_size,
277
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
278
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
279
+ gradient_accumulation_steps = gradient_accumulation_steps,
280
+ eval_accumulation_steps = eval_accumulation_steps,
281
+ eval_delay = eval_delay,
282
+ torch_empty_cache_steps = torch_empty_cache_steps,
283
+ learning_rate = learning_rate,
284
+ weight_decay = weight_decay,
285
+ adam_beta1 = adam_beta1,
286
+ adam_beta2 = adam_beta2,
287
+ adam_epsilon = adam_epsilon,
288
+ max_grad_norm = max_grad_norm,
289
+ num_train_epochs = num_train_epochs,
290
+ max_steps = max_steps,
291
+ lr_scheduler_type = lr_scheduler_type,
292
+ warmup_ratio = warmup_ratio,
293
+ warmup_steps = warmup_steps,
294
+ log_level = log_level,
295
+ log_level_replica = log_level_replica,
296
+ log_on_each_node = log_on_each_node,
297
+ logging_dir = logging_dir,
298
+ logging_strategy = logging_strategy,
299
+ logging_first_step = logging_first_step,
300
+ logging_steps = logging_steps,
301
+ logging_nan_inf_filter = logging_nan_inf_filter,
302
+ save_strategy = save_strategy,
303
+ save_steps = save_steps,
304
+ save_total_limit = save_total_limit,
305
+ save_safetensors = save_safetensors,
306
+ save_on_each_node = save_on_each_node,
307
+ save_only_model = save_only_model,
308
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
309
+ no_cuda = no_cuda,
310
+ use_cpu = use_cpu,
311
+ use_mps_device = use_mps_device,
312
+ seed = seed,
313
+ data_seed = data_seed,
314
+ jit_mode_eval = jit_mode_eval,
315
+ use_ipex = use_ipex,
316
+ bf16 = bf16,
317
+ fp16 = fp16,
318
+ fp16_opt_level = fp16_opt_level,
319
+ half_precision_backend = half_precision_backend,
320
+ bf16_full_eval = bf16_full_eval,
321
+ fp16_full_eval = fp16_full_eval,
322
+ tf32 = tf32,
323
+ local_rank = local_rank,
324
+ ddp_backend = ddp_backend,
325
+ tpu_num_cores = tpu_num_cores,
326
+ tpu_metrics_debug = tpu_metrics_debug,
327
+ debug = debug,
328
+ dataloader_drop_last = dataloader_drop_last,
329
+ eval_steps = eval_steps,
330
+ dataloader_num_workers = dataloader_num_workers,
331
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
332
+ past_index = past_index,
333
+ run_name = run_name,
334
+ disable_tqdm = disable_tqdm,
335
+ remove_unused_columns = remove_unused_columns,
336
+ label_names = label_names,
337
+ load_best_model_at_end = load_best_model_at_end,
338
+ metric_for_best_model = metric_for_best_model,
339
+ greater_is_better = greater_is_better,
340
+ ignore_data_skip = ignore_data_skip,
341
+ fsdp = fsdp,
342
+ fsdp_min_num_params = fsdp_min_num_params,
343
+ fsdp_config = fsdp_config,
344
+ tp_size = tp_size,
345
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
346
+ accelerator_config = accelerator_config,
347
+ deepspeed = deepspeed,
348
+ label_smoothing_factor = label_smoothing_factor,
349
+ optim = optim,
350
+ optim_args = optim_args,
351
+ adafactor = adafactor,
352
+ group_by_length = group_by_length,
353
+ length_column_name = length_column_name,
354
+ report_to = report_to,
355
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
356
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
357
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
358
+ dataloader_pin_memory = dataloader_pin_memory,
359
+ dataloader_persistent_workers = dataloader_persistent_workers,
360
+ skip_memory_metrics = skip_memory_metrics,
361
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
362
+ push_to_hub = push_to_hub,
363
+ resume_from_checkpoint = resume_from_checkpoint,
364
+ hub_model_id = hub_model_id,
365
+ hub_strategy = hub_strategy,
366
+ hub_token = hub_token,
367
+ hub_private_repo = hub_private_repo,
368
+ hub_always_push = hub_always_push,
369
+ gradient_checkpointing = gradient_checkpointing,
370
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
371
+ include_inputs_for_metrics = include_inputs_for_metrics,
372
+ eval_do_concat_batches = eval_do_concat_batches,
373
+ fp16_backend = fp16_backend,
374
+ evaluation_strategy = evaluation_strategy,
375
+ push_to_hub_model_id = push_to_hub_model_id,
376
+ push_to_hub_organization = push_to_hub_organization,
377
+ push_to_hub_token = push_to_hub_token,
378
+ mp_parameters = mp_parameters,
379
+ auto_find_batch_size = auto_find_batch_size,
380
+ full_determinism = full_determinism,
381
+ torchdynamo = torchdynamo,
382
+ ray_scope = ray_scope,
383
+ ddp_timeout = ddp_timeout,
384
+ torch_compile = torch_compile,
385
+ torch_compile_backend = torch_compile_backend,
386
+ torch_compile_mode = torch_compile_mode,
387
+ dispatch_batches = dispatch_batches,
388
+ split_batches = split_batches,
389
+ include_tokens_per_second = include_tokens_per_second,
390
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
391
+ neftune_noise_alpha = neftune_noise_alpha,
392
+ optim_target_modules = optim_target_modules,
393
+ batch_eval_metrics = batch_eval_metrics,
394
+ eval_on_start = eval_on_start,
395
+ use_liger_kernel = use_liger_kernel,
396
+ eval_use_gather_object = eval_use_gather_object,
397
+ average_tokens_across_devices = average_tokens_across_devices,
398
+ dataset_num_proc = dataset_num_proc,
399
+ num_mini_batches = num_mini_batches,
400
+ total_episodes = total_episodes,
401
+ local_rollout_forward_batch_size = local_rollout_forward_batch_size,
402
+ num_sample_generations = num_sample_generations,
403
+ response_length = response_length,
404
+ stop_token = stop_token,
405
+ stop_token_id = stop_token_id,
406
+ temperature = temperature,
407
+ missing_eos_penalty = missing_eos_penalty,
408
+ sft_model_path = sft_model_path,
409
+ world_size = world_size,
410
+ num_total_batches = num_total_batches,
411
+ micro_batch_size = micro_batch_size,
412
+ local_batch_size = local_batch_size,
413
+ batch_size = batch_size,
414
+ local_mini_batch_size = local_mini_batch_size,
415
+ mini_batch_size = mini_batch_size,
416
+ exp_name = exp_name,
417
+ reward_model_path = reward_model_path,
418
+ model_adapter_name = model_adapter_name,
419
+ ref_adapter_name = ref_adapter_name,
420
+ num_ppo_epochs = num_ppo_epochs,
421
+ whiten_rewards = whiten_rewards,
422
+ kl_coef = kl_coef,
423
+ cliprange = cliprange,
424
+ vf_coef = vf_coef,
425
+ cliprange_value = cliprange_value,
426
+ gamma = gamma,
427
+ lam = lam,
428
+ ds3_gather_for_generation = ds3_gather_for_generation,**kwargs)
429
+ self.vllm_sampling_params = vllm_sampling_params
430
+ self.unsloth_num_chunks = unsloth_num_chunks
431
+ pass
432
+
433
+ class _UnslothPPOTrainer(Trainer):
434
+ _tag_names = ["trl", "ppo"]
435
+
436
+ def __init__(
437
+ self,
438
+ args: PPOConfig,
439
+ processing_class: Optional[
440
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
441
+ ],
442
+ model: nn.Module,
443
+ ref_model: Optional[nn.Module],
444
+ reward_model: nn.Module,
445
+ train_dataset: Dataset,
446
+ value_model: Optional[nn.Module] = None,
447
+ data_collator: Optional[DataCollatorWithPadding] = None,
448
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
449
+ # less commonly used
450
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
451
+ callbacks: Optional[list[TrainerCallback]] = None,
452
+ peft_config: Optional["PeftConfig"] = None,
453
+ ) -> None:
454
+ if ref_model is model:
455
+ raise ValueError(
456
+ "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
457
+ "same as `model`, you must make a copy of it, or `None` if you use peft."
458
+ )
459
+
460
+ self.args = args
461
+ self.processing_class = processing_class
462
+ self.policy_model = model
463
+
464
+ # Define the collator if not provided
465
+ if data_collator is None:
466
+ data_collator = DataCollatorWithPadding(self.processing_class)
467
+
468
+ # Handle stop token settings: update policy model's generation_config to use provided stop token
469
+ if args.stop_token and args.stop_token_id:
470
+ raise ValueError("You cannot set both `stop_token` and `stop_token_id`.")
471
+ elif args.stop_token:
472
+ if args.stop_token == "eos":
473
+ self.policy_model.generation_config.eos_token_id = self.stop_token_id = processing_class.eos_token_id
474
+ else:
475
+ raise ValueError(
476
+ f"Unknown `stop_token` {args.stop_token}. Allowed values are: `'eos'` and `None` (no stop token)."
477
+ )
478
+ else:
479
+ self.policy_model.generation_config.eos_token_id = self.stop_token_id = args.stop_token_id # None or int
480
+
481
+ # peft support
482
+ if not is_peft_available() and peft_config is not None:
483
+ raise ImportError(
484
+ "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
485
+ )
486
+ elif is_peft_available() and peft_config is not None:
487
+ # if model is a peft model and we have a peft_confg, we merge and unload it first
488
+ if isinstance(self.policy_model, PeftModel):
489
+ self.policy_model = self.policy_model.merge_and_unload()
490
+
491
+ # get peft model with the given config
492
+ self.policy_model = get_peft_model(self.policy_model, peft_config)
493
+ if args.bf16 and getattr(self.policy_model, "is_loaded_in_4bit", False):
494
+ peft_module_casting_to_bf16(self.policy_model)
495
+
496
+ self.is_peft_model = is_peft_available() and isinstance(self.policy_model, PeftModel)
497
+ self.model_adapter_name = args.model_adapter_name
498
+ self.ref_adapter_name = args.ref_adapter_name
499
+
500
+ if ref_model:
501
+ self.ref_model = ref_model
502
+ elif self.is_peft_model:
503
+ self.ref_model = None
504
+ else:
505
+ self.ref_model = create_reference_model(self.policy_model)
506
+
507
+ self.reward_model = reward_model
508
+ self.train_dataset = train_dataset
509
+ self.train_dataset_len = len(train_dataset)
510
+ self.value_model = value_model
511
+ self.data_collator = data_collator
512
+ self.eval_dataset = eval_dataset
513
+ self.optimizer, self.lr_scheduler = optimizers
514
+ self.optimizer_cls_and_kwargs = None # needed for transformers >= 4.47
515
+
516
+ #########
517
+ # calculate various batch sizes
518
+ #########
519
+ if args.total_episodes is None: # allow the users to define episodes in terms of epochs.
520
+ args.total_episodes = int(args.num_train_epochs * self.train_dataset_len)
521
+ accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps)
522
+ self.accelerator = accelerator
523
+ args.world_size = accelerator.num_processes
524
+ args.local_batch_size = (
525
+ args.per_device_train_batch_size * args.gradient_accumulation_steps * args.num_mini_batches
526
+ )
527
+ args.micro_batch_size = int(args.per_device_train_batch_size * args.world_size)
528
+ args.batch_size = int(args.local_batch_size * args.world_size)
529
+ args.mini_batch_size = exact_div(
530
+ args.batch_size, args.num_mini_batches, "`batch_size` must be a multiple of `num_mini_batches`"
531
+ )
532
+ args.local_mini_batch_size = exact_div(
533
+ args.local_batch_size, args.num_mini_batches, "`local_batch_size` must be a multiple of `num_mini_batches`"
534
+ )
535
+ if args.whiten_rewards:
536
+ assert (
537
+ args.local_mini_batch_size >= 8
538
+ ), f"Per-rank minibatch size {args.local_mini_batch_size} is insufficient for whitening"
539
+ # `per_rank_rollout_batch_size` is our `args.local_batch_size`
540
+ # `per_rank_minibatch_size` is our `args.local_mini_batch_size`
541
+ args.num_total_batches = math.ceil(
542
+ args.total_episodes / args.batch_size
543
+ ) # we may train for more than `total_episodes`
544
+ time_tensor = torch.tensor(int(time.time()), device=accelerator.device)
545
+ time_int = broadcast(time_tensor, 0).item() # avoid different timestamps across processes
546
+ args.run_name = f"{args.exp_name}__{args.seed}__{time_int}"
547
+ self.local_seed = args.seed + accelerator.process_index * 100003 # Prime
548
+ if args.num_sample_generations > 0:
549
+ self.sample_generations_freq = max(1, args.num_total_batches // args.num_sample_generations)
550
+ self.local_dataloader_batch_size = args.local_batch_size
551
+
552
+ #########
553
+ # setup model, optimizer, and others
554
+ #########
555
+ for module in [self.policy_model, self.ref_model, self.value_model, self.reward_model]:
556
+ if module is not None:
557
+ disable_dropout_in_model(module)
558
+ self.model = PolicyAndValueWrapper(self.policy_model, self.value_model)
559
+ self.model.config = self.policy_model.config # needed for pushing to hub
560
+ self.create_optimizer_and_scheduler(
561
+ num_training_steps=args.num_total_batches
562
+ ) # note that we are calling `self.lr_scheduler.step()` manually only at the batch level
563
+
564
+ #########
565
+ ### trainer specifics
566
+ #########
567
+ default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
568
+ self.callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
569
+ self.callback_handler = CallbackHandler(
570
+ self.callbacks, self.model, self.processing_class, self.optimizer, self.lr_scheduler
571
+ )
572
+ self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
573
+ self.control = TrainerControl()
574
+ self.state = OnlineTrainerState(
575
+ is_local_process_zero=self.is_local_process_zero(),
576
+ is_world_process_zero=self.is_world_process_zero(),
577
+ stateful_callbacks=[
578
+ cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
579
+ ],
580
+ )
581
+ self.current_flos = 0
582
+ self.hp_search_backend = None
583
+ self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
584
+ self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
585
+ # Create distant repo and output directory if needed
586
+ self.hub_model_id = None
587
+ if self.args.push_to_hub:
588
+ self.init_hf_repo()
589
+ if self.args.should_save:
590
+ os.makedirs(self.args.output_dir, exist_ok=True)
591
+
592
+ # Add tags for models that have been loaded with the correct transformers version
593
+ if hasattr(self.model, "add_model_tags"):
594
+ self.model.add_model_tags(self._tag_names)
595
+
596
+ #########
597
+ ### setup dataloader
598
+ #########
599
+ self.dataloader = DataLoader(
600
+ self.train_dataset,
601
+ batch_size=self.local_dataloader_batch_size,
602
+ shuffle=True,
603
+ collate_fn=self.data_collator,
604
+ drop_last=True, # needed; otherwise the last batch will be of ragged shape
605
+ )
606
+ # sync random states for DataLoader(shuffle=True) before `accelerator.prepare`
607
+ # see https://gist.github.com/vwxyzjn/2581bff1e48e185e0b85b6dfe1def79c
608
+ torch.manual_seed(args.seed)
609
+ self.model, self.optimizer, self.dataloader = accelerator.prepare(self.model, self.optimizer, self.dataloader)
610
+ torch.manual_seed(self.local_seed) # reset the local seed again
611
+
612
+ self.eval_dataloader = DataLoader(
613
+ self.eval_dataset,
614
+ batch_size=args.per_device_eval_batch_size,
615
+ collate_fn=self.data_collator,
616
+ drop_last=True,
617
+ ) # no need to shuffle eval dataset
618
+ self.eval_dataloader = accelerator.prepare(self.eval_dataloader)
619
+
620
+ if self.is_deepspeed_enabled:
621
+ self.reward_model = prepare_deepspeed(
622
+ self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16
623
+ )
624
+
625
+ if self.ref_model is None:
626
+ if not self.is_peft_model:
627
+ raise ValueError("No reference model and model is not a Peft model.")
628
+ else:
629
+ self.ref_model = prepare_deepspeed(
630
+ self.ref_model, args.per_device_train_batch_size, args.fp16, args.bf16
631
+ )
632
+ else:
633
+ if self.ref_model is None:
634
+ if not self.is_peft_model:
635
+ raise ValueError("No reference model and model is not a Peft model.")
636
+ else:
637
+ self.ref_model = self.ref_model.to(self.accelerator.device)
638
+ self.reward_model = self.reward_model.to(self.accelerator.device)
639
+
640
+ def get_train_dataloader(self) -> DataLoader:
641
+ return self.dataloader
642
+
643
+ def get_eval_dataloader(self) -> DataLoader:
644
+ return self.eval_dataloader
645
+
646
+ @contextmanager
647
+ def null_ref_context(self):
648
+ """Context manager for handling null reference model (that is, peft adapter manipulation)."""
649
+ with (
650
+ self.accelerator.unwrap_model(self.model.policy).disable_adapter()
651
+ if self.is_peft_model and not self.ref_adapter_name
652
+ else nullcontext()
653
+ ):
654
+ if self.ref_adapter_name:
655
+ self.model.policy.set_adapter(self.ref_adapter_name)
656
+ yield
657
+ if self.ref_adapter_name:
658
+ self.model.policy.set_adapter(self.model_adapter_name or "default")
659
+
660
+ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
661
+ backup_model = self.model
662
+ self.model = self.model.policy # save only the policy
663
+
664
+ if self.is_deepspeed_enabled:
665
+ backup_deepspeed = self.deepspeed
666
+ self.deepspeed = self.model
667
+
668
+ super().save_model(output_dir, _internal_call)
669
+
670
+ self.model = backup_model
671
+
672
+ if self.is_deepspeed_enabled:
673
+ self.deepspeed = backup_deepspeed
674
+
675
+ def train(self):
676
+ args = self.args
677
+ accelerator = self.accelerator
678
+ optimizer = self.optimizer
679
+ model = self.model
680
+ ref_policy = self.ref_model
681
+ reward_model = self.reward_model
682
+ processing_class = self.processing_class
683
+ dataloader = self.dataloader
684
+ device = accelerator.device
685
+
686
+ def repeat_generator():
687
+ while True:
688
+ yield from dataloader
689
+
690
+ iter_dataloader = iter(repeat_generator())
691
+ generation_config = GenerationConfig(
692
+ max_new_tokens=args.response_length,
693
+ temperature=(args.temperature + 1e-7),
694
+ top_k=0.0,
695
+ top_p=1.0,
696
+ do_sample=True,
697
+ )
698
+
699
+ accelerator.print("===training policy===")
700
+ start_time = time.time()
701
+ stats_shape = (args.num_ppo_epochs, args.num_mini_batches, args.gradient_accumulation_steps)
702
+ approxkl_stats = torch.zeros(stats_shape, device=device)
703
+ pg_clipfrac_stats = torch.zeros(stats_shape, device=device)
704
+ pg_loss_stats = torch.zeros(stats_shape, device=device)
705
+ vf_loss_stats = torch.zeros(stats_shape, device=device)
706
+ vf_clipfrac_stats = torch.zeros(stats_shape, device=device)
707
+ entropy_stats = torch.zeros(stats_shape, device=device)
708
+ ratio_stats = torch.zeros(stats_shape, device=device)
709
+ model.train()
710
+
711
+ # trainer state initialization
712
+ self.state.global_step = 0
713
+ self.state.episode = 0
714
+ self.state.max_steps = args.num_total_batches * args.num_mini_batches
715
+ self.state.num_train_epochs = args.total_episodes / self.train_dataset_len
716
+ # Compute absolute values for logging, eval, and save if given as ratio
717
+ if args.logging_steps is not None:
718
+ if args.logging_steps < 1:
719
+ self.state.logging_steps = math.ceil(self.state.max_steps * args.logging_steps)
720
+ else:
721
+ self.state.logging_steps = args.logging_steps
722
+ if args.eval_steps is not None:
723
+ if args.eval_steps < 1:
724
+ self.state.eval_steps = math.ceil(self.state.max_steps * args.eval_steps)
725
+ else:
726
+ self.state.eval_steps = args.eval_steps
727
+ if args.save_steps is not None:
728
+ if args.save_steps < 1:
729
+ self.state.save_steps = math.ceil(self.state.max_steps * args.save_steps)
730
+ else:
731
+ self.state.save_steps = args.save_steps
732
+ self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
733
+
734
+ # backward compatibility
735
+ if self.is_deepspeed_enabled:
736
+ self.deepspeed = self.model
737
+ self.model_wrapped = self.model
738
+
739
+ for update in range(1, args.num_total_batches + 1):
740
+ self.state.episode += 1 * args.batch_size
741
+ data = next(iter_dataloader)
742
+ with torch.no_grad():
743
+ queries = data["input_ids"].to(device)
744
+ context_length = queries.shape[1]
745
+ responses = []
746
+ postprocessed_responses = []
747
+ logprobs = []
748
+ ref_logprobs = []
749
+ scores = []
750
+ sequence_lengths = []
751
+ values = []
752
+ with unwrap_model_for_generation(
753
+ self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
754
+ ) as unwrapped_model:
755
+ query_responses, logitss = batch_generation(
756
+ unwrapped_model.policy,
757
+ queries,
758
+ args.local_rollout_forward_batch_size,
759
+ processing_class.pad_token_id,
760
+ generation_config,
761
+ )
762
+
763
+ for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size):
764
+ query = queries[i : i + args.local_rollout_forward_batch_size]
765
+ query_response = query_responses[i : i + args.local_rollout_forward_batch_size]
766
+ response = query_response[:, context_length:]
767
+ logits = logitss[i : i + args.local_rollout_forward_batch_size]
768
+ logprob = selective_log_softmax(logits, response)
769
+ del logits
770
+ torch.cuda.empty_cache()
771
+
772
+ if ref_policy is None:
773
+ with self.null_ref_context():
774
+ ref_output = forward(model.policy, query_response, processing_class.pad_token_id)
775
+ else:
776
+ ref_output = forward(ref_policy, query_response, processing_class.pad_token_id)
777
+ ref_logits = ref_output.logits[:, context_length - 1 : -1]
778
+ ref_logits /= args.temperature + 1e-7
779
+ ref_logprob = selective_log_softmax(ref_logits, response)
780
+ del ref_output, ref_logits
781
+ torch.cuda.empty_cache()
782
+
783
+ # Response Processing 1. truncate response after the first occurrence of `stop_token_id`
784
+ postprocessed_response = response
785
+ if self.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
786
+ postprocessed_response = truncate_response(
787
+ self.stop_token_id, processing_class.pad_token_id, response
788
+ )
789
+
790
+ # Response Processing 2. run reward model on the truncated responses
791
+ postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
792
+ sequence_length = first_true_indices(postprocessed_response == processing_class.pad_token_id) - 1
793
+ unwrapped_value_model = accelerator.unwrap_model(model).value_model
794
+ full_value, _, _ = get_reward(
795
+ unwrapped_value_model, query_response, processing_class.pad_token_id, context_length
796
+ )
797
+ value = full_value[:, context_length - 1 : -1].squeeze(-1)
798
+ _, score, _ = get_reward(
799
+ reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length
800
+ )
801
+
802
+ responses.append(response)
803
+ postprocessed_responses.append(postprocessed_response)
804
+ logprobs.append(logprob)
805
+ ref_logprobs.append(ref_logprob)
806
+ sequence_lengths.append(sequence_length)
807
+ scores.append(score)
808
+ values.append(value)
809
+ responses = torch.cat(responses, 0)
810
+ postprocessed_responses = torch.cat(postprocessed_responses, 0)
811
+ logprobs = torch.cat(logprobs, 0)
812
+ ref_logprobs = torch.cat(ref_logprobs, 0)
813
+ sequence_lengths = torch.cat(sequence_lengths, 0)
814
+ scores = torch.cat(scores, 0)
815
+ values = torch.cat(values, 0)
816
+ del (logprob, ref_logprob, full_value, value, score, unwrapped_model)
817
+ torch.cuda.empty_cache()
818
+ gc.collect()
819
+
820
+ # Response Processing 3. Filter completion. Ensure that the sample contains stop_token_id
821
+ # Completions not passing that filter will receive a lower score.
822
+ contain_eos_token = torch.any(postprocessed_responses == self.processing_class.eos_token_id, dim=-1)
823
+ if self.args.missing_eos_penalty is not None:
824
+ scores[~contain_eos_token] -= self.args.missing_eos_penalty
825
+ # accelerator.print(f"{scores=}, {(contain_eos_token.sum() / len(contain_eos_token))=}")
826
+
827
+ # be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw
828
+ response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1)
829
+ padding_mask = response_idxs > sequence_lengths.unsqueeze(1)
830
+ logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB)
831
+ ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB)
832
+ sequence_lengths_p1 = sequence_lengths + 1
833
+ padding_mask_p1 = response_idxs > (sequence_lengths_p1.unsqueeze(1))
834
+ values = torch.masked_fill(values, padding_mask_p1, 0)
835
+
836
+ # 4. compute rewards
837
+ kl = logprobs - ref_logprobs
838
+ non_score_reward = -args.kl_coef * kl
839
+ rewards = non_score_reward.clone()
840
+ actual_start = torch.arange(rewards.size(0), device=rewards.device)
841
+ actual_end = torch.where(sequence_lengths_p1 < rewards.size(1), sequence_lengths_p1, sequence_lengths)
842
+ rewards[[actual_start, actual_end]] += scores
843
+
844
+ # 5. whiten rewards
845
+ if args.whiten_rewards:
846
+ rewards = masked_whiten(rewards, mask=~padding_mask_p1, shift_mean=False)
847
+ rewards = torch.masked_fill(rewards, padding_mask_p1, 0)
848
+
849
+ # 6. compute advantages and returns
850
+ lastgaelam = 0
851
+ advantages_reversed = []
852
+ gen_length = responses.shape[1]
853
+ for t in reversed(range(gen_length)):
854
+ nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0
855
+ delta = rewards[:, t] + args.gamma * nextvalues - values[:, t]
856
+ lastgaelam = delta + args.gamma * args.lam * lastgaelam
857
+ advantages_reversed.append(lastgaelam)
858
+ advantages = torch.stack(advantages_reversed[::-1], axis=1)
859
+ returns = advantages + values
860
+ advantages = masked_whiten(advantages, ~padding_mask)
861
+ advantages = torch.masked_fill(advantages, padding_mask, 0)
862
+ torch.cuda.empty_cache()
863
+
864
+ # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch
865
+ for ppo_epoch_idx in range(args.num_ppo_epochs):
866
+ b_inds = np.random.permutation(args.local_batch_size)
867
+ minibatch_idx = 0
868
+ for mini_batch_start in range(0, args.local_batch_size, args.local_mini_batch_size):
869
+ mini_batch_end = mini_batch_start + args.local_mini_batch_size
870
+ mini_batch_inds = b_inds[mini_batch_start:mini_batch_end]
871
+ gradient_accumulation_idx = 0
872
+ for micro_batch_start in range(0, args.local_mini_batch_size, args.per_device_train_batch_size):
873
+ with accelerator.accumulate(model):
874
+ micro_batch_end = micro_batch_start + args.per_device_train_batch_size
875
+ micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end]
876
+ mb_advantage = advantages[micro_batch_inds]
877
+ mb_responses = responses[micro_batch_inds]
878
+ mb_query_responses = query_responses[micro_batch_inds]
879
+ mb_logprobs = logprobs[micro_batch_inds]
880
+ mb_return = returns[micro_batch_inds]
881
+ mb_values = values[micro_batch_inds]
882
+
883
+ output, vpred_temp = forward(model, mb_query_responses, processing_class.pad_token_id)
884
+ logits = output.logits[:, context_length - 1 : -1]
885
+ logits /= args.temperature + 1e-7
886
+ new_logprobs = selective_log_softmax(logits, mb_responses)
887
+ new_logprobs = torch.masked_fill(
888
+ new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB
889
+ )
890
+ vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1)
891
+ vpred = torch.masked_fill(vpred, padding_mask_p1[micro_batch_inds], 0)
892
+ vpredclipped = torch.clamp(
893
+ vpred,
894
+ mb_values - args.cliprange_value,
895
+ mb_values + args.cliprange_value,
896
+ )
897
+ vf_losses1 = torch.square(vpred - mb_return)
898
+ vf_losses2 = torch.square(vpredclipped - mb_return)
899
+ vf_loss_max = torch.max(vf_losses1, vf_losses2)
900
+ vf_loss = 0.5 * masked_mean(vf_loss_max, ~padding_mask_p1[micro_batch_inds])
901
+ vf_clipfrac = masked_mean(
902
+ (vf_losses2 > vf_losses1).float(), ~padding_mask_p1[micro_batch_inds]
903
+ )
904
+ logprobs_diff = new_logprobs - mb_logprobs
905
+ ratio = torch.exp(logprobs_diff)
906
+ pg_losses = -mb_advantage * ratio
907
+ pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange)
908
+ pg_loss_max = torch.max(pg_losses, pg_losses2)
909
+ pg_loss = masked_mean(pg_loss_max, ~padding_mask[micro_batch_inds])
910
+ loss = pg_loss + args.vf_coef * vf_loss
911
+ accelerator.backward(loss)
912
+ optimizer.step()
913
+ optimizer.zero_grad()
914
+ with torch.no_grad():
915
+ pg_clipfrac = masked_mean(
916
+ (pg_losses2 > pg_losses).float(), ~padding_mask[micro_batch_inds]
917
+ )
918
+ prob_dist = torch.nn.functional.softmax(logits, dim=-1)
919
+ entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1)
920
+ approxkl = 0.5 * (logprobs_diff**2).mean()
921
+ approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl
922
+ pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = (
923
+ pg_clipfrac
924
+ )
925
+ pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss
926
+ vf_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss
927
+ vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = (
928
+ vf_clipfrac
929
+ )
930
+ entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean()
931
+ ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean()
932
+ gradient_accumulation_idx += 1
933
+ minibatch_idx += 1
934
+ # del everything and empty cache
935
+ # fmt: off
936
+ del (
937
+ output, vpred_temp, logits, new_logprobs, vpred, vpredclipped,
938
+ vf_losses1, vf_losses2, vf_loss, vf_clipfrac, logprobs_diff, ratio, pg_losses, pg_losses2, pg_loss_max,
939
+ pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl, mb_return,
940
+ mb_advantage, mb_values, mb_responses, mb_query_responses, mb_logprobs,
941
+ )
942
+ # fmt: on
943
+ torch.cuda.empty_cache()
944
+ with torch.no_grad():
945
+ mean_kl = kl.sum(1).mean()
946
+ mean_entropy = (-logprobs).sum(1).mean()
947
+ mean_non_score_reward = non_score_reward.sum(1).mean()
948
+ rlhf_reward = mean_non_score_reward + scores.mean()
949
+ eps = int(self.state.episode / (time.time() - start_time))
950
+ metrics = {}
951
+ metrics["eps"] = eps
952
+ metrics["objective/kl"] = self.accelerator.gather_for_metrics(mean_kl).mean().item()
953
+ metrics["objective/entropy"] = self.accelerator.gather_for_metrics(mean_entropy).mean().item()
954
+ metrics["objective/non_score_reward"] = (
955
+ self.accelerator.gather_for_metrics(mean_non_score_reward).mean().item()
956
+ )
957
+ metrics["objective/rlhf_reward"] = self.accelerator.gather_for_metrics(rlhf_reward).mean().item()
958
+ metrics["objective/scores"] = self.accelerator.gather_for_metrics(scores.mean()).mean().item()
959
+ metrics["policy/approxkl_avg"] = self.accelerator.gather_for_metrics(approxkl_stats).mean().item()
960
+ metrics["policy/clipfrac_avg"] = self.accelerator.gather_for_metrics(pg_clipfrac_stats).mean().item()
961
+ metrics["loss/policy_avg"] = self.accelerator.gather_for_metrics(pg_loss_stats).mean().item()
962
+ metrics["loss/value_avg"] = self.accelerator.gather_for_metrics(vf_loss_stats).mean().item()
963
+ metrics["val/clipfrac_avg"] = self.accelerator.gather_for_metrics(vf_clipfrac_stats).mean().item()
964
+ metrics["policy/entropy_avg"] = self.accelerator.gather_for_metrics(entropy_stats).mean().item()
965
+ metrics["val/ratio"] = self.accelerator.gather_for_metrics(ratio_stats).mean().item()
966
+ metrics["val/ratio_var"] = self.accelerator.gather_for_metrics(ratio_stats).var().item()
967
+ metrics["val/num_eos_tokens"] = (responses == processing_class.eos_token_id).sum().item()
968
+ metrics["lr"] = self.lr_scheduler.get_last_lr()[0]
969
+ metrics["episode"] = self.state.episode
970
+ self.state.epoch = self.state.episode / self.train_dataset_len # used by self.log
971
+ self.state.global_step += 1
972
+ self.log(metrics)
973
+
974
+ self.lr_scheduler.step()
975
+ self.control = self.callback_handler.on_step_end(args, self.state, self.control)
976
+ if self.control.should_save:
977
+ self._save_checkpoint(model, trial=None)
978
+ self.control = self.callback_handler.on_save(self.args, self.state, self.control)
979
+ del kl, mean_kl, mean_entropy, mean_non_score_reward, scores, metrics, non_score_reward
980
+ torch.cuda.empty_cache()
981
+ gc.collect()
982
+
983
+ if args.num_sample_generations > 0 and (update - 1) % self.sample_generations_freq == 0:
984
+ self.generate_completions(sampling=True)
985
+ torch.cuda.empty_cache()
986
+ del (
987
+ query_responses,
988
+ responses,
989
+ postprocessed_responses,
990
+ logprobs,
991
+ ref_logprobs,
992
+ values,
993
+ sequence_lengths,
994
+ contain_eos_token,
995
+ sequence_lengths_p1,
996
+ response_idxs,
997
+ padding_mask,
998
+ padding_mask_p1,
999
+ rewards,
1000
+ actual_start,
1001
+ actual_end,
1002
+ advantages,
1003
+ returns,
1004
+ )
1005
+ torch.cuda.empty_cache()
1006
+
1007
+ # HF trainer specifics
1008
+ self.control = self.callback_handler.on_train_end(args, self.state, self.control)
1009
+ if self.control.should_save:
1010
+ self._save_checkpoint(model, trial=None, metrics=None)
1011
+ self.control = self.callback_handler.on_save(self.args, self.state, self.control)
1012
+
1013
+ def generate_completions(self, sampling: bool = False):
1014
+ args = self.args
1015
+ processing_class = self.processing_class
1016
+ generation_config = GenerationConfig(
1017
+ max_new_tokens=self.args.response_length,
1018
+ temperature=(0.01 + 1e-7),
1019
+ top_k=0.0,
1020
+ top_p=1.0,
1021
+ do_sample=True,
1022
+ )
1023
+
1024
+ table = defaultdict(list)
1025
+ with unwrap_model_for_generation(
1026
+ self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
1027
+ ) as unwrapped_model:
1028
+ for batch in self.eval_dataloader:
1029
+ query = batch["input_ids"]
1030
+ with torch.no_grad():
1031
+ context_length = query.shape[1]
1032
+ query_response, _ = batch_generation(
1033
+ unwrapped_model.policy,
1034
+ query,
1035
+ query.shape[0],
1036
+ processing_class.pad_token_id,
1037
+ generation_config,
1038
+ )
1039
+ response = query_response[:, context_length:]
1040
+ postprocessed_response = response
1041
+ if self.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
1042
+ postprocessed_response = truncate_response(
1043
+ self.stop_token_id, processing_class.pad_token_id, response
1044
+ )
1045
+ table["query"].extend(
1046
+ gather_object(processing_class.batch_decode(query, skip_special_tokens=True))
1047
+ )
1048
+ table["model response"].extend(
1049
+ gather_object(processing_class.batch_decode(postprocessed_response))
1050
+ )
1051
+
1052
+ postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
1053
+ _, score, _ = get_reward(
1054
+ self.reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length
1055
+ )
1056
+ table["score"].extend(self.accelerator.gather_for_metrics(score).float().cpu().numpy())
1057
+
1058
+ if sampling:
1059
+ break
1060
+ df = pd.DataFrame(table)
1061
+
1062
+ if self.accelerator.is_main_process:
1063
+ print_rich_table(df.iloc[0 : 0 + 5])
1064
+ if "wandb" in args.report_to:
1065
+ import wandb
1066
+
1067
+ if wandb.run is not None:
1068
+ wandb.log({"completions": wandb.Table(dataframe=df)})
1069
+
1070
+ if "comet_ml" in args.report_to:
1071
+ log_table_to_comet_experiment(
1072
+ name="completions.csv",
1073
+ table=df,
1074
+ )
1075
+
1076
+ def create_model_card(
1077
+ self,
1078
+ model_name: Optional[str] = None,
1079
+ dataset_name: Optional[str] = None,
1080
+ tags: Union[str, list[str], None] = None,
1081
+ ):
1082
+ """
1083
+ Creates a draft of a model card using the information available to the `Trainer`.
1084
+
1085
+ Args:
1086
+ model_name (`str` or `None`, *optional*, defaults to `None`):
1087
+ Name of the model.
1088
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
1089
+ Name of the dataset used for training.
1090
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
1091
+ Tags to be associated with the model card.
1092
+ """
1093
+ if not self.is_world_process_zero():
1094
+ return
1095
+
1096
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
1097
+ base_model = self.model.config._name_or_path
1098
+ else:
1099
+ base_model = None
1100
+
1101
+ tags = tags or []
1102
+ if isinstance(tags, str):
1103
+ tags = [tags]
1104
+
1105
+ if hasattr(self.model.config, "unsloth_version"):
1106
+ tags.append("unsloth")
1107
+
1108
+ citation = textwrap.dedent("""\
1109
+ @article{mziegler2019fine-tuning,
1110
+ title = {{Fine-Tuning Language Models from Human Preferences}},
1111
+ author = {Daniel M. Ziegler and Nisan Stiennon and Jeffrey Wu and Tom B. Brown and Alec Radford and Dario Amodei and Paul F. Christiano and Geoffrey Irving},
1112
+ year = 2019,
1113
+ eprint = {arXiv:1909.08593}
1114
+ }""")
1115
+
1116
+ model_card = generate_model_card(
1117
+ base_model=base_model,
1118
+ model_name=model_name,
1119
+ hub_model_id=self.hub_model_id,
1120
+ dataset_name=dataset_name,
1121
+ tags=tags,
1122
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
1123
+ comet_url=get_comet_experiment_url(),
1124
+ trainer_name="PPO",
1125
+ trainer_citation=citation,
1126
+ paper_title="Fine-Tuning Language Models from Human Preferences",
1127
+ paper_id="1909.08593",
1128
+ )
1129
+
1130
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
1131
+ class UnslothPPOTrainer(_UnslothPPOTrainer):
1132
+ """
1133
+
1134
+ """
1135
+ def __init__(
1136
+ self,
1137
+ args,
1138
+ processing_class,
1139
+ model,
1140
+ ref_model,
1141
+ reward_model,
1142
+ train_dataset,
1143
+ value_model = None,
1144
+ data_collator = None,
1145
+ eval_dataset = None,
1146
+ callbacks = None,
1147
+ peft_config = None,
1148
+ **kwargs
1149
+ ):
1150
+ if args is None: args = UnslothPPOConfig()
1151
+ use_bf16 = getattr(args, 'bf16', False)
1152
+ use_fp16 = getattr(args, 'fp16', False)
1153
+ force_float32 = False
1154
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
1155
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
1156
+ force_float32 = True
1157
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
1158
+ dtype = getattr(model.config, 'torch_dtype', None)
1159
+ if dtype is None: dtype = model.get_input_embeddings().dtype
1160
+ from unsloth_zoo.utils import _get_dtype
1161
+ dtype = _get_dtype(dtype)
1162
+ float16 = dtype == torch.float16
1163
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
1164
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
1165
+ if force_float32:
1166
+ args.fp16 = False
1167
+ args.bf16 = False
1168
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1169
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
1170
+ args.fp16 = float16
1171
+ args.bf16 = not float16
1172
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
1173
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
1174
+ args.eval_strategy = 'steps'
1175
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
1176
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
1177
+ if ga_steps is not None and ga_steps > 1:
1178
+ from transformers import __version__ as transformers_version
1179
+ if Version(transformers_version) <= Version('4.45.2'):
1180
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
1181
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
1182
+ if getattr(args, 'eval_strategy', 'no') != 'no':
1183
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
1184
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
1185
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
1186
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
1187
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
1188
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
1189
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
1190
+ if force_float32:
1191
+ args.bf16_full_eval = False
1192
+ args.fp16_full_eval = False
1193
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
1194
+ args.bf16_full_eval = True
1195
+ args.fp16_full_eval = False
1196
+ elif not bf16_full_eval and not fp16_full_eval:
1197
+ args.bf16_full_eval = args.bf16
1198
+ args.fp16_full_eval = args.fp16
1199
+ _output_logits = False
1200
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
1201
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
1202
+ if _output_logits:
1203
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
1204
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1205
+ pass
1206
+ else:
1207
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
1208
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
1209
+ if args_max_seq_length is None and model_max_seq_length is not None:
1210
+ max_seq_length = model.max_seq_length
1211
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1212
+ if model is not None and hasattr(model, 'for_training'):
1213
+ model.for_training()
1214
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1215
+ if 'processing_class' in locals():
1216
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1217
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1218
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
1219
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
1220
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1221
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
1222
+ data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
1223
+ elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
1224
+ data_collator = DataCollatorForSeq2Seq(__tokenizer)
1225
+ else:
1226
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
1227
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
1228
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
1229
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1230
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
1231
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
1232
+ data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
1233
+ else:
1234
+ data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
1235
+ other_metrics = []
1236
+
1237
+ from unsloth_zoo.logging_utils import PatchRLStatistics
1238
+ PatchRLStatistics('ppo_trainer', other_metrics)
1239
+
1240
+ super().__init__(
1241
+ args = args,
1242
+ processing_class = processing_class,
1243
+ model = model,
1244
+ ref_model = ref_model,
1245
+ reward_model = reward_model,
1246
+ train_dataset = train_dataset,
1247
+ value_model = value_model,
1248
+ data_collator = data_collator,
1249
+ eval_dataset = eval_dataset,
1250
+ callbacks = callbacks,
1251
+ peft_config = peft_config,**kwargs)
1252
+ if hasattr(self, 'neftune_hook_handle'):
1253
+ self.neftune_hook_handle.remove()
1254
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1255
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
1256
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1257
+ pass
1258
+
1259
+ pass
unsloth_compiled_cache/UnslothPRMTrainer.py ADDED
@@ -0,0 +1,800 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.3.17
3
+ 2025.3.19
4
+ 4.50.3
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from trl.trainer.prm_trainer import (BaseImageProcessor, Callable, DataCollator, DataCollatorForTokenClassification, Dataset, EvalPrediction, FeatureExtractionMixin, Optional, PRMConfig, PRMTrainer, PartialState, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, Trainer, TrainerCallback, Union, chain, compute_accuracy, disable_dropout_in_model, features, generate_model_card, inspect, is_peft_available, is_wandb_available, nn, os, prepare_model_for_kbit_training, textwrap, torch, wandb, warnings)
13
+
14
+
15
+ import os
16
+ from typing import *
17
+ from dataclasses import dataclass, field
18
+ from packaging.version import Version
19
+ import torch
20
+ import numpy as np
21
+ from contextlib import nullcontext
22
+ from torch.nn import functional as F
23
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
+
25
+ torch_compile_options = {
26
+ "epilogue_fusion" : True,
27
+ "max_autotune" : False,
28
+ "shape_padding" : True,
29
+ "trace.enabled" : False,
30
+ "triton.cudagraphs" : False,
31
+ }
32
+
33
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
+ def selective_log_softmax(logits, index):
35
+ logits = logits.to(torch.float32)
36
+ selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
+ # loop to reduce peak mem consumption
38
+ # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
+ logsumexp_values = torch.logsumexp(logits, dim = -1)
40
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
+ return per_token_logps
42
+ @dataclass
43
+ class UnslothPRMConfig(PRMConfig):
44
+ """
45
+
46
+ Configuration class for the [`PRMTrainer`].
47
+
48
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
49
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
50
+ command line.
51
+
52
+ Parameters:
53
+ learning_rate (`float`, *optional*, defaults to `1e-5`):
54
+ Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
55
+ [`~transformers.TrainingArguments`].
56
+ max_length (`int` or `None`, *optional*, defaults to `1024`):
57
+ Maximum length of the sequences (prompt + completion) used for truncation.
58
+ max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
59
+ Maximum length of the prompt used for truncation.
60
+ max_completion_length (`int` or `None`, *optional*, defaults to `None`):
61
+ Maximum length of the completion used for truncation. The completion is the concatenation of the steps.
62
+ disable_dropout (`bool`, *optional*, defaults to `True`):
63
+ Whether to disable dropout in the model.
64
+ step_separator (`str`, *optional*, defaults to `"\n"`):
65
+ Separator used to separate each step of the reasoning process.
66
+ train_on_last_step_only (`bool`, *optional*, defaults to `False`):
67
+ Whether to train only on the last step.
68
+ dataset_num_proc (`int`, *optional*, defaults to `None`):
69
+ Number of processes to use for processing the dataset.
70
+
71
+ """
72
+ vllm_sampling_params: Optional[Any] = field(
73
+ default = None,
74
+ metadata = {'help': 'vLLM SamplingParams'},
75
+ )
76
+ unsloth_num_chunks : Optional[int] = field(
77
+ default = -1,
78
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
79
+ )
80
+ def __init__(
81
+ self,
82
+ output_dir = None,
83
+ overwrite_output_dir = None,
84
+ do_train = False,
85
+ do_eval = False,
86
+ do_predict = False,
87
+ eval_strategy = 'no',
88
+ prediction_loss_only = False,
89
+ per_device_train_batch_size = 4,
90
+ per_device_eval_batch_size = 4,
91
+ per_gpu_train_batch_size = None,
92
+ per_gpu_eval_batch_size = None,
93
+ gradient_accumulation_steps = 2,
94
+ eval_accumulation_steps = 2,
95
+ eval_delay = 0,
96
+ torch_empty_cache_steps = 250,
97
+ learning_rate = 5e-05,
98
+ weight_decay = 0.01,
99
+ adam_beta1 = 0.9,
100
+ adam_beta2 = 0.999,
101
+ adam_epsilon = 1e-08,
102
+ max_grad_norm = 1.0,
103
+ num_train_epochs = 3.0,
104
+ max_steps = -1,
105
+ lr_scheduler_type = 'linear',
106
+ warmup_ratio = 0.1,
107
+ warmup_steps = 0,
108
+ log_level = 'passive',
109
+ log_level_replica = 'warning',
110
+ log_on_each_node = True,
111
+ logging_dir = None,
112
+ logging_strategy = 'steps',
113
+ logging_first_step = False,
114
+ logging_steps = 1,
115
+ logging_nan_inf_filter = False,
116
+ save_strategy = 'steps',
117
+ save_steps = 500,
118
+ save_total_limit = None,
119
+ save_safetensors = True,
120
+ save_on_each_node = False,
121
+ save_only_model = False,
122
+ restore_callback_states_from_checkpoint = False,
123
+ no_cuda = False,
124
+ use_cpu = False,
125
+ use_mps_device = False,
126
+ seed = 3407,
127
+ data_seed = 3407,
128
+ jit_mode_eval = False,
129
+ use_ipex = False,
130
+ bf16 = False,
131
+ fp16 = False,
132
+ fp16_opt_level = 'O1',
133
+ half_precision_backend = 'auto',
134
+ bf16_full_eval = False,
135
+ fp16_full_eval = False,
136
+ tf32 = None,
137
+ local_rank = -1,
138
+ ddp_backend = None,
139
+ tpu_num_cores = None,
140
+ tpu_metrics_debug = False,
141
+ debug = '',
142
+ dataloader_drop_last = False,
143
+ eval_steps = None,
144
+ dataloader_num_workers = 0,
145
+ dataloader_prefetch_factor = None,
146
+ past_index = -1,
147
+ run_name = None,
148
+ disable_tqdm = None,
149
+ remove_unused_columns = True,
150
+ label_names = None,
151
+ load_best_model_at_end = False,
152
+ metric_for_best_model = None,
153
+ greater_is_better = None,
154
+ ignore_data_skip = False,
155
+ fsdp = '',
156
+ fsdp_min_num_params = 0,
157
+ fsdp_config = None,
158
+ tp_size = 0,
159
+ fsdp_transformer_layer_cls_to_wrap = None,
160
+ accelerator_config = None,
161
+ deepspeed = None,
162
+ label_smoothing_factor = 0.0,
163
+ optim = 'adamw_8bit',
164
+ optim_args = None,
165
+ adafactor = False,
166
+ group_by_length = False,
167
+ length_column_name = 'length',
168
+ report_to = None,
169
+ ddp_find_unused_parameters = None,
170
+ ddp_bucket_cap_mb = None,
171
+ ddp_broadcast_buffers = None,
172
+ dataloader_pin_memory = True,
173
+ dataloader_persistent_workers = False,
174
+ skip_memory_metrics = True,
175
+ use_legacy_prediction_loop = False,
176
+ push_to_hub = False,
177
+ resume_from_checkpoint = None,
178
+ hub_model_id = None,
179
+ hub_strategy = 'every_save',
180
+ hub_token = None,
181
+ hub_private_repo = None,
182
+ hub_always_push = False,
183
+ gradient_checkpointing = False,
184
+ gradient_checkpointing_kwargs = None,
185
+ include_inputs_for_metrics = False,
186
+ eval_do_concat_batches = True,
187
+ fp16_backend = 'auto',
188
+ evaluation_strategy = None,
189
+ push_to_hub_model_id = None,
190
+ push_to_hub_organization = None,
191
+ push_to_hub_token = None,
192
+ mp_parameters = '',
193
+ auto_find_batch_size = False,
194
+ full_determinism = False,
195
+ torchdynamo = None,
196
+ ray_scope = 'last',
197
+ ddp_timeout = 1800,
198
+ torch_compile = False,
199
+ torch_compile_backend = None,
200
+ torch_compile_mode = None,
201
+ dispatch_batches = None,
202
+ split_batches = None,
203
+ include_tokens_per_second = False,
204
+ include_num_input_tokens_seen = False,
205
+ neftune_noise_alpha = None,
206
+ optim_target_modules = None,
207
+ batch_eval_metrics = False,
208
+ eval_on_start = False,
209
+ use_liger_kernel = False,
210
+ eval_use_gather_object = False,
211
+ average_tokens_across_devices = False,
212
+ max_length = 1024,
213
+ max_prompt_length = 512,
214
+ max_completion_length = None,
215
+ disable_dropout = True,
216
+ step_separator = '\
217
+ ',
218
+ train_on_last_step_only = False,
219
+ dataset_num_proc = None,
220
+ vllm_sampling_params = None,
221
+ unsloth_num_chunks = -1,
222
+ **kwargs,
223
+ ):
224
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
225
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
226
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
227
+ output_dir = 'unsloth_training_checkpoints'
228
+ save_strategy = 'no'
229
+ if dataset_num_proc is None:
230
+ from multiprocessing import cpu_count
231
+ dataset_num_proc = cpu_count()
232
+
233
+ super().__init__(
234
+ output_dir = output_dir,
235
+ overwrite_output_dir = overwrite_output_dir,
236
+ do_train = do_train,
237
+ do_eval = do_eval,
238
+ do_predict = do_predict,
239
+ eval_strategy = eval_strategy,
240
+ prediction_loss_only = prediction_loss_only,
241
+ per_device_train_batch_size = per_device_train_batch_size,
242
+ per_device_eval_batch_size = per_device_eval_batch_size,
243
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
244
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
245
+ gradient_accumulation_steps = gradient_accumulation_steps,
246
+ eval_accumulation_steps = eval_accumulation_steps,
247
+ eval_delay = eval_delay,
248
+ torch_empty_cache_steps = torch_empty_cache_steps,
249
+ learning_rate = learning_rate,
250
+ weight_decay = weight_decay,
251
+ adam_beta1 = adam_beta1,
252
+ adam_beta2 = adam_beta2,
253
+ adam_epsilon = adam_epsilon,
254
+ max_grad_norm = max_grad_norm,
255
+ num_train_epochs = num_train_epochs,
256
+ max_steps = max_steps,
257
+ lr_scheduler_type = lr_scheduler_type,
258
+ warmup_ratio = warmup_ratio,
259
+ warmup_steps = warmup_steps,
260
+ log_level = log_level,
261
+ log_level_replica = log_level_replica,
262
+ log_on_each_node = log_on_each_node,
263
+ logging_dir = logging_dir,
264
+ logging_strategy = logging_strategy,
265
+ logging_first_step = logging_first_step,
266
+ logging_steps = logging_steps,
267
+ logging_nan_inf_filter = logging_nan_inf_filter,
268
+ save_strategy = save_strategy,
269
+ save_steps = save_steps,
270
+ save_total_limit = save_total_limit,
271
+ save_safetensors = save_safetensors,
272
+ save_on_each_node = save_on_each_node,
273
+ save_only_model = save_only_model,
274
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
275
+ no_cuda = no_cuda,
276
+ use_cpu = use_cpu,
277
+ use_mps_device = use_mps_device,
278
+ seed = seed,
279
+ data_seed = data_seed,
280
+ jit_mode_eval = jit_mode_eval,
281
+ use_ipex = use_ipex,
282
+ bf16 = bf16,
283
+ fp16 = fp16,
284
+ fp16_opt_level = fp16_opt_level,
285
+ half_precision_backend = half_precision_backend,
286
+ bf16_full_eval = bf16_full_eval,
287
+ fp16_full_eval = fp16_full_eval,
288
+ tf32 = tf32,
289
+ local_rank = local_rank,
290
+ ddp_backend = ddp_backend,
291
+ tpu_num_cores = tpu_num_cores,
292
+ tpu_metrics_debug = tpu_metrics_debug,
293
+ debug = debug,
294
+ dataloader_drop_last = dataloader_drop_last,
295
+ eval_steps = eval_steps,
296
+ dataloader_num_workers = dataloader_num_workers,
297
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
298
+ past_index = past_index,
299
+ run_name = run_name,
300
+ disable_tqdm = disable_tqdm,
301
+ remove_unused_columns = remove_unused_columns,
302
+ label_names = label_names,
303
+ load_best_model_at_end = load_best_model_at_end,
304
+ metric_for_best_model = metric_for_best_model,
305
+ greater_is_better = greater_is_better,
306
+ ignore_data_skip = ignore_data_skip,
307
+ fsdp = fsdp,
308
+ fsdp_min_num_params = fsdp_min_num_params,
309
+ fsdp_config = fsdp_config,
310
+ tp_size = tp_size,
311
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
312
+ accelerator_config = accelerator_config,
313
+ deepspeed = deepspeed,
314
+ label_smoothing_factor = label_smoothing_factor,
315
+ optim = optim,
316
+ optim_args = optim_args,
317
+ adafactor = adafactor,
318
+ group_by_length = group_by_length,
319
+ length_column_name = length_column_name,
320
+ report_to = report_to,
321
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
322
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
323
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
324
+ dataloader_pin_memory = dataloader_pin_memory,
325
+ dataloader_persistent_workers = dataloader_persistent_workers,
326
+ skip_memory_metrics = skip_memory_metrics,
327
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
328
+ push_to_hub = push_to_hub,
329
+ resume_from_checkpoint = resume_from_checkpoint,
330
+ hub_model_id = hub_model_id,
331
+ hub_strategy = hub_strategy,
332
+ hub_token = hub_token,
333
+ hub_private_repo = hub_private_repo,
334
+ hub_always_push = hub_always_push,
335
+ gradient_checkpointing = gradient_checkpointing,
336
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
337
+ include_inputs_for_metrics = include_inputs_for_metrics,
338
+ eval_do_concat_batches = eval_do_concat_batches,
339
+ fp16_backend = fp16_backend,
340
+ evaluation_strategy = evaluation_strategy,
341
+ push_to_hub_model_id = push_to_hub_model_id,
342
+ push_to_hub_organization = push_to_hub_organization,
343
+ push_to_hub_token = push_to_hub_token,
344
+ mp_parameters = mp_parameters,
345
+ auto_find_batch_size = auto_find_batch_size,
346
+ full_determinism = full_determinism,
347
+ torchdynamo = torchdynamo,
348
+ ray_scope = ray_scope,
349
+ ddp_timeout = ddp_timeout,
350
+ torch_compile = torch_compile,
351
+ torch_compile_backend = torch_compile_backend,
352
+ torch_compile_mode = torch_compile_mode,
353
+ dispatch_batches = dispatch_batches,
354
+ split_batches = split_batches,
355
+ include_tokens_per_second = include_tokens_per_second,
356
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
357
+ neftune_noise_alpha = neftune_noise_alpha,
358
+ optim_target_modules = optim_target_modules,
359
+ batch_eval_metrics = batch_eval_metrics,
360
+ eval_on_start = eval_on_start,
361
+ use_liger_kernel = use_liger_kernel,
362
+ eval_use_gather_object = eval_use_gather_object,
363
+ average_tokens_across_devices = average_tokens_across_devices,
364
+ max_length = max_length,
365
+ max_prompt_length = max_prompt_length,
366
+ max_completion_length = max_completion_length,
367
+ disable_dropout = disable_dropout,
368
+ step_separator = step_separator,
369
+ train_on_last_step_only = train_on_last_step_only,
370
+ dataset_num_proc = dataset_num_proc,**kwargs)
371
+ self.vllm_sampling_params = vllm_sampling_params
372
+ self.unsloth_num_chunks = unsloth_num_chunks
373
+ pass
374
+
375
+ class _UnslothPRMTrainer(Trainer):
376
+ """"""
377
+
378
+ _tag_names = ["trl", "prm"]
379
+
380
+ def __init__(
381
+ self,
382
+ model: Optional[Union[PreTrainedModel, nn.Module]] = None,
383
+ args: Optional[PRMConfig] = None,
384
+ data_collator: Optional[DataCollator] = None,
385
+ train_dataset: Optional[Dataset] = None,
386
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
387
+ processing_class: Optional[
388
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
389
+ ] = None,
390
+ model_init: Optional[Callable[[], PreTrainedModel]] = None,
391
+ compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
392
+ callbacks: Optional[list[TrainerCallback]] = None,
393
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (
394
+ None,
395
+ None,
396
+ ),
397
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
398
+ peft_config: Optional[dict] = None,
399
+ ):
400
+ if not is_peft_available() and peft_config is not None:
401
+ raise ValueError(
402
+ "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
403
+ )
404
+ elif is_peft_available() and peft_config is not None:
405
+ if not isinstance(model, PeftModel):
406
+ if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_quantized", False):
407
+ _supports_gc_kwargs = "gradient_checkpointing_kwargs" in list(
408
+ inspect.signature(prepare_model_for_kbit_training).parameters
409
+ )
410
+
411
+ prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
412
+
413
+ if not _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
414
+ warnings.warn(
415
+ "You passed `gradient_checkpointing_kwargs` in the trainer's kwargs, but your peft version does not support it. "
416
+ "please update to the latest version of peft to use `gradient_checkpointing_kwargs`."
417
+ )
418
+ elif _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
419
+ prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
420
+
421
+ model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
422
+
423
+ model = model
424
+
425
+ # Disable dropout in the model
426
+ if args.disable_dropout:
427
+ disable_dropout_in_model(model)
428
+
429
+ if compute_metrics is None:
430
+ compute_metrics = compute_accuracy
431
+
432
+ if data_collator is None:
433
+ if processing_class is None:
434
+ raise ValueError(
435
+ "A processing_class must be specified when using the default DataCollatorForTokenClassification"
436
+ )
437
+ data_collator = DataCollatorForTokenClassification(processing_class, max_length=args.max_length)
438
+
439
+ if "input_ids" not in train_dataset.column_names:
440
+ with PartialState().local_main_process_first():
441
+ fn_kwargs = {
442
+ "tokenizer": processing_class,
443
+ "step_separator": args.step_separator,
444
+ "max_length": args.max_length,
445
+ "max_prompt_length": args.max_prompt_length,
446
+ "max_completion_length": args.max_completion_length,
447
+ "train_on_last_step_only": args.train_on_last_step_only,
448
+ }
449
+ train_fn_kwargs = {**fn_kwargs, "is_eval": False}
450
+ train_dataset = train_dataset.map(
451
+ self.tokenize_row,
452
+ fn_kwargs=train_fn_kwargs,
453
+ num_proc=args.dataset_num_proc,
454
+ remove_columns=train_dataset.features,
455
+ desc="Tokenizing train dataset",
456
+ features=features.Features( # needed to avoid map to cast labels to bool
457
+ {
458
+ "labels": features.Sequence(features.Value("int64")),
459
+ "input_ids": features.Sequence(features.Value("int64")),
460
+ }
461
+ ),
462
+ )
463
+
464
+ eval_fn_kwargs = {**fn_kwargs, "is_eval": True}
465
+ if eval_dataset is not None:
466
+ eval_dataset = eval_dataset.map(
467
+ self.tokenize_row,
468
+ fn_kwargs=eval_fn_kwargs,
469
+ num_proc=args.dataset_num_proc,
470
+ remove_columns=eval_dataset.features,
471
+ desc="Tokenizing eval dataset",
472
+ features=features.Features( # needed to avoid map to cast labels to bool
473
+ {
474
+ "labels": features.Sequence(features.Value("int64")),
475
+ "input_ids": features.Sequence(features.Value("int64")),
476
+ }
477
+ ),
478
+ )
479
+
480
+ super().__init__(
481
+ model=model,
482
+ args=args,
483
+ data_collator=data_collator,
484
+ train_dataset=train_dataset,
485
+ eval_dataset=eval_dataset,
486
+ processing_class=processing_class,
487
+ model_init=model_init,
488
+ compute_metrics=compute_metrics,
489
+ callbacks=callbacks,
490
+ optimizers=optimizers,
491
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
492
+ )
493
+
494
+ # Add tags for models that have been loaded with the correct transformers version
495
+ if hasattr(self.model, "add_model_tags"):
496
+ self.model.add_model_tags(self._tag_names)
497
+
498
+ @staticmethod
499
+ def tokenize_row(
500
+ features,
501
+ tokenizer,
502
+ step_separator,
503
+ max_length,
504
+ max_prompt_length,
505
+ max_completion_length,
506
+ train_on_last_step_only,
507
+ is_eval,
508
+ ):
509
+ r"""
510
+ Tokenize a row of the dataset.
511
+
512
+ Args:
513
+ features (`dict[str, str]`):
514
+ Row of the dataset, should contain the keys `"prompt"`, `"completions"`, and `"labels"`.
515
+ tokenizer (`PreTrainedTokenizerBase`):
516
+ Tokenizer used to process the data.
517
+ step_separator (`str`):
518
+ Separator between steps in the completion.
519
+ max_length (`int` or `None`):
520
+ Maximum length of the sequences (prompt + completion). If `None`, the sequences are not truncated.
521
+ max_prompt_length (`int` or `None`):
522
+ Maximum length of the prompt. If `None`, the prompt is not truncated.
523
+ max_completion_length (`int` or `None`):
524
+ Maximum length of the completion sequences. If `None`, the completion sequences are not truncated.
525
+ train_on_last_step_only (`bool`):
526
+ Whether to train only on the last step. If `True`, the labels are `-100` for all tokens except the last
527
+ token of the completion.
528
+ is_eval (`bool`):
529
+ Whether the function is used to tokenize samples from a training or an evaluation dataset. Used only if `train_on_last_step_only` is set to `True`.
530
+
531
+ Returns:
532
+ `dict[str, list[int]]`:
533
+ Tokenized sequences with the keys `"input_ids"`, and `"labels".
534
+
535
+ Example:
536
+ ```python
537
+ >>> from transformers import AutoTokenizer
538
+ >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
539
+ >>> features = {"prompt": "Which number is larger, 9.8 or 9.11?",
540
+ ... "completions": ["11 is greater than 8.",
541
+ ... "Hence, 9.11 > 9.8."],
542
+ ... "labels": [True, False]}
543
+ >>> PRMTrainer.tokenize_row(features, tokenizer, "\n", max_completion_length=None, train_on_last_step_only=False, is_eval=False)
544
+ {'input_ids': [23085, 1372, 374, 8131, 11, 220, 24, 13, 23, 476, 220, 24, 13, 16, 16, 30, 16, 16, 374, 7046, 1091, 220, 23, 13, 198, 39, 763, 11, 220, 24, 13, 16, 16, 861, 220, 24, 13, 23, 13, 198],
545
+ 'labels': [-100, -100, -100, -100, -100, -100, -100, -100, 1, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 0]}
546
+ ```
547
+ """
548
+ # Tokenize the prompt and completions
549
+ prompt_ids = tokenizer(features["prompt"], add_special_tokens=False)["input_ids"]
550
+ completions_ids = [
551
+ tokenizer(completion, add_special_tokens=False)["input_ids"] for completion in features["completions"]
552
+ ]
553
+ if train_on_last_step_only and not is_eval:
554
+ labels = [-100] * (len(features["labels"]) - 1) + [int(features["labels"][-1])]
555
+ else:
556
+ labels = [int(label) for label in features["labels"]]
557
+
558
+ # Get the ID of the separator token and add it to the completions
559
+ separator_ids = tokenizer.encode(step_separator, add_special_tokens=False)
560
+ completions_ids = [completion + separator_ids for completion in completions_ids]
561
+
562
+ # Create the label
563
+ labels = [[-100] * (len(completion) - 1) + [label] for completion, label in zip(completions_ids, labels)]
564
+
565
+ # Join the completions and labels steps
566
+ completion_ids = list(chain(*completions_ids))
567
+ labels = list(chain(*labels))
568
+
569
+ if tokenizer.bos_token_id is not None:
570
+ prompt_ids = [tokenizer.bos_token_id] + prompt_ids
571
+
572
+ # Truncate prompt and completion sequences
573
+ if max_prompt_length is not None:
574
+ prompt_ids = prompt_ids[-max_prompt_length:]
575
+ if max_completion_length is not None:
576
+ completion_ids = completion_ids[:max_completion_length]
577
+ labels = labels[:max_completion_length]
578
+
579
+ input_ids = prompt_ids + completion_ids
580
+ labels = [-100] * len(prompt_ids) + labels
581
+
582
+ if max_length is not None:
583
+ input_ids = input_ids[:max_length]
584
+ labels = labels[:max_length]
585
+
586
+ return {"input_ids": input_ids, "labels": labels}
587
+
588
+ def create_model_card(
589
+ self,
590
+ model_name: Optional[str] = None,
591
+ dataset_name: Optional[str] = None,
592
+ tags: Union[str, list[str], None] = None,
593
+ ):
594
+ """
595
+ Creates a draft of a model card using the information available to the `Trainer`.
596
+
597
+ Args:
598
+ model_name (`str` or `None`, *optional*, defaults to `None`):
599
+ Name of the model.
600
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
601
+ Name of the dataset used for training.
602
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
603
+ Tags to be associated with the model card.
604
+ """
605
+ if not self.is_world_process_zero():
606
+ return
607
+
608
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
609
+ base_model = self.model.config._name_or_path
610
+ else:
611
+ base_model = None
612
+
613
+ tags = tags or []
614
+ if isinstance(tags, str):
615
+ tags = [tags]
616
+
617
+ if hasattr(self.model.config, "unsloth_version"):
618
+ tags.append("unsloth")
619
+
620
+ citation = textwrap.dedent("""\
621
+ @article{uesato2022solving,
622
+ title = {{Solving Math Word Problems With Process- and Outcome-Based Feedback}},
623
+ author = {Uesato, Jonathan and Kushman, Nate and Kumar, Ramana and Song, Francis and Siegel, Noah and Wang, Lisa and Creswell, Antonia and Irving, Geoffrey and Higgins, Irina},
624
+ year = 2022,
625
+ journal = {arXiv preprint arXiv:2211.14275}
626
+ }""")
627
+
628
+ model_card = generate_model_card(
629
+ base_model=base_model,
630
+ model_name=model_name,
631
+ hub_model_id=self.hub_model_id,
632
+ dataset_name=dataset_name,
633
+ tags=tags,
634
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
635
+ trainer_name="PRM",
636
+ trainer_citation=citation,
637
+ paper_title="Solving math word problems with process-and outcome-based feedback",
638
+ )
639
+
640
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
641
+ class UnslothPRMTrainer(_UnslothPRMTrainer):
642
+ """
643
+
644
+ Initialize PRMTrainer.
645
+
646
+ Args:
647
+ model (`transformers.PreTrainedModel`):
648
+ The model to train, preferably an `AutoModelForTokenClassification`.
649
+ args (`PRMConfig`):
650
+ The arguments to use for training.
651
+ data_collator (`transformers.DataCollator`):
652
+ The data collator to use for training. If None is specified, the default data collator (`DataCollatorForTokenClassification`) will be used
653
+ which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
654
+ train_dataset (`datasets.Dataset`):
655
+ The dataset to use for training.
656
+ eval_dataset (`datasets.Dataset`):
657
+ The dataset to use for evaluation.
658
+ processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
659
+ Processing class used to process the data. If provided, will be used to automatically process the inputs
660
+ for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
661
+ reuse the fine-tuned model.
662
+ model_init (`Callable[[], transformers.PreTrainedModel]`):
663
+ The model initializer to use for training. If None is specified, the default model initializer will be used.
664
+ compute_metrics (`Callable[[transformers.EvalPrediction], dict]`, *optional* defaults to `compute_accuracy`):
665
+ The metrics to use for evaluation. If no metrics are specified, the default metric (`compute_accuracy`) will be used.
666
+ callbacks (`list[transformers.TrainerCallback]`):
667
+ The callbacks to use for training.
668
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
669
+ The optimizer and scheduler to use for training.
670
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
671
+ The function to use to preprocess the logits before computing the metrics.
672
+ peft_config (`dict`, defaults to `None`):
673
+ The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
674
+
675
+ """
676
+ def __init__(
677
+ self,
678
+ model = None,
679
+ args = None,
680
+ data_collator = None,
681
+ train_dataset = None,
682
+ eval_dataset = None,
683
+ processing_class = None,
684
+ model_init = None,
685
+ compute_metrics = None,
686
+ callbacks = None,
687
+ preprocess_logits_for_metrics = None,
688
+ peft_config = None,
689
+ **kwargs
690
+ ):
691
+ if args is None: args = UnslothPRMConfig()
692
+ use_bf16 = getattr(args, 'bf16', False)
693
+ use_fp16 = getattr(args, 'fp16', False)
694
+ force_float32 = False
695
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
696
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
697
+ force_float32 = True
698
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
699
+ dtype = getattr(model.config, 'torch_dtype', None)
700
+ if dtype is None: dtype = model.get_input_embeddings().dtype
701
+ from unsloth_zoo.utils import _get_dtype
702
+ dtype = _get_dtype(dtype)
703
+ float16 = dtype == torch.float16
704
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
705
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
706
+ if force_float32:
707
+ args.fp16 = False
708
+ args.bf16 = False
709
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
710
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
711
+ args.fp16 = float16
712
+ args.bf16 = not float16
713
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
714
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
715
+ args.eval_strategy = 'steps'
716
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
717
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
718
+ if ga_steps is not None and ga_steps > 1:
719
+ from transformers import __version__ as transformers_version
720
+ if Version(transformers_version) <= Version('4.45.2'):
721
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
722
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
723
+ if getattr(args, 'eval_strategy', 'no') != 'no':
724
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
725
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
726
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
727
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
728
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
729
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
730
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
731
+ if force_float32:
732
+ args.bf16_full_eval = False
733
+ args.fp16_full_eval = False
734
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
735
+ args.bf16_full_eval = True
736
+ args.fp16_full_eval = False
737
+ elif not bf16_full_eval and not fp16_full_eval:
738
+ args.bf16_full_eval = args.bf16
739
+ args.fp16_full_eval = args.fp16
740
+ _output_logits = False
741
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
742
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
743
+ if _output_logits:
744
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
745
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
746
+ pass
747
+ else:
748
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
749
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
750
+ if args_max_seq_length is None and model_max_seq_length is not None:
751
+ max_seq_length = model.max_seq_length
752
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
753
+ if model is not None and hasattr(model, 'for_training'):
754
+ model.for_training()
755
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
756
+ if 'processing_class' in locals():
757
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
758
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
759
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
760
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
761
+ if not isinstance(data_collator, UnslothVisionDataCollator):
762
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
763
+ data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
764
+ elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
765
+ data_collator = DataCollatorForSeq2Seq(__tokenizer)
766
+ else:
767
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
768
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
769
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
770
+ if not isinstance(data_collator, UnslothVisionDataCollator):
771
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
772
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
773
+ data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
774
+ else:
775
+ data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
776
+ other_metrics = []
777
+
778
+ from unsloth_zoo.logging_utils import PatchRLStatistics
779
+ PatchRLStatistics('prm_trainer', other_metrics)
780
+
781
+ super().__init__(
782
+ model = model,
783
+ args = args,
784
+ data_collator = data_collator,
785
+ train_dataset = train_dataset,
786
+ eval_dataset = eval_dataset,
787
+ processing_class = processing_class,
788
+ model_init = model_init,
789
+ compute_metrics = compute_metrics,
790
+ callbacks = callbacks,
791
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics,
792
+ peft_config = peft_config,**kwargs)
793
+ if hasattr(self, 'neftune_hook_handle'):
794
+ self.neftune_hook_handle.remove()
795
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
796
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
797
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
798
+ pass
799
+
800
+ pass
unsloth_compiled_cache/UnslothRLOOTrainer.py ADDED
@@ -0,0 +1,1133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.3.17
3
+ 2025.3.19
4
+ 4.50.3
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from trl.trainer.rloo_trainer import (Accelerator, BaseImageProcessor, Callable, CallbackHandler, DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK, DataCollatorWithPadding, DataLoader, Dataset, ExportableState, FeatureExtractionMixin, GenerationConfig, INVALID_LOGPROB, OnlineTrainerState, Optional, PreTrainedTokenizerBase, PrinterCallback, ProcessorMixin, RLOOConfig, RLOOTrainer, Trainer, TrainerCallback, TrainerControl, Union, batch_generation, broadcast, defaultdict, disable_dropout_in_model, exact_div, first_true_indices, forward, gather_object, gc, generate_model_card, get_comet_experiment_url, get_reporting_integration_callbacks, get_reward, is_wandb_available, log_table_to_comet_experiment, math, nn, np, os, pd, prepare_deepspeed, print_rich_table, textwrap, time, torch, truncate_response, unwrap_model_for_generation, wandb)
13
+
14
+
15
+ import os
16
+ from typing import *
17
+ from dataclasses import dataclass, field
18
+ from packaging.version import Version
19
+ import torch
20
+ import numpy as np
21
+ from contextlib import nullcontext
22
+ from torch.nn import functional as F
23
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
+
25
+ torch_compile_options = {
26
+ "epilogue_fusion" : True,
27
+ "max_autotune" : False,
28
+ "shape_padding" : True,
29
+ "trace.enabled" : False,
30
+ "triton.cudagraphs" : False,
31
+ }
32
+
33
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
+ def selective_log_softmax(logits, index):
35
+ logits = logits.to(torch.float32)
36
+ selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
+ # loop to reduce peak mem consumption
38
+ # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
+ logsumexp_values = torch.logsumexp(logits, dim = -1)
40
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
+ return per_token_logps
42
+ @dataclass
43
+ class UnslothRLOOConfig(RLOOConfig):
44
+ """
45
+
46
+ Configuration class for the [`RLOOTrainer`].
47
+
48
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
49
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
50
+ command line.
51
+
52
+ Parameters:
53
+ exp_name (`str`, *optional*, defaults to `os.path.basename(__file__)[: -len(".py")]`):
54
+ Name of this experiment.
55
+ reward_model_path (`str`, *optional*, defaults to `"EleutherAI/pythia-160m"`):
56
+ Path to the reward model.
57
+ num_ppo_epochs (`int`, *optional*, defaults to `4`):
58
+ Number of epochs to train.
59
+ whiten_rewards (`bool`, *optional*, defaults to `False`):
60
+ Whether to whiten the rewards.
61
+ kl_coef (`float`, *optional*, defaults to `0.05`):
62
+ KL coefficient.
63
+ cliprange (`float`, *optional*, defaults to `0.2`):
64
+ Clip range.
65
+ rloo_k (`int`, *optional*, defaults to `2`):
66
+ REINFORCE Leave-One-Out (RLOO) number of online samples per prompt.
67
+ normalize_reward (`bool`, *optional*, defaults to `False`):
68
+ Whether to normalize rewards.
69
+ reward_clip_range (`float`, *optional*, defaults to `10.0`):
70
+ Clip range for rewards.
71
+ normalize_advantage (`bool`, *optional*, defaults to `False`):
72
+ Whether to normalize advantages.
73
+ token_level_kl (`bool`, *optional*, defaults to `True`):
74
+ Whether to use token-level KL penalty or sequence-level KL penalty.
75
+ ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
76
+ This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
77
+ improving generation speed. However, disabling this option allows training models that exceed the VRAM
78
+ capacity of a single GPU, albeit at the cost of slower generation.
79
+
80
+ """
81
+ vllm_sampling_params: Optional[Any] = field(
82
+ default = None,
83
+ metadata = {'help': 'vLLM SamplingParams'},
84
+ )
85
+ unsloth_num_chunks : Optional[int] = field(
86
+ default = -1,
87
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
88
+ )
89
+ def __init__(
90
+ self,
91
+ output_dir = None,
92
+ overwrite_output_dir = None,
93
+ do_train = False,
94
+ do_eval = False,
95
+ do_predict = False,
96
+ eval_strategy = 'no',
97
+ prediction_loss_only = False,
98
+ per_device_train_batch_size = 4,
99
+ per_device_eval_batch_size = 4,
100
+ per_gpu_train_batch_size = None,
101
+ per_gpu_eval_batch_size = None,
102
+ gradient_accumulation_steps = 2,
103
+ eval_accumulation_steps = 2,
104
+ eval_delay = 0,
105
+ torch_empty_cache_steps = 250,
106
+ learning_rate = 5e-05,
107
+ weight_decay = 0.01,
108
+ adam_beta1 = 0.9,
109
+ adam_beta2 = 0.999,
110
+ adam_epsilon = 1e-08,
111
+ max_grad_norm = 1.0,
112
+ num_train_epochs = 3.0,
113
+ max_steps = -1,
114
+ lr_scheduler_type = 'linear',
115
+ warmup_ratio = 0.1,
116
+ warmup_steps = 0,
117
+ log_level = 'passive',
118
+ log_level_replica = 'warning',
119
+ log_on_each_node = True,
120
+ logging_dir = None,
121
+ logging_strategy = 'steps',
122
+ logging_first_step = False,
123
+ logging_steps = 1,
124
+ logging_nan_inf_filter = False,
125
+ save_strategy = 'steps',
126
+ save_steps = 500,
127
+ save_total_limit = None,
128
+ save_safetensors = True,
129
+ save_on_each_node = False,
130
+ save_only_model = False,
131
+ restore_callback_states_from_checkpoint = False,
132
+ no_cuda = False,
133
+ use_cpu = False,
134
+ use_mps_device = False,
135
+ seed = 3407,
136
+ data_seed = 3407,
137
+ jit_mode_eval = False,
138
+ use_ipex = False,
139
+ bf16 = False,
140
+ fp16 = False,
141
+ fp16_opt_level = 'O1',
142
+ half_precision_backend = 'auto',
143
+ bf16_full_eval = False,
144
+ fp16_full_eval = False,
145
+ tf32 = None,
146
+ local_rank = -1,
147
+ ddp_backend = None,
148
+ tpu_num_cores = None,
149
+ tpu_metrics_debug = False,
150
+ debug = '',
151
+ dataloader_drop_last = False,
152
+ eval_steps = None,
153
+ dataloader_num_workers = 0,
154
+ dataloader_prefetch_factor = None,
155
+ past_index = -1,
156
+ run_name = None,
157
+ disable_tqdm = None,
158
+ remove_unused_columns = True,
159
+ label_names = None,
160
+ load_best_model_at_end = False,
161
+ metric_for_best_model = None,
162
+ greater_is_better = None,
163
+ ignore_data_skip = False,
164
+ fsdp = '',
165
+ fsdp_min_num_params = 0,
166
+ fsdp_config = None,
167
+ tp_size = 0,
168
+ fsdp_transformer_layer_cls_to_wrap = None,
169
+ accelerator_config = None,
170
+ deepspeed = None,
171
+ label_smoothing_factor = 0.0,
172
+ optim = 'adamw_8bit',
173
+ optim_args = None,
174
+ adafactor = False,
175
+ group_by_length = False,
176
+ length_column_name = 'length',
177
+ report_to = None,
178
+ ddp_find_unused_parameters = None,
179
+ ddp_bucket_cap_mb = None,
180
+ ddp_broadcast_buffers = None,
181
+ dataloader_pin_memory = True,
182
+ dataloader_persistent_workers = False,
183
+ skip_memory_metrics = True,
184
+ use_legacy_prediction_loop = False,
185
+ push_to_hub = False,
186
+ resume_from_checkpoint = None,
187
+ hub_model_id = None,
188
+ hub_strategy = 'every_save',
189
+ hub_token = None,
190
+ hub_private_repo = None,
191
+ hub_always_push = False,
192
+ gradient_checkpointing = False,
193
+ gradient_checkpointing_kwargs = None,
194
+ include_inputs_for_metrics = False,
195
+ eval_do_concat_batches = True,
196
+ fp16_backend = 'auto',
197
+ evaluation_strategy = None,
198
+ push_to_hub_model_id = None,
199
+ push_to_hub_organization = None,
200
+ push_to_hub_token = None,
201
+ mp_parameters = '',
202
+ auto_find_batch_size = False,
203
+ full_determinism = False,
204
+ torchdynamo = None,
205
+ ray_scope = 'last',
206
+ ddp_timeout = 1800,
207
+ torch_compile = False,
208
+ torch_compile_backend = None,
209
+ torch_compile_mode = None,
210
+ dispatch_batches = None,
211
+ split_batches = None,
212
+ include_tokens_per_second = False,
213
+ include_num_input_tokens_seen = False,
214
+ neftune_noise_alpha = None,
215
+ optim_target_modules = None,
216
+ batch_eval_metrics = False,
217
+ eval_on_start = False,
218
+ use_liger_kernel = False,
219
+ eval_use_gather_object = False,
220
+ average_tokens_across_devices = False,
221
+ dataset_num_proc = None,
222
+ num_mini_batches = 1,
223
+ total_episodes = None,
224
+ local_rollout_forward_batch_size = 64,
225
+ num_sample_generations = 10,
226
+ response_length = 53,
227
+ stop_token = None,
228
+ stop_token_id = None,
229
+ temperature = 0.7,
230
+ missing_eos_penalty = None,
231
+ sft_model_path = 'EleutherAI/pythia-160m',
232
+ world_size = None,
233
+ num_total_batches = None,
234
+ micro_batch_size = None,
235
+ local_batch_size = None,
236
+ batch_size = None,
237
+ local_mini_batch_size = None,
238
+ mini_batch_size = None,
239
+ exp_name = 'rloo_config',
240
+ reward_model_path = 'EleutherAI/pythia-160m',
241
+ num_ppo_epochs = 4,
242
+ whiten_rewards = False,
243
+ kl_coef = 0.05,
244
+ cliprange = 0.2,
245
+ rloo_k = 2,
246
+ normalize_reward = False,
247
+ reward_clip_range = 10.0,
248
+ normalize_advantage = False,
249
+ token_level_kl = False,
250
+ ds3_gather_for_generation = True,
251
+ vllm_sampling_params = None,
252
+ unsloth_num_chunks = -1,
253
+ **kwargs,
254
+ ):
255
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
256
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
257
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
258
+ output_dir = 'unsloth_training_checkpoints'
259
+ save_strategy = 'no'
260
+ if dataset_num_proc is None:
261
+ from multiprocessing import cpu_count
262
+ dataset_num_proc = cpu_count()
263
+
264
+ super().__init__(
265
+ output_dir = output_dir,
266
+ overwrite_output_dir = overwrite_output_dir,
267
+ do_train = do_train,
268
+ do_eval = do_eval,
269
+ do_predict = do_predict,
270
+ eval_strategy = eval_strategy,
271
+ prediction_loss_only = prediction_loss_only,
272
+ per_device_train_batch_size = per_device_train_batch_size,
273
+ per_device_eval_batch_size = per_device_eval_batch_size,
274
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
275
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
276
+ gradient_accumulation_steps = gradient_accumulation_steps,
277
+ eval_accumulation_steps = eval_accumulation_steps,
278
+ eval_delay = eval_delay,
279
+ torch_empty_cache_steps = torch_empty_cache_steps,
280
+ learning_rate = learning_rate,
281
+ weight_decay = weight_decay,
282
+ adam_beta1 = adam_beta1,
283
+ adam_beta2 = adam_beta2,
284
+ adam_epsilon = adam_epsilon,
285
+ max_grad_norm = max_grad_norm,
286
+ num_train_epochs = num_train_epochs,
287
+ max_steps = max_steps,
288
+ lr_scheduler_type = lr_scheduler_type,
289
+ warmup_ratio = warmup_ratio,
290
+ warmup_steps = warmup_steps,
291
+ log_level = log_level,
292
+ log_level_replica = log_level_replica,
293
+ log_on_each_node = log_on_each_node,
294
+ logging_dir = logging_dir,
295
+ logging_strategy = logging_strategy,
296
+ logging_first_step = logging_first_step,
297
+ logging_steps = logging_steps,
298
+ logging_nan_inf_filter = logging_nan_inf_filter,
299
+ save_strategy = save_strategy,
300
+ save_steps = save_steps,
301
+ save_total_limit = save_total_limit,
302
+ save_safetensors = save_safetensors,
303
+ save_on_each_node = save_on_each_node,
304
+ save_only_model = save_only_model,
305
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
306
+ no_cuda = no_cuda,
307
+ use_cpu = use_cpu,
308
+ use_mps_device = use_mps_device,
309
+ seed = seed,
310
+ data_seed = data_seed,
311
+ jit_mode_eval = jit_mode_eval,
312
+ use_ipex = use_ipex,
313
+ bf16 = bf16,
314
+ fp16 = fp16,
315
+ fp16_opt_level = fp16_opt_level,
316
+ half_precision_backend = half_precision_backend,
317
+ bf16_full_eval = bf16_full_eval,
318
+ fp16_full_eval = fp16_full_eval,
319
+ tf32 = tf32,
320
+ local_rank = local_rank,
321
+ ddp_backend = ddp_backend,
322
+ tpu_num_cores = tpu_num_cores,
323
+ tpu_metrics_debug = tpu_metrics_debug,
324
+ debug = debug,
325
+ dataloader_drop_last = dataloader_drop_last,
326
+ eval_steps = eval_steps,
327
+ dataloader_num_workers = dataloader_num_workers,
328
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
329
+ past_index = past_index,
330
+ run_name = run_name,
331
+ disable_tqdm = disable_tqdm,
332
+ remove_unused_columns = remove_unused_columns,
333
+ label_names = label_names,
334
+ load_best_model_at_end = load_best_model_at_end,
335
+ metric_for_best_model = metric_for_best_model,
336
+ greater_is_better = greater_is_better,
337
+ ignore_data_skip = ignore_data_skip,
338
+ fsdp = fsdp,
339
+ fsdp_min_num_params = fsdp_min_num_params,
340
+ fsdp_config = fsdp_config,
341
+ tp_size = tp_size,
342
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
343
+ accelerator_config = accelerator_config,
344
+ deepspeed = deepspeed,
345
+ label_smoothing_factor = label_smoothing_factor,
346
+ optim = optim,
347
+ optim_args = optim_args,
348
+ adafactor = adafactor,
349
+ group_by_length = group_by_length,
350
+ length_column_name = length_column_name,
351
+ report_to = report_to,
352
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
353
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
354
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
355
+ dataloader_pin_memory = dataloader_pin_memory,
356
+ dataloader_persistent_workers = dataloader_persistent_workers,
357
+ skip_memory_metrics = skip_memory_metrics,
358
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
359
+ push_to_hub = push_to_hub,
360
+ resume_from_checkpoint = resume_from_checkpoint,
361
+ hub_model_id = hub_model_id,
362
+ hub_strategy = hub_strategy,
363
+ hub_token = hub_token,
364
+ hub_private_repo = hub_private_repo,
365
+ hub_always_push = hub_always_push,
366
+ gradient_checkpointing = gradient_checkpointing,
367
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
368
+ include_inputs_for_metrics = include_inputs_for_metrics,
369
+ eval_do_concat_batches = eval_do_concat_batches,
370
+ fp16_backend = fp16_backend,
371
+ evaluation_strategy = evaluation_strategy,
372
+ push_to_hub_model_id = push_to_hub_model_id,
373
+ push_to_hub_organization = push_to_hub_organization,
374
+ push_to_hub_token = push_to_hub_token,
375
+ mp_parameters = mp_parameters,
376
+ auto_find_batch_size = auto_find_batch_size,
377
+ full_determinism = full_determinism,
378
+ torchdynamo = torchdynamo,
379
+ ray_scope = ray_scope,
380
+ ddp_timeout = ddp_timeout,
381
+ torch_compile = torch_compile,
382
+ torch_compile_backend = torch_compile_backend,
383
+ torch_compile_mode = torch_compile_mode,
384
+ dispatch_batches = dispatch_batches,
385
+ split_batches = split_batches,
386
+ include_tokens_per_second = include_tokens_per_second,
387
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
388
+ neftune_noise_alpha = neftune_noise_alpha,
389
+ optim_target_modules = optim_target_modules,
390
+ batch_eval_metrics = batch_eval_metrics,
391
+ eval_on_start = eval_on_start,
392
+ use_liger_kernel = use_liger_kernel,
393
+ eval_use_gather_object = eval_use_gather_object,
394
+ average_tokens_across_devices = average_tokens_across_devices,
395
+ dataset_num_proc = dataset_num_proc,
396
+ num_mini_batches = num_mini_batches,
397
+ total_episodes = total_episodes,
398
+ local_rollout_forward_batch_size = local_rollout_forward_batch_size,
399
+ num_sample_generations = num_sample_generations,
400
+ response_length = response_length,
401
+ stop_token = stop_token,
402
+ stop_token_id = stop_token_id,
403
+ temperature = temperature,
404
+ missing_eos_penalty = missing_eos_penalty,
405
+ sft_model_path = sft_model_path,
406
+ world_size = world_size,
407
+ num_total_batches = num_total_batches,
408
+ micro_batch_size = micro_batch_size,
409
+ local_batch_size = local_batch_size,
410
+ batch_size = batch_size,
411
+ local_mini_batch_size = local_mini_batch_size,
412
+ mini_batch_size = mini_batch_size,
413
+ exp_name = exp_name,
414
+ reward_model_path = reward_model_path,
415
+ num_ppo_epochs = num_ppo_epochs,
416
+ whiten_rewards = whiten_rewards,
417
+ kl_coef = kl_coef,
418
+ cliprange = cliprange,
419
+ rloo_k = rloo_k,
420
+ normalize_reward = normalize_reward,
421
+ reward_clip_range = reward_clip_range,
422
+ normalize_advantage = normalize_advantage,
423
+ token_level_kl = token_level_kl,
424
+ ds3_gather_for_generation = ds3_gather_for_generation,**kwargs)
425
+ self.vllm_sampling_params = vllm_sampling_params
426
+ self.unsloth_num_chunks = unsloth_num_chunks
427
+ pass
428
+
429
+ class _UnslothRLOOTrainer(Trainer):
430
+ _tag_names = ["trl", "rloo"]
431
+
432
+ def __init__(
433
+ self,
434
+ config: RLOOConfig,
435
+ processing_class: Optional[
436
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
437
+ ],
438
+ policy: nn.Module,
439
+ ref_policy: nn.Module,
440
+ reward_model: Union[nn.Module, Callable[[list[str]], list[float]]],
441
+ train_dataset: Dataset,
442
+ data_collator: Optional[DataCollatorWithPadding] = None,
443
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
444
+ # less commonly used
445
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
446
+ callbacks: Optional[list[TrainerCallback]] = None,
447
+ ) -> None:
448
+ if ref_policy is policy:
449
+ raise ValueError(
450
+ "`policy` and `ref_policy` cannot be the same object. If you want `ref_policy` to be the "
451
+ "same as `policy`, you must mass a copy of it, or `None` if you use peft."
452
+ )
453
+
454
+ self.args = config
455
+ args = config
456
+ self.processing_class = processing_class
457
+ self.policy = policy
458
+
459
+ # Define the collator if not provided
460
+ if data_collator is None:
461
+ data_collator = DataCollatorWithPadding(self.processing_class)
462
+
463
+ self.policy.generation_config.eos_token_id = (
464
+ None # disable `pad_token_id` and `eos_token_id` because we just want to
465
+ )
466
+ self.policy.generation_config.pad_token_id = None # generate tokens without truncation / padding
467
+
468
+ self.ref_policy = ref_policy
469
+ self.reward_model = reward_model
470
+ self.train_dataset = train_dataset
471
+ self.train_dataset_len = len(train_dataset)
472
+ self.data_collator = data_collator
473
+ self.eval_dataset = eval_dataset
474
+ self.optimizer, self.lr_scheduler = optimizers
475
+ self.optimizer_cls_and_kwargs = None # needed for transformers >= 4.47
476
+
477
+ #########
478
+ # calculate various batch sizes
479
+ #########
480
+ if args.total_episodes is None: # allow the users to define episodes in terms of epochs.
481
+ args.total_episodes = int(args.num_train_epochs * self.train_dataset_len)
482
+ accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps)
483
+ self.accelerator = accelerator
484
+ args.world_size = accelerator.num_processes
485
+ args.local_batch_size = (
486
+ args.per_device_train_batch_size * args.gradient_accumulation_steps * args.num_mini_batches
487
+ )
488
+ args.micro_batch_size = int(args.per_device_train_batch_size * args.world_size)
489
+ args.batch_size = int(args.local_batch_size * args.world_size)
490
+ args.mini_batch_size = exact_div(
491
+ args.batch_size, args.num_mini_batches, "`batch_size` must be a multiple of `num_mini_batches`"
492
+ )
493
+ args.local_mini_batch_size = exact_div(
494
+ args.local_batch_size, args.num_mini_batches, "`local_batch_size` must be a multiple of `num_mini_batches`"
495
+ )
496
+ args.num_total_batches = math.ceil(
497
+ args.total_episodes / args.batch_size
498
+ ) # we may train for more than `total_episodes`
499
+ time_tensor = torch.tensor(int(time.time()), device=accelerator.device)
500
+ time_int = broadcast(time_tensor, 0).item() # avoid different timestamps across processes
501
+ args.run_name = f"{args.exp_name}__{args.seed}__{time_int}"
502
+ self.local_seed = args.seed + accelerator.process_index * 100003 # Prime
503
+ if args.num_sample_generations > 0:
504
+ self.sample_generations_freq = max(1, args.num_total_batches // args.num_sample_generations)
505
+ self.local_dataloader_batch_size = exact_div(
506
+ args.local_batch_size, args.rloo_k, "`local_batch_size` must be a multiple of rloo_k"
507
+ ) # RLOO logic: needed because RLOO repeats the same prompt args.rloo_k times
508
+
509
+ #########
510
+ # setup model, optimizer, and others
511
+ #########
512
+ for module in [policy, ref_policy, reward_model]:
513
+ if isinstance(module, nn.Module):
514
+ disable_dropout_in_model(module)
515
+ if args.stop_token and args.stop_token == "eos":
516
+ args.stop_token_id = self.processing_class.eos_token_id
517
+ self.model = policy
518
+ self.create_optimizer_and_scheduler(
519
+ num_training_steps=args.num_total_batches
520
+ ) # note that we are calling `self.lr_scheduler.step()` manually only at the batch level
521
+
522
+ #########
523
+ ### trainer specifics
524
+ #########
525
+ default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
526
+ self.callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
527
+ self.callback_handler = CallbackHandler(
528
+ self.callbacks, self.model, self.processing_class, self.optimizer, self.lr_scheduler
529
+ )
530
+ self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
531
+ self.control = TrainerControl()
532
+ self.state = OnlineTrainerState(
533
+ is_local_process_zero=self.is_local_process_zero(),
534
+ is_world_process_zero=self.is_world_process_zero(),
535
+ stateful_callbacks=[
536
+ cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
537
+ ],
538
+ )
539
+
540
+ self.current_flos = 0
541
+ self.hp_search_backend = None
542
+ self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
543
+ self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
544
+ # Create distant repo and output directory if needed
545
+ self.hub_model_id = None
546
+ if self.args.push_to_hub:
547
+ self.init_hf_repo()
548
+ if self.args.should_save:
549
+ os.makedirs(self.args.output_dir, exist_ok=True)
550
+ self.backup_model = None
551
+
552
+ # Add tags for models that have been loaded with the correct transformers version
553
+ if hasattr(self.model, "add_model_tags"):
554
+ self.model.add_model_tags(self._tag_names)
555
+
556
+ #########
557
+ ### setup dataloader
558
+ #########
559
+ self.dataloader = DataLoader(
560
+ self.train_dataset,
561
+ batch_size=self.local_dataloader_batch_size,
562
+ shuffle=True,
563
+ collate_fn=self.data_collator,
564
+ drop_last=True, # needed; otherwise the last batch will be of ragged shape
565
+ )
566
+ # sync random states for DataLoader(shuffle=True) before `accelerator.prepare`
567
+ # see https://gist.github.com/vwxyzjn/2581bff1e48e185e0b85b6dfe1def79c
568
+ torch.manual_seed(args.seed)
569
+ self.model, self.optimizer, self.dataloader = accelerator.prepare(self.model, self.optimizer, self.dataloader)
570
+ torch.manual_seed(self.local_seed) # reset the local seed again
571
+
572
+ self.eval_dataloader = DataLoader(
573
+ self.eval_dataset,
574
+ batch_size=args.per_device_eval_batch_size,
575
+ collate_fn=self.data_collator,
576
+ drop_last=True,
577
+ ) # no need to shuffle eval dataset
578
+ self.eval_dataloader = accelerator.prepare(self.eval_dataloader)
579
+
580
+ if self.is_deepspeed_enabled:
581
+ if isinstance(self.reward_model, nn.Module):
582
+ self.reward_model = prepare_deepspeed(
583
+ self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16
584
+ )
585
+ self.ref_policy = prepare_deepspeed(
586
+ self.ref_policy, args.per_device_train_batch_size, args.fp16, args.bf16
587
+ )
588
+ self.deepspeed = self.model
589
+ else:
590
+ self.ref_policy = self.ref_policy.to(self.accelerator.device)
591
+ if isinstance(self.reward_model, nn.Module):
592
+ self.reward_model = self.reward_model.to(self.accelerator.device)
593
+
594
+ def get_train_dataloader(self) -> DataLoader:
595
+ return self.dataloader
596
+
597
+ def get_eval_dataloader(self) -> DataLoader:
598
+ return self.eval_dataloader
599
+
600
+ def train(self):
601
+ args = self.args
602
+ accelerator = self.accelerator
603
+ optimizer = self.optimizer
604
+ model = self.model
605
+ self.model_wrapped = self.model
606
+ ref_policy = self.ref_policy
607
+ reward_model = self.reward_model
608
+ processing_class = self.processing_class
609
+ dataloader = self.dataloader
610
+ device = accelerator.device
611
+
612
+ def repeat_generator():
613
+ while True:
614
+ yield from dataloader
615
+
616
+ iter_dataloader = iter(repeat_generator())
617
+ generation_config = GenerationConfig(
618
+ max_new_tokens=args.response_length,
619
+ temperature=(args.temperature + 1e-7),
620
+ top_k=0.0,
621
+ top_p=1.0,
622
+ do_sample=True,
623
+ )
624
+
625
+ accelerator.print("===training policy===")
626
+ start_time = time.time()
627
+ stats_shape = (args.num_ppo_epochs, args.num_mini_batches, args.gradient_accumulation_steps)
628
+ approxkl_stats = torch.zeros(stats_shape, device=device)
629
+ pg_clipfrac_stats = torch.zeros(stats_shape, device=device)
630
+ pg_loss_stats = torch.zeros(stats_shape, device=device)
631
+ vf_clipfrac_stats = torch.zeros(stats_shape, device=device)
632
+ entropy_stats = torch.zeros(stats_shape, device=device)
633
+ ratio_stats = torch.zeros(stats_shape, device=device)
634
+ model.train()
635
+
636
+ # trainer state initialization
637
+ self.state.global_step = 0
638
+ self.state.episode = 0
639
+ self.state.max_steps = (args.num_total_batches * args.num_mini_batches) // 2
640
+ self.state.num_train_epochs = args.total_episodes / self.train_dataset_len
641
+ # Compute absolute values for logging, eval, and save if given as ratio
642
+ if args.logging_steps is not None:
643
+ if args.logging_steps < 1:
644
+ self.state.logging_steps = math.ceil(self.state.max_steps * args.logging_steps)
645
+ else:
646
+ self.state.logging_steps = args.logging_steps
647
+ if args.eval_steps is not None:
648
+ if args.eval_steps < 1:
649
+ self.state.eval_steps = math.ceil(self.state.max_steps * args.eval_steps)
650
+ else:
651
+ self.state.eval_steps = args.eval_steps
652
+ if args.save_steps is not None:
653
+ if args.save_steps < 1:
654
+ self.state.save_steps = math.ceil(self.state.max_steps * args.save_steps)
655
+ else:
656
+ self.state.save_steps = args.save_steps
657
+ self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
658
+
659
+ for update in range(1, args.num_total_batches + 1):
660
+ self.state.episode += 1 * args.batch_size
661
+ data = next(iter_dataloader)
662
+ with torch.no_grad():
663
+ queries = data["input_ids"].to(device)
664
+ queries = queries.repeat(args.rloo_k, 1)
665
+ context_length = queries.shape[1]
666
+ responses = []
667
+ postprocessed_responses = []
668
+ logprobs = []
669
+ ref_logprobs = []
670
+ scores = []
671
+ sequence_lengths = []
672
+
673
+ # Generate responses and compute logprobs
674
+ with unwrap_model_for_generation(
675
+ self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
676
+ ) as unwrapped_model:
677
+ query_responses, logitss = batch_generation(
678
+ unwrapped_model,
679
+ queries,
680
+ args.local_rollout_forward_batch_size,
681
+ processing_class.pad_token_id,
682
+ generation_config,
683
+ )
684
+
685
+ # Process responses in batches
686
+ for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size):
687
+ query = queries[i : i + args.local_rollout_forward_batch_size]
688
+ query_response = query_responses[i : i + args.local_rollout_forward_batch_size]
689
+ response = query_response[:, context_length:]
690
+ logits = logitss[i : i + args.local_rollout_forward_batch_size]
691
+ logprob = selective_log_softmax(logits, response)
692
+ del logits
693
+ torch.cuda.empty_cache()
694
+
695
+ ref_output = forward(ref_policy, query_response, processing_class.pad_token_id)
696
+ ref_logits = ref_output.logits[:, context_length - 1 : -1]
697
+ ref_logits /= args.temperature + 1e-7
698
+ ref_logprob = selective_log_softmax(ref_logits, response)
699
+ del ref_output, ref_logits
700
+ torch.cuda.empty_cache()
701
+
702
+ # Response Processing 1. truncate response after the first occurrence of `stop_token_id`
703
+ postprocessed_response = response
704
+ if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
705
+ postprocessed_response = truncate_response(
706
+ args.stop_token_id, processing_class.pad_token_id, response
707
+ )
708
+
709
+ # Response Processing 2. run reward model on the truncated responses
710
+ postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
711
+ sequence_length = first_true_indices(postprocessed_response == processing_class.pad_token_id) - 1
712
+
713
+ if isinstance(reward_model, nn.Module):
714
+ _, score, _ = get_reward(
715
+ reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length
716
+ )
717
+ else:
718
+ score = torch.tensor(
719
+ reward_model(
720
+ processing_class.batch_decode(postprocessed_query_response, skip_special_tokens=True)
721
+ ),
722
+ dtype=torch.float,
723
+ ).to(device)
724
+
725
+ # Store batch results
726
+ responses.append(response)
727
+ postprocessed_responses.append(postprocessed_response)
728
+ logprobs.append(logprob)
729
+ ref_logprobs.append(ref_logprob)
730
+ sequence_lengths.append(sequence_length)
731
+ scores.append(score)
732
+
733
+ # Concatenate all batched results
734
+ responses = torch.cat(responses, 0)
735
+ postprocessed_responses = torch.cat(postprocessed_responses, 0)
736
+ logprobs = torch.cat(logprobs, 0)
737
+ ref_logprobs = torch.cat(ref_logprobs, 0)
738
+ sequence_lengths = torch.cat(sequence_lengths, 0)
739
+ scores = torch.cat(scores, 0)
740
+ del (logprob, ref_logprob, score)
741
+ torch.cuda.empty_cache()
742
+ gc.collect()
743
+
744
+ # Response Processing 3. filter response. Ensure that the sample contains stop_token_id
745
+ # responses not passing that filter will receive a low (fixed) score
746
+ # only query humans on responses that pass that filter
747
+ contain_eos_token = torch.any(postprocessed_responses == processing_class.eos_token_id, dim=-1)
748
+ if args.missing_eos_penalty is not None:
749
+ scores[~contain_eos_token] -= self.args.missing_eos_penalty
750
+ # accelerator.print(f"{scores=}, {(contain_eos_token.sum() / len(contain_eos_token))=}")
751
+
752
+ # be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw
753
+ response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1)
754
+ padding_mask = response_idxs > sequence_lengths.unsqueeze(1)
755
+ logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB)
756
+ ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB)
757
+
758
+ # 4. compute rewards
759
+ # Compute KL divergence
760
+ kl = logprobs - ref_logprobs
761
+
762
+ # Normalize rewards
763
+ if args.normalize_reward:
764
+ scores = (scores - scores.mean()) / (scores.std() + 1e-8)
765
+ scores = torch.clamp(scores, -args.reward_clip_range, args.reward_clip_range)
766
+
767
+ # Compute total reward with KL penalty
768
+ if args.token_level_kl:
769
+ # Token-level KL penalty: apply KL penalty per token
770
+ kl_reward = -args.kl_coef * kl
771
+
772
+ # Get the index of the last non-padded token for each sequence
773
+ eos_indices = padding_mask.size(1) - 1 - padding_mask.long().fliplr().argmax(dim=1, keepdim=True)
774
+ last_reward = torch.zeros_like(kl)
775
+ # Ensure scores has correct shape and type
776
+ scores_shaped = scores.reshape(-1, 1).to(kl.dtype)
777
+ last_reward.scatter_(dim=1, index=eos_indices, src=scores_shaped)
778
+
779
+ # Combine KL reward and last reward
780
+ non_score_reward = kl_reward.sum(1) # Keep this for logging
781
+ reward = last_reward + kl_reward
782
+ rlhf_reward = reward.sum(1) # Sum across sequence length
783
+ else:
784
+ # Sequence-level KL penalty: sum KL across tokens first
785
+ sequence_kl = kl.sum(1)
786
+ non_score_reward = -args.kl_coef * sequence_kl
787
+ rlhf_reward = non_score_reward + scores
788
+
789
+ # vectorized RLOO advantages implementation
790
+ rlhf_reward = rlhf_reward.reshape(args.rloo_k, -1)
791
+ baseline = (rlhf_reward.sum(0) - rlhf_reward) / (args.rloo_k - 1)
792
+ advantages = rlhf_reward - baseline
793
+ advantages = advantages.flatten()
794
+
795
+ # Normalize advantages
796
+ if args.normalize_advantage:
797
+ advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
798
+
799
+ torch.cuda.empty_cache()
800
+
801
+ # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch
802
+ for ppo_epoch_idx in range(args.num_ppo_epochs):
803
+ b_inds = np.random.permutation(args.local_batch_size)
804
+ minibatch_idx = 0
805
+ for mini_batch_start in range(0, args.local_batch_size, args.local_mini_batch_size):
806
+ mini_batch_end = mini_batch_start + args.local_mini_batch_size
807
+ mini_batch_inds = b_inds[mini_batch_start:mini_batch_end]
808
+ gradient_accumulation_idx = 0
809
+ for micro_batch_start in range(0, args.local_mini_batch_size, args.per_device_train_batch_size):
810
+ with accelerator.accumulate(model):
811
+ micro_batch_end = micro_batch_start + args.per_device_train_batch_size
812
+ micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end]
813
+
814
+ # Get batch data
815
+ mb_advantage = advantages[micro_batch_inds]
816
+ mb_responses = responses[micro_batch_inds]
817
+ mb_query_responses = query_responses[micro_batch_inds]
818
+ mb_logprobs = logprobs[micro_batch_inds]
819
+
820
+ # Forward pass
821
+ output = forward(model, mb_query_responses, processing_class.pad_token_id)
822
+ logits = output.logits[:, context_length - 1 : -1]
823
+ logits /= args.temperature + 1e-7
824
+
825
+ # Compute new logprobs
826
+ new_logprobs = selective_log_softmax(logits, mb_responses)
827
+ new_logprobs = torch.masked_fill(
828
+ new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB
829
+ )
830
+
831
+ # Compute probability ratios
832
+ new_ratio = (new_logprobs - mb_logprobs).exp()
833
+ new_logprobs = new_logprobs.sum(1)
834
+ mb_logprobs = mb_logprobs.sum(1)
835
+ logprobs_diff = new_logprobs - mb_logprobs
836
+ ratio = torch.exp(logprobs_diff)
837
+
838
+ # PPO clipped loss
839
+ pg_losses = -mb_advantage * ratio
840
+ pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange)
841
+ pg_loss_max = torch.max(pg_losses, pg_losses2)
842
+ pg_loss = pg_loss_max.mean()
843
+
844
+ # Final loss
845
+ loss = pg_loss
846
+
847
+ # Optimization step
848
+ accelerator.backward(loss)
849
+ optimizer.step()
850
+ optimizer.zero_grad()
851
+
852
+ with torch.no_grad():
853
+ pg_clipfrac = (pg_losses2 > pg_losses).float().mean()
854
+ prob_dist = torch.nn.functional.softmax(logits, dim=-1)
855
+ entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1)
856
+ approxkl = 0.5 * (logprobs_diff**2).mean()
857
+ approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl
858
+ pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = (
859
+ pg_clipfrac
860
+ )
861
+ pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss
862
+ entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean()
863
+ ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = new_ratio.mean()
864
+ gradient_accumulation_idx += 1
865
+ minibatch_idx += 1
866
+
867
+ # del everything and empty cache
868
+ # fmt: off
869
+ del (
870
+ output, logits, new_logprobs, logprobs_diff, ratio, pg_losses,
871
+ pg_losses2, pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl,
872
+ mb_advantage, mb_responses, mb_query_responses, mb_logprobs,
873
+ )
874
+ # fmt: on
875
+ torch.cuda.empty_cache()
876
+
877
+ # Compute metrics
878
+ with torch.no_grad():
879
+ mean_kl = kl.sum(1).mean()
880
+ mean_entropy = (-logprobs).sum(1).mean()
881
+ mean_non_score_reward = non_score_reward.mean()
882
+ eps = int(self.state.episode / (time.time() - start_time))
883
+ metrics = {}
884
+ metrics["eps"] = eps
885
+ metrics["objective/kl"] = self.accelerator.gather_for_metrics(mean_kl).mean().item()
886
+ metrics["objective/entropy"] = self.accelerator.gather_for_metrics(mean_entropy).mean().item()
887
+ metrics["objective/non_score_reward"] = (
888
+ self.accelerator.gather_for_metrics(mean_non_score_reward).mean().item()
889
+ )
890
+ metrics["objective/rlhf_reward"] = self.accelerator.gather_for_metrics(rlhf_reward).mean().item()
891
+ metrics["objective/scores"] = self.accelerator.gather_for_metrics(scores.mean()).mean().item()
892
+ metrics["policy/approxkl_avg"] = self.accelerator.gather_for_metrics(approxkl_stats).mean().item()
893
+ metrics["policy/clipfrac_avg"] = self.accelerator.gather_for_metrics(pg_clipfrac_stats).mean().item()
894
+ metrics["loss/policy_avg"] = self.accelerator.gather_for_metrics(pg_loss_stats).mean().item()
895
+ metrics["val/clipfrac_avg"] = self.accelerator.gather_for_metrics(vf_clipfrac_stats).mean().item()
896
+ metrics["policy/entropy_avg"] = self.accelerator.gather_for_metrics(entropy_stats).mean().item()
897
+ metrics["val/ratio"] = self.accelerator.gather_for_metrics(ratio_stats).mean().item()
898
+ metrics["val/ratio_var"] = self.accelerator.gather_for_metrics(ratio_stats).var().item()
899
+ metrics["val/num_eos_tokens"] = (responses == processing_class.eos_token_id).sum().item()
900
+ metrics["lr"] = self.lr_scheduler.get_last_lr()[0]
901
+ metrics["episode"] = self.state.episode
902
+ self.state.epoch = self.state.episode / (args.rloo_k * self.train_dataset_len) # used by self.log
903
+ self.log(metrics)
904
+ del kl, mean_kl, mean_entropy, scores
905
+
906
+ self.lr_scheduler.step()
907
+ self.state.global_step += 1
908
+ self.control = self.callback_handler.on_step_end(args, self.state, self.control)
909
+ if self.control.should_save:
910
+ self._save_checkpoint(model, trial=None)
911
+ self.control = self.callback_handler.on_save(self.args, self.state, self.control)
912
+ torch.cuda.empty_cache()
913
+ gc.collect()
914
+
915
+ if args.num_sample_generations > 0 and (update - 1) % self.sample_generations_freq == 0:
916
+ self.generate_completions(sampling=True)
917
+
918
+ # HF trainer specifics
919
+ self.control = self.callback_handler.on_train_end(args, self.state, self.control)
920
+ if self.control.should_save:
921
+ self._save_checkpoint(model, trial=None, metrics=None)
922
+ self.control = self.callback_handler.on_save(self.args, self.state, self.control)
923
+
924
+ def generate_completions(self, sampling: bool = False):
925
+ args = self.args
926
+ processing_class = self.processing_class
927
+ generation_config = GenerationConfig(
928
+ max_new_tokens=self.args.response_length,
929
+ temperature=(0.01 + 1e-7),
930
+ top_k=0.0,
931
+ top_p=1.0,
932
+ do_sample=True,
933
+ )
934
+
935
+ table = defaultdict(list)
936
+ with unwrap_model_for_generation(
937
+ self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
938
+ ) as unwrapped_model:
939
+ for batch in self.eval_dataloader:
940
+ query = batch["input_ids"]
941
+ with torch.no_grad():
942
+ context_length = query.shape[1]
943
+ query_response, _ = batch_generation(
944
+ unwrapped_model,
945
+ query,
946
+ query.shape[0],
947
+ processing_class.pad_token_id,
948
+ generation_config,
949
+ )
950
+ response = query_response[:, context_length:]
951
+ postprocessed_response = response
952
+ if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
953
+ postprocessed_response = truncate_response(
954
+ args.stop_token_id, processing_class.pad_token_id, response
955
+ )
956
+ table["query"].extend(
957
+ gather_object(processing_class.batch_decode(query, skip_special_tokens=True))
958
+ )
959
+ table["model response"].extend(
960
+ gather_object(processing_class.batch_decode(postprocessed_response))
961
+ )
962
+
963
+ postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
964
+
965
+ if isinstance(self.reward_model, nn.Module):
966
+ _, score, _ = get_reward(
967
+ self.reward_model,
968
+ postprocessed_query_response,
969
+ processing_class.pad_token_id,
970
+ context_length,
971
+ )
972
+ else:
973
+ score = torch.tensor(
974
+ self.reward_model(
975
+ processing_class.batch_decode(postprocessed_query_response, skip_special_tokens=True)
976
+ ),
977
+ dtype=torch.float,
978
+ ).to(postprocessed_query_response.device)
979
+ table["score"].extend(self.accelerator.gather_for_metrics(score).float().cpu().numpy())
980
+
981
+ if sampling:
982
+ break
983
+ df = pd.DataFrame(table)
984
+
985
+ if self.accelerator.is_main_process:
986
+ print_rich_table(df.iloc[0 : 0 + 5])
987
+ if "wandb" in args.report_to:
988
+ import wandb
989
+
990
+ if wandb.run is not None:
991
+ wandb.log({"completions": wandb.Table(dataframe=df)})
992
+
993
+ if "comet_ml" in args.report_to:
994
+ log_table_to_comet_experiment(
995
+ name="completions.csv",
996
+ table=df,
997
+ )
998
+
999
+ def create_model_card(
1000
+ self,
1001
+ model_name: Optional[str] = None,
1002
+ dataset_name: Optional[str] = None,
1003
+ tags: Union[str, list[str], None] = None,
1004
+ ):
1005
+ """
1006
+ Creates a draft of a model card using the information available to the `Trainer`.
1007
+
1008
+ Args:
1009
+ model_name (`str` or `None`, *optional*, defaults to `None`):
1010
+ Name of the model.
1011
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
1012
+ Name of the dataset used for training.
1013
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
1014
+ Tags to be associated with the model card.
1015
+ """
1016
+ if not self.is_world_process_zero():
1017
+ return
1018
+
1019
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
1020
+ base_model = self.model.config._name_or_path
1021
+ else:
1022
+ base_model = None
1023
+
1024
+ tags = tags or []
1025
+ if isinstance(tags, str):
1026
+ tags = [tags]
1027
+
1028
+ if hasattr(self.model.config, "unsloth_version"):
1029
+ tags.append("unsloth")
1030
+
1031
+ citation = textwrap.dedent("""\
1032
+ @inproceedings{ahmadian2024back,
1033
+ title = {{Back to Basics: Revisiting REINFORCE-Style Optimization for Learning from Human Feedback in LLMs}},
1034
+ author = {Arash Ahmadian and Chris Cremer and Matthias Gall{\'{e}} and Marzieh Fadaee and Julia Kreutzer and Olivier Pietquin and Ahmet {\"{U}}st{\"{u}}n and Sara Hooker},
1035
+ year = 2024,
1036
+ booktitle = {Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), {ACL} 2024, Bangkok, Thailand, August 11-16, 2024},
1037
+ publisher = {Association for Computational Linguistics},
1038
+ pages = {12248--12267},
1039
+ editor = {Lun{-}Wei Ku and Andre Martins and Vivek Srikumar},
1040
+ }""")
1041
+
1042
+ model_card = generate_model_card(
1043
+ base_model=base_model,
1044
+ model_name=model_name,
1045
+ hub_model_id=self.hub_model_id,
1046
+ dataset_name=dataset_name,
1047
+ tags=tags,
1048
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
1049
+ comet_url=get_comet_experiment_url(),
1050
+ trainer_name="RLOO",
1051
+ trainer_citation=citation,
1052
+ paper_title="Back to Basics: Revisiting REINFORCE-Style Optimization for Learning from Human Feedback in LLMs",
1053
+ paper_id="2402.14740",
1054
+ )
1055
+
1056
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
1057
+ class UnslothRLOOTrainer(_UnslothRLOOTrainer):
1058
+ """
1059
+
1060
+ """
1061
+ def __init__(
1062
+ self,
1063
+ config,
1064
+ processing_class,
1065
+ policy,
1066
+ ref_policy,
1067
+ reward_model,
1068
+ train_dataset,
1069
+ data_collator = None,
1070
+ eval_dataset = None,
1071
+ callbacks = None,
1072
+ **kwargs
1073
+ ):
1074
+ if args is None: args = UnslothRLOOConfig()
1075
+ _output_logits = False
1076
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
1077
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
1078
+ if _output_logits:
1079
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
1080
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1081
+ pass
1082
+ else:
1083
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
1084
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
1085
+ if args_max_seq_length is None and model_max_seq_length is not None:
1086
+ max_seq_length = model.max_seq_length
1087
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1088
+ if model is not None and hasattr(model, 'for_training'):
1089
+ model.for_training()
1090
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1091
+ if 'processing_class' in locals():
1092
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1093
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1094
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
1095
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
1096
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1097
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
1098
+ data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
1099
+ elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
1100
+ data_collator = DataCollatorForSeq2Seq(__tokenizer)
1101
+ else:
1102
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
1103
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
1104
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
1105
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1106
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
1107
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
1108
+ data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
1109
+ else:
1110
+ data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
1111
+ other_metrics = []
1112
+
1113
+ from unsloth_zoo.logging_utils import PatchRLStatistics
1114
+ PatchRLStatistics('rloo_trainer', other_metrics)
1115
+
1116
+ super().__init__(
1117
+ config = config,
1118
+ processing_class = processing_class,
1119
+ policy = policy,
1120
+ ref_policy = ref_policy,
1121
+ reward_model = reward_model,
1122
+ train_dataset = train_dataset,
1123
+ data_collator = data_collator,
1124
+ eval_dataset = eval_dataset,
1125
+ callbacks = callbacks,**kwargs)
1126
+ if hasattr(self, 'neftune_hook_handle'):
1127
+ self.neftune_hook_handle.remove()
1128
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1129
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
1130
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1131
+ pass
1132
+
1133
+ pass
unsloth_compiled_cache/UnslothRewardTrainer.py ADDED
@@ -0,0 +1,819 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.3.17
3
+ 2025.3.19
4
+ 4.50.3
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from trl.trainer.reward_trainer import (Any, BaseImageProcessor, Callable, DataCollator, Dataset, EvalPrediction, FeatureExtractionMixin, FrozenInstanceError, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, RewardConfig, RewardDataCollatorWithPadding, RewardTrainer, Trainer, TrainerCallback, Union, _tokenize, compute_accuracy, decode_and_strip_padding, defaultdict, disable_dropout_in_model, gather_object, generate_model_card, get_comet_experiment_url, inspect, is_peft_available, is_wandb_available, log_table_to_comet_experiment, maybe_apply_chat_template, nested_detach, nn, os, pd, prepare_model_for_kbit_training, print_rich_table, replace, torch, wandb, warnings)
13
+
14
+
15
+ import os
16
+ from typing import *
17
+ from dataclasses import dataclass, field
18
+ from packaging.version import Version
19
+ import torch
20
+ import numpy as np
21
+ from contextlib import nullcontext
22
+ from torch.nn import functional as F
23
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
+
25
+ torch_compile_options = {
26
+ "epilogue_fusion" : True,
27
+ "max_autotune" : False,
28
+ "shape_padding" : True,
29
+ "trace.enabled" : False,
30
+ "triton.cudagraphs" : False,
31
+ }
32
+
33
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
+ def selective_log_softmax(logits, index):
35
+ logits = logits.to(torch.float32)
36
+ selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
+ # loop to reduce peak mem consumption
38
+ # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
+ logsumexp_values = torch.logsumexp(logits, dim = -1)
40
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
+ return per_token_logps
42
+ @dataclass
43
+ class UnslothRewardConfig(RewardConfig):
44
+ """
45
+
46
+ Configuration class for the [`RewardTrainer`].
47
+
48
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
49
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
50
+ command line.
51
+
52
+ Parameters:
53
+ max_length (`int` or `None`, *optional*, defaults to `1024`):
54
+ Maximum length of the sequences (prompt + completion) in the batch, filters out entries that exceed the
55
+ limit. This argument is required if you want to use the default data collator.
56
+ disable_dropout (`bool`, *optional*, defaults to `True`):
57
+ Whether to disable dropout in the model.
58
+ dataset_num_proc (`int`, *optional*, defaults to `None`):
59
+ Number of processes to use for processing the dataset.
60
+ center_rewards_coefficient (`float`, *optional*, defaults to `None`):
61
+ Coefficient to incentivize the reward model to output mean-zero rewards (proposed by
62
+ https://huggingface.co/papers/2312.09244, Eq. 2). Recommended value: `0.01`.
63
+ remove_unused_columns (`bool`, *optional*, defaults to `False`):
64
+ Whether to remove the columns that are not used by the model's forward pass. Can be `True` only if
65
+ the dataset is pretokenized.
66
+
67
+ """
68
+ vllm_sampling_params: Optional[Any] = field(
69
+ default = None,
70
+ metadata = {'help': 'vLLM SamplingParams'},
71
+ )
72
+ unsloth_num_chunks : Optional[int] = field(
73
+ default = -1,
74
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
75
+ )
76
+ def __init__(
77
+ self,
78
+ output_dir = None,
79
+ overwrite_output_dir = None,
80
+ do_train = False,
81
+ do_eval = False,
82
+ do_predict = False,
83
+ eval_strategy = 'no',
84
+ prediction_loss_only = False,
85
+ per_device_train_batch_size = 4,
86
+ per_device_eval_batch_size = 4,
87
+ per_gpu_train_batch_size = None,
88
+ per_gpu_eval_batch_size = None,
89
+ gradient_accumulation_steps = 2,
90
+ eval_accumulation_steps = 2,
91
+ eval_delay = 0,
92
+ torch_empty_cache_steps = 250,
93
+ learning_rate = 5e-05,
94
+ weight_decay = 0.01,
95
+ adam_beta1 = 0.9,
96
+ adam_beta2 = 0.999,
97
+ adam_epsilon = 1e-08,
98
+ max_grad_norm = 1.0,
99
+ num_train_epochs = 3.0,
100
+ max_steps = -1,
101
+ lr_scheduler_type = 'linear',
102
+ warmup_ratio = 0.1,
103
+ warmup_steps = 0,
104
+ log_level = 'passive',
105
+ log_level_replica = 'warning',
106
+ log_on_each_node = True,
107
+ logging_dir = None,
108
+ logging_strategy = 'steps',
109
+ logging_first_step = False,
110
+ logging_steps = 1,
111
+ logging_nan_inf_filter = False,
112
+ save_strategy = 'steps',
113
+ save_steps = 500,
114
+ save_total_limit = None,
115
+ save_safetensors = True,
116
+ save_on_each_node = False,
117
+ save_only_model = False,
118
+ restore_callback_states_from_checkpoint = False,
119
+ no_cuda = False,
120
+ use_cpu = False,
121
+ use_mps_device = False,
122
+ seed = 3407,
123
+ data_seed = 3407,
124
+ jit_mode_eval = False,
125
+ use_ipex = False,
126
+ bf16 = False,
127
+ fp16 = False,
128
+ fp16_opt_level = 'O1',
129
+ half_precision_backend = 'auto',
130
+ bf16_full_eval = False,
131
+ fp16_full_eval = False,
132
+ tf32 = None,
133
+ local_rank = -1,
134
+ ddp_backend = None,
135
+ tpu_num_cores = None,
136
+ tpu_metrics_debug = False,
137
+ debug = '',
138
+ dataloader_drop_last = False,
139
+ eval_steps = None,
140
+ dataloader_num_workers = 0,
141
+ dataloader_prefetch_factor = None,
142
+ past_index = -1,
143
+ run_name = None,
144
+ disable_tqdm = None,
145
+ remove_unused_columns = False,
146
+ label_names = None,
147
+ load_best_model_at_end = False,
148
+ metric_for_best_model = None,
149
+ greater_is_better = None,
150
+ ignore_data_skip = False,
151
+ fsdp = '',
152
+ fsdp_min_num_params = 0,
153
+ fsdp_config = None,
154
+ tp_size = 0,
155
+ fsdp_transformer_layer_cls_to_wrap = None,
156
+ accelerator_config = None,
157
+ deepspeed = None,
158
+ label_smoothing_factor = 0.0,
159
+ optim = 'adamw_8bit',
160
+ optim_args = None,
161
+ adafactor = False,
162
+ group_by_length = False,
163
+ length_column_name = 'length',
164
+ report_to = None,
165
+ ddp_find_unused_parameters = None,
166
+ ddp_bucket_cap_mb = None,
167
+ ddp_broadcast_buffers = None,
168
+ dataloader_pin_memory = True,
169
+ dataloader_persistent_workers = False,
170
+ skip_memory_metrics = True,
171
+ use_legacy_prediction_loop = False,
172
+ push_to_hub = False,
173
+ resume_from_checkpoint = None,
174
+ hub_model_id = None,
175
+ hub_strategy = 'every_save',
176
+ hub_token = None,
177
+ hub_private_repo = None,
178
+ hub_always_push = False,
179
+ gradient_checkpointing = False,
180
+ gradient_checkpointing_kwargs = None,
181
+ include_inputs_for_metrics = False,
182
+ eval_do_concat_batches = True,
183
+ fp16_backend = 'auto',
184
+ evaluation_strategy = None,
185
+ push_to_hub_model_id = None,
186
+ push_to_hub_organization = None,
187
+ push_to_hub_token = None,
188
+ mp_parameters = '',
189
+ auto_find_batch_size = False,
190
+ full_determinism = False,
191
+ torchdynamo = None,
192
+ ray_scope = 'last',
193
+ ddp_timeout = 1800,
194
+ torch_compile = False,
195
+ torch_compile_backend = None,
196
+ torch_compile_mode = None,
197
+ dispatch_batches = None,
198
+ split_batches = None,
199
+ include_tokens_per_second = False,
200
+ include_num_input_tokens_seen = False,
201
+ neftune_noise_alpha = None,
202
+ optim_target_modules = None,
203
+ batch_eval_metrics = False,
204
+ eval_on_start = False,
205
+ use_liger_kernel = False,
206
+ eval_use_gather_object = False,
207
+ average_tokens_across_devices = False,
208
+ max_length = 1024,
209
+ disable_dropout = True,
210
+ dataset_num_proc = None,
211
+ center_rewards_coefficient = None,
212
+ vllm_sampling_params = None,
213
+ unsloth_num_chunks = -1,
214
+ **kwargs,
215
+ ):
216
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
217
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
218
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
219
+ output_dir = 'unsloth_training_checkpoints'
220
+ save_strategy = 'no'
221
+ if dataset_num_proc is None:
222
+ from multiprocessing import cpu_count
223
+ dataset_num_proc = cpu_count()
224
+
225
+ super().__init__(
226
+ output_dir = output_dir,
227
+ overwrite_output_dir = overwrite_output_dir,
228
+ do_train = do_train,
229
+ do_eval = do_eval,
230
+ do_predict = do_predict,
231
+ eval_strategy = eval_strategy,
232
+ prediction_loss_only = prediction_loss_only,
233
+ per_device_train_batch_size = per_device_train_batch_size,
234
+ per_device_eval_batch_size = per_device_eval_batch_size,
235
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
236
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
237
+ gradient_accumulation_steps = gradient_accumulation_steps,
238
+ eval_accumulation_steps = eval_accumulation_steps,
239
+ eval_delay = eval_delay,
240
+ torch_empty_cache_steps = torch_empty_cache_steps,
241
+ learning_rate = learning_rate,
242
+ weight_decay = weight_decay,
243
+ adam_beta1 = adam_beta1,
244
+ adam_beta2 = adam_beta2,
245
+ adam_epsilon = adam_epsilon,
246
+ max_grad_norm = max_grad_norm,
247
+ num_train_epochs = num_train_epochs,
248
+ max_steps = max_steps,
249
+ lr_scheduler_type = lr_scheduler_type,
250
+ warmup_ratio = warmup_ratio,
251
+ warmup_steps = warmup_steps,
252
+ log_level = log_level,
253
+ log_level_replica = log_level_replica,
254
+ log_on_each_node = log_on_each_node,
255
+ logging_dir = logging_dir,
256
+ logging_strategy = logging_strategy,
257
+ logging_first_step = logging_first_step,
258
+ logging_steps = logging_steps,
259
+ logging_nan_inf_filter = logging_nan_inf_filter,
260
+ save_strategy = save_strategy,
261
+ save_steps = save_steps,
262
+ save_total_limit = save_total_limit,
263
+ save_safetensors = save_safetensors,
264
+ save_on_each_node = save_on_each_node,
265
+ save_only_model = save_only_model,
266
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
267
+ no_cuda = no_cuda,
268
+ use_cpu = use_cpu,
269
+ use_mps_device = use_mps_device,
270
+ seed = seed,
271
+ data_seed = data_seed,
272
+ jit_mode_eval = jit_mode_eval,
273
+ use_ipex = use_ipex,
274
+ bf16 = bf16,
275
+ fp16 = fp16,
276
+ fp16_opt_level = fp16_opt_level,
277
+ half_precision_backend = half_precision_backend,
278
+ bf16_full_eval = bf16_full_eval,
279
+ fp16_full_eval = fp16_full_eval,
280
+ tf32 = tf32,
281
+ local_rank = local_rank,
282
+ ddp_backend = ddp_backend,
283
+ tpu_num_cores = tpu_num_cores,
284
+ tpu_metrics_debug = tpu_metrics_debug,
285
+ debug = debug,
286
+ dataloader_drop_last = dataloader_drop_last,
287
+ eval_steps = eval_steps,
288
+ dataloader_num_workers = dataloader_num_workers,
289
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
290
+ past_index = past_index,
291
+ run_name = run_name,
292
+ disable_tqdm = disable_tqdm,
293
+ remove_unused_columns = remove_unused_columns,
294
+ label_names = label_names,
295
+ load_best_model_at_end = load_best_model_at_end,
296
+ metric_for_best_model = metric_for_best_model,
297
+ greater_is_better = greater_is_better,
298
+ ignore_data_skip = ignore_data_skip,
299
+ fsdp = fsdp,
300
+ fsdp_min_num_params = fsdp_min_num_params,
301
+ fsdp_config = fsdp_config,
302
+ tp_size = tp_size,
303
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
304
+ accelerator_config = accelerator_config,
305
+ deepspeed = deepspeed,
306
+ label_smoothing_factor = label_smoothing_factor,
307
+ optim = optim,
308
+ optim_args = optim_args,
309
+ adafactor = adafactor,
310
+ group_by_length = group_by_length,
311
+ length_column_name = length_column_name,
312
+ report_to = report_to,
313
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
314
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
315
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
316
+ dataloader_pin_memory = dataloader_pin_memory,
317
+ dataloader_persistent_workers = dataloader_persistent_workers,
318
+ skip_memory_metrics = skip_memory_metrics,
319
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
320
+ push_to_hub = push_to_hub,
321
+ resume_from_checkpoint = resume_from_checkpoint,
322
+ hub_model_id = hub_model_id,
323
+ hub_strategy = hub_strategy,
324
+ hub_token = hub_token,
325
+ hub_private_repo = hub_private_repo,
326
+ hub_always_push = hub_always_push,
327
+ gradient_checkpointing = gradient_checkpointing,
328
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
329
+ include_inputs_for_metrics = include_inputs_for_metrics,
330
+ eval_do_concat_batches = eval_do_concat_batches,
331
+ fp16_backend = fp16_backend,
332
+ evaluation_strategy = evaluation_strategy,
333
+ push_to_hub_model_id = push_to_hub_model_id,
334
+ push_to_hub_organization = push_to_hub_organization,
335
+ push_to_hub_token = push_to_hub_token,
336
+ mp_parameters = mp_parameters,
337
+ auto_find_batch_size = auto_find_batch_size,
338
+ full_determinism = full_determinism,
339
+ torchdynamo = torchdynamo,
340
+ ray_scope = ray_scope,
341
+ ddp_timeout = ddp_timeout,
342
+ torch_compile = torch_compile,
343
+ torch_compile_backend = torch_compile_backend,
344
+ torch_compile_mode = torch_compile_mode,
345
+ dispatch_batches = dispatch_batches,
346
+ split_batches = split_batches,
347
+ include_tokens_per_second = include_tokens_per_second,
348
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
349
+ neftune_noise_alpha = neftune_noise_alpha,
350
+ optim_target_modules = optim_target_modules,
351
+ batch_eval_metrics = batch_eval_metrics,
352
+ eval_on_start = eval_on_start,
353
+ use_liger_kernel = use_liger_kernel,
354
+ eval_use_gather_object = eval_use_gather_object,
355
+ average_tokens_across_devices = average_tokens_across_devices,
356
+ max_length = max_length,
357
+ disable_dropout = disable_dropout,
358
+ dataset_num_proc = dataset_num_proc,
359
+ center_rewards_coefficient = center_rewards_coefficient,**kwargs)
360
+ self.vllm_sampling_params = vllm_sampling_params
361
+ self.unsloth_num_chunks = unsloth_num_chunks
362
+ pass
363
+
364
+ class _UnslothRewardTrainer(Trainer):
365
+ _tag_names = ["trl", "reward-trainer"]
366
+
367
+ def __init__(
368
+ self,
369
+ model: Optional[Union[PreTrainedModel, nn.Module]] = None,
370
+ args: Optional[RewardConfig] = None,
371
+ data_collator: Optional[DataCollator] = None,
372
+ train_dataset: Optional[Dataset] = None,
373
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
374
+ processing_class: Optional[
375
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
376
+ ] = None,
377
+ model_init: Optional[Callable[[], PreTrainedModel]] = None,
378
+ compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
379
+ callbacks: Optional[list[TrainerCallback]] = None,
380
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (
381
+ None,
382
+ None,
383
+ ),
384
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
385
+ peft_config: Optional[dict] = None,
386
+ ):
387
+ """
388
+ Initialize RewardTrainer.
389
+
390
+ Args:
391
+ model (`transformers.PreTrainedModel`):
392
+ The model to train, preferably an `AutoModelForSequenceClassification`.
393
+ args (`RewardConfig`):
394
+ The arguments to use for training.
395
+ data_collator (`transformers.DataCollator`):
396
+ The data collator to use for training. If None is specified, the default data collator (`RewardDataCollatorWithPadding`) will be used
397
+ which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
398
+ train_dataset (`datasets.Dataset`):
399
+ The dataset to use for training.
400
+ eval_dataset (`datasets.Dataset`):
401
+ The dataset to use for evaluation.
402
+ processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
403
+ Processing class used to process the data. If provided, will be used to automatically process the inputs
404
+ for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
405
+ reuse the fine-tuned model.
406
+ model_init (`Callable[[], transformers.PreTrainedModel]`):
407
+ The model initializer to use for training. If None is specified, the default model initializer will be used.
408
+ compute_metrics (`Callable[[transformers.EvalPrediction], dict]`, *optional* defaults to `compute_accuracy`):
409
+ The metrics to use for evaluation. If no metrics are specified, the default metric (`compute_accuracy`) will be used.
410
+ callbacks (`list[transformers.TrainerCallback]`):
411
+ The callbacks to use for training.
412
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
413
+ The optimizer and scheduler to use for training.
414
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
415
+ The function to use to preprocess the logits before computing the metrics.
416
+ peft_config (`dict`, defaults to `None`):
417
+ The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
418
+ """
419
+ if not is_peft_available() and peft_config is not None:
420
+ raise ValueError(
421
+ "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
422
+ )
423
+ elif is_peft_available() and peft_config is not None:
424
+ if not isinstance(model, PeftModel):
425
+ if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_quantized", False):
426
+ _supports_gc_kwargs = "gradient_checkpointing_kwargs" in list(
427
+ inspect.signature(prepare_model_for_kbit_training).parameters
428
+ )
429
+
430
+ prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
431
+
432
+ if not _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
433
+ warnings.warn(
434
+ "You passed `gradient_checkpointing_kwargs` in the trainer's kwargs, but your peft version does not support it. "
435
+ "please update to the latest version of peft to use `gradient_checkpointing_kwargs`.",
436
+ UserWarning,
437
+ )
438
+ elif _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
439
+ prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
440
+
441
+ model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
442
+
443
+ model = model
444
+
445
+ # Disable dropout in the model
446
+ if args.disable_dropout:
447
+ disable_dropout_in_model(model)
448
+
449
+ if compute_metrics is None:
450
+ compute_metrics = compute_accuracy
451
+
452
+ if data_collator is None:
453
+ if processing_class is None:
454
+ raise ValueError(
455
+ "A processing_class must be specified when using the default RewardDataCollatorWithPadding"
456
+ )
457
+
458
+ max_length = args.max_length
459
+
460
+ data_collator = RewardDataCollatorWithPadding(processing_class)
461
+
462
+ if args.remove_unused_columns:
463
+ try: # for bc before https://github.com/huggingface/transformers/pull/25435
464
+ args.remove_unused_columns = False
465
+ except FrozenInstanceError:
466
+ args = replace(args, remove_unused_columns=False)
467
+ # warn users
468
+ warnings.warn(
469
+ "When using RewardDataCollatorWithPadding, you should set `remove_unused_columns=False` in your RewardConfig"
470
+ " we have set it for you, but you should do it yourself in the future.",
471
+ UserWarning,
472
+ )
473
+
474
+ self.use_reward_data_collator = True
475
+ else:
476
+ self.use_reward_data_collator = False
477
+
478
+ # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
479
+ # input tensor associated with the key "input_ids". However, in Reward, the sampled data does not include the
480
+ # "input_ids" key. Instead, the available keys are "input_ids_chosen" and "input_ids_rejected". As a result,
481
+ # the trainer issues the warning: "Could not estimate the number of tokens of the input, floating-point
482
+ # operations will not be computed." To suppress this warning, we set the "estimate_tokens" key in the model's
483
+ # "warnings_issued" dictionary to True. This acts as a flag to indicate that the warning has already been
484
+ # issued.
485
+ model.warnings_issued["estimate_tokens"] = True
486
+
487
+ if "input_ids_chosen" not in train_dataset.column_names:
488
+ with PartialState().local_main_process_first():
489
+ fn_kwargs = {"tokenizer": processing_class}
490
+ train_dataset = train_dataset.map(maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class})
491
+ train_dataset = train_dataset.map(
492
+ _tokenize,
493
+ batched=True,
494
+ fn_kwargs=fn_kwargs,
495
+ num_proc=args.dataset_num_proc,
496
+ )
497
+ # This filter is important because otherwise you get samples that exceed the model's context length and
498
+ # get truncated => noisy signal the chosen/rejected label gets lost. The downside is that the
499
+ # user might get surprised if N samples are missing from training.
500
+ train_dataset = train_dataset.filter(
501
+ lambda x: len(x["input_ids_chosen"]) <= max_length and len(x["input_ids_rejected"]) <= max_length,
502
+ num_proc=args.dataset_num_proc,
503
+ )
504
+ if eval_dataset is not None:
505
+ eval_dataset = eval_dataset.map(
506
+ maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}
507
+ )
508
+ eval_dataset = eval_dataset.map(
509
+ _tokenize,
510
+ fn_kwargs=fn_kwargs,
511
+ batched=True,
512
+ num_proc=args.dataset_num_proc,
513
+ )
514
+ # This filter is important because otherwise you get samples that exceed the model's context length and
515
+ # get truncated => noisy signal the chosen/rejected label gets lost. The downside is that the
516
+ # user might get surprised if N samples are missing from training.
517
+ eval_dataset = eval_dataset.filter(
518
+ lambda x: len(x["input_ids_chosen"]) <= max_length
519
+ and len(x["input_ids_rejected"]) <= max_length,
520
+ num_proc=args.dataset_num_proc,
521
+ )
522
+
523
+ super().__init__(
524
+ model=model,
525
+ args=args,
526
+ data_collator=data_collator,
527
+ train_dataset=train_dataset,
528
+ eval_dataset=eval_dataset,
529
+ processing_class=processing_class,
530
+ model_init=model_init,
531
+ compute_metrics=compute_metrics,
532
+ callbacks=callbacks,
533
+ optimizers=optimizers,
534
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
535
+ )
536
+
537
+ # Add tags for models that have been loaded with the correct transformers version
538
+ if hasattr(self.model, "add_model_tags"):
539
+ self.model.add_model_tags(self._tag_names)
540
+
541
+ def compute_loss(
542
+ self,
543
+ model: Union[PreTrainedModel, nn.Module],
544
+ inputs: dict[str, Union[torch.Tensor, Any]],
545
+ return_outputs=False,
546
+ num_items_in_batch=None,
547
+ ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
548
+ rewards_chosen = model(
549
+ input_ids=inputs["input_ids_chosen"],
550
+ attention_mask=inputs["attention_mask_chosen"],
551
+ return_dict=True,
552
+ )["logits"]
553
+ rewards_rejected = model(
554
+ input_ids=inputs["input_ids_rejected"],
555
+ attention_mask=inputs["attention_mask_rejected"],
556
+ return_dict=True,
557
+ )["logits"]
558
+ # calculate loss, optionally modulate with margin
559
+ if "margin" in inputs:
560
+ loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected - inputs["margin"]).mean()
561
+ else:
562
+ loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected).mean()
563
+
564
+ if self.args.center_rewards_coefficient is not None:
565
+ loss += self.args.center_rewards_coefficient * torch.mean((rewards_chosen + rewards_rejected) ** 2)
566
+
567
+ if return_outputs:
568
+ return loss, {
569
+ "rewards_chosen": rewards_chosen,
570
+ "rewards_rejected": rewards_rejected,
571
+ }
572
+ return loss
573
+
574
+ def prediction_step(
575
+ self,
576
+ model: Union[PreTrainedModel, nn.Module],
577
+ inputs: dict[str, Union[torch.Tensor, Any]],
578
+ prediction_loss_only: bool,
579
+ ignore_keys: Optional[list[str]] = None,
580
+ ) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
581
+ inputs = self._prepare_inputs(inputs)
582
+ if ignore_keys is None:
583
+ if hasattr(self.model, "config"):
584
+ ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
585
+ else:
586
+ ignore_keys = []
587
+
588
+ with torch.no_grad():
589
+ loss, logits_dict = self.compute_loss(model, inputs, return_outputs=True)
590
+
591
+ if prediction_loss_only:
592
+ return (loss, None, None)
593
+
594
+ loss = loss.detach()
595
+ logits = tuple(v for k, v in logits_dict.items() if k not in ignore_keys)
596
+ logits = nested_detach(logits)
597
+ # Stack accepted against rejected, mean over logits
598
+ # and softmax to get preferences between accepted and rejected to sum to 1
599
+ logits = torch.stack(logits).mean(dim=2).softmax(dim=0).T
600
+
601
+ labels = torch.zeros(logits.shape[0])
602
+ labels = self._prepare_inputs(labels)
603
+
604
+ return loss, logits, labels
605
+
606
+ def evaluate(self, *args, **kwargs):
607
+ num_print_samples = kwargs.pop("num_print_samples", 4)
608
+ self.visualize_samples(num_print_samples)
609
+ return super().evaluate(*args, **kwargs)
610
+
611
+ def visualize_samples(self, num_print_samples: int):
612
+ """
613
+ Visualize the reward model logits prediction
614
+
615
+ Args:
616
+ num_print_samples (`int`, defaults to `4`):
617
+ The number of samples to print. Set to `-1` to print all samples.
618
+ """
619
+ eval_dataloader = self.get_eval_dataloader()
620
+ table = defaultdict(list)
621
+ for _, inputs in enumerate(eval_dataloader):
622
+ _, logits, _ = self.prediction_step(self.model, inputs, prediction_loss_only=False)
623
+ chosen_text = decode_and_strip_padding(inputs["input_ids_chosen"], self.processing_class)
624
+ rejected_text = decode_and_strip_padding(inputs["input_ids_rejected"], self.processing_class)
625
+ table["chosen_text"].extend(gather_object(chosen_text))
626
+ table["rejected_text"].extend(gather_object(rejected_text))
627
+ table["logits"].extend(
628
+ gather_object([[round(inner_item, 4) for inner_item in item] for item in logits.tolist()])
629
+ )
630
+ if num_print_samples >= 0 and len(table["chosen_text"]) >= num_print_samples:
631
+ break
632
+ df = pd.DataFrame(table)
633
+ if self.accelerator.process_index == 0:
634
+ print_rich_table(df[:num_print_samples])
635
+ if "wandb" in self.args.report_to:
636
+ import wandb
637
+
638
+ if wandb.run is not None:
639
+ wandb.log({"completions": wandb.Table(dataframe=df)})
640
+
641
+ if "comet_ml" in self.args.report_to:
642
+ log_table_to_comet_experiment(
643
+ name="completions.csv",
644
+ table=df,
645
+ )
646
+
647
+ def create_model_card(
648
+ self,
649
+ model_name: Optional[str] = None,
650
+ dataset_name: Optional[str] = None,
651
+ tags: Union[str, list[str], None] = None,
652
+ ):
653
+ """
654
+ Creates a draft of a model card using the information available to the `Trainer`.
655
+
656
+ Args:
657
+ model_name (`str` or `None`, *optional*, defaults to `None`):
658
+ Name of the model.
659
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
660
+ Name of the dataset used for training.
661
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
662
+ Tags to be associated with the model card.
663
+ """
664
+ if not self.is_world_process_zero():
665
+ return
666
+
667
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
668
+ base_model = self.model.config._name_or_path
669
+ else:
670
+ base_model = None
671
+
672
+ tags = tags or []
673
+ if isinstance(tags, str):
674
+ tags = [tags]
675
+
676
+ if hasattr(self.model.config, "unsloth_version"):
677
+ tags.append("unsloth")
678
+
679
+ model_card = generate_model_card(
680
+ base_model=base_model,
681
+ model_name=model_name,
682
+ hub_model_id=self.hub_model_id,
683
+ dataset_name=dataset_name,
684
+ tags=tags,
685
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
686
+ comet_url=get_comet_experiment_url(),
687
+ trainer_name="Reward",
688
+ )
689
+
690
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
691
+ class UnslothRewardTrainer(_UnslothRewardTrainer):
692
+ """
693
+
694
+ """
695
+ def __init__(
696
+ self,
697
+ model = None,
698
+ args = None,
699
+ data_collator = None,
700
+ train_dataset = None,
701
+ eval_dataset = None,
702
+ processing_class = None,
703
+ model_init = None,
704
+ compute_metrics = None,
705
+ callbacks = None,
706
+ preprocess_logits_for_metrics = None,
707
+ peft_config = None,
708
+ **kwargs
709
+ ):
710
+ if args is None: args = UnslothRewardConfig()
711
+ use_bf16 = getattr(args, 'bf16', False)
712
+ use_fp16 = getattr(args, 'fp16', False)
713
+ force_float32 = False
714
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
715
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
716
+ force_float32 = True
717
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
718
+ dtype = getattr(model.config, 'torch_dtype', None)
719
+ if dtype is None: dtype = model.get_input_embeddings().dtype
720
+ from unsloth_zoo.utils import _get_dtype
721
+ dtype = _get_dtype(dtype)
722
+ float16 = dtype == torch.float16
723
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
724
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
725
+ if force_float32:
726
+ args.fp16 = False
727
+ args.bf16 = False
728
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
729
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
730
+ args.fp16 = float16
731
+ args.bf16 = not float16
732
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
733
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
734
+ args.eval_strategy = 'steps'
735
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
736
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
737
+ if ga_steps is not None and ga_steps > 1:
738
+ from transformers import __version__ as transformers_version
739
+ if Version(transformers_version) <= Version('4.45.2'):
740
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
741
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
742
+ if getattr(args, 'eval_strategy', 'no') != 'no':
743
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
744
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
745
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
746
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
747
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
748
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
749
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
750
+ if force_float32:
751
+ args.bf16_full_eval = False
752
+ args.fp16_full_eval = False
753
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
754
+ args.bf16_full_eval = True
755
+ args.fp16_full_eval = False
756
+ elif not bf16_full_eval and not fp16_full_eval:
757
+ args.bf16_full_eval = args.bf16
758
+ args.fp16_full_eval = args.fp16
759
+ _output_logits = False
760
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
761
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
762
+ if _output_logits:
763
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
764
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
765
+ pass
766
+ else:
767
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
768
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
769
+ if args_max_seq_length is None and model_max_seq_length is not None:
770
+ max_seq_length = model.max_seq_length
771
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
772
+ if model is not None and hasattr(model, 'for_training'):
773
+ model.for_training()
774
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
775
+ if 'processing_class' in locals():
776
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
777
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
778
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
779
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
780
+ if not isinstance(data_collator, UnslothVisionDataCollator):
781
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
782
+ data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
783
+ elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
784
+ data_collator = DataCollatorForSeq2Seq(__tokenizer)
785
+ else:
786
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
787
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
788
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
789
+ if not isinstance(data_collator, UnslothVisionDataCollator):
790
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
791
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
792
+ data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
793
+ else:
794
+ data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
795
+ other_metrics = []
796
+
797
+ from unsloth_zoo.logging_utils import PatchRLStatistics
798
+ PatchRLStatistics('reward_trainer', other_metrics)
799
+
800
+ super().__init__(
801
+ model = model,
802
+ args = args,
803
+ data_collator = data_collator,
804
+ train_dataset = train_dataset,
805
+ eval_dataset = eval_dataset,
806
+ processing_class = processing_class,
807
+ model_init = model_init,
808
+ compute_metrics = compute_metrics,
809
+ callbacks = callbacks,
810
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics,
811
+ peft_config = peft_config,**kwargs)
812
+ if hasattr(self, 'neftune_hook_handle'):
813
+ self.neftune_hook_handle.remove()
814
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
815
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
816
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
817
+ pass
818
+
819
+ pass
unsloth_compiled_cache/UnslothSFTTrainer.py ADDED
@@ -0,0 +1,1031 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.3.17
3
+ 2025.3.19
4
+ 4.50.3
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from trl.trainer.sft_trainer import (Any, AutoModelForCausalLM, AutoTokenizer, BaseImageProcessor, Callable, ConstantLengthDataset, DataCollator, DataCollatorForLanguageModeling, Dataset, EvalPrediction, FeatureExtractionMixin, IterableDataset, Optional, PeftConfig, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SFTConfig, SFTTrainer, Trainer, TrainerCallback, TrainingArguments, Type, Union, dataclasses, defaultdict, deprecate_kwarg, generate_model_card, get_comet_experiment_url, get_peft_model, is_liger_kernel_available, is_peft_available, is_wandb_available, nn, os, pack_examples, peft, peft_module_casting_to_bf16, prepare_model_for_kbit_training, torch, transformers, version, wandb, warnings, Callable, ConstantLengthDataset, DataCollator, DataCollatorForLanguageModeling, Dataset, IterableDataset, Optional, Union, os, pack_examples, transformers, os)
13
+
14
+
15
+ import os
16
+ from typing import *
17
+ from dataclasses import dataclass, field
18
+ from packaging.version import Version
19
+ import torch
20
+ import numpy as np
21
+ from contextlib import nullcontext
22
+ from torch.nn import functional as F
23
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
+
25
+ torch_compile_options = {
26
+ "epilogue_fusion" : True,
27
+ "max_autotune" : False,
28
+ "shape_padding" : True,
29
+ "trace.enabled" : False,
30
+ "triton.cudagraphs" : False,
31
+ }
32
+
33
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
+ def selective_log_softmax(logits, index):
35
+ logits = logits.to(torch.float32)
36
+ selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
+ # loop to reduce peak mem consumption
38
+ # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
+ logsumexp_values = torch.logsumexp(logits, dim = -1)
40
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
+ return per_token_logps
42
+ @dataclass
43
+ class UnslothSFTConfig(SFTConfig):
44
+ """
45
+
46
+ Configuration class for the [`SFTTrainer`].
47
+
48
+ Only the parameters specific to SFT training are listed here. For details on other parameters, refer to the
49
+ [`~transformers.TrainingArguments`] documentation.
50
+
51
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
52
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
53
+ command line.
54
+
55
+ Parameters:
56
+ > Parameters that control the model
57
+
58
+ model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
59
+ Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model`
60
+ argument of the [`SFTTrainer`] is provided as a string.
61
+ use_liger (`bool`, *optional*, defaults to `False`):
62
+ Monkey patch the model with Liger kernels to increase throughput and reduce memory usage.
63
+
64
+ > Parameters that control the data preprocessing
65
+
66
+ dataset_text_field (`str`, *optional*, defaults to `"text"`):
67
+ Name of the column that contains text data in the dataset.
68
+ dataset_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
69
+ Dictionary of optional keyword arguments for the dataset preparation. The only supported key is
70
+ `skip_prepare_dataset`.
71
+ dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
72
+ Number of processes to use for processing the dataset.
73
+ max_seq_length (`int` or `None`, *optional*, defaults to `1024`):
74
+ Maximum length of the tokenized sequence. Sequences longer than `max_seq_length` are truncated from the
75
+ right.
76
+ If `None`, no truncation is applied. When packing is enabled, this value sets the sequence length.
77
+ packing (`bool`, *optional*, defaults to `False`):
78
+ Whether to pack multiple sequences into a fixed-length format. Uses `max_seq_length` to define sequence
79
+ length.
80
+ eval_packing (`bool` or `None`, *optional*, defaults to `None`):
81
+ Whether to pack the eval dataset. If `None`, uses the same value as `packing`.
82
+
83
+ > Parameters that control the training
84
+
85
+ learning_rate (`float`, *optional*, defaults to `2e-5`):
86
+ Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
87
+ [`~transformers.TrainingArguments`].
88
+
89
+ """
90
+ vllm_sampling_params: Optional[Any] = field(
91
+ default = None,
92
+ metadata = {'help': 'vLLM SamplingParams'},
93
+ )
94
+ unsloth_num_chunks : Optional[int] = field(
95
+ default = -1,
96
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
97
+ )
98
+ def __init__(
99
+ self,
100
+ output_dir = None,
101
+ overwrite_output_dir = None,
102
+ do_train = False,
103
+ do_eval = False,
104
+ do_predict = False,
105
+ eval_strategy = 'no',
106
+ prediction_loss_only = False,
107
+ per_device_train_batch_size = 4,
108
+ per_device_eval_batch_size = 4,
109
+ per_gpu_train_batch_size = None,
110
+ per_gpu_eval_batch_size = None,
111
+ gradient_accumulation_steps = 2,
112
+ eval_accumulation_steps = 2,
113
+ eval_delay = 0,
114
+ torch_empty_cache_steps = 250,
115
+ learning_rate = 5e-05,
116
+ weight_decay = 0.01,
117
+ adam_beta1 = 0.9,
118
+ adam_beta2 = 0.999,
119
+ adam_epsilon = 1e-08,
120
+ max_grad_norm = 1.0,
121
+ num_train_epochs = 3.0,
122
+ max_steps = -1,
123
+ lr_scheduler_type = 'linear',
124
+ warmup_ratio = 0.1,
125
+ warmup_steps = 0,
126
+ log_level = 'passive',
127
+ log_level_replica = 'warning',
128
+ log_on_each_node = True,
129
+ logging_dir = None,
130
+ logging_strategy = 'steps',
131
+ logging_first_step = False,
132
+ logging_steps = 1,
133
+ logging_nan_inf_filter = False,
134
+ save_strategy = 'steps',
135
+ save_steps = 500,
136
+ save_total_limit = None,
137
+ save_safetensors = True,
138
+ save_on_each_node = False,
139
+ save_only_model = False,
140
+ restore_callback_states_from_checkpoint = False,
141
+ no_cuda = False,
142
+ use_cpu = False,
143
+ use_mps_device = False,
144
+ seed = 3407,
145
+ data_seed = 3407,
146
+ jit_mode_eval = False,
147
+ use_ipex = False,
148
+ bf16 = False,
149
+ fp16 = False,
150
+ fp16_opt_level = 'O1',
151
+ half_precision_backend = 'auto',
152
+ bf16_full_eval = False,
153
+ fp16_full_eval = False,
154
+ tf32 = None,
155
+ local_rank = -1,
156
+ ddp_backend = None,
157
+ tpu_num_cores = None,
158
+ tpu_metrics_debug = False,
159
+ debug = '',
160
+ dataloader_drop_last = False,
161
+ eval_steps = None,
162
+ dataloader_num_workers = 0,
163
+ dataloader_prefetch_factor = None,
164
+ past_index = -1,
165
+ run_name = None,
166
+ disable_tqdm = None,
167
+ remove_unused_columns = True,
168
+ label_names = None,
169
+ load_best_model_at_end = False,
170
+ metric_for_best_model = None,
171
+ greater_is_better = None,
172
+ ignore_data_skip = False,
173
+ fsdp = '',
174
+ fsdp_min_num_params = 0,
175
+ fsdp_config = None,
176
+ tp_size = 0,
177
+ fsdp_transformer_layer_cls_to_wrap = None,
178
+ accelerator_config = None,
179
+ deepspeed = None,
180
+ label_smoothing_factor = 0.0,
181
+ optim = 'adamw_8bit',
182
+ optim_args = None,
183
+ adafactor = False,
184
+ group_by_length = False,
185
+ length_column_name = 'length',
186
+ report_to = None,
187
+ ddp_find_unused_parameters = None,
188
+ ddp_bucket_cap_mb = None,
189
+ ddp_broadcast_buffers = None,
190
+ dataloader_pin_memory = True,
191
+ dataloader_persistent_workers = False,
192
+ skip_memory_metrics = True,
193
+ use_legacy_prediction_loop = False,
194
+ push_to_hub = False,
195
+ resume_from_checkpoint = None,
196
+ hub_model_id = None,
197
+ hub_strategy = 'every_save',
198
+ hub_token = None,
199
+ hub_private_repo = None,
200
+ hub_always_push = False,
201
+ gradient_checkpointing = False,
202
+ gradient_checkpointing_kwargs = None,
203
+ include_inputs_for_metrics = False,
204
+ eval_do_concat_batches = True,
205
+ fp16_backend = 'auto',
206
+ evaluation_strategy = None,
207
+ push_to_hub_model_id = None,
208
+ push_to_hub_organization = None,
209
+ push_to_hub_token = None,
210
+ mp_parameters = '',
211
+ auto_find_batch_size = False,
212
+ full_determinism = False,
213
+ torchdynamo = None,
214
+ ray_scope = 'last',
215
+ ddp_timeout = 1800,
216
+ torch_compile = False,
217
+ torch_compile_backend = None,
218
+ torch_compile_mode = None,
219
+ dispatch_batches = None,
220
+ split_batches = None,
221
+ include_tokens_per_second = False,
222
+ include_num_input_tokens_seen = False,
223
+ neftune_noise_alpha = None,
224
+ optim_target_modules = None,
225
+ batch_eval_metrics = False,
226
+ eval_on_start = False,
227
+ use_liger_kernel = False,
228
+ eval_use_gather_object = False,
229
+ average_tokens_across_devices = False,
230
+ model_init_kwargs = None,
231
+ use_liger = False,
232
+ dataset_text_field = 'text',
233
+ dataset_kwargs = None,
234
+ dataset_num_proc = None,
235
+ max_seq_length = None,
236
+ packing = False,
237
+ eval_packing = None,
238
+ dataset_batch_size = None,
239
+ num_of_sequences = None,
240
+ chars_per_token = None,
241
+ vllm_sampling_params = None,
242
+ unsloth_num_chunks = -1,
243
+ **kwargs,
244
+ ):
245
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
246
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
247
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
248
+ output_dir = 'unsloth_training_checkpoints'
249
+ save_strategy = 'no'
250
+ if dataset_num_proc is None:
251
+ from multiprocessing import cpu_count
252
+ dataset_num_proc = cpu_count()
253
+
254
+ super().__init__(
255
+ output_dir = output_dir,
256
+ overwrite_output_dir = overwrite_output_dir,
257
+ do_train = do_train,
258
+ do_eval = do_eval,
259
+ do_predict = do_predict,
260
+ eval_strategy = eval_strategy,
261
+ prediction_loss_only = prediction_loss_only,
262
+ per_device_train_batch_size = per_device_train_batch_size,
263
+ per_device_eval_batch_size = per_device_eval_batch_size,
264
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
265
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
266
+ gradient_accumulation_steps = gradient_accumulation_steps,
267
+ eval_accumulation_steps = eval_accumulation_steps,
268
+ eval_delay = eval_delay,
269
+ torch_empty_cache_steps = torch_empty_cache_steps,
270
+ learning_rate = learning_rate,
271
+ weight_decay = weight_decay,
272
+ adam_beta1 = adam_beta1,
273
+ adam_beta2 = adam_beta2,
274
+ adam_epsilon = adam_epsilon,
275
+ max_grad_norm = max_grad_norm,
276
+ num_train_epochs = num_train_epochs,
277
+ max_steps = max_steps,
278
+ lr_scheduler_type = lr_scheduler_type,
279
+ warmup_ratio = warmup_ratio,
280
+ warmup_steps = warmup_steps,
281
+ log_level = log_level,
282
+ log_level_replica = log_level_replica,
283
+ log_on_each_node = log_on_each_node,
284
+ logging_dir = logging_dir,
285
+ logging_strategy = logging_strategy,
286
+ logging_first_step = logging_first_step,
287
+ logging_steps = logging_steps,
288
+ logging_nan_inf_filter = logging_nan_inf_filter,
289
+ save_strategy = save_strategy,
290
+ save_steps = save_steps,
291
+ save_total_limit = save_total_limit,
292
+ save_safetensors = save_safetensors,
293
+ save_on_each_node = save_on_each_node,
294
+ save_only_model = save_only_model,
295
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
296
+ no_cuda = no_cuda,
297
+ use_cpu = use_cpu,
298
+ use_mps_device = use_mps_device,
299
+ seed = seed,
300
+ data_seed = data_seed,
301
+ jit_mode_eval = jit_mode_eval,
302
+ use_ipex = use_ipex,
303
+ bf16 = bf16,
304
+ fp16 = fp16,
305
+ fp16_opt_level = fp16_opt_level,
306
+ half_precision_backend = half_precision_backend,
307
+ bf16_full_eval = bf16_full_eval,
308
+ fp16_full_eval = fp16_full_eval,
309
+ tf32 = tf32,
310
+ local_rank = local_rank,
311
+ ddp_backend = ddp_backend,
312
+ tpu_num_cores = tpu_num_cores,
313
+ tpu_metrics_debug = tpu_metrics_debug,
314
+ debug = debug,
315
+ dataloader_drop_last = dataloader_drop_last,
316
+ eval_steps = eval_steps,
317
+ dataloader_num_workers = dataloader_num_workers,
318
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
319
+ past_index = past_index,
320
+ run_name = run_name,
321
+ disable_tqdm = disable_tqdm,
322
+ remove_unused_columns = remove_unused_columns,
323
+ label_names = label_names,
324
+ load_best_model_at_end = load_best_model_at_end,
325
+ metric_for_best_model = metric_for_best_model,
326
+ greater_is_better = greater_is_better,
327
+ ignore_data_skip = ignore_data_skip,
328
+ fsdp = fsdp,
329
+ fsdp_min_num_params = fsdp_min_num_params,
330
+ fsdp_config = fsdp_config,
331
+ tp_size = tp_size,
332
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
333
+ accelerator_config = accelerator_config,
334
+ deepspeed = deepspeed,
335
+ label_smoothing_factor = label_smoothing_factor,
336
+ optim = optim,
337
+ optim_args = optim_args,
338
+ adafactor = adafactor,
339
+ group_by_length = group_by_length,
340
+ length_column_name = length_column_name,
341
+ report_to = report_to,
342
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
343
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
344
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
345
+ dataloader_pin_memory = dataloader_pin_memory,
346
+ dataloader_persistent_workers = dataloader_persistent_workers,
347
+ skip_memory_metrics = skip_memory_metrics,
348
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
349
+ push_to_hub = push_to_hub,
350
+ resume_from_checkpoint = resume_from_checkpoint,
351
+ hub_model_id = hub_model_id,
352
+ hub_strategy = hub_strategy,
353
+ hub_token = hub_token,
354
+ hub_private_repo = hub_private_repo,
355
+ hub_always_push = hub_always_push,
356
+ gradient_checkpointing = gradient_checkpointing,
357
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
358
+ include_inputs_for_metrics = include_inputs_for_metrics,
359
+ eval_do_concat_batches = eval_do_concat_batches,
360
+ fp16_backend = fp16_backend,
361
+ evaluation_strategy = evaluation_strategy,
362
+ push_to_hub_model_id = push_to_hub_model_id,
363
+ push_to_hub_organization = push_to_hub_organization,
364
+ push_to_hub_token = push_to_hub_token,
365
+ mp_parameters = mp_parameters,
366
+ auto_find_batch_size = auto_find_batch_size,
367
+ full_determinism = full_determinism,
368
+ torchdynamo = torchdynamo,
369
+ ray_scope = ray_scope,
370
+ ddp_timeout = ddp_timeout,
371
+ torch_compile = torch_compile,
372
+ torch_compile_backend = torch_compile_backend,
373
+ torch_compile_mode = torch_compile_mode,
374
+ dispatch_batches = dispatch_batches,
375
+ split_batches = split_batches,
376
+ include_tokens_per_second = include_tokens_per_second,
377
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
378
+ neftune_noise_alpha = neftune_noise_alpha,
379
+ optim_target_modules = optim_target_modules,
380
+ batch_eval_metrics = batch_eval_metrics,
381
+ eval_on_start = eval_on_start,
382
+ use_liger_kernel = use_liger_kernel,
383
+ eval_use_gather_object = eval_use_gather_object,
384
+ average_tokens_across_devices = average_tokens_across_devices,
385
+ model_init_kwargs = model_init_kwargs,
386
+ use_liger = use_liger,
387
+ dataset_text_field = dataset_text_field,
388
+ dataset_kwargs = dataset_kwargs,
389
+ dataset_num_proc = dataset_num_proc,
390
+ max_seq_length = max_seq_length,
391
+ packing = packing,
392
+ eval_packing = eval_packing,
393
+ dataset_batch_size = dataset_batch_size,
394
+ num_of_sequences = num_of_sequences,
395
+ chars_per_token = chars_per_token,**kwargs)
396
+ self.vllm_sampling_params = vllm_sampling_params
397
+ self.unsloth_num_chunks = unsloth_num_chunks
398
+ pass
399
+
400
+ class _UnslothSFTTrainer(Trainer):
401
+ """"""
402
+
403
+ _tag_names = ["trl", "sft"]
404
+
405
+ @deprecate_kwarg(
406
+ "tokenizer", "0.16.0", "processing_class", warn_if_greater_or_equal_version=True, raise_if_both_names=True
407
+ )
408
+ def __init__(
409
+ self,
410
+ model: Union[str, nn.Module, PreTrainedModel],
411
+ args: Optional[Union[SFTConfig, TrainingArguments]] = None,
412
+ data_collator: Optional[DataCollator] = None, # type: ignore
413
+ train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
414
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
415
+ processing_class: Optional[
416
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
417
+ ] = None,
418
+ compute_loss_func: Optional[Callable] = None,
419
+ compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
420
+ callbacks: Optional[list[TrainerCallback]] = None,
421
+ optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
422
+ optimizer_cls_and_kwargs: Optional[tuple[Type[torch.optim.Optimizer], dict[str, Any]]] = None,
423
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
424
+ peft_config: Optional["PeftConfig"] = None,
425
+ formatting_func: Optional[Union[Callable[[dict], str], Callable[[dict], list[str]]]] = None,
426
+ ):
427
+ # Args
428
+ if args is None:
429
+ model_name = model if isinstance(model, str) else model.config._name_or_path
430
+ model_name = model_name.split("/")[-1]
431
+ args = SFTConfig(f"{model_name}-SFT")
432
+ elif isinstance(args, TrainingArguments) and not isinstance(args, SFTConfig):
433
+ dict_args = args.to_dict()
434
+ dict_args["hub_token"] = args.hub_token # to_dict hides the hub_token
435
+ dict_args.pop("push_to_hub_token")
436
+ args = SFTConfig(**dict_args)
437
+
438
+ # Model
439
+ if args.model_init_kwargs is not None and not isinstance(model, str):
440
+ warnings.warn(
441
+ "You passed model_init_kwargs to the `SFTConfig`, but your model is already instantiated. "
442
+ "The `model_init_kwargs` will be ignored."
443
+ )
444
+ if isinstance(model, str):
445
+ model = self._create_model_from_path(model, args)
446
+
447
+ # PEFT configuration and model wrapping
448
+ if False:
449
+ model = self._prepare_peft_model(model, peft_config, args)
450
+
451
+ # Handle the tokenizer
452
+ if processing_class is None:
453
+ processing_class = AutoTokenizer.from_pretrained(model.config._name_or_path)
454
+ if processing_class.pad_token is None:
455
+ processing_class.pad_token = processing_class.eos_token # required for padding when collating data
456
+
457
+ # Dataset
458
+ preprocess_dataset = args.dataset_kwargs is None or not args.dataset_kwargs.get("skip_prepare_dataset", False)
459
+ if preprocess_dataset:
460
+ train_dataset = self._prepare_dataset(
461
+ train_dataset, processing_class, args, args.packing, formatting_func, "train"
462
+ )
463
+ if eval_dataset is not None:
464
+ packing = args.packing if args.eval_packing is None else args.eval_packing
465
+ if isinstance(eval_dataset, dict):
466
+ eval_dataset = {
467
+ key: self._prepare_dataset(dataset, processing_class, args, packing, formatting_func, key)
468
+ for key, dataset in eval_dataset.items()
469
+ }
470
+ else:
471
+ eval_dataset = self._prepare_dataset(
472
+ eval_dataset, processing_class, args, packing, formatting_func, "eval"
473
+ )
474
+
475
+ # Data collator
476
+ if data_collator is None:
477
+ data_collator = DataCollatorForLanguageModeling(tokenizer=processing_class, mlm=False)
478
+
479
+ # Initialize the metrics
480
+ self._metrics = defaultdict(list)
481
+
482
+ # Initialize the Trainer. Parent class will handle:
483
+ # - DeepSpeed configuration (through create_accelerator_and_postprocess)
484
+ # - FSDP setup
485
+ # - Distributed training setup
486
+ # - Optimizer and scheduler creation
487
+ # Some arguments are only available for transformers>=4.47.0. Can be removed when the min version is bumped.
488
+ super_init_kwargs = {}
489
+ if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
490
+ super_init_kwargs["optimizer_cls_and_kwargs"] = optimizer_cls_and_kwargs
491
+ else:
492
+ if optimizer_cls_and_kwargs is not None:
493
+ warnings.warn(
494
+ "The `optimizer_cls_and_kwargs` argument is only available for `transformers>=4.47.0`. "
495
+ "The default optimizer will be used. "
496
+ "Remove the `optimizer_cls_and_kwargs` or upgrade to `transformers>=4.47.0`."
497
+ )
498
+ super().__init__(
499
+ model=model,
500
+ args=args,
501
+ data_collator=data_collator,
502
+ train_dataset=train_dataset,
503
+ eval_dataset=eval_dataset,
504
+ processing_class=processing_class,
505
+ compute_loss_func=compute_loss_func,
506
+ compute_metrics=compute_metrics,
507
+ callbacks=callbacks,
508
+ optimizers=optimizers,
509
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
510
+ **super_init_kwargs,
511
+ )
512
+
513
+ # Add tags for models that have been loaded with the correct transformers version
514
+ if hasattr(self.model, "add_model_tags"):
515
+ self.model.add_model_tags(self._tag_names)
516
+
517
+ def _create_model_from_path(self, model_path: str, args: SFTConfig) -> PreTrainedModel:
518
+ """Creates a model from a path or model identifier."""
519
+ model_init_kwargs = args.model_init_kwargs or {}
520
+ # Handle torch dtype
521
+ torch_dtype = model_init_kwargs.get("torch_dtype")
522
+ if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None:
523
+ pass # torch_dtype is already a torch.dtype or "auto" or None
524
+ elif isinstance(torch_dtype, str): # it's a str, but not "auto"
525
+ torch_dtype = getattr(torch, torch_dtype)
526
+ model_init_kwargs["torch_dtype"] = torch_dtype
527
+ else:
528
+ raise ValueError(
529
+ "Invalid `torch_dtype` passed to `SFTConfig`. Expected either 'auto' or a string representing "
530
+ f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
531
+ )
532
+ # Disable caching if gradient checkpointing is enabled (not supported)
533
+ if args.gradient_checkpointing:
534
+ model_init_kwargs["use_cache"] = False
535
+
536
+ # Create model
537
+ if args.use_liger:
538
+ if not is_liger_kernel_available():
539
+ raise ImportError("Please install Liger-kernel for use_liger=True")
540
+ model = AutoLigerKernelForCausalLM.from_pretrained(model_path, **model_init_kwargs)
541
+ else:
542
+ model = AutoModelForCausalLM.from_pretrained(model_path, **model_init_kwargs)
543
+ return model
544
+
545
+ def _prepare_peft_model(self, model: PreTrainedModel, peft_config: Any, args: SFTConfig) -> PreTrainedModel:
546
+ """Prepares a model for PEFT training."""
547
+ if not is_peft_available():
548
+ raise ImportError("To use PeftModel, you need to install the `peft` library.")
549
+
550
+ if not isinstance(peft_config, PeftConfig):
551
+ raise ValueError(
552
+ f"Expected PeftConfig object but got {type(peft_config)}. If you want to use the PeftModel, you need "
553
+ "to pass a PeftConfig object to the SFTTrainer."
554
+ )
555
+
556
+ if isinstance(model, PeftModel):
557
+ return model
558
+
559
+ # Handle quantized models (QLoRA)
560
+ is_qlora = getattr(model, "is_loaded_in_4bit", False) or getattr(model, "is_loaded_in_8bit", False)
561
+
562
+ is_sharded_qlora = False
563
+ if getattr(model, "is_loaded_in_4bit", False):
564
+ # Check if model is sharded (FSDP/DS-Zero3)
565
+ for _, param in model.named_parameters():
566
+ if param.__class__.__name__ == "Params4bit":
567
+ is_sharded_qlora = param.data.device.type in {"cpu", "meta"}
568
+ break
569
+
570
+ # Prepare model for kbit training if needed
571
+ if is_qlora and not is_sharded_qlora:
572
+ model = self._prepare_model_for_kbit_training(model, args)
573
+ # Disable gradient checkpointing as it's handled by prepare_model_for_kbit_training
574
+ args = dataclasses.replace(args, gradient_checkpointing=False)
575
+ elif args.gradient_checkpointing:
576
+ model = self._enable_gradient_checkpointing(model, args)
577
+
578
+ # Create PEFT model
579
+ if (
580
+ version.parse(peft.__version__) >= version.parse("0.12") # autocast_adapter_dtype introduced in 0.12
581
+ and getattr(model, "is_loaded_in_4bit", False)
582
+ and is_sharded_qlora
583
+ ):
584
+ model = get_peft_model(model, peft_config, autocast_adapter_dtype=False)
585
+ else:
586
+ model = get_peft_model(model, peft_config)
587
+
588
+ # Handle bf16 casting for 4-bit models
589
+ if args.bf16 and getattr(model, "is_loaded_in_4bit", False) and not is_sharded_qlora:
590
+ peft_module_casting_to_bf16(model)
591
+
592
+ return model
593
+
594
+ def _prepare_model_for_kbit_training(self, model: PreTrainedModel, args: SFTConfig) -> PreTrainedModel:
595
+ """Prepares a quantized model for kbit training."""
596
+ prepare_model_kwargs = {
597
+ "use_gradient_checkpointing": args.gradient_checkpointing,
598
+ "gradient_checkpointing_kwargs": args.gradient_checkpointing_kwargs or {},
599
+ }
600
+
601
+ return prepare_model_for_kbit_training(model, **prepare_model_kwargs)
602
+
603
+ def _enable_gradient_checkpointing(self, model: PreTrainedModel, args: SFTConfig) -> PreTrainedModel:
604
+ """Enables gradient checkpointing for the model."""
605
+ gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
606
+ use_reentrant = (
607
+ "use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]
608
+ )
609
+
610
+ if use_reentrant:
611
+ if hasattr(model, "enable_input_require_grads"):
612
+ model.enable_input_require_grads()
613
+ else:
614
+
615
+ def make_inputs_require_grad(module, input, output):
616
+ output.requires_grad_(True)
617
+
618
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
619
+
620
+ return model
621
+
622
+ def _prepare_dataset(
623
+ self,
624
+ dataset: Union[Dataset, IterableDataset],
625
+ processing_class,
626
+ args,
627
+ packing: bool,
628
+ formatting_func: Optional[Callable[[dict], str]],
629
+ dataset_name: str,
630
+ ) -> Union[Dataset, IterableDataset]:
631
+ # All Unsloth Zoo code licensed under LGPLv3
632
+ if isinstance(dataset, ConstantLengthDataset): return dataset
633
+
634
+ map_kwargs = {}
635
+ use_desc = isinstance(dataset, Dataset)
636
+ is_vlm = hasattr(processing_class, "tokenizer")
637
+ tokenizer = processing_class
638
+ if is_vlm: tokenizer = processing_class.tokenizer
639
+
640
+ # Get max length
641
+ max_seq_length = getattr(args, "max_length", 0)
642
+ if max_seq_length == 0: max_seq_length = getattr(args, "max_seq_length", 0)
643
+ if max_seq_length == 0: max_seq_length = getattr(self, "max_seq_length", 0)
644
+ if max_seq_length == 0: max_seq_length = getattr(self, "max_seq", 0)
645
+ if max_seq_length == 0: raise RuntimeError("Unsloth: max_seq_length is 0! Please specify one!")
646
+ dataset_text_field = getattr(args, "dataset_text_field", "text")
647
+ do_truncation = max_seq_length != 0
648
+ do_formatting_func = False
649
+ do_tokenize = True
650
+
651
+ # Get correct column names
652
+ column_names = set(next(iter(dataset)).keys())
653
+ used_column_names = ["input_ids"]
654
+ if "attention_mask" in column_names:
655
+ used_column_names.append("attention_mask")
656
+
657
+ # Check if already tokenized so skip
658
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
659
+ if "labels" in column_names:
660
+ # Most likely forgot data collator!
661
+ if is_vlm and not hasattr(tokenizer, "pad"):
662
+ # Check if processing_class has a .pad, if not, use tokenizer.tokenizer
663
+ raise RuntimeError(f"Unsloth: {processing_class.__class__} does not have .pad!")
664
+ self.data_collator = DataCollatorForSeq2Seq(tokenizer)
665
+ used_column_names.append("labels")
666
+ do_tokenize = False
667
+ elif "input_ids" in column_names:
668
+ # Skip dataset prep, and set data collator
669
+ if is_vlm and not hasattr(tokenizer, "pad"):
670
+ # Check if processing_class has a .pad, if not, use tokenizer.tokenizer
671
+ raise RuntimeError(f"Unsloth: {processing_class.__class__} does not have .pad!")
672
+ self.data_collator = DataCollatorForLanguageModeling(tokenizer, mlm = False)
673
+ do_tokenize = False
674
+ elif dataset_text_field not in column_names:
675
+ do_formatting_func = True
676
+ if formatting_func is None:
677
+ raise RuntimeError("Unsloth: You must specify a `formatting_func`")
678
+ pass
679
+
680
+ if do_tokenize:
681
+ # Check double BOS tokens
682
+ if do_formatting_func:
683
+ test_text = formatting_func(next(iter(dataset)))
684
+ if not isinstance(test_text, list):
685
+ raise ValueError(
686
+ "Unsloth: The `formatting_func` should return a list of processed strings."
687
+ )
688
+ test_text = test_text[0]
689
+ else:
690
+ test_text = next(iter(dataset))[dataset_text_field][0]
691
+
692
+ # Get chat template
693
+ chat_template = getattr(processing_class, 'chat_template', '')
694
+ if chat_template == '' and is_vlm:
695
+ chat_template = getattr(tokenizer, 'chat_template', '')
696
+ if chat_template is None:
697
+ chat_template = ''
698
+
699
+ # Get bos_token
700
+ add_special_tokens = True
701
+ bos_token_1 = getattr(processing_class, 'bos_token', None)
702
+ bos_token_2 = getattr(tokenizer, 'bos_token', None)
703
+ bos_token = bos_token_1 or bos_token_2
704
+
705
+ if bos_token is not None:
706
+ if test_text.startswith(bos_token) or bos_token in chat_template:
707
+ add_special_tokens = False
708
+ print("Unsloth: We found double BOS tokens - we shall remove one automatically.")
709
+ pass
710
+
711
+ # Create tokenize function
712
+ def _tokenize(example):
713
+ return tokenizer(
714
+ example[dataset_text_field] if not do_formatting_func else formatting_func(example),
715
+ truncation = do_truncation,
716
+ max_length = max_seq_length,
717
+ return_token_type_ids = False,
718
+ add_special_tokens = add_special_tokens,
719
+ )
720
+ pass
721
+
722
+ if not isinstance(dataset, IterableDataset):
723
+ map_kwargs["num_proc"] = getattr(args, "dataset_num_proc", 2)
724
+ else:
725
+ map_kwargs["batch_size"] = dataset._ex_iterable.batch_size
726
+
727
+ if use_desc: map_kwargs["desc"] = f'Unsloth: Tokenizing ["{dataset_text_field}"]'
728
+ dataset = dataset.map(_tokenize, batched = True, **map_kwargs)
729
+
730
+ # If VLM, switch data collator since .pad is needed!
731
+ if is_vlm and not hasattr(processing_class, "pad"):
732
+ data_collator = DataCollatorForLanguageModeling(tokenizer, mlm = False)
733
+ self.data_collator = data_collator
734
+ pass
735
+ pass
736
+ if packing:
737
+ print("Unsloth: Hugging Face's packing is currently buggy - we're disabling it for now!")
738
+ return dataset
739
+
740
+ if max_seq_length == 0:
741
+ raise ValueError("When packing is enabled, `max_seq_length` can't be `None`.")
742
+
743
+ if use_desc: map_kwargs["desc"] = f"Unsloth: Packing {dataset_name} dataset"
744
+ dataset = dataset.select_columns(used_column_names).map(
745
+ pack_examples,
746
+ batched = True,
747
+ fn_kwargs = {"seq_length": max_seq_length,},
748
+ **map_kwargs,
749
+ )
750
+ pass
751
+ return dataset
752
+
753
+ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch = None):
754
+ outputs = super().compute_loss(
755
+ model,
756
+ inputs,
757
+ return_outputs = return_outputs,
758
+ num_items_in_batch = num_items_in_batch,
759
+ )
760
+ return outputs
761
+
762
+ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
763
+ metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics
764
+
765
+ # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
766
+ # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
767
+ if next(iter(logs.keys())).startswith("eval_"):
768
+ metrics = {f"eval_{key}": val for key, val in metrics.items()}
769
+
770
+ logs = {**logs, **metrics}
771
+ if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
772
+ super().log(logs, start_time)
773
+ else: # transformers<=4.46
774
+ super().log(logs)
775
+ self._metrics.clear()
776
+
777
+ def create_model_card(
778
+ self,
779
+ model_name: Optional[str] = None,
780
+ dataset_name: Optional[str] = None,
781
+ tags: Union[str, list[str], None] = None,
782
+ ):
783
+ """
784
+ Creates a draft of a model card using the information available to the `Trainer`.
785
+
786
+ Args:
787
+ model_name (`str` or `None`, *optional*, defaults to `None`):
788
+ Name of the model.
789
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
790
+ Name of the dataset used for training.
791
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
792
+ Tags to be associated with the model card.
793
+ """
794
+ if not self.is_world_process_zero():
795
+ return
796
+
797
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
798
+ base_model = self.model.config._name_or_path
799
+ else:
800
+ base_model = None
801
+
802
+ tags = tags or []
803
+ if isinstance(tags, str):
804
+ tags = [tags]
805
+
806
+ if hasattr(self.model.config, "unsloth_version"):
807
+ tags.append("unsloth")
808
+
809
+ model_card = generate_model_card(
810
+ base_model=base_model,
811
+ model_name=model_name,
812
+ hub_model_id=self.hub_model_id,
813
+ dataset_name=dataset_name,
814
+ tags=tags,
815
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
816
+ comet_url=get_comet_experiment_url(),
817
+ trainer_name="SFT",
818
+ )
819
+
820
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
821
+ class UnslothSFTTrainer(_UnslothSFTTrainer):
822
+ """
823
+
824
+ Trainer for Supervised Fine-Tuning (SFT) method.
825
+
826
+ This class is a wrapper around the [`transformers.Trainer`] class and inherits all of its attributes and methods.
827
+
828
+ Example:
829
+
830
+ ```python
831
+ from datasets import load_dataset
832
+ from trl import SFTTrainer
833
+
834
+ dataset = load_dataset("roneneldan/TinyStories", split="train[:1%]")
835
+
836
+ trainer = SFTTrainer(model="Qwen/Qwen2-0.5B-Instruct", train_dataset=dataset)
837
+ trainer.train()
838
+ ```
839
+
840
+ Args:
841
+ model (`Union[str, PreTrainedModel]`):
842
+ Model to be trained. Can be either:
843
+
844
+ - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or
845
+ a path to a *directory* containing model weights saved using
846
+ [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is
847
+ loaded using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keywork arguments
848
+ in `args.model_init_kwargs`.
849
+ - A [`~transformers.PreTrainedModel`] object. Only causal language models are supported.
850
+ args ([`SFTConfig`], *optional*, defaults to `None`):
851
+ Configuration for this trainer. If `None`, a default configuration is used.
852
+ data_collator (`DataCollator`, *optional*):
853
+ Function to use to form a batch from a list of elements of the prcessed `train_dataset` or `eval_dataset`.
854
+ Will default to [`~transformers.default_data_collator`] if no `processing_class` is provided, an instance
855
+ of [`~transformers.DataCollatorWithPadding`] otherwise if the processing_class is a feature extractor or
856
+ tokenizer.
857
+ train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
858
+ Dataset to use for training. SFT supports both [language modeling](#language-modeling) type and
859
+ [prompt-completion](#prompt-completion) type. The format of the samples can be either:
860
+
861
+ - [Standard](dataset_formats#standard): Each sample contains plain text.
862
+ - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
863
+ and content).
864
+
865
+ The trainer also supports processed datasets (tokenized) as long as they contain an `input_ids` field.
866
+ eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
867
+ Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
868
+ processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*, defaults to `None`):
869
+ Processing class used to process the data. If `None`, the processing class is loaded from the model's name
870
+ with [`~transformers.AutoTokenizer.from_pretrained`].
871
+ callbacks (list of [`~transformers.TrainerCallback`], *optional*, defaults to `None`):
872
+ List of callbacks to customize the training loop. Will add those to the list of default callbacks
873
+ detailed in [here](https://huggingface.co/docs/transformers/main_classes/callback).
874
+
875
+ If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`]
876
+ method.
877
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`):
878
+ A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your
879
+ model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
880
+ optimizer_cls_and_kwargs (`Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]`, *optional*, defaults to `None`):
881
+ A tuple containing the optimizer class and keyword arguments to use.
882
+ Overrides `optim` and `optim_args` in `args`. Incompatible with the `optimizers` argument.
883
+
884
+ Unlike `optimizers`, this argument avoids the need to place model parameters on the correct devices before initializing the Trainer.
885
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*, defaults to `None`):
886
+ A function that preprocess the logits right before caching them at each evaluation step. Must take two
887
+ tensors, the logits and the labels, and return the logits once processed as desired. The modifications made
888
+ by this function will be reflected in the predictions received by `compute_metrics`.
889
+
890
+ Note that the labels (second parameter) will be `None` if the dataset does not have them.
891
+ peft_config ([`~peft.PeftConfig`], *optional*, defaults to `None`):
892
+ PEFT configuration used to wrap the model. If `None`, the model is not wrapped.
893
+ formatting_func (`Optional[Callable]`):
894
+ Formatting function applied to the dataset before tokenization.
895
+
896
+ """
897
+ def __init__(
898
+ self,
899
+ model,
900
+ args = None,
901
+ data_collator = None,
902
+ train_dataset = None,
903
+ eval_dataset = None,
904
+ processing_class = None,
905
+ compute_loss_func = None,
906
+ compute_metrics = None,
907
+ callbacks = None,
908
+ optimizer_cls_and_kwargs = None,
909
+ preprocess_logits_for_metrics = None,
910
+ peft_config = None,
911
+ formatting_func = None,
912
+ **kwargs
913
+ ):
914
+ if args is None: args = UnslothSFTConfig()
915
+ use_bf16 = getattr(args, 'bf16', False)
916
+ use_fp16 = getattr(args, 'fp16', False)
917
+ force_float32 = False
918
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
919
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
920
+ force_float32 = True
921
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
922
+ dtype = getattr(model.config, 'torch_dtype', None)
923
+ if dtype is None: dtype = model.get_input_embeddings().dtype
924
+ from unsloth_zoo.utils import _get_dtype
925
+ dtype = _get_dtype(dtype)
926
+ float16 = dtype == torch.float16
927
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
928
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
929
+ if force_float32:
930
+ args.fp16 = False
931
+ args.bf16 = False
932
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
933
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
934
+ args.fp16 = float16
935
+ args.bf16 = not float16
936
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
937
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
938
+ args.eval_strategy = 'steps'
939
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
940
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
941
+ if ga_steps is not None and ga_steps > 1:
942
+ from transformers import __version__ as transformers_version
943
+ if Version(transformers_version) <= Version('4.45.2'):
944
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
945
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
946
+ if getattr(args, 'eval_strategy', 'no') != 'no':
947
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
948
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
949
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
950
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
951
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
952
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
953
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
954
+ if force_float32:
955
+ args.bf16_full_eval = False
956
+ args.fp16_full_eval = False
957
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
958
+ args.bf16_full_eval = True
959
+ args.fp16_full_eval = False
960
+ elif not bf16_full_eval and not fp16_full_eval:
961
+ args.bf16_full_eval = args.bf16
962
+ args.fp16_full_eval = args.fp16
963
+ _output_logits = False
964
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
965
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
966
+ if _output_logits:
967
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
968
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
969
+ pass
970
+ else:
971
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
972
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
973
+ if args_max_seq_length is None and model_max_seq_length is not None:
974
+ max_seq_length = model.max_seq_length
975
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
976
+ if model is not None and hasattr(model, 'for_training'):
977
+ model.for_training()
978
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
979
+ if 'processing_class' in locals():
980
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
981
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
982
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
983
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
984
+ if not isinstance(data_collator, UnslothVisionDataCollator):
985
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
986
+ data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
987
+ elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
988
+ data_collator = DataCollatorForSeq2Seq(__tokenizer)
989
+ else:
990
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
991
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
992
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
993
+ if not isinstance(data_collator, UnslothVisionDataCollator):
994
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
995
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
996
+ data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
997
+ else:
998
+ data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
999
+ other_metrics = []
1000
+
1001
+ from unsloth_zoo.logging_utils import PatchRLStatistics
1002
+ PatchRLStatistics('sft_trainer', other_metrics)
1003
+ IGNORED_TOKENIZER_NAMES = os.environ.get('UNSLOTH_IGNORED_TOKENIZER_NAMES', '').split('\n')
1004
+ from unsloth_zoo.tokenizer_utils import fix_untrained_tokens
1005
+ from unsloth_zoo.training_utils import fix_zero_training_loss
1006
+ if 'tokenizer' not in locals(): tokenizer = processing_class
1007
+ fix_untrained_tokens(model, tokenizer, train_dataset, IGNORED_TOKENIZER_NAMES, eps = 1e-16)
1008
+ fix_zero_training_loss(model, tokenizer, train_dataset)
1009
+
1010
+ super().__init__(
1011
+ model = model,
1012
+ args = args,
1013
+ data_collator = data_collator,
1014
+ train_dataset = train_dataset,
1015
+ eval_dataset = eval_dataset,
1016
+ processing_class = processing_class,
1017
+ compute_loss_func = compute_loss_func,
1018
+ compute_metrics = compute_metrics,
1019
+ callbacks = callbacks,
1020
+ optimizer_cls_and_kwargs = optimizer_cls_and_kwargs,
1021
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics,
1022
+ peft_config = peft_config,
1023
+ formatting_func = formatting_func,**kwargs)
1024
+ if hasattr(self, 'neftune_hook_handle'):
1025
+ self.neftune_hook_handle.remove()
1026
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1027
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
1028
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1029
+ pass
1030
+
1031
+ pass
unsloth_compiled_cache/UnslothXPOTrainer.py ADDED
@@ -0,0 +1,1010 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.3.17
3
+ 2025.3.19
4
+ 4.50.3
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from trl.trainer.xpo_trainer import (Any, BaseImageProcessor, BasePairwiseJudge, Callable, Dataset, EvalPrediction, F, FeatureExtractionMixin, IterableDataset, OnlineDPOTrainer, OptimizerNames, Optional, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SIMPLE_CHAT_TEMPLATE, TrainerCallback, Union, XPOConfig, XPOTrainer, empty_cache, generate_model_card, get_comet_experiment_url, get_reward, is_conversational, is_wandb_available, jinja2, maybe_apply_chat_template, nn, os, textwrap, torch, truncate_right, unwrap_model_for_generation, wandb)
13
+
14
+
15
+ import os
16
+ from typing import *
17
+ from dataclasses import dataclass, field
18
+ from packaging.version import Version
19
+ import torch
20
+ import numpy as np
21
+ from contextlib import nullcontext
22
+ from torch.nn import functional as F
23
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
+
25
+ torch_compile_options = {
26
+ "epilogue_fusion" : True,
27
+ "max_autotune" : False,
28
+ "shape_padding" : True,
29
+ "trace.enabled" : False,
30
+ "triton.cudagraphs" : False,
31
+ }
32
+
33
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
+ def selective_log_softmax(logits, index):
35
+ logits = logits.to(torch.float32)
36
+ selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
+ # loop to reduce peak mem consumption
38
+ # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
+ logsumexp_values = torch.logsumexp(logits, dim = -1)
40
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
+ return per_token_logps
42
+ @dataclass
43
+ class UnslothXPOConfig(XPOConfig):
44
+ """
45
+
46
+ Configuration class for the [`XPOTrainer`].
47
+
48
+ Subclass of [`OnlineDPOConfig`] we can use all its arguments and add the following:
49
+
50
+ Parameters:
51
+ alpha (`float` or `list[float]`, *optional*, defaults to `1e-5`):
52
+ Weight of the XPO loss term. If a list of floats is provided then the alpha is selected for each new epoch
53
+ and the last alpha is used for the rest of the epochs.
54
+
55
+ """
56
+ vllm_sampling_params: Optional[Any] = field(
57
+ default = None,
58
+ metadata = {'help': 'vLLM SamplingParams'},
59
+ )
60
+ unsloth_num_chunks : Optional[int] = field(
61
+ default = -1,
62
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
63
+ )
64
+ def __init__(
65
+ self,
66
+ output_dir = None,
67
+ overwrite_output_dir = None,
68
+ do_train = False,
69
+ do_eval = False,
70
+ do_predict = False,
71
+ eval_strategy = 'no',
72
+ prediction_loss_only = False,
73
+ per_device_train_batch_size = 4,
74
+ per_device_eval_batch_size = 4,
75
+ per_gpu_train_batch_size = None,
76
+ per_gpu_eval_batch_size = None,
77
+ gradient_accumulation_steps = 2,
78
+ eval_accumulation_steps = 2,
79
+ eval_delay = 0,
80
+ torch_empty_cache_steps = 250,
81
+ learning_rate = 5e-05,
82
+ weight_decay = 0.01,
83
+ adam_beta1 = 0.9,
84
+ adam_beta2 = 0.999,
85
+ adam_epsilon = 1e-08,
86
+ max_grad_norm = 1.0,
87
+ num_train_epochs = 3.0,
88
+ max_steps = -1,
89
+ lr_scheduler_type = 'linear',
90
+ warmup_ratio = 0.1,
91
+ warmup_steps = 0,
92
+ log_level = 'passive',
93
+ log_level_replica = 'warning',
94
+ log_on_each_node = True,
95
+ logging_dir = None,
96
+ logging_strategy = 'steps',
97
+ logging_first_step = False,
98
+ logging_steps = 1,
99
+ logging_nan_inf_filter = False,
100
+ save_strategy = 'steps',
101
+ save_steps = 500,
102
+ save_total_limit = None,
103
+ save_safetensors = True,
104
+ save_on_each_node = False,
105
+ save_only_model = False,
106
+ restore_callback_states_from_checkpoint = False,
107
+ no_cuda = False,
108
+ use_cpu = False,
109
+ use_mps_device = False,
110
+ seed = 3407,
111
+ data_seed = 3407,
112
+ jit_mode_eval = False,
113
+ use_ipex = False,
114
+ bf16 = False,
115
+ fp16 = False,
116
+ fp16_opt_level = 'O1',
117
+ half_precision_backend = 'auto',
118
+ bf16_full_eval = False,
119
+ fp16_full_eval = False,
120
+ tf32 = None,
121
+ local_rank = -1,
122
+ ddp_backend = None,
123
+ tpu_num_cores = None,
124
+ tpu_metrics_debug = False,
125
+ debug = '',
126
+ dataloader_drop_last = False,
127
+ eval_steps = None,
128
+ dataloader_num_workers = 0,
129
+ dataloader_prefetch_factor = None,
130
+ past_index = -1,
131
+ run_name = None,
132
+ disable_tqdm = None,
133
+ remove_unused_columns = True,
134
+ label_names = None,
135
+ load_best_model_at_end = False,
136
+ metric_for_best_model = None,
137
+ greater_is_better = None,
138
+ ignore_data_skip = False,
139
+ fsdp = '',
140
+ fsdp_min_num_params = 0,
141
+ fsdp_config = None,
142
+ tp_size = 0,
143
+ fsdp_transformer_layer_cls_to_wrap = None,
144
+ accelerator_config = None,
145
+ deepspeed = None,
146
+ label_smoothing_factor = 0.0,
147
+ optim = 'adamw_8bit',
148
+ optim_args = None,
149
+ adafactor = False,
150
+ group_by_length = False,
151
+ length_column_name = 'length',
152
+ report_to = None,
153
+ ddp_find_unused_parameters = None,
154
+ ddp_bucket_cap_mb = None,
155
+ ddp_broadcast_buffers = None,
156
+ dataloader_pin_memory = True,
157
+ dataloader_persistent_workers = False,
158
+ skip_memory_metrics = True,
159
+ use_legacy_prediction_loop = False,
160
+ push_to_hub = False,
161
+ resume_from_checkpoint = None,
162
+ hub_model_id = None,
163
+ hub_strategy = 'every_save',
164
+ hub_token = None,
165
+ hub_private_repo = None,
166
+ hub_always_push = False,
167
+ gradient_checkpointing = False,
168
+ gradient_checkpointing_kwargs = None,
169
+ include_inputs_for_metrics = False,
170
+ eval_do_concat_batches = True,
171
+ fp16_backend = 'auto',
172
+ evaluation_strategy = None,
173
+ push_to_hub_model_id = None,
174
+ push_to_hub_organization = None,
175
+ push_to_hub_token = None,
176
+ mp_parameters = '',
177
+ auto_find_batch_size = False,
178
+ full_determinism = False,
179
+ torchdynamo = None,
180
+ ray_scope = 'last',
181
+ ddp_timeout = 1800,
182
+ torch_compile = False,
183
+ torch_compile_backend = None,
184
+ torch_compile_mode = None,
185
+ dispatch_batches = None,
186
+ split_batches = None,
187
+ include_tokens_per_second = False,
188
+ include_num_input_tokens_seen = False,
189
+ neftune_noise_alpha = None,
190
+ optim_target_modules = None,
191
+ batch_eval_metrics = False,
192
+ eval_on_start = False,
193
+ use_liger_kernel = False,
194
+ eval_use_gather_object = False,
195
+ average_tokens_across_devices = False,
196
+ reward_model_path = None,
197
+ judge = None,
198
+ max_new_tokens = 64,
199
+ max_length = 512,
200
+ temperature = 0.9,
201
+ missing_eos_penalty = None,
202
+ loss_type = 'sigmoid',
203
+ dataset_num_proc = None,
204
+ disable_dropout = True,
205
+ use_vllm = False,
206
+ ds3_gather_for_generation = True,
207
+ vllm_sampling_params = None,
208
+ unsloth_num_chunks = -1,
209
+ **kwargs,
210
+ ):
211
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
212
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
213
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
214
+ output_dir = 'unsloth_training_checkpoints'
215
+ save_strategy = 'no'
216
+ if dataset_num_proc is None:
217
+ from multiprocessing import cpu_count
218
+ dataset_num_proc = cpu_count()
219
+
220
+ super().__init__(
221
+ output_dir = output_dir,
222
+ overwrite_output_dir = overwrite_output_dir,
223
+ do_train = do_train,
224
+ do_eval = do_eval,
225
+ do_predict = do_predict,
226
+ eval_strategy = eval_strategy,
227
+ prediction_loss_only = prediction_loss_only,
228
+ per_device_train_batch_size = per_device_train_batch_size,
229
+ per_device_eval_batch_size = per_device_eval_batch_size,
230
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
231
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
232
+ gradient_accumulation_steps = gradient_accumulation_steps,
233
+ eval_accumulation_steps = eval_accumulation_steps,
234
+ eval_delay = eval_delay,
235
+ torch_empty_cache_steps = torch_empty_cache_steps,
236
+ learning_rate = learning_rate,
237
+ weight_decay = weight_decay,
238
+ adam_beta1 = adam_beta1,
239
+ adam_beta2 = adam_beta2,
240
+ adam_epsilon = adam_epsilon,
241
+ max_grad_norm = max_grad_norm,
242
+ num_train_epochs = num_train_epochs,
243
+ max_steps = max_steps,
244
+ lr_scheduler_type = lr_scheduler_type,
245
+ warmup_ratio = warmup_ratio,
246
+ warmup_steps = warmup_steps,
247
+ log_level = log_level,
248
+ log_level_replica = log_level_replica,
249
+ log_on_each_node = log_on_each_node,
250
+ logging_dir = logging_dir,
251
+ logging_strategy = logging_strategy,
252
+ logging_first_step = logging_first_step,
253
+ logging_steps = logging_steps,
254
+ logging_nan_inf_filter = logging_nan_inf_filter,
255
+ save_strategy = save_strategy,
256
+ save_steps = save_steps,
257
+ save_total_limit = save_total_limit,
258
+ save_safetensors = save_safetensors,
259
+ save_on_each_node = save_on_each_node,
260
+ save_only_model = save_only_model,
261
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
262
+ no_cuda = no_cuda,
263
+ use_cpu = use_cpu,
264
+ use_mps_device = use_mps_device,
265
+ seed = seed,
266
+ data_seed = data_seed,
267
+ jit_mode_eval = jit_mode_eval,
268
+ use_ipex = use_ipex,
269
+ bf16 = bf16,
270
+ fp16 = fp16,
271
+ fp16_opt_level = fp16_opt_level,
272
+ half_precision_backend = half_precision_backend,
273
+ bf16_full_eval = bf16_full_eval,
274
+ fp16_full_eval = fp16_full_eval,
275
+ tf32 = tf32,
276
+ local_rank = local_rank,
277
+ ddp_backend = ddp_backend,
278
+ tpu_num_cores = tpu_num_cores,
279
+ tpu_metrics_debug = tpu_metrics_debug,
280
+ debug = debug,
281
+ dataloader_drop_last = dataloader_drop_last,
282
+ eval_steps = eval_steps,
283
+ dataloader_num_workers = dataloader_num_workers,
284
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
285
+ past_index = past_index,
286
+ run_name = run_name,
287
+ disable_tqdm = disable_tqdm,
288
+ remove_unused_columns = remove_unused_columns,
289
+ label_names = label_names,
290
+ load_best_model_at_end = load_best_model_at_end,
291
+ metric_for_best_model = metric_for_best_model,
292
+ greater_is_better = greater_is_better,
293
+ ignore_data_skip = ignore_data_skip,
294
+ fsdp = fsdp,
295
+ fsdp_min_num_params = fsdp_min_num_params,
296
+ fsdp_config = fsdp_config,
297
+ tp_size = tp_size,
298
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
299
+ accelerator_config = accelerator_config,
300
+ deepspeed = deepspeed,
301
+ label_smoothing_factor = label_smoothing_factor,
302
+ optim = optim,
303
+ optim_args = optim_args,
304
+ adafactor = adafactor,
305
+ group_by_length = group_by_length,
306
+ length_column_name = length_column_name,
307
+ report_to = report_to,
308
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
309
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
310
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
311
+ dataloader_pin_memory = dataloader_pin_memory,
312
+ dataloader_persistent_workers = dataloader_persistent_workers,
313
+ skip_memory_metrics = skip_memory_metrics,
314
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
315
+ push_to_hub = push_to_hub,
316
+ resume_from_checkpoint = resume_from_checkpoint,
317
+ hub_model_id = hub_model_id,
318
+ hub_strategy = hub_strategy,
319
+ hub_token = hub_token,
320
+ hub_private_repo = hub_private_repo,
321
+ hub_always_push = hub_always_push,
322
+ gradient_checkpointing = gradient_checkpointing,
323
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
324
+ include_inputs_for_metrics = include_inputs_for_metrics,
325
+ eval_do_concat_batches = eval_do_concat_batches,
326
+ fp16_backend = fp16_backend,
327
+ evaluation_strategy = evaluation_strategy,
328
+ push_to_hub_model_id = push_to_hub_model_id,
329
+ push_to_hub_organization = push_to_hub_organization,
330
+ push_to_hub_token = push_to_hub_token,
331
+ mp_parameters = mp_parameters,
332
+ auto_find_batch_size = auto_find_batch_size,
333
+ full_determinism = full_determinism,
334
+ torchdynamo = torchdynamo,
335
+ ray_scope = ray_scope,
336
+ ddp_timeout = ddp_timeout,
337
+ torch_compile = torch_compile,
338
+ torch_compile_backend = torch_compile_backend,
339
+ torch_compile_mode = torch_compile_mode,
340
+ dispatch_batches = dispatch_batches,
341
+ split_batches = split_batches,
342
+ include_tokens_per_second = include_tokens_per_second,
343
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
344
+ neftune_noise_alpha = neftune_noise_alpha,
345
+ optim_target_modules = optim_target_modules,
346
+ batch_eval_metrics = batch_eval_metrics,
347
+ eval_on_start = eval_on_start,
348
+ use_liger_kernel = use_liger_kernel,
349
+ eval_use_gather_object = eval_use_gather_object,
350
+ average_tokens_across_devices = average_tokens_across_devices,
351
+ reward_model_path = reward_model_path,
352
+ judge = judge,
353
+ max_new_tokens = max_new_tokens,
354
+ max_length = max_length,
355
+ temperature = temperature,
356
+ missing_eos_penalty = missing_eos_penalty,
357
+ loss_type = loss_type,
358
+ dataset_num_proc = dataset_num_proc,
359
+ disable_dropout = disable_dropout,
360
+ use_vllm = use_vllm,
361
+ ds3_gather_for_generation = ds3_gather_for_generation,**kwargs)
362
+ self.vllm_sampling_params = vllm_sampling_params
363
+ self.unsloth_num_chunks = unsloth_num_chunks
364
+ pass
365
+
366
+ class _UnslothXPOTrainer(OnlineDPOTrainer):
367
+ r""""""
368
+
369
+ _tag_names = ["trl", "xpo"]
370
+
371
+ def __init__(
372
+ self,
373
+ model: Union[PreTrainedModel, nn.Module] = None,
374
+ ref_model: Union[PreTrainedModel, nn.Module] = None,
375
+ reward_model: Optional[nn.Module] = None,
376
+ judge: Optional[BasePairwiseJudge] = None,
377
+ args: Optional[XPOConfig] = None,
378
+ data_collator: Optional[Callable] = None,
379
+ train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
380
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
381
+ processing_class: Optional[
382
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
383
+ ] = None,
384
+ peft_config: Optional[dict] = None,
385
+ compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
386
+ callbacks: Optional[list[TrainerCallback]] = None,
387
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
388
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
389
+ ) -> None:
390
+ super().__init__(
391
+ model=model,
392
+ ref_model=ref_model,
393
+ judge=judge,
394
+ reward_model=reward_model,
395
+ args=args,
396
+ data_collator=data_collator,
397
+ train_dataset=train_dataset,
398
+ eval_dataset=eval_dataset,
399
+ processing_class=processing_class,
400
+ reward_processing_class=processing_class, # for now, XPOTrainer can't use any reward model
401
+ peft_config=peft_config,
402
+ compute_metrics=compute_metrics,
403
+ callbacks=callbacks,
404
+ optimizers=optimizers,
405
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
406
+ )
407
+
408
+ self._alpha = self.args.alpha
409
+
410
+ # Overwrite the stats dictionary to include XPO specific statistics
411
+ self.stats = {
412
+ # Remove "non_score_reward", "rlhf_reward", "scores"
413
+ # Add "loss/dpo", "loss/xpo"
414
+ "loss/dpo": [],
415
+ "loss/xpo": [],
416
+ "objective/kl": [],
417
+ "objective/entropy": [],
418
+ "rewards/chosen": [],
419
+ "rewards/rejected": [],
420
+ "rewards/accuracies": [],
421
+ "rewards/margins": [],
422
+ "logps/chosen": [],
423
+ "logps/rejected": [],
424
+ # Replace "contain_eos_token" by "model_contain_eos_token" and "ref_contain_eos_token"
425
+ "val/model_contain_eos_token": [],
426
+ "val/ref_contain_eos_token": [],
427
+ "alpha": [],
428
+ "beta": [],
429
+ }
430
+ if self.reward_model is not None:
431
+ # Replace "scores" by "model_scores" and "ref_scores"
432
+ self.stats["objective/model_scores"] = []
433
+ self.stats["objective/ref_scores"] = []
434
+ self.stats["objective/scores_margin"] = []
435
+
436
+ @property
437
+ def alpha(self):
438
+ if isinstance(self._alpha, list):
439
+ epoch = self.state.epoch
440
+ return self._alpha[epoch] if epoch < len(self._alpha) else self._alpha[-1]
441
+ else:
442
+ return self._alpha
443
+
444
+ def _generate_completions(self, prompts, model):
445
+ with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
446
+ model_output = unwrapped_model.generate(
447
+ input_ids=prompts["input_ids"],
448
+ attention_mask=prompts["attention_mask"],
449
+ generation_config=self.generation_config,
450
+ )
451
+
452
+ ref_model = model if self.ref_model is None else self.ref_model
453
+ with torch.no_grad(), unwrap_model_for_generation(ref_model, self.accelerator) as unwrapped_ref_model:
454
+ ref_output = unwrapped_ref_model.generate(
455
+ input_ids=prompts["input_ids"],
456
+ attention_mask=prompts["attention_mask"],
457
+ generation_config=self.generation_config,
458
+ )
459
+
460
+ return model_output, ref_output
461
+
462
+ def _process_completions(self, model_output, ref_output, prompts):
463
+ context_length = prompts["input_ids"].shape[1]
464
+
465
+ # Process model completions
466
+ model_completion_ids = model_output[:, context_length:]
467
+ model_completion_ids, model_completion_mask = truncate_right(
468
+ model_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id
469
+ )
470
+ model_data = {
471
+ "input_ids": torch.cat((prompts["input_ids"], model_completion_ids), dim=1),
472
+ "attention_mask": torch.cat((prompts["attention_mask"], model_completion_mask), dim=1),
473
+ "raw": prompts["raw"],
474
+ }
475
+
476
+ # Process reference model completions
477
+ ref_completion_ids = ref_output[:, context_length:]
478
+ ref_completion_ids, ref_completion_mask = truncate_right(
479
+ ref_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id
480
+ )
481
+ ref_data = {
482
+ "input_ids": torch.cat((prompts["input_ids"], ref_completion_ids), dim=1),
483
+ "attention_mask": torch.cat((prompts["attention_mask"], ref_completion_mask), dim=1),
484
+ "raw": prompts["raw"],
485
+ }
486
+
487
+ return model_data, ref_data
488
+
489
+ def _compute_rewards(self, model_data, ref_data, context_length):
490
+ with torch.no_grad():
491
+ _, model_scores, _ = get_reward(
492
+ self.reward_model, model_data["input_ids"], self.processing_class.pad_token_id, context_length
493
+ )
494
+ _, ref_scores, _ = get_reward(
495
+ self.reward_model, ref_data["input_ids"], self.processing_class.pad_token_id, context_length
496
+ )
497
+
498
+ # Apply EOS penalty if needed
499
+ if self.args.missing_eos_penalty is not None:
500
+ model_contain_eos = torch.any(model_data["input_ids"] == self.processing_class.eos_token_id, dim=-1)
501
+ ref_contain_eos = torch.any(ref_data["input_ids"] == self.processing_class.eos_token_id, dim=-1)
502
+ model_scores[~model_contain_eos] -= self.args.missing_eos_penalty
503
+ ref_scores[~ref_contain_eos] -= self.args.missing_eos_penalty
504
+
505
+ return model_scores, ref_scores
506
+
507
+ def _compute_judge(self, model_data, ref_data, context_length):
508
+ prompts = model_data["raw"]
509
+ model_data_completions = self.processing_class.batch_decode(
510
+ model_data["input_ids"][:, context_length:], skip_special_tokens=True
511
+ )
512
+ model_data_completions = [completion.strip() for completion in model_data_completions]
513
+
514
+ ref_data_completions = self.processing_class.batch_decode(
515
+ ref_data["input_ids"][:, context_length:], skip_special_tokens=True
516
+ )
517
+ ref_data_completions = [completion.strip() for completion in ref_data_completions]
518
+
519
+ if is_conversational({"prompt": prompts[0]}):
520
+ model_data_completions = [
521
+ [{"role": "assistant", "content": completion}] for completion in model_data_completions
522
+ ]
523
+ environment = jinja2.Environment()
524
+ template = environment.from_string(SIMPLE_CHAT_TEMPLATE)
525
+ prompts = [template.render(messages=message) for message in prompts]
526
+ model_data_completions = [template.render(messages=completion) for completion in model_data_completions]
527
+
528
+ ref_data_completions = [
529
+ [{"role": "assistant", "content": completion}] for completion in ref_data_completions
530
+ ]
531
+ ref_data_completions = [template.render(messages=completion) for completion in ref_data_completions]
532
+
533
+ ranks_of_first_completion = self.judge.judge(
534
+ prompts,
535
+ list(zip(model_data_completions, ref_data_completions)),
536
+ )
537
+ # convert ranks to a True/False mask:
538
+ # when rank == 0, it means the first completion is the best
539
+ # when rank == 1, it means the second completion is the best
540
+ return torch.tensor([rank == 0 for rank in ranks_of_first_completion], device=model_data["input_ids"].device)
541
+
542
+ def _compute_logprobs(self, model, model_data, ref_data, context_length):
543
+ def compute_logprobs_for_data(m, data):
544
+ output = m(data["input_ids"], attention_mask=data["attention_mask"])
545
+ logits = output.logits[:, context_length - 1 : -1]
546
+ token_logprobs = selective_log_softmax(logits, data["input_ids"][:, context_length:])
547
+ return token_logprobs
548
+
549
+ # Compute logprobs for model completions
550
+ model_logprobs_model_data = compute_logprobs_for_data(model, model_data)
551
+ # Compute logprobs for model on reference completions (for XPO loss)
552
+ model_logprobs_ref_data = compute_logprobs_for_data(model, ref_data)
553
+
554
+ # Compute logprobs for reference model completions
555
+ with torch.no_grad():
556
+ if self.ref_model is None:
557
+ with model.disable_adapter():
558
+ ref_logprobs_model_data = compute_logprobs_for_data(model, model_data)
559
+ ref_logprobs_ref_data = compute_logprobs_for_data(model, ref_data)
560
+ else:
561
+ ref_logprobs_model_data = compute_logprobs_for_data(self.ref_model, model_data)
562
+ ref_logprobs_ref_data = compute_logprobs_for_data(self.ref_model, ref_data)
563
+
564
+ # Mask padding tokens
565
+ model_padding_mask = model_data["attention_mask"][:, context_length:] == 0
566
+ ref_padding_mask = ref_data["attention_mask"][:, context_length:] == 0
567
+ model_logprobs_model_data = model_logprobs_model_data.masked_fill(model_padding_mask, 0.0)
568
+ model_logprobs_ref_data = model_logprobs_ref_data.masked_fill(ref_padding_mask, 0.0)
569
+ ref_logprobs_ref_data = ref_logprobs_ref_data.masked_fill(ref_padding_mask, 0.0)
570
+ ref_logprobs_model_data = ref_logprobs_model_data.masked_fill(model_padding_mask, 0.0)
571
+
572
+ return model_logprobs_model_data, model_logprobs_ref_data, ref_logprobs_ref_data, ref_logprobs_model_data
573
+
574
+ def _compute_losses(
575
+ self,
576
+ model_logprobs_model_data,
577
+ model_logprobs_ref_data,
578
+ ref_logprobs_ref_data,
579
+ ref_logprobs_model_data,
580
+ chosen_mask,
581
+ ):
582
+ # Compute log probs
583
+ model_logprobs_model_data_sum = model_logprobs_model_data.sum(1)
584
+ model_logprobs_ref_data_sum = model_logprobs_ref_data.sum(1)
585
+ ref_logprobs_ref_data_sum = ref_logprobs_ref_data.sum(1)
586
+ ref_logprobs_model_data_sum = ref_logprobs_model_data.sum(1)
587
+
588
+ chosen_model_logprobs = torch.where(chosen_mask, model_logprobs_model_data_sum, model_logprobs_ref_data_sum)
589
+ chosen_ref_logprobs = torch.where(chosen_mask, ref_logprobs_model_data_sum, ref_logprobs_ref_data_sum)
590
+ chosen_log_ratios = chosen_model_logprobs - chosen_ref_logprobs
591
+
592
+ rejected_model_logprobs = torch.where(~chosen_mask, model_logprobs_model_data_sum, model_logprobs_ref_data_sum)
593
+ rejected_ref_logprobs = torch.where(~chosen_mask, ref_logprobs_model_data_sum, ref_logprobs_ref_data_sum)
594
+ rejected_log_ratios = rejected_model_logprobs - rejected_ref_logprobs
595
+
596
+ # Compute logits as the difference between chosen and rejected log ratios
597
+ logits = chosen_log_ratios - rejected_log_ratios
598
+
599
+ if self.args.loss_type == "sigmoid":
600
+ dpo_losses = -F.logsigmoid(self.beta * logits)
601
+ elif self.args.loss_type == "ipo":
602
+ dpo_losses = (logits - 1 / (2 * self.beta)) ** 2
603
+ else:
604
+ raise NotImplementedError(f"invalid loss type {self.args.loss_type}")
605
+
606
+ # Compute XPO specific loss
607
+ xpo_losses = self.alpha * model_logprobs_ref_data_sum
608
+
609
+ # Total loss
610
+ loss = (dpo_losses + xpo_losses).mean()
611
+
612
+ return loss, dpo_losses, xpo_losses
613
+
614
+ def _log_statistics(
615
+ self,
616
+ model_data,
617
+ ref_data,
618
+ model_logprobs_model_data,
619
+ model_logprobs_ref_data,
620
+ ref_logprobs_ref_data,
621
+ ref_logprobs_model_data,
622
+ chosen_mask,
623
+ dpo_losses,
624
+ xpo_losses,
625
+ context_length,
626
+ model_scores=None,
627
+ ref_scores=None,
628
+ ):
629
+ # Helper function to gather and compute mean
630
+ def gather_mean(tensor):
631
+ return self.accelerator.gather_for_metrics(tensor).mean().item()
632
+
633
+ # Log losses
634
+ self.stats["loss/dpo"].append(gather_mean(dpo_losses))
635
+ self.stats["loss/xpo"].append(gather_mean(xpo_losses))
636
+
637
+ # Log scores
638
+ if self.reward_model is not None:
639
+ self.stats["objective/model_scores"].append(gather_mean(model_scores))
640
+ self.stats["objective/ref_scores"].append(gather_mean(ref_scores))
641
+ self.stats["objective/scores_margin"].append(gather_mean(model_scores - ref_scores))
642
+
643
+ # Log logprobs
644
+ model_logprobs_model_data_sum = model_logprobs_model_data.sum(1)
645
+ model_logprobs_ref_data_sum = model_logprobs_ref_data.sum(1)
646
+ ref_logprobs_ref_data_sum = ref_logprobs_ref_data.sum(1)
647
+ ref_logprobs_model_data_sum = ref_logprobs_model_data.sum(1)
648
+
649
+ chosen_model_logprobs = torch.where(chosen_mask, model_logprobs_model_data_sum, model_logprobs_ref_data_sum)
650
+ chosen_ref_logprobs = torch.where(chosen_mask, ref_logprobs_model_data_sum, ref_logprobs_ref_data_sum)
651
+ chosen_log_ratios = chosen_model_logprobs - chosen_ref_logprobs
652
+
653
+ rejected_model_logprobs = torch.where(~chosen_mask, model_logprobs_model_data_sum, model_logprobs_ref_data_sum)
654
+ rejected_ref_logprobs = torch.where(~chosen_mask, ref_logprobs_model_data_sum, ref_logprobs_ref_data_sum)
655
+ rejected_log_ratios = rejected_model_logprobs - rejected_ref_logprobs
656
+
657
+ self.stats["logps/chosen"].append(gather_mean(chosen_model_logprobs.mean() + chosen_ref_logprobs.mean()))
658
+ self.stats["logps/rejected"].append(gather_mean(rejected_model_logprobs.mean() + rejected_ref_logprobs.mean()))
659
+
660
+ # Log rewards
661
+ # Compute various statistics
662
+ chosen_rewards = chosen_log_ratios * self.beta
663
+ rejected_rewards = rejected_log_ratios * self.beta
664
+ self.stats["rewards/chosen"].append(gather_mean(chosen_rewards.mean()))
665
+ self.stats["rewards/rejected"].append(gather_mean(rejected_rewards.mean()))
666
+
667
+ # Calculate KL divergence for model and ref data
668
+ kl_model_data = model_logprobs_model_data - ref_logprobs_model_data
669
+ kl_ref_data = model_logprobs_ref_data - ref_logprobs_ref_data
670
+ mean_kl = (kl_model_data.sum(1) + kl_ref_data.sum(1)).mean() / 2
671
+ self.stats["objective/kl"].append(gather_mean(mean_kl))
672
+
673
+ # Calculate entropy for model and ref data
674
+ entropy_model_data = -model_logprobs_model_data.sum(1)
675
+ entropy_ref_data = -model_logprobs_ref_data.sum(1)
676
+ mean_entropy = (entropy_model_data.mean() + entropy_ref_data.mean()) / 2
677
+ self.stats["objective/entropy"].append(gather_mean(mean_entropy))
678
+
679
+ # Calculate margins
680
+ margin = chosen_rewards - rejected_rewards
681
+ self.stats["rewards/margins"].append(gather_mean(margin.mean()))
682
+
683
+ # Calculate accuracy
684
+ accuracy = (margin > 0).float()
685
+ self.stats["rewards/accuracies"].append(gather_mean(accuracy.mean()))
686
+
687
+ # Log EOS token statistics
688
+ model_eos = (model_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1)
689
+ ref_eos = (ref_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1)
690
+ self.stats["val/model_contain_eos_token"].append(gather_mean(model_eos.float()))
691
+ self.stats["val/ref_contain_eos_token"].append(gather_mean(ref_eos.float()))
692
+
693
+ # Log alpha and beta
694
+ self.stats["alpha"].append(self.alpha)
695
+ self.stats["beta"].append(self.beta)
696
+
697
+ def training_step(
698
+ self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
699
+ ) -> torch.Tensor:
700
+ model.train()
701
+
702
+ # Apply chat template and tokenize the input
703
+ batch_size = len(next(iter(inputs.values())))
704
+ prompts = inputs["prompt"]
705
+ inputs = [{k: v[i] for k, v in inputs.items()} for i in range(batch_size)]
706
+ inputs = [maybe_apply_chat_template(x, self.processing_class) for x in inputs]
707
+ inputs = [self.tokenize_row(x, self.model.config.is_encoder_decoder, self.processing_class) for x in inputs]
708
+ inputs = self.data_collator(inputs)
709
+
710
+ # need the prompt_ only
711
+ inputs = self._prepare_inputs(inputs)
712
+ context_length = inputs["prompt_input_ids"].shape[1]
713
+ prompts = {
714
+ "input_ids": inputs["prompt_input_ids"],
715
+ "attention_mask": inputs["prompt_attention_mask"],
716
+ "raw": prompts,
717
+ }
718
+ del inputs
719
+
720
+ # Sample completions from both the model and the reference model
721
+ model_output, ref_output = self._generate_completions(prompts, model)
722
+
723
+ # Process model completions
724
+ model_data, ref_data = self._process_completions(model_output, ref_output, prompts)
725
+
726
+ # Compute rewards
727
+ if self.reward_model is not None:
728
+ model_scores, ref_scores = self._compute_rewards(model_data, ref_data, context_length)
729
+ chosen_mask = model_scores >= ref_scores
730
+ else:
731
+ model_scores, ref_scores = None, None
732
+ chosen_mask = self._compute_judge(model_data, ref_data, context_length)
733
+
734
+ # Compute logprobs
735
+ model_logprobs_model_data, model_logprobs_ref_data, ref_logprobs_ref_data, ref_logprobs_model_data = (
736
+ self._compute_logprobs(model, model_data, ref_data, context_length)
737
+ )
738
+
739
+ # Compute loss
740
+ loss, dpo_losses, xpo_losses = self._compute_losses(
741
+ model_logprobs_model_data,
742
+ model_logprobs_ref_data,
743
+ ref_logprobs_ref_data,
744
+ ref_logprobs_model_data,
745
+ chosen_mask,
746
+ )
747
+
748
+ # Log everything
749
+ self._log_statistics(
750
+ model_data,
751
+ ref_data,
752
+ model_logprobs_model_data.detach(),
753
+ model_logprobs_ref_data.detach(),
754
+ ref_logprobs_ref_data,
755
+ ref_logprobs_model_data,
756
+ chosen_mask,
757
+ dpo_losses.detach(),
758
+ xpo_losses.detach(),
759
+ context_length,
760
+ model_scores,
761
+ ref_scores,
762
+ )
763
+
764
+ if (
765
+ self.args.torch_empty_cache_steps is not None
766
+ and self.state.global_step % self.args.torch_empty_cache_steps == 0
767
+ ):
768
+ empty_cache()
769
+
770
+ kwargs = {}
771
+ # For LOMO optimizers you need to explicitly use the learning rate
772
+ if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
773
+ kwargs["learning_rate"] = self._get_learning_rate()
774
+
775
+ if self.args.n_gpu > 1:
776
+ loss = loss.mean() # mean() to average on multi-gpu parallel training
777
+
778
+ if self.use_apex:
779
+ with amp.scale_loss(loss, self.optimizer) as scaled_loss:
780
+ scaled_loss.backward()
781
+ else:
782
+ self.accelerator.backward(loss, **kwargs)
783
+
784
+ return loss.detach() / self.args.gradient_accumulation_steps
785
+
786
+ def create_model_card(
787
+ self,
788
+ model_name: Optional[str] = None,
789
+ dataset_name: Optional[str] = None,
790
+ tags: Union[str, list[str], None] = None,
791
+ ):
792
+ """
793
+ Creates a draft of a model card using the information available to the `Trainer`.
794
+
795
+ Args:
796
+ model_name (`str` or `None`, *optional*, defaults to `None`):
797
+ Name of the model.
798
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
799
+ Name of the dataset used for training.
800
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
801
+ Tags to be associated with the model card.
802
+ """
803
+ if not self.is_world_process_zero():
804
+ return
805
+
806
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
807
+ base_model = self.model.config._name_or_path
808
+ else:
809
+ base_model = None
810
+
811
+ tags = tags or []
812
+ if isinstance(tags, str):
813
+ tags = [tags]
814
+
815
+ if hasattr(self.model.config, "unsloth_version"):
816
+ tags.append("unsloth")
817
+
818
+ citation = textwrap.dedent("""\
819
+ @article{jung2024binary,
820
+ title = {{Exploratory Preference Optimization: Harnessing Implicit Q*-Approximation for Sample-Efficient RLHF}},
821
+ author = {Tengyang Xie and Dylan J. Foster and Akshay Krishnamurthy and Corby Rosset and Ahmed Awadallah and Alexander Rakhlin},
822
+ year = 2024,
823
+ eprint = {arXiv:2405.21046}
824
+ }""")
825
+
826
+ model_card = generate_model_card(
827
+ base_model=base_model,
828
+ model_name=model_name,
829
+ hub_model_id=self.hub_model_id,
830
+ dataset_name=dataset_name,
831
+ tags=tags,
832
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
833
+ comet_url=get_comet_experiment_url(),
834
+ trainer_name="XPO",
835
+ trainer_citation=citation,
836
+ paper_title="Exploratory Preference Optimization: Harnessing Implicit Q*-Approximation for Sample-Efficient RLHF",
837
+ paper_id="2405.21046",
838
+ )
839
+
840
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
841
+ class UnslothXPOTrainer(_UnslothXPOTrainer):
842
+ """
843
+
844
+ Initialize XPOTrainer as a subclass of [`OnlineDPOConfig`].
845
+
846
+ Args:
847
+ model (`transformers.PreTrainedModel`):
848
+ The model to train, preferably an `AutoModelForCausalLM`.
849
+ ref_model (`PreTrainedModelWrapper`):
850
+ Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no
851
+ reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized.
852
+ reward_model (`transformers.PreTrainedModel`):
853
+ The reward model to score completions with, preferably an `AutoModelForSequenceClassification`.
854
+ judge (`BasePairwiseJudge`):
855
+ The judge to use for pairwise comparison of model completions.
856
+ args (`XPOConfig`):
857
+ The XPO config arguments to use for training.
858
+ data_collator (`transformers.DataCollator`):
859
+ The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
860
+ which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
861
+ train_dataset (`datasets.Dataset`):
862
+ The dataset to use for training.
863
+ eval_dataset (`datasets.Dataset`):
864
+ The dataset to use for evaluation.
865
+ processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
866
+ Processing class used to process the data. If provided, will be used to automatically process the inputs
867
+ for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
868
+ reuse the fine-tuned model.
869
+ peft_config (`dict`):
870
+ The peft config to use for training.
871
+ compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
872
+ The function to use to compute the metrics. Must take a `EvalPrediction` and return
873
+ a dictionary string to metric values.
874
+ callbacks (`list[transformers.TrainerCallback]`):
875
+ The callbacks to use for training.
876
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
877
+ The optimizer and scheduler to use for training.
878
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
879
+ The function to use to preprocess the logits before computing the metrics.
880
+
881
+ """
882
+ def __init__(
883
+ self,
884
+ model = None,
885
+ ref_model = None,
886
+ reward_model = None,
887
+ judge = None,
888
+ args = None,
889
+ data_collator = None,
890
+ train_dataset = None,
891
+ eval_dataset = None,
892
+ processing_class = None,
893
+ peft_config = None,
894
+ compute_metrics = None,
895
+ callbacks = None,
896
+ preprocess_logits_for_metrics = None,
897
+ **kwargs
898
+ ):
899
+ if args is None: args = UnslothXPOConfig()
900
+ use_bf16 = getattr(args, 'bf16', False)
901
+ use_fp16 = getattr(args, 'fp16', False)
902
+ force_float32 = False
903
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
904
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
905
+ force_float32 = True
906
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
907
+ dtype = getattr(model.config, 'torch_dtype', None)
908
+ if dtype is None: dtype = model.get_input_embeddings().dtype
909
+ from unsloth_zoo.utils import _get_dtype
910
+ dtype = _get_dtype(dtype)
911
+ float16 = dtype == torch.float16
912
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
913
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
914
+ if force_float32:
915
+ args.fp16 = False
916
+ args.bf16 = False
917
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
918
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
919
+ args.fp16 = float16
920
+ args.bf16 = not float16
921
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
922
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
923
+ args.eval_strategy = 'steps'
924
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
925
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
926
+ if ga_steps is not None and ga_steps > 1:
927
+ from transformers import __version__ as transformers_version
928
+ if Version(transformers_version) <= Version('4.45.2'):
929
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
930
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
931
+ if getattr(args, 'eval_strategy', 'no') != 'no':
932
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
933
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
934
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
935
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
936
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
937
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
938
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
939
+ if force_float32:
940
+ args.bf16_full_eval = False
941
+ args.fp16_full_eval = False
942
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
943
+ args.bf16_full_eval = True
944
+ args.fp16_full_eval = False
945
+ elif not bf16_full_eval and not fp16_full_eval:
946
+ args.bf16_full_eval = args.bf16
947
+ args.fp16_full_eval = args.fp16
948
+ _output_logits = False
949
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
950
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
951
+ if _output_logits:
952
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
953
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
954
+ pass
955
+ else:
956
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
957
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
958
+ if args_max_seq_length is None and model_max_seq_length is not None:
959
+ max_seq_length = model.max_seq_length
960
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
961
+ if model is not None and hasattr(model, 'for_training'):
962
+ model.for_training()
963
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
964
+ if 'processing_class' in locals():
965
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
966
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
967
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
968
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
969
+ if not isinstance(data_collator, UnslothVisionDataCollator):
970
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
971
+ data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
972
+ elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
973
+ data_collator = DataCollatorForSeq2Seq(__tokenizer)
974
+ else:
975
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
976
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
977
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
978
+ if not isinstance(data_collator, UnslothVisionDataCollator):
979
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
980
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
981
+ data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
982
+ else:
983
+ data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
984
+ other_metrics = []
985
+
986
+ from unsloth_zoo.logging_utils import PatchRLStatistics
987
+ PatchRLStatistics('xpo_trainer', other_metrics)
988
+
989
+ super().__init__(
990
+ model = model,
991
+ ref_model = ref_model,
992
+ reward_model = reward_model,
993
+ judge = judge,
994
+ args = args,
995
+ data_collator = data_collator,
996
+ train_dataset = train_dataset,
997
+ eval_dataset = eval_dataset,
998
+ processing_class = processing_class,
999
+ peft_config = peft_config,
1000
+ compute_metrics = compute_metrics,
1001
+ callbacks = callbacks,
1002
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics,**kwargs)
1003
+ if hasattr(self, 'neftune_hook_handle'):
1004
+ self.neftune_hook_handle.remove()
1005
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1006
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
1007
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1008
+ pass
1009
+
1010
+ pass
unsloth_compiled_cache/__pycache__/UnslothAlignPropTrainer.cpython-311.pyc ADDED
Binary file (32.9 kB). View file
 
unsloth_compiled_cache/__pycache__/UnslothBCOTrainer.cpython-311.pyc ADDED
Binary file (91.7 kB). View file
 
unsloth_compiled_cache/__pycache__/UnslothCPOTrainer.cpython-311.pyc ADDED
Binary file (75.6 kB). View file
 
unsloth_compiled_cache/__pycache__/UnslothDDPOTrainer.cpython-311.pyc ADDED
Binary file (45.5 kB). View file
 
unsloth_compiled_cache/__pycache__/UnslothDPOTrainer.cpython-311.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d636dd9bd05907f328159064584e1667333117b7120ddbed1c3c316bc279cc36
3
+ size 103583
unsloth_compiled_cache/__pycache__/UnslothGKDTrainer.cpython-311.pyc ADDED
Binary file (37.7 kB). View file
 
unsloth_compiled_cache/__pycache__/UnslothGRPOTrainer.cpython-311.pyc ADDED
Binary file (78.5 kB). View file
 
unsloth_compiled_cache/__pycache__/UnslothKTOTrainer.cpython-311.pyc ADDED
Binary file (87.4 kB). View file
 
unsloth_compiled_cache/__pycache__/UnslothNashMDTrainer.cpython-311.pyc ADDED
Binary file (47.3 kB). View file
 
unsloth_compiled_cache/__pycache__/UnslothORPOTrainer.cpython-311.pyc ADDED
Binary file (75.6 kB). View file
 
unsloth_compiled_cache/__pycache__/UnslothOnlineDPOTrainer.cpython-311.pyc ADDED
Binary file (67.2 kB). View file
 
unsloth_compiled_cache/__pycache__/UnslothPPOTrainer.cpython-311.pyc ADDED
Binary file (62.7 kB). View file
 
unsloth_compiled_cache/__pycache__/UnslothPRMTrainer.cpython-311.pyc ADDED
Binary file (36.5 kB). View file
 
unsloth_compiled_cache/__pycache__/UnslothRLOOTrainer.cpython-311.pyc ADDED
Binary file (54.2 kB). View file
 
unsloth_compiled_cache/__pycache__/UnslothRewardTrainer.cpython-311.pyc ADDED
Binary file (38.9 kB). View file
 
unsloth_compiled_cache/__pycache__/UnslothSFTTrainer.cpython-311.pyc ADDED
Binary file (48.1 kB). View file
 
unsloth_compiled_cache/__pycache__/UnslothXPOTrainer.cpython-311.pyc ADDED
Binary file (49.9 kB). View file
 
upload_utils.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import HfApi
2
+ import os
3
+
4
+ def upload_to_huggingface(model, tokenizer, repo_name, token):
5
+ """
6
+ Upload a fine-tuned model and tokenizer to Hugging Face.
7
+
8
+ Args:
9
+ model: The fine-tuned model to upload.
10
+ tokenizer: The tokenizer associated with the model.
11
+ repo_name (str): The name of the repository to create/update on Hugging Face.
12
+ token (str): Hugging Face API token.
13
+
14
+ Returns:
15
+ str: A message indicating the success or failure of the upload.
16
+ """
17
+ try:
18
+ # Save the model and tokenizer to a temporary directory
19
+ temp_dir = "temp_model"
20
+ model.save_pretrained(temp_dir)
21
+ tokenizer.save_pretrained(temp_dir)
22
+
23
+ # Initialize the Hugging Face API
24
+ api = HfApi()
25
+
26
+ # Create or update the repository
27
+ api.create_repo(repo_id=repo_name, token=token, exist_ok=True)
28
+
29
+ # Upload the model and tokenizer files
30
+ api.upload_folder(
31
+ folder_path=temp_dir,
32
+ repo_id=repo_name,
33
+ token=token
34
+ )
35
+
36
+ # Clean up the temporary directory
37
+ for file in os.listdir(temp_dir):
38
+ os.remove(os.path.join(temp_dir, file))
39
+ os.rmdir(temp_dir)
40
+
41
+ return f"Model successfully uploaded to https://huggingface.co/{repo_name}"
42
+ except Exception as e:
43
+ return f"Error uploading model: {str(e)}"
44
+
45
+ def upload_gguf_to_huggingface(gguf_file_path, repo_name, token):
46
+ """
47
+ Upload a GGUF converted model to Hugging Face.
48
+
49
+ Args:
50
+ gguf_file_path (str): The path to the GGUF file.
51
+ repo_name (str): The name of the repository to create/update on Hugging Face.
52
+ token (str): Hugging Face API token.
53
+
54
+ Returns:
55
+ str: A message indicating the success or failure of the upload.
56
+ """
57
+ try:
58
+ if not os.path.exists(gguf_file_path):
59
+ return f"Error: GGUF file not found at {gguf_file_path}"
60
+
61
+ # Initialize the Hugging Face API
62
+ api = HfApi()
63
+
64
+ # Create or update the repository
65
+ api.create_repo(repo_id=repo_name, token=token, exist_ok=True)
66
+
67
+ # Upload the GGUF file
68
+ api.upload_file(
69
+ path_or_fileobj=gguf_file_path,
70
+ path_in_repo=os.path.basename(gguf_file_path),
71
+ repo_id=repo_name,
72
+ token=token
73
+ )
74
+
75
+ return f"GGUF model successfully uploaded to https://huggingface.co/{repo_name}"
76
+ except Exception as e:
77
+ return f"Error uploading GGUF model: {str(e)}"