import os from fastapi import FastAPI, HTTPException, BackgroundTasks from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel import torch from transformers import AutoTokenizer, AutoModelForCausalLM import logging from typing import List, Optional # Setup logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Setup cache directory os.makedirs("/app/cache", exist_ok=True) os.environ['TRANSFORMERS_CACHE'] = "/app/cache" # Pydantic models for request/response class GenerateRequest(BaseModel): text: str max_length: Optional[int] = 512 temperature: Optional[float] = 0.7 num_return_sequences: Optional[int] = 1 class GenerateResponse(BaseModel): generated_text: List[str] class HealthResponse(BaseModel): status: str model_loaded: bool gpu_available: bool device: str # Initialize FastAPI app app = FastAPI( title="Medical LLaMA API", description="API for medical text generation using fine-tuned LLaMA model", version="1.0.0", docs_url="/docs", redoc_url="/redoc" ) # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Global variables for model and tokenizer model = None tokenizer = None @app.get("/", response_model=HealthResponse, tags=["Health"]) async def root(): """ Root endpoint to check API health and model status """ device = "cuda" if torch.cuda.is_available() else "cpu" return HealthResponse( status="online", model_loaded=model is not None, gpu_available=torch.cuda.is_available(), device=device ) @app.post("/generate", response_model=GenerateResponse, tags=["Generation"]) async def generate_text(request: GenerateRequest): """ Generate medical text based on input prompt Parameters: - text: Input text prompt - max_length: Maximum length of generated text - temperature: Sampling temperature (0.0 to 1.0) - num_return_sequences: Number of sequences to generate Returns: - List of generated text sequences """ try: if model is None or tokenizer is None: raise HTTPException(status_code=500, detail="Model not loaded") inputs = tokenizer( request.text, return_tensors="pt", padding=True, truncation=True, max_length=request.max_length ).to(model.device) with torch.no_grad(): generated_ids = model.generate( inputs.input_ids, max_length=request.max_length, num_return_sequences=request.num_return_sequences, temperature=request.temperature, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, ) generated_texts = [ tokenizer.decode(g, skip_special_tokens=True) for g in generated_ids ] return GenerateResponse(generated_text=generated_texts) except Exception as e: logger.error(f"Generation error: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @app.get("/health", tags=["Health"]) async def health_check(): """ Check the health status of the API and model """ return { "status": "healthy", "model_loaded": model is not None, "gpu_available": torch.cuda.is_available(), "device": "cuda" if torch.cuda.is_available() else "cpu" } @app.on_event("startup") async def startup_event(): logger.info("Starting up application...") try: global tokenizer, model tokenizer, model = init_model() logger.info("Model loaded successfully") except Exception as e: logger.error(f"Failed to load model: {str(e)}")