#!/usr/bin/env python3 # app.py – CGD Survey Explorer + merged semantic search import os, io, json, gc import streamlit as st import pandas as pd import psycopg2 import boto3, torch from sentence_transformers import SentenceTransformer, util # ──────────────────────────────────────────────────────────────────────── # 1) Database credentials (provided via HF Secrets / env vars) # ──────────────────────────────────────────────────────────────────────── DB_HOST = os.getenv("DB_HOST") # set these in the Space’s Secrets DB_PORT = os.getenv("DB_PORT", "5432") DB_NAME = os.getenv("DB_NAME") DB_USER = os.getenv("DB_USER") DB_PASSWORD = os.getenv("DB_PASSWORD") @st.cache_data(ttl=600) def get_data() -> pd.DataFrame: """Pull the full survey_info table (cached for 10 min).""" conn = psycopg2.connect( host=DB_HOST, port=DB_PORT, dbname=DB_NAME, user=DB_USER, password=DB_PASSWORD, sslmode="require", ) df_ = pd.read_sql_query( """SELECT id, country, year, section, question_code, question_text, answer_code, answer_text FROM survey_info;""", conn, ) conn.close() return df_ df = get_data() row_lookup = {row.id: i for i, row in df.iterrows()} # ──────────────────────────────────────────────────────────────────────── # 2) Pre-computed embeddings (ids + tensor) – download once per session # ──────────────────────────────────────────────────────────────────────── @st.cache_resource def load_embeddings(): BUCKET = "cgd-embeddings-bucket" KEY = "survey_info_embeddings.pt" # contains {'ids', 'embeddings'} buf = io.BytesIO() boto3.client("s3").download_fileobj(BUCKET, KEY, buf) buf.seek(0) ckpt = torch.load(buf, map_location="cpu") buf.close(); gc.collect() if not (isinstance(ckpt, dict) and {"ids", "embeddings"} <= ckpt.keys()): st.error("Bad checkpoint format in survey_info_embeddings.pt"); st.stop() return ckpt["ids"], ckpt["embeddings"] ids_list, emb_tensor = load_embeddings() # ──────────────────────────────────────────────────────────────────────── # 3) Streamlit UI # ──────────────────────────────────────────────────────────────────────── st.title("🌍 CGD Survey Explorer (Live DB)") st.sidebar.header("🔎 Filter Questions") country_options = sorted(df["country"].dropna().unique()) year_options = sorted(df["year"].dropna().unique()) sel_countries = st.sidebar.multiselect("Select Country/Countries", country_options) sel_years = st.sidebar.multiselect("Select Year(s)", year_options) keyword = st.sidebar.text_input( "Keyword Search (Question text / Answer text / Question code)", "" ) group_by_question = st.sidebar.checkbox("Group by Question Text") # --- Semantic-search input (kept in sidebar) --------------------------- st.sidebar.markdown("---") st.sidebar.subheader("🧠 Semantic Search") sem_query = st.sidebar.text_input("Enter a natural-language query") search_clicked = st.sidebar.button("Search", disabled=not sem_query.strip()) # ── base_filtered: applies dropdown + keyword logic (always computed) ── base_filtered = df[ (df["country"].isin(sel_countries) if sel_countries else True) & (df["year"].isin(sel_years) if sel_years else True) & ( df["question_text"].str.contains(keyword, case=False, na=False) | df["answer_text"].str.contains(keyword, case=False, na=False) | df["question_code"].astype(str).str.contains(keyword, case=False, na=False) ) ] # ──────────────────────────────────────────────────────────────────────── # 4) When the Search button is clicked → build merged table # ──────────────────────────────────────────────────────────────────────── if search_clicked: with st.spinner("Embedding & searching…"): model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") q_vec = model.encode(sem_query.strip(), convert_to_tensor=True).cpu() sims = util.cos_sim(q_vec, emb_tensor)[0] top_vals, top_idx = torch.topk(sims, k=50) # get 50 candidates sem_ids = [ids_list[i] for i in top_idx.tolist()] sem_rows = df.loc[df["id"].isin(sem_ids)].copy() score_map = dict(zip(sem_ids, top_vals.tolist())) sem_rows["Score"] = sem_rows["id"].map(score_map) sem_rows = sem_rows.sort_values("Score", ascending=False) remainder = base_filtered.loc[~base_filtered["id"].isin(sem_ids)].copy() remainder["Score"] = "" # blank score for keyword-only rows combined = pd.concat([sem_rows, remainder], ignore_index=True) st.subheader(f"🔍 Combined Results ({len(combined)})") st.dataframe( combined[["Score", "country", "year", "question_text", "answer_text"]], use_container_width=True, ) # ──────────────────────────────────────────────────────────────────────── # 5) No semantic query → use original keyword filter logic / grouping # ──────────────────────────────────────────────────────────────────────── else: if group_by_question: st.subheader("📊 Grouped by Question Text") grouped = ( base_filtered.groupby("question_text") .agg({ "country": lambda x: sorted(set(x)), "year": lambda x: sorted(set(x)), "answer_text": lambda x: list(x)[:3] }) .reset_index() .rename(columns={ "country": "Countries", "year": "Years", "answer_text": "Sample Answers" }) ) st.dataframe(grouped, use_container_width=True) if grouped.empty: st.info("No questions found with current filters.") else: heading = [] if sel_countries: heading.append("Countries: " + ", ".join(sel_countries)) if sel_years: heading.append("Years: " + ", ".join(map(str, sel_years))) st.markdown("### Results for " + (" | ".join(heading) if heading else "All Countries and Years")) st.dataframe( base_filtered[["country", "year", "question_text", "answer_text"]], use_container_width=True, ) if base_filtered.empty: st.info("No matching questions found.")