Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pandas as pd | |
from appStore.prep_data import process_giz_worldwide | |
from appStore.prep_utils import create_documents, get_client | |
from appStore.embed import hybrid_embed_chunks | |
from appStore.search import hybrid_search | |
from torch import cuda | |
import json | |
# get the device to be used eithe gpu or cpu | |
device = 'cuda' if cuda.is_available() else 'cpu' | |
st.set_page_config(page_title="SEARCH IATI",layout='wide') | |
st.title("GIZ Project Database") | |
var = st.text_input("Enter Search Query") | |
#################### Create the embeddings collection and save ###################### | |
# the steps below need to be performed only once and then commented out any unnecssary compute over-run | |
##### First we process and create the chunks for relvant data source | |
#chunks = process_giz_worldwide() | |
##### Convert to langchain documents | |
#temp_doc = create_documents(chunks,'chunks') | |
##### Embed and store docs, check if collection exist then you need to update the collection | |
collection_name = "giz_worldwide" | |
#hybrid_embed_chunks(docs= temp_doc, collection_name = collection_name) | |
################### Hybrid Search ###################################################### | |
client = get_client() | |
print(client.get_collections()) | |
# Fetch unique country codes from the metadata for the dropdown | |
def get_unique_countries(_client, collection_name): | |
results = hybrid_search(_client, "", collection_name) | |
country_set = set() | |
for res in results[0] + results[1]: | |
countries = res.payload.get('metadata', {}).get('countries', "[]") | |
try: | |
country_list = json.loads(countries.replace("'", '"')) | |
country_set.update(country_list) | |
except json.JSONDecodeError: | |
pass | |
return sorted(list(country_set)) | |
unique_countries = get_unique_countries(client, collection_name) | |
# Layout filters in columns | |
col1, col2, col3 = st.columns([1, 1, 4]) | |
with col1: | |
country_filter = st.selectbox("Country Code", ["All"] + unique_countries) | |
with col2: | |
end_year_range = st.slider("Project End Year", min_value=2010, max_value=2030, value=(2010, 2030)) | |
# Checkbox to control whether to show only exact matches | |
show_exact_matches = st.checkbox("Show only exact matches", value=False) | |
button=st.button("Search") | |
#found_docs = vectorstore.similarity_search(var) | |
#print(found_docs) | |
# results= get_context(vectorstore, f"find the relvant paragraphs for: {var}") | |
if button: | |
results = hybrid_search(client, var, collection_name) | |
# Filter results based on the user's input | |
def filter_results(results, country_filter, end_year_range): | |
filtered = [] | |
for res in results: | |
metadata = res.payload.get('metadata', {}) | |
countries = metadata.get('countries', "[]") | |
end_year = float(metadata.get('end_year', 0)) | |
# Process countries string to a list | |
try: | |
country_list = json.loads(countries.replace("'", '"')) | |
except json.JSONDecodeError: | |
country_list = [] | |
# Apply country and year filters | |
if (country_filter == "All" or country_filter in country_list) and (end_year_range[0] <= end_year <= end_year_range[1]): | |
filtered.append(res) | |
return filtered | |
# Check user preference for exact matches | |
if show_exact_matches: | |
st.write(f"Showing **Top 10 Lexical Search results** for query: {var}") | |
lexical_results = results[1] # Lexical results are in index 1 | |
filtered_lexical_results = filter_results(lexical_results, country_filter, end_year_range) | |
for res in filtered_lexical_results[:10]: | |
project_name = res.payload['metadata'].get('project_name', 'Project Link') | |
url = res.payload['metadata'].get('url', '#') | |
st.markdown(f"#### [{project_name}]({url})") | |
st.write(res.payload['page_content']) | |
st.divider() | |
else: | |
st.write(f"Showing **Top 10 Semantic Search results** for query: {var}") | |
semantic_results = results[0] # Semantic results are in index 0 | |
filtered_semantic_results = filter_results(semantic_results, country_filter, end_year_range) | |
for res in filtered_semantic_results[:10]: | |
project_name = res.payload['metadata'].get('project_name', 'Project Link') | |
url = res.payload['metadata'].get('url', '#') | |
st.markdown(f"#### [{project_name}]({url})") | |
st.write(res.payload['page_content']) | |
st.divider() | |
# for i in results: | |
# st.subheader(str(i.metadata['id'])+":"+str(i.metadata['title_main'])) | |
# st.caption(f"Status:{str(i.metadata['status'])}, Country:{str(i.metadata['country_name'])}") | |
# st.write(i.page_content) | |
# st.divider() | |