myshirk commited on
Commit
5968656
·
verified ·
1 Parent(s): 1e66d1d

fix search logic error

Browse files
Files changed (1) hide show
  1. app.py +33 -6
app.py CHANGED
@@ -119,17 +119,44 @@ def semantic_search(event=None):
119
  if not query:
120
  return
121
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  model, ids_list, emb_tensor = get_semantic_resources()
 
 
 
 
 
 
 
 
 
 
123
  q_vec = model.encode(query, convert_to_tensor=True, device="cpu").cpu()
124
-
125
- sims = util.cos_sim(q_vec, emb_tensor)[0]
126
  top_vals, top_idx = torch.topk(sims, k=50)
127
-
128
- sem_ids = [ids_list[i] for i in top_idx.tolist()]
129
- sem_rows = df.loc[df["id"].isin(sem_ids)].copy()
130
- score_map = dict(zip(sem_ids, top_vals.tolist()))
131
  sem_rows["Score"] = sem_rows["id"].map(score_map)
132
  sem_rows = sem_rows.sort_values("Score", ascending=False)
 
 
 
 
133
 
134
  filt = df.copy()
135
  if w_countries.value:
 
119
  if not query:
120
  return
121
 
122
+ # Step 1: Filter the full dataframe
123
+ filt = df.copy()
124
+ if w_countries.value:
125
+ filt = filt[filt["country"].isin(w_countries.value)]
126
+ if w_years.value:
127
+ filt = filt[filt["year"].isin(w_years.value)]
128
+ if w_keyword.value:
129
+ filt = filt[
130
+ filt["question_text"].str.contains(w_keyword.value, case=False, na=False) |
131
+ filt["answer_text"].str.contains(w_keyword.value, case=False, na=False) |
132
+ filt["question_code"].astype(str).str.contains(w_keyword.value, case=False, na=False)
133
+ ]
134
+
135
+ # Step 2: Load only embeddings for the filtered rows
136
  model, ids_list, emb_tensor = get_semantic_resources()
137
+
138
+ # Create a mask for filtered IDs
139
+ filtered_ids = filt["id"].tolist()
140
+ id_to_index = {id_: i for i, id_ in enumerate(ids_list)}
141
+ filtered_indices = [id_to_index[id_] for id_ in filtered_ids if id_ in id_to_index]
142
+
143
+ # Subset the embedding tensor
144
+ filtered_embs = emb_tensor[filtered_indices]
145
+
146
+ # Step 3: Semantic search only within filtered subset
147
  q_vec = model.encode(query, convert_to_tensor=True, device="cpu").cpu()
148
+ sims = util.cos_sim(q_vec, filtered_embs)[0]
 
149
  top_vals, top_idx = torch.topk(sims, k=50)
150
+
151
+ top_filtered_ids = [filtered_ids[i] for i in top_idx.tolist()]
152
+ sem_rows = filt[filt["id"].isin(top_filtered_ids)].copy()
153
+ score_map = dict(zip(top_filtered_ids, top_vals.tolist()))
154
  sem_rows["Score"] = sem_rows["id"].map(score_map)
155
  sem_rows = sem_rows.sort_values("Score", ascending=False)
156
+
157
+ # Final output
158
+ result_table.value = sem_rows[["Score", "country", "year", "question_text", "answer_text"]]
159
+
160
 
161
  filt = df.copy()
162
  if w_countries.value: