Guetat Youssef
commited on
Commit
·
10b3fe6
1
Parent(s):
9a98795
test
Browse files- app.py +11 -2
- requirements.txt +11 -9
app.py
CHANGED
@@ -78,6 +78,7 @@ def train_model_background(job_id):
|
|
78 |
BitsAndBytesConfig,
|
79 |
TrainingArguments,
|
80 |
logging,
|
|
|
81 |
)
|
82 |
from peft import (
|
83 |
LoraConfig,
|
@@ -188,11 +189,11 @@ def train_model_background(job_id):
|
|
188 |
dataloader_num_workers=0,
|
189 |
remove_unused_columns=False,
|
190 |
load_best_model_at_end=False,
|
191 |
-
|
192 |
)
|
193 |
|
194 |
# Custom callback to track progress
|
195 |
-
class ProgressCallback:
|
196 |
def __init__(self, progress_tracker):
|
197 |
self.progress_tracker = progress_tracker
|
198 |
self.last_update = time.time()
|
@@ -207,6 +208,14 @@ def train_model_background(job_id):
|
|
207 |
f"Training step {state.global_step}/{state.max_steps}"
|
208 |
)
|
209 |
self.last_update = current_time
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
210 |
|
211 |
# === Trainer Initialization ===
|
212 |
trainer = SFTTrainer(
|
|
|
78 |
BitsAndBytesConfig,
|
79 |
TrainingArguments,
|
80 |
logging,
|
81 |
+
TrainerCallback
|
82 |
)
|
83 |
from peft import (
|
84 |
LoraConfig,
|
|
|
189 |
dataloader_num_workers=0,
|
190 |
remove_unused_columns=False,
|
191 |
load_best_model_at_end=False,
|
192 |
+
# Remove evaluation_strategy parameter - not supported in this version
|
193 |
)
|
194 |
|
195 |
# Custom callback to track progress
|
196 |
+
class ProgressCallback(TrainerCallback):
|
197 |
def __init__(self, progress_tracker):
|
198 |
self.progress_tracker = progress_tracker
|
199 |
self.last_update = time.time()
|
|
|
208 |
f"Training step {state.global_step}/{state.max_steps}"
|
209 |
)
|
210 |
self.last_update = current_time
|
211 |
+
|
212 |
+
def on_train_begin(self, args, state, control, **kwargs):
|
213 |
+
self.progress_tracker.status = "training"
|
214 |
+
self.progress_tracker.message = "Training started..."
|
215 |
+
|
216 |
+
def on_train_end(self, args, state, control, **kwargs):
|
217 |
+
self.progress_tracker.status = "saving"
|
218 |
+
self.progress_tracker.message = "Training complete, saving model..."
|
219 |
|
220 |
# === Trainer Initialization ===
|
221 |
trainer = SFTTrainer(
|
requirements.txt
CHANGED
@@ -1,15 +1,17 @@
|
|
1 |
flask==2.3.3
|
2 |
-
transformers
|
3 |
-
datasets
|
4 |
-
accelerate
|
5 |
-
peft
|
6 |
-
trl
|
7 |
bitsandbytes
|
8 |
-
torch
|
9 |
torchvision
|
10 |
torchaudio
|
11 |
-
huggingface-hub
|
12 |
scipy
|
13 |
scikit-learn
|
14 |
-
numpy
|
15 |
-
pandas
|
|
|
|
|
|
1 |
flask==2.3.3
|
2 |
+
transformers>=4.36.0,<4.45.0
|
3 |
+
datasets>=2.14.0
|
4 |
+
accelerate>=0.24.0
|
5 |
+
peft>=0.6.0,<0.8.0
|
6 |
+
trl>=0.7.0
|
7 |
bitsandbytes
|
8 |
+
torch>=2.0.0
|
9 |
torchvision
|
10 |
torchaudio
|
11 |
+
huggingface-hub>=0.17.0
|
12 |
scipy
|
13 |
scikit-learn
|
14 |
+
numpy<2.0.0
|
15 |
+
pandas
|
16 |
+
sentencepiece
|
17 |
+
protobuf
|