annikwag commited on
Commit
d845358
·
verified ·
1 Parent(s): ead7309

Update app.py

Browse files

Add UI features: year slicer, exact match checkbox, format search results

Files changed (1) hide show
  1. app.py +76 -7
app.py CHANGED
@@ -10,8 +10,8 @@ device = 'cuda' if cuda.is_available() else 'cpu'
10
 
11
 
12
  st.set_page_config(page_title="SEARCH IATI",layout='wide')
13
- st.title("SEARCH IATI Database")
14
- var=st.text_input("enter keyword")
15
 
16
  #################### Create the embeddings collection and save ######################
17
  # the steps below need to be performed only once and then commented out any unnecssary compute over-run
@@ -26,17 +26,86 @@ collection_name = "giz_worldwide"
26
  ################### Hybrid Search ######################################################
27
  client = get_client()
28
  print(client.get_collections())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  button=st.button("search")
30
  #found_docs = vectorstore.similarity_search(var)
31
  #print(found_docs)
32
  # results= get_context(vectorstore, f"find the relvant paragraphs for: {var}")
33
  if button:
34
  results = hybrid_search(client, var, collection_name)
35
- st.write(f"Showing Top 10 results for query:{var}")
36
- st.write(f"Semantic: {len(results[0])}")
37
- st.write(results[0])
38
- st.write(f"Lexical: {len(results[1])}")
39
- st.write(results[1])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  # for i in results:
42
  # st.subheader(str(i.metadata['id'])+":"+str(i.metadata['title_main']))
 
10
 
11
 
12
  st.set_page_config(page_title="SEARCH IATI",layout='wide')
13
+ st.title("GIZ Project Database")
14
+ var = st.text_input("Enter Search Query")
15
 
16
  #################### Create the embeddings collection and save ######################
17
  # the steps below need to be performed only once and then commented out any unnecssary compute over-run
 
26
  ################### Hybrid Search ######################################################
27
  client = get_client()
28
  print(client.get_collections())
29
+
30
+ # Fetch unique country codes from the metadata for the dropdown
31
+ @st.cache_data
32
+ def get_unique_countries(_client, collection_name):
33
+ results = hybrid_search(_client, "", collection_name)
34
+ country_set = set()
35
+ for res in results[0] + results[1]:
36
+ countries = res.payload.get('metadata', {}).get('countries', "[]")
37
+ try:
38
+ country_list = json.loads(countries.replace("'", '"'))
39
+ country_set.update(country_list)
40
+ except json.JSONDecodeError:
41
+ pass
42
+ return sorted(list(country_set))
43
+
44
+ unique_countries = get_unique_countries(client, collection_name)
45
+
46
+ # Layout filters in columns
47
+ col1, col2 = st.columns([1, 1])
48
+
49
+ with col1:
50
+ country_filter = st.selectbox("Filter by Country Code", ["All"] + unique_countries)
51
+ with col2:
52
+ end_year_range = st.slider("Filter by Project End Year", min_value=2010, max_value=2030, value=(2010, 2030))
53
+
54
+
55
+ # Checkbox to control whether to show only exact matches
56
+ show_exact_matches = st.checkbox("Show only exact matches", value=False)
57
+
58
+
59
  button=st.button("search")
60
  #found_docs = vectorstore.similarity_search(var)
61
  #print(found_docs)
62
  # results= get_context(vectorstore, f"find the relvant paragraphs for: {var}")
63
  if button:
64
  results = hybrid_search(client, var, collection_name)
65
+
66
+ # Filter results based on the user's input
67
+ def filter_results(results, country_filter, end_year_range):
68
+ filtered = []
69
+ for res in results:
70
+ metadata = res.payload.get('metadata', {})
71
+ countries = metadata.get('countries', "[]")
72
+ end_year = float(metadata.get('end_year', 0))
73
+
74
+ # Process countries string to a list
75
+ try:
76
+ country_list = json.loads(countries.replace("'", '"'))
77
+ except json.JSONDecodeError:
78
+ country_list = []
79
+
80
+ # Apply country and year filters
81
+ if (country_filter == "All" or country_filter in country_list) and (end_year_range[0] <= end_year <= end_year_range[1]):
82
+ filtered.append(res)
83
+ return filtered
84
+
85
+ # Check user preference for exact matches
86
+ if show_exact_matches:
87
+ st.write(f"Showing **Top 10 Lexical Search results** for query: {var}")
88
+ lexical_results = results[1] # Lexical results are in index 1
89
+ filtered_lexical_results = filter_results(lexical_results, country_filter, end_year_range)
90
+ for i, res in enumerate(filtered_lexical_results[:10]):
91
+ st.markdown(f"#### Result {i+1}")
92
+ st.write(res.payload['page_content'])
93
+ url = res.payload['metadata'].get('url', '#')
94
+ project_name = res.payload['metadata'].get('project_name', 'Project Link')
95
+ st.caption(f"**Source:** [{project_name}]({url})")
96
+ st.divider()
97
+ else:
98
+ st.write(f"Showing **Top 10 Semantic Search results** for query: {var}")
99
+ semantic_results = results[0] # Semantic results are in index 0
100
+ filtered_semantic_results = filter_results(semantic_results, country_filter, end_year_range)
101
+ for i, res in enumerate(filtered_semantic_results[:10]):
102
+ st.markdown(f"#### Result {i+1}")
103
+ st.write(res.payload['page_content'])
104
+ url = res.payload['metadata'].get('url', '#')
105
+ project_name = res.payload['metadata'].get('project_name', 'Project Link')
106
+ st.caption(f"**Source:** [{project_name}]({url})")
107
+ st.divider()
108
+
109
 
110
  # for i in results:
111
  # st.subheader(str(i.metadata['id'])+":"+str(i.metadata['title_main']))