Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -131,8 +131,15 @@ class CustomRagRetriever:
|
|
| 131 |
# Load PCA if it exists
|
| 132 |
pca_path = os.path.join(os.path.dirname(model_path), "pca_model.pkl")
|
| 133 |
if os.path.exists(pca_path):
|
| 134 |
-
|
| 135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
|
| 137 |
def retrieve(self, query, top_k=10):
|
| 138 |
print(f"Retrieving properties for query: {query}")
|
|
@@ -145,11 +152,15 @@ class CustomRagRetriever:
|
|
| 145 |
device=device,
|
| 146 |
normalize_embeddings=True
|
| 147 |
)
|
| 148 |
-
# Convert to
|
| 149 |
query_embedding = query_embedding.astype(np.float32)
|
| 150 |
|
|
|
|
| 151 |
if self.pca is not None:
|
| 152 |
-
|
|
|
|
|
|
|
|
|
|
| 153 |
|
| 154 |
distances, indices = self.index.search(query_embedding, top_k)
|
| 155 |
|
|
|
|
| 131 |
# Load PCA if it exists
|
| 132 |
pca_path = os.path.join(os.path.dirname(model_path), "pca_model.pkl")
|
| 133 |
if os.path.exists(pca_path):
|
| 134 |
+
try:
|
| 135 |
+
with open(pca_path, 'rb') as f:
|
| 136 |
+
self.pca = pickle.load(f)
|
| 137 |
+
except ModuleNotFoundError:
|
| 138 |
+
print("Warning: Could not load PCA model due to numpy version mismatch. Continuing without PCA.")
|
| 139 |
+
self.pca = None
|
| 140 |
+
except Exception as e:
|
| 141 |
+
print(f"Warning: Error loading PCA model: {str(e)}. Continuing without PCA.")
|
| 142 |
+
self.pca = None
|
| 143 |
|
| 144 |
def retrieve(self, query, top_k=10):
|
| 145 |
print(f"Retrieving properties for query: {query}")
|
|
|
|
| 152 |
device=device,
|
| 153 |
normalize_embeddings=True
|
| 154 |
)
|
| 155 |
+
# Convert to FP32
|
| 156 |
query_embedding = query_embedding.astype(np.float32)
|
| 157 |
|
| 158 |
+
# Only apply PCA if it was successfully loaded
|
| 159 |
if self.pca is not None:
|
| 160 |
+
try:
|
| 161 |
+
query_embedding = self.pca.transform(query_embedding)
|
| 162 |
+
except Exception as e:
|
| 163 |
+
print(f"Warning: Error applying PCA transformation: {str(e)}")
|
| 164 |
|
| 165 |
distances, indices = self.index.search(query_embedding, top_k)
|
| 166 |
|