thechaiexperiment commited on
Commit
554b5f1
·
1 Parent(s): 2a5cca5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -21
app.py CHANGED
@@ -22,38 +22,74 @@ import time
22
  # Initialize FastAPI app first
23
  app = FastAPI()
24
 
25
- class CustomUnpickler(pickle.Unpickler):
 
 
 
 
 
 
 
 
 
 
 
26
  def persistent_load(self, pid):
 
27
  try:
 
28
  if isinstance(pid, bytes):
29
  pid = pid.decode('utf-8', errors='ignore')
30
- pid = str(pid).encode('ascii', errors='ignore').decode('ascii')
31
-
32
- if pid == "sentence_transformer_model":
33
- return SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
34
- return pid
35
  except Exception as e:
36
- raise pickle.UnpicklingError(f"Error handling persistent ID: {e}")
 
37
 
38
- def safe_load_embeddings():
 
39
  try:
40
- with open('embeddings.pkl', 'rb') as file:
41
- unpickler = CustomUnpickler(file)
 
 
 
 
42
  embeddings_data = unpickler.load()
43
-
 
44
  if not isinstance(embeddings_data, dict):
45
- raise ValueError("Loaded data is not a dictionary")
46
-
47
- first_key = next(iter(embeddings_data))
48
- if not isinstance(embeddings_data[first_key], (np.ndarray, list)):
49
- raise ValueError("Embeddings are not in the expected format")
50
-
51
- return embeddings_data
52
-
53
- except (FileNotFoundError, pickle.UnpicklingError, ValueError) as e:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  print(f"Error loading embeddings: {str(e)}")
55
  return None
56
 
 
 
57
  # Models and data structures
58
  class GlobalModels:
59
  embedding_model = None
@@ -91,10 +127,18 @@ class DocumentResponse(BaseModel):
91
  text: str
92
  score: float
93
 
 
94
  @app.on_event("startup")
95
  async def load_models():
96
- """Initialize all models and data on startup"""
97
  try:
 
 
 
 
 
 
 
98
  # Load embedding models first
99
  global_models.embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
100
 
 
22
  # Initialize FastAPI app first
23
  app = FastAPI()
24
 
25
+ class ArticleEmbeddingUnpickler(pickle.Unpickler):
26
+ """Custom unpickler specifically designed for article embeddings"""
27
+ def find_class(self, module, name):
28
+ # Handle numpy arrays specially
29
+ if module == 'numpy':
30
+ return getattr(np, name)
31
+ # Handle the SentenceTransformer case
32
+ if module == 'sentence_transformers.SentenceTransformer':
33
+ from sentence_transformers import SentenceTransformer
34
+ return SentenceTransformer
35
+ return super().find_class(module, name)
36
+
37
  def persistent_load(self, pid):
38
+ """Handle persistent IDs during unpickling"""
39
  try:
40
+ # Convert to string if bytes
41
  if isinstance(pid, bytes):
42
  pid = pid.decode('utf-8', errors='ignore')
43
+ return str(pid)
 
 
 
 
44
  except Exception as e:
45
+ print(f"Error in persistent_load: {str(e)}")
46
+ return str(pid)
47
 
48
+ def safe_load_embeddings(file_path='embeddings.pkl'):
49
+ """Load embeddings with enhanced error handling for article embeddings"""
50
  try:
51
+ if not os.path.exists(file_path):
52
+ print(f"Embeddings file not found at {file_path}")
53
+ return None
54
+
55
+ with open(file_path, 'rb') as file:
56
+ unpickler = ArticleEmbeddingUnpickler(file)
57
  embeddings_data = unpickler.load()
58
+
59
+ # Validate the dictionary structure
60
  if not isinstance(embeddings_data, dict):
61
+ print(f"Invalid data structure: expected dict, got {type(embeddings_data)}")
62
+ return None
63
+
64
+ # Validate each embedding
65
+ valid_embeddings = {}
66
+ for key, value in embeddings_data.items():
67
+ # Ensure key is string and value is numpy array
68
+ try:
69
+ key_str = str(key)
70
+ if isinstance(value, list):
71
+ value = np.array(value)
72
+ if isinstance(value, np.ndarray):
73
+ valid_embeddings[key_str] = value
74
+ else:
75
+ print(f"Skipping invalid embedding for key {key}: {type(value)}")
76
+ except Exception as e:
77
+ print(f"Error processing embedding {key}: {str(e)}")
78
+ continue
79
+
80
+ if not valid_embeddings:
81
+ print("No valid embeddings found in file")
82
+ return None
83
+
84
+ print(f"Successfully loaded {len(valid_embeddings)} embeddings")
85
+ return valid_embeddings
86
+
87
+ except Exception as e:
88
  print(f"Error loading embeddings: {str(e)}")
89
  return None
90
 
91
+
92
+
93
  # Models and data structures
94
  class GlobalModels:
95
  embedding_model = None
 
127
  text: str
128
  score: float
129
 
130
+ # Modified startup event handler
131
  @app.on_event("startup")
132
  async def load_models():
133
+ """Initialize models with enhanced embeddings loading"""
134
  try:
135
+ # Load embeddings first
136
+ embeddings_data = safe_load_embeddings()
137
+ if embeddings_data is None:
138
+ raise HTTPException(
139
+ status_code=500,
140
+ detail="Failed to load embeddings data. Check logs for details."
141
+ )
142
  # Load embedding models first
143
  global_models.embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
144