Guetat Youssef commited on
Commit
28d8f8e
·
1 Parent(s): 9774f95
Files changed (1) hide show
  1. app.py +169 -18
app.py CHANGED
@@ -1,10 +1,12 @@
1
- from flask import Flask, jsonify, request
2
  import threading
3
  import time
4
  import os
5
  import tempfile
6
  import shutil
7
  import uuid
 
 
8
  from datetime import datetime, timedelta
9
 
10
  app = Flask(__name__)
@@ -23,6 +25,8 @@ class TrainingProgress:
23
  self.estimated_finish_time = None
24
  self.message = "Starting training..."
25
  self.error = None
 
 
26
 
27
  def update_progress(self, current_step, total_steps, message=""):
28
  self.current_step = current_step
@@ -47,10 +51,67 @@ class TrainingProgress:
47
  "total_steps": self.total_steps,
48
  "message": self.message,
49
  "estimated_finish_time": self.estimated_finish_time.isoformat() if self.estimated_finish_time else None,
50
- "error": self.error
 
 
51
  }
52
 
53
- def train_model_background(job_id):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  """Background training function with progress tracking"""
55
  progress = training_jobs[job_id]
56
 
@@ -92,8 +153,7 @@ def train_model_background(job_id):
92
  progress.message = "Loading base model and tokenizer..."
93
 
94
  # === Configuration ===
95
- base_model = "microsoft/DialoGPT-small"
96
- dataset_name = "ruslanmv/ai-medical-chatbot"
97
  new_model = f"trained-model-{job_id}"
98
  max_length = 256
99
 
@@ -138,11 +198,22 @@ def train_model_background(job_id):
138
  # === Load & Prepare Dataset ===
139
  dataset = load_dataset(
140
  dataset_name,
141
- split="all",
142
  cache_dir=temp_dir,
143
  trust_remote_code=True
144
  )
145
- dataset = dataset.shuffle(seed=65).select(range(30)) # Use only 30 samples for faster testing
 
 
 
 
 
 
 
 
 
 
 
146
 
147
  # Custom dataset class for proper handling
148
  class CustomDataset(torch.utils.data.Dataset):
@@ -171,7 +242,6 @@ def train_model_background(job_id):
171
  attention_mask = encoding['attention_mask'].squeeze()
172
 
173
  # For causal language modeling, labels are the same as input_ids
174
- # But we shift them so the model predicts the next token
175
  labels = input_ids.clone()
176
 
177
  # Set labels to -100 for padding tokens (they won't contribute to loss)
@@ -183,10 +253,12 @@ def train_model_background(job_id):
183
  'labels': labels
184
  }
185
 
186
- # Prepare texts
187
  texts = []
188
  for item in dataset:
189
- text = f"Patient: {item['Patient']}\nDoctor: {item['Doctor']}{tokenizer.eos_token}"
 
 
190
  texts.append(text)
191
 
192
  # Create custom dataset
@@ -214,7 +286,7 @@ def train_model_background(job_id):
214
  gradient_accumulation_steps=gradient_accumulation_steps,
215
  num_train_epochs=num_epochs,
216
  logging_steps=1,
217
- save_steps=15,
218
  save_total_limit=1,
219
  learning_rate=5e-5,
220
  warmup_steps=2,
@@ -272,15 +344,20 @@ def train_model_background(job_id):
272
  trainer.save_model(output_dir)
273
  tokenizer.save_pretrained(output_dir)
274
 
 
 
275
  progress.status = "completed"
276
  progress.progress = 100
277
- progress.message = f"Training completed! Model saved to {output_dir}"
278
 
279
- # Clean up temporary directory after a delay
280
  def cleanup_temp_dir():
281
- time.sleep(300) # Wait 5 minutes before cleanup
282
  try:
283
  shutil.rmtree(temp_dir)
 
 
 
284
  except:
285
  pass
286
 
@@ -300,23 +377,46 @@ def train_model_background(job_id):
300
  except:
301
  pass
