flask_pfe / app.py
Guetat Youssef
test
8f8763e
raw
history blame
19.7 kB
from flask import Flask, jsonify, request, send_file
import threading
import time
import os
import tempfile
import shutil
import uuid
import zipfile
import io
from datetime import datetime, timedelta
app = Flask(__name__)
# Global variables to track training progress
training_jobs = {}
class TrainingProgress:
def __init__(self, job_id):
self.job_id = job_id
self.status = "initializing"
self.progress = 0
self.current_step = 0
self.total_steps = 0
self.start_time = time.time()
self.estimated_finish_time = None
self.message = "Starting training..."
self.error = None
self.model_path = None
self.detected_columns = None
def update_progress(self, current_step, total_steps, message=""):
self.current_step = current_step
self.total_steps = total_steps
self.progress = (current_step / total_steps) * 100 if total_steps > 0 else 0
self.message = message
# Calculate estimated finish time
if current_step > 0:
elapsed_time = time.time() - self.start_time
time_per_step = elapsed_time / current_step
remaining_steps = total_steps - current_step
estimated_remaining_time = remaining_steps * time_per_step
self.estimated_finish_time = datetime.now() + timedelta(seconds=estimated_remaining_time)
def to_dict(self):
return {
"job_id": self.job_id,
"status": self.status,
"progress": round(self.progress, 2),
"current_step": self.current_step,
"total_steps": self.total_steps,
"message": self.message,
"estimated_finish_time": self.estimated_finish_time.isoformat() if self.estimated_finish_time else None,
"error": self.error,
"model_path": self.model_path,
"detected_columns": self.detected_columns
}
def detect_qa_columns(dataset):
"""Automatically detect question and answer columns in the dataset"""
# Common patterns for question columns
question_patterns = [
'question', 'prompt', 'input', 'query', 'patient', 'user', 'human',
'instruction', 'context', 'q', 'text', 'source'
]
# Common patterns for answer columns
answer_patterns = [
'answer', 'response', 'output', 'reply', 'doctor', 'assistant', 'ai',
'completion', 'target', 'a', 'label', 'ground_truth'
]
# Get column names
columns = list(dataset.column_names)
# Find question column
question_col = None
for pattern in question_patterns:
for col in columns:
if pattern.lower() in col.lower():
question_col = col
break
if question_col:
break
# Find answer column
answer_col = None
for pattern in answer_patterns:
for col in columns:
if pattern.lower() in col.lower() and col != question_col:
answer_col = col
break
if answer_col:
break
# Fallback: use first two text columns if patterns don't match
if not question_col or not answer_col:
text_columns = []
for col in columns:
# Check if column contains text data
sample = dataset[0][col]
if isinstance(sample, str) and len(sample.strip()) > 0:
text_columns.append(col)
if len(text_columns) >= 2:
question_col = text_columns[0]
answer_col = text_columns[1]
elif len(text_columns) == 1:
# Single column case - use it for both (self-supervised)
question_col = answer_col = text_columns[0]
return question_col, answer_col
def train_model_background(job_id, dataset_name, base_model_name=None):
"""Background training function with fixed tokenization"""
progress = training_jobs[job_id]
try:
# Create a temporary directory for this job
temp_dir = tempfile.mkdtemp(prefix=f"train_{job_id}_")
# Set environment variables for caching
os.environ['HF_HOME'] = temp_dir
os.environ['TRANSFORMERS_CACHE'] = temp_dir
os.environ['HF_DATASETS_CACHE'] = temp_dir
os.environ['TORCH_HOME'] = temp_dir
progress.status = "loading_libraries"
progress.message = "Loading required libraries..."
# Import heavy libraries after setting cache paths
import torch
from datasets import load_dataset, Dataset
from huggingface_hub import login
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TrainingArguments,
Trainer,
TrainerCallback,
DataCollatorForLanguageModeling
)
from peft import (
LoraConfig,
get_peft_model,
TaskType
)
# === Authentication ===
hf_token = os.getenv('HF_TOKEN')
if hf_token:
login(token=hf_token)
progress.status = "loading_model"
progress.message = "Loading base model and tokenizer..."
# === Model Configuration ===
base_model = base_model_name or "microsoft/DialoGPT-medium"
new_model = f"trained-model-{job_id}"
max_length = 512
# === Load Model and Tokenizer ===
model = AutoModelForCausalLM.from_pretrained(
base_model,
cache_dir=temp_dir,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else "cpu",
trust_remote_code=True,
low_cpu_mem_usage=True
)
tokenizer = AutoTokenizer.from_pretrained(
base_model,
cache_dir=temp_dir,
trust_remote_code=True,
padding_side="right"
)
# Add padding token if not present
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
# Resize token embeddings if needed
model.resize_token_embeddings(len(tokenizer))
progress.status = "preparing_model"
progress.message = "Setting up LoRA configuration..."
# === LoRA Config ===
peft_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type=TaskType.CAUSAL_LM,
target_modules=["c_attn", "c_proj"],
)
model = get_peft_model(model, peft_config)
# Print trainable parameters
model.print_trainable_parameters()
progress.status = "loading_dataset"
progress.message = "Loading and preparing dataset..."
# === Load & Prepare Dataset ===
dataset = load_dataset(
dataset_name,
split="train" if "train" in load_dataset(dataset_name, cache_dir=temp_dir).keys() else "all",
cache_dir=temp_dir,
trust_remote_code=True
)
# Automatically detect question and answer columns
question_col, answer_col = detect_qa_columns(dataset)
if not question_col or not answer_col:
raise ValueError("Could not automatically detect question and answer columns in the dataset")
progress.detected_columns = {"question": question_col, "answer": answer_col}
progress.message = f"Detected columns - Question: {question_col}, Answer: {answer_col}"
# Use subset for faster training
dataset_size = min(500, len(dataset))
dataset = dataset.shuffle(seed=42).select(range(dataset_size))
# === Fixed Text Formatting ===
def format_conversation(example):
question = str(example[question_col]).strip()
answer = str(example[answer_col]).strip()
# Simple format that works well with tokenizer
conversation = f"Question: {question}\nAnswer: {answer}{tokenizer.eos_token}"
return {"text": conversation}
# Apply formatting
formatted_dataset = dataset.map(format_conversation, remove_columns=dataset.column_names)
# Filter out very short or very long examples
formatted_dataset = formatted_dataset.filter(lambda x: 10 < len(x["text"]) < max_length * 3)
# === Fixed Tokenization Function ===
def tokenize_function(examples):
# Tokenize the text
model_inputs = tokenizer(
examples["text"],
truncation=True,
padding=False, # Will be handled by data collator
max_length=max_length,
return_tensors=None,
)
# For causal LM, labels are the same as input_ids
model_inputs["labels"] = model_inputs["input_ids"].copy()
return model_inputs
# Tokenize dataset
tokenized_dataset = formatted_dataset.map(
tokenize_function,
batched=True,
remove_columns=formatted_dataset.column_names,
desc="Tokenizing dataset",
)
# === Training Configuration ===
batch_size = 4 if torch.cuda.is_available() else 2
gradient_accumulation_steps = 2
num_epochs = 2
learning_rate = 2e-4
steps_per_epoch = len(tokenized_dataset) // (batch_size * gradient_accumulation_steps)
total_steps = steps_per_epoch * num_epochs
warmup_steps = max(10, total_steps // 10)
progress.total_steps = total_steps
progress.status = "training"
progress.message = "Starting training..."
output_dir = os.path.join(temp_dir, new_model)
os.makedirs(output_dir, exist_ok=True)
training_args = TrainingArguments(
output_dir=output_dir,
per_device_train_batch_size=batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
num_train_epochs=num_epochs,
learning_rate=learning_rate,
warmup_steps=warmup_steps,
logging_steps=5,
save_steps=max(10, total_steps // 4),
save_total_limit=2,
evaluation_strategy="no",
logging_strategy="steps",
save_strategy="steps",
fp16=torch.cuda.is_available(),
bf16=False,
dataloader_num_workers=0,
remove_unused_columns=False,
report_to=None,
prediction_loss_only=True,
optim="adamw_torch",
weight_decay=0.01,
lr_scheduler_type="cosine",
gradient_checkpointing=True,
dataloader_pin_memory=False,
)
# === Data Collator ===
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False,
return_tensors="pt",
pad_to_multiple_of=8 if torch.cuda.is_available() else None,
)
# Custom callback to track progress
class ProgressCallback(TrainerCallback):
def __init__(self, progress_tracker):
self.progress_tracker = progress_tracker
self.last_update = time.time()
def on_log(self, args, state, control, model=None, logs=None, **kwargs):
current_time = time.time()
if current_time - self.last_update >= 5:
self.progress_tracker.update_progress(
state.global_step,
state.max_steps,
f"Training step {state.global_step}/{state.max_steps}"
)
self.last_update = current_time
if logs:
loss = logs.get('train_loss', logs.get('loss', 'N/A'))
lr = logs.get('learning_rate', 'N/A')
if isinstance(loss, (int, float)):
loss = f"{loss:.4f}"
self.progress_tracker.message = f"Step {state.global_step}/{state.max_steps}, Loss: {loss}, LR: {lr}"
def on_train_begin(self, args, state, control, **kwargs):
self.progress_tracker.status = "training"
self.progress_tracker.message = "Training started..."
def on_train_end(self, args, state, control, **kwargs):
self.progress_tracker.status = "saving"
self.progress_tracker.message = "Training complete, saving model..."
# === Trainer Initialization ===
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
data_collator=data_collator,
callbacks=[ProgressCallback(progress)],
tokenizer=tokenizer,
)
# === Train & Save ===
trainer.train()
# Save the model properly
trainer.save_model(output_dir)
tokenizer.save_pretrained(output_dir)
# Save additional info
with open(os.path.join(output_dir, "base_model.txt"), "w") as f:
f.write(base_model)
training_info = {
"base_model": base_model,
"dataset_name": dataset_name,
"dataset_size": len(tokenized_dataset),
"max_length": max_length,
"batch_size": batch_size,
"learning_rate": learning_rate,
"num_epochs": num_epochs,
"total_steps": total_steps,
"detected_columns": progress.detected_columns
}
with open(os.path.join(output_dir, "training_info.json"), "w") as f:
import json
json.dump(training_info, f, indent=2)
# Update progress
progress.model_path = output_dir
progress.status = "completed"
progress.progress = 100
progress.message = f"Training completed successfully! Model ready for download."
# Keep the temp directory for download
def cleanup_temp_dir():
time.sleep(7200) # Wait 2 hours before cleanup
try:
shutil.rmtree(temp_dir)
if job_id in training_jobs:
del training_jobs[job_id]
except:
pass
cleanup_thread = threading.Thread(target=cleanup_temp_dir)
cleanup_thread.daemon = True
cleanup_thread.start()
except Exception as e:
progress.status = "error"
progress.error = str(e)
progress.message = f"Training failed: {str(e)}"
# Clean up on error
try:
if 'temp_dir' in locals():
shutil.rmtree(temp_dir)
except:
pass
def create_model_zip(model_path, job_id):
"""Create a zip file containing the trained model"""
memory_file = io.BytesIO()
with zipfile.ZipFile(memory_file, 'w', zipfile.ZIP_DEFLATED) as zf:
for root, dirs, files in os.walk(model_path):
for file in files:
file_path = os.path.join(root, file)
arc_name = os.path.relpath(file_path, model_path)
zf.write(file_path, arc_name)
memory_file.seek(0)
return memory_file
# ============== API ROUTES ==============
@app.route('/api/train', methods=['POST'])
def start_training():
"""Start training and return job ID for tracking"""
try:
data = request.get_json() if request.is_json else {}
dataset_name = data.get('dataset_name', 'ruslanmv/ai-medical-chatbot')
base_model_name = data.get('base_model', 'microsoft/DialoGPT-medium')
job_id = str(uuid.uuid4())[:8]
progress = TrainingProgress(job_id)
training_jobs[job_id] = progress
# Start training in background thread
training_thread = threading.Thread(
target=train_model_background,
args=(job_id, dataset_name, base_model_name)
)
training_thread.daemon = True
training_thread.start()
return jsonify({
"status": "started",
"job_id": job_id,
"dataset_name": dataset_name,
"base_model": base_model_name,
"message": "Training started. Use /api/status/<job_id> to track progress."
})
except Exception as e:
return jsonify({"status": "error", "message": str(e)}), 500
@app.route('/api/status/<job_id>', methods=['GET'])
def get_training_status(job_id):
"""Get training progress and estimated completion time"""
if job_id not in training_jobs:
return jsonify({"status": "error", "message": "Job not found"}), 404
progress = training_jobs[job_id]
return jsonify(progress.to_dict())
@app.route('/api/download/<job_id>', methods=['GET'])
def download_model(job_id):
"""Download the trained model as a zip file"""
if job_id not in training_jobs:
return jsonify({"status": "error", "message": "Job not found"}), 404
progress = training_jobs[job_id]
if progress.status != "completed":
return jsonify({
"status": "error",
"message": f"Model not ready for download. Current status: {progress.status}"
}), 400
if not progress.model_path or not os.path.exists(progress.model_path):
return jsonify({
"status": "error",
"message": "Model files not found. They may have been cleaned up."
}), 404
try:
# Create zip file in memory
zip_file = create_model_zip(progress.model_path, job_id)
return send_file(
zip_file,
as_attachment=True,
download_name=f"trained_model_{job_id}.zip",
mimetype='application/zip'
)
except Exception as e:
return jsonify({"status": "error", "message": f"Download failed: {str(e)}"}), 500
@app.route('/api/jobs', methods=['GET'])
def list_jobs():
"""List all training jobs"""
jobs = {job_id: progress.to_dict() for job_id, progress in training_jobs.items()}
return jsonify({"jobs": jobs})
@app.route('/')
def home():
return jsonify({
"message": "Welcome to Enhanced LLaMA Fine-tuning API!",
"features": [
"Automatic question/answer column detection",
"Configurable base model and dataset",
"Local model download",
"Progress tracking with ETA"
],
"endpoints": {
"POST /api/train": "Start training (accepts dataset_name and base_model in JSON)",
"GET /api/status/<job_id>": "Get training status and detected columns",
"GET /api/download/<job_id>": "Download trained model as zip",
"GET /api/jobs": "List all jobs"
},
"usage_example": {
"start_training": {
"method": "POST",
"url": "/api/train",
"body": {
"dataset_name": "your-dataset-name",
"base_model": "microsoft/DialoGPT-medium"
}
}
}
})
@app.route('/health')
def health():
return jsonify({"status": "healthy"})
if __name__ == '__main__':
port = int(os.environ.get('PORT', 7860))
app.run(host='0.0.0.0', port=port, debug=False)