myshirk commited on
Commit
baa583b
Β·
verified Β·
1 Parent(s): 9cb0a2b

add gigi's changes

Browse files
Files changed (1) hide show
  1. app.py +93 -31
app.py CHANGED
@@ -1,67 +1,130 @@
1
-
2
  import streamlit as st
3
  import pandas as pd
4
  import psycopg2
5
- import os
 
6
 
7
- # Load DB credentials from Hugging Face secrets or environment variables
 
 
8
  DB_HOST = os.getenv("DB_HOST")
9
  DB_PORT = os.getenv("DB_PORT", "5432")
10
  DB_NAME = os.getenv("DB_NAME")
11
- DB_USER = os.getenv("DB_USER")
12
  DB_PASSWORD = os.getenv("DB_PASSWORD")
13
 
14
  @st.cache_data(ttl=600)
15
- def get_data():
16
  try:
17
  conn = psycopg2.connect(
18
  host=DB_HOST,
19
- port=DB_PORT,
20
  dbname=DB_NAME,
21
  user=DB_USER,
22
  password=DB_PASSWORD,
23
- sslmode="require"
24
 
25
  )
26
- query = "SELECT country, year, section, question_code, question_text, answer_code, answer_text FROM survey_info;"
27
- df = pd.read_sql_query(query, conn)
 
 
 
 
 
28
  conn.close()
29
- return df
30
  except Exception as e:
31
  st.error(f"Failed to connect to the database: {e}")
32
  st.stop()
33
 
34
- # Load data
35
- df = get_data()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
- # Streamlit UI
 
 
38
  st.title("🌍 CGD Survey Explorer (Live DB)")
39
 
40
  st.sidebar.header("πŸ”Ž Filter Questions")
41
 
42
- # Multiselect filters with default = show all
43
  country_options = sorted(df["country"].dropna().unique())
44
- year_options = sorted(df["year"].dropna().unique())
45
 
46
  selected_countries = st.sidebar.multiselect("Select Country/Countries", country_options)
47
- selected_years = st.sidebar.multiselect("Select Year(s)", year_options)
48
  keyword = st.sidebar.text_input(
49
  "Keyword Search (Question text / Answer text / Question code)", ""
50
- ) #NEW
51
  group_by_question = st.sidebar.checkbox("Group by Question Text")
52
 
53
- # Apply filters
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  filtered = df[
55
  (df["country"].isin(selected_countries) if selected_countries else True) &
56
- (df["year"].isin(selected_years) if selected_years else True) &
57
  (
58
  df["question_text"].str.contains(keyword, case=False, na=False) |
59
  df["answer_text"].str.contains(keyword, case=False, na=False) |
60
- df["question_code"].astype(str).str.contains(keyword, case=False, na=False) # NEW
61
  )
62
  ]
63
 
64
- # Output
65
  if group_by_question:
66
  st.subheader("πŸ“Š Grouped by Question Text")
67
 
@@ -69,13 +132,13 @@ if group_by_question:
69
  filtered.groupby("question_text")
70
  .agg({
71
  "country": lambda x: sorted(set(x)),
72
- "year": lambda x: sorted(set(x)),
73
- "answer_text": lambda x: list(x)[:3] # preview up to 3 answers
74
  })
75
  .reset_index()
76
  .rename(columns={
77
  "country": "Countries",
78
- "year": "Years",
79
  "answer_text": "Sample Answers"
80
  })
81
  )
@@ -86,19 +149,18 @@ if group_by_question:
86
  st.info("No questions found with current filters.")
87
 
88
  else:
89
- # Context-aware heading
90
  heading_parts = []
91
  if selected_countries:
92
  heading_parts.append("Countries: " + ", ".join(selected_countries))
93
  if selected_years:
94
  heading_parts.append("Years: " + ", ".join(map(str, selected_years)))
95
- if heading_parts:
96
- st.markdown("### Results for " + " | ".join(heading_parts))
97
- else:
98
- st.markdown("### Results for All Countries and Years")
99
 
100
  st.dataframe(filtered[["country", "year", "question_text", "answer_text"]])
101
 
102
  if filtered.empty:
103
- st.info("No matching questions found.")
104
-
 
1
+ import os, io, json, gc
2
  import streamlit as st
3
  import pandas as pd
4
  import psycopg2
5
+ import boto3, torch
6
+ from sentence_transformers import SentenceTransformer, util
7
 
8
+ # ────────────────────────────────────────────────────────────────────────
9
+ # 1) DB credentials (from HF secrets or env) – original
10
+ # ────────────────────────────────────────────────────────────────────────
11
  DB_HOST = os.getenv("DB_HOST")
12
  DB_PORT = os.getenv("DB_PORT", "5432")
13
  DB_NAME = os.getenv("DB_NAME")
 
14
  DB_PASSWORD = os.getenv("DB_PASSWORD")
15
 
16
  @st.cache_data(ttl=600)
17
+ def get_data() -> pd.DataFrame:
18
  try:
19
  conn = psycopg2.connect(
20
  host=DB_HOST,
 
21
  dbname=DB_NAME,
22
  user=DB_USER,
23
  password=DB_PASSWORD,
24
+ sslmode="require",
25
 
26
  )
27
+ query = """
28
+ SELECT id, country, year, section,
29
+ question_code, question_text,
30
+ answer_code, answer_text
31
+ FROM survey_info;
32
+ """
33
+ df_ = pd.read_sql_query(query, conn)
34
  conn.close()
35
+ return df_
36
  except Exception as e:
37
  st.error(f"Failed to connect to the database: {e}")
