Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Update app.py
Browse filesAdd UI features: year slicer, exact match checkbox, format search results
app.py
CHANGED
@@ -10,8 +10,8 @@ device = 'cuda' if cuda.is_available() else 'cpu'
|
|
10 |
|
11 |
|
12 |
st.set_page_config(page_title="SEARCH IATI",layout='wide')
|
13 |
-
st.title("
|
14 |
-
var=st.text_input("
|
15 |
|
16 |
#################### Create the embeddings collection and save ######################
|
17 |
# the steps below need to be performed only once and then commented out any unnecssary compute over-run
|
@@ -26,17 +26,86 @@ collection_name = "giz_worldwide"
|
|
26 |
################### Hybrid Search ######################################################
|
27 |
client = get_client()
|
28 |
print(client.get_collections())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
button=st.button("search")
|
30 |
#found_docs = vectorstore.similarity_search(var)
|
31 |
#print(found_docs)
|
32 |
# results= get_context(vectorstore, f"find the relvant paragraphs for: {var}")
|
33 |
if button:
|
34 |
results = hybrid_search(client, var, collection_name)
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
# for i in results:
|
42 |
# st.subheader(str(i.metadata['id'])+":"+str(i.metadata['title_main']))
|
|
|
10 |
|
11 |
|
12 |
st.set_page_config(page_title="SEARCH IATI",layout='wide')
|
13 |
+
st.title("GIZ Project Database")
|
14 |
+
var = st.text_input("Enter Search Query")
|
15 |
|
16 |
#################### Create the embeddings collection and save ######################
|
17 |
# the steps below need to be performed only once and then commented out any unnecssary compute over-run
|
|
|
26 |
################### Hybrid Search ######################################################
|
27 |
client = get_client()
|
28 |
print(client.get_collections())
|
29 |
+
|
30 |
+
# Fetch unique country codes from the metadata for the dropdown
|
31 |
+
@st.cache_data
|
32 |
+
def get_unique_countries(_client, collection_name):
|
33 |
+
results = hybrid_search(_client, "", collection_name)
|
34 |
+
country_set = set()
|
35 |
+
for res in results[0] + results[1]:
|
36 |
+
countries = res.payload.get('metadata', {}).get('countries', "[]")
|
37 |
+
try:
|
38 |
+
country_list = json.loads(countries.replace("'", '"'))
|
39 |
+
country_set.update(country_list)
|
40 |
+
except json.JSONDecodeError:
|
41 |
+
pass
|
42 |
+
return sorted(list(country_set))
|
43 |
+
|
44 |
+
unique_countries = get_unique_countries(client, collection_name)
|
45 |
+
|
46 |
+
# Layout filters in columns
|
47 |
+
col1, col2 = st.columns([1, 1])
|
48 |
+
|
49 |
+
with col1:
|
50 |
+
country_filter = st.selectbox("Filter by Country Code", ["All"] + unique_countries)
|
51 |
+
with col2:
|
52 |
+
end_year_range = st.slider("Filter by Project End Year", min_value=2010, max_value=2030, value=(2010, 2030))
|
53 |
+
|
54 |
+
|
55 |
+
# Checkbox to control whether to show only exact matches
|
56 |
+
show_exact_matches = st.checkbox("Show only exact matches", value=False)
|
57 |
+
|
58 |
+
|
59 |
button=st.button("search")
|
60 |
#found_docs = vectorstore.similarity_search(var)
|
61 |
#print(found_docs)
|
62 |
# results= get_context(vectorstore, f"find the relvant paragraphs for: {var}")
|
63 |
if button:
|
64 |
results = hybrid_search(client, var, collection_name)
|
65 |
+
|
66 |
+
# Filter results based on the user's input
|
67 |
+
def filter_results(results, country_filter, end_year_range):
|
68 |
+
filtered = []
|
69 |
+
for res in results:
|
70 |
+
metadata = res.payload.get('metadata', {})
|
71 |
+
countries = metadata.get('countries', "[]")
|
72 |
+
end_year = float(metadata.get('end_year', 0))
|
73 |
+
|
74 |
+
# Process countries string to a list
|
75 |
+
try:
|
76 |
+
country_list = json.loads(countries.replace("'", '"'))
|
77 |
+
except json.JSONDecodeError:
|
78 |
+
country_list = []
|
79 |
+
|
80 |
+
# Apply country and year filters
|
81 |
+
if (country_filter == "All" or country_filter in country_list) and (end_year_range[0] <= end_year <= end_year_range[1]):
|
82 |
+
filtered.append(res)
|
83 |
+
return filtered
|
84 |
+
|
85 |
+
# Check user preference for exact matches
|
86 |
+
if show_exact_matches:
|
87 |
+
st.write(f"Showing **Top 10 Lexical Search results** for query: {var}")
|
88 |
+
lexical_results = results[1] # Lexical results are in index 1
|
89 |
+
filtered_lexical_results = filter_results(lexical_results, country_filter, end_year_range)
|
90 |
+
for i, res in enumerate(filtered_lexical_results[:10]):
|
91 |
+
st.markdown(f"#### Result {i+1}")
|
92 |
+
st.write(res.payload['page_content'])
|
93 |
+
url = res.payload['metadata'].get('url', '#')
|
94 |
+
project_name = res.payload['metadata'].get('project_name', 'Project Link')
|
95 |
+
st.caption(f"**Source:** [{project_name}]({url})")
|
96 |
+
st.divider()
|
97 |
+
else:
|
98 |
+
st.write(f"Showing **Top 10 Semantic Search results** for query: {var}")
|
99 |
+
semantic_results = results[0] # Semantic results are in index 0
|
100 |
+
filtered_semantic_results = filter_results(semantic_results, country_filter, end_year_range)
|
101 |
+
for i, res in enumerate(filtered_semantic_results[:10]):
|
102 |
+
st.markdown(f"#### Result {i+1}")
|
103 |
+
st.write(res.payload['page_content'])
|
104 |
+
url = res.payload['metadata'].get('url', '#')
|
105 |
+
project_name = res.payload['metadata'].get('project_name', 'Project Link')
|
106 |
+
st.caption(f"**Source:** [{project_name}]({url})")
|
107 |
+
st.divider()
|
108 |
+
|
109 |
|
110 |
# for i in results:
|
111 |
# st.subheader(str(i.metadata['id'])+":"+str(i.metadata['title_main']))
|