Riddhi Bhagwat commited on
Commit
430ca63
·
unverified ·
2 Parent(s): fbc38c9 d062581

Merge pull request #11 from jenbenarye/main

Browse files

training (lora) & dataset processing scripts

.gitignore CHANGED
@@ -160,4 +160,16 @@ cython_debug/
160
  # and can be added to the global gitignore or merged into this file. For a more nuclear
161
  # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162
  #.idea/
163
- user_feedback
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  # and can be added to the global gitignore or merged into this file. For a more nuclear
161
  # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162
  #.idea/
163
+ user_feedback
164
+
165
+
166
+ # Hugging Face cache
167
+ wandb/
168
+ .cache/
169
+ cached_*
170
+
171
+ # Hugging Face datasets
172
+ datasets/
173
+
174
+ # Hugging Face models
175
+ models/
app/app.py CHANGED
@@ -386,6 +386,9 @@ css = """
386
  .option.svelte-pcaovb {
387
  display: none !important;
388
  }
 
 
 
389
  """
390
 
391
  with gr.Blocks(css=css) as demo:
 
386
  .option.svelte-pcaovb {
387
  display: none !important;
388
  }
389
+ .retry-btn {
390
+ display: none !important;
391
+ }
392
  """
393
 
394
  with gr.Blocks(css=css) as demo:
ml/adapter_metadata.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from datetime import datetime
3
+ from typing import List, Dict
4
+ import json
5
+
6
+ @dataclass
7
+ class AdapterMetadata:
8
+ """Metadata for tracking adapter training history"""
9
+ training_timestamp: str # ISO format timestamp
10
+ training_params: Dict # Training parameters used
11
+ model_name: str # Base model name
12
+ language: str # Language of the adapter
13
+ version: str # Version of the adapter
14
+
15
+ # Create class instance from a dictionary
16
+ @classmethod
17
+ def from_dict(cls, data: Dict):
18
+ return cls(**data)
19
+
20
+ # Convert class instance to a dictionary
21
+ def to_dict(self) -> Dict:
22
+ return {
23
+ "training_timestamp": self.training_timestamp,
24
+ "dataset_entries": self.dataset_entries,
25
+ "training_params": self.training_params,
26
+ "model_name": self.model_name,
27
+ "language": self.language,
28
+ "version": self.version
29
+ }
30
+
31
+ # Save metadata to a JSON file
32
+ def save(self, filepath: str):
33
+ with open(filepath, 'w') as f:
34
+ json.dump(self.to_dict(), f, indent=2)
35
+
36
+ # Load metadata from a JSON file
37
+ @classmethod
38
+ def load(cls, filepath: str):
39
+ with open(filepath, 'r') as f:
40
+ data = json.load(f)
41
+ return cls.from_dict(data)
ml/dataset_training.ipynb DELETED
@@ -1,398 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": 43,
6
- "metadata": {},
7
- "outputs": [],
8
- "source": [
9
- "#dependencies:\n",
10
- "import pandas as pd\n",
11
- "\n",
12
- "import torch\n",
13
- "from transformers import GPT2Tokenizer\n",
14
- "\n",
15
- "from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer"
16
- ]
17
- },
18
- {
19
- "cell_type": "code",
20
- "execution_count": 44,
21
- "metadata": {},
22
- "outputs": [
23
- {
24
- "data": {
25
- "application/vnd.jupyter.widget-view+json": {
26
- "model_id": "b8a22b8d60c0417eafbf554832398287",
27
- "version_major": 2,
28
- "version_minor": 0
29
- },
30
- "text/plain": [
31
- "Resolving data files: 0%| | 0/18 [00:00<?, ?it/s]"
32
- ]
33
- },
34
- "metadata": {},
35
- "output_type": "display_data"
36
- },
37
- {
38
- "data": {
39
- "application/vnd.jupyter.widget-view+json": {
40
- "model_id": "b83d2624c2b14986a8297821460225ab",
41
- "version_major": 2,
42
- "version_minor": 0
43
- },
44
- "text/plain": [
45
- "Resolving data files: 0%| | 0/18 [00:00<?, ?it/s]"
46
- ]
47
- },
48
- "metadata": {},
49
- "output_type": "display_data"
50
- },
51
- {
52
- "data": {
53
- "application/vnd.jupyter.widget-view+json": {
54
- "model_id": "b4304c0f48cb472589b5e80d3a42cba2",
55
- "version_major": 2,
56
- "version_minor": 0
57
- },
58
- "text/plain": [
59
- "Resolving data files: 0%| | 0/18 [00:00<?, ?it/s]"
60
- ]
61
- },
62
- "metadata": {},
63
- "output_type": "display_data"
64
- }
65
- ],
66
- "source": [
67
- "#loading datasets:\n",
68
- "from datasets import load_dataset\n",
69
- "\n",
70
- "ds = load_dataset(\"stanfordnlp/SHP\", split='train')"
71
- ]
72
- },
73
- {
74
- "cell_type": "code",
75
- "execution_count": 45,
76
- "metadata": {},
77
- "outputs": [
78
- {
79
- "name": "stdout",
80
- "output_type": "stream",
81
- "text": [
82
- "Index(['post_id', 'domain', 'upvote_ratio', 'history', 'c_root_id_A',\n",
83
- " 'c_root_id_B', 'created_at_utc_A', 'created_at_utc_B', 'score_A',\n",
84
- " 'score_B', 'human_ref_A', 'human_ref_B', 'labels', 'seconds_difference',\n",
85
- " 'score_ratio'],\n",
86
- " dtype='object')\n"
87
- ]
88
- }
89
- ],
90
- "source": [
91
- "df = ds.to_pandas()\n",
92
- "print(df.columns)\n"
93
- ]
94
- },
95
- {
96
- "cell_type": "code",
97
- "execution_count": 46,
98
- "metadata": {},
99
- "outputs": [
100
- {
101
- "data": {
102
- "text/html": [
103
- "<div>\n",
104
- "<style scoped>\n",
105
- " .dataframe tbody tr th:only-of-type {\n",
106
- " vertical-align: middle;\n",
107
- " }\n",
108
- "\n",
109
- " .dataframe tbody tr th {\n",
110
- " vertical-align: top;\n",
111
- " }\n",
112
- "\n",
113
- " .dataframe thead th {\n",
114
- " text-align: right;\n",
115
- " }\n",
116
- "</style>\n",
117
- "<table border=\"1\" class=\"dataframe\">\n",
118
- " <thead>\n",
119
- " <tr style=\"text-align: right;\">\n",
120
- " <th></th>\n",
121
- " <th>upvote_ratio</th>\n",
122
- " <th>history</th>\n",
123
- " <th>score_A</th>\n",
124
- " <th>score_B</th>\n",
125
- " <th>human_ref_A</th>\n",
126
- " <th>human_ref_B</th>\n",
127
- " <th>labels</th>\n",
128
- " <th>score_ratio</th>\n",
129
- " </tr>\n",
130
- " </thead>\n",
131
- " <tbody>\n",
132
- " <tr>\n",
133
- " <th>0</th>\n",
134
- " <td>0.99</td>\n",
135
- " <td>In an interview right before receiving the 201...</td>\n",
136
- " <td>52</td>\n",
137
- " <td>54</td>\n",
138
- " <td>Currently wrapping up my PhD. There is a stark...</td>\n",
139
- " <td>It’s ironic to me that research has shown that...</td>\n",
140
- " <td>0</td>\n",
141
- " <td>1.038462</td>\n",
142
- " </tr>\n",
143
- " <tr>\n",
144
- " <th>1</th>\n",
145
- " <td>0.95</td>\n",
146
- " <td>If any professor is reading this: please do no...</td>\n",
147
- " <td>5</td>\n",
148
- " <td>17</td>\n",
149
- " <td>And when your teacher doesn't listen or pay at...</td>\n",
150
- " <td>I'm pretty strict on time, to the point where ...</td>\n",
151
- " <td>0</td>\n",
152
- " <td>3.400000</td>\n",
153
- " </tr>\n",
154
- " <tr>\n",
155
- " <th>2</th>\n",
156
- " <td>0.95</td>\n",
157
- " <td>If any professor is reading this: please do no...</td>\n",
158
- " <td>5</td>\n",
159
- " <td>7</td>\n",
160
- " <td>Profs can be oblivious? What’s new!</td>\n",
161
- " <td>This sounds like a problem with a specific pro...</td>\n",
162
- " <td>0</td>\n",
163
- " <td>1.400000</td>\n",
164
- " </tr>\n",
165
- " <tr>\n",
166
- " <th>3</th>\n",
167
- " <td>0.95</td>\n",
168
- " <td>If any professor is reading this: please do no...</td>\n",
169
- " <td>7</td>\n",
170
- " <td>5</td>\n",
171
- " <td>This sounds like a problem with a specific pro...</td>\n",
172
- " <td>And when your teacher doesn't listen or pay at...</td>\n",
173
- " <td>1</td>\n",
174
- " <td>1.400000</td>\n",
175
- " </tr>\n",
176
- " <tr>\n",
177
- " <th>4</th>\n",
178
- " <td>0.95</td>\n",
179
- " <td>If any professor is reading this: please do no...</td>\n",
180
- " <td>6</td>\n",
181
- " <td>7</td>\n",
182
- " <td>This would be totally unacceptable in my class...</td>\n",
183
- " <td>This sounds like a problem with a specific pro...</td>\n",
184
- " <td>0</td>\n",
185
- " <td>1.166667</td>\n",
186
- " </tr>\n",
187
- " <tr>\n",
188
- " <th>...</th>\n",
189
- " <td>...</td>\n",
190
- " <td>...</td>\n",
191
- " <td>...</td>\n",
192
- " <td>...</td>\n",
193
- " <td>...</td>\n",
194
- " <td>...</td>\n",
195
- " <td>...</td>\n",
196
- " <td>...</td>\n",
197
- " </tr>\n",
198
- " <tr>\n",
199
- " <th>348713</th>\n",
200
- " <td>0.94</td>\n",
201
- " <td>Can I get in trouble for giving my neighbor hi...</td>\n",
202
- " <td>7</td>\n",
203
- " <td>25</td>\n",
204
- " <td>Just put up a fence. Legally he isn't responsi...</td>\n",
205
- " <td>Whatever you do, don't cut his trees down.</td>\n",
206
- " <td>0</td>\n",
207
- " <td>3.571429</td>\n",
208
- " </tr>\n",
209
- " <tr>\n",
210
- " <th>348714</th>\n",
211
- " <td>0.94</td>\n",
212
- " <td>Can I get in trouble for giving my neighbor hi...</td>\n",
213
- " <td>2</td>\n",
214
- " <td>25</td>\n",
215
- " <td>If OP pays someone to clean his yard, and then...</td>\n",
216
- " <td>Whatever you do, don't cut his trees down.</td>\n",
217
- " <td>0</td>\n",
218
- " <td>12.500000</td>\n",
219
- " </tr>\n",
220
- " <tr>\n",
221
- " <th>348715</th>\n",
222
- " <td>0.94</td>\n",
223
- " <td>Can I get in trouble for giving my neighbor hi...</td>\n",
224
- " <td>9</td>\n",
225
- " <td>7</td>\n",
226
- " <td>My observation is that both of you are idiots...</td>\n",
227
- " <td>Are you Rand Paul's neighbor? https://www.gq....</td>\n",
228
- " <td>1</td>\n",
229
- " <td>1.285714</td>\n",
230
- " </tr>\n",
231
- " <tr>\n",
232
- " <th>348716</th>\n",
233
- " <td>0.94</td>\n",
234
- " <td>Can I get in trouble for giving my neighbor hi...</td>\n",
235
- " <td>9</td>\n",
236
- " <td>7</td>\n",
237
- " <td>My observation is that both of you are idiots...</td>\n",
238
- " <td>Just put up a fence. Legally he isn't responsi...</td>\n",
239
- " <td>1</td>\n",
240
- " <td>1.285714</td>\n",
241
- " </tr>\n",
242
- " <tr>\n",
243
- " <th>348717</th>\n",
244
- " <td>0.94</td>\n",
245
- " <td>Can I get in trouble for giving my neighbor hi...</td>\n",
246
- " <td>7</td>\n",
247
- " <td>2</td>\n",
248
- " <td>Capture his acts on camera. Collect and bag l...</td>\n",
249
- " <td>If OP pays someone to clean his yard, and then...</td>\n",
250
- " <td>1</td>\n",
251
- " <td>3.500000</td>\n",
252
- " </tr>\n",
253
- " </tbody>\n",
254
- "</table>\n",
255
- "<p>348718 rows × 8 columns</p>\n",
256
- "</div>"
257
- ],
258
- "text/plain": [
259
- " upvote_ratio history \\\n",
260
- "0 0.99 In an interview right before receiving the 201... \n",
261
- "1 0.95 If any professor is reading this: please do no... \n",
262
- "2 0.95 If any professor is reading this: please do no... \n",
263
- "3 0.95 If any professor is reading this: please do no... \n",
264
- "4 0.95 If any professor is reading this: please do no... \n",
265
- "... ... ... \n",
266
- "348713 0.94 Can I get in trouble for giving my neighbor hi... \n",
267
- "348714 0.94 Can I get in trouble for giving my neighbor hi... \n",
268
- "348715 0.94 Can I get in trouble for giving my neighbor hi... \n",
269
- "348716 0.94 Can I get in trouble for giving my neighbor hi... \n",
270
- "348717 0.94 Can I get in trouble for giving my neighbor hi... \n",
271
- "\n",
272
- " score_A score_B human_ref_A \\\n",
273
- "0 52 54 Currently wrapping up my PhD. There is a stark... \n",
274
- "1 5 17 And when your teacher doesn't listen or pay at... \n",
275
- "2 5 7 Profs can be oblivious? What’s new! \n",
276
- "3 7 5 This sounds like a problem with a specific pro... \n",
277
- "4 6 7 This would be totally unacceptable in my class... \n",
278
- "... ... ... ... \n",
279
- "348713 7 25 Just put up a fence. Legally he isn't responsi... \n",
280
- "348714 2 25 If OP pays someone to clean his yard, and then... \n",
281
- "348715 9 7 My observation is that both of you are idiots... \n",
282
- "348716 9 7 My observation is that both of you are idiots... \n",
283
- "348717 7 2 Capture his acts on camera. Collect and bag l... \n",
284
- "\n",
285
- " human_ref_B labels score_ratio \n",
286
- "0 It’s ironic to me that research has shown that... 0 1.038462 \n",
287
- "1 I'm pretty strict on time, to the point where ... 0 3.400000 \n",
288
- "2 This sounds like a problem with a specific pro... 0 1.400000 \n",
289
- "3 And when your teacher doesn't listen or pay at... 1 1.400000 \n",
290
- "4 This sounds like a problem with a specific pro... 0 1.166667 \n",
291
- "... ... ... ... \n",
292
- "348713 Whatever you do, don't cut his trees down. 0 3.571429 \n",
293
- "348714 Whatever you do, don't cut his trees down. 0 12.500000 \n",
294
- "348715 Are you Rand Paul's neighbor? https://www.gq.... 1 1.285714 \n",
295
- "348716 Just put up a fence. Legally he isn't responsi... 1 1.285714 \n",
296
- "348717 If OP pays someone to clean his yard, and then... 1 3.500000 \n",
297
- "\n",
298
- "[348718 rows x 8 columns]"
299
- ]
300
- },
301
- "execution_count": 46,
302
- "metadata": {},
303
- "output_type": "execute_result"
304
- }
305
- ],
306
- "source": [
307
- "# df['response_length'] = df['history'].apply(len)\n",
308
- "# df['label'] = df['response_length'].apply(lambda x: 'long' if x > 100 else 'short')\n",
309
- "df.drop(columns=['post_id', 'domain', 'c_root_id_A', 'c_root_id_B', 'created_at_utc_A', 'created_at_utc_B', 'seconds_difference'])"
310
- ]
311
- },
312
- {
313
- "cell_type": "code",
314
- "execution_count": 47,
315
- "metadata": {},
316
- "outputs": [
317
- {
318
- "name": "stderr",
319
- "output_type": "stream",
320
- "text": [
321
- "/Users/riddhib/.pyenv/versions/3.10.13/lib/python3.10/site-packages/transformers/tokenization_utils_base.py:1617: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be deprecated in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884\n",
322
- " warnings.warn(\n"
323
- ]
324
- }
325
- ],
326
- "source": [
327
- "model = AutoModelForCausalLMWithValueHead.from_pretrained(\"gpt2\")\n",
328
- "ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(\"gpt2\")\n",
329
- "tokenizer = GPT2Tokenizer.from_pretrained(\"gpt2\")\n",
330
- "tokenizer.pad_token = tokenizer.eos_token"
331
- ]
332
- },
333
- {
334
- "cell_type": "code",
335
- "execution_count": 48,
336
- "metadata": {},
337
- "outputs": [],
338
- "source": [
339
- "from trl_rlhf_data import runner, ScriptArguments\n",
340
- "import re\n",
341
- "from dataclasses import dataclass\n",
342
- "from typing import Dict, List, Optional\n",
343
- "\n",
344
- "from datasets import load_dataset\n",
345
- "from transformers import HfArgumentParser"
346
- ]
347
- },
348
- {
349
- "cell_type": "code",
350
- "execution_count": 49,
351
- "metadata": {},
352
- "outputs": [
353
- {
354
- "ename": "TypeError",
355
- "evalue": "runner() takes 0 positional arguments but 1 was given",
356
- "output_type": "error",
357
- "traceback": [
358
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
359
- "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
360
- "Cell \u001b[0;32mIn[49], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m dataset \u001b[38;5;241m=\u001b[39m \u001b[43mrunner\u001b[49m\u001b[43m(\u001b[49m\u001b[43mScriptArguments\u001b[49m\u001b[43m)\u001b[49m\n",
361
- "\u001b[0;31mTypeError\u001b[0m: runner() takes 0 positional arguments but 1 was given"
362
- ]
363
- }
364
- ],
365
- "source": [
366
- "dataset = runner(ScriptArguments)"
367
- ]
368
- },
369
- {
370
- "cell_type": "code",
371
- "execution_count": null,
372
- "metadata": {},
373
- "outputs": [],
374
- "source": []
375
- }
376
- ],
377
- "metadata": {
378
- "kernelspec": {
379
- "display_name": "Python 3",
380
- "language": "python",
381
- "name": "python3"
382
- },
383
- "language_info": {
384
- "codemirror_mode": {
385
- "name": "ipython",
386
- "version": 3
387
- },
388
- "file_extension": ".py",
389
- "mimetype": "text/x-python",
390
- "name": "python",
391
- "nbconvert_exporter": "python",
392
- "pygments_lexer": "ipython3",
393
- "version": "3.10.13"
394
- }
395
- },
396
- "nbformat": 4,
397
- "nbformat_minor": 2
398
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ml/kto.py DELETED
@@ -1,117 +0,0 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """
16
- Run the KTO training script with the commands below. In general, the optimal configuration for KTO will be similar to that of DPO.
17
-
18
- # Full training:
19
- python examples/scripts/kto.py \
20
- --dataset_name trl-lib/kto-mix-14k \
21
- --model_name_or_path=trl-lib/qwen1.5-1.8b-sft \
22
- --per_device_train_batch_size 16 \
23
- --num_train_epochs 1 \
24
- --learning_rate 5e-7 \
25
- --lr_scheduler_type=cosine \
26
- --gradient_accumulation_steps 1 \
27
- --logging_steps 10 \
28
- --eval_steps 500 \
29
- --output_dir=kto-aligned-model \
30
- --warmup_ratio 0.1 \
31
- --report_to wandb \
32
- --bf16 \
33
- --logging_first_step
34
-
35
- # QLoRA:
36
- python examples/scripts/kto.py \
37
- --dataset_name trl-lib/kto-mix-14k \
38
- --model_name_or_path=trl-lib/qwen1.5-1.8b-sft \
39
- --per_device_train_batch_size 8 \
40
- --num_train_epochs 1 \
41
- --learning_rate 5e-7 \
42
- --lr_scheduler_type=cosine \
43
- --gradient_accumulation_steps 1 \
44
- --logging_steps 10 \
45
- --eval_steps 500 \
46
- --output_dir=kto-aligned-model-lora \
47
- --warmup_ratio 0.1 \
48
- --report_to wandb \
49
- --bf16 \
50
- --logging_first_step \
51
- --use_peft \
52
- --load_in_4bit \
53
- --lora_target_modules=all-linear \
54
- --lora_r=16 \
55
- --lora_alpha=16
56
- """
57
-
58
- from datasets import load_dataset
59
- from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
60
-
61
- from trl import (
62
- KTOConfig,
63
- KTOTrainer,
64
- ModelConfig,
65
- ScriptArguments,
66
- get_peft_config,
67
- setup_chat_format,
68
- )
69
-
70
-
71
- if __name__ == "__main__":
72
- parser = HfArgumentParser((ScriptArguments, KTOConfig, ModelConfig))
73
- script_args, training_args, model_args = parser.parse_args_into_dataclasses()
74
-
75
- # Load a pretrained model
76
- model = AutoModelForCausalLM.from_pretrained(
77
- model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
78
- )
79
- ref_model = AutoModelForCausalLM.from_pretrained(
80
- model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
81
- )
82
-
83
- tokenizer = AutoTokenizer.from_pretrained(
84
- model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
85
- )
86
- if tokenizer.pad_token is None:
87
- tokenizer.pad_token = tokenizer.eos_token
88
-
89
- # If we are aligning a base model, we use ChatML as the default template
90
- if tokenizer.chat_template is None:
91
- model, tokenizer = setup_chat_format(model, tokenizer)
92
-
93
- # Load the dataset
94
- dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
95
-
96
- # Initialize the KTO trainer
97
- trainer = KTOTrainer(
98
- model,
99
- ref_model,
100
- args=training_args,
101
- train_dataset=dataset[script_args.dataset_train_split],
102
- eval_dataset=(
103
- dataset[script_args.dataset_test_split]
104
- if training_args.eval_strategy != "no"
105
- else None
106
- ),
107
- processing_class=tokenizer,
108
- peft_config=get_peft_config(model_args),
109
- )
110
-
111
- # Train and push the model to the Hub
112
- trainer.train()
113
-
114
- # Save and push to hub
115
- trainer.save_model(training_args.output_dir)
116
- if training_args.push_to_hub:
117
- trainer.push_to_hub(dataset_name=script_args.dataset_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ml/kto_dataset_processor.py CHANGED
@@ -1,65 +1,210 @@
1
- from datasets import load_dataset, Dataset
2
  import pandas as pd
3
- from pdb import set_trace as st
 
 
 
 
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- def process_dataset_ultrafeedback():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  """
8
- Processes the 'train_prefs' and 'test_prefs' splits of the 'HuggingFaceH4/ultrafeedback_binarized' dataset
9
- into a unified format for preference modeling.
 
 
 
 
 
