Jordi Catafal commited on
Commit
0610fdd
·
1 Parent(s): dda5c3b

fixing api problem

Browse files
Files changed (3) hide show
  1. app.py +22 -25
  2. app_old.py +159 -0
  3. test_api.py +1 -1
app.py CHANGED
@@ -1,6 +1,5 @@
1
  from fastapi import FastAPI, HTTPException
2
  from fastapi.middleware.cors import CORSMiddleware
3
- from contextlib import asynccontextmanager
4
  from typing import List
5
  import torch
6
  import uvicorn
@@ -8,31 +7,10 @@ import uvicorn
8
  from models.schemas import EmbeddingRequest, EmbeddingResponse, ModelInfo
9
  from utils.helpers import load_models, get_embeddings, cleanup_memory
10
 
11
- # Global model cache
12
- models_cache = {}
13
-
14
- @asynccontextmanager
15
- async def lifespan(app: FastAPI):
16
- """Application lifespan handler for startup and shutdown"""
17
- # Startup
18
- try:
19
- global models_cache
20
- print("Loading models...")
21
- models_cache = load_models()
22
- print("All models loaded successfully!")
23
- yield
24
- except Exception as e:
25
- print(f"Failed to load models: {str(e)}")
26
- raise
27
- finally:
28
- # Shutdown - cleanup resources
29
- cleanup_memory()
30
-
31
  app = FastAPI(
32
  title="Multilingual & Legal Embedding API",
33
  description="Multi-model embedding API for Spanish, Catalan, English and Legal texts",
34
- version="3.0.0",
35
- lifespan=lifespan
36
  )
37
 
38
  # Add CORS middleware to allow cross-origin requests
@@ -44,6 +22,21 @@ app.add_middleware(
44
  allow_headers=["*"],
45
  )
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  @app.get("/")
48
  async def root():
49
  return {
@@ -58,6 +51,9 @@ async def root():
58
  async def create_embeddings(request: EmbeddingRequest):
59
  """Generate embeddings for input texts"""
60
  try:
 
 
 
61
  if not request.texts:
62
  raise HTTPException(status_code=400, detail="No texts provided")
63
 
@@ -144,11 +140,12 @@ async def health_check():
144
  """Health check endpoint"""
145
  models_loaded = len(models_cache) == 5
146
  return {
147
- "status": "healthy" if models_loaded else "degraded",
148
  "models_loaded": models_loaded,
149
  "available_models": list(models_cache.keys()),
150
  "expected_models": ["jina", "robertalex", "jina-v3", "legal-bert", "roberta-ca"],
151
- "models_count": len(models_cache)
 
152
  }
153
 
154
  if __name__ == "__main__":
 
1
  from fastapi import FastAPI, HTTPException
2
  from fastapi.middleware.cors import CORSMiddleware
 
3
  from typing import List
4
  import torch
5
  import uvicorn
 
7
  from models.schemas import EmbeddingRequest, EmbeddingResponse, ModelInfo
8
  from utils.helpers import load_models, get_embeddings, cleanup_memory
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  app = FastAPI(
11
  title="Multilingual & Legal Embedding API",
12
  description="Multi-model embedding API for Spanish, Catalan, English and Legal texts",
13
+ version="3.0.0"
 
14
  )
15
 
16
  # Add CORS middleware to allow cross-origin requests
 
22
  allow_headers=["*"],
23
  )
24
 
25
+ # Global model cache - loaded on demand
26
+ models_cache = {}
27
+
28
+ def ensure_models_loaded():
29
+ """Load models on first request if not already loaded"""
30
+ global models_cache
31
+ if not models_cache:
32
+ try:
33
+ print("Loading models on demand...")
34
+ models_cache = load_models()
35
+ print("All models loaded successfully!")
36
+ except Exception as e:
37
+ print(f"Failed to load models: {str(e)}")
38
+ raise HTTPException(status_code=500, detail=f"Model loading failed: {str(e)}")
39
+
40
  @app.get("/")
41
  async def root():
