myshirk commited on
Commit
b1d5a3b
·
verified ·
1 Parent(s): c59bc5d

enable showing top-k searches for semantic query

Browse files
Files changed (1) hide show
  1. app.py +10 -1
app.py CHANGED
@@ -70,6 +70,7 @@ w_countries = pn.widgets.MultiSelect(name="Countries", options=country_opts)
70
  w_years = pn.widgets.MultiSelect(name="Years", options=year_opts)
71
  w_keyword = pn.widgets.TextInput(name="Keyword Search", placeholder="Search questions or answers with exact string matching")
72
  w_group = pn.widgets.Checkbox(name="Group by Question Text", value=False)
 
73
 
74
  w_semquery = pn.widgets.TextInput(name="Semantic Query", placeholder="LLM-powered semantic search")
75
  w_search_button = pn.widgets.Button(name="Search", button_type="primary")
@@ -133,10 +134,10 @@ def search(event=None):
133
  result_table.value = _group_by_question(filt.iloc[0:0]) if w_group.value else pd.DataFrame(columns=["Score", "country", "year", "question_text", "answer_text"])
134
  return
135
 
 
136
  filtered_embs = emb_tensor[filtered_indices]
137
  q_vec = model.encode(query, convert_to_tensor=True, device="cpu").cpu()
138
  sims = util.cos_sim(q_vec, filtered_embs)[0]
139
- top_k = min(50, len(filtered_indices))
140
  top_vals, top_idx = torch.topk(sims, k=top_k)
141
 
142
  top_filtered_ids = [filtered_ids[i] for i in top_idx.tolist()]
@@ -148,11 +149,13 @@ def search(event=None):
148
  result_table.value = _group_by_question(sem_rows.drop(columns=["Score"])) if w_group.value else sem_rows[["Score", "country", "year", "question_text", "answer_text"]]
149
 
150
 
 
151
  def clear_filters(event=None):
152
  w_countries.value = []
153
  w_years.value = []
154
  w_keyword.value = ""
155
  w_semquery.value = ""
 
156
  result_table.value = df[["country", "year", "question_text", "answer_text"]].copy()
157
 
158
  w_search_button.on_click(search)
@@ -167,6 +170,12 @@ w_years.param.watch(lambda e: search(), 'value')
167
  w_semquery.param.watch(lambda e: search(), 'enter_pressed')
168
  w_keyword.param.watch(lambda e: search(), 'enter_pressed')
169
 
 
 
 
 
 
 
170
  # Show all data at startup
171
  result_table.value = df[["country", "year", "question_text", "answer_text"]].copy()
172
 
 
70
  w_years = pn.widgets.MultiSelect(name="Years", options=year_opts)
71
  w_keyword = pn.widgets.TextInput(name="Keyword Search", placeholder="Search questions or answers with exact string matching")
72
  w_group = pn.widgets.Checkbox(name="Group by Question Text", value=False)
73
+ w_topk = pn.widgets.Select(name="Top-K (semantic)", options=[5, 10, 20, 50, 100], value=50, disabled=True)
74
 
75
  w_semquery = pn.widgets.TextInput(name="Semantic Query", placeholder="LLM-powered semantic search")
76
  w_search_button = pn.widgets.Button(name="Search", button_type="primary")
 
134
  result_table.value = _group_by_question(filt.iloc[0:0]) if w_group.value else pd.DataFrame(columns=["Score", "country", "year", "question_text", "answer_text"])
135
  return
136
 
137
+ top_k = min(int(w_topk.value), len(filtered_indices))
138
  filtered_embs = emb_tensor[filtered_indices]
139
  q_vec = model.encode(query, convert_to_tensor=True, device="cpu").cpu()
140
  sims = util.cos_sim(q_vec, filtered_embs)[0]
 
141
  top_vals, top_idx = torch.topk(sims, k=top_k)
142
 
143
  top_filtered_ids = [filtered_ids[i] for i in top_idx.tolist()]
 
149
  result_table.value = _group_by_question(sem_rows.drop(columns=["Score"])) if w_group.value else sem_rows[["Score", "country", "year", "question_text", "answer_text"]]
150
 
151
 
152
+
153
  def clear_filters(event=None):
154
  w_countries.value = []
155
  w_years.value = []
156
  w_keyword.value = ""
157
  w_semquery.value = ""
158
+ w_topk.disabled = True
159
  result_table.value = df[["country", "year", "question_text", "answer_text"]].copy()
160
 
161
  w_search_button.on_click(search)
 
170
  w_semquery.param.watch(lambda e: search(), 'enter_pressed')
171
  w_keyword.param.watch(lambda e: search(), 'enter_pressed')
172
 
173
+ # Enable/disable Top-K based on semantic query presence
174
+ def _toggle_topk_disabled(event=None):
175
+ w_topk.disabled = (w_semquery.value.strip() == '')
176
+ _toggle_topk_disabled()
177
+ w_semquery.param.watch(lambda e: _toggle_topk_disabled(), 'value')
178
+
179
  # Show all data at startup
180
  result_table.value = df[["country", "year", "question_text", "answer_text"]].copy()
181