annikwag's picture
Update app.py
4681487 verified
raw
history blame
4.8 kB
import streamlit as st
import pandas as pd
from appStore.prep_data import process_giz_worldwide
from appStore.prep_utils import create_documents, get_client
from appStore.embed import hybrid_embed_chunks
from appStore.search import hybrid_search
from torch import cuda
import json
# get the device to be used eithe gpu or cpu
device = 'cuda' if cuda.is_available() else 'cpu'
st.set_page_config(page_title="SEARCH IATI",layout='wide')
st.title("GIZ Project Database")
var = st.text_input("Enter Search Query")
#################### Create the embeddings collection and save ######################
# the steps below need to be performed only once and then commented out any unnecssary compute over-run
##### First we process and create the chunks for relvant data source
#chunks = process_giz_worldwide()
##### Convert to langchain documents
#temp_doc = create_documents(chunks,'chunks')
##### Embed and store docs, check if collection exist then you need to update the collection
collection_name = "giz_worldwide"
#hybrid_embed_chunks(docs= temp_doc, collection_name = collection_name)
################### Hybrid Search ######################################################
client = get_client()
print(client.get_collections())
# Fetch unique country codes from the metadata for the dropdown
@st.cache_data
def get_unique_countries(_client, collection_name):
results = hybrid_search(_client, "", collection_name)
country_set = set()
for res in results[0] + results[1]:
countries = res.payload.get('metadata', {}).get('countries', "[]")
try:
country_list = json.loads(countries.replace("'", '"'))
country_set.update(country_list)
except json.JSONDecodeError:
pass
return sorted(list(country_set))
unique_countries = get_unique_countries(client, collection_name)
# Layout filters in columns
col1, col2, col3 = st.columns([1, 1, 4])
with col1:
country_filter = st.selectbox("Country Code", ["All"] + unique_countries)
with col2:
end_year_range = st.slider("Project End Year", min_value=2010, max_value=2030, value=(2010, 2030))
# Checkbox to control whether to show only exact matches
show_exact_matches = st.checkbox("Show only exact matches", value=False)
button=st.button("Search")
#found_docs = vectorstore.similarity_search(var)
#print(found_docs)
# results= get_context(vectorstore, f"find the relvant paragraphs for: {var}")
if button:
results = hybrid_search(client, var, collection_name)
# Filter results based on the user's input
def filter_results(results, country_filter, end_year_range):
filtered = []
for res in results:
metadata = res.payload.get('metadata', {})
countries = metadata.get('countries', "[]")
end_year = float(metadata.get('end_year', 0))
# Process countries string to a list
try:
country_list = json.loads(countries.replace("'", '"'))
except json.JSONDecodeError:
country_list = []
# Apply country and year filters
if (country_filter == "All" or country_filter in country_list) and (end_year_range[0] <= end_year <= end_year_range[1]):
filtered.append(res)
return filtered
# Check user preference for exact matches
if show_exact_matches:
st.write(f"Showing **Top 10 Lexical Search results** for query: {var}")
lexical_results = results[1] # Lexical results are in index 1
filtered_lexical_results = filter_results(lexical_results, country_filter, end_year_range)
for res in filtered_lexical_results[:10]:
project_name = res.payload['metadata'].get('project_name', 'Project Link')
url = res.payload['metadata'].get('url', '#')
st.markdown(f"#### [{project_name}]({url})")
st.write(res.payload['page_content'])
st.divider()
else:
st.write(f"Showing **Top 10 Semantic Search results** for query: {var}")
semantic_results = results[0] # Semantic results are in index 0
filtered_semantic_results = filter_results(semantic_results, country_filter, end_year_range)
for res in filtered_semantic_results[:10]:
project_name = res.payload['metadata'].get('project_name', 'Project Link')
url = res.payload['metadata'].get('url', '#')
st.markdown(f"#### [{project_name}]({url})")
st.write(res.payload['page_content'])
st.divider()
# for i in results:
# st.subheader(str(i.metadata['id'])+":"+str(i.metadata['title_main']))
# st.caption(f"Status:{str(i.metadata['status'])}, Country:{str(i.metadata['country_name'])}")
# st.write(i.page_content)
# st.divider()