import os from fastapi import FastAPI, HTTPException, BackgroundTasks from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel import torch from transformers import AutoTokenizer, AutoModelForCausalLM import logging from typing import List, Optional from datasets import load_dataset from transformers import TrainingArguments, Trainer, DataCollatorForLanguageModeling import json # Setup logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Setup cache directory os.makedirs("/app/cache", exist_ok=True) os.environ['TRANSFORMERS_CACHE'] = "/app/cache" # Pydantic models for request/response class GenerateRequest(BaseModel): text: str max_length: Optional[int] = 512 temperature: Optional[float] = 0.7 num_return_sequences: Optional[int] = 1 class GenerateResponse(BaseModel): generated_text: List[str] class HealthResponse(BaseModel): status: str model_loaded: bool gpu_available: bool device: str class TrainRequest(BaseModel): dataset_path: str num_epochs: Optional[int] = 3 batch_size: Optional[int] = 4 learning_rate: Optional[float] = 2e-5 class TrainResponse(BaseModel): status: str message: str # Add training status tracking class TrainingStatus: def __init__(self): self.is_training = False self.current_epoch = 0 self.current_loss = None self.status = "idle" training_status = TrainingStatus() # Initialize FastAPI app app = FastAPI( title="Medical LLaMA API", description="API for medical text generation using fine-tuned LLaMA model", version="1.0.0", docs_url="/docs", redoc_url="/redoc" ) # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Global variables for model and tokenizer model = None tokenizer = None @app.get("/", response_model=HealthResponse, tags=["Health"]) async def root(): """ Root endpoint to check API health and model status """ device = "cuda" if torch.cuda.is_available() else "cpu" return HealthResponse( status="online", model_loaded=model is not None, gpu_available=torch.cuda.is_available(), device=device ) @app.post("/generate", response_model=GenerateResponse, tags=["Generation"]) async def generate_text(request: GenerateRequest): """ Generate medical text based on input prompt Parameters: - text: Input text prompt - max_length: Maximum length of generated text - temperature: Sampling temperature (0.0 to 1.0) - num_return_sequences: Number of sequences to generate Returns: - List of generated text sequences """ try: if model is None or tokenizer is None: raise HTTPException(status_code=500, detail="Model not loaded") inputs = tokenizer( request.text, return_tensors="pt", padding=True, truncation=True, max_length=request.max_length ).to(model.device) with torch.no_grad(): generated_ids = model.generate( inputs.input_ids, max_length=request.max_length, num_return_sequences=request.num_return_sequences, temperature=request.temperature, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, ) generated_texts = [ tokenizer.decode(g, skip_special_tokens=True) for g in generated_ids ] return GenerateResponse(generated_text=generated_texts) except Exception as e: logger.error(f"Generation error: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @app.get("/health", tags=["Health"]) async def health_check(): """ Check the health status of the API and model """ return { "status": "healthy", "model_loaded": model is not None, "gpu_available": torch.cuda.is_available(), "device": "cuda" if torch.cuda.is_available() else "cpu" } @app.on_event("startup") async def startup_event(): logger.info("Starting up application...") try: global tokenizer, model tokenizer, model = init_model() logger.info("Model loaded successfully") except Exception as e: logger.error(f"Failed to load model: {str(e)}") @app.post("/train", response_model=TrainResponse, tags=["Training"]) async def train_model(request: TrainRequest, background_tasks: BackgroundTasks): """ Start model training with the specified dataset Parameters: - dataset_path: Path to the JSON dataset file - num_epochs: Number of training epochs - batch_size: Training batch size - learning_rate: Learning rate for training """ if training_status.is_training: raise HTTPException(status_code=400, detail="Training is already in progress") try: # Verify dataset exists if not os.path.exists(request.dataset_path): raise HTTPException(status_code=404, detail="Dataset file not found") # Start training in background background_tasks.add_task( run_training, request.dataset_path, request.num_epochs, request.batch_size, request.learning_rate ) return TrainResponse( status="started", message="Training started in background" ) except Exception as e: logger.error(f"Training setup error: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @app.get("/train/status", tags=["Training"]) async def get_training_status(): """ Get current training status """ return { "is_training": training_status.is_training, "current_epoch": training_status.current_epoch, "current_loss": training_status.current_loss, "status": training_status.status } # Add training function async def run_training(dataset_path: str, num_epochs: int, batch_size: int, learning_rate: float): global model, tokenizer, training_status try: training_status.is_training = True training_status.status = "loading_dataset" # Load dataset dataset = load_dataset("json", data_files=dataset_path) training_status.status = "preprocessing" # Preprocess function def preprocess_function(examples): return tokenizer( examples["text"], truncation=True, padding="max_length", max_length=512 ) # Tokenize dataset tokenized_dataset = dataset.map( preprocess_function, batched=True, remove_columns=dataset["train"].column_names ) training_status.status = "training" # Training arguments training_args = TrainingArguments( output_dir=f"{model_output_path}/checkpoints", per_device_train_batch_size=batch_size, gradient_accumulation_steps=4, num_train_epochs=num_epochs, learning_rate=learning_rate, fp16=True, save_steps=500, logging_steps=100, ) # Initialize trainer trainer = Trainer( model=model, args=training_args, train_dataset=tokenized_dataset["train"], data_collator=DataCollatorForLanguageModeling( tokenizer=tokenizer, mlm=False ), ) # Training callback to update status class TrainingCallback(trainer.callback_handler): def on_epoch_begin(self, args, state, control, **kwargs): training_status.current_epoch = state.epoch def on_log(self, args, state, control, logs=None, **kwargs): if logs: training_status.current_loss = logs.get("loss", None) trainer.add_callback(TrainingCallback) # Start training trainer.train() # Save the model training_status.status = "saving" model.save_pretrained(model_output_path) tokenizer.save_pretrained(model_output_path) training_status.status = "completed" logger.info("Training completed successfully") except Exception as e: training_status.status = f"failed: {str(e)}" logger.error(f"Training error: {str(e)}") raise finally: training_status.is_training = False # Update model initialization def init_model(): try: device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Loading model on device: {device}") # Try to load fine-tuned model if it exists if os.path.exists(model_output_path): tokenizer = AutoTokenizer.from_pretrained(model_output_path) model = AutoModelForCausalLM.from_pretrained( model_output_path, torch_dtype=torch.float16 if device == "cuda" else torch.float32, device_map="auto" ) else: # Load base model if no fine-tuned model exists model_name = "nvidia/Meta-Llama-3.2-3B-Instruct-ONNX-INT4" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16 if device == "cuda" else torch.float32, device_map="auto" ) return tokenizer, model except Exception as e: logger.error(f"Model initialization error: {str(e)}") raise