Spaces:
Build error
Build error
import os | |
from fastapi import FastAPI, HTTPException, BackgroundTasks | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import logging | |
from typing import List, Optional | |
# Setup logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Setup cache directory | |
os.makedirs("/app/cache", exist_ok=True) | |
os.environ['TRANSFORMERS_CACHE'] = "/app/cache" | |
# Pydantic models for request/response | |
class GenerateRequest(BaseModel): | |
text: str | |
max_length: Optional[int] = 512 | |
temperature: Optional[float] = 0.7 | |
num_return_sequences: Optional[int] = 1 | |
class GenerateResponse(BaseModel): | |
generated_text: List[str] | |
class HealthResponse(BaseModel): | |
status: str | |
model_loaded: bool | |
gpu_available: bool | |
device: str | |
# Initialize FastAPI app | |
app = FastAPI( | |
title="Medical LLaMA API", | |
description="API for medical text generation using fine-tuned LLaMA model", | |
version="1.0.0", | |
docs_url="/docs", | |
redoc_url="/redoc" | |
) | |
# Add CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Global variables for model and tokenizer | |
model = None | |
tokenizer = None | |
async def root(): | |
""" | |
Root endpoint to check API health and model status | |
""" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
return HealthResponse( | |
status="online", | |
model_loaded=model is not None, | |
gpu_available=torch.cuda.is_available(), | |
device=device | |
) | |
async def generate_text(request: GenerateRequest): | |
""" | |
Generate medical text based on input prompt | |
Parameters: | |
- text: Input text prompt | |
- max_length: Maximum length of generated text | |
- temperature: Sampling temperature (0.0 to 1.0) | |
- num_return_sequences: Number of sequences to generate | |
Returns: | |
- List of generated text sequences | |
""" | |
try: | |
if model is None or tokenizer is None: | |
raise HTTPException(status_code=500, detail="Model not loaded") | |
inputs = tokenizer( | |
request.text, | |
return_tensors="pt", | |
padding=True, | |
truncation=True, | |
max_length=request.max_length | |
).to(model.device) | |
with torch.no_grad(): | |
generated_ids = model.generate( | |
inputs.input_ids, | |
max_length=request.max_length, | |
num_return_sequences=request.num_return_sequences, | |
temperature=request.temperature, | |
pad_token_id=tokenizer.pad_token_id, | |
eos_token_id=tokenizer.eos_token_id, | |
) | |
generated_texts = [ | |
tokenizer.decode(g, skip_special_tokens=True) | |
for g in generated_ids | |
] | |
return GenerateResponse(generated_text=generated_texts) | |
except Exception as e: | |
logger.error(f"Generation error: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def health_check(): | |
""" | |
Check the health status of the API and model | |
""" | |
return { | |
"status": "healthy", | |
"model_loaded": model is not None, | |
"gpu_available": torch.cuda.is_available(), | |
"device": "cuda" if torch.cuda.is_available() else "cpu" | |
} | |
async def startup_event(): | |
logger.info("Starting up application...") | |
try: | |
global tokenizer, model | |
tokenizer, model = init_model() | |
logger.info("Model loaded successfully") | |
except Exception as e: | |
logger.error(f"Failed to load model: {str(e)}") |