cgd-ui-panel / app.py
myshirk's picture
auto update cols
91a6a7f verified
# app.py – Unified Panel App with Semantic Search + Filterable Tabulator
import os, io, gc
import panel as pn
import pandas as pd
import boto3, torch
import psycopg2
from sentence_transformers import SentenceTransformer, util
pn.extension('tabulator')
# ──────────────────────────────────────────────────────────────────────
# 1) Database and Resource Loading
# ──────────────────────────────────────────────────────────────────────
DB_HOST = os.getenv("DB_HOST")
DB_PORT = os.getenv("DB_PORT", "5432")
DB_NAME = os.getenv("DB_NAME")
DB_USER = os.getenv("DB_USER")
DB_PASSWORD = os.getenv("DB_PASSWORD")
@pn.cache()
def get_data():
conn = psycopg2.connect(
host=DB_HOST, port=DB_PORT,
dbname=DB_NAME, user=DB_USER, password=DB_PASSWORD,
sslmode="require"
)
df_ = pd.read_sql_query("""
SELECT id, country, year, section,
question_code, question_text,
answer_code, answer_text
FROM survey_info;
""", conn)
conn.close()
# Ensure year column is int, show blank instead of NaN
if "year" in df_.columns:
df_["year"] = pd.to_numeric(df_["year"], errors="coerce").astype("Int64").astype(str).replace({'<NA>': ''})
return df_
df = get_data()
@pn.cache()
def load_embeddings():
BUCKET, KEY = "cgd-embeddings-bucket", "survey_info_embeddings.pt"
buf = io.BytesIO()
boto3.client("s3").download_fileobj(BUCKET, KEY, buf)
buf.seek(0)
ckpt = torch.load(buf, map_location="cpu")
buf.close(); gc.collect()
return ckpt["ids"], ckpt["embeddings"]
@pn.cache()
def get_st_model():
return SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", device="cpu")
@pn.cache()
def get_semantic_resources():
model = get_st_model()
ids_list, emb_tensor = load_embeddings()
return model, ids_list, emb_tensor
# ──────────────────────────────────────────────────────────────────────
# 2) Widgets
# ──────────────────────────────────────────────────────────────────────
country_opts = sorted(df["country"].dropna().unique())
year_opts = sorted(df["year"].dropna().unique())
ALL_COLUMNS = ["country","year","section","question_code","question_text","answer_code","answer_text","Score"]
w_columns = pn.widgets.MultiChoice(
name="Columns to show",
options=ALL_COLUMNS,
value=["country","year","question_text","answer_text"]
)
w_countries = pn.widgets.MultiSelect(name="Countries", options=country_opts)
w_years = pn.widgets.MultiSelect(name="Years", options=year_opts)
w_keyword = pn.widgets.TextInput(name="Keyword Search", placeholder="Search questions or answers with exact string matching")
w_group = pn.widgets.Checkbox(name="Group by Question Text", value=False)
w_topk = pn.widgets.Select(name="Top-K (semantic)", options=[5, 10, 20, 50, 100], value=10, disabled=True)
w_semquery = pn.widgets.TextInput(name="Semantic Query", placeholder="LLM-powered semantic search")
w_search_button = pn.widgets.Button(name="Search", button_type="primary")
w_clear_filters = pn.widgets.Button(name="Clear Filters", button_type="warning")
# ──────────────────────────────────────────────────────────────────────
# 3) Unified Results Table (Tabulator)
# ──────────────────────────────────────────────────────────────────────
result_table = pn.widgets.Tabulator(
pagination='remote',
page_size=15,
sizing_mode="stretch_width",
layout='fit_columns',
show_index=False
)
# ──────────────────────────────────────────────────────────────────────
# 4) Search Logic
# ──────────────────────────────────────────────────────────────────────
def _group_by_question(df_in: pd.DataFrame) -> pd.DataFrame:
if df_in.empty:
return pd.DataFrame(columns=["question_text", "Countries", "Years", "Sample Answers"])
tmp = df_in.copy()
tmp["year"] = tmp["year"].replace('', pd.NA)
grouped = (
tmp.groupby("question_text", dropna=False)
.agg({
"country": lambda x: sorted({v for v in x if pd.notna(v)}),
"year": lambda x: sorted({str(v) for v in x if pd.notna(v)}),
"answer_text": lambda x: list(x.dropna())[:3],
})
.reset_index()
.rename(columns={"country": "Countries", "year": "Years", "answer_text": "Sample Answers"})
)
return grouped
def _selected_cols(has_score=False):
allowed = set(ALL_COLUMNS)
if not has_score and "Score" in w_columns.value:
w_columns.value = [c for c in w_columns.value if c != "Score"]
cols = [c for c in w_columns.value if c in allowed]
if not cols:
cols = ["country", "year", "question_text", "answer_text"]
return cols
def search(event=None):
query = w_semquery.value.strip()
filt = df.copy()
if w_countries.value:
filt = filt[filt["country"].isin(w_countries.value)]
if w_years.value:
filt = filt[filt["year"].isin(w_years.value)]
if w_keyword.value:
filt = filt[
filt["question_text"].str.contains(w_keyword.value, case=False, na=False) |
filt["answer_text"].str.contains(w_keyword.value, case=False, na=False) |
filt["question_code"].astype(str).str.contains(w_keyword.value, case=False, na=False)
]
if not query:
result_table.value = _group_by_question(filt) if w_group.value else filt[_selected_cols(False)]
return
model, ids_list, emb_tensor = get_semantic_resources()
filtered_ids = filt["id"].tolist()
id_to_index = {id_: i for i, id_ in enumerate(ids_list)}
filtered_indices = [id_to_index[id_] for id_ in filtered_ids if id_ in id_to_index]
if not filtered_indices:
result_table.value = _group_by_question(filt.iloc[0:0]) if w_group.value else pd.DataFrame(columns=_selected_cols(True))
return
top_k = min(int(w_topk.value), len(filtered_indices))
filtered_embs = emb_tensor[filtered_indices]
q_vec = model.encode(query, convert_to_tensor=True, device="cpu").cpu()
sims = util.cos_sim(q_vec, filtered_embs)[0]
top_vals, top_idx = torch.topk(sims, k=top_k)
top_filtered_ids = [filtered_ids[i] for i in top_idx.tolist()]
sem_rows = filt[filt["id"].isin(top_filtered_ids)].copy()
score_map = dict(zip(top_filtered_ids, top_vals.tolist()))
sem_rows["Score"] = sem_rows["id"].map(score_map)
sem_rows = sem_rows.sort_values("Score", ascending=False)
result_table.value = _group_by_question(sem_rows.drop(columns=["Score"])) if w_group.value else sem_rows[_selected_cols(True)]
def clear_filters(event=None):
w_countries.value = []
w_years.value = []
w_keyword.value = ""
w_semquery.value = ""
w_topk.disabled = True
result_table.value = df[["country", "year", "question_text", "answer_text"]].copy()
w_search_button.on_click(search)
w_clear_filters.on_click(clear_filters)
# Live updates for filters (except semantic query and keyword)
w_group.param.watch(lambda e: search(), 'value')
w_countries.param.watch(lambda e: search(), 'value')
w_years.param.watch(lambda e: search(), 'value')
w_columns.param.watch(lambda e: search(), 'value')
# Allow pressing Enter in semantic query or keyword to trigger search
w_semquery.param.watch(lambda e: search(), 'enter_pressed')
w_keyword.param.watch(lambda e: search(), 'enter_pressed')
# Enable/disable Top-K based on semantic query presence
def _toggle_topk_disabled(event=None):
w_topk.disabled = (w_semquery.value.strip() == '')
_toggle_topk_disabled()
w_semquery.param.watch(lambda e: _toggle_topk_disabled(), 'value')
# Show all data at startup
result_table.value = df[["country", "year", "question_text", "answer_text"]].copy()
# ──────────────────────────────────────────────────────────────────────
# 5) Layout
# ──────────────────────────────────────────────────────────────────────
sidebar = pn.Column(
"## πŸ”Ž Filters",
w_countries, w_years, w_keyword, w_group, w_columns,
pn.Spacer(height=20),
"## 🧠 Semantic Search",
w_semquery,
w_topk,
w_search_button,
pn.Spacer(height=20),
w_clear_filters,
width=300
)
main = pn.Column(
pn.pane.Markdown("## 🌍 CGD Survey Explorer"),
result_table
)
pn.template.FastListTemplate(
title="CGD Survey Explorer",
sidebar=sidebar,
main=main,
theme_toggle=True,
).servable()