Spaces:
Running
Running
File size: 6,433 Bytes
f7d7a98 9a96b62 f7d7a98 9a96b62 f7d7a98 9a96b62 f7d7a98 9a96b62 f7d7a98 9a96b62 f7d7a98 9a96b62 f7d7a98 9a96b62 f7d7a98 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
# app_panel.py β Panel-based CGD Survey Explorer
import os, io, json, gc
import panel as pn
import pandas as pd
import boto3, torch
from sentence_transformers import SentenceTransformer, util
import psycopg2
pn.extension()
# βββββββββββββββββββββββββββββββββββββββββββββββ
# 1) Data / Embeddings Loaders
# βββββββββββββββββββββββββββββββββββββββββββββββ
DB_HOST = os.getenv("DB_HOST")
DB_PORT = os.getenv("DB_PORT", "5432")
DB_NAME = os.getenv("DB_NAME")
DB_USER = os.getenv("DB_USER")
DB_PASSWORD = os.getenv("DB_PASSWORD")
@pn.cache()
def get_data():
conn = psycopg2.connect(
host=DB_HOST, port=DB_PORT,
dbname=DB_NAME, user=DB_USER, password=DB_PASSWORD,
sslmode="require"
)
df_ = pd.read_sql_query("""
SELECT id, country, year, section,
question_code, question_text,
answer_code, answer_text
FROM survey_info;
""", conn)
conn.close()
return df_
df = get_data()
row_lookup = {row.id: i for i, row in df.iterrows()}
@pn.cache()
def load_embeddings():
BUCKET, KEY = "cgd-embeddings-bucket", "survey_info_embeddings.pt"
buf = io.BytesIO()
boto3.client("s3").download_fileobj(BUCKET, KEY, buf)
buf.seek(0)
ckpt = torch.load(buf, map_location="cpu")
buf.close(); gc.collect()
return ckpt["ids"], ckpt["embeddings"]
ids_list, emb_tensor = load_embeddings()
@pn.cache()
def get_st_model():
return SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", device="cpu")
# βββββββββββββββββββββββββββββββββββββββββββββββ
# 2) Widgets
# βββββββββββββββββββββββββββββββββββββββββββββββ
country_opts = sorted(df["country"].dropna().unique())
year_opts = sorted(df["year"].dropna().unique())
w_countries = pn.widgets.MultiSelect(name="Countries", options=country_opts)
w_years = pn.widgets.MultiSelect(name="Years", options=year_opts)
w_keyword = pn.widgets.TextInput(name="Keyword Search", placeholder="Search questions or answers")
w_group = pn.widgets.Checkbox(name="Group by Question Text", value=False)
# Semantic search
w_semquery = pn.widgets.TextInput(name="Semantic Query")
w_search_button = pn.widgets.Button(name="Search", button_type="primary", disabled=False)
# βββββββββββββββββββββββββββββββββββββββββββββββ
# 3) Filtering Logic
# βββββββββββββββββββββββββββββββββββββββββββββββ
@pn.depends(w_countries, w_years, w_keyword, w_group)
def keyword_filter(countries, years, keyword, group):
filt = df.copy()
if countries:
filt = filt[filt["country"].isin(countries)]
if years:
filt = filt[filt["year"].isin(years)]
if keyword:
filt = filt[
filt["question_text"].str.contains(keyword, case=False, na=False) |
filt["answer_text"].str.contains(keyword, case=False, na=False) |
filt["question_code"].astype(str).str.contains(keyword, case=False, na=False)
]
if group:
grouped = (
filt.groupby("question_text")
.agg({
"country": lambda x: sorted(set(x)),
"year": lambda x: sorted(set(x)),
"answer_text": lambda x: list(x)[:3]
})
.reset_index()
.rename(columns={
"country": "Countries",
"year": "Years",
"answer_text": "Sample Answers"
})
)
return pn.pane.DataFrame(grouped, sizing_mode="stretch_width", height=400)
return pn.pane.DataFrame(
filt[["country", "year", "question_text", "answer_text"]],
sizing_mode="stretch_width", height=400
)
# βββββββββββββββββββββββββββββββββββββββββββββββ
# 4) Semantic Search Callback
# βββββββββββββββββββββββββββββββββββββββββββββββ
def semantic_search(event=None):
query = w_semquery.value.strip()
if not query:
return
model = get_st_model()
q_vec = model.encode(query, convert_to_tensor=True, device="cpu").cpu()
sims = util.cos_sim(q_vec, emb_tensor)[0]
top_vals, top_idx = torch.topk(sims, k=50)
sem_ids = [ids_list[i] for i in top_idx.tolist()]
sem_rows = df.loc[df["id"].isin(sem_ids)].copy()
score_map = dict(zip(sem_ids, top_vals.tolist()))
sem_rows["Score"] = sem_rows["id"].map(score_map)
sem_rows = sem_rows.sort_values("Score", ascending=False)
# Get keyword-filtered data
keyword_df = keyword_filter(
w_countries.value,
w_years.value,
w_keyword.value,
False
).object
remainder = keyword_df.loc[~keyword_df["id"].isin(sem_ids)].copy()
remainder["Score"] = ""
combined = pd.concat([sem_rows, remainder], ignore_index=True)
result_pane.object = combined[["Score", "country", "year", "question_text", "answer_text"]]
w_search_button.on_click(semantic_search)
result_pane = pn.pane.DataFrame(height=500, sizing_mode="stretch_width")
# βββββββββββββββββββββββββββββββββββββββββββββββ
# 5) Layout
# βββββββββββββββββββββββββββββββββββββββββββββββ
sidebar = pn.Column(
"## π Filter Questions",
w_countries, w_years, w_keyword, w_group,
pn.Spacer(height=20),
"## π§ Semantic Search",
w_semquery, w_search_button,
width=300
)
main = pn.Column(
pn.pane.Markdown("## π CGD Survey Explorer"),
pn.Tabs(
("Filtered Results", keyword_filter),
("Semantic Search Results", result_pane),
)
)
pn.template.FastListTemplate(
title="CGD Survey Explorer",
sidebar=sidebar,
main=main,
theme_toggle=True,
).servable()
|