Spaces:
Running
Running
File size: 7,183 Bytes
d3a33c8 e35f77b d3a33c8 23381bb e35f77b 23381bb e35f77b 23381bb e35f77b 23381bb 35c1ade e35f77b 23381bb e35f77b 23381bb e35f77b 23381bb 4b5445a 23381bb e35f77b 23381bb e35f77b 23381bb d3a33c8 cc9cf8b 4e71d04 e35f77b 4e71d04 e35f77b fad7dca 9cb0a2b e35f77b cc9cf8b 81d02b5 e35f77b 81d02b5 4e71d04 e35f77b cca0254 9cb0a2b e35f77b cca0254 81d02b5 e35f77b cc9cf8b e35f77b cc9cf8b e35f77b cc9cf8b 4e71d04 cc9cf8b e35f77b cc9cf8b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
import os, io, json, gc
import streamlit as st
import pandas as pd
import psycopg2
import boto3, torch
from sentence_transformers import SentenceTransformer, util
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 1) DB credentials (from HF secrets or env) β original
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
DB_HOST = os.getenv("DB_HOST")
DB_PORT = os.getenv("DB_PORT", "5432")
DB_NAME = os.getenv("DB_NAME")
DB_USER = os.getenv("DB_USER")
DB_PASSWORD = os.getenv("DB_PASSWORD")
@st.cache_data(ttl=600)
def get_data() -> pd.DataFrame:
try:
conn = psycopg2.connect(
host=DB_HOST,
port=DB_PORT,
dbname=DB_NAME,
user=DB_USER,
password=DB_PASSWORD,
sslmode="require",
)
query = """
SELECT id, country, year, section,
question_code, question_text,
answer_code, answer_text
FROM survey_info;
"""
df_ = pd.read_sql_query(query, conn)
conn.close()
return df_
except Exception as e:
st.error(f"Failed to connect to the database: {e}")
st.stop()
df = get_data() # β original DataFrame
# Build a quick lookup row-index β DataFrame row for later
row_lookup = {row.id: i for i, row in df.iterrows()}
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 2) Load embeddings + ids once per session (S3) β new, cached
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
@st.cache_resource
def load_embeddings():
# credentials already in env (HF secrets) β boto3 will pick them up
BUCKET = "cgd-embeddings-bucket"
KEY = "survey_info_embeddings.pt" # dict {'ids', 'embeddings'}
buf = io.BytesIO()
boto3.client("s3").download_fileobj(BUCKET, KEY, buf)
buf.seek(0)
ckpt = torch.load(buf, map_location="cpu")
buf.close(); gc.collect()
if not (isinstance(ckpt, dict) and {"ids","embeddings"} <= ckpt.keys()):
st.error("Bad checkpoint format in survey_info_embeddings.pt"); st.stop()
return ckpt["ids"], ckpt["embeddings"]
ids_list, emb_tensor = load_embeddings()
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 3) Streamlit UI β original filters + new semantic search
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
st.title("π CGD Survey Explorer (Live DB)")
st.sidebar.header("π Filter Questions")
country_options = sorted(df["country"].dropna().unique())
year_options = sorted(df["year"].dropna().unique())
selected_countries = st.sidebar.multiselect("Select Country/Countries", country_options)
selected_years = st.sidebar.multiselect("Select Year(s)", year_options)
keyword = st.sidebar.text_input(
"Keyword Search (Question text / Answer text / Question code)", ""
)
group_by_question = st.sidebar.checkbox("Group by Question Text")
# ββ new semantic search panel βββββββββββββββββββββββββββββββββββββββββββ
st.sidebar.markdown("---")
st.sidebar.subheader("π§ Semantic Search")
sem_query = st.sidebar.text_input("Enter a natural-language query")
if st.sidebar.button("Search", disabled=not sem_query.strip()):
with st.spinner("Embedding & searchingβ¦"):
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
q_vec = model.encode(sem_query.strip(), convert_to_tensor=True).cpu()
scores = util.cos_sim(q_vec, emb_tensor)[0]
top_vals, top_idx = torch.topk(scores, k=10) # grab extra
results = []
for score, emb_row in zip(top_vals.tolist(), top_idx.tolist()):
db_id = ids_list[emb_row]
if db_id in row_lookup:
row = df.iloc[row_lookup[db_id]]
if row["question_text"] and row["answer_text"]:
results.append({
"Score": f"{score:.3f}",
"Country": row["country"],
"Year": row["year"],
"Question": row["question_text"],
"Answer": row["answer_text"],
})
if results:
st.subheader(f"π Semantic Results ({len(results)} found)")
st.dataframe(pd.DataFrame(results).head(5))
else:
st.info("No semantic matches found.")
st.markdown("---")
# ββ apply original filters ββββββββββββββββββββββββββββββββββββββββββββββ
filtered = df[
(df["country"].isin(selected_countries) if selected_countries else True) &
(df["year"].isin(selected_years) if selected_years else True) &
(
df["question_text"].str.contains(keyword, case=False, na=False) |
df["answer_text"].str.contains(keyword, case=False, na=False) |
df["question_code"].astype(str).str.contains(keyword, case=False, na=False)
)
]
# ββ original output logic βββββββββββββββββββββββ
if group_by_question:
st.subheader("π Grouped by Question Text")
grouped = (
filtered.groupby("question_text")
.agg({
"country": lambda x: sorted(set(x)),
"year": lambda x: sorted(set(x)),
"answer_text": lambda x: list(x)[:3]
})
.reset_index()
.rename(columns={
"country": "Countries",
"year": "Years",
"answer_text": "Sample Answers"
})
)
st.dataframe(grouped)
if grouped.empty:
st.info("No questions found with current filters.")
else:
heading_parts = []
if selected_countries:
heading_parts.append("Countries: " + ", ".join(selected_countries))
if selected_years:
heading_parts.append("Years: " + ", ".join(map(str, selected_years)))
st.markdown("### Results for " + (" | ".join(heading_parts) if heading_parts else "All Countries and Years"))
st.dataframe(filtered[["country", "year", "question_text", "answer_text"]])
if filtered.empty:
st.info("No matching questions found.")
|