fullstuckdev
path swagger
f6b6cd4
raw
history blame
3.9 kB
import os
from fastapi import FastAPI, HTTPException, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import logging
from typing import List, Optional
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Setup cache directory
os.makedirs("/app/cache", exist_ok=True)
os.environ['TRANSFORMERS_CACHE'] = "/app/cache"
# Pydantic models for request/response
class GenerateRequest(BaseModel):
text: str
max_length: Optional[int] = 512
temperature: Optional[float] = 0.7
num_return_sequences: Optional[int] = 1
class GenerateResponse(BaseModel):
generated_text: List[str]
class HealthResponse(BaseModel):
status: str
model_loaded: bool
gpu_available: bool
device: str
# Initialize FastAPI app
app = FastAPI(
title="Medical LLaMA API",
description="API for medical text generation using fine-tuned LLaMA model",
version="1.0.0",
docs_url="/docs",
redoc_url="/redoc"
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global variables for model and tokenizer
model = None
tokenizer = None
@app.get("/", response_model=HealthResponse, tags=["Health"])
async def root():
"""
Root endpoint to check API health and model status
"""
device = "cuda" if torch.cuda.is_available() else "cpu"
return HealthResponse(
status="online",
model_loaded=model is not None,
gpu_available=torch.cuda.is_available(),
device=device
)
@app.post("/generate", response_model=GenerateResponse, tags=["Generation"])
async def generate_text(request: GenerateRequest):
"""
Generate medical text based on input prompt
Parameters:
- text: Input text prompt
- max_length: Maximum length of generated text
- temperature: Sampling temperature (0.0 to 1.0)
- num_return_sequences: Number of sequences to generate
Returns:
- List of generated text sequences
"""
try:
if model is None or tokenizer is None:
raise HTTPException(status_code=500, detail="Model not loaded")
inputs = tokenizer(
request.text,
return_tensors="pt",
padding=True,
truncation=True,
max_length=request.max_length
).to(model.device)
with torch.no_grad():
generated_ids = model.generate(
inputs.input_ids,
max_length=request.max_length,
num_return_sequences=request.num_return_sequences,
temperature=request.temperature,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
)
generated_texts = [
tokenizer.decode(g, skip_special_tokens=True)
for g in generated_ids
]
return GenerateResponse(generated_text=generated_texts)
except Exception as e:
logger.error(f"Generation error: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health", tags=["Health"])
async def health_check():
"""
Check the health status of the API and model
"""
return {
"status": "healthy",
"model_loaded": model is not None,
"gpu_available": torch.cuda.is_available(),
"device": "cuda" if torch.cuda.is_available() else "cpu"
}
@app.on_event("startup")
async def startup_event():
logger.info("Starting up application...")
try:
global tokenizer, model
tokenizer, model = init_model()
logger.info("Model loaded successfully")
except Exception as e:
logger.error(f"Failed to load model: {str(e)}")