10
 
11
  Returns:
12
- dict: A dictionary containing the unified 'train' and 'test' splits of the dataset in the KTO format.
13
- Each split is a Hugging Face Dataset object.
 
 
14
  """
15
- # Load the relevant splits of the dataset
16
- dataset_name = "HuggingFaceH4/ultrafeedback_binarized"
17
- train_prefs = load_dataset(dataset_name, split="train_prefs")
18
- test_prefs = load_dataset(dataset_name, split="test_prefs")
19
-
20
- # Function to transform a single example into the desired schema
21
- def transform_data(example):
22
- data_points = []
23
- # Chosen completion
24
- chosen_completion = example["chosen"][1]["content"]
25
- if chosen_completion.strip(): # Check for non-empty completions
26
- data_points.append({
27
- "prompt": example["prompt"],
28
- "completion": chosen_completion.strip(),
29
- "label": True
30
- })
31
- # Rejected completion
32
- rejected_completion = example["rejected"][1]["content"]
33
- if rejected_completion.strip(): # Check for non-empty completions
34
- data_points.append({
35
- "prompt": example["prompt"],
36
- "completion": rejected_completion.strip(),
37
- "label": False
38
- })
39
- return data_points
40
-
41
- # Process train and test splits
42
- train_data = []
43
- test_data = []
44
-
45
- for example in train_prefs:
46
- train_data.extend(transform_data(example))
47
-
48
- for example in test_prefs:
49
- test_data.extend(transform_data(example))
50
-
51
- # Convert unified data to DataFrames
52
- train_df = pd.DataFrame(train_data)
53
- test_df = pd.DataFrame(test_data)
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  # Convert to Hugging Face Dataset
57
- unified_train = Dataset.from_pandas(train_df)
58
- unified_test = Dataset.from_pandas(test_df)
59
 
60
- return {"train": unified_train, "test": unified_test}
 
 
61
 
 
62
 
63
  if __name__ == "__main__":
64
- kto_dataset = process_dataset_ultrafeedback()
65
- st()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import Dataset, load_dataset
2
  import pandas as pd
3
+ from sklearn.model_selection import train_test_split
4
+ import json
5
+ from ipdb import set_trace as st
6
+ from transformers import AutoTokenizer
7
+ from enum import Enum
8
 
9
+ class SupportedLanguages(str, Enum):
10
+ """Enumeration of supported languages"""
11
+ ENGLISH = "English"
12
+ DUTCH = "Dutch"
13
+ ITALIAN = "Italian"
14
+ SPANISH = "Spanish"
15
+ FRENCH = "French"
16
+ GERMAN = "German"
17
+ PORTUGUESE = "Portuguese"
18
+ RUSSIAN = "Russian"
19
+ CHINESE = "Chinese"
20
+ JAPANESE = "Japanese"
21
+ KOREAN = "Korean"
22
 
23
+ def transform_conversation(
24
+ entry: dict,
25
+ model_name: str,
26
+ max_history_turns: int = 10,
27
+ max_history_tokens: int = 4000
28
+ ) -> list:
29
+ """Transform conversation into KTO format with history"""
30
+ data_points = []
31
+ conversation = entry["conversation"]
32
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
33
+
34
+ for i, message in enumerate(conversation):
35
+ # Only create data points for assistant messages that have ratings
36
+ if message["role"] != "assistant" or message["rating"] not in [1, -1]:
37
+ continue
38
+
39
+ # Get previous messages up to limits
40
+ formatted_history = []
41
+ formatted_prompt = ""
42
+ tokens = 0
43
+ pairs = 0 # Count complete user/assistant pairs
44
+
45
+ # Start from the current message and work backwards
46
+ current_idx = i - 1
47
+ while current_idx >= 0 and pairs < max_history_turns:
48
+ # We need both user and assistant messages to form a pair
49
+ if current_idx > 0 and conversation[current_idx]["role"] == "assistant" and conversation[current_idx-1]["role"] == "user":
50
+ # Add the pair to history
51
+ formatted_history.insert(0, conversation[current_idx-1]) # user
52
+ formatted_history.insert(1, conversation[current_idx]) # assistant
53
+
54
+ # Check token limit
55
+ try:
56
+ current_formatted = tokenizer.apply_chat_template(formatted_history, tokenize=False)
57
+ current_tokens = len(tokenizer.encode(current_formatted))
58
+
59
+ if current_tokens > max_history_tokens:
60
+ formatted_history = formatted_history[2:] # Remove the oldest pair
61
+ break
62
+
63
+ formatted_prompt = current_formatted
64
+ tokens = current_tokens
65
+ pairs += 1
66
+ current_idx -= 2
67
+ except Exception:
68
+ # If template application fails, remove the last added pair
69
+ formatted_history = formatted_history[2:]
70
+ break
71
+ else:
72
+ current_idx -= 1
73
+
74
+ # Add the final user message that prompted the rated response
75
+ if i > 0 and conversation[i-1]["role"] == "user":
76
+ last_history = formatted_history + [conversation[i-1]]
77
+ try:
78
+ formatted_prompt = tokenizer.apply_chat_template(last_history, tokenize=False)
79
+ except Exception:
80
+ # If template application fails, use the previous valid prompt
81
+ pass
82
+
83
+ data_points.append({
84
+ "prompt": formatted_prompt.strip(),
85
+ "completion": message["content"].strip(),
86
+ "label": message["rating"] == 1,
87
+ "timestamp": entry["timestamp"],
88
+ "session_id": entry["session_id"],
89
+ "conversation_id": entry["conversation_id"],
90
+ "language": entry["language"]
91
+ })
92
+
93
+ return data_points
94
+
95
+ def process_feel_dataset(
96
+ language: str,
97
+ model_name: str = "CohereForAI/aya-expanse-8b",
98
+ max_history_turns: int = 10,
99
+ max_history_tokens: int = 4000
100
+ ):
101
  """
