Spaces:
Running
Running
# app_panel.py β Panel-based CGD Survey Explorer | |
import os, io, json, gc | |
import panel as pn | |
import pandas as pd | |
import boto3, torch | |
from sentence_transformers import SentenceTransformer, util | |
import psycopg2 | |
pn.extension() | |
# βββββββββββββββββββββββββββββββββββββββββββββββ | |
# 1) Data / Embeddings Loaders | |
# βββββββββββββββββββββββββββββββββββββββββββββββ | |
DB_HOST = os.getenv("DB_HOST") | |
DB_PORT = os.getenv("DB_PORT", "5432") | |
DB_NAME = os.getenv("DB_NAME") | |
DB_USER = os.getenv("DB_USER") | |
DB_PASSWORD = os.getenv("DB_PASSWORD") | |
def get_data(): | |
conn = psycopg2.connect( | |
host=DB_HOST, port=DB_PORT, | |
dbname=DB_NAME, user=DB_USER, password=DB_PASSWORD, | |
sslmode="require" | |
) | |
df_ = pd.read_sql_query(""" | |
SELECT id, country, year, section, | |
question_code, question_text, | |
answer_code, answer_text | |
FROM survey_info; | |
""", conn) | |
conn.close() | |
return df_ | |
df = get_data() | |
row_lookup = {row.id: i for i, row in df.iterrows()} | |
def load_embeddings(): | |
BUCKET, KEY = "cgd-embeddings-bucket", "survey_info_embeddings.pt" | |
buf = io.BytesIO() | |
boto3.client("s3").download_fileobj(BUCKET, KEY, buf) | |
buf.seek(0) | |
ckpt = torch.load(buf, map_location="cpu") | |
buf.close(); gc.collect() | |
return ckpt["ids"], ckpt["embeddings"] | |
ids_list, emb_tensor = load_embeddings() | |
def get_st_model(): | |
return SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", device="cpu") | |
# βββββββββββββββββββββββββββββββββββββββββββββββ | |
# 2) Widgets | |
# βββββββββββββββββββββββββββββββββββββββββββββββ | |
country_opts = sorted(df["country"].dropna().unique()) | |
year_opts = sorted(df["year"].dropna().unique()) | |
w_countries = pn.widgets.MultiSelect(name="Countries", options=country_opts) | |
w_years = pn.widgets.MultiSelect(name="Years", options=year_opts) | |
w_keyword = pn.widgets.TextInput(name="Keyword Search", placeholder="Search questions or answers") | |
w_group = pn.widgets.Checkbox(name="Group by Question Text", value=False) | |
# Semantic search | |
w_semquery = pn.widgets.TextInput(name="Semantic Query") | |
w_search_button = pn.widgets.Button(name="Search", button_type="primary", disabled=False) | |
# βββββββββββββββββββββββββββββββββββββββββββββββ | |
# 3) Filtering Logic | |
# βββββββββββββββββββββββββββββββββββββββββββββββ | |
def keyword_filter(countries, years, keyword, group): | |
filt = df.copy() | |
if countries: | |
filt = filt[filt["country"].isin(countries)] | |
if years: | |
filt = filt[filt["year"].isin(years)] | |
if keyword: | |
filt = filt[ | |
filt["question_text"].str.contains(keyword, case=False, na=False) | | |
filt["answer_text"].str.contains(keyword, case=False, na=False) | | |
filt["question_code"].astype(str).str.contains(keyword, case=False, na=False) | |
] | |
if group: | |
grouped = ( | |
filt.groupby("question_text") | |
.agg({ | |
"country": lambda x: sorted(set(x)), | |
"year": lambda x: sorted(set(x)), | |
"answer_text": lambda x: list(x)[:3] | |
}) | |
.reset_index() | |
.rename(columns={ | |
"country": "Countries", | |
"year": "Years", | |
"answer_text": "Sample Answers" | |
}) | |
) | |
return pn.pane.DataFrame(grouped, sizing_mode="stretch_width", height=400) | |
return pn.pane.DataFrame( | |
filt[["country", "year", "question_text", "answer_text"]], | |
sizing_mode="stretch_width", height=400 | |
) | |
# βββββββββββββββββββββββββββββββββββββββββββββββ | |
# 4) Semantic Search Callback | |
# βββββββββββββββββββββββββββββββββββββββββββββββ | |
def semantic_search(event=None): | |
query = w_semquery.value.strip() | |
if not query: | |
return | |
model = get_st_model() | |
q_vec = model.encode(query, convert_to_tensor=True, device="cpu").cpu() | |
sims = util.cos_sim(q_vec, emb_tensor)[0] | |
top_vals, top_idx = torch.topk(sims, k=50) | |
sem_ids = [ids_list[i] for i in top_idx.tolist()] | |
sem_rows = df.loc[df["id"].isin(sem_ids)].copy() | |
score_map = dict(zip(sem_ids, top_vals.tolist())) | |
sem_rows["Score"] = sem_rows["id"].map(score_map) | |
sem_rows = sem_rows.sort_values("Score", ascending=False) | |
# Get keyword-filtered data | |
keyword_df = keyword_filter( | |
w_countries.value, | |
w_years.value, | |
w_keyword.value, | |
False | |
).object | |
remainder = keyword_df.loc[~keyword_df["id"].isin(sem_ids)].copy() | |
remainder["Score"] = "" | |
combined = pd.concat([sem_rows, remainder], ignore_index=True) | |
result_pane.object = combined[["Score", "country", "year", "question_text", "answer_text"]] | |
w_search_button.on_click(semantic_search) | |
result_pane = pn.pane.DataFrame(height=500, sizing_mode="stretch_width") | |
# βββββββββββββββββββββββββββββββββββββββββββββββ | |
# 5) Layout | |
# βββββββββββββββββββββββββββββββββββββββββββββββ | |
sidebar = pn.Column( | |
"## π Filter Questions", | |
w_countries, w_years, w_keyword, w_group, | |
pn.Spacer(height=20), | |
"## π§ Semantic Search", | |
w_semquery, w_search_button, | |
width=300 | |
) | |
main = pn.Column( | |
pn.pane.Markdown("## π CGD Survey Explorer"), | |
pn.Tabs( | |
("Filtered Results", keyword_filter), | |
("Semantic Search Results", result_pane), | |
) | |
) | |
pn.template.FastListTemplate( | |
title="CGD Survey Explorer", | |
sidebar=sidebar, | |
main=main, | |
theme_toggle=True, | |
).servable() | |