flask_pfe / app.py
Guetat Youssef
test
c2215d0
from flask import Flask, jsonify, request, send_file
import threading
import time
import os
import tempfile
import shutil
import uuid
import zipfile
import io
from datetime import datetime, timedelta
app = Flask(__name__)
# Global variables to track training progress
training_jobs = {}
class TrainingProgress:
def __init__(self, job_id):
self.job_id = job_id
self.status = "initializing"
self.progress = 0
self.current_step = 0
self.total_steps = 0
self.start_time = time.time()
self.estimated_finish_time = None
self.message = "Starting training..."
self.error = None
self.model_path = None
self.detected_columns = None
def update_progress(self, current_step, total_steps, message=""):
self.current_step = current_step
self.total_steps = total_steps
self.progress = (current_step / total_steps) * 100 if total_steps > 0 else 0
self.message = message
# Calculate estimated finish time
if current_step > 0:
elapsed_time = time.time() - self.start_time
time_per_step = elapsed_time / current_step
remaining_steps = total_steps - current_step
estimated_remaining_time = remaining_steps * time_per_step
self.estimated_finish_time = datetime.now() + timedelta(seconds=estimated_remaining_time)
def to_dict(self):
return {
"job_id": self.job_id,
"status": self.status,
"progress": round(self.progress, 2),
"current_step": self.current_step,
"total_steps": self.total_steps,
"message": self.message,
"estimated_finish_time": self.estimated_finish_time.isoformat() if self.estimated_finish_time else None,
"error": self.error,
"model_path": self.model_path,
"detected_columns": self.detected_columns
}
def detect_qa_columns(dataset):
"""Automatically detect question and answer columns in the dataset"""
# Common patterns for question columns
question_patterns = [
'question', 'prompt', 'input', 'query', 'patient', 'user', 'human',
'instruction', 'context', 'q', 'text', 'source'
]
# Common patterns for answer columns
answer_patterns = [
'answer', 'response', 'output', 'reply', 'doctor', 'assistant', 'ai',
'completion', 'target', 'a', 'label', 'ground_truth'
]
# Get column names
columns = list(dataset.column_names)
# Find question column
question_col = None
for pattern in question_patterns:
for col in columns:
if pattern.lower() in col.lower():
question_col = col
break
if question_col:
break
# Find answer column
answer_col = None
for pattern in answer_patterns:
for col in columns:
if pattern.lower() in col.lower() and col != question_col:
answer_col = col
break
if answer_col:
break
# Fallback: use first two text columns if patterns don't match
if not question_col or not answer_col:
text_columns = []
for col in columns:
# Check if column contains text data
sample = dataset[0][col]
if isinstance(sample, str) and len(sample.strip()) > 0:
text_columns.append(col)
if len(text_columns) >= 2:
question_col = text_columns[0]
answer_col = text_columns[1]
elif len(text_columns) == 1:
# Single column case - use it for both (self-supervised)
question_col = answer_col = text_columns[0]
return question_col, answer_col
def train_model_background(job_id, dataset_name, base_model_name=None):
"""Background training function with progress tracking"""
progress = training_jobs[job_id]
try:
# Create a temporary directory for this job
temp_dir = tempfile.mkdtemp(prefix=f"train_{job_id}_")
# Set environment variables for caching
os.environ['HF_HOME'] = temp_dir
os.environ['TRANSFORMERS_CACHE'] = temp_dir
os.environ['HF_DATASETS_CACHE'] = temp_dir
os.environ['TORCH_HOME'] = temp_dir
progress.status = "loading_libraries"
progress.message = "Loading required libraries..."
# Import heavy libraries after setting cache paths
import torch
from datasets import load_dataset, Dataset
from huggingface_hub import login
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TrainingArguments,
Trainer,
TrainerCallback,
)
from peft import (
LoraConfig,
get_peft_model,
)
# === Authentication ===
hf_token = os.getenv('HF_TOKEN')
if hf_token:
login(token=hf_token)
progress.status = "loading_model"
progress.message = "Loading base model and tokenizer..."
# === Configuration ===
base_model = base_model_name or "microsoft/DialoGPT-small"
new_model = f"trained-model-{job_id}"
max_length = 256
# === Load Model and Tokenizer ===
model = AutoModelForCausalLM.from_pretrained(
base_model,
cache_dir=temp_dir,
torch_dtype=torch.float32,
device_map="auto" if torch.cuda.is_available() else "cpu",
trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(
base_model,
cache_dir=temp_dir,
trust_remote_code=True
)
# Add padding token if not present
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Resize token embeddings if needed
model.resize_token_embeddings(len(tokenizer))
progress.status = "preparing_model"
progress.message = "Setting up LoRA configuration..."
# === LoRA Config ===
peft_config = LoraConfig(
r=8,
lora_alpha=16,
lora_dropout=0.1,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, peft_config)
progress.status = "loading_dataset"
progress.message = "Loading and preparing dataset..."
# === Load & Prepare Dataset ===
dataset = load_dataset(
dataset_name,
split="train" if "train" in load_dataset(dataset_name, cache_dir=temp_dir).keys() else "all",
cache_dir=temp_dir,
trust_remote_code=True
)
# Automatically detect question and answer columns
question_col, answer_col = detect_qa_columns(dataset)
if not question_col or not answer_col:
raise ValueError("Could not automatically detect question and answer columns in the dataset")
progress.detected_columns = {"question": question_col, "answer": answer_col}
progress.message = f"Detected columns - Question: {question_col}, Answer: {answer_col}"
# Use subset for faster testing (can be made configurable)
dataset = dataset.shuffle(seed=65).select(range(min(1000, len(dataset))))
# Custom dataset class for proper handling
class CustomDataset(torch.utils.data.Dataset):
def __init__(self, texts, tokenizer, max_length):
self.texts = texts
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
text = self.texts[idx]
# Tokenize the text
encoding = self.tokenizer(
text,
truncation=True,
padding='max_length',
max_length=self.max_length,
return_tensors='pt'
)
# Flatten the tensors (remove batch dimension)
input_ids = encoding['input_ids'].squeeze()
attention_mask = encoding['attention_mask'].squeeze()
# For causal language modeling, labels are the same as input_ids
labels = input_ids.clone()
# Set labels to -100 for padding tokens (they won't contribute to loss)
labels[attention_mask == 0] = -100
return {
'input_ids': input_ids,
'attention_mask': attention_mask,
'labels': labels
}
# Prepare texts using detected columns
texts = []
for item in dataset:
question = str(item[question_col]).strip()
answer = str(item[answer_col]).strip()
text = f"Question: {question}\nAnswer: {answer}{tokenizer.eos_token}"
texts.append(text)
# Create custom dataset
train_dataset = CustomDataset(texts, tokenizer, max_length)
# Calculate total training steps
batch_size = 2
gradient_accumulation_steps = 1
num_epochs = 1
steps_per_epoch = len(train_dataset) // (batch_size * gradient_accumulation_steps)
total_steps = steps_per_epoch * num_epochs
progress.total_steps = total_steps
progress.status = "training"
progress.message = "Starting training..."
# === Training Arguments ===
output_dir = os.path.join(temp_dir, new_model)
os.makedirs(output_dir, exist_ok=True)
training_args = TrainingArguments(
output_dir=output_dir,
per_device_train_batch_size=batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
num_train_epochs=num_epochs,
logging_steps=1,
save_steps=max(1, total_steps // 2),
save_total_limit=1,
learning_rate=5e-5,
warmup_steps=2,
logging_strategy="steps",
save_strategy="steps",
fp16=False,
bf16=False,
dataloader_num_workers=0,
remove_unused_columns=False,
report_to=None,
prediction_loss_only=True,
)
# Custom callback to track progress
class ProgressCallback(TrainerCallback):
def __init__(self, progress_tracker):
self.progress_tracker = progress_tracker
self.last_update = time.time()
def on_log(self, args, state, control, model=None, logs=None, **kwargs):
current_time = time.time()
# Update every 3 seconds
if current_time - self.last_update >= 3:
self.progress_tracker.update_progress(
state.global_step,
state.max_steps,
f"Training step {state.global_step}/{state.max_steps}"
)
self.last_update = current_time
# Log training metrics if available
if logs:
loss = logs.get('train_loss', logs.get('loss', 'N/A'))
self.progress_tracker.message = f"Step {state.global_step}/{state.max_steps}, Loss: {loss}"
def on_train_begin(self, args, state, control, **kwargs):
self.progress_tracker.status = "training"
self.progress_tracker.message = "Training started..."
def on_train_end(self, args, state, control, **kwargs):
self.progress_tracker.status = "saving"
self.progress_tracker.message = "Training complete, saving model..."
# === Trainer Initialization ===
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
callbacks=[ProgressCallback(progress)],
tokenizer=tokenizer,
)
# === Train & Save ===
trainer.train()
trainer.save_model(output_dir)
tokenizer.save_pretrained(output_dir)
# Save model info
progress.model_path = output_dir
progress.status = "completed"
progress.progress = 100
progress.message = f"Training completed! Model ready for download."
# Keep the temp directory for download (cleanup after 1 hour)
def cleanup_temp_dir():
time.sleep(3600) # Wait 1 hour before cleanup
try:
shutil.rmtree(temp_dir)
# Remove from training_jobs after cleanup
if job_id in training_jobs:
del training_jobs[job_id]
except:
pass
cleanup_thread = threading.Thread(target=cleanup_temp_dir)
cleanup_thread.daemon = True
cleanup_thread.start()
except Exception as e:
progress.status = "error"
progress.error = str(e)
progress.message = f"Training failed: {str(e)}"
# Clean up on error
try:
if 'temp_dir' in locals():
shutil.rmtree(temp_dir)
except:
pass
def create_model_zip(model_path, job_id):
"""Create a zip file containing the trained model"""
memory_file = io.BytesIO()
with zipfile.ZipFile(memory_file, 'w', zipfile.ZIP_DEFLATED) as zf:
for root, dirs, files in os.walk(model_path):
for file in files:
file_path = os.path.join(root, file)
arc_name = os.path.relpath(file_path, model_path)
zf.write(file_path, arc_name)
memory_file.seek(0)
return memory_file
# ============== API ROUTES ==============
@app.route('/api/train', methods=['POST'])
def start_training():
"""Start training and return job ID for tracking"""
try:
data = request.get_json() if request.is_json else {}
dataset_name = data.get('dataset_name', 'ruslanmv/ai-medical-chatbot')
base_model_name = data.get('base_model', 'microsoft/DialoGPT-small')
job_id = str(uuid.uuid4())[:8] # Short UUID
progress = TrainingProgress(job_id)
training_jobs[job_id] = progress
# Start training in background thread
training_thread = threading.Thread(
target=train_model_background,
args=(job_id, dataset_name, base_model_name)
)
training_thread.daemon = True
training_thread.start()
return jsonify({
"status": "started",
"job_id": job_id,
"dataset_name": dataset_name,
"base_model": base_model_name,
"message": "Training started. Use /api/status/<job_id> to track progress."
})
except Exception as e:
return jsonify({"status": "error", "message": str(e)}), 500
@app.route('/api/status/<job_id>', methods=['GET'])
def get_training_status(job_id):
"""Get training progress and estimated completion time"""
if job_id not in training_jobs:
return jsonify({"status": "error", "message": "Job not found"}), 404
progress = training_jobs[job_id]
return jsonify(progress.to_dict())
@app.route('/api/download/<job_id>', methods=['GET'])
def download_model(job_id):
"""Download the trained model as a zip file"""
if job_id not in training_jobs:
return jsonify({"status": "error", "message": "Job not found"}), 404
progress = training_jobs[job_id]
if progress.status != "completed":
return jsonify({
"status": "error",
"message": f"Model not ready for download. Current status: {progress.status}"
}), 400
if not progress.model_path or not os.path.exists(progress.model_path):
return jsonify({
"status": "error",
"message": "Model files not found. They may have been cleaned up."
}), 404
try:
# Create zip file in memory
zip_file = create_model_zip(progress.model_path, job_id)
return send_file(
zip_file,
as_attachment=True,
download_name=f"trained_model_{job_id}.zip",
mimetype='application/zip'
)
except Exception as e:
return jsonify({"status": "error", "message": f"Download failed: {str(e)}"}), 500
@app.route('/api/jobs', methods=['GET'])
def list_jobs():
"""List all training jobs"""
jobs = {job_id: progress.to_dict() for job_id, progress in training_jobs.items()}
return jsonify({"jobs": jobs})
@app.route('/')
def home():
return jsonify({
"message": "Welcome to Enhanced LLaMA Fine-tuning API!",
"features": [
"Automatic question/answer column detection",
"Configurable base model and dataset",
"Local model download",
"Progress tracking with ETA"
],
"endpoints": {
"POST /api/train": "Start training (accepts dataset_name and base_model in JSON)",
"GET /api/status/<job_id>": "Get training status and detected columns",
"GET /api/download/<job_id>": "Download trained model as zip",
"GET /api/jobs": "List all jobs"
},
"usage_example": {
"start_training": {
"method": "POST",
"url": "/api/train",
"body": {
"dataset_name": "your-dataset-name",
"base_model": "microsoft/DialoGPT-small"
}
}
}
})
@app.route('/health')
def health():
return jsonify({"status": "healthy"})
if __name__ == '__main__':
port = int(os.environ.get('PORT', 7860)) # HF Spaces uses port 7860
app.run(host='0.0.0.0', port=port, debug=False)