102
+ Processes the feel dataset into a format suitable for KTO training using TRL.
103
+
104
+ Args:
105
+ language: Language to filter the dataset for (must be one of SupportedLanguages)
106
+ model_name: Name of the model to format for
107
+ max_history_turns: Maximum number of previous turns to include in history
108
+ max_history_tokens: Maximum number of tokens allowed in history
109
 
110
  Returns:
111
+ dict: A dictionary containing the 'train' and 'test' splits of the dataset in KTO format
112
+
113
+ Raises:
114
+ ValueError: If language is not provided or not in SupportedLanguages
115
  """
116
+ # Validate language
117
+ if not language:
118
+ raise ValueError("Language parameter is required")
119
+
120
+ try:
121
+ # Validate that it's a supported language
122
+ SupportedLanguages(language)
123
+ except ValueError:
124
+ supported_langs = "\n- ".join([lang.value for lang in SupportedLanguages])
125
+ raise ValueError(
126
+ f"Invalid language: '{language}'\n"
127
+ f"Supported languages are:\n- {supported_langs}"
128
+ )
129
+
130
+ # Load feel dataset from HuggingFace
131
+ feel_dataset = load_dataset("feel-fl/feel-feedback")["train"]
132
+
133
+ # Filter dataset by language
134
+ feel_dataset = feel_dataset.filter(lambda x: x["language"] == language)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
+ if len(feel_dataset) == 0:
137
+ raise ValueError(f"No data found for language: {language}")
138
+
139
+ kto_data = []
140
+
141
+ # Process all conversations in the filtered dataset
142
+ for entry in feel_dataset:
143
+ kto_data.extend(transform_conversation(
144
+ entry,
145
+ model_name,
146
+ max_history_turns,
147
+ max_history_tokens
148
+ ))
149
+
150
+ if len(kto_data) == 0:
151
+ raise ValueError(f"No valid training examples found for language: {language}")
152
+
153
+ # Convert to DataFrame
154
+ kto_df = pd.DataFrame(kto_data)
155
+
156
+ # Split into train and test sets (70% train, 30% test)
157
+ train_df, test_df = train_test_split(kto_df, test_size=0.3, random_state=42)
158
+
159
+ # Reset index to remove '__index_level_0__'
160
+ train_df = train_df.reset_index(drop=True)
161
+ test_df = test_df.reset_index(drop=True)
162
 
163
  # Convert to Hugging Face Dataset
164
+ train_dataset = Dataset.from_pandas(train_df)
165
+ test_dataset = Dataset.from_pandas(test_df)
166
 
167
+ print(f"Processed {len(kto_data)} examples for language: {language}")
168
+ print(f"Train set size: {len(train_dataset)}")
169
+ print(f"Test set size: {len(test_dataset)}")
170
 
171
+ return {"train": train_dataset, "test": test_dataset}
172
 
173
  if __name__ == "__main__":
174
+ # Process the dataset
175
+ datasets = process_feel_dataset("English")
176
+
177
+ # Print distribution of positive/negative labels
178
+ train_labels = datasets['train']['label']
179
+ test_labels = datasets['test']['label']
180
+
181
+ print("\nLabel Distribution:")
182
+ print("Train set:")
183
+ print(f"Positive feedback: {sum(train_labels)}")
184
+ print(f"Negative feedback: {len(train_labels) - sum(train_labels)}")
185
+ print(f"Positive ratio: {sum(train_labels)/len(train_labels):.2%}")
186
+
187
+ print("\nTest set:")
188
+ print(f"Positive feedback: {sum(test_labels)}")
189
+ print(f"Negative feedback: {len(test_labels) - sum(test_labels)}")
190
+ print(f"Positive ratio: {sum(test_labels)/len(test_labels):.2%}")
191
+
192
+ # Load original FEEL dataset
193
+ feel_dataset = load_dataset("feel-fl/feel-feedback", split="train")
194
+
195
+ # Print one original conversation
196
+ print("\nOriginal conversation from FEEL dataset:")
197
+ print(json.dumps(feel_dataset[0], indent=2))
198
+
199
+ # Print sample entries from processed dataset
200
+ print("\nSample entries from processed KTO dataset:")
201
+ print("\n" + "="*80 + "\nTRAIN SET SAMPLES\n" + "="*80)
202
+
203
+ # Export datasets to CSV
204
+ train_df = datasets['train'].to_pandas()
205
+ test_df = datasets['test'].to_pandas()
206
+
207
+ train_df.to_csv('kto_train_dataset.csv', index=False)
208
+ test_df.to_csv('kto_test_dataset.csv', index=False)
209
+
210
+ print("\nDatasets exported to 'kto_train_dataset.csv' and 'kto_test_dataset.csv'")
ml/{kto_pipeline.py → trainer.py} RENAMED
@@ -1,35 +1,58 @@
 
1
  import torch
2
  from dataclasses import dataclass
3
  from accelerate import PartialState
4
  from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
5
  from trl import KTOConfig, KTOTrainer, ModelConfig, get_peft_config, maybe_unpair_preference_dataset, setup_chat_format
6
- from kto_dataset_processor import process_dataset_ultrafeedback
7
  from datetime import datetime
8
  import wandb
 
 
 
 
 
 
 
9
 
10
  ####################################
11
  # CONFIGURATION
12
  ####################################
13
 
 
14
  @dataclass
15
  class ScriptArguments:
16
  """
