cgd-ui-panel / app.py
myshirk's picture
add our app
f7d7a98 verified
raw
history blame
6.43 kB
# app_panel.py – Panel-based CGD Survey Explorer
import os, io, json, gc
import panel as pn
import pandas as pd
import boto3, torch
from sentence_transformers import SentenceTransformer, util
import psycopg2
pn.extension()
# ───────────────────────────────────────────────
# 1) Data / Embeddings Loaders
# ───────────────────────────────────────────────
DB_HOST = os.getenv("DB_HOST")
DB_PORT = os.getenv("DB_PORT", "5432")
DB_NAME = os.getenv("DB_NAME")
DB_USER = os.getenv("DB_USER")
DB_PASSWORD = os.getenv("DB_PASSWORD")
@pn.cache()
def get_data():
conn = psycopg2.connect(
host=DB_HOST, port=DB_PORT,
dbname=DB_NAME, user=DB_USER, password=DB_PASSWORD,
sslmode="require"
)
df_ = pd.read_sql_query("""
SELECT id, country, year, section,
question_code, question_text,
answer_code, answer_text
FROM survey_info;
""", conn)
conn.close()
return df_
df = get_data()
row_lookup = {row.id: i for i, row in df.iterrows()}
@pn.cache()
def load_embeddings():
BUCKET, KEY = "cgd-embeddings-bucket", "survey_info_embeddings.pt"
buf = io.BytesIO()
boto3.client("s3").download_fileobj(BUCKET, KEY, buf)
buf.seek(0)
ckpt = torch.load(buf, map_location="cpu")
buf.close(); gc.collect()
return ckpt["ids"], ckpt["embeddings"]
ids_list, emb_tensor = load_embeddings()
@pn.cache()
def get_st_model():
return SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", device="cpu")
# ───────────────────────────────────────────────
# 2) Widgets
# ───────────────────────────────────────────────
country_opts = sorted(df["country"].dropna().unique())
year_opts = sorted(df["year"].dropna().unique())
w_countries = pn.widgets.MultiSelect(name="Countries", options=country_opts)
w_years = pn.widgets.MultiSelect(name="Years", options=year_opts)
w_keyword = pn.widgets.TextInput(name="Keyword Search", placeholder="Search questions or answers")
w_group = pn.widgets.Checkbox(name="Group by Question Text", value=False)
# Semantic search
w_semquery = pn.widgets.TextInput(name="Semantic Query")
w_search_button = pn.widgets.Button(name="Search", button_type="primary", disabled=False)
# ───────────────────────────────────────────────
# 3) Filtering Logic
# ───────────────────────────────────────────────
@pn.depends(w_countries, w_years, w_keyword, w_group)
def keyword_filter(countries, years, keyword, group):
filt = df.copy()
if countries:
filt = filt[filt["country"].isin(countries)]
if years:
filt = filt[filt["year"].isin(years)]
if keyword:
filt = filt[
filt["question_text"].str.contains(keyword, case=False, na=False) |
filt["answer_text"].str.contains(keyword, case=False, na=False) |
filt["question_code"].astype(str).str.contains(keyword, case=False, na=False)
]
if group:
grouped = (
filt.groupby("question_text")
.agg({
"country": lambda x: sorted(set(x)),
"year": lambda x: sorted(set(x)),
"answer_text": lambda x: list(x)[:3]
})
.reset_index()
.rename(columns={
"country": "Countries",
"year": "Years",
"answer_text": "Sample Answers"
})
)
return pn.pane.DataFrame(grouped, sizing_mode="stretch_width", height=400)
return pn.pane.DataFrame(
filt[["country", "year", "question_text", "answer_text"]],
sizing_mode="stretch_width", height=400
)
# ───────────────────────────────────────────────
# 4) Semantic Search Callback
# ───────────────────────────────────────────────
def semantic_search(event=None):
query = w_semquery.value.strip()
if not query:
return
model = get_st_model()
q_vec = model.encode(query, convert_to_tensor=True, device="cpu").cpu()
sims = util.cos_sim(q_vec, emb_tensor)[0]
top_vals, top_idx = torch.topk(sims, k=50)
sem_ids = [ids_list[i] for i in top_idx.tolist()]
sem_rows = df.loc[df["id"].isin(sem_ids)].copy()
score_map = dict(zip(sem_ids, top_vals.tolist()))
sem_rows["Score"] = sem_rows["id"].map(score_map)
sem_rows = sem_rows.sort_values("Score", ascending=False)
# Get keyword-filtered data
keyword_df = keyword_filter(
w_countries.value,
w_years.value,
w_keyword.value,
False
).object
remainder = keyword_df.loc[~keyword_df["id"].isin(sem_ids)].copy()
remainder["Score"] = ""
combined = pd.concat([sem_rows, remainder], ignore_index=True)
result_pane.object = combined[["Score", "country", "year", "question_text", "answer_text"]]
w_search_button.on_click(semantic_search)
result_pane = pn.pane.DataFrame(height=500, sizing_mode="stretch_width")
# ───────────────────────────────────────────────
# 5) Layout
# ───────────────────────────────────────────────
sidebar = pn.Column(
"## πŸ” Filter Questions",
w_countries, w_years, w_keyword, w_group,
pn.Spacer(height=20),
"## 🧠 Semantic Search",
w_semquery, w_search_button,
width=300
)
main = pn.Column(
pn.pane.Markdown("## 🌍 CGD Survey Explorer"),
pn.Tabs(
("Filtered Results", keyword_filter),
("Semantic Search Results", result_pane),
)
)
pn.template.FastListTemplate(
title="CGD Survey Explorer",
sidebar=sidebar,
main=main,
theme_toggle=True,
).servable()