Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 4,162 Bytes
2095fff 7394c77 e3e12f1 2095fff 7394c77 2095fff e3e12f1 2095fff 7394c77 e3e12f1 2095fff e3e12f1 2095fff e3e12f1 2095fff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import logging
from typing import Optional, Dict, Any
import os
import torch
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(title="Dream Interpretation API")
# Get HF token from environment variable
HF_TOKEN = os.environ.get("HF_TOKEN")
if not HF_TOKEN:
raise ValueError("HF_TOKEN environment variable must be set")
# Define the model names
MODELS = {
"nidra-v1": "m1k3wn/nidra-v1",
"nidra-v2": "m1k3wn/nidra-v2"
}
# Cache for loaded models
loaded_models = {}
loaded_tokenizers = {}
# Pydantic models for request/response validation
class PredictionRequest(BaseModel):
inputs: str
model: str = "nidra-v1" # Default to v1
parameters: Optional[Dict[str, Any]] = {}
class PredictionResponse(BaseModel):
generated_text: str
def load_model(model_name: str):
"""Load model and tokenizer on demand"""
if model_name not in loaded_models:
logger.info(f"Loading {model_name}...")
try:
model_path = MODELS[model_name]
# Load tokenizer with minimal settings
tokenizer = AutoTokenizer.from_pretrained(
model_path,
token=HF_TOKEN,
use_fast=False # Use slower but more stable tokenizer
)
# Load model with minimal settings
model = AutoModelForSeq2SeqLM.from_pretrained(
model_path,
token=HF_TOKEN,
torch_dtype=torch.float32, # Use standard precision
)
# Move model to CPU explicitly
model = model.cpu()
loaded_models[model_name] = model
loaded_tokenizers[model_name] = tokenizer
logger.info(f"Successfully loaded {model_name}")
except Exception as e:
logger.error(f"Error loading {model_name}: {str(e)}")
raise
return loaded_tokenizers[model_name], loaded_models[model_name]
@app.get("/")
def read_root():
"""Root endpoint with API info"""
return {
"api_name": "Dream Interpretation API",
"models_available": list(MODELS.keys()),
"endpoints": {
"/predict": "POST - Make predictions",
"/health": "GET - Health check"
}
}
@app.get("/health")
def health_check():
"""Basic health check endpoint"""
return {"status": "healthy"}
@app.post("/predict", response_model=PredictionResponse)
async def predict(request: PredictionRequest):
"""Make a prediction using the specified model"""
try:
if request.model not in MODELS:
raise HTTPException(
status_code=400,
detail=f"Invalid model choice. Available models: {list(MODELS.keys())}"
)
# Load model on demand
tokenizer, model = load_model(request.model)
# Prepend the shared prefix
full_input = "Interpret this dream: " + request.inputs
# Tokenize and generate with explicit error handling
try:
input_ids = tokenizer(
full_input,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
).input_ids
outputs = model.generate(
input_ids,
max_length=200,
num_return_sequences=1,
no_repeat_ngram_size=2,
**request.parameters
)
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
except Exception as e:
logger.error(f"Error in model prediction pipeline: {str(e)}")
raise HTTPException(status_code=500, detail=f"Model prediction failed: {str(e)}")
return PredictionResponse(generated_text=decoded)
except Exception as e:
logger.error(f"Error in prediction: {str(e)}")
raise HTTPException(status_code=500, detail=str(e)) |