Guetat Youssef
commited on
Commit
·
e4256df
1
Parent(s):
3d011a9
test
Browse files
app.py
ADDED
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from flask import Flask, jsonify, request
|
2 |
+
import threading
|
3 |
+
import time
|
4 |
+
import os
|
5 |
+
import torch
|
6 |
+
from datasets import load_dataset
|
7 |
+
from huggingface_hub import login
|
8 |
+
from transformers import (
|
9 |
+
AutoConfig,
|
10 |
+
AutoModelForCausalLM,
|
11 |
+
AutoTokenizer,
|
12 |
+
BitsAndBytesConfig,
|
13 |
+
TrainingArguments,
|
14 |
+
pipeline,
|
15 |
+
logging,
|
16 |
+
DataCollatorForLanguageModeling,
|
17 |
+
)
|
18 |
+
from peft import (
|
19 |
+
LoraConfig,
|
20 |
+
PeftModel,
|
21 |
+
prepare_model_for_kbit_training,
|
22 |
+
get_peft_model,
|
23 |
+
)
|
24 |
+
from trl import SFTTrainer, setup_chat_format
|
25 |
+
import uuid
|
26 |
+
from datetime import datetime, timedelta
|
27 |
+
|
28 |
+
# ============== CONFIGURATION ==============
|
29 |
+
app = Flask(__name__)
|
30 |
+
|
31 |
+
# Global variables to track training progress
|
32 |
+
training_jobs = {}
|
33 |
+
|
34 |
+
class TrainingProgress:
|
35 |
+
def __init__(self, job_id):
|
36 |
+
self.job_id = job_id
|
37 |
+
self.status = "initializing"
|
38 |
+
self.progress = 0
|
39 |
+
self.current_step = 0
|
40 |
+
self.total_steps = 0
|
41 |
+
self.start_time = time.time()
|
42 |
+
self.estimated_finish_time = None
|
43 |
+
self.message = "Starting training..."
|
44 |
+
self.error = None
|
45 |
+
|
46 |
+
def update_progress(self, current_step, total_steps, message=""):
|
47 |
+
self.current_step = current_step
|
48 |
+
self.total_steps = total_steps
|
49 |
+
self.progress = (current_step / total_steps) * 100 if total_steps > 0 else 0
|
50 |
+
self.message = message
|
51 |
+
|
52 |
+
# Calculate estimated finish time
|
53 |
+
if current_step > 0:
|
54 |
+
elapsed_time = time.time() - self.start_time
|
55 |
+
time_per_step = elapsed_time / current_step
|
56 |
+
remaining_steps = total_steps - current_step
|
57 |
+
estimated_remaining_time = remaining_steps * time_per_step
|
58 |
+
self.estimated_finish_time = datetime.now() + timedelta(seconds=estimated_remaining_time)
|
59 |
+
|
60 |
+
def to_dict(self):
|
61 |
+
return {
|
62 |
+
"job_id": self.job_id,
|
63 |
+
"status": self.status,
|
64 |
+
"progress": round(self.progress, 2),
|
65 |
+
"current_step": self.current_step,
|
66 |
+
"total_steps": self.total_steps,
|
67 |
+
"message": self.message,
|
68 |
+
"estimated_finish_time": self.estimated_finish_time.isoformat() if self.estimated_finish_time else None,
|
69 |
+
"error": self.error
|
70 |
+
}
|
71 |
+
|
72 |
+
def train_model_background(job_id):
|
73 |
+
"""Background training function with progress tracking"""
|
74 |
+
progress = training_jobs[job_id]
|
75 |
+
|
76 |
+
try:
|
77 |
+
# === Authentication ===
|
78 |
+
import os
|
79 |
+
from huggingface_hub import login
|
80 |
+
|
81 |
+
hf_token = os.getenv('HF_TOKEN')
|
82 |
+
|
83 |
+
if not hf_token:
|
84 |
+
raise ValueError("HF_TOKEN is not set. Please define it as an environment variable or secret.")
|
85 |
+
|
86 |
+
login(token=hf_token)
|
87 |
+
|
88 |
+
|
89 |
+
progress.status = "loading_model"
|
90 |
+
progress.message = "Loading base model and tokenizer..."
|
91 |
+
|
92 |
+
# === Configuration ===
|
93 |
+
base_model = "meta-llama/Llama-3.2-1B"
|
94 |
+
dataset_name = "ruslanmv/ai-medical-chatbot"
|
95 |
+
new_model = f"Llama-3.2-3B-chat-doctor-{job_id}"
|
96 |
+
torch_dtype = torch.float16
|
97 |
+
attn_implementation = "eager"
|
98 |
+
|
99 |
+
# === QLoRA Config ===
|
100 |
+
bnb_config = BitsAndBytesConfig(
|
101 |
+
load_in_4bit=True,
|
102 |
+
bnb_4bit_quant_type="nf4",
|
103 |
+
bnb_4bit_compute_dtype=torch_dtype,
|
104 |
+
bnb_4bit_use_double_quant=True,
|
105 |
+
)
|
106 |
+
|
107 |
+
# === Load Model and Tokenizer ===
|
108 |
+
model = AutoModelForCausalLM.from_pretrained(
|
109 |
+
base_model,
|
110 |
+
quantization_config=bnb_config,
|
111 |
+
device_map="auto",
|
112 |
+
attn_implementation=attn_implementation
|
113 |
+
)
|
114 |
+
tokenizer = AutoTokenizer.from_pretrained(base_model)
|
115 |
+
model, tokenizer = setup_chat_format(model, tokenizer)
|
116 |
+
|
117 |
+
progress.status = "preparing_model"
|
118 |
+
progress.message = "Setting up LoRA configuration..."
|
119 |
+
|
120 |
+
# === LoRA Config ===
|
121 |
+
peft_config = LoraConfig(
|
122 |
+
r=16,
|
123 |
+
lora_alpha=32,
|
124 |
+
lora_dropout=0.05,
|
125 |
+
bias="none",
|
126 |
+
task_type="CAUSAL_LM",
|
127 |
+
target_modules=[
|
128 |
+
'up_proj', 'down_proj', 'gate_proj',
|
129 |
+
'k_proj', 'q_proj', 'v_proj', 'o_proj'
|
130 |
+
]
|
131 |
+
)
|
132 |
+
model = get_peft_model(model, peft_config)
|
133 |
+
|
134 |
+
progress.status = "loading_dataset"
|
135 |
+
progress.message = "Loading and preparing dataset..."
|
136 |
+
|
137 |
+
# === Load & Prepare Dataset ===
|
138 |
+
dataset = load_dataset(dataset_name, split="all")
|
139 |
+
dataset = dataset.shuffle(seed=65).select(range(1000)) # Use 1000 samples
|
140 |
+
|
141 |
+
def format_chat_template(row, tokenizer):
|
142 |
+
row_json = [
|
143 |
+
{"role": "user", "content": row["Patient"]},
|
144 |
+
{"role": "assistant", "content": row["Doctor"]}
|
145 |
+
]
|
146 |
+
row["text"] = tokenizer.apply_chat_template(row_json, tokenize=False)
|
147 |
+
return row
|
148 |
+
|
149 |
+
dataset = dataset.map(
|
150 |
+
format_chat_template,
|
151 |
+
fn_kwargs={"tokenizer": tokenizer},
|
152 |
+
num_proc=4
|
153 |
+
)
|
154 |
+
|
155 |
+
dataset = dataset.train_test_split(test_size=0.1)
|
156 |
+
|
157 |
+
# Calculate total training steps
|
158 |
+
train_size = len(dataset["train"])
|
159 |
+
batch_size = 1
|
160 |
+
gradient_accumulation_steps = 2
|
161 |
+
num_epochs = 1
|
162 |
+
|
163 |
+
steps_per_epoch = train_size // (batch_size * gradient_accumulation_steps)
|
164 |
+
total_steps = steps_per_epoch * num_epochs
|
165 |
+
|
166 |
+
progress.total_steps = total_steps
|
167 |
+
progress.status = "training"
|
168 |
+
progress.message = "Starting training..."
|
169 |
+
|
170 |
+
# === Training Arguments ===
|
171 |
+
training_args = TrainingArguments(
|
172 |
+
output_dir=new_model,
|
173 |
+
per_device_train_batch_size=batch_size,
|
174 |
+
per_device_eval_batch_size=1,
|
175 |
+
gradient_accumulation_steps=gradient_accumulation_steps,
|
176 |
+
optim="paged_adamw_32bit",
|
177 |
+
num_train_epochs=num_epochs,
|
178 |
+
eval_steps=0.2,
|
179 |
+
logging_steps=1,
|
180 |
+
warmup_steps=10,
|
181 |
+
logging_strategy="steps",
|
182 |
+
learning_rate=2e-5,
|
183 |
+
fp16=False,
|
184 |
+
bf16=False,
|
185 |
+
group_by_length=True,
|
186 |
+
save_steps=50,
|
187 |
+
save_total_limit=2,
|
188 |
+
report_to=None # Disable wandb for HF Spaces
|
189 |
+
)
|
190 |
+
|
191 |
+
# === Data Collator ===
|
192 |
+
tokenizer.model_max_length = 512
|
193 |
+
|
194 |
+
# Custom callback to track progress
|
195 |
+
class ProgressCallback:
|
196 |
+
def __init__(self, progress_tracker):
|
197 |
+
self.progress_tracker = progress_tracker
|
198 |
+
self.last_update = time.time()
|
199 |
+
|
200 |
+
def on_log(self, args, state, control, model=None, logs=None, **kwargs):
|
201 |
+
current_time = time.time()
|
202 |
+
# Update every 10 seconds or on significant step changes
|
203 |
+
if current_time - self.last_update >= 10 or state.global_step % 10 == 0:
|
204 |
+
self.progress_tracker.update_progress(
|
205 |
+
state.global_step,
|
206 |
+
state.max_steps,
|
207 |
+
f"Training step {state.global_step}/{state.max_steps}"
|
208 |
+
)
|
209 |
+
self.last_update = current_time
|
210 |
+
|
211 |
+
# === Trainer Initialization ===
|
212 |
+
trainer = SFTTrainer(
|
213 |
+
model=model,
|
214 |
+
train_dataset=dataset["train"],
|
215 |
+
eval_dataset=dataset["test"],
|
216 |
+
peft_config=peft_config,
|
217 |
+
args=training_args,
|
218 |
+
callbacks=[ProgressCallback(progress)]
|
219 |
+
)
|
220 |
+
|
221 |
+
# === Train & Save ===
|
222 |
+
trainer.train()
|
223 |
+
trainer.save_model(new_model)
|
224 |
+
|
225 |
+
progress.status = "completed"
|
226 |
+
progress.progress = 100
|
227 |
+
progress.message = f"Training completed! Model saved as {new_model}"
|
228 |
+
|
229 |
+
except Exception as e:
|
230 |
+
progress.status = "error"
|
231 |
+
progress.error = str(e)
|
232 |
+
progress.message = f"Training failed: {str(e)}"
|
233 |
+
|
234 |
+
# ============== API ROUTES ==============
|
235 |
+
@app.route('/api/train', methods=['POST'])
|
236 |
+
def start_training():
|
237 |
+
"""Start training and return job ID for tracking"""
|
238 |
+
try:
|
239 |
+
job_id = str(uuid.uuid4())[:8] # Short UUID
|
240 |
+
progress = TrainingProgress(job_id)
|
241 |
+
training_jobs[job_id] = progress
|
242 |
+
|
243 |
+
# Start training in background thread
|
244 |
+
training_thread = threading.Thread(target=train_model_background, args=(job_id,))
|
245 |
+
training_thread.daemon = True
|
246 |
+
training_thread.start()
|
247 |
+
|
248 |
+
return jsonify({
|
249 |
+
"status": "started",
|
250 |
+
"job_id": job_id,
|
251 |
+
"message": "Training started. Use /api/status/<job_id> to track progress."
|
252 |
+
})
|
253 |
+
|
254 |
+
except Exception as e:
|
255 |
+
return jsonify({"status": "error", "message": str(e)}), 500
|
256 |
+
|
257 |
+
@app.route('/api/status/<job_id>', methods=['GET'])
|
258 |
+
def get_training_status(job_id):
|
259 |
+
"""Get training progress and estimated completion time"""
|
260 |
+
if job_id not in training_jobs:
|
261 |
+
return jsonify({"status": "error", "message": "Job not found"}), 404
|
262 |
+
|
263 |
+
progress = training_jobs[job_id]
|
264 |
+
return jsonify(progress.to_dict())
|
265 |
+
|
266 |
+
@app.route('/api/jobs', methods=['GET'])
|
267 |
+
def list_jobs():
|
268 |
+
"""List all training jobs"""
|
269 |
+
jobs = {job_id: progress.to_dict() for job_id, progress in training_jobs.items()}
|
270 |
+
return jsonify({"jobs": jobs})
|
271 |
+
|
272 |
+
@app.route('/')
|
273 |
+
def home():
|
274 |
+
return jsonify({
|
275 |
+
"message": "Welcome to LLaMA Fine-tuning API!",
|
276 |
+
"endpoints": {
|
277 |
+
"POST /api/train": "Start training",
|
278 |
+
"GET /api/status/<job_id>": "Get training status",
|
279 |
+
"GET /api/jobs": "List all jobs"
|
280 |
+
}
|
281 |
+
})
|
282 |
+
|
283 |
+
@app.route('/health')
|
284 |
+
def health():
|
285 |
+
return jsonify({"status": "healthy"})
|
286 |
+
|
287 |
+
if __name__ == '__main__':
|
288 |
+
port = int(os.environ.get('PORT', 7860)) # HF Spaces uses port 7860
|
289 |
+
app.run(host='0.0.0.0', port=port, debug=False)
|
main.py
DELETED
File without changes
|