nidra / app.py
m1k3wn's picture
Update app.py
f7ed1d0 verified
raw
history blame
6.06 kB
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import T5Tokenizer, T5ForConditionalGeneration, GenerationConfig
from typing import Optional, Dict, Any, ClassVar
import logging
import os
import sys
import traceback
from functools import lru_cache
# Initialize FastAPI
app = FastAPI()
# Set up logging with more detailed formatting
logging.basicConfig(
level=logging.DEBUG,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Get HF token
HF_TOKEN = os.environ.get("HF_TOKEN")
if not HF_TOKEN:
logger.warning("No HF_TOKEN found in environment variables")
MODELS = {
"nidra-v1": "m1k3wn/nidra-v1",
"nidra-v2": "m1k3wn/nidra-v2"
}
DEFAULT_GENERATION_CONFIGS = {
"nidra-v1": {
"max_length": 300,
"min_length": 150,
"num_beams": 8,
"temperature": 0.55,
"do_sample": True,
"top_p": 0.95,
"repetition_penalty": 4.5,
"no_repeat_ngram_size": 4,
"early_stopping": True,
"length_penalty": 1.2,
},
"nidra-v2": {
"max_length": 300,
"min_length": 150,
"num_beams": 8,
"temperature": 0.4,
"do_sample": True,
"top_p": 0.95,
"repetition_penalty": 3.5,
"no_repeat_ngram_size": 4,
"early_stopping": True,
"length_penalty": 1.2,
}
}
class ModelManager:
_instances: ClassVar[Dict[str, tuple]] = {}
@classmethod
def get_model_and_tokenizer(cls, model_name: str):
if model_name not in cls._instances:
try:
model_path = MODELS[model_name]
logger.info(f"Loading tokenizer for {model_name}")
tokenizer = T5Tokenizer.from_pretrained(
model_path,
token=HF_TOKEN,
local_files_only=False,
return_special_tokens_mask=True
)
logger.info(f"Loading model {model_name}")
model = T5ForConditionalGeneration.from_pretrained(
model_path,
token=HF_TOKEN,
local_files_only=False,
device_map="auto" # This will handle GPU if available
)
cls._instances[model_name] = (model, tokenizer)
logger.info(f"Successfully loaded {model_name}")
except Exception as e:
logger.error(f"Error loading {model_name}: {str(e)}")
raise HTTPException(
status_code=500,
detail=f"Failed to load model {model_name}: {str(e)}"
)
return cls._instances[model_name]
class PredictionRequest(BaseModel):
inputs: str
model: str = "nidra-v1"
parameters: Optional[Dict[str, Any]] = None
class PredictionResponse(BaseModel):
generated_text: str
model_used: str
@app.get("/version")
async def version():
return {
"python_version": sys.version,
"models_available": list(MODELS.keys())
}
@app.get("/health")
async def health():
# More comprehensive health check
try:
# Try to load at least one model to verify functionality
ModelManager.get_model_and_tokenizer("nidra-v1")
return {
"status": "healthy",
"loaded_models": list(ModelManager._instances.keys())
}
except Exception as e:
logger.error(f"Health check failed: {str(e)}")
return {
"status": "unhealthy",
"error": str(e)
}
@app.post("/predict", response_model=PredictionResponse)
async def predict(request: PredictionRequest):
try:
# Validate model
if request.model not in MODELS:
raise HTTPException(
status_code=400,
detail=f"Invalid model. Available models: {list(MODELS.keys())}"
)
# Get cached model and tokenizer
model, tokenizer = ModelManager.get_model_and_tokenizer(request.model)
# Get generation parameters
generation_params = DEFAULT_GENERATION_CONFIGS[request.model].copy()
# Try to load model's saved generation config
try:
model_generation_config = model.generation_config
generation_params.update({
k: v for k, v in model_generation_config.to_dict().items()
if v is not None
})
except Exception as config_load_error:
logger.warning(f"Using default generation config: {config_load_error}")
# Override with request-specific parameters
if request.parameters:
generation_params.update(request.parameters)
logger.debug(f"Final generation parameters: {generation_params}")
# Prepare input
full_input = "Interpret this dream: " + request.inputs
inputs = tokenizer(
full_input,
return_tensors="pt",
truncation=True,
max_length=512,
padding=True
).to(model.device) # Ensure inputs are on same device as model
# Generate
outputs = model.generate(
**inputs,
**{k: v for k, v in generation_params.items() if k in [
'max_length', 'min_length', 'do_sample', 'temperature',
'top_p', 'top_k', 'num_beams', 'no_repeat_ngram_size',
'repetition_penalty', 'early_stopping'
]}
)
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
return PredictionResponse(
generated_text=result,
model_used=request.model
)
except Exception as e:
error_msg = f"Error during prediction: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
raise HTTPException(status_code=500, detail=error_msg)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)