Spaces:
Running
Running
fix search logic error
Browse files
app.py
CHANGED
@@ -119,17 +119,44 @@ def semantic_search(event=None):
|
|
119 |
if not query:
|
120 |
return
|
121 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
model, ids_list, emb_tensor = get_semantic_resources()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
q_vec = model.encode(query, convert_to_tensor=True, device="cpu").cpu()
|
124 |
-
|
125 |
-
sims = util.cos_sim(q_vec, emb_tensor)[0]
|
126 |
top_vals, top_idx = torch.topk(sims, k=50)
|
127 |
-
|
128 |
-
|
129 |
-
sem_rows =
|
130 |
-
score_map = dict(zip(
|
131 |
sem_rows["Score"] = sem_rows["id"].map(score_map)
|
132 |
sem_rows = sem_rows.sort_values("Score", ascending=False)
|
|
|
|
|
|
|
|
|
133 |
|
134 |
filt = df.copy()
|
135 |
if w_countries.value:
|
|
|
119 |
if not query:
|
120 |
return
|
121 |
|
122 |
+
# Step 1: Filter the full dataframe
|
123 |
+
filt = df.copy()
|
124 |
+
if w_countries.value:
|
125 |
+
filt = filt[filt["country"].isin(w_countries.value)]
|
126 |
+
if w_years.value:
|
127 |
+
filt = filt[filt["year"].isin(w_years.value)]
|
128 |
+
if w_keyword.value:
|
129 |
+
filt = filt[
|
130 |
+
filt["question_text"].str.contains(w_keyword.value, case=False, na=False) |
|
131 |
+
filt["answer_text"].str.contains(w_keyword.value, case=False, na=False) |
|
132 |
+
filt["question_code"].astype(str).str.contains(w_keyword.value, case=False, na=False)
|
133 |
+
]
|
134 |
+
|
135 |
+
# Step 2: Load only embeddings for the filtered rows
|
136 |
model, ids_list, emb_tensor = get_semantic_resources()
|
137 |
+
|
138 |
+
# Create a mask for filtered IDs
|
139 |
+
filtered_ids = filt["id"].tolist()
|
140 |
+
id_to_index = {id_: i for i, id_ in enumerate(ids_list)}
|
141 |
+
filtered_indices = [id_to_index[id_] for id_ in filtered_ids if id_ in id_to_index]
|
142 |
+
|
143 |
+
# Subset the embedding tensor
|
144 |
+
filtered_embs = emb_tensor[filtered_indices]
|
145 |
+
|
146 |
+
# Step 3: Semantic search only within filtered subset
|
147 |
q_vec = model.encode(query, convert_to_tensor=True, device="cpu").cpu()
|
148 |
+
sims = util.cos_sim(q_vec, filtered_embs)[0]
|
|
|
149 |
top_vals, top_idx = torch.topk(sims, k=50)
|
150 |
+
|
151 |
+
top_filtered_ids = [filtered_ids[i] for i in top_idx.tolist()]
|
152 |
+
sem_rows = filt[filt["id"].isin(top_filtered_ids)].copy()
|
153 |
+
score_map = dict(zip(top_filtered_ids, top_vals.tolist()))
|
154 |
sem_rows["Score"] = sem_rows["id"].map(score_map)
|
155 |
sem_rows = sem_rows.sort_values("Score", ascending=False)
|
156 |
+
|
157 |
+
# Final output
|
158 |
+
result_table.value = sem_rows[["Score", "country", "year", "question_text", "answer_text"]]
|
159 |
+
|
160 |
|
161 |
filt = df.copy()
|
162 |
if w_countries.value:
|