import streamlit as st import pandas as pd from appStore.prep_data import process_giz_worldwide from appStore.prep_utils import create_documents, get_client from appStore.embed import hybrid_embed_chunks from appStore.search import hybrid_search from appStore.region_utils import load_region_data, get_country_name from torch import cuda import json # get the device to be used eithe gpu or cpu device = 'cuda' if cuda.is_available() else 'cpu' st.set_page_config(page_title="SEARCH IATI",layout='wide') st.title("GIZ Project Database") var = st.text_input("Enter Search Query") # Load the region lookup CSV region_lookup_path = "docStore/regions_lookup.csv" region_df = load_region_data(region_lookup_path) #################### Create the embeddings collection and save ###################### # the steps below need to be performed only once and then commented out any unnecssary compute over-run ##### First we process and create the chunks for relvant data source #chunks = process_giz_worldwide() ##### Convert to langchain documents #temp_doc = create_documents(chunks,'chunks') ##### Embed and store docs, check if collection exist then you need to update the collection collection_name = "giz_worldwide" #hybrid_embed_chunks(docs= temp_doc, collection_name = collection_name) ################### Hybrid Search ###################################################### client = get_client() print(client.get_collections()) # Fetch unique country codes and map to country names @st.cache_data def get_country_name_mapping(_client, collection_name, region_df): results = hybrid_search(_client, "", collection_name) country_set = set() for res in results[0] + results[1]: countries = res.payload.get('metadata', {}).get('countries', "[]") try: country_list = json.loads(countries.replace("'", '"')) country_set.update([code.upper() for code in country_list]) # Normalize to uppercase except json.JSONDecodeError: pass # Create a mapping of country names to ISO codes country_name_to_code = {get_country_name(code, region_df): code for code in country_set} return country_name_to_code # Get country name mapping client = get_client() country_name_mapping = get_country_name_mapping(client, collection_name, region_df) unique_country_names = sorted(country_name_mapping.keys()) # List of country names # Layout filters in columns col1, col2, col3 = st.columns([1, 1, 4]) with col1: country_filter = st.selectbox("Country", ["All"] + unique_country_names) # Display country names with col2: end_year_range = st.slider("Project End Year", min_value=2010, max_value=2030, value=(2010, 2030)) # Checkbox to control whether to show only exact matches show_exact_matches = st.checkbox("Show only exact matches", value=False) button = st.button("Search") if button: results = hybrid_search(client, var, collection_name) def filter_results(results, country_filter, end_year_range): filtered = [] for res in results: metadata = res.payload.get('metadata', {}) countries = metadata.get('countries', "[]") end_year = float(metadata.get('end_year', 0)) # Process countries string to a list try: country_list = json.loads(countries.replace("'", '"')) country_list = [code.upper() for code in country_list] # Convert to uppercase except json.JSONDecodeError: country_list = [] # Translate selected country name back to ISO code selected_iso_code = country_name_mapping.get(country_filter, None) # Apply country and year filters if (country_filter == "All" or selected_iso_code in country_list) and (end_year_range[0] <= end_year <= end_year_range[1]): filtered.append(res) return filtered # Check user preference for exact matches if show_exact_matches: st.write(f"Showing **Top 10 Lexical Search results** for query: {var}") lexical_results = results[1] # Lexical results are in index 1 filtered_lexical_results = filter_results(lexical_results, country_filter, end_year_range) for res in filtered_lexical_results[:10]: project_name = res.payload['metadata'].get('project_name', 'Project Link') url = res.payload['metadata'].get('url', '#') st.markdown(f"#### [{project_name}]({url})") st.write(res.payload['page_content']) st.divider() else: st.write(f"Showing **Top 10 Semantic Search results** for query: {var}") semantic_results = results[0] # Semantic results are in index 0 filtered_semantic_results = filter_results(semantic_results, country_filter, end_year_range) for res in filtered_semantic_results[:10]: project_name = res.payload['metadata'].get('project_name', 'Project Link') url = res.payload['metadata'].get('url', '#') st.markdown(f"#### [{project_name}]({url})") st.write(res.payload['page_content']) st.divider() # for i in results: # st.subheader(str(i.metadata['id'])+":"+str(i.metadata['title_main'])) # st.caption(f"Status:{str(i.metadata['status'])}, Country:{str(i.metadata['country_name'])}") # st.write(i.page_content) # st.divider()