flask_pfe / app.py
Guetat Youssef
test
aba82e3
raw
history blame
11.6 kB
from flask import Flask, jsonify, request
import threading
import time
import os
import tempfile
import shutil
import uuid
from datetime import datetime, timedelta
app = Flask(__name__)
# Global variables to track training progress
training_jobs = {}
class TrainingProgress:
def __init__(self, job_id):
self.job_id = job_id
self.status = "initializing"
self.progress = 0
self.current_step = 0
self.total_steps = 0
self.start_time = time.time()
self.estimated_finish_time = None
self.message = "Starting training..."
self.error = None
def update_progress(self, current_step, total_steps, message=""):
self.current_step = current_step
self.total_steps = total_steps
self.progress = (current_step / total_steps) * 100 if total_steps > 0 else 0
self.message = message
# Calculate estimated finish time
if current_step > 0:
elapsed_time = time.time() - self.start_time
time_per_step = elapsed_time / current_step
remaining_steps = total_steps - current_step
estimated_remaining_time = remaining_steps * time_per_step
self.estimated_finish_time = datetime.now() + timedelta(seconds=estimated_remaining_time)
def to_dict(self):
return {
"job_id": self.job_id,
"status": self.status,
"progress": round(self.progress, 2),
"current_step": self.current_step,
"total_steps": self.total_steps,
"message": self.message,
"estimated_finish_time": self.estimated_finish_time.isoformat() if self.estimated_finish_time else None,
"error": self.error
}
def train_model_background(job_id):
"""Background training function with progress tracking"""
progress = training_jobs[job_id]
try:
# Create a temporary directory for this job
temp_dir = tempfile.mkdtemp(prefix=f"train_{job_id}_")
# Set environment variables for caching
os.environ['HF_HOME'] = temp_dir
os.environ['TRANSFORMERS_CACHE'] = temp_dir
os.environ['HF_DATASETS_CACHE'] = temp_dir
os.environ['TORCH_HOME'] = temp_dir
progress.status = "loading_libraries"
progress.message = "Loading required libraries..."
# Import heavy libraries after setting cache paths
import torch
from datasets import load_dataset
from huggingface_hub import login
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TrainingArguments,
Trainer,
TrainerCallback,
DataCollatorForLanguageModeling
)
from peft import (
LoraConfig,
get_peft_model,
)
# === Authentication ===
hf_token = os.getenv('HF_TOKEN')
if hf_token:
login(token=hf_token)
progress.status = "loading_model"
progress.message = "Loading base model and tokenizer..."
# === Configuration ===
base_model = "microsoft/DialoGPT-small" # Smaller model for testing
dataset_name = "ruslanmv/ai-medical-chatbot"
new_model = f"trained-model-{job_id}"
# === Load Model and Tokenizer ===
model = AutoModelForCausalLM.from_pretrained(
base_model,
cache_dir=temp_dir,
torch_dtype=torch.float32,
device_map="auto" if torch.cuda.is_available() else "cpu",
trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(
base_model,
cache_dir=temp_dir,
trust_remote_code=True
)
# Add padding token if not present
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
progress.status = "preparing_model"
progress.message = "Setting up LoRA configuration..."
# === LoRA Config ===
peft_config = LoraConfig(
r=8,
lora_alpha=16,
lora_dropout=0.1,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, peft_config)
progress.status = "loading_dataset"
progress.message = "Loading and preparing dataset..."
# === Load & Prepare Dataset ===
dataset = load_dataset(
dataset_name,
split="all",
cache_dir=temp_dir,
trust_remote_code=True
)
dataset = dataset.shuffle(seed=65).select(range(50)) # Use only 50 samples for faster testing
def tokenize_function(examples):
# Format the text
texts = []
for i in range(len(examples['Patient'])):
text = f"Patient: {examples['Patient'][i]}\nDoctor: {examples['Doctor'][i]}{tokenizer.eos_token}"
texts.append(text)
# Tokenize
tokenized = tokenizer(
texts,
truncation=True,
padding=False,
max_length=256,
return_tensors=None
)
# For causal LM, labels are the same as input_ids
tokenized["labels"] = tokenized["input_ids"].copy()
return tokenized
# Tokenize dataset
tokenized_dataset = dataset.map(
tokenize_function,
batched=True,
remove_columns=dataset.column_names,
desc="Tokenizing dataset"
)
# Data collator for language modeling
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False, # We're doing causal LM, not masked LM
)
# Calculate total training steps
train_size = len(tokenized_dataset)
batch_size = 2
gradient_accumulation_steps = 1
num_epochs = 1
steps_per_epoch = train_size // (batch_size * gradient_accumulation_steps)
total_steps = steps_per_epoch * num_epochs
progress.total_steps = total_steps
progress.status = "training"
progress.message = "Starting training..."
# === Training Arguments ===
output_dir = os.path.join(temp_dir, new_model)
os.makedirs(output_dir, exist_ok=True)
training_args = TrainingArguments(
output_dir=output_dir,
per_device_train_batch_size=batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
num_train_epochs=num_epochs,
logging_steps=1,
save_steps=20,
save_total_limit=1,
learning_rate=5e-5,
warmup_steps=5,
logging_strategy="steps",
save_strategy="steps",
fp16=False,
bf16=False,
dataloader_num_workers=0,
remove_unused_columns=False,
report_to=None,
)
# Custom callback to track progress
class ProgressCallback(TrainerCallback):
def __init__(self, progress_tracker):
self.progress_tracker = progress_tracker
self.last_update = time.time()
def on_log(self, args, state, control, model=None, logs=None, **kwargs):
current_time = time.time()
# Update every 5 seconds or on significant step changes
if current_time - self.last_update >= 5 or state.global_step % 2 == 0:
self.progress_tracker.update_progress(
state.global_step,
state.max_steps,
f"Training step {state.global_step}/{state.max_steps}"
)
self.last_update = current_time
def on_train_begin(self, args, state, control, **kwargs):
self.progress_tracker.status = "training"
self.progress_tracker.message = "Training started..."
def on_train_end(self, args, state, control, **kwargs):
self.progress_tracker.status = "saving"
self.progress_tracker.message = "Training complete, saving model..."
# === Trainer Initialization ===
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
data_collator=data_collator,
callbacks=[ProgressCallback(progress)],
)
# === Train & Save ===
trainer.train()
trainer.save_model(output_dir)
tokenizer.save_pretrained(output_dir)
progress.status = "completed"
progress.progress = 100
progress.message = f"Training completed! Model saved to {output_dir}"
# Clean up temporary directory after a delay
def cleanup_temp_dir():
time.sleep(300) # Wait 5 minutes before cleanup
try:
shutil.rmtree(temp_dir)
except:
pass
cleanup_thread = threading.Thread(target=cleanup_temp_dir)
cleanup_thread.daemon = True
cleanup_thread.start()
except Exception as e:
progress.status = "error"
progress.error = str(e)
progress.message = f"Training failed: {str(e)}"
# Clean up on error
try:
if 'temp_dir' in locals():
shutil.rmtree(temp_dir)
except:
pass
# ============== API ROUTES ==============
@app.route('/api/train', methods=['POST'])
def start_training():
"""Start training and return job ID for tracking"""
try:
job_id = str(uuid.uuid4())[:8] # Short UUID
progress = TrainingProgress(job_id)
training_jobs[job_id] = progress
# Start training in background thread
training_thread = threading.Thread(target=train_model_background, args=(job_id,))
training_thread.daemon = True
training_thread.start()
return jsonify({
"status": "started",
"job_id": job_id,
"message": "Training started. Use /api/status/<job_id> to track progress."
})
except Exception as e:
return jsonify({"status": "error", "message": str(e)}), 500
@app.route('/api/status/<job_id>', methods=['GET'])
def get_training_status(job_id):
"""Get training progress and estimated completion time"""
if job_id not in training_jobs:
return jsonify({"status": "error", "message": "Job not found"}), 404
progress = training_jobs[job_id]
return jsonify(progress.to_dict())
@app.route('/api/jobs', methods=['GET'])
def list_jobs():
"""List all training jobs"""
jobs = {job_id: progress.to_dict() for job_id, progress in training_jobs.items()}
return jsonify({"jobs": jobs})
@app.route('/')
def home():
return jsonify({
"message": "Welcome to LLaMA Fine-tuning API!",
"endpoints": {
"POST /api/train": "Start training",
"GET /api/status/<job_id>": "Get training status",
"GET /api/jobs": "List all jobs"
}
})
@app.route('/health')
def health():
return jsonify({"status": "healthy"})
if __name__ == '__main__':
port = int(os.environ.get('PORT', 7860)) # HF Spaces uses port 7860
app.run(host='0.0.0.0', port=port, debug=False)