17
  Configuration for the script.
18
  """
19
- process_dataset_func: callable = process_dataset_ultrafeedback # process_dataset function from kto_dataset_processor.py
20
- checkpoint_path: str = None # Checkpoint path
21
- push_to_hub: bool = False # Whether to push the model to the Hugging Face hub
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  @dataclass
24
  class ModelArguments(ModelConfig):
25
  """
26
  Configuration for the model.
27
  """
28
- model_name: str = "HuggingFaceH4/zephyr-7b-beta"
29
  use_peft: bool = True
30
  lora_target_modules: str = "all-linear"
31
  lora_r: int = 16
32
  lora_alpha: int = 16
 
33
 
34
  @dataclass
35
  class TrainingArguments(KTOConfig):
@@ -38,7 +61,7 @@ class TrainingArguments(KTOConfig):
38
  """
39
  output_dir: str = f"kto_{ModelArguments.model_name}_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
40
  num_train_epochs: int = 1
41
- per_device_train_batch_size: int = 4 # Highest that runs well
42
  learning_rate: float = 5e-7
43
  lr_scheduler_type: str = "cosine"
44
  gradient_accumulation_steps: int = 1
@@ -48,8 +71,6 @@ class TrainingArguments(KTOConfig):
48
  bf16: bool = True
