flask_pfe / app.py
Guetat Youssef
test
9774f95
raw
history blame
12.6 kB
from flask import Flask, jsonify, request
import threading
import time
import os
import tempfile
import shutil
import uuid
from datetime import datetime, timedelta
app = Flask(__name__)
# Global variables to track training progress
training_jobs = {}
class TrainingProgress:
def __init__(self, job_id):
self.job_id = job_id
self.status = "initializing"
self.progress = 0
self.current_step = 0
self.total_steps = 0
self.start_time = time.time()
self.estimated_finish_time = None
self.message = "Starting training..."
self.error = None
def update_progress(self, current_step, total_steps, message=""):
self.current_step = current_step
self.total_steps = total_steps
self.progress = (current_step / total_steps) * 100 if total_steps > 0 else 0
self.message = message
# Calculate estimated finish time
if current_step > 0:
elapsed_time = time.time() - self.start_time
time_per_step = elapsed_time / current_step
remaining_steps = total_steps - current_step
estimated_remaining_time = remaining_steps * time_per_step
self.estimated_finish_time = datetime.now() + timedelta(seconds=estimated_remaining_time)
def to_dict(self):
return {
"job_id": self.job_id,
"status": self.status,
"progress": round(self.progress, 2),
"current_step": self.current_step,
"total_steps": self.total_steps,
"message": self.message,
"estimated_finish_time": self.estimated_finish_time.isoformat() if self.estimated_finish_time else None,
"error": self.error
}
def train_model_background(job_id):
"""Background training function with progress tracking"""
progress = training_jobs[job_id]
try:
# Create a temporary directory for this job
temp_dir = tempfile.mkdtemp(prefix=f"train_{job_id}_")
# Set environment variables for caching
os.environ['HF_HOME'] = temp_dir
os.environ['TRANSFORMERS_CACHE'] = temp_dir
os.environ['HF_DATASETS_CACHE'] = temp_dir
os.environ['TORCH_HOME'] = temp_dir
progress.status = "loading_libraries"
progress.message = "Loading required libraries..."
# Import heavy libraries after setting cache paths
import torch
from datasets import load_dataset, Dataset
from huggingface_hub import login
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TrainingArguments,
Trainer,
TrainerCallback,
)
from peft import (
LoraConfig,
get_peft_model,
)
# === Authentication ===
hf_token = os.getenv('HF_TOKEN')
if hf_token:
login(token=hf_token)
progress.status = "loading_model"
progress.message = "Loading base model and tokenizer..."
# === Configuration ===
base_model = "microsoft/DialoGPT-small"
dataset_name = "ruslanmv/ai-medical-chatbot"
new_model = f"trained-model-{job_id}"
max_length = 256
# === Load Model and Tokenizer ===
model = AutoModelForCausalLM.from_pretrained(
base_model,
cache_dir=temp_dir,
torch_dtype=torch.float32,
device_map="auto" if torch.cuda.is_available() else "cpu",
trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(
base_model,
cache_dir=temp_dir,
trust_remote_code=True
)
# Add padding token if not present
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Resize token embeddings if needed
model.resize_token_embeddings(len(tokenizer))
progress.status = "preparing_model"
progress.message = "Setting up LoRA configuration..."
# === LoRA Config ===
peft_config = LoraConfig(
r=8,
lora_alpha=16,
lora_dropout=0.1,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, peft_config)
progress.status = "loading_dataset"
progress.message = "Loading and preparing dataset..."
# === Load & Prepare Dataset ===
dataset = load_dataset(
dataset_name,
split="all",
cache_dir=temp_dir,
trust_remote_code=True
)
dataset = dataset.shuffle(seed=65).select(range(30)) # Use only 30 samples for faster testing
# Custom dataset class for proper handling
class CustomDataset(torch.utils.data.Dataset):
def __init__(self, texts, tokenizer, max_length):
self.texts = texts
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
text = self.texts[idx]
# Tokenize the text
encoding = self.tokenizer(
text,
truncation=True,
padding='max_length',
max_length=self.max_length,
return_tensors='pt'
)
# Flatten the tensors (remove batch dimension)
input_ids = encoding['input_ids'].squeeze()
attention_mask = encoding['attention_mask'].squeeze()
# For causal language modeling, labels are the same as input_ids
# But we shift them so the model predicts the next token
labels = input_ids.clone()
# Set labels to -100 for padding tokens (they won't contribute to loss)
labels[attention_mask == 0] = -100
return {
'input_ids': input_ids,
'attention_mask': attention_mask,
'labels': labels
}
# Prepare texts
texts = []
for item in dataset:
text = f"Patient: {item['Patient']}\nDoctor: {item['Doctor']}{tokenizer.eos_token}"
texts.append(text)
# Create custom dataset
train_dataset = CustomDataset(texts, tokenizer, max_length)
# Calculate total training steps
batch_size = 2
gradient_accumulation_steps = 1
num_epochs = 1
steps_per_epoch = len(train_dataset) // (batch_size * gradient_accumulation_steps)
total_steps = steps_per_epoch * num_epochs
progress.total_steps = total_steps
progress.status = "training"
progress.message = "Starting training..."
# === Training Arguments ===
output_dir = os.path.join(temp_dir, new_model)
os.makedirs(output_dir, exist_ok=True)
training_args = TrainingArguments(
output_dir=output_dir,
per_device_train_batch_size=batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
num_train_epochs=num_epochs,
logging_steps=1,
save_steps=15,
save_total_limit=1,
learning_rate=5e-5,
warmup_steps=2,
logging_strategy="steps",
save_strategy="steps",
fp16=False,
bf16=False,
dataloader_num_workers=0,
remove_unused_columns=False,
report_to=None,
prediction_loss_only=True,
)
# Custom callback to track progress
class ProgressCallback(TrainerCallback):
def __init__(self, progress_tracker):
self.progress_tracker = progress_tracker
self.last_update = time.time()
def on_log(self, args, state, control, model=None, logs=None, **kwargs):
current_time = time.time()
# Update every 3 seconds
if current_time - self.last_update >= 3:
self.progress_tracker.update_progress(
state.global_step,
state.max_steps,
f"Training step {state.global_step}/{state.max_steps}"
)
self.last_update = current_time
# Log training metrics if available
if logs:
loss = logs.get('train_loss', logs.get('loss', 'N/A'))
self.progress_tracker.message = f"Step {state.global_step}/{state.max_steps}, Loss: {loss}"
def on_train_begin(self, args, state, control, **kwargs):
self.progress_tracker.status = "training"
self.progress_tracker.message = "Training started..."
def on_train_end(self, args, state, control, **kwargs):
self.progress_tracker.status = "saving"
self.progress_tracker.message = "Training complete, saving model..."
# === Trainer Initialization ===
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
callbacks=[ProgressCallback(progress)],
tokenizer=tokenizer,
)
# === Train & Save ===
trainer.train()
trainer.save_model(output_dir)
tokenizer.save_pretrained(output_dir)
progress.status = "completed"
progress.progress = 100
progress.message = f"Training completed! Model saved to {output_dir}"
# Clean up temporary directory after a delay
def cleanup_temp_dir():
time.sleep(300) # Wait 5 minutes before cleanup
try:
shutil.rmtree(temp_dir)
except:
pass
cleanup_thread = threading.Thread(target=cleanup_temp_dir)
cleanup_thread.daemon = True
cleanup_thread.start()
except Exception as e:
progress.status = "error"
progress.error = str(e)
progress.message = f"Training failed: {str(e)}"
# Clean up on error
try:
if 'temp_dir' in locals():
shutil.rmtree(temp_dir)
except:
pass
# ============== API ROUTES ==============
@app.route('/api/train', methods=['POST'])
def start_training():
"""Start training and return job ID for tracking"""
try:
job_id = str(uuid.uuid4())[:8] # Short UUID
progress = TrainingProgress(job_id)
training_jobs[job_id] = progress
# Start training in background thread
training_thread = threading.Thread(target=train_model_background, args=(job_id,))
training_thread.daemon = True
training_thread.start()
return jsonify({
"status": "started",
"job_id": job_id,
"message": "Training started. Use /api/status/<job_id> to track progress."
})
except Exception as e:
return jsonify({"status": "error", "message": str(e)}), 500
@app.route('/api/status/<job_id>', methods=['GET'])
def get_training_status(job_id):
"""Get training progress and estimated completion time"""
if job_id not in training_jobs:
return jsonify({"status": "error", "message": "Job not found"}), 404
progress = training_jobs[job_id]
return jsonify(progress.to_dict())
@app.route('/api/jobs', methods=['GET'])
def list_jobs():
"""List all training jobs"""
jobs = {job_id: progress.to_dict() for job_id, progress in training_jobs.items()}
return jsonify({"jobs": jobs})
@app.route('/')
def home():
return jsonify({
"message": "Welcome to LLaMA Fine-tuning API!",
"endpoints": {
"POST /api/train": "Start training",
"GET /api/status/<job_id>": "Get training status",
"GET /api/jobs": "List all jobs"
}
})
@app.route('/health')
def health():
return jsonify({"status": "healthy"})
if __name__ == '__main__':
port = int(os.environ.get('PORT', 7860)) # HF Spaces uses port 7860
app.run(host='0.0.0.0', port=port, debug=False)