Spaces:
Build error
Build error
File size: 3,899 Bytes
8c2f469 e7ceaff f6b6cd4 8c2f469 e7ceaff f6b6cd4 8c2f469 e7ceaff 8c2f469 e7ceaff 8c2f469 f6b6cd4 8c2f469 e7ceaff 8c2f469 f6b6cd4 8c2f469 f6b6cd4 8c2f469 f6b6cd4 8c2f469 f6b6cd4 8c2f469 f6b6cd4 8c2f469 e7ceaff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import os
from fastapi import FastAPI, HTTPException, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import logging
from typing import List, Optional
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Setup cache directory
os.makedirs("/app/cache", exist_ok=True)
os.environ['TRANSFORMERS_CACHE'] = "/app/cache"
# Pydantic models for request/response
class GenerateRequest(BaseModel):
text: str
max_length: Optional[int] = 512
temperature: Optional[float] = 0.7
num_return_sequences: Optional[int] = 1
class GenerateResponse(BaseModel):
generated_text: List[str]
class HealthResponse(BaseModel):
status: str
model_loaded: bool
gpu_available: bool
device: str
# Initialize FastAPI app
app = FastAPI(
title="Medical LLaMA API",
description="API for medical text generation using fine-tuned LLaMA model",
version="1.0.0",
docs_url="/docs",
redoc_url="/redoc"
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global variables for model and tokenizer
model = None
tokenizer = None
@app.get("/", response_model=HealthResponse, tags=["Health"])
async def root():
"""
Root endpoint to check API health and model status
"""
device = "cuda" if torch.cuda.is_available() else "cpu"
return HealthResponse(
status="online",
model_loaded=model is not None,
gpu_available=torch.cuda.is_available(),
device=device
)
@app.post("/generate", response_model=GenerateResponse, tags=["Generation"])
async def generate_text(request: GenerateRequest):
"""
Generate medical text based on input prompt
Parameters:
- text: Input text prompt
- max_length: Maximum length of generated text
- temperature: Sampling temperature (0.0 to 1.0)
- num_return_sequences: Number of sequences to generate
Returns:
- List of generated text sequences
"""
try:
if model is None or tokenizer is None:
raise HTTPException(status_code=500, detail="Model not loaded")
inputs = tokenizer(
request.text,
return_tensors="pt",
padding=True,
truncation=True,
max_length=request.max_length
).to(model.device)
with torch.no_grad():
generated_ids = model.generate(
inputs.input_ids,
max_length=request.max_length,
num_return_sequences=request.num_return_sequences,
temperature=request.temperature,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
)
generated_texts = [
tokenizer.decode(g, skip_special_tokens=True)
for g in generated_ids
]
return GenerateResponse(generated_text=generated_texts)
except Exception as e:
logger.error(f"Generation error: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health", tags=["Health"])
async def health_check():
"""
Check the health status of the API and model
"""
return {
"status": "healthy",
"model_loaded": model is not None,
"gpu_available": torch.cuda.is_available(),
"device": "cuda" if torch.cuda.is_available() else "cpu"
}
@app.on_event("startup")
async def startup_event():
logger.info("Starting up application...")
try:
global tokenizer, model
tokenizer, model = init_model()
logger.info("Model loaded successfully")
except Exception as e:
logger.error(f"Failed to load model: {str(e)}") |