flask_pfe / app.py
Guetat Youssef
test
e4256df
raw
history blame
9.73 kB
from flask import Flask, jsonify, request
import threading
import time
import os
import torch
from datasets import load_dataset
from huggingface_hub import login
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
TrainingArguments,
pipeline,
logging,
DataCollatorForLanguageModeling,
)
from peft import (
LoraConfig,
PeftModel,
prepare_model_for_kbit_training,
get_peft_model,
)
from trl import SFTTrainer, setup_chat_format
import uuid
from datetime import datetime, timedelta
# ============== CONFIGURATION ==============
app = Flask(__name__)
# Global variables to track training progress
training_jobs = {}
class TrainingProgress:
def __init__(self, job_id):
self.job_id = job_id
self.status = "initializing"
self.progress = 0
self.current_step = 0
self.total_steps = 0
self.start_time = time.time()
self.estimated_finish_time = None
self.message = "Starting training..."
self.error = None
def update_progress(self, current_step, total_steps, message=""):
self.current_step = current_step
self.total_steps = total_steps
self.progress = (current_step / total_steps) * 100 if total_steps > 0 else 0
self.message = message
# Calculate estimated finish time
if current_step > 0:
elapsed_time = time.time() - self.start_time
time_per_step = elapsed_time / current_step
remaining_steps = total_steps - current_step
estimated_remaining_time = remaining_steps * time_per_step
self.estimated_finish_time = datetime.now() + timedelta(seconds=estimated_remaining_time)
def to_dict(self):
return {
"job_id": self.job_id,
"status": self.status,
"progress": round(self.progress, 2),
"current_step": self.current_step,
"total_steps": self.total_steps,
"message": self.message,
"estimated_finish_time": self.estimated_finish_time.isoformat() if self.estimated_finish_time else None,
"error": self.error
}
def train_model_background(job_id):
"""Background training function with progress tracking"""
progress = training_jobs[job_id]
try:
# === Authentication ===
import os
from huggingface_hub import login
hf_token = os.getenv('HF_TOKEN')
if not hf_token:
raise ValueError("HF_TOKEN is not set. Please define it as an environment variable or secret.")
login(token=hf_token)
progress.status = "loading_model"
progress.message = "Loading base model and tokenizer..."
# === Configuration ===
base_model = "meta-llama/Llama-3.2-1B"
dataset_name = "ruslanmv/ai-medical-chatbot"
new_model = f"Llama-3.2-3B-chat-doctor-{job_id}"
torch_dtype = torch.float16
attn_implementation = "eager"
# === QLoRA Config ===
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch_dtype,
bnb_4bit_use_double_quant=True,
)
# === Load Model and Tokenizer ===
model = AutoModelForCausalLM.from_pretrained(
base_model,
quantization_config=bnb_config,
device_map="auto",
attn_implementation=attn_implementation
)
tokenizer = AutoTokenizer.from_pretrained(base_model)
model, tokenizer = setup_chat_format(model, tokenizer)
progress.status = "preparing_model"
progress.message = "Setting up LoRA configuration..."
# === LoRA Config ===
peft_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules=[
'up_proj', 'down_proj', 'gate_proj',
'k_proj', 'q_proj', 'v_proj', 'o_proj'
]
)
model = get_peft_model(model, peft_config)
progress.status = "loading_dataset"
progress.message = "Loading and preparing dataset..."
# === Load & Prepare Dataset ===
dataset = load_dataset(dataset_name, split="all")
dataset = dataset.shuffle(seed=65).select(range(1000)) # Use 1000 samples
def format_chat_template(row, tokenizer):
row_json = [
{"role": "user", "content": row["Patient"]},
{"role": "assistant", "content": row["Doctor"]}
]
row["text"] = tokenizer.apply_chat_template(row_json, tokenize=False)
return row
dataset = dataset.map(
format_chat_template,
fn_kwargs={"tokenizer": tokenizer},
num_proc=4
)
dataset = dataset.train_test_split(test_size=0.1)
# Calculate total training steps
train_size = len(dataset["train"])
batch_size = 1
gradient_accumulation_steps = 2
num_epochs = 1
steps_per_epoch = train_size // (batch_size * gradient_accumulation_steps)
total_steps = steps_per_epoch * num_epochs
progress.total_steps = total_steps
progress.status = "training"
progress.message = "Starting training..."
# === Training Arguments ===
training_args = TrainingArguments(
output_dir=new_model,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=1,
gradient_accumulation_steps=gradient_accumulation_steps,
optim="paged_adamw_32bit",
num_train_epochs=num_epochs,
eval_steps=0.2,
logging_steps=1,
warmup_steps=10,
logging_strategy="steps",
learning_rate=2e-5,
fp16=False,
bf16=False,
group_by_length=True,
save_steps=50,
save_total_limit=2,
report_to=None # Disable wandb for HF Spaces
)
# === Data Collator ===
tokenizer.model_max_length = 512
# Custom callback to track progress
class ProgressCallback:
def __init__(self, progress_tracker):
self.progress_tracker = progress_tracker
self.last_update = time.time()
def on_log(self, args, state, control, model=None, logs=None, **kwargs):
current_time = time.time()
# Update every 10 seconds or on significant step changes
if current_time - self.last_update >= 10 or state.global_step % 10 == 0:
self.progress_tracker.update_progress(
state.global_step,
state.max_steps,
f"Training step {state.global_step}/{state.max_steps}"
)
self.last_update = current_time
# === Trainer Initialization ===
trainer = SFTTrainer(
model=model,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
peft_config=peft_config,
args=training_args,
callbacks=[ProgressCallback(progress)]
)
# === Train & Save ===
trainer.train()
trainer.save_model(new_model)
progress.status = "completed"
progress.progress = 100
progress.message = f"Training completed! Model saved as {new_model}"
except Exception as e:
progress.status = "error"
progress.error = str(e)
progress.message = f"Training failed: {str(e)}"
# ============== API ROUTES ==============
@app.route('/api/train', methods=['POST'])
def start_training():
"""Start training and return job ID for tracking"""
try:
job_id = str(uuid.uuid4())[:8] # Short UUID
progress = TrainingProgress(job_id)
training_jobs[job_id] = progress
# Start training in background thread
training_thread = threading.Thread(target=train_model_background, args=(job_id,))
training_thread.daemon = True
training_thread.start()
return jsonify({
"status": "started",
"job_id": job_id,
"message": "Training started. Use /api/status/<job_id> to track progress."
})
except Exception as e:
return jsonify({"status": "error", "message": str(e)}), 500
@app.route('/api/status/<job_id>', methods=['GET'])
def get_training_status(job_id):
"""Get training progress and estimated completion time"""
if job_id not in training_jobs:
return jsonify({"status": "error", "message": "Job not found"}), 404
progress = training_jobs[job_id]
return jsonify(progress.to_dict())
@app.route('/api/jobs', methods=['GET'])
def list_jobs():
"""List all training jobs"""
jobs = {job_id: progress.to_dict() for job_id, progress in training_jobs.items()}
return jsonify({"jobs": jobs})
@app.route('/')
def home():
return jsonify({
"message": "Welcome to LLaMA Fine-tuning API!",
"endpoints": {
"POST /api/train": "Start training",
"GET /api/status/<job_id>": "Get training status",
"GET /api/jobs": "List all jobs"
}
})
@app.route('/health')
def health():
return jsonify({"status": "healthy"})
if __name__ == '__main__':
port = int(os.environ.get('PORT', 7860)) # HF Spaces uses port 7860
app.run(host='0.0.0.0', port=port, debug=False)