thechaiexperiment commited on
Commit
5635397
·
1 Parent(s): f342c38

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -23
app.py CHANGED
@@ -86,15 +86,18 @@ import pickle
86
  import numpy as np
87
  import os
88
  from typing import Dict, Optional
 
89
 
90
- class EmbeddingsUnpickler(pickle.Unpickler):
91
  def persistent_load(self, pid):
92
- # Handle persistent IDs by returning them as-is
93
- return pid
 
 
94
 
95
  def load_embeddings(embeddings_path: str = 'embeddings.pkl') -> Optional[Dict[str, np.ndarray]]:
96
  """
97
- Load embeddings from a pickle file with support for persistent IDs.
98
 
99
  Args:
100
  embeddings_path (str): Path to the pickle file containing embeddings
@@ -107,37 +110,53 @@ def load_embeddings(embeddings_path: str = 'embeddings.pkl') -> Optional[Dict[st
107
  return None
108
 
109
  try:
110
- with open(embeddings_path, 'rb') as f:
111
- # Use custom unpickler with persistent_load support
112
- unpickler = EmbeddingsUnpickler(f)
113
- embeddings = unpickler.load()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
  # Validate the loaded data
116
  if not isinstance(embeddings, dict):
117
  print(f"Error: Expected dict, got {type(embeddings)}")
118
  return None
119
 
120
- # Convert values to numpy arrays if they aren't already
121
  processed_embeddings = {}
122
  for key, value in embeddings.items():
123
- # Handle both direct arrays and persistent IDs
124
- if isinstance(value, (list, np.ndarray)):
125
- processed_embeddings[key] = np.array(value)
126
- else:
127
- # If it's a persistent ID, convert it to a numpy array
128
- try:
129
- processed_embeddings[key] = np.array(value)
130
- except Exception as e:
131
- print(f"Warning: Could not convert embedding for {key}: {e}")
132
- continue
133
 
134
- # Print sample for verification
135
  if processed_embeddings:
136
  sample_key = next(iter(processed_embeddings))
137
  print(f"Data type: {type(processed_embeddings)}")
138
- print(f"First few keys and values:")
139
- print(f"Key: {sample_key}, Value: {processed_embeddings[sample_key][:20]}")
140
- print(f"Successfully loaded {len(processed_embeddings)} embeddings")
141
  return processed_embeddings
142
  else:
143
  print("Error: No valid embeddings were processed")
@@ -145,8 +164,10 @@ def load_embeddings(embeddings_path: str = 'embeddings.pkl') -> Optional[Dict[st
145
 
146
  except Exception as e:
147
  print(f"Error loading embeddings: {str(e)}")
 
148
  return None
149
 
 
150
  def load_documents_data():
151
  """Load document data with error handling"""
152
  try:
 
86
  import numpy as np
87
  import os
88
  from typing import Dict, Optional
89
+ import codecs
90
 
91
+ class LFSEmbeddingsUnpickler(pickle.Unpickler):
92
  def persistent_load(self, pid):
93
+ # Ensure persistent ID is ASCII string
94
+ if isinstance(pid, bytes):
95
+ return pid.decode('ascii')
96
+ return str(pid)
97
 
98
  def load_embeddings(embeddings_path: str = 'embeddings.pkl') -> Optional[Dict[str, np.ndarray]]:
99
  """
100
+ Load embeddings from a pickle file with support for Git LFS and protocol 0 requirements.
101
 
102
  Args:
103
  embeddings_path (str): Path to the pickle file containing embeddings
 
110
  return None
111
 
112
  try:
113
+ # Open file in binary mode with buffering
114
+ with open(embeddings_path, 'rb', buffering=1024*1024) as f:
115
+ # Check if it's a Git LFS pointer file
116
+ first_line = f.peek(100)[:100].decode('utf-8', errors='ignore')
117
+ if 'version https://git-lfs.github.com/spec/' in first_line:
118
+ print("Warning: This appears to be a Git LFS pointer file.")
119
+ print("Please ensure you've properly downloaded the actual embeddings file using Git LFS")
120
+ return None
121
+
122
+ # Use custom unpickler with ASCII string handling
123
+ unpickler = LFSEmbeddingsUnpickler(f)
124
+
125
+ # Set encoding for protocol 0 compatibility
126
+ if hasattr(unpickler, 'encoding'):
127
+ unpickler.encoding = 'ascii'
128
+
129
+ try:
130
+ embeddings = unpickler.load()
131
+ except UnicodeDecodeError:
132
+ # If ASCII decode fails, try UTF-8
133
+ f.seek(0)
134
+ unpickler = pickle.Unpickler(f)
135
+ embeddings = unpickler.load()
136
 
137
  # Validate the loaded data
138
  if not isinstance(embeddings, dict):
139
  print(f"Error: Expected dict, got {type(embeddings)}")
140
  return None
141
 
142
+ # Convert values to numpy arrays
143
  processed_embeddings = {}
144
  for key, value in embeddings.items():
145
+ try:
146
+ # Handle various input types
147
+ if isinstance(value, np.ndarray):
148
+ processed_embeddings[key] = value
149
+ else:
150
+ processed_embeddings[key] = np.array(value, dtype=np.float32)
151
+ except Exception as e:
152
+ print(f"Warning: Could not process embedding for {key}: {e}")
153
+ continue
 
154
 
 
155
  if processed_embeddings:
156
  sample_key = next(iter(processed_embeddings))
157
  print(f"Data type: {type(processed_embeddings)}")
158
+ print(f"Total embeddings loaded: {len(processed_embeddings)}")
159
+ print(f"Sample embedding shape: {processed_embeddings[sample_key].shape}")
 
160
  return processed_embeddings
161
  else:
162
  print("Error: No valid embeddings were processed")
 
164
 
165
  except Exception as e:
166
  print(f"Error loading embeddings: {str(e)}")
167
+ print("If using Git LFS, ensure you've run 'git lfs pull' to download the actual file")
168
  return None
169
 
170
+
171
  def load_documents_data():
172
  """Load document data with error handling"""
173
  try: