from fastapi import FastAPI, HTTPException from pydantic import BaseModel from transformers import T5Tokenizer, T5ForConditionalGeneration, GenerationConfig from typing import Optional, Dict, Any, ClassVar import logging import os import sys import traceback from functools import lru_cache # Initialize FastAPI app = FastAPI() # Set up logging with more detailed formatting logging.basicConfig( level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) # Get HF token HF_TOKEN = os.environ.get("HF_TOKEN") if not HF_TOKEN: logger.warning("No HF_TOKEN found in environment variables") MODELS = { "nidra-v1": "m1k3wn/nidra-v1", "nidra-v2": "m1k3wn/nidra-v2" } DEFAULT_GENERATION_CONFIGS = { "nidra-v1": { "max_length": 300, "min_length": 150, "num_beams": 8, "temperature": 0.55, "do_sample": True, "top_p": 0.95, "repetition_penalty": 4.5, "no_repeat_ngram_size": 4, "early_stopping": True, "length_penalty": 1.2, }, "nidra-v2": { "max_length": 300, "min_length": 150, "num_beams": 8, "temperature": 0.4, "do_sample": True, "top_p": 0.95, "repetition_penalty": 3.5, "no_repeat_ngram_size": 4, "early_stopping": True, "length_penalty": 1.2, } } class ModelManager: _instances: ClassVar[Dict[str, tuple]] = {} @classmethod def get_model_and_tokenizer(cls, model_name: str): if model_name not in cls._instances: try: model_path = MODELS[model_name] logger.info(f"Loading tokenizer for {model_name}") tokenizer = T5Tokenizer.from_pretrained( model_path, token=HF_TOKEN, local_files_only=False, return_special_tokens_mask=True ) logger.info(f"Loading model {model_name}") model = T5ForConditionalGeneration.from_pretrained( model_path, token=HF_TOKEN, local_files_only=False, device_map="auto" # This will handle GPU if available ) cls._instances[model_name] = (model, tokenizer) logger.info(f"Successfully loaded {model_name}") except Exception as e: logger.error(f"Error loading {model_name}: {str(e)}") raise HTTPException( status_code=500, detail=f"Failed to load model {model_name}: {str(e)}" ) return cls._instances[model_name] class PredictionRequest(BaseModel): inputs: str model: str = "nidra-v1" parameters: Optional[Dict[str, Any]] = None class PredictionResponse(BaseModel): generated_text: str model_used: str @app.get("/version") async def version(): return { "python_version": sys.version, "models_available": list(MODELS.keys()) } @app.get("/health") async def health(): # More comprehensive health check try: # Try to load at least one model to verify functionality ModelManager.get_model_and_tokenizer("nidra-v1") return { "status": "healthy", "loaded_models": list(ModelManager._instances.keys()) } except Exception as e: logger.error(f"Health check failed: {str(e)}") return { "status": "unhealthy", "error": str(e) } @app.post("/predict", response_model=PredictionResponse) async def predict(request: PredictionRequest): try: # Validate model if request.model not in MODELS: raise HTTPException( status_code=400, detail=f"Invalid model. Available models: {list(MODELS.keys())}" ) # Get cached model and tokenizer model, tokenizer = ModelManager.get_model_and_tokenizer(request.model) # Get generation parameters generation_params = DEFAULT_GENERATION_CONFIGS[request.model].copy() # Try to load model's saved generation config try: model_generation_config = model.generation_config generation_params.update({ k: v for k, v in model_generation_config.to_dict().items() if v is not None }) except Exception as config_load_error: logger.warning(f"Using default generation config: {config_load_error}") # Override with request-specific parameters if request.parameters: generation_params.update(request.parameters) logger.debug(f"Final generation parameters: {generation_params}") # Prepare input full_input = "Interpret this dream: " + request.inputs inputs = tokenizer( full_input, return_tensors="pt", truncation=True, max_length=512, padding=True ).to(model.device) # Ensure inputs are on same device as model # Generate outputs = model.generate( **inputs, **{k: v for k, v in generation_params.items() if k in [ 'max_length', 'min_length', 'do_sample', 'temperature', 'top_p', 'top_k', 'num_beams', 'no_repeat_ngram_size', 'repetition_penalty', 'early_stopping' ]} ) result = tokenizer.decode(outputs[0], skip_special_tokens=True) return PredictionResponse( generated_text=result, model_used=request.model ) except Exception as e: error_msg = f"Error during prediction: {str(e)}\n{traceback.format_exc()}" logger.error(error_msg) raise HTTPException(status_code=500, detail=error_msg) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)