|
from fastapi import FastAPI, HTTPException, status |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from fastapi.responses import JSONResponse |
|
from pydantic import BaseModel |
|
from transformers import T5Tokenizer, T5ForConditionalGeneration, AutoConfig |
|
import torch |
|
import os |
|
import sys |
|
import traceback |
|
from typing import Optional, Dict, Any |
|
from accelerate import Accelerator |
|
import time |
|
import psutil |
|
from loguru import logger |
|
|
|
|
|
logger.remove() |
|
logger.add( |
|
sys.stderr, |
|
level="INFO", |
|
format="<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>" |
|
) |
|
|
|
|
|
app = FastAPI( |
|
title="Clinical Report Generator API", |
|
description="Production API for generating clinical report summaries using T5", |
|
version="1.0.0", |
|
docs_url="/documentation", |
|
redoc_url="/redoc" |
|
) |
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["https://pdarleyjr.github.io"], |
|
allow_credentials=True, |
|
allow_methods=["POST", "GET"], |
|
allow_headers=["*"], |
|
max_age=3600, |
|
) |
|
|
|
|
|
MODEL_ID = "pdarleyjr/iplc-t5-clinical" |
|
|
|
class ModelManager: |
|
def __init__(self): |
|
self.model = None |
|
self.tokenizer = None |
|
self.accelerator = Accelerator() |
|
self.last_load_time = None |
|
self.load_lock = False |
|
|
|
async def load_model(self) -> bool: |
|
"""Load model and tokenizer with proper error handling and logging""" |
|
if self.load_lock: |
|
logger.warning("Model load already in progress") |
|
return False |
|
|
|
try: |
|
self.load_lock = True |
|
logger.info("Starting model and tokenizer loading process...") |
|
|
|
|
|
memory = psutil.virtual_memory() |
|
logger.info(f"System memory: {memory.percent}% used, {memory.available / (1024*1024*1024):.2f}GB available") |
|
if torch.cuda.is_available(): |
|
logger.info(f"CUDA memory: {torch.cuda.memory_allocated() / (1024*1024*1024):.2f}GB allocated") |
|
|
|
|
|
logger.info("Initializing tokenizer...") |
|
self.tokenizer = T5Tokenizer.from_pretrained( |
|
MODEL_ID, |
|
use_fast=True, |
|
model_max_length=512 |
|
) |
|
logger.success("Tokenizer loaded successfully") |
|
|
|
|
|
logger.info("Fetching model configuration...") |
|
config = AutoConfig.from_pretrained( |
|
MODEL_ID, |
|
trust_remote_code=False |
|
) |
|
logger.success("Model configuration loaded successfully") |
|
|
|
|
|
logger.info("Loading model (this may take a few minutes)...") |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
logger.info(f"Using device: {device}") |
|
|
|
self.model = T5ForConditionalGeneration.from_pretrained( |
|
MODEL_ID, |
|
config=config, |
|
torch_dtype=torch.float16 if device == "cuda" else torch.float32, |
|
low_cpu_mem_usage=True |
|
).to(device) |
|
logger.success("Model loaded successfully") |
|
|
|
|
|
self.model = self.accelerator.prepare_model(self.model) |
|
logger.success("Model prepared with accelerator") |
|
|
|
|
|
memory = psutil.virtual_memory() |
|
logger.info(f"Final memory usage: {memory.percent}% used, {memory.available / (1024*1024*1024):.2f}GB available") |
|
if torch.cuda.is_available(): |
|
logger.info(f"Final CUDA memory: {torch.cuda.memory_allocated() / (1024*1024*1024):.2f}GB allocated") |
|
|
|
self.last_load_time = time.time() |
|
return True |
|
|
|
except Exception as e: |
|
logger.exception("Error loading model") |
|
self.model = None |
|
self.tokenizer = None |
|
return False |
|
|
|
finally: |
|
self.load_lock = False |
|
|
|
def is_loaded(self) -> bool: |
|
"""Check if model and tokenizer are loaded""" |
|
return self.model is not None and self.tokenizer is not None |
|
|
|
def get_load_time(self) -> Optional[float]: |
|
"""Get the last successful load time""" |
|
return self.last_load_time |
|
|
|
|
|
model_manager = ModelManager() |
|
|
|
class PredictRequest(BaseModel): |
|
"""Request model for prediction endpoint""" |
|
text: str |
|
|
|
class Config: |
|
schema_extra = { |
|
"example": { |
|
"text": "evaluation type: initial. primary diagnosis: F84.0. severity: mild. primary language: english" |
|
} |
|
} |
|
|
|
@app.post("/predict", |
|
response_model=Dict[str, Any], |
|
status_code=status.HTTP_200_OK, |
|
responses={ |
|
500: {"description": "Internal server error"}, |
|
503: {"description": "Service unavailable - model loading"} |
|
}) |
|
async def predict(request: PredictRequest) -> JSONResponse: |
|
"""Generate a clinical report summary""" |
|
start_time = time.time() |
|
|
|
try: |
|
|
|
if not model_manager.is_loaded(): |
|
logger.warning("Model not loaded, attempting to load...") |
|
success = await model_manager.load_model() |
|
if not success: |
|
return JSONResponse( |
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE, |
|
content={ |
|
"success": False, |
|
"error": "Model is initializing. Please try again in a few moments." |
|
} |
|
) |
|
|
|
|
|
input_text = "summarize: " + request.text |
|
input_ids = model_manager.tokenizer.encode( |
|
input_text, |
|
return_tensors="pt", |
|
max_length=512, |
|
truncation=True, |
|
padding=True |
|
) |
|
|
|
|
|
try: |
|
device = next(model_manager.model.parameters()).device |
|
input_ids = input_ids.to(device) |
|
|
|
with torch.no_grad(), model_manager.accelerator.autocast(): |
|
outputs = model_manager.model.generate( |
|
input_ids, |
|
max_length=512, |
|
num_beams=5, |
|
no_repeat_ngram_size=3, |
|
length_penalty=2.0, |
|
early_stopping=True, |
|
pad_token_id=model_manager.tokenizer.pad_token_id, |
|
eos_token_id=model_manager.tokenizer.eos_token_id, |
|
temperature=0.7 |
|
) |
|
|
|
summary = model_manager.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
process_time = time.time() - start_time |
|
logger.info(f"Summary generated in {process_time:.2f} seconds") |
|
|
|
return JSONResponse( |
|
content={ |
|
"success": True, |
|
"data": summary, |
|
"error": None, |
|
"metrics": { |
|
"process_time": process_time |
|
} |
|
} |
|
) |
|
|
|
except torch.cuda.OutOfMemoryError: |
|
logger.error("CUDA out of memory error - clearing cache and reducing batch size") |
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
logger.info(f"CUDA memory after cleanup: {torch.cuda.memory_allocated() / (1024*1024*1024):.2f}GB allocated") |
|
return JSONResponse( |
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE, |
|
content={ |
|
"success": False, |
|
"error": "Server is currently overloaded. Please try again later." |
|
} |
|
) |
|
|
|
except Exception as e: |
|
logger.exception("Error in predict endpoint") |
|
return JSONResponse( |
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
|
content={ |
|
"success": False, |
|
"error": "An unexpected error occurred. Please try again later." |
|
} |
|
) |
|
|
|
@app.get("/health", |
|
response_model=Dict[str, Any], |
|
status_code=status.HTTP_200_OK) |
|
async def health_check() -> JSONResponse: |
|
"""Check API and model health status""" |
|
try: |
|
is_loaded = model_manager.is_loaded() |
|
load_time = model_manager.get_load_time() |
|
|
|
return JSONResponse( |
|
content={ |
|
"status": "healthy", |
|
"model_loaded": is_loaded, |
|
"last_load_time": load_time, |
|
"version": "1.0.0", |
|
"gpu_available": torch.cuda.is_available(), |
|
"gpu_name": torch.cuda.get_device_name(0) if torch.cuda.is_available() else None |
|
} |
|
) |
|
except Exception as e: |
|
logger.error(f"Error in health check: {str(e)}") |
|
return JSONResponse( |
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
|
content={ |
|
"status": "unhealthy", |
|
"error": str(e) |
|
} |
|
) |
|
|
|
@app.on_event("startup") |
|
async def startup_event() -> None: |
|
"""Initialize model on startup""" |
|
logger.info("Starting application in production mode...") |
|
logger.info(f"System resources - CPU: {psutil.cpu_percent()}%, Memory: {psutil.virtual_memory().percent}%") |
|
if torch.cuda.is_available(): |
|
logger.info(f"CUDA device: {torch.cuda.get_device_name(0)}") |
|
await model_manager.load_model() |
|
|
|
@app.on_event("shutdown") |
|
async def shutdown_event() -> None: |
|
"""Clean up resources on shutdown""" |
|
logger.info("Initiating graceful shutdown...") |
|
|
|
if torch.cuda.is_available(): |
|
logger.info(f"Final CUDA memory before cleanup: {torch.cuda.memory_allocated() / (1024*1024*1024):.2f}GB") |
|
torch.cuda.empty_cache() |
|
logger.info("CUDA cache cleared") |
|
logger.info(f"Final system stats - CPU: {psutil.cpu_percent()}%, Memory: {psutil.virtual_memory().percent}%") |
|
logger.success("Application shutdown complete") |
|
|
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|