#!/usr/bin/env python3 # app.py – CGD Survey Explorer (keyword + semantic in one table) import os, io, json, gc import streamlit as st import pandas as pd import psycopg2 import boto3, torch from sentence_transformers import SentenceTransformer, util # ───────────────────────────────────────────────────────────── # 1) Database credentials (HF Secrets or env vars) # ───────────────────────────────────────────────────────────── DB_HOST = os.getenv("DB_HOST") DB_PORT = os.getenv("DB_PORT", "5432") DB_NAME = os.getenv("DB_NAME") DB_USER = os.getenv("DB_USER") DB_PASSWORD = os.getenv("DB_PASSWORD") @st.cache_data(ttl=600) def get_data() -> pd.DataFrame: """Read survey_info once every 10 min.""" conn = psycopg2.connect( host=DB_HOST, port=DB_PORT, dbname=DB_NAME, user=DB_USER, password=DB_PASSWORD, sslmode="require", ) df_ = pd.read_sql_query(""" SELECT id, country, year, section, question_code, question_text, answer_code, answer_text FROM survey_info; """, conn) conn.close() return df_ df = get_data() row_lookup = {row.id: i for i, row in df.iterrows()} # ───────────────────────────────────────────────────────────── # 2) Cached resources # ───────────────────────────────────────────────────────────── @st.cache_resource def load_embeddings(): """Download ids + embedding tensor from S3 once per session.""" BUCKET, KEY = "cgd-embeddings-bucket", "survey_info_embeddings.pt" buf = io.BytesIO() boto3.client("s3").download_fileobj(BUCKET, KEY, buf) buf.seek(0) ckpt = torch.load(buf, map_location="cpu") buf.close(); gc.collect() if not (isinstance(ckpt, dict) and {"ids", "embeddings"} <= ckpt.keys()): st.error("Bad checkpoint format in survey_info_embeddings.pt"); st.stop() return ckpt["ids"], ckpt["embeddings"] ids_list, emb_tensor = load_embeddings() @st.cache_resource def get_st_model(): """Mini-LM sentence-transformer pinned to CPU (avoids meta-tensor bug).""" return SentenceTransformer( "sentence-transformers/all-MiniLM-L6-v2", device="cpu", ) # ───────────────────────────────────────────────────────────── # 3) Streamlit UI # ───────────────────────────────────────────────────────────── st.title("🌍 CGD Survey Explorer (Live DB)") st.sidebar.header("🔎 Filter Questions") country_opts = sorted(df["country"].dropna().unique()) year_opts = sorted(df["year"].dropna().unique()) sel_countries = st.sidebar.multiselect("Select Country/Countries", country_opts) sel_years = st.sidebar.multiselect("Select Year(s)", year_opts) keyword = st.sidebar.text_input("Keyword Search (Question / Answer / Code)") group_by_q = st.sidebar.checkbox("Group by Question Text") # ── Semantic search panel st.sidebar.markdown("---") st.sidebar.subheader("🧠 Semantic Search") sem_query = st.sidebar.text_input("Enter a natural-language query") search_clicked = st.sidebar.button("Search", disabled=not sem_query.strip()) # ── Always build the keyword/dropdown subset filtered = df[ (df["country"].isin(sel_countries) if sel_countries else True) & (df["year"].isin(sel_years) if sel_years else True) & ( df["question_text"].str.contains(keyword, case=False, na=False) | df["answer_text"].str.contains(keyword, case=False, na=False) | df["question_code"].astype(str).str.contains(keyword, case=False, na=False) ) ] # ───────────────────────────────────────────────────────────── # 4) Semantic Search → merged table # ───────────────────────────────────────────────────────────── if search_clicked: with st.spinner("Embedding & searching…"): model = get_st_model() q_vec = model.encode( sem_query.strip(), convert_to_tensor=True, device="cpu" ).cpu() sims = util.cos_sim(q_vec, emb_tensor)[0] top_vals, top_idx = torch.topk(sims, k=50) # 50 candidates sem_ids = [ids_list[i] for i in top_idx.tolist()] sem_rows = df.loc[df["id"].isin(sem_ids)].copy() score_map = dict(zip(sem_ids, top_vals.tolist())) sem_rows["Score"] = sem_rows["id"].map(score_map) sem_rows = sem_rows.sort_values("Score", ascending=False) # rows that matched keyword/dropdown but not semantic remainder = filtered.loc[~filtered["id"].isin(sem_ids)].copy() remainder["Score"] = "" # blank score combined = pd.concat([sem_rows, remainder], ignore_index=True) st.subheader(f"🔍 Combined Results ({len(combined)})") st.dataframe( combined[["Score", "country", "year", "question_text", "answer_text"]], use_container_width=True, ) st.stop() # skip original display logic below when semantic ran # ───────────────────────────────────────────────────────────── # 5) Original display (keyword / filters only) # ───────────────────────────────────────────────────────────── if group_by_q: st.subheader("📊 Grouped by Question Text") grouped = ( filtered.groupby("question_text") .agg({ "country": lambda x: sorted(set(x)), "year": lambda x: sorted(set(x)), "answer_text": lambda x: list(x)[:3] }) .reset_index() .rename(columns={ "country": "Countries", "year": "Years", "answer_text": "Sample Answers" }) ) st.dataframe(grouped, use_container_width=True) if grouped.empty: st.info("No questions found with current filters.") else: hdr = [] if sel_countries: hdr.append("Countries: " + ", ".join(sel_countries)) if sel_years: hdr.append("Years: " + ", ".join(map(str, sel_years))) st.markdown("### Results for " + (" | ".join(hdr) if hdr else "All Countries and Years")) st.dataframe( filtered[["country", "year", "question_text", "answer_text"]], use_container_width=True, ) if filtered.empty: st.info("No matching questions found.")