Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException | |
from fastapi.middleware.cors import CORSMiddleware | |
from contextlib import asynccontextmanager | |
from typing import List | |
import torch | |
import uvicorn | |
from models.schemas import EmbeddingRequest, EmbeddingResponse, ModelInfo | |
from utils.helpers import load_models, get_embeddings, cleanup_memory | |
# Global model cache | |
models_cache = {} | |
async def lifespan(app: FastAPI): | |
"""Application lifespan handler for startup and shutdown""" | |
# Startup | |
try: | |
global models_cache | |
print("Loading models...") | |
models_cache = load_models() | |
print("All models loaded successfully!") | |
yield | |
except Exception as e: | |
print(f"Failed to load models: {str(e)}") | |
raise | |
finally: | |
# Shutdown - cleanup resources | |
cleanup_memory() | |
app = FastAPI( | |
title="Multilingual & Legal Embedding API", | |
description="Multi-model embedding API for Spanish, Catalan, English and Legal texts", | |
version="3.0.0", | |
lifespan=lifespan | |
) | |
# Add CORS middleware to allow cross-origin requests | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # In production, specify actual domains | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
async def root(): | |
return { | |
"message": "Multilingual & Legal Embedding API", | |
"models": ["jina", "robertalex", "jina-v3", "legal-bert", "roberta-ca"], | |
"status": "running", | |
"docs": "/docs", | |
"total_models": 5 | |
} | |
async def create_embeddings(request: EmbeddingRequest): | |
"""Generate embeddings for input texts""" | |
try: | |
if not request.texts: | |
raise HTTPException(status_code=400, detail="No texts provided") | |
if len(request.texts) > 50: # Rate limiting | |
raise HTTPException(status_code=400, detail="Maximum 50 texts per request") | |
embeddings = get_embeddings( | |
request.texts, | |
request.model, | |
models_cache, | |
request.normalize, | |
request.max_length | |
) | |
# Cleanup memory after large batches | |
if len(request.texts) > 20: | |
cleanup_memory() | |
return EmbeddingResponse( | |
embeddings=embeddings, | |
model_used=request.model, | |
dimensions=len(embeddings[0]) if embeddings else 0, | |
num_texts=len(request.texts) | |
) | |
except ValueError as e: | |
raise HTTPException(status_code=400, detail=str(e)) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}") | |
async def list_models(): | |
"""List available models and their specifications""" | |
return [ | |
ModelInfo( | |
model_id="jina", | |
name="jinaai/jina-embeddings-v2-base-es", | |
dimensions=768, | |
max_sequence_length=8192, | |
languages=["Spanish", "English"], | |
model_type="bilingual", | |
description="Bilingual Spanish-English embeddings with long context support" | |
), | |
ModelInfo( | |
model_id="robertalex", | |
name="PlanTL-GOB-ES/RoBERTalex", | |
dimensions=768, | |
max_sequence_length=512, | |
languages=["Spanish"], | |
model_type="legal domain", | |
description="Spanish legal domain specialized embeddings" | |
), | |
ModelInfo( | |
model_id="jina-v3", | |
name="jinaai/jina-embeddings-v3", | |
dimensions=1024, | |
max_sequence_length=8192, | |
languages=["Multilingual"], | |
model_type="multilingual", | |
description="Latest Jina v3 with superior multilingual performance" | |
), | |
ModelInfo( | |
model_id="legal-bert", | |
name="nlpaueb/legal-bert-base-uncased", | |
dimensions=768, | |
max_sequence_length=512, | |
languages=["English"], | |
model_type="legal domain", | |
description="English legal domain BERT model" | |
), | |
ModelInfo( | |
model_id="roberta-ca", | |
name="projecte-aina/roberta-large-ca-v2", | |
dimensions=1024, | |
max_sequence_length=512, | |
languages=["Catalan"], | |
model_type="general", | |
description="Catalan RoBERTa-large model trained on large corpus" | |
) | |
] | |
async def health_check(): | |
"""Health check endpoint""" | |
models_loaded = len(models_cache) == 5 | |
return { | |
"status": "healthy" if models_loaded else "degraded", | |
"models_loaded": models_loaded, | |
"available_models": list(models_cache.keys()), | |
"expected_models": ["jina", "robertalex", "jina-v3", "legal-bert", "roberta-ca"], | |
"models_count": len(models_cache) | |
} | |
if __name__ == "__main__": | |
# Set multi-threading for CPU | |
torch.set_num_threads(8) | |
torch.set_num_interop_threads(1) | |
uvicorn.run(app, host="0.0.0.0", port=7860) |