Spaces:
Sleeping
Sleeping
Commit
·
554b5f1
1
Parent(s):
2a5cca5
Update app.py
Browse files
app.py
CHANGED
@@ -22,38 +22,74 @@ import time
|
|
22 |
# Initialize FastAPI app first
|
23 |
app = FastAPI()
|
24 |
|
25 |
-
class
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
def persistent_load(self, pid):
|
|
|
27 |
try:
|
|
|
28 |
if isinstance(pid, bytes):
|
29 |
pid = pid.decode('utf-8', errors='ignore')
|
30 |
-
|
31 |
-
|
32 |
-
if pid == "sentence_transformer_model":
|
33 |
-
return SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
|
34 |
-
return pid
|
35 |
except Exception as e:
|
36 |
-
|
|
|
37 |
|
38 |
-
def safe_load_embeddings():
|
|
|
39 |
try:
|
40 |
-
|
41 |
-
|
|
|
|
|
|
|
|
|
42 |
embeddings_data = unpickler.load()
|
43 |
-
|
|
|
44 |
if not isinstance(embeddings_data, dict):
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
print(f"Error loading embeddings: {str(e)}")
|
55 |
return None
|
56 |
|
|
|
|
|
57 |
# Models and data structures
|
58 |
class GlobalModels:
|
59 |
embedding_model = None
|
@@ -91,10 +127,18 @@ class DocumentResponse(BaseModel):
|
|
91 |
text: str
|
92 |
score: float
|
93 |
|
|
|
94 |
@app.on_event("startup")
|
95 |
async def load_models():
|
96 |
-
"""Initialize
|
97 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
# Load embedding models first
|
99 |
global_models.embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
|
100 |
|
|
|
22 |
# Initialize FastAPI app first
|
23 |
app = FastAPI()
|
24 |
|
25 |
+
class ArticleEmbeddingUnpickler(pickle.Unpickler):
|
26 |
+
"""Custom unpickler specifically designed for article embeddings"""
|
27 |
+
def find_class(self, module, name):
|
28 |
+
# Handle numpy arrays specially
|
29 |
+
if module == 'numpy':
|
30 |
+
return getattr(np, name)
|
31 |
+
# Handle the SentenceTransformer case
|
32 |
+
if module == 'sentence_transformers.SentenceTransformer':
|
33 |
+
from sentence_transformers import SentenceTransformer
|
34 |
+
return SentenceTransformer
|
35 |
+
return super().find_class(module, name)
|
36 |
+
|
37 |
def persistent_load(self, pid):
|
38 |
+
"""Handle persistent IDs during unpickling"""
|
39 |
try:
|
40 |
+
# Convert to string if bytes
|
41 |
if isinstance(pid, bytes):
|
42 |
pid = pid.decode('utf-8', errors='ignore')
|
43 |
+
return str(pid)
|
|
|
|
|
|
|
|
|
44 |
except Exception as e:
|
45 |
+
print(f"Error in persistent_load: {str(e)}")
|
46 |
+
return str(pid)
|
47 |
|
48 |
+
def safe_load_embeddings(file_path='embeddings.pkl'):
|
49 |
+
"""Load embeddings with enhanced error handling for article embeddings"""
|
50 |
try:
|
51 |
+
if not os.path.exists(file_path):
|
52 |
+
print(f"Embeddings file not found at {file_path}")
|
53 |
+
return None
|
54 |
+
|
55 |
+
with open(file_path, 'rb') as file:
|
56 |
+
unpickler = ArticleEmbeddingUnpickler(file)
|
57 |
embeddings_data = unpickler.load()
|
58 |
+
|
59 |
+
# Validate the dictionary structure
|
60 |
if not isinstance(embeddings_data, dict):
|
61 |
+
print(f"Invalid data structure: expected dict, got {type(embeddings_data)}")
|
62 |
+
return None
|
63 |
+
|
64 |
+
# Validate each embedding
|
65 |
+
valid_embeddings = {}
|
66 |
+
for key, value in embeddings_data.items():
|
67 |
+
# Ensure key is string and value is numpy array
|
68 |
+
try:
|
69 |
+
key_str = str(key)
|
70 |
+
if isinstance(value, list):
|
71 |
+
value = np.array(value)
|
72 |
+
if isinstance(value, np.ndarray):
|
73 |
+
valid_embeddings[key_str] = value
|
74 |
+
else:
|
75 |
+
print(f"Skipping invalid embedding for key {key}: {type(value)}")
|
76 |
+
except Exception as e:
|
77 |
+
print(f"Error processing embedding {key}: {str(e)}")
|
78 |
+
continue
|
79 |
+
|
80 |
+
if not valid_embeddings:
|
81 |
+
print("No valid embeddings found in file")
|
82 |
+
return None
|
83 |
+
|
84 |
+
print(f"Successfully loaded {len(valid_embeddings)} embeddings")
|
85 |
+
return valid_embeddings
|
86 |
+
|
87 |
+
except Exception as e:
|
88 |
print(f"Error loading embeddings: {str(e)}")
|
89 |
return None
|
90 |
|
91 |
+
|
92 |
+
|
93 |
# Models and data structures
|
94 |
class GlobalModels:
|
95 |
embedding_model = None
|
|
|
127 |
text: str
|
128 |
score: float
|
129 |
|
130 |
+
# Modified startup event handler
|
131 |
@app.on_event("startup")
|
132 |
async def load_models():
|
133 |
+
"""Initialize models with enhanced embeddings loading"""
|
134 |
try:
|
135 |
+
# Load embeddings first
|
136 |
+
embeddings_data = safe_load_embeddings()
|
137 |
+
if embeddings_data is None:
|
138 |
+
raise HTTPException(
|
139 |
+
status_code=500,
|
140 |
+
detail="Failed to load embeddings data. Check logs for details."
|
141 |
+
)
|
142 |
# Load embedding models first
|
143 |
global_models.embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
|
144 |
|