Spaces:
Running
Running
enable showing top-k searches for semantic query
Browse files
app.py
CHANGED
@@ -70,6 +70,7 @@ w_countries = pn.widgets.MultiSelect(name="Countries", options=country_opts)
|
|
70 |
w_years = pn.widgets.MultiSelect(name="Years", options=year_opts)
|
71 |
w_keyword = pn.widgets.TextInput(name="Keyword Search", placeholder="Search questions or answers with exact string matching")
|
72 |
w_group = pn.widgets.Checkbox(name="Group by Question Text", value=False)
|
|
|
73 |
|
74 |
w_semquery = pn.widgets.TextInput(name="Semantic Query", placeholder="LLM-powered semantic search")
|
75 |
w_search_button = pn.widgets.Button(name="Search", button_type="primary")
|
@@ -133,10 +134,10 @@ def search(event=None):
|
|
133 |
result_table.value = _group_by_question(filt.iloc[0:0]) if w_group.value else pd.DataFrame(columns=["Score", "country", "year", "question_text", "answer_text"])
|
134 |
return
|
135 |
|
|
|
136 |
filtered_embs = emb_tensor[filtered_indices]
|
137 |
q_vec = model.encode(query, convert_to_tensor=True, device="cpu").cpu()
|
138 |
sims = util.cos_sim(q_vec, filtered_embs)[0]
|
139 |
-
top_k = min(50, len(filtered_indices))
|
140 |
top_vals, top_idx = torch.topk(sims, k=top_k)
|
141 |
|
142 |
top_filtered_ids = [filtered_ids[i] for i in top_idx.tolist()]
|
@@ -148,11 +149,13 @@ def search(event=None):
|
|
148 |
result_table.value = _group_by_question(sem_rows.drop(columns=["Score"])) if w_group.value else sem_rows[["Score", "country", "year", "question_text", "answer_text"]]
|
149 |
|
150 |
|
|
|
151 |
def clear_filters(event=None):
|
152 |
w_countries.value = []
|
153 |
w_years.value = []
|
154 |
w_keyword.value = ""
|
155 |
w_semquery.value = ""
|
|
|
156 |
result_table.value = df[["country", "year", "question_text", "answer_text"]].copy()
|
157 |
|
158 |
w_search_button.on_click(search)
|
@@ -167,6 +170,12 @@ w_years.param.watch(lambda e: search(), 'value')
|
|
167 |
w_semquery.param.watch(lambda e: search(), 'enter_pressed')
|
168 |
w_keyword.param.watch(lambda e: search(), 'enter_pressed')
|
169 |
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
# Show all data at startup
|
171 |
result_table.value = df[["country", "year", "question_text", "answer_text"]].copy()
|
172 |
|
|
|
70 |
w_years = pn.widgets.MultiSelect(name="Years", options=year_opts)
|
71 |
w_keyword = pn.widgets.TextInput(name="Keyword Search", placeholder="Search questions or answers with exact string matching")
|
72 |
w_group = pn.widgets.Checkbox(name="Group by Question Text", value=False)
|
73 |
+
w_topk = pn.widgets.Select(name="Top-K (semantic)", options=[5, 10, 20, 50, 100], value=50, disabled=True)
|
74 |
|
75 |
w_semquery = pn.widgets.TextInput(name="Semantic Query", placeholder="LLM-powered semantic search")
|
76 |
w_search_button = pn.widgets.Button(name="Search", button_type="primary")
|
|
|
134 |
result_table.value = _group_by_question(filt.iloc[0:0]) if w_group.value else pd.DataFrame(columns=["Score", "country", "year", "question_text", "answer_text"])
|
135 |
return
|
136 |
|
137 |
+
top_k = min(int(w_topk.value), len(filtered_indices))
|
138 |
filtered_embs = emb_tensor[filtered_indices]
|
139 |
q_vec = model.encode(query, convert_to_tensor=True, device="cpu").cpu()
|
140 |
sims = util.cos_sim(q_vec, filtered_embs)[0]
|
|
|
141 |
top_vals, top_idx = torch.topk(sims, k=top_k)
|
142 |
|
143 |
top_filtered_ids = [filtered_ids[i] for i in top_idx.tolist()]
|
|
|
149 |
result_table.value = _group_by_question(sem_rows.drop(columns=["Score"])) if w_group.value else sem_rows[["Score", "country", "year", "question_text", "answer_text"]]
|
150 |
|
151 |
|
152 |
+
|
153 |
def clear_filters(event=None):
|
154 |
w_countries.value = []
|
155 |
w_years.value = []
|
156 |
w_keyword.value = ""
|
157 |
w_semquery.value = ""
|
158 |
+
w_topk.disabled = True
|
159 |
result_table.value = df[["country", "year", "question_text", "answer_text"]].copy()
|
160 |
|
161 |
w_search_button.on_click(search)
|
|
|
170 |
w_semquery.param.watch(lambda e: search(), 'enter_pressed')
|
171 |
w_keyword.param.watch(lambda e: search(), 'enter_pressed')
|
172 |
|
173 |
+
# Enable/disable Top-K based on semantic query presence
|
174 |
+
def _toggle_topk_disabled(event=None):
|
175 |
+
w_topk.disabled = (w_semquery.value.strip() == '')
|
176 |
+
_toggle_topk_disabled()
|
177 |
+
w_semquery.param.watch(lambda e: _toggle_topk_disabled(), 'value')
|
178 |
+
|
179 |
# Show all data at startup
|
180 |
result_table.value = df[["country", "year", "question_text", "answer_text"]].copy()
|
181 |
|