File size: 3,899 Bytes
8c2f469
 
e7ceaff
f6b6cd4
8c2f469
e7ceaff
 
f6b6cd4
8c2f469
e7ceaff
 
 
8c2f469
e7ceaff
 
 
8c2f469
f6b6cd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c2f469
e7ceaff
 
 
 
 
 
 
 
8c2f469
f6b6cd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c2f469
f6b6cd4
 
 
 
 
 
 
 
 
 
 
 
 
 
8c2f469
f6b6cd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c2f469
f6b6cd4
 
8c2f469
f6b6cd4
 
 
 
 
 
 
 
 
 
 
8c2f469
e7ceaff
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import os
from fastapi import FastAPI, HTTPException, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import logging
from typing import List, Optional

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Setup cache directory
os.makedirs("/app/cache", exist_ok=True)
os.environ['TRANSFORMERS_CACHE'] = "/app/cache"

# Pydantic models for request/response
class GenerateRequest(BaseModel):
    text: str
    max_length: Optional[int] = 512
    temperature: Optional[float] = 0.7
    num_return_sequences: Optional[int] = 1

class GenerateResponse(BaseModel):
    generated_text: List[str]

class HealthResponse(BaseModel):
    status: str
    model_loaded: bool
    gpu_available: bool
    device: str

# Initialize FastAPI app
app = FastAPI(
    title="Medical LLaMA API",
    description="API for medical text generation using fine-tuned LLaMA model",
    version="1.0.0",
    docs_url="/docs",
    redoc_url="/redoc"
)

# Add CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Global variables for model and tokenizer
model = None
tokenizer = None

@app.get("/", response_model=HealthResponse, tags=["Health"])
async def root():
    """
    Root endpoint to check API health and model status
    """
    device = "cuda" if torch.cuda.is_available() else "cpu"
    return HealthResponse(
        status="online",
        model_loaded=model is not None,
        gpu_available=torch.cuda.is_available(),
        device=device
    )

@app.post("/generate", response_model=GenerateResponse, tags=["Generation"])
async def generate_text(request: GenerateRequest):
    """
    Generate medical text based on input prompt
    
    Parameters:
    - text: Input text prompt
    - max_length: Maximum length of generated text
    - temperature: Sampling temperature (0.0 to 1.0)
    - num_return_sequences: Number of sequences to generate
    
    Returns:
    - List of generated text sequences
    """
    try:
        if model is None or tokenizer is None:
            raise HTTPException(status_code=500, detail="Model not loaded")

        inputs = tokenizer(
            request.text,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=request.max_length
        ).to(model.device)

        with torch.no_grad():
            generated_ids = model.generate(
                inputs.input_ids,
                max_length=request.max_length,
                num_return_sequences=request.num_return_sequences,
                temperature=request.temperature,
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=tokenizer.eos_token_id,
            )

        generated_texts = [
            tokenizer.decode(g, skip_special_tokens=True)
            for g in generated_ids
        ]

        return GenerateResponse(generated_text=generated_texts)

    except Exception as e:
        logger.error(f"Generation error: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/health", tags=["Health"])
async def health_check():
    """
    Check the health status of the API and model
    """
    return {
        "status": "healthy",
        "model_loaded": model is not None,
        "gpu_available": torch.cuda.is_available(),
        "device": "cuda" if torch.cuda.is_available() else "cpu"
    }

@app.on_event("startup")
async def startup_event():
    logger.info("Starting up application...")
    try:
        global tokenizer, model
        tokenizer, model = init_model()
        logger.info("Model loaded successfully")
    except Exception as e:
        logger.error(f"Failed to load model: {str(e)}")