38
  st.stop()
39
 
40
+ df = get_data() # ← original DataFrame
41
+
42
+ # Build a quick lookup row-index β†’ DataFrame row for later
43
+ row_lookup = {row.id: i for i, row in df.iterrows()}
44
+
45
+ # ────────────────────────────────────────────────────────────────────────
46
+ # 2) Load embeddings + ids once per session (S3) – new, cached
47
+ # ────────────────────────────────────────────────────────────────────────
48
+ @st.cache_resource
49
+ def load_embeddings():
50
+ # credentials already in env (HF secrets) – boto3 will pick them up
51
+ BUCKET = "cgd-embeddings-bucket"
52
+ KEY = "survey_info_embeddings.pt" # dict {'ids', 'embeddings'}
53
+ buf = io.BytesIO()
54
+ boto3.client("s3").download_fileobj(BUCKET, KEY, buf)
55
+ buf.seek(0)
56
+ ckpt = torch.load(buf, map_location="cpu")
57
+ buf.close(); gc.collect()
58
+
59
+ if not (isinstance(ckpt, dict) and {"ids","embeddings"} <= ckpt.keys()):
60
+ st.error("Bad checkpoint format in survey_info_embeddings.pt"); st.stop()
61
+
62
+ return ckpt["ids"], ckpt["embeddings"]
63
+
64
+ ids_list, emb_tensor = load_embeddings()
65
 
66
+ # ────────────────────────────────────────────────────────────────────────
67
+ # 3) Streamlit UI – original filters + new semantic search
68
+ # ────────────────────────────────────────────────────────────────────────
69
  st.title("🌍 CGD Survey Explorer (Live DB)")
70
 
71
  st.sidebar.header("πŸ”Ž Filter Questions")
72
 
73
+
74
  country_options = sorted(df["country"].dropna().unique())
75
+ year_options = sorted(df["year"].dropna().unique())
76
 
77
  selected_countries = st.sidebar.multiselect("Select Country/Countries", country_options)
78
+ selected_years = st.sidebar.multiselect("Select Year(s)", year_options)
79
  keyword = st.sidebar.text_input(
80
  "Keyword Search (Question text / Answer text / Question code)", ""
81
+ )
82
  group_by_question = st.sidebar.checkbox("Group by Question Text")
83
 
84
+ # ── new semantic search panel ───────────────────────────────────────────
85
+ st.sidebar.markdown("---")
86
+ st.sidebar.subheader("🧠 Semantic Search")
87
+ sem_query = st.sidebar.text_input("Enter a natural-language query")
88
+ if st.sidebar.button("Search", disabled=not sem_query.strip()):
89
+ with st.spinner("Embedding & searching…"):
90
+ model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
91
+ q_vec = model.encode(sem_query.strip(), convert_to_tensor=True).cpu()
92
+ scores = util.cos_sim(q_vec, emb_tensor)[0]
93
+ top_vals, top_idx = torch.topk(scores, k=10) # grab extra
94
+
95
+ results = []
96
+ for score, emb_row in zip(top_vals.tolist(), top_idx.tolist()):
97
+ db_id = ids_list[emb_row]
98
+ if db_id in row_lookup:
99
+ row = df.iloc[row_lookup[db_id]]
100
+ if row["question_text"] and row["answer_text"]:
101
+ results.append({
102
+ "Score": f"{score:.3f}",
103
+ "Country": row["country"],
104
+ "Year": row["year"],
105
+ "Question": row["question_text"],
106
+ "Answer": row["answer_text"],
107
+ })
108
+ if results:
109
+ st.subheader(f"πŸ” Semantic Results ({len(results)} found)")
110
+ st.dataframe(pd.DataFrame(results).head(5))
111
+ else:
112
+ st.info("No semantic matches found.")
113
+
114
+ st.markdown("---")
115
+
116
+ # ── apply original filters ──────────────────────────────────────────────
117
  filtered = df[
118
  (df["country"].isin(selected_countries) if selected_countries else True) &
119
+ (df["year"].isin(selected_years) if selected_years else True) &
120
  (
121
  df["question_text"].str.contains(keyword, case=False, na=False) |
122
  df["answer_text"].str.contains(keyword, case=False, na=False) |
123
+ df["question_code"].astype(str).str.contains(keyword, case=False, na=False)
124
  )
125
  ]
126
 
127
+ # ── original output logic ───────────────────────
128
  if group_by_question:
129
  st.subheader("πŸ“Š Grouped by Question Text")
130
 
 
132
  filtered.groupby("question_text")
133
  .agg({
134
  "country": lambda x: sorted(set(x)),
135
+ "year": lambda x: sorted(set(x)),
136
+ "answer_text": lambda x: list(x)[:3]
137
  })
138
  .reset_index()
139
  .rename(columns={
140
  "country": "Countries",
141
+ "year": "Years",
142
  "answer_text": "Sample Answers"
143
  })
144
  )
 
149
  st.info("No questions found with current filters.")
150
 
151
  else:
152
+
153
  heading_parts = []
154
  if selected_countries:
155
  heading_parts.append("Countries: " + ", ".join(selected_countries))
156
  if selected_years:
157
  heading_parts.append("Years: " + ", ".join(map(str, selected_years)))
158
+ st.markdown("### Results for " + (" | ".join(heading_parts) if heading_parts else "All Countries and Years"))
159
+
160
+
161
+
162
 
163
  st.dataframe(filtered[["country", "year", "question_text", "answer_text"]])
164
 
165
  if filtered.empty:
166
+ st.info("No matching questions found.")