thechaiexperiment commited on
Commit
3b1c99a
·
1 Parent(s): 554b5f1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -49
app.py CHANGED
@@ -23,72 +23,93 @@ import time
23
  app = FastAPI()
24
 
25
  class ArticleEmbeddingUnpickler(pickle.Unpickler):
26
- """Custom unpickler specifically designed for article embeddings"""
27
- def find_class(self, module, name):
28
- # Handle numpy arrays specially
29
  if module == 'numpy':
30
  return getattr(np, name)
31
- # Handle the SentenceTransformer case
32
  if module == 'sentence_transformers.SentenceTransformer':
33
  from sentence_transformers import SentenceTransformer
34
  return SentenceTransformer
35
  return super().find_class(module, name)
36
 
37
- def persistent_load(self, pid):
38
- """Handle persistent IDs during unpickling"""
39
  try:
40
- # Convert to string if bytes
41
  if isinstance(pid, bytes):
42
- pid = pid.decode('utf-8', errors='ignore')
43
- return str(pid)
 
 
44
  except Exception as e:
45
- print(f"Error in persistent_load: {str(e)}")
46
- return str(pid)
47
 
48
- def safe_load_embeddings(file_path='embeddings.pkl'):
49
- """Load embeddings with enhanced error handling for article embeddings"""
50
  try:
51
  if not os.path.exists(file_path):
52
- print(f"Embeddings file not found at {file_path}")
53
- return None
54
 
55
  with open(file_path, 'rb') as file:
56
  unpickler = ArticleEmbeddingUnpickler(file)
57
  embeddings_data = unpickler.load()
58
 
59
- # Validate the dictionary structure
60
- if not isinstance(embeddings_data, dict):
61
- print(f"Invalid data structure: expected dict, got {type(embeddings_data)}")
62
- return None
63
-
64
- # Validate each embedding
65
- valid_embeddings = {}
66
- for key, value in embeddings_data.items():
67
- # Ensure key is string and value is numpy array
68
- try:
69
- key_str = str(key)
70
- if isinstance(value, list):
71
- value = np.array(value)
72
- if isinstance(value, np.ndarray):
73
- valid_embeddings[key_str] = value
74
- else:
75
- print(f"Skipping invalid embedding for key {key}: {type(value)}")
76
- except Exception as e:
77
- print(f"Error processing embedding {key}: {str(e)}")
78
  continue
79
 
80
- if not valid_embeddings:
81
- print("No valid embeddings found in file")
82
- return None
 
 
 
 
 
83
 
84
- print(f"Successfully loaded {len(valid_embeddings)} embeddings")
85
- return valid_embeddings
86
 
87
- except Exception as e:
88
- print(f"Error loading embeddings: {str(e)}")
89
- return None
90
 
 
 
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
  # Models and data structures
94
  class GlobalModels:
@@ -129,16 +150,16 @@ class DocumentResponse(BaseModel):
129
 
130
  # Modified startup event handler
131
  @app.on_event("startup")
 
132
  async def load_models():
133
- """Initialize models with enhanced embeddings loading"""
134
  try:
135
- # Load embeddings first
136
  embeddings_data = safe_load_embeddings()
137
- if embeddings_data is None:
138
- raise HTTPException(
139
- status_code=500,
140
- detail="Failed to load embeddings data. Check logs for details."
141
- )
142
  # Load embedding models first
143
  global_models.embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
144
 
 
23
  app = FastAPI()
24
 
25
  class ArticleEmbeddingUnpickler(pickle.Unpickler):
26
+ """Custom unpickler for article embeddings with enhanced persistence handling"""
27
+ def find_class(self, module: str, name: str) -> Any:
 
28
  if module == 'numpy':
29
  return getattr(np, name)
 
30
  if module == 'sentence_transformers.SentenceTransformer':
31
  from sentence_transformers import SentenceTransformer
32
  return SentenceTransformer
33
  return super().find_class(module, name)
34
 