42
  return {
 
51
  async def create_embeddings(request: EmbeddingRequest):
52
  """Generate embeddings for input texts"""
53
  try:
54
+ # Load models on first request
55
+ ensure_models_loaded()
56
+
57
  if not request.texts:
58
  raise HTTPException(status_code=400, detail="No texts provided")
59
 
 
140
  """Health check endpoint"""
141
  models_loaded = len(models_cache) == 5
142
  return {
143
+ "status": "healthy" if models_loaded else "ready",
144
  "models_loaded": models_loaded,
145
  "available_models": list(models_cache.keys()),
146
  "expected_models": ["jina", "robertalex", "jina-v3", "legal-bert", "roberta-ca"],
147
+ "models_count": len(models_cache),
148
+ "note": "Models load on first embedding request" if not models_loaded else "All models ready"
149
  }
150
 
151
  if __name__ == "__main__":
app_old.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from contextlib import asynccontextmanager
4
+ from typing import List
5
+ import torch
6
+ import uvicorn
7
+
8
+ from models.schemas import EmbeddingRequest, EmbeddingResponse, ModelInfo
9
+ from utils.helpers import load_models, get_embeddings, cleanup_memory
10
+
11
+ # Global model cache
12
+ models_cache = {}
13
+
14
+ @asynccontextmanager
15
+ async def lifespan(app: FastAPI):
16
+ """Application lifespan handler for startup and shutdown"""
17
+ # Startup
18
+ try:
19
+ global models_cache
20
+ print("Loading models...")
21
+ models_cache = load_models()
22
+ print("All models loaded successfully!")
23
+ yield
24
+ except Exception as e:
25
+ print(f"Failed to load models: {str(e)}")
26
+ raise
27
+ finally:
28
+ # Shutdown - cleanup resources
29
+ cleanup_memory()
30
+
31
+ app = FastAPI(
32
+ title="Multilingual & Legal Embedding API",
33
+ description="Multi-model embedding API for Spanish, Catalan, English and Legal texts",
34
+ version="3.0.0",
35
+ lifespan=lifespan
36
+ )
37
+
38
+ # Add CORS middleware to allow cross-origin requests
39
+ app.add_middleware(
40
+ CORSMiddleware,
41
+ allow_origins=["*"], # In production, specify actual domains
42
+ allow_credentials=True,
43
+ allow_methods=["*"],
44
+ allow_headers=["*"],
45
+ )
46
+
47
+ @app.get("/")
48
+ async def root():
49
+ return {
50
+ "message": "Multilingual & Legal Embedding API",
51
+ "models": ["jina", "robertalex", "jina-v3", "legal-bert", "roberta-ca"],
52
+ "status": "running",
53
+ "docs": "/docs",
54
+ "total_models": 5
55
+ }
56
+
57
+ @app.post("/embed", response_model=EmbeddingResponse)
58
+ async def create_embeddings(request: EmbeddingRequest):
59
+ """Generate embeddings for input texts"""
60
+ try:
61
+ if not request.texts:
62
+ raise HTTPException(status_code=400, detail="No texts provided")
63
+
64
+ if len(request.texts) > 50: # Rate limiting
65
+ raise HTTPException(status_code=400, detail="Maximum 50 texts per request")
66
+
67
+ embeddings = get_embeddings(
68
+ request.texts,
69
+ request.model,
70
+ models_cache,
71
+ request.normalize,
72
+ request.max_length
73
+ )
74
+
75
+ # Cleanup memory after large batches
76
+ if len(request.texts) > 20:
77
+ cleanup_memory()
78
+
79
+ return EmbeddingResponse(
80
+ embeddings=embeddings,
81
+ model_used=request.model,
82
+ dimensions=len(embeddings[0]) if embeddings else 0,
83
+ num_texts=len(request.texts)
84
+ )
85
+
86
+ except ValueError as e:
87
+ raise HTTPException(status_code=400, detail=str(e))
88
+ except Exception as e:
89
+ raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}")
90
+
91
+ @app.get("/models", response_model=List[ModelInfo])
92
+ async def list_models():
93
+ """List available models and their specifications"""
94
+ return [
95
+ ModelInfo(
96
+ model_id="jina",
97
+ name="jinaai/jina-embeddings-v2-base-es",
98
+ dimensions=768,
99
+ max_sequence_length=8192,
100
+ languages=["Spanish", "English"],
101
+ model_type="bilingual",
102
+ description="Bilingual Spanish-English embeddings with long context support"
103
+ ),
104
+ ModelInfo(
105
+ model_id="robertalex",
106
+ name="PlanTL-GOB-ES/RoBERTalex",
107
+ dimensions=768,
108
+ max_sequence_length=512,
109
+ languages=["Spanish"],
110
+ model_type="legal domain",
111
+ description="Spanish legal domain specialized embeddings"
112
+ ),
113
+ ModelInfo(
114
+ model_id="jina-v3",
115
+ name="jinaai/jina-embeddings-v3",
116
+ dimensions=1024,
117
+ max_sequence_length=8192,
118
+ languages=["Multilingual"],
119
+ model_type="multilingual",
120
+ description="Latest Jina v3 with superior multilingual performance"
121
+ ),
122
+ ModelInfo(
123
+ model_id="legal-bert",
124
+ name="nlpaueb/legal-bert-base-uncased",
125
+ dimensions=768,
126
+ max_sequence_length=512,
127
+ languages=["English"],
128
+ model_type="legal domain",
129
+ description="English legal domain BERT model"
130
+ ),
131
+ ModelInfo(
132
+ model_id="roberta-ca",
133
+ name="projecte-aina/roberta-large-ca-v2",
134
+ dimensions=1024,
135
+ max_sequence_length=512,
136
+ languages=["Catalan"],
137
+ model_type="general",
138
+ description="Catalan RoBERTa-large model trained on large corpus"
139
+ )
140
+ ]
141
+
142
+ @app.get("/health")
143
+ async def health_check():
144
+ """Health check endpoint"""
145
+ models_loaded = len(models_cache) == 5
146
+ return {
147
+ "status": "healthy" if models_loaded else "degraded",
148
+ "models_loaded": models_loaded,
149
+ "available_models": list(models_cache.keys()),
150
+ "expected_models": ["jina", "robertalex", "jina-v3", "legal-bert", "roberta-ca"],
151
+ "models_count": len(models_cache)
152
+ }
153
+
154
+ if __name__ == "__main__":
155
+ # Set multi-threading for CPU
156
+ torch.set_num_threads(8)
157
+ torch.set_num_interop_threads(1)
158
+
159
+ uvicorn.run(app, host="0.0.0.0", port=7860)
test_api.py CHANGED
@@ -7,7 +7,7 @@ import requests
7
  import json
8
  import time
9
 
10
- def test_api(base_url="http://localhost:7860"):
11
  """Test the API endpoints"""
12
 
13
  print(f"Testing API at {base_url}")
 
7
  import json
8
  import time
9
 
10
+ def test_api(base_url="https://aurasystems-spanish-embeddings-api.hf.space"):
11
  """Test the API endpoints"""
12
 
13
  print(f"Testing API at {base_url}")