Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
import logging | |
from typing import Optional, Dict, Any | |
import os | |
import torch | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
app = FastAPI(title="Dream Interpretation API") | |
# Get HF token from environment variable | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
if not HF_TOKEN: | |
raise ValueError("HF_TOKEN environment variable must be set") | |
# Define the model names | |
MODELS = { | |
"nidra-v1": "m1k3wn/nidra-v1", | |
"nidra-v2": "m1k3wn/nidra-v2" | |
} | |
# Cache for loaded models | |
loaded_models = {} | |
loaded_tokenizers = {} | |
# Pydantic models for request/response validation | |
class PredictionRequest(BaseModel): | |
inputs: str | |
model: str = "nidra-v1" # Default to v1 | |
parameters: Optional[Dict[str, Any]] = {} | |
class PredictionResponse(BaseModel): | |
generated_text: str | |
def load_model(model_name: str): | |
"""Load model and tokenizer on demand""" | |
if model_name not in loaded_models: | |
logger.info(f"Loading {model_name}...") | |
try: | |
model_path = MODELS[model_name] | |
# Load tokenizer with minimal settings | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_path, | |
token=HF_TOKEN, | |
use_fast=False # Use slower but more stable tokenizer | |
) | |
# Load model with minimal settings | |
model = AutoModelForSeq2SeqLM.from_pretrained( | |
model_path, | |
token=HF_TOKEN, | |
torch_dtype=torch.float32, # Use standard precision | |
) | |
# Move model to CPU explicitly | |
model = model.cpu() | |
loaded_models[model_name] = model | |
loaded_tokenizers[model_name] = tokenizer | |
logger.info(f"Successfully loaded {model_name}") | |
except Exception as e: | |
logger.error(f"Error loading {model_name}: {str(e)}") | |
raise | |
return loaded_tokenizers[model_name], loaded_models[model_name] | |
def read_root(): | |
"""Root endpoint with API info""" | |
return { | |
"api_name": "Dream Interpretation API", | |
"models_available": list(MODELS.keys()), | |
"endpoints": { | |
"/predict": "POST - Make predictions", | |
"/health": "GET - Health check" | |
} | |
} | |
def health_check(): | |
"""Basic health check endpoint""" | |
return {"status": "healthy"} | |
async def predict(request: PredictionRequest): | |
"""Make a prediction using the specified model""" | |
try: | |
if request.model not in MODELS: | |
raise HTTPException( | |
status_code=400, | |
detail=f"Invalid model choice. Available models: {list(MODELS.keys())}" | |
) | |
# Load model on demand | |
tokenizer, model = load_model(request.model) | |
# Prepend the shared prefix | |
full_input = "Interpret this dream: " + request.inputs | |
# Tokenize and generate with explicit error handling | |
try: | |
input_ids = tokenizer( | |
full_input, | |
return_tensors="pt", | |
padding=True, | |
truncation=True, | |
max_length=512 | |
).input_ids | |
outputs = model.generate( | |
input_ids, | |
max_length=200, | |
num_return_sequences=1, | |
no_repeat_ngram_size=2, | |
**request.parameters | |
) | |
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
except Exception as e: | |
logger.error(f"Error in model prediction pipeline: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Model prediction failed: {str(e)}") | |
return PredictionResponse(generated_text=decoded) | |
except Exception as e: | |
logger.error(f"Error in prediction: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) |