Guetat Youssef
commited on
Commit
·
28d8f8e
1
Parent(s):
9774f95
test
Browse files
app.py
CHANGED
@@ -1,10 +1,12 @@
|
|
1 |
-
from flask import Flask, jsonify, request
|
2 |
import threading
|
3 |
import time
|
4 |
import os
|
5 |
import tempfile
|
6 |
import shutil
|
7 |
import uuid
|
|
|
|
|
8 |
from datetime import datetime, timedelta
|
9 |
|
10 |
app = Flask(__name__)
|
@@ -23,6 +25,8 @@ class TrainingProgress:
|
|
23 |
self.estimated_finish_time = None
|
24 |
self.message = "Starting training..."
|
25 |
self.error = None
|
|
|
|
|
26 |
|
27 |
def update_progress(self, current_step, total_steps, message=""):
|
28 |
self.current_step = current_step
|
@@ -47,10 +51,67 @@ class TrainingProgress:
|
|
47 |
"total_steps": self.total_steps,
|
48 |
"message": self.message,
|
49 |
"estimated_finish_time": self.estimated_finish_time.isoformat() if self.estimated_finish_time else None,
|
50 |
-
"error": self.error
|
|
|
|
|
51 |
}
|
52 |
|
53 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
"""Background training function with progress tracking"""
|
55 |
progress = training_jobs[job_id]
|
56 |
|
@@ -92,8 +153,7 @@ def train_model_background(job_id):
|
|
92 |
progress.message = "Loading base model and tokenizer..."
|
93 |
|
94 |
# === Configuration ===
|
95 |
-
base_model = "microsoft/DialoGPT-small"
|
96 |
-
dataset_name = "ruslanmv/ai-medical-chatbot"
|
97 |
new_model = f"trained-model-{job_id}"
|
98 |
max_length = 256
|
99 |
|
@@ -138,11 +198,22 @@ def train_model_background(job_id):
|
|
138 |
# === Load & Prepare Dataset ===
|
139 |
dataset = load_dataset(
|
140 |
dataset_name,
|
141 |
-
split="all",
|
142 |
cache_dir=temp_dir,
|
143 |
trust_remote_code=True
|
144 |
)
|
145 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
|
147 |
# Custom dataset class for proper handling
|
148 |
class CustomDataset(torch.utils.data.Dataset):
|
@@ -171,7 +242,6 @@ def train_model_background(job_id):
|
|
171 |
attention_mask = encoding['attention_mask'].squeeze()
|
172 |
|
173 |
# For causal language modeling, labels are the same as input_ids
|
174 |
-
# But we shift them so the model predicts the next token
|
175 |
labels = input_ids.clone()
|
176 |
|
177 |
# Set labels to -100 for padding tokens (they won't contribute to loss)
|
@@ -183,10 +253,12 @@ def train_model_background(job_id):
|
|
183 |
'labels': labels
|
184 |
}
|
185 |
|
186 |
-
# Prepare texts
|
187 |
texts = []
|
188 |
for item in dataset:
|
189 |
-
|
|
|
|
|
190 |
texts.append(text)
|
191 |
|
192 |
# Create custom dataset
|
@@ -214,7 +286,7 @@ def train_model_background(job_id):
|
|
214 |
gradient_accumulation_steps=gradient_accumulation_steps,
|
215 |
num_train_epochs=num_epochs,
|
216 |
logging_steps=1,
|
217 |
-
save_steps=
|
218 |
save_total_limit=1,
|
219 |
learning_rate=5e-5,
|
220 |
warmup_steps=2,
|
@@ -272,15 +344,20 @@ def train_model_background(job_id):
|
|
272 |
trainer.save_model(output_dir)
|
273 |
tokenizer.save_pretrained(output_dir)
|
274 |
|
|
|
|
|
275 |
progress.status = "completed"
|
276 |
progress.progress = 100
|
277 |
-
progress.message = f"Training completed! Model
|
278 |
|
279 |
-
#
|
280 |
def cleanup_temp_dir():
|
281 |
-
time.sleep(
|
282 |
try:
|
283 |
shutil.rmtree(temp_dir)
|
|
|
|
|
|
|
284 |
except:
|
285 |
pass
|
286 |
|
@@ -300,23 +377,46 @@ def train_model_background(job_id):
|
|
300 |
except:
|
301 |
pass
|
302 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
303 |
# ============== API ROUTES ==============
|
304 |
@app.route('/api/train', methods=['POST'])
|
305 |
def start_training():
|
306 |
"""Start training and return job ID for tracking"""
|
307 |
try:
|
|
|
|
|
|
|
|
|
308 |
job_id = str(uuid.uuid4())[:8] # Short UUID
|
309 |
progress = TrainingProgress(job_id)
|
310 |
training_jobs[job_id] = progress
|
311 |
|
312 |
# Start training in background thread
|
313 |
-
training_thread = threading.Thread(
|
|
|
|
|
|
|
314 |
training_thread.daemon = True
|
315 |
training_thread.start()
|
316 |
|
317 |
return jsonify({
|
318 |
"status": "started",
|
319 |
"job_id": job_id,
|
|
|
|
|
320 |
"message": "Training started. Use /api/status/<job_id> to track progress."
|
321 |
})
|
322 |
|
@@ -332,6 +432,40 @@ def get_training_status(job_id):
|
|
332 |
progress = training_jobs[job_id]
|
333 |
return jsonify(progress.to_dict())
|
334 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
335 |
@app.route('/api/jobs', methods=['GET'])
|
336 |
def list_jobs():
|
337 |
"""List all training jobs"""
|
@@ -341,11 +475,28 @@ def list_jobs():
|
|
341 |
@app.route('/')
|
342 |
def home():
|
343 |
return jsonify({
|
344 |
-
"message": "Welcome to LLaMA Fine-tuning API!",
|
|
|
|
|
|
|
|
|
|
|
|
|
345 |
"endpoints": {
|
346 |
-
"POST /api/train": "Start training",
|
347 |
-
"GET /api/status/<job_id>": "Get training status",
|
|
|
348 |
"GET /api/jobs": "List all jobs"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
349 |
}
|
350 |
})
|
351 |
|
|
|
1 |
+
from flask import Flask, jsonify, request, send_file
|
2 |
import threading
|
3 |
import time
|
4 |
import os
|
5 |
import tempfile
|
6 |
import shutil
|
7 |
import uuid
|
8 |
+
import zipfile
|
9 |
+
import io
|
10 |
from datetime import datetime, timedelta
|
11 |
|
12 |
app = Flask(__name__)
|
|
|
25 |
self.estimated_finish_time = None
|
26 |
self.message = "Starting training..."
|
27 |
self.error = None
|
28 |
+
self.model_path = None
|
29 |
+
self.detected_columns = None
|
30 |
|
31 |
def update_progress(self, current_step, total_steps, message=""):
|
32 |
self.current_step = current_step
|
|
|
51 |
"total_steps": self.total_steps,
|
52 |
"message": self.message,
|
53 |
"estimated_finish_time": self.estimated_finish_time.isoformat() if self.estimated_finish_time else None,
|
54 |
+
"error": self.error,
|
55 |
+
"model_path": self.model_path,
|
56 |
+
"detected_columns": self.detected_columns
|
57 |
}
|
58 |
|
59 |
+
def detect_qa_columns(dataset):
|
60 |
+
"""Automatically detect question and answer columns in the dataset"""
|
61 |
+
# Common patterns for question columns
|
62 |
+
question_patterns = [
|
63 |
+
'question', 'prompt', 'input', 'query', 'patient', 'user', 'human',
|
64 |
+
'instruction', 'context', 'q', 'text', 'source'
|
65 |
+
]
|
66 |
+
|
67 |
+
# Common patterns for answer columns
|
68 |
+
answer_patterns = [
|
69 |
+
'answer', 'response', 'output', 'reply', 'doctor', 'assistant', 'ai',
|
70 |
+
'completion', 'target', 'a', 'label', 'ground_truth'
|
71 |
+
]
|
72 |
+
|
73 |
+
# Get column names
|
74 |
+
columns = list(dataset.column_names)
|
75 |
+
|
76 |
+
# Find question column
|
77 |
+
question_col = None
|
78 |
+
for pattern in question_patterns:
|
79 |
+
for col in columns:
|
80 |
+
if pattern.lower() in col.lower():
|
81 |
+
question_col = col
|
82 |
+
break
|
83 |
+
if question_col:
|
84 |
+
break
|
85 |
+
|
86 |
+
# Find answer column
|
87 |
+
answer_col = None
|
88 |
+
for pattern in answer_patterns:
|
89 |
+
for col in columns:
|
90 |
+
if pattern.lower() in col.lower() and col != question_col:
|
91 |
+
answer_col = col
|
92 |
+
break
|
93 |
+
if answer_col:
|
94 |
+
break
|
95 |
+
|
96 |
+
# Fallback: use first two text columns if patterns don't match
|
97 |
+
if not question_col or not answer_col:
|
98 |
+
text_columns = []
|
99 |
+
for col in columns:
|
100 |
+
# Check if column contains text data
|
101 |
+
sample = dataset[0][col]
|
102 |
+
if isinstance(sample, str) and len(sample.strip()) > 0:
|
103 |
+
text_columns.append(col)
|
104 |
+
|
105 |
+
if len(text_columns) >= 2:
|
106 |
+
question_col = text_columns[0]
|
107 |
+
answer_col = text_columns[1]
|
108 |
+
elif len(text_columns) == 1:
|
109 |
+
# Single column case - use it for both (self-supervised)
|
110 |
+
question_col = answer_col = text_columns[0]
|
111 |
+
|
112 |
+
return question_col, answer_col
|
113 |
+
|
114 |
+
def train_model_background(job_id, dataset_name, base_model_name=None):
|
115 |
"""Background training function with progress tracking"""
|
116 |
progress = training_jobs[job_id]
|
117 |
|
|
|
153 |
progress.message = "Loading base model and tokenizer..."
|
154 |
|
155 |
# === Configuration ===
|
156 |
+
base_model = base_model_name or "microsoft/DialoGPT-small"
|
|
|
157 |
new_model = f"trained-model-{job_id}"
|
158 |
max_length = 256
|
159 |
|
|
|
198 |
# === Load & Prepare Dataset ===
|
199 |
dataset = load_dataset(
|
200 |
dataset_name,
|
201 |
+
split="train" if "train" in load_dataset(dataset_name, cache_dir=temp_dir).keys() else "all",
|
202 |
cache_dir=temp_dir,
|
203 |
trust_remote_code=True
|
204 |
)
|
205 |
+
|
206 |
+
# Automatically detect question and answer columns
|
207 |
+
question_col, answer_col = detect_qa_columns(dataset)
|
208 |
+
|
209 |
+
if not question_col or not answer_col:
|
210 |
+
raise ValueError("Could not automatically detect question and answer columns in the dataset")
|
211 |
+
|
212 |
+
progress.detected_columns = {"question": question_col, "answer": answer_col}
|
213 |
+
progress.message = f"Detected columns - Question: {question_col}, Answer: {answer_col}"
|
214 |
+
|
215 |
+
# Use subset for faster testing (can be made configurable)
|
216 |
+
dataset = dataset.shuffle(seed=65).select(range(min(100, len(dataset))))
|
217 |
|
218 |
# Custom dataset class for proper handling
|
219 |
class CustomDataset(torch.utils.data.Dataset):
|
|
|
242 |
attention_mask = encoding['attention_mask'].squeeze()
|
243 |
|
244 |
# For causal language modeling, labels are the same as input_ids
|
|
|
245 |
labels = input_ids.clone()
|
246 |
|
247 |
# Set labels to -100 for padding tokens (they won't contribute to loss)
|
|
|
253 |
'labels': labels
|
254 |
}
|
255 |
|
256 |
+
# Prepare texts using detected columns
|
257 |
texts = []
|
258 |
for item in dataset:
|
259 |
+
question = str(item[question_col]).strip()
|
260 |
+
answer = str(item[answer_col]).strip()
|
261 |
+
text = f"Question: {question}\nAnswer: {answer}{tokenizer.eos_token}"
|
262 |
texts.append(text)
|
263 |
|
264 |
# Create custom dataset
|
|
|
286 |
gradient_accumulation_steps=gradient_accumulation_steps,
|
287 |
num_train_epochs=num_epochs,
|
288 |
logging_steps=1,
|
289 |
+
save_steps=max(1, total_steps // 2),
|
290 |
save_total_limit=1,
|
291 |
learning_rate=5e-5,
|
292 |
warmup_steps=2,
|
|
|
344 |
trainer.save_model(output_dir)
|
345 |
tokenizer.save_pretrained(output_dir)
|
346 |
|
347 |
+
# Save model info
|
348 |
+
progress.model_path = output_dir
|
349 |
progress.status = "completed"
|
350 |
progress.progress = 100
|
351 |
+
progress.message = f"Training completed! Model ready for download."
|
352 |
|
353 |
+
# Keep the temp directory for download (cleanup after 1 hour)
|
354 |
def cleanup_temp_dir():
|
355 |
+
time.sleep(3600) # Wait 1 hour before cleanup
|
356 |
try:
|
357 |
shutil.rmtree(temp_dir)
|
358 |
+
# Remove from training_jobs after cleanup
|
359 |
+
if job_id in training_jobs:
|
360 |
+
del training_jobs[job_id]
|
361 |
except:
|
362 |
pass
|
363 |
|
|
|
377 |
except:
|
378 |
pass
|
379 |
|
380 |
+
def create_model_zip(model_path, job_id):
|
381 |
+
"""Create a zip file containing the trained model"""
|
382 |
+
memory_file = io.BytesIO()
|
383 |
+
|
384 |
+
with zipfile.ZipFile(memory_file, 'w', zipfile.ZIP_DEFLATED) as zf:
|
385 |
+
for root, dirs, files in os.walk(model_path):
|
386 |
+
for file in files:
|
387 |
+
file_path = os.path.join(root, file)
|
388 |
+
arc_name = os.path.relpath(file_path, model_path)
|
389 |
+
zf.write(file_path, arc_name)
|
390 |
+
|
391 |
+
memory_file.seek(0)
|
392 |
+
return memory_file
|
393 |
+
|
394 |
# ============== API ROUTES ==============
|
395 |
@app.route('/api/train', methods=['POST'])
|
396 |
def start_training():
|
397 |
"""Start training and return job ID for tracking"""
|
398 |
try:
|
399 |
+
data = request.get_json() if request.is_json else {}
|
400 |
+
dataset_name = data.get('dataset_name', 'ruslanmv/ai-medical-chatbot')
|
401 |
+
base_model_name = data.get('base_model', 'microsoft/DialoGPT-small')
|
402 |
+
|
403 |
job_id = str(uuid.uuid4())[:8] # Short UUID
|
404 |
progress = TrainingProgress(job_id)
|
405 |
training_jobs[job_id] = progress
|
406 |
|
407 |
# Start training in background thread
|
408 |
+
training_thread = threading.Thread(
|
409 |
+
target=train_model_background,
|
410 |
+
args=(job_id, dataset_name, base_model_name)
|
411 |
+
)
|
412 |
training_thread.daemon = True
|
413 |
training_thread.start()
|
414 |
|
415 |
return jsonify({
|
416 |
"status": "started",
|
417 |
"job_id": job_id,
|
418 |
+
"dataset_name": dataset_name,
|
419 |
+
"base_model": base_model_name,
|
420 |
"message": "Training started. Use /api/status/<job_id> to track progress."
|
421 |
})
|
422 |
|
|
|
432 |
progress = training_jobs[job_id]
|
433 |
return jsonify(progress.to_dict())
|
434 |
|
435 |
+
@app.route('/api/download/<job_id>', methods=['GET'])
|
436 |
+
def download_model(job_id):
|
437 |
+
"""Download the trained model as a zip file"""
|
438 |
+
if job_id not in training_jobs:
|
439 |
+
return jsonify({"status": "error", "message": "Job not found"}), 404
|
440 |
+
|
441 |
+
progress = training_jobs[job_id]
|
442 |
+
|
443 |
+
if progress.status != "completed":
|
444 |
+
return jsonify({
|
445 |
+
"status": "error",
|
446 |
+
"message": f"Model not ready for download. Current status: {progress.status}"
|
447 |
+
}), 400
|
448 |
+
|
449 |
+
if not progress.model_path or not os.path.exists(progress.model_path):
|
450 |
+
return jsonify({
|
451 |
+
"status": "error",
|
452 |
+
"message": "Model files not found. They may have been cleaned up."
|
453 |
+
}), 404
|
454 |
+
|
455 |
+
try:
|
456 |
+
# Create zip file in memory
|
457 |
+
zip_file = create_model_zip(progress.model_path, job_id)
|
458 |
+
|
459 |
+
return send_file(
|
460 |
+
zip_file,
|
461 |
+
as_attachment=True,
|
462 |
+
download_name=f"trained_model_{job_id}.zip",
|
463 |
+
mimetype='application/zip'
|
464 |
+
)
|
465 |
+
|
466 |
+
except Exception as e:
|
467 |
+
return jsonify({"status": "error", "message": f"Download failed: {str(e)}"}), 500
|
468 |
+
|
469 |
@app.route('/api/jobs', methods=['GET'])
|
470 |
def list_jobs():
|
471 |
"""List all training jobs"""
|
|
|
475 |
@app.route('/')
|
476 |
def home():
|
477 |
return jsonify({
|
478 |
+
"message": "Welcome to Enhanced LLaMA Fine-tuning API!",
|
479 |
+
"features": [
|
480 |
+
"Automatic question/answer column detection",
|
481 |
+
"Configurable base model and dataset",
|
482 |
+
"Local model download",
|
483 |
+
"Progress tracking with ETA"
|
484 |
+
],
|
485 |
"endpoints": {
|
486 |
+
"POST /api/train": "Start training (accepts dataset_name and base_model in JSON)",
|
487 |
+
"GET /api/status/<job_id>": "Get training status and detected columns",
|
488 |
+
"GET /api/download/<job_id>": "Download trained model as zip",
|
489 |
"GET /api/jobs": "List all jobs"
|
490 |
+
},
|
491 |
+
"usage_example": {
|
492 |
+
"start_training": {
|
493 |
+
"method": "POST",
|
494 |
+
"url": "/api/train",
|
495 |
+
"body": {
|
496 |
+
"dataset_name": "your-dataset-name",
|
497 |
+
"base_model": "microsoft/DialoGPT-small"
|
498 |
+
}
|
499 |
+
}
|
500 |
}
|
501 |
})
|
502 |
|