49
  logging_first_step: bool = True
50
 
51
-
52
-
53
  # Initialize configurations
54
  script_args = ScriptArguments()
55
  training_args = TrainingArguments()
@@ -61,7 +82,7 @@ model_args = ModelArguments()
61
 
62
  def load_model_and_tokenizer(model_args):
63
  """
64
- Load a model and tokenizer from a specified path.
65
  """
66
  model = AutoModelForCausalLM.from_pretrained(
67
  model_args.model_name,
@@ -74,74 +95,97 @@ def load_model_and_tokenizer(model_args):
74
  trust_remote_code=model_args.trust_remote_code
75
  )
76
 
77
- # Set pad token if missing
78
  if tokenizer.pad_token is None:
79
  tokenizer.pad_token = tokenizer.eos_token
80
 
81
- # Setup chat format if not present
82
- if tokenizer.chat_template is None:
83
  model, tokenizer = setup_chat_format(model, tokenizer)
84
 
 
85
 
 
 
 
 
 
86
 
87
- return model, tokenizer
 
 
 
88
 
 
 
 
 
89
 
90
- # def find_unknown_tokens(tokenizer, texts):
91
- # """
92
- # Identify tokens in the dataset that are not in the tokenizer's vocabulary.
93
- # """
94
- # all_tokens = set()
95
- # for text in texts:
96
- # tokens = tokenizer.tokenize(text)
97
- # all_tokens.update(tokens)
98
- # vocab = set(tokenizer.get_vocab().keys())
99
- # unknown_tokens = all_tokens - vocab
100
- # return unknown_tokens
101
 
 
 
