import os from fastapi import FastAPI, HTTPException, BackgroundTasks from fastapi.middleware.cors import CORSMiddleware import torch from transformers import AutoTokenizer, AutoModelForCausalLM import logging # Setup logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Setup cache directory os.makedirs("/app/cache", exist_ok=True) os.environ['TRANSFORMERS_CACHE'] = "/app/cache" app = FastAPI(title="Medical LLaMA API") # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Check GPU availability def check_gpu(): if torch.cuda.is_available(): logger.info(f"GPU available: {torch.cuda.get_device_name(0)}") return True logger.warning("No GPU available, using CPU") return False # Initialize model with proper device def init_model(): try: device = "cuda" if check_gpu() else "cpu" model_path = os.getenv("MODEL_PATH", "./model/medical_llama_3b") logger.info(f"Loading model from {model_path}") tokenizer = AutoTokenizer.from_pretrained(model_path, cache_dir="/app/cache") model = AutoModelForCausalLM.from_pretrained( model_path, torch_dtype=torch.float16 if device == "cuda" else torch.float32, device_map="auto", cache_dir="/app/cache" ) return tokenizer, model except Exception as e: logger.error(f"Error loading model: {str(e)}") raise # Rest of your existing code... @app.on_event("startup") async def startup_event(): logger.info("Starting up application...") try: global tokenizer, model tokenizer, model = init_model() logger.info("Model loaded successfully") except Exception as e: logger.error(f"Failed to load model: {str(e)}")