Files changed (1) hide show
  1. app.py +87 -34
app.py CHANGED
@@ -1,10 +1,14 @@
1
 
 
2
  import streamlit as st
3
  import pandas as pd
4
  import psycopg2
5
- import os
 
6
 
7
- # Load DB credentials from Hugging Face secrets or environment variables
 
 
8
  DB_HOST = os.getenv("DB_HOST")
9
  DB_PORT = os.getenv("DB_PORT", "5432")
10
  DB_NAME = os.getenv("DB_NAME")
@@ -12,7 +16,7 @@ DB_USER = os.getenv("DB_USER")
12
  DB_PASSWORD = os.getenv("DB_PASSWORD")
13
 
14
  @st.cache_data(ttl=600)
15
- def get_data():
16
  try:
17
  conn = psycopg2.connect(
18
  host=DB_HOST,
@@ -20,85 +24,134 @@ def get_data():
20
  dbname=DB_NAME,
21
  user=DB_USER,
22
  password=DB_PASSWORD,
23
- sslmode="require"
24
-
25
  )
26
- query = "SELECT country, year, section, question_code, question_text, answer_code, answer_text FROM survey_info;"
27
- df = pd.read_sql_query(query, conn)
 
 
 
 
 
28
  conn.close()
29
- return df
30
  except Exception as e:
31
  st.error(f"Failed to connect to the database: {e}")
32
  st.stop()
33
 
34
- # Load data
35
- df = get_data()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
- # Streamlit UI
 
 
 
 
 
 
38
  st.title("🌍 CGD Survey Explorer (Live DB)")
39
 
40
  st.sidebar.header("πŸ”Ž Filter Questions")
41
 
42
- # Multiselect filters with default = show all
43
  country_options = sorted(df["country"].dropna().unique())
44
- year_options = sorted(df["year"].dropna().unique())
45
 
46
  selected_countries = st.sidebar.multiselect("Select Country/Countries", country_options)
47
- selected_years = st.sidebar.multiselect("Select Year(s)", year_options)
48
  keyword = st.sidebar.text_input(
49
  "Keyword Search (Question text / Answer text / Question code)", ""
50
- ) #NEW
51
  group_by_question = st.sidebar.checkbox("Group by Question Text")
52
 
53
- # Apply filters
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  filtered = df[
55
  (df["country"].isin(selected_countries) if selected_countries else True) &
56
- (df["year"].isin(selected_years) if selected_years else True) &
57
  (
58
  df["question_text"].str.contains(keyword, case=False, na=False) |
59
  df["answer_text"].str.contains(keyword, case=False, na=False) |
60
- df["question_code"].astype(str).str.contains(keyword, case=False, na=False) # NEW
61
  )
62
  ]
63
 
64
- # Output
65
  if group_by_question:
66
  st.subheader("πŸ“Š Grouped by Question Text")
67
-
68
  grouped = (
69
  filtered.groupby("question_text")
70
  .agg({
71
  "country": lambda x: sorted(set(x)),
72
- "year": lambda x: sorted(set(x)),
73
- "answer_text": lambda x: list(x)[:3] # preview up to 3 answers
74
  })
75
  .reset_index()
76
  .rename(columns={
77
  "country": "Countries",
78
- "year": "Years",
79
  "answer_text": "Sample Answers"
80
  })
81
  )
82
-
83
  st.dataframe(grouped)
84
-
85
  if grouped.empty:
86
  st.info("No questions found with current filters.")
87
-
88
  else:
89
- # Context-aware heading
90
  heading_parts = []
91
  if selected_countries:
92
  heading_parts.append("Countries: " + ", ".join(selected_countries))
93
  if selected_years:
94
  heading_parts.append("Years: " + ", ".join(map(str, selected_years)))
95
- if heading_parts:
96
- st.markdown("### Results for " + " | ".join(heading_parts))
97
- else:
98
- st.markdown("### Results for All Countries and Years")
99
-
100
  st.dataframe(filtered[["country", "year", "question_text", "answer_text"]])
101
-
102
  if filtered.empty:
103
  st.info("No matching questions found.")
104
-
 
1
 
2
+ import os, io, json, gc
3
  import streamlit as st
4
  import pandas as pd
5
  import psycopg2
6
+ import boto3, torch
7
+ from sentence_transformers import SentenceTransformer, util
8
 
9
+ # ────────────────────────────────────────────────────────────────────────
10
+ # 1) DB credentials (from HF secrets or env) – original
11
+ # ────────────────────────────────────────────────────────────────────────
12
  DB_HOST = os.getenv("DB_HOST")
13
  DB_PORT = os.getenv("DB_PORT", "5432")
14
  DB_NAME = os.getenv("DB_NAME")
 
16
  DB_PASSWORD = os.getenv("DB_PASSWORD")
17
 
18
  @st.cache_data(ttl=600)
19
+ def get_data() -> pd.DataFrame:
20
  try:
21
  conn = psycopg2.connect(
22
  host=DB_HOST,
 
24
  dbname=DB_NAME,
25
  user=DB_USER,
26
  password=DB_PASSWORD,
27
+ sslmode="require",
 
28
  )
29
+ query = """
30
+ SELECT id, country, year, section,
31
+ question_code, question_text,
32
+ answer_code, answer_text
33
+ FROM survey_info;
34
+ """
35
+ df_ = pd.read_sql_query(query, conn)
36
  conn.close()
37
+ return df_
38
  except Exception as e:
39
  st.error(f"Failed to connect to the database: {e}")
40
  st.stop()
41
 
