File size: 1,908 Bytes
8c2f469
 
e7ceaff
8c2f469
e7ceaff
 
8c2f469
e7ceaff
 
 
8c2f469
e7ceaff
 
 
8c2f469
e7ceaff
8c2f469
e7ceaff
 
 
 
 
 
 
 
8c2f469
e7ceaff
 
 
 
 
 
 
8c2f469
e7ceaff
 
8c2f469
e7ceaff
 
 
 
 
 
 
 
 
 
 
 
8c2f469
e7ceaff
 
8c2f469
e7ceaff
8c2f469
e7ceaff
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import os
from fastapi import FastAPI, HTTPException, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import logging

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Setup cache directory
os.makedirs("/app/cache", exist_ok=True)
os.environ['TRANSFORMERS_CACHE'] = "/app/cache"

app = FastAPI(title="Medical LLaMA API")

# Add CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Check GPU availability
def check_gpu():
    if torch.cuda.is_available():
        logger.info(f"GPU available: {torch.cuda.get_device_name(0)}")
        return True
    logger.warning("No GPU available, using CPU")
    return False

# Initialize model with proper device
def init_model():
    try:
        device = "cuda" if check_gpu() else "cpu"
        model_path = os.getenv("MODEL_PATH", "./model/medical_llama_3b")
        
        logger.info(f"Loading model from {model_path}")
        tokenizer = AutoTokenizer.from_pretrained(model_path, cache_dir="/app/cache")
        model = AutoModelForCausalLM.from_pretrained(
            model_path,
            torch_dtype=torch.float16 if device == "cuda" else torch.float32,
            device_map="auto",
            cache_dir="/app/cache"
        )
        return tokenizer, model
    except Exception as e:
        logger.error(f"Error loading model: {str(e)}")
        raise

# Rest of your existing code...

@app.on_event("startup")
async def startup_event():
    logger.info("Starting up application...")
    try:
        global tokenizer, model
        tokenizer, model = init_model()
        logger.info("Model loaded successfully")
    except Exception as e:
        logger.error(f"Failed to load model: {str(e)}")