35
+ def persistent_load(self, pid: Any) -> str:
36
+ """Enhanced persistent ID handler with better encoding management"""
37
  try:
38
+ # Handle different types of persistent IDs
39
  if isinstance(pid, bytes):
40
+ return pid.decode('utf-8', errors='replace')
41
+ if isinstance(pid, (str, int, float)):
42
+ return str(pid)
43
+ return repr(pid)
44
  except Exception as e:
45
+ print(f"Warning: Error in persistent_load: {str(e)}")
46
+ return repr(pid)
47
 
48
+ def safe_load_embeddings(file_path: str = 'embeddings.pkl') -> Dict[str, np.ndarray]:
49
+ """Load embeddings with enhanced error handling and validation"""
50
  try:
51
  if not os.path.exists(file_path):
52
+ raise FileNotFoundError(f"Embeddings file not found at {file_path}")
 
53
 
54
  with open(file_path, 'rb') as file:
55
  unpickler = ArticleEmbeddingUnpickler(file)
56
  embeddings_data = unpickler.load()
57
 
58
+ if not isinstance(embeddings_data, dict):
59
+ raise ValueError(f"Invalid data structure: expected dict, got {type(embeddings_data)}")
60
+
61
+ # Process and validate embeddings
62
+ valid_embeddings = {}
63
+ for key, value in embeddings_data.items():
64
+ try:
65
+ # Ensure key is a valid string
66
+ key_str = str(key).strip()
67
+ if not key_str:
68
+ continue
69
+
70
+ # Convert value to numpy array if needed
71
+ if isinstance(value, list):
72
+ value = np.array(value, dtype=np.float32)
73
+ elif isinstance(value, np.ndarray):
74
+ value = value.astype(np.float32)
75
+ else:
76
+ print(f"Skipping invalid embedding type for key {key_str}: {type(value)}")
77
  continue
78
 
79
+ # Validate array dimensions and values
80
+ if value.ndim != 1:
81
+ print(f"Skipping invalid embedding shape for key {key_str}: {value.shape}")
82
+ continue
83
+
84
+ if np.isnan(value).any() or np.isinf(value).any():
85
+ print(f"Skipping embedding with invalid values for key {key_str}")
86
+ continue
87
 
88
+ valid_embeddings[key_str] = value
 
89
 
90
+ except Exception as e:
91
+ print(f"Error processing embedding for key {key}: {str(e)}")
92
+ continue
93
 
94
+ if not valid_embeddings:
95
+ raise ValueError("No valid embeddings found in file")
96
 
97
+ print(f"Successfully loaded {len(valid_embeddings)} valid embeddings")
98
+ return valid_embeddings
99
+
100
+ except Exception as e:
101
+ print(f"Error loading embeddings: {str(e)}")
102
+ raise
103
+
104
+ def safe_save_embeddings(embeddings_dict, file_path='embeddings.pkl'):
105
+ # Convert all keys to ASCII-safe strings
106
+ cleaned_embeddings = {
107
+ str(key).encode('ascii', errors='replace').decode('ascii'): value
108
+ for key, value in embeddings_dict.items()
109
+ }
110
+
111
+ with open(file_path, 'wb') as f:
112
+ pickle.dump(cleaned_embeddings, f, protocol=0)
113
 
114
  # Models and data structures
115
  class GlobalModels:
 
150
 
151
  # Modified startup event handler
152
  @app.on_event("startup")
153
+ @app.on_event("startup")
154
  async def load_models():
 
155
  try:
156
+ print("Starting to load embeddings...")
157
  embeddings_data = safe_load_embeddings()
158
+ print(f"Embeddings data type: {type(embeddings_data)}")
159
+ if embeddings_data:
160
+ print(f"Number of embeddings: {len(embeddings_data)}")
161
+ # Print sample of keys
162
+ print("Sample keys:", list(embeddings_data.keys())[:3])
163
  # Load embedding models first
164
  global_models.embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
165