102
 
103
- # def add_tokens_to_tokenizer(tokenizer, model, dataset):
104
- # """
105
- # Extend the tokenizer's vocabulary with missing tokens and resize the model embeddings.
106
- # """
107
- # # Extract all texts from the dataset
108
- # texts = [example["completion"] for example in dataset["train"]]
109
 
110
- # # Identify unknown tokens
111
- # unknown_tokens = find_unknown_tokens(tokenizer, texts)
112
- # print(f"Found {len(unknown_tokens)} unknown tokens: {list(unknown_tokens)[:10]}...")
113
 
114
- # # Add unknown tokens to tokenizer
115
- # tokenizer.add_tokens(list(unknown_tokens))
116
- # model.resize_token_embeddings(len(tokenizer))
117
- # print(f"Tokenizer vocabulary size after extension: {len(tokenizer)}")
118
 
 
 
119
 
120
  ####################################
121
  # MAIN LOGIC
122
  ####################################
123
 
124
  def main():
125
- # Initialize wandb
126
  wandb.init(project="kto")
127
 
128
- # Load models and tokenizer
129
- print("Loading models and tokenizer...")
 
 
130
  model, tokenizer = load_model_and_tokenizer(model_args)
131
  ref_model, _ = load_model_and_tokenizer(model_args)
