annikwag's picture
Update app.py
5c71cde verified
raw
history blame
5.44 kB
import streamlit as st
import pandas as pd
from appStore.prep_data import process_giz_worldwide
from appStore.prep_utils import create_documents, get_client
from appStore.embed import hybrid_embed_chunks
from appStore.search import hybrid_search
from appStore.region_utils import load_region_data, get_country_name
from torch import cuda
import json
# get the device to be used eithe gpu or cpu
device = 'cuda' if cuda.is_available() else 'cpu'
st.set_page_config(page_title="SEARCH IATI",layout='wide')
st.title("GIZ Project Database")
var = st.text_input("Enter Search Query")
# Load the region lookup CSV
region_lookup_path = "docStore/regions_lookup.csv"
region_df = load_region_data(region_lookup_path)
#################### Create the embeddings collection and save ######################
# the steps below need to be performed only once and then commented out any unnecssary compute over-run
##### First we process and create the chunks for relvant data source
#chunks = process_giz_worldwide()
##### Convert to langchain documents
#temp_doc = create_documents(chunks,'chunks')
##### Embed and store docs, check if collection exist then you need to update the collection
collection_name = "giz_worldwide"
#hybrid_embed_chunks(docs= temp_doc, collection_name = collection_name)
################### Hybrid Search ######################################################
client = get_client()
print(client.get_collections())
# Fetch unique country codes and map to country names
@st.cache_data
def get_country_name_mapping(_client, collection_name, region_df):
results = hybrid_search(_client, "", collection_name)
country_set = set()
for res in results[0] + results[1]:
countries = res.payload.get('metadata', {}).get('countries', "[]")
try:
country_list = json.loads(countries.replace("'", '"'))
country_set.update([code.upper() for code in country_list]) # Normalize to uppercase
except json.JSONDecodeError:
pass
# Create a mapping of country names to ISO codes
country_name_to_code = {get_country_name(code, region_df): code for code in country_set}
return country_name_to_code
# Get country name mapping
client = get_client()
country_name_mapping = get_country_name_mapping(client, collection_name, region_df)
unique_country_names = sorted(country_name_mapping.keys()) # List of country names
# Layout filters in columns
col1, col2, col3 = st.columns([1, 1, 4])
with col1:
country_filter = st.selectbox("Country", ["All"] + unique_country_names) # Display country names
with col2:
end_year_range = st.slider("Project End Year", min_value=2010, max_value=2030, value=(2010, 2030))
# Checkbox to control whether to show only exact matches
show_exact_matches = st.checkbox("Show only exact matches", value=False)
button = st.button("Search")
if button:
results = hybrid_search(client, var, collection_name)
def filter_results(results, country_filter, end_year_range):
filtered = []
for res in results:
metadata = res.payload.get('metadata', {})
countries = metadata.get('countries', "[]")
end_year = float(metadata.get('end_year', 0))
# Process countries string to a list
try:
country_list = json.loads(countries.replace("'", '"'))
country_list = [code.upper() for code in country_list] # Convert to uppercase
except json.JSONDecodeError:
country_list = []
# Translate selected country name back to ISO code
selected_iso_code = country_name_mapping.get(country_filter, None)
# Apply country and year filters
if (country_filter == "All" or selected_iso_code in country_list) and (end_year_range[0] <= end_year <= end_year_range[1]):
filtered.append(res)
return filtered
# Check user preference for exact matches
if show_exact_matches:
st.write(f"Showing **Top 10 Lexical Search results** for query: {var}")
lexical_results = results[1] # Lexical results are in index 1
filtered_lexical_results = filter_results(lexical_results, country_filter, end_year_range)
for res in filtered_lexical_results[:10]:
project_name = res.payload['metadata'].get('project_name', 'Project Link')
url = res.payload['metadata'].get('url', '#')
st.markdown(f"#### [{project_name}]({url})")
st.write(res.payload['page_content'])
st.divider()
else:
st.write(f"Showing **Top 10 Semantic Search results** for query: {var}")
semantic_results = results[0] # Semantic results are in index 0
filtered_semantic_results = filter_results(semantic_results, country_filter, end_year_range)
for res in filtered_semantic_results[:10]:
project_name = res.payload['metadata'].get('project_name', 'Project Link')
url = res.payload['metadata'].get('url', '#')
st.markdown(f"#### [{project_name}]({url})")
st.write(res.payload['page_content'])
st.divider()
# for i in results:
# st.subheader(str(i.metadata['id'])+":"+str(i.metadata['title_main']))
# st.caption(f"Status:{str(i.metadata['status'])}, Country:{str(i.metadata['country_name'])}")
# st.write(i.page_content)
# st.divider()