File size: 3,209 Bytes
c3aef13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from fastapi import FastAPI, HTTPException
from typing import List
import torch
import uvicorn
import gc
import os

from models.schemas import EmbeddingRequest, EmbeddingResponse, ModelInfo
from utils.helpers import load_models, get_embeddings, cleanup_memory

app = FastAPI(
    title="Spanish Embedding API",
    description="Dual Spanish embedding models API",
    version="1.0.0"
)

# Global model cache
models_cache = {}

@app.on_event("startup")
async def startup_event():
    """Load models on startup"""
    global models_cache
    models_cache = load_models()
    print("Models loaded successfully!")

@app.get("/")
async def root():
    return {
        "message": "Spanish Embedding API",
        "models": ["jina", "robertalex"],
        "status": "running",
        "docs": "/docs"
    }

@app.post("/embed", response_model=EmbeddingResponse)
async def create_embeddings(request: EmbeddingRequest):
    """Generate embeddings for input texts"""
    try:
        if not request.texts:
            raise HTTPException(status_code=400, detail="No texts provided")
        
        if len(request.texts) > 50:  # Rate limiting
            raise HTTPException(status_code=400, detail="Maximum 50 texts per request")
        
        embeddings = get_embeddings(
            request.texts,
            request.model,
            models_cache,
            request.normalize,
            request.max_length
        )
        
        # Cleanup memory after large batches
        if len(request.texts) > 20:
            cleanup_memory()
        
        return EmbeddingResponse(
            embeddings=embeddings,
            model_used=request.model,
            dimensions=len(embeddings[0]) if embeddings else 0,
            num_texts=len(request.texts)
        )
        
    except ValueError as e:
        raise HTTPException(status_code=400, detail=str(e))
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}")

@app.get("/models", response_model=List[ModelInfo])
async def list_models():
    """List available models and their specifications"""
    return [
        ModelInfo(
            model_id="jina",
            name="jinaai/jina-embeddings-v2-base-es",
            dimensions=768,
            max_sequence_length=8192,
            languages=["Spanish", "English"],
            model_type="bilingual",
            description="Bilingual Spanish-English embeddings with long context support"
        ),
        ModelInfo(
            model_id="robertalex",
            name="PlanTL-GOB-ES/RoBERTalex",
            dimensions=768,
            max_sequence_length=512,
            languages=["Spanish"],
            model_type="legal domain",
            description="Spanish legal domain specialized embeddings"
        )
    ]

@app.get("/health")
async def health_check():
    """Health check endpoint"""
    return {
        "status": "healthy",
        "models_loaded": len(models_cache) == 2,
        "available_models": list(models_cache.keys())
    }

if __name__ == "__main__":
    # Set multi-threading for CPU
    torch.set_num_threads(8)
    torch.set_num_interop_threads(1)
    
    uvicorn.run(app, host="0.0.0.0", port=7860)