132
  print("Models and tokenizer loaded.")
133
 
134
- # Load and process datasets using external function
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  print("Processing dataset...")
136
- dataset = process_dataset_ultrafeedback()
137
  print("Dataset processed.")
138
 
139
- # # Extend tokenizer with missing tokens
140
- # print("Adding unknown tokens to tokenizer...")
141
- # add_tokens_to_tokenizer(tokenizer, model, dataset)
142
- # print("Tokenizer updated.")
143
-
144
- # Initialize trainer
145
  print("Initializing trainer...")
146
  trainer = KTOTrainer(
147
  model=model,
@@ -149,8 +193,8 @@ def main():
149
  args=training_args,
150
  train_dataset=dataset["train"],
151
  eval_dataset=dataset["test"],
152
- tokenizer=tokenizer,
153
- peft_config=get_peft_config(model_args),
154
  )
155
 
156
  # Training
@@ -182,10 +226,29 @@ def main():
182
  "step": metrics.get("step")
183
  })
184
 
185
- # Save model and optionally push to hub
186
- trainer.save_model(training_args.output_dir)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  if script_args.push_to_hub:
188
- trainer.push_to_hub()
 
 
189
 
190
  print("Process completed.")
191
 
 
1
+ import os
2
  import torch
3
  from dataclasses import dataclass
4
  from accelerate import PartialState
5
  from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
6
  from trl import KTOConfig, KTOTrainer, ModelConfig, get_peft_config, maybe_unpair_preference_dataset, setup_chat_format
7
+ from kto_dataset_processor import process_feel_dataset, SupportedLanguages
8
  from datetime import datetime
9
  import wandb
10
+ from enum import Enum
11
+ from typing import Optional
12
+ from pathlib import Path
13
+
14
+
15
+ # PEFT library: attach and load adapters
16
+ from peft import get_peft_model, PeftModel
17
 
18
  ####################################
19
  # CONFIGURATION
20
  ####################################
21
 
22
+
23
  @dataclass
24
  class ScriptArguments:
25
  """
26
  Configuration for the script.
27
  """
28
+ process_dataset_func: callable = process_feel_dataset
29
+ checkpoint_path: str = None
30
+ push_to_hub: bool = True
31
+ language: str = "English" # Default to English
32
+
33
+ def __post_init__(self):
34
+ """Validate the language after initialization"""
35
+ try:
36
+ # This will raise ValueError if language is not in the enum
37
+ SupportedLanguages(self.language)
38
+ except ValueError:
39
+ supported_langs = "\n- ".join([lang.value for lang in SupportedLanguages])
40
+ raise ValueError(
41
+ f"Invalid language: '{self.language}'\n"
42
+ f"Supported languages are:\n- {supported_langs}"
43
+ )
44
 
45
  @dataclass
46
  class ModelArguments(ModelConfig):
47
  """
48
  Configuration for the model.
49
  """
50
+ model_name: str = "CohereForAI/aya-expanse-8b"
51
  use_peft: bool = True
