Jordi Catafal
Move files to root for HF Spaces
734683c
raw
history blame
3.21 kB
from fastapi import FastAPI, HTTPException
from typing import List
import torch
import uvicorn
import gc
import os
from models.schemas import EmbeddingRequest, EmbeddingResponse, ModelInfo
from utils.helpers import load_models, get_embeddings, cleanup_memory
app = FastAPI(
title="Spanish Embedding API",
description="Dual Spanish embedding models API",
version="1.0.0"
)
# Global model cache
models_cache = {}
@app.on_event("startup")
async def startup_event():
"""Load models on startup"""
global models_cache
models_cache = load_models()
print("Models loaded successfully!")
@app.get("/")
async def root():
return {
"message": "Spanish Embedding API",
"models": ["jina", "robertalex"],
"status": "running",
"docs": "/docs"
}
@app.post("/embed", response_model=EmbeddingResponse)
async def create_embeddings(request: EmbeddingRequest):
"""Generate embeddings for input texts"""
try:
if not request.texts:
raise HTTPException(status_code=400, detail="No texts provided")
if len(request.texts) > 50: # Rate limiting
raise HTTPException(status_code=400, detail="Maximum 50 texts per request")
embeddings = get_embeddings(
request.texts,
request.model,
models_cache,
request.normalize,
request.max_length
)
# Cleanup memory after large batches
if len(request.texts) > 20:
cleanup_memory()
return EmbeddingResponse(
embeddings=embeddings,
model_used=request.model,
dimensions=len(embeddings[0]) if embeddings else 0,
num_texts=len(request.texts)
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}")
@app.get("/models", response_model=List[ModelInfo])
async def list_models():
"""List available models and their specifications"""
return [
ModelInfo(
model_id="jina",
name="jinaai/jina-embeddings-v2-base-es",
dimensions=768,
max_sequence_length=8192,
languages=["Spanish", "English"],
model_type="bilingual",
description="Bilingual Spanish-English embeddings with long context support"
),
ModelInfo(
model_id="robertalex",
name="PlanTL-GOB-ES/RoBERTalex",
dimensions=768,
max_sequence_length=512,
languages=["Spanish"],
model_type="legal domain",
description="Spanish legal domain specialized embeddings"
)
]
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {
"status": "healthy",
"models_loaded": len(models_cache) == 2,
"available_models": list(models_cache.keys())
}
if __name__ == "__main__":
# Set multi-threading for CPU
torch.set_num_threads(8)
torch.set_num_interop_threads(1)
uvicorn.run(app, host="0.0.0.0", port=7860)