42
+ df = get_data() # ← original DataFrame
43
+
44
+ # Build a quick lookup row-index β†’ DataFrame row for later
45
+ row_lookup = {row.id: i for i, row in df.iterrows()}
46
+
47
+ # ────────────────────────────────────────────────────────────────────────
48
+ # 2) Load embeddings + ids once per session (S3) – new, cached
49
+ # ────────────────────────────────────────────────────────────────────────
50
+ @st.cache_resource
51
+ def load_embeddings():
52
+ # credentials already in env (HF secrets) – boto3 will pick them up
53
+ BUCKET = "cgd-embeddings-bucket"
54
+ KEY = "survey_info_embeddings.pt" # dict {'ids', 'embeddings'}
55
+ buf = io.BytesIO()
56
+ boto3.client("s3").download_fileobj(BUCKET, KEY, buf)
57
+ buf.seek(0)
58
+ ckpt = torch.load(buf, map_location="cpu")
59
+ buf.close(); gc.collect()
60
+
61
+ if not (isinstance(ckpt, dict) and {"ids","embeddings"} <= ckpt.keys()):
62
+ st.error("Bad checkpoint format in survey_info_embeddings.pt"); st.stop()
63
 
64
+ return ckpt["ids"], ckpt["embeddings"]
65
+
66
+ ids_list, emb_tensor = load_embeddings()
67
+
68
+ # ────────────────────────────────────────────────────────────────────────
69
+ # 3) Streamlit UI – original filters + new semantic search
70
+ # ────────────────────────────────────────────────────────────────────────
71
  st.title("🌍 CGD Survey Explorer (Live DB)")
72
 
73
  st.sidebar.header("πŸ”Ž Filter Questions")
74
 
 
75
  country_options = sorted(df["country"].dropna().unique())
76
+ year_options = sorted(df["year"].dropna().unique())
77
 
78
  selected_countries = st.sidebar.multiselect("Select Country/Countries", country_options)
79
+ selected_years = st.sidebar.multiselect("Select Year(s)", year_options)
80
  keyword = st.sidebar.text_input(
81
  "Keyword Search (Question text / Answer text / Question code)", ""
82
+ )
83
  group_by_question = st.sidebar.checkbox("Group by Question Text")
84
 
85
+ # ── new semantic search panel ───────────────────────────────────────────
86
+ st.sidebar.markdown("---")
87
+ st.sidebar.subheader("🧠 Semantic Search")
88
+ sem_query = st.sidebar.text_input("Enter a natural-language query")
89
+ if st.sidebar.button("Search", disabled=not sem_query.strip()):
90
+ with st.spinner("Embedding & searching…"):
91
+ model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
92
+ q_vec = model.encode(sem_query.strip(), convert_to_tensor=True).cpu()
93
+ scores = util.cos_sim(q_vec, emb_tensor)[0]
94
+ top_vals, top_idx = torch.topk(scores, k=10) # grab extra
95
+
96
+ results = []
97
+ for score, emb_row in zip(top_vals.tolist(), top_idx.tolist()):
98
+ db_id = ids_list[emb_row]
99
+ if db_id in row_lookup:
100
+ row = df.iloc[row_lookup[db_id]]
101
+ if row["question_text"] and row["answer_text"]:
102
+ results.append({
103
+ "Score": f"{score:.3f}",
104
+ "Country": row["country"],
105
+ "Year": row["year"],
106
+ "Question": row["question_text"],
107
+ "Answer": row["answer_text"],
108
+ })
109
+ if results:
110
+ st.subheader(f"πŸ” Semantic Results ({len(results)} found)")
111
+ st.dataframe(pd.DataFrame(results).head(5))
112
+ else:
113
+ st.info("No semantic matches found.")
114
+
115
+ st.markdown("---")
116
+
117
+ # ── apply original filters ──────────────────────────────────────────────
118
  filtered = df[
119
  (df["country"].isin(selected_countries) if selected_countries else True) &
120
+ (df["year"].isin(selected_years) if selected_years else True) &
121
  (
122
  df["question_text"].str.contains(keyword, case=False, na=False) |
123
  df["answer_text"].str.contains(keyword, case=False, na=False) |
124
+ df["question_code"].astype(str).str.contains(keyword, case=False, na=False)
125
  )
126
  ]
127
 
128
+ # ── original output logic ───────────────────────
129
  if group_by_question:
130
  st.subheader("πŸ“Š Grouped by Question Text")
 
131
  grouped = (
132
  filtered.groupby("question_text")
133
  .agg({
134
  "country": lambda x: sorted(set(x)),
135
+ "year": lambda x: sorted(set(x)),
136
+ "answer_text": lambda x: list(x)[:3]
137
  })
138
  .reset_index()
139
  .rename(columns={
140
  "country": "Countries",
141
+ "year": "Years",
142
  "answer_text": "Sample Answers"
143
  })
144
  )
 
145
  st.dataframe(grouped)
 
146
  if grouped.empty:
147
  st.info("No questions found with current filters.")
 
148
  else:
 
149
  heading_parts = []
150
  if selected_countries:
151
  heading_parts.append("Countries: " + ", ".join(selected_countries))
152
  if selected_years:
153
  heading_parts.append("Years: " + ", ".join(map(str, selected_years)))
154
+ st.markdown("### Results for " + (" | ".join(heading_parts) if heading_parts else "All Countries and Years"))
 
 
 
 
155
  st.dataframe(filtered[["country", "year", "question_text", "answer_text"]])
 
156
  if filtered.empty:
157
  st.info("No matching questions found.")