sksameermujahid commited on
Commit
fae03ad
·
verified ·
1 Parent(s): 6181b5d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -4
app.py CHANGED
@@ -131,8 +131,15 @@ class CustomRagRetriever:
131
  # Load PCA if it exists
132
  pca_path = os.path.join(os.path.dirname(model_path), "pca_model.pkl")
133
  if os.path.exists(pca_path):
134
- with open(pca_path, 'rb') as f:
135
- self.pca = pickle.load(f)
 
 
 
 
 
 
 
136
 
137
  def retrieve(self, query, top_k=10):
138
  print(f"Retrieving properties for query: {query}")
@@ -145,11 +152,15 @@ class CustomRagRetriever:
145
  device=device,
146
  normalize_embeddings=True
147
  )
148
- # Convert to FP16 after encoding
149
  query_embedding = query_embedding.astype(np.float32)
150
 
 
151
  if self.pca is not None:
152
- query_embedding = self.pca.transform(query_embedding)
 
 
 
153
 
154
  distances, indices = self.index.search(query_embedding, top_k)
155
 
 
131
  # Load PCA if it exists
132
  pca_path = os.path.join(os.path.dirname(model_path), "pca_model.pkl")
133
  if os.path.exists(pca_path):
134
+ try:
135
+ with open(pca_path, 'rb') as f:
136
+ self.pca = pickle.load(f)
137
+ except ModuleNotFoundError:
138
+ print("Warning: Could not load PCA model due to numpy version mismatch. Continuing without PCA.")
139
+ self.pca = None
140
+ except Exception as e:
141
+ print(f"Warning: Error loading PCA model: {str(e)}. Continuing without PCA.")
142
+ self.pca = None
143
 
144
  def retrieve(self, query, top_k=10):
145
  print(f"Retrieving properties for query: {query}")
 
152
  device=device,
153
  normalize_embeddings=True
154
  )
155
+ # Convert to FP32
156
  query_embedding = query_embedding.astype(np.float32)
157
 
158
+ # Only apply PCA if it was successfully loaded
159
  if self.pca is not None:
160
+ try:
161
+ query_embedding = self.pca.transform(query_embedding)
162
+ except Exception as e:
163
+ print(f"Warning: Error applying PCA transformation: {str(e)}")
164
 
165
  distances, indices = self.index.search(query_embedding, top_k)
166