nidra / app.py
m1k3wn's picture
Update app.py
e3e12f1 verified
raw
history blame
4.16 kB
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import logging
from typing import Optional, Dict, Any
import os
import torch
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(title="Dream Interpretation API")
# Get HF token from environment variable
HF_TOKEN = os.environ.get("HF_TOKEN")
if not HF_TOKEN:
raise ValueError("HF_TOKEN environment variable must be set")
# Define the model names
MODELS = {
"nidra-v1": "m1k3wn/nidra-v1",
"nidra-v2": "m1k3wn/nidra-v2"
}
# Cache for loaded models
loaded_models = {}
loaded_tokenizers = {}
# Pydantic models for request/response validation
class PredictionRequest(BaseModel):
inputs: str
model: str = "nidra-v1" # Default to v1
parameters: Optional[Dict[str, Any]] = {}
class PredictionResponse(BaseModel):
generated_text: str
def load_model(model_name: str):
"""Load model and tokenizer on demand"""
if model_name not in loaded_models:
logger.info(f"Loading {model_name}...")
try:
model_path = MODELS[model_name]
# Load tokenizer with minimal settings
tokenizer = AutoTokenizer.from_pretrained(
model_path,
token=HF_TOKEN,
use_fast=False # Use slower but more stable tokenizer
)
# Load model with minimal settings
model = AutoModelForSeq2SeqLM.from_pretrained(
model_path,
token=HF_TOKEN,
torch_dtype=torch.float32, # Use standard precision
)
# Move model to CPU explicitly
model = model.cpu()
loaded_models[model_name] = model
loaded_tokenizers[model_name] = tokenizer
logger.info(f"Successfully loaded {model_name}")
except Exception as e:
logger.error(f"Error loading {model_name}: {str(e)}")
raise
return loaded_tokenizers[model_name], loaded_models[model_name]
@app.get("/")
def read_root():
"""Root endpoint with API info"""
return {
"api_name": "Dream Interpretation API",
"models_available": list(MODELS.keys()),
"endpoints": {
"/predict": "POST - Make predictions",
"/health": "GET - Health check"
}
}
@app.get("/health")
def health_check():
"""Basic health check endpoint"""
return {"status": "healthy"}
@app.post("/predict", response_model=PredictionResponse)
async def predict(request: PredictionRequest):
"""Make a prediction using the specified model"""
try:
if request.model not in MODELS:
raise HTTPException(
status_code=400,
detail=f"Invalid model choice. Available models: {list(MODELS.keys())}"
)
# Load model on demand
tokenizer, model = load_model(request.model)
# Prepend the shared prefix
full_input = "Interpret this dream: " + request.inputs
# Tokenize and generate with explicit error handling
try:
input_ids = tokenizer(
full_input,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
).input_ids
outputs = model.generate(
input_ids,
max_length=200,
num_return_sequences=1,
no_repeat_ngram_size=2,
**request.parameters
)
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
except Exception as e:
logger.error(f"Error in model prediction pipeline: {str(e)}")
raise HTTPException(status_code=500, detail=f"Model prediction failed: {str(e)}")
return PredictionResponse(generated_text=decoded)
except Exception as e:
logger.error(f"Error in prediction: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))