Spaces:
Sleeping
Sleeping
File size: 3,209 Bytes
c3aef13 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
from fastapi import FastAPI, HTTPException
from typing import List
import torch
import uvicorn
import gc
import os
from models.schemas import EmbeddingRequest, EmbeddingResponse, ModelInfo
from utils.helpers import load_models, get_embeddings, cleanup_memory
app = FastAPI(
title="Spanish Embedding API",
description="Dual Spanish embedding models API",
version="1.0.0"
)
# Global model cache
models_cache = {}
@app.on_event("startup")
async def startup_event():
"""Load models on startup"""
global models_cache
models_cache = load_models()
print("Models loaded successfully!")
@app.get("/")
async def root():
return {
"message": "Spanish Embedding API",
"models": ["jina", "robertalex"],
"status": "running",
"docs": "/docs"
}
@app.post("/embed", response_model=EmbeddingResponse)
async def create_embeddings(request: EmbeddingRequest):
"""Generate embeddings for input texts"""
try:
if not request.texts:
raise HTTPException(status_code=400, detail="No texts provided")
if len(request.texts) > 50: # Rate limiting
raise HTTPException(status_code=400, detail="Maximum 50 texts per request")
embeddings = get_embeddings(
request.texts,
request.model,
models_cache,
request.normalize,
request.max_length
)
# Cleanup memory after large batches
if len(request.texts) > 20:
cleanup_memory()
return EmbeddingResponse(
embeddings=embeddings,
model_used=request.model,
dimensions=len(embeddings[0]) if embeddings else 0,
num_texts=len(request.texts)
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}")
@app.get("/models", response_model=List[ModelInfo])
async def list_models():
"""List available models and their specifications"""
return [
ModelInfo(
model_id="jina",
name="jinaai/jina-embeddings-v2-base-es",
dimensions=768,
max_sequence_length=8192,
languages=["Spanish", "English"],
model_type="bilingual",
description="Bilingual Spanish-English embeddings with long context support"
),
ModelInfo(
model_id="robertalex",
name="PlanTL-GOB-ES/RoBERTalex",
dimensions=768,
max_sequence_length=512,
languages=["Spanish"],
model_type="legal domain",
description="Spanish legal domain specialized embeddings"
)
]
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {
"status": "healthy",
"models_loaded": len(models_cache) == 2,
"available_models": list(models_cache.keys())
}
if __name__ == "__main__":
# Set multi-threading for CPU
torch.set_num_threads(8)
torch.set_num_interop_threads(1)
uvicorn.run(app, host="0.0.0.0", port=7860) |