from flask import Flask, jsonify, request, send_file import threading import time import os import tempfile import shutil import uuid import zipfile import io from datetime import datetime, timedelta app = Flask(__name__) # Global variables to track training progress training_jobs = {} class TrainingProgress: def __init__(self, job_id): self.job_id = job_id self.status = "initializing" self.progress = 0 self.current_step = 0 self.total_steps = 0 self.start_time = time.time() self.estimated_finish_time = None self.message = "Starting training..." self.error = None self.model_path = None self.detected_columns = None def update_progress(self, current_step, total_steps, message=""): self.current_step = current_step self.total_steps = total_steps self.progress = (current_step / total_steps) * 100 if total_steps > 0 else 0 self.message = message # Calculate estimated finish time if current_step > 0: elapsed_time = time.time() - self.start_time time_per_step = elapsed_time / current_step remaining_steps = total_steps - current_step estimated_remaining_time = remaining_steps * time_per_step self.estimated_finish_time = datetime.now() + timedelta(seconds=estimated_remaining_time) def to_dict(self): return { "job_id": self.job_id, "status": self.status, "progress": round(self.progress, 2), "current_step": self.current_step, "total_steps": self.total_steps, "message": self.message, "estimated_finish_time": self.estimated_finish_time.isoformat() if self.estimated_finish_time else None, "error": self.error, "model_path": self.model_path, "detected_columns": self.detected_columns } def detect_qa_columns(dataset): """Automatically detect question and answer columns in the dataset""" # Common patterns for question columns question_patterns = [ 'question', 'prompt', 'input', 'query', 'patient', 'user', 'human', 'instruction', 'context', 'q', 'text', 'source' ] # Common patterns for answer columns answer_patterns = [ 'answer', 'response', 'output', 'reply', 'doctor', 'assistant', 'ai', 'completion', 'target', 'a', 'label', 'ground_truth' ] # Get column names columns = list(dataset.column_names) # Find question column question_col = None for pattern in question_patterns: for col in columns: if pattern.lower() in col.lower(): question_col = col break if question_col: break # Find answer column answer_col = None for pattern in answer_patterns: for col in columns: if pattern.lower() in col.lower() and col != question_col: answer_col = col break if answer_col: break # Fallback: use first two text columns if patterns don't match if not question_col or not answer_col: text_columns = [] for col in columns: # Check if column contains text data sample = dataset[0][col] if isinstance(sample, str) and len(sample.strip()) > 0: text_columns.append(col) if len(text_columns) >= 2: question_col = text_columns[0] answer_col = text_columns[1] elif len(text_columns) == 1: # Single column case - use it for both (self-supervised) question_col = answer_col = text_columns[0] return question_col, answer_col def train_model_background(job_id, dataset_name, base_model_name=None): """Background training function with progress tracking""" progress = training_jobs[job_id] try: # Create a temporary directory for this job temp_dir = tempfile.mkdtemp(prefix=f"train_{job_id}_") # Set environment variables for caching os.environ['HF_HOME'] = temp_dir os.environ['TRANSFORMERS_CACHE'] = temp_dir os.environ['HF_DATASETS_CACHE'] = temp_dir os.environ['TORCH_HOME'] = temp_dir progress.status = "loading_libraries" progress.message = "Loading required libraries..." # Import heavy libraries after setting cache paths import torch from datasets import load_dataset, Dataset from huggingface_hub import login from transformers import ( AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, TrainerCallback, ) from peft import ( LoraConfig, get_peft_model, ) # === Authentication === hf_token = os.getenv('HF_TOKEN') if hf_token: login(token=hf_token) progress.status = "loading_model" progress.message = "Loading base model and tokenizer..." # === Configuration === base_model = base_model_name or "microsoft/DialoGPT-small" new_model = f"trained-model-{job_id}" max_length = 256 # === Load Model and Tokenizer === model = AutoModelForCausalLM.from_pretrained( base_model, cache_dir=temp_dir, torch_dtype=torch.float32, device_map="auto" if torch.cuda.is_available() else "cpu", trust_remote_code=True ) tokenizer = AutoTokenizer.from_pretrained( base_model, cache_dir=temp_dir, trust_remote_code=True ) # Add padding token if not present if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Resize token embeddings if needed model.resize_token_embeddings(len(tokenizer)) progress.status = "preparing_model" progress.message = "Setting up LoRA configuration..." # === LoRA Config === peft_config = LoraConfig( r=8, lora_alpha=16, lora_dropout=0.1, bias="none", task_type="CAUSAL_LM", ) model = get_peft_model(model, peft_config) progress.status = "loading_dataset" progress.message = "Loading and preparing dataset..." # === Load & Prepare Dataset === dataset = load_dataset( dataset_name, split="train" if "train" in load_dataset(dataset_name, cache_dir=temp_dir).keys() else "all", cache_dir=temp_dir, trust_remote_code=True ) # Automatically detect question and answer columns question_col, answer_col = detect_qa_columns(dataset) if not question_col or not answer_col: raise ValueError("Could not automatically detect question and answer columns in the dataset") progress.detected_columns = {"question": question_col, "answer": answer_col} progress.message = f"Detected columns - Question: {question_col}, Answer: {answer_col}" # Use subset for faster testing (can be made configurable) dataset = dataset.shuffle(seed=65).select(range(min(1000, len(dataset)))) # Custom dataset class for proper handling class CustomDataset(torch.utils.data.Dataset): def __init__(self, texts, tokenizer, max_length): self.texts = texts self.tokenizer = tokenizer self.max_length = max_length def __len__(self): return len(self.texts) def __getitem__(self, idx): text = self.texts[idx] # Tokenize the text encoding = self.tokenizer( text, truncation=True, padding='max_length', max_length=self.max_length, return_tensors='pt' ) # Flatten the tensors (remove batch dimension) input_ids = encoding['input_ids'].squeeze() attention_mask = encoding['attention_mask'].squeeze() # For causal language modeling, labels are the same as input_ids labels = input_ids.clone() # Set labels to -100 for padding tokens (they won't contribute to loss) labels[attention_mask == 0] = -100 return { 'input_ids': input_ids, 'attention_mask': attention_mask, 'labels': labels } # Prepare texts using detected columns texts = [] for item in dataset: question = str(item[question_col]).strip() answer = str(item[answer_col]).strip() text = f"Question: {question}\nAnswer: {answer}{tokenizer.eos_token}" texts.append(text) # Create custom dataset train_dataset = CustomDataset(texts, tokenizer, max_length) # Calculate total training steps batch_size = 2 gradient_accumulation_steps = 1 num_epochs = 1 steps_per_epoch = len(train_dataset) // (batch_size * gradient_accumulation_steps) total_steps = steps_per_epoch * num_epochs progress.total_steps = total_steps progress.status = "training" progress.message = "Starting training..." # === Training Arguments === output_dir = os.path.join(temp_dir, new_model) os.makedirs(output_dir, exist_ok=True) training_args = TrainingArguments( output_dir=output_dir, per_device_train_batch_size=batch_size, gradient_accumulation_steps=gradient_accumulation_steps, num_train_epochs=num_epochs, logging_steps=1, save_steps=max(1, total_steps // 2), save_total_limit=1, learning_rate=5e-5, warmup_steps=2, logging_strategy="steps", save_strategy="steps", fp16=False, bf16=False, dataloader_num_workers=0, remove_unused_columns=False, report_to=None, prediction_loss_only=True, ) # Custom callback to track progress class ProgressCallback(TrainerCallback): def __init__(self, progress_tracker): self.progress_tracker = progress_tracker self.last_update = time.time() def on_log(self, args, state, control, model=None, logs=None, **kwargs): current_time = time.time() # Update every 3 seconds if current_time - self.last_update >= 3: self.progress_tracker.update_progress( state.global_step, state.max_steps, f"Training step {state.global_step}/{state.max_steps}" ) self.last_update = current_time # Log training metrics if available if logs: loss = logs.get('train_loss', logs.get('loss', 'N/A')) self.progress_tracker.message = f"Step {state.global_step}/{state.max_steps}, Loss: {loss}" def on_train_begin(self, args, state, control, **kwargs): self.progress_tracker.status = "training" self.progress_tracker.message = "Training started..." def on_train_end(self, args, state, control, **kwargs): self.progress_tracker.status = "saving" self.progress_tracker.message = "Training complete, saving model..." # === Trainer Initialization === trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, callbacks=[ProgressCallback(progress)], tokenizer=tokenizer, ) # === Train & Save === trainer.train() trainer.save_model(output_dir) tokenizer.save_pretrained(output_dir) # Save model info progress.model_path = output_dir progress.status = "completed" progress.progress = 100 progress.message = f"Training completed! Model ready for download." # Keep the temp directory for download (cleanup after 1 hour) def cleanup_temp_dir(): time.sleep(3600) # Wait 1 hour before cleanup try: shutil.rmtree(temp_dir) # Remove from training_jobs after cleanup if job_id in training_jobs: del training_jobs[job_id] except: pass cleanup_thread = threading.Thread(target=cleanup_temp_dir) cleanup_thread.daemon = True cleanup_thread.start() except Exception as e: progress.status = "error" progress.error = str(e) progress.message = f"Training failed: {str(e)}" # Clean up on error try: if 'temp_dir' in locals(): shutil.rmtree(temp_dir) except: pass def create_model_zip(model_path, job_id): """Create a zip file containing the trained model""" memory_file = io.BytesIO() with zipfile.ZipFile(memory_file, 'w', zipfile.ZIP_DEFLATED) as zf: for root, dirs, files in os.walk(model_path): for file in files: file_path = os.path.join(root, file) arc_name = os.path.relpath(file_path, model_path) zf.write(file_path, arc_name) memory_file.seek(0) return memory_file # ============== API ROUTES ============== @app.route('/api/train', methods=['POST']) def start_training(): """Start training and return job ID for tracking""" try: data = request.get_json() if request.is_json else {} dataset_name = data.get('dataset_name', 'ruslanmv/ai-medical-chatbot') base_model_name = data.get('base_model', 'microsoft/DialoGPT-small') job_id = str(uuid.uuid4())[:8] # Short UUID progress = TrainingProgress(job_id) training_jobs[job_id] = progress # Start training in background thread training_thread = threading.Thread( target=train_model_background, args=(job_id, dataset_name, base_model_name) ) training_thread.daemon = True training_thread.start() return jsonify({ "status": "started", "job_id": job_id, "dataset_name": dataset_name, "base_model": base_model_name, "message": "Training started. Use /api/status/ to track progress." }) except Exception as e: return jsonify({"status": "error", "message": str(e)}), 500 @app.route('/api/status/', methods=['GET']) def get_training_status(job_id): """Get training progress and estimated completion time""" if job_id not in training_jobs: return jsonify({"status": "error", "message": "Job not found"}), 404 progress = training_jobs[job_id] return jsonify(progress.to_dict()) @app.route('/api/download/', methods=['GET']) def download_model(job_id): """Download the trained model as a zip file""" if job_id not in training_jobs: return jsonify({"status": "error", "message": "Job not found"}), 404 progress = training_jobs[job_id] if progress.status != "completed": return jsonify({ "status": "error", "message": f"Model not ready for download. Current status: {progress.status}" }), 400 if not progress.model_path or not os.path.exists(progress.model_path): return jsonify({ "status": "error", "message": "Model files not found. They may have been cleaned up." }), 404 try: # Create zip file in memory zip_file = create_model_zip(progress.model_path, job_id) return send_file( zip_file, as_attachment=True, download_name=f"trained_model_{job_id}.zip", mimetype='application/zip' ) except Exception as e: return jsonify({"status": "error", "message": f"Download failed: {str(e)}"}), 500 @app.route('/api/jobs', methods=['GET']) def list_jobs(): """List all training jobs""" jobs = {job_id: progress.to_dict() for job_id, progress in training_jobs.items()} return jsonify({"jobs": jobs}) @app.route('/') def home(): return jsonify({ "message": "Welcome to Enhanced LLaMA Fine-tuning API!", "features": [ "Automatic question/answer column detection", "Configurable base model and dataset", "Local model download", "Progress tracking with ETA" ], "endpoints": { "POST /api/train": "Start training (accepts dataset_name and base_model in JSON)", "GET /api/status/": "Get training status and detected columns", "GET /api/download/": "Download trained model as zip", "GET /api/jobs": "List all jobs" }, "usage_example": { "start_training": { "method": "POST", "url": "/api/train", "body": { "dataset_name": "your-dataset-name", "base_model": "microsoft/DialoGPT-small" } } } }) @app.route('/health') def health(): return jsonify({"status": "healthy"}) if __name__ == '__main__': port = int(os.environ.get('PORT', 7860)) # HF Spaces uses port 7860 app.run(host='0.0.0.0', port=port, debug=False)