302
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
  # ============== API ROUTES ==============
304
  @app.route('/api/train', methods=['POST'])
305
  def start_training():
306
  """Start training and return job ID for tracking"""
307
  try:
 
 
 
 
308
  job_id = str(uuid.uuid4())[:8] # Short UUID
309
  progress = TrainingProgress(job_id)
310
  training_jobs[job_id] = progress
311
 
312
  # Start training in background thread
313
- training_thread = threading.Thread(target=train_model_background, args=(job_id,))
 
 
 
314
  training_thread.daemon = True
315
  training_thread.start()
316
 
317
  return jsonify({
318
  "status": "started",
319
  "job_id": job_id,
 
 
320
  "message": "Training started. Use /api/status/<job_id> to track progress."
321
  })
322
 
@@ -332,6 +432,40 @@ def get_training_status(job_id):
332
  progress = training_jobs[job_id]
333
  return jsonify(progress.to_dict())
334
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
335
  @app.route('/api/jobs', methods=['GET'])
336
  def list_jobs():
337
  """List all training jobs"""
@@ -341,11 +475,28 @@ def list_jobs():
341
  @app.route('/')
342
  def home():
343
  return jsonify({
344
- "message": "Welcome to LLaMA Fine-tuning API!",
 
 
 
 
 
 
345
  "endpoints": {
346
- "POST /api/train": "Start training",
347
- "GET /api/status/<job_id>": "Get training status",
 
348
  "GET /api/jobs": "List all jobs"
 
 
 
 
 
 
 
 
 
 
349
  }
350
  })
351
 
 
1
+ from flask import Flask, jsonify, request, send_file
2
  import threading
3
  import time
4
  import os
5
  import tempfile
6
  import shutil
7
  import uuid
8
+ import zipfile
9
+ import io
10
  from datetime import datetime, timedelta
11
 
12
  app = Flask(__name__)
 
25
  self.estimated_finish_time = None
26
  self.message = "Starting training..."
27
  self.error = None
28
+ self.model_path = None
29
+ self.detected_columns = None
30
 
31
  def update_progress(self, current_step, total_steps, message=""):
32
  self.current_step = current_step
 
51
  "total_steps": self.total_steps,
52
  "message": self.message,
53
  "estimated_finish_time": self.estimated_finish_time.isoformat() if self.estimated_finish_time else None,
54
+ "error": self.error,
55
+ "model_path": self.model_path,
56
+ "detected_columns": self.detected_columns
57
  }
58
 
59
+ def detect_qa_columns(dataset):
60
+ """Automatically detect question and answer columns in the dataset"""
61
+ # Common patterns for question columns
62
+ question_patterns = [
63
+ 'question', 'prompt', 'input', 'query', 'patient', 'user', 'human',
64
+ 'instruction', 'context', 'q', 'text', 'source'
65
+ ]
66
+
67
+ # Common patterns for answer columns
68
+ answer_patterns = [
69
+ 'answer', 'response', 'output', 'reply', 'doctor', 'assistant', 'ai',
70
+ 'completion', 'target', 'a', 'label', 'ground_truth'
71
+ ]
72
+
73
+ # Get column names
74
+ columns = list(dataset.column_names)
75
+
76
+ # Find question column
77
+ question_col = None
78
+ for pattern in question_patterns:
79
+ for col in columns:
80
+ if pattern.lower() in col.lower():
81
+ question_col = col
82
+ break
83
+ if question_col:
84
+ break
85
+
86
+ # Find answer column
87
+ answer_col = None
88
+ for pattern in answer_patterns:
89
+ for col in columns:
90
+ if pattern.lower() in col.lower() and col != question_col:
91
+ answer_col = col
92
+ break
93
+ if answer_col:
94
+ break
95
+
96
+ # Fallback: use first two text columns if patterns don't match
97
+ if not question_col or not answer_col:
98
+ text_columns = []
99
+ for col in columns:
100
+ # Check if column contains text data
101
+ sample = dataset[0][col]
102
+ if isinstance(sample, str) and len(sample.strip()) > 0:
103
+ text_columns.append(col)
104
+
105
+ if len(text_columns) >= 2:
106
+ question_col = text_columns[0]
107
+ answer_col = text_columns[1]
108
+ elif len(text_columns) == 1:
109
+ # Single column case - use it for both (self-supervised)
110
+ question_col = answer_col = text_columns[0]
111
+
112
+ return question_col, answer_col
113
+
114
+ def train_model_background(job_id, dataset_name, base_model_name=None):
115
  """Background training function with progress tracking"""