52
  lora_target_modules: str = "all-linear"
53
  lora_r: int = 16
54
  lora_alpha: int = 16
55
+ trust_remote_code: bool = True
56
 
57
  @dataclass
58
  class TrainingArguments(KTOConfig):
 
61
  """
62
  output_dir: str = f"kto_{ModelArguments.model_name}_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
63
  num_train_epochs: int = 1
64
+ per_device_train_batch_size: int = 4
65
  learning_rate: float = 5e-7
66
  lr_scheduler_type: str = "cosine"
67
  gradient_accumulation_steps: int = 1
 
71
  bf16: bool = True
72
  logging_first_step: bool = True
73
 
 
 
74
  # Initialize configurations
75
  script_args = ScriptArguments()
76
  training_args = TrainingArguments()
 
82
 
83
  def load_model_and_tokenizer(model_args):
84
  """
85
+ Load the base model and tokenizer from the Hugging Face Hub.
86
  """
87
  model = AutoModelForCausalLM.from_pretrained(
88
  model_args.model_name,
 
95
  trust_remote_code=model_args.trust_remote_code
96
  )
97
 
98
+ # Set pad token if it is missing
99
  if tokenizer.pad_token is None:
100
  tokenizer.pad_token = tokenizer.eos_token
101
 
102
+ # Setup chat format if not available on the tokenizer
103
+ if not getattr(tokenizer, "chat_template", None):
104
  model, tokenizer = setup_chat_format(model, tokenizer)
105
 
106
+ return model, tokenizer
107
 
108
+ def get_adapter_path(model_name: str, language: str, timestamp: str = None) -> Path:
109
+ """
110
+ Generate standardized adapter path.
111
+ If timestamp is None, returns the base language directory.
112
+ Otherwise, returns specific adapter version path.
113
 
114
+ Format: adapters/{model_name}/{language}/version_{timestamp}
115
+ """
116
+ # Clean model name (remove slashes, etc.)
117
+ clean_model_name = model_name.replace('/', '_')
118
 
119
+ base_path = Path("adapters") / clean_model_name / language
120
+ if timestamp:
121
+ return base_path / f"version_{timestamp}"
122
+ return base_path
123
 
124
+ def load_latest_adapter(model, model_name: str, language: str) -> tuple[PeftModel, str]:
125
+ """
126
+ Load the most recent adapter for given model and language.
127
+ Returns: (loaded_model, timestamp of loaded adapter)
128
+ """
129
+ adapter_base = get_adapter_path(model_name, language)
 
 
 
 
 
130
 
131
+ if not adapter_base.exists():
132
+ return None, None
133
 
134
+ # Get all version directories and sort by timestamp
135
+ versions = sorted(
136
+ [d for d in adapter_base.glob("version_*")],
137
+ key=lambda x: x.name,
138
+ reverse=True
139
+ )
140
 
141
+ if not versions:
142
+ return None, None
 
143
 
144
+ latest_version = versions[0]
145
+ timestamp = latest_version.name.replace("version_", "")
 
 
146
 
147
+ model = PeftModel.from_pretrained(model, latest_version, is_trainable=True)
148
+ return model, timestamp
149
 
150
  ####################################
151
  # MAIN LOGIC
152
  ####################################
153
 
154
  def main():
155
+ # Initialize wandb for logging
156
  wandb.init(project="kto")
157
 
158
+ # Get timestamp at start of training
159
+ training_timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
160
+
161
+ print("Loading base model and tokenizer...")
162
  model, tokenizer = load_model_and_tokenizer(model_args)
163
  ref_model, _ = load_model_and_tokenizer(model_args)
164
  print("Models and tokenizer loaded.")
165
 
166
+ # Load existing adapter or create new one
167
+ loaded_model, previous_timestamp = load_latest_adapter(
168
+ model,
169
+ model_args.model_name,
170
+ script_args.language
171
+ )
172
+
173
+ if loaded_model is not None:
174
+ model = loaded_model
175
+ print(f"Loaded existing adapter trained at {previous_timestamp}")
176
+ else:
177
+ # Initialize new LoRA adapter
178
+ peft_config = get_peft_config(model_args)
179
+ model = get_peft_model(model, peft_config)
180
+ print("Initialized new adapter")
181
+
182
+ # -----------------------------
183
+ # Data Preparation and Training
184
+ # -----------------------------
185
  print("Processing dataset...")
186
+ dataset = script_args.process_dataset_func(script_args.language)
187
  print("Dataset processed.")
188
 
 
 
 
 
 
 
189
  print("Initializing trainer...")
190
  trainer = KTOTrainer(
191
  model=model,
 
193
  args=training_args,
194
  train_dataset=dataset["train"],
195
  eval_dataset=dataset["test"],
196
+ processing_class=tokenizer,
197
+ peft_config=peft_config,
198
  )
199
 
200
  # Training
 
226
  "step": metrics.get("step")
227
  })
228
 
229
+ # Save the adapter
230
+ adapter_path = get_adapter_path(
231
+ model_args.model_name,
232
+ script_args.language,
233
+ training_timestamp
234
+ )
235
+ adapter_path.parent.mkdir(parents=True, exist_ok=True)
236
+
237
+ print(f"Saving adapter to: {adapter_path}")
238
+ model.save_pretrained(adapter_path)
239
+
240
+ # Save metadata
241
+ metadata = AdapterMetadata(
242
+ training_timestamp=training_timestamp,
243
+ model_name=model_args.model_name,
244
+ language=script_args.language,
245
+ )
246
+ metadata.save(adapter_path / "metadata.json")
247
+
248
  if script_args.push_to_hub:
249
+ repo_id = f"feel-fl/adapters/{model_args.model_name.replace('/', '_')}/{script_args.language}"
250
+ print(f"Pushing adapter to Hugging Face Hub at {repo_id}...")
251
+ model.push_to_hub(repo_id=repo_id)
252
 
253
  print("Process completed.")
254