cgd-ui-TEST / app.py
myshirk's picture
add gigi's changes
baa583b verified
raw
history blame
7.14 kB
import os, io, json, gc
import streamlit as st
import pandas as pd
import psycopg2
import boto3, torch
from sentence_transformers import SentenceTransformer, util
# ────────────────────────────────────────────────────────────────────────
# 1) DB credentials (from HF secrets or env) – original
# ────────────────────────────────────────────────────────────────────────
DB_HOST = os.getenv("DB_HOST")
DB_PORT = os.getenv("DB_PORT", "5432")
DB_NAME = os.getenv("DB_NAME")
DB_PASSWORD = os.getenv("DB_PASSWORD")
@st.cache_data(ttl=600)
def get_data() -> pd.DataFrame:
try:
conn = psycopg2.connect(
host=DB_HOST,
dbname=DB_NAME,
user=DB_USER,
password=DB_PASSWORD,
sslmode="require",
)
query = """
SELECT id, country, year, section,
question_code, question_text,
answer_code, answer_text
FROM survey_info;
"""
df_ = pd.read_sql_query(query, conn)
conn.close()
return df_
except Exception as e:
st.error(f"Failed to connect to the database: {e}")
st.stop()
df = get_data() # ← original DataFrame
# Build a quick lookup row-index β†’ DataFrame row for later
row_lookup = {row.id: i for i, row in df.iterrows()}
# ────────────────────────────────────────────────────────────────────────
# 2) Load embeddings + ids once per session (S3) – new, cached
# ────────────────────────────────────────────────────────────────────────
@st.cache_resource
def load_embeddings():
# credentials already in env (HF secrets) – boto3 will pick them up
BUCKET = "cgd-embeddings-bucket"
KEY = "survey_info_embeddings.pt" # dict {'ids', 'embeddings'}
buf = io.BytesIO()
boto3.client("s3").download_fileobj(BUCKET, KEY, buf)
buf.seek(0)
ckpt = torch.load(buf, map_location="cpu")
buf.close(); gc.collect()
if not (isinstance(ckpt, dict) and {"ids","embeddings"} <= ckpt.keys()):
st.error("Bad checkpoint format in survey_info_embeddings.pt"); st.stop()
return ckpt["ids"], ckpt["embeddings"]
ids_list, emb_tensor = load_embeddings()
# ────────────────────────────────────────────────────────────────────────
# 3) Streamlit UI – original filters + new semantic search
# ────────────────────────────────────────────────────────────────────────
st.title("🌍 CGD Survey Explorer (Live DB)")
st.sidebar.header("πŸ”Ž Filter Questions")
country_options = sorted(df["country"].dropna().unique())
year_options = sorted(df["year"].dropna().unique())
selected_countries = st.sidebar.multiselect("Select Country/Countries", country_options)
selected_years = st.sidebar.multiselect("Select Year(s)", year_options)
keyword = st.sidebar.text_input(
"Keyword Search (Question text / Answer text / Question code)", ""
)
group_by_question = st.sidebar.checkbox("Group by Question Text")
# ── new semantic search panel ───────────────────────────────────────────
st.sidebar.markdown("---")
st.sidebar.subheader("🧠 Semantic Search")
sem_query = st.sidebar.text_input("Enter a natural-language query")
if st.sidebar.button("Search", disabled=not sem_query.strip()):
with st.spinner("Embedding & searching…"):
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
q_vec = model.encode(sem_query.strip(), convert_to_tensor=True).cpu()
scores = util.cos_sim(q_vec, emb_tensor)[0]
top_vals, top_idx = torch.topk(scores, k=10) # grab extra
results = []
for score, emb_row in zip(top_vals.tolist(), top_idx.tolist()):
db_id = ids_list[emb_row]
if db_id in row_lookup:
row = df.iloc[row_lookup[db_id]]
if row["question_text"] and row["answer_text"]:
results.append({
"Score": f"{score:.3f}",
"Country": row["country"],
"Year": row["year"],
"Question": row["question_text"],
"Answer": row["answer_text"],
})
if results:
st.subheader(f"πŸ” Semantic Results ({len(results)} found)")
st.dataframe(pd.DataFrame(results).head(5))
else:
st.info("No semantic matches found.")
st.markdown("---")
# ── apply original filters ──────────────────────────────────────────────
filtered = df[
(df["country"].isin(selected_countries) if selected_countries else True) &
(df["year"].isin(selected_years) if selected_years else True) &
(
df["question_text"].str.contains(keyword, case=False, na=False) |
df["answer_text"].str.contains(keyword, case=False, na=False) |
df["question_code"].astype(str).str.contains(keyword, case=False, na=False)
)
]
# ── original output logic ───────────────────────
if group_by_question:
st.subheader("πŸ“Š Grouped by Question Text")
grouped = (
filtered.groupby("question_text")
.agg({
"country": lambda x: sorted(set(x)),
"year": lambda x: sorted(set(x)),
"answer_text": lambda x: list(x)[:3]
})
.reset_index()
.rename(columns={
"country": "Countries",
"year": "Years",
"answer_text": "Sample Answers"
})
)
st.dataframe(grouped)
if grouped.empty:
st.info("No questions found with current filters.")
else:
heading_parts = []
if selected_countries:
heading_parts.append("Countries: " + ", ".join(selected_countries))
if selected_years:
heading_parts.append("Years: " + ", ".join(map(str, selected_years)))
st.markdown("### Results for " + (" | ".join(heading_parts) if heading_parts else "All Countries and Years"))
st.dataframe(filtered[["country", "year", "question_text", "answer_text"]])
if filtered.empty:
st.info("No matching questions found.")