116
  progress = training_jobs[job_id]
117
 
 
153
  progress.message = "Loading base model and tokenizer..."
154
 
155
  # === Configuration ===
156
+ base_model = base_model_name or "microsoft/DialoGPT-small"
 
157
  new_model = f"trained-model-{job_id}"
158
  max_length = 256
159
 
 
198
  # === Load & Prepare Dataset ===
199
  dataset = load_dataset(
200
  dataset_name,
201
+ split="train" if "train" in load_dataset(dataset_name, cache_dir=temp_dir).keys() else "all",
202
  cache_dir=temp_dir,
203
  trust_remote_code=True
204
  )
205
+
206
+ # Automatically detect question and answer columns
207
+ question_col, answer_col = detect_qa_columns(dataset)
208
+
209
+ if not question_col or not answer_col:
210
+ raise ValueError("Could not automatically detect question and answer columns in the dataset")
211
+
212
+ progress.detected_columns = {"question": question_col, "answer": answer_col}
213
+ progress.message = f"Detected columns - Question: {question_col}, Answer: {answer_col}"
214
+
215
+ # Use subset for faster testing (can be made configurable)
216
+ dataset = dataset.shuffle(seed=65).select(range(min(100, len(dataset))))
217
 
218
  # Custom dataset class for proper handling
219
  class CustomDataset(torch.utils.data.Dataset):
 
242
  attention_mask = encoding['attention_mask'].squeeze()
243
 
244
  # For causal language modeling, labels are the same as input_ids
 
245
  labels = input_ids.clone()
246
 
247
  # Set labels to -100 for padding tokens (they won't contribute to loss)
 
253
  'labels': labels
254
  }
255
 
256
+ # Prepare texts using detected columns
257
  texts = []
258
  for item in dataset:
259
+ question = str(item[question_col]).strip()
260
+ answer = str(item[answer_col]).strip()
261
+ text = f"Question: {question}\nAnswer: {answer}{tokenizer.eos_token}"
262
  texts.append(text)
263
 
264
  # Create custom dataset
 
286
  gradient_accumulation_steps=gradient_accumulation_steps,
287
  num_train_epochs=num_epochs,
288
  logging_steps=1,
