Spaces:
Sleeping
Sleeping
Commit
·
3b1c99a
1
Parent(s):
554b5f1
Update app.py
Browse files
app.py
CHANGED
@@ -23,72 +23,93 @@ import time
|
|
23 |
app = FastAPI()
|
24 |
|
25 |
class ArticleEmbeddingUnpickler(pickle.Unpickler):
|
26 |
-
"""Custom unpickler
|
27 |
-
def find_class(self, module, name):
|
28 |
-
# Handle numpy arrays specially
|
29 |
if module == 'numpy':
|
30 |
return getattr(np, name)
|
31 |
-
# Handle the SentenceTransformer case
|
32 |
if module == 'sentence_transformers.SentenceTransformer':
|
33 |
from sentence_transformers import SentenceTransformer
|
34 |
return SentenceTransformer
|
35 |
return super().find_class(module, name)
|
36 |
|
37 |
-
def persistent_load(self, pid):
|
38 |
-
"""
|
39 |
try:
|
40 |
-
#
|
41 |
if isinstance(pid, bytes):
|
42 |
-
|
43 |
-
|
|
|
|
|
44 |
except Exception as e:
|
45 |
-
print(f"Error in persistent_load: {str(e)}")
|
46 |
-
return
|
47 |
|
48 |
-
def safe_load_embeddings(file_path='embeddings.pkl'):
|
49 |
-
"""Load embeddings with enhanced error handling
|
50 |
try:
|
51 |
if not os.path.exists(file_path):
|
52 |
-
|
53 |
-
return None
|
54 |
|
55 |
with open(file_path, 'rb') as file:
|
56 |
unpickler = ArticleEmbeddingUnpickler(file)
|
57 |
embeddings_data = unpickler.load()
|
58 |
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
print(f"
|
78 |
continue
|
79 |
|
80 |
-
|
81 |
-
|
82 |
-
|
|
|
|
|
|
|
|
|
|
|
83 |
|
84 |
-
|
85 |
-
return valid_embeddings
|
86 |
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
|
|
|
|
|
91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
|
93 |
# Models and data structures
|
94 |
class GlobalModels:
|
@@ -129,16 +150,16 @@ class DocumentResponse(BaseModel):
|
|
129 |
|
130 |
# Modified startup event handler
|
131 |
@app.on_event("startup")
|
|
|
132 |
async def load_models():
|
133 |
-
"""Initialize models with enhanced embeddings loading"""
|
134 |
try:
|
135 |
-
|
136 |
embeddings_data = safe_load_embeddings()
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
)
|
142 |
# Load embedding models first
|
143 |
global_models.embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
|
144 |
|
|
|
23 |
app = FastAPI()
|
24 |
|
25 |
class ArticleEmbeddingUnpickler(pickle.Unpickler):
|
26 |
+
"""Custom unpickler for article embeddings with enhanced persistence handling"""
|
27 |
+
def find_class(self, module: str, name: str) -> Any:
|
|
|
28 |
if module == 'numpy':
|
29 |
return getattr(np, name)
|
|
|
30 |
if module == 'sentence_transformers.SentenceTransformer':
|
31 |
from sentence_transformers import SentenceTransformer
|
32 |
return SentenceTransformer
|
33 |
return super().find_class(module, name)
|
34 |
|
35 |
+
def persistent_load(self, pid: Any) -> str:
|
36 |
+
"""Enhanced persistent ID handler with better encoding management"""
|
37 |
try:
|
38 |
+
# Handle different types of persistent IDs
|
39 |
if isinstance(pid, bytes):
|
40 |
+
return pid.decode('utf-8', errors='replace')
|
41 |
+
if isinstance(pid, (str, int, float)):
|
42 |
+
return str(pid)
|
43 |
+
return repr(pid)
|
44 |
except Exception as e:
|
45 |
+
print(f"Warning: Error in persistent_load: {str(e)}")
|
46 |
+
return repr(pid)
|
47 |
|
48 |
+
def safe_load_embeddings(file_path: str = 'embeddings.pkl') -> Dict[str, np.ndarray]:
|
49 |
+
"""Load embeddings with enhanced error handling and validation"""
|
50 |
try:
|
51 |
if not os.path.exists(file_path):
|
52 |
+
raise FileNotFoundError(f"Embeddings file not found at {file_path}")
|
|
|
53 |
|
54 |
with open(file_path, 'rb') as file:
|
55 |
unpickler = ArticleEmbeddingUnpickler(file)
|
56 |
embeddings_data = unpickler.load()
|
57 |
|
58 |
+
if not isinstance(embeddings_data, dict):
|
59 |
+
raise ValueError(f"Invalid data structure: expected dict, got {type(embeddings_data)}")
|
60 |
+
|
61 |
+
# Process and validate embeddings
|
62 |
+
valid_embeddings = {}
|
63 |
+
for key, value in embeddings_data.items():
|
64 |
+
try:
|
65 |
+
# Ensure key is a valid string
|
66 |
+
key_str = str(key).strip()
|
67 |
+
if not key_str:
|
68 |
+
continue
|
69 |
+
|
70 |
+
# Convert value to numpy array if needed
|
71 |
+
if isinstance(value, list):
|
72 |
+
value = np.array(value, dtype=np.float32)
|
73 |
+
elif isinstance(value, np.ndarray):
|
74 |
+
value = value.astype(np.float32)
|
75 |
+
else:
|
76 |
+
print(f"Skipping invalid embedding type for key {key_str}: {type(value)}")
|
77 |
continue
|
78 |
|
79 |
+
# Validate array dimensions and values
|
80 |
+
if value.ndim != 1:
|
81 |
+
print(f"Skipping invalid embedding shape for key {key_str}: {value.shape}")
|
82 |
+
continue
|
83 |
+
|
84 |
+
if np.isnan(value).any() or np.isinf(value).any():
|
85 |
+
print(f"Skipping embedding with invalid values for key {key_str}")
|
86 |
+
continue
|
87 |
|
88 |
+
valid_embeddings[key_str] = value
|
|
|
89 |
|
90 |
+
except Exception as e:
|
91 |
+
print(f"Error processing embedding for key {key}: {str(e)}")
|
92 |
+
continue
|
93 |
|
94 |
+
if not valid_embeddings:
|
95 |
+
raise ValueError("No valid embeddings found in file")
|
96 |
|
97 |
+
print(f"Successfully loaded {len(valid_embeddings)} valid embeddings")
|
98 |
+
return valid_embeddings
|
99 |
+
|
100 |
+
except Exception as e:
|
101 |
+
print(f"Error loading embeddings: {str(e)}")
|
102 |
+
raise
|
103 |
+
|
104 |
+
def safe_save_embeddings(embeddings_dict, file_path='embeddings.pkl'):
|
105 |
+
# Convert all keys to ASCII-safe strings
|
106 |
+
cleaned_embeddings = {
|
107 |
+
str(key).encode('ascii', errors='replace').decode('ascii'): value
|
108 |
+
for key, value in embeddings_dict.items()
|
109 |
+
}
|
110 |
+
|
111 |
+
with open(file_path, 'wb') as f:
|
112 |
+
pickle.dump(cleaned_embeddings, f, protocol=0)
|
113 |
|
114 |
# Models and data structures
|
115 |
class GlobalModels:
|
|
|
150 |
|
151 |
# Modified startup event handler
|
152 |
@app.on_event("startup")
|
153 |
+
@app.on_event("startup")
|
154 |
async def load_models():
|
|
|
155 |
try:
|
156 |
+
print("Starting to load embeddings...")
|
157 |
embeddings_data = safe_load_embeddings()
|
158 |
+
print(f"Embeddings data type: {type(embeddings_data)}")
|
159 |
+
if embeddings_data:
|
160 |
+
print(f"Number of embeddings: {len(embeddings_data)}")
|
161 |
+
# Print sample of keys
|
162 |
+
print("Sample keys:", list(embeddings_data.keys())[:3])
|
163 |
# Load embedding models first
|
164 |
global_models.embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
|
165 |
|