289
+ save_steps=max(1, total_steps // 2),
290
  save_total_limit=1,
291
  learning_rate=5e-5,
292
  warmup_steps=2,
 
344
  trainer.save_model(output_dir)
345
  tokenizer.save_pretrained(output_dir)
346
 
347
+ # Save model info
348
+ progress.model_path = output_dir
349
  progress.status = "completed"
350
  progress.progress = 100
351
+ progress.message = f"Training completed! Model ready for download."
352
 
353
+ # Keep the temp directory for download (cleanup after 1 hour)
354
  def cleanup_temp_dir():
355
+ time.sleep(3600) # Wait 1 hour before cleanup
356
  try:
357
  shutil.rmtree(temp_dir)
358
+ # Remove from training_jobs after cleanup
359
+ if job_id in training_jobs:
360
+ del training_jobs[job_id]
361
  except:
362
  pass
363
 
 
377
  except:
378
  pass
379
 
380
+ def create_model_zip(model_path, job_id):
381
+ """Create a zip file containing the trained model"""
382
+ memory_file = io.BytesIO()
383
+
384
+ with zipfile.ZipFile(memory_file, 'w', zipfile.ZIP_DEFLATED) as zf:
385
+ for root, dirs, files in os.walk(model_path):
386
+ for file in files:
387
+ file_path = os.path.join(root, file)
388
+ arc_name = os.path.relpath(file_path, model_path)
389
+ zf.write(file_path, arc_name)
390
+
391
+ memory_file.seek(0)
392
+ return memory_file
393
+
394
  # ============== API ROUTES ==============
395
  @app.route('/api/train', methods=['POST'])
396
  def start_training():
397
  """Start training and return job ID for tracking"""
398
  try:
399
+ data = request.get_json() if request.is_json else {}
400
+ dataset_name = data.get('dataset_name', 'ruslanmv/ai-medical-chatbot')
401
+ base_model_name = data.get('base_model', 'microsoft/DialoGPT-small')
402
+
403
  job_id = str(uuid.uuid4())[:8] # Short UUID
404
  progress = TrainingProgress(job_id)
405
  training_jobs[job_id] = progress
406
 
407
  # Start training in background thread
408
+ training_thread = threading.Thread(
409
+ target=train_model_background,
410
+ args=(job_id, dataset_name, base_model_name)
411
+ )
412
  training_thread.daemon = True
413
  training_thread.start()
414
 
415
  return jsonify({
416
  "status": "started",
417
  "job_id": job_id,
418
+ "dataset_name": dataset_name,
419
+ "base_model": base_model_name,
420
  "message": "Training started. Use /api/status/<job_id> to track progress."
421
  })
422
 
 
432
  progress = training_jobs[job_id]
433
  return jsonify(progress.to_dict())
434
 
435
+ @app.route('/api/download/<job_id>', methods=['GET'])
436
+ def download_model(job_id):
437
+ """Download the trained model as a zip file"""
438
+ if job_id not in training_jobs:
439
+ return jsonify({"status": "error", "message": "Job not found"}), 404
440
+
441
+ progress = training_jobs[job_id]
442
+
443
+ if progress.status != "completed":
444
+ return jsonify({
445
+ "status": "error",
446
+ "message": f"Model not ready for download. Current status: {progress.status}"
447
+ }), 400
448
+
449
+ if not progress.model_path or not os.path.exists(progress.model_path):
450
+ return jsonify({
451
+ "status": "error",
452
+ "message": "Model files not found. They may have been cleaned up."
453
+ }), 404
454
+
455
+ try:
456
+ # Create zip file in memory
457
+ zip_file = create_model_zip(progress.model_path, job_id)
458
+
459
+ return send_file(
460
+ zip_file,
461
+ as_attachment=True,
462
+ download_name=f"trained_model_{job_id}.zip",
463
+ mimetype='application/zip'
464
+ )
465
+
466
+ except Exception as e:
467
+ return jsonify({"status": "error", "message": f"Download failed: {str(e)}"}), 500
468
+
469
  @app.route('/api/jobs', methods=['GET'])
470
  def list_jobs():
471
  """List all training jobs"""
 
475
  @app.route('/')
476
  def home():
477
  return jsonify({
478
+ "message": "Welcome to Enhanced LLaMA Fine-tuning API!",
479
+ "features": [
480
+ "Automatic question/answer column detection",
481
+ "Configurable base model and dataset",
482
+ "Local model download",
483
+ "Progress tracking with ETA"
484
+ ],
485
  "endpoints": {
486
+ "POST /api/train": "Start training (accepts dataset_name and base_model in JSON)",
487
+ "GET /api/status/<job_id>": "Get training status and detected columns",
488
+ "GET /api/download/<job_id>": "Download trained model as zip",
489
  "GET /api/jobs": "List all jobs"
490
+ },
491
+ "usage_example": {
492
+ "start_training": {
493
+ "method": "POST",
494
+ "url": "/api/train",
495
+ "body": {
496
+ "dataset_name": "your-dataset-name",
497
+ "base_model": "microsoft/DialoGPT-small"
498
+ }
499
+ }
500
  }
501
  })
502