annikwag commited on
Commit
17d08d8
·
verified ·
1 Parent(s): 9254d49

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -12
app.py CHANGED
@@ -34,9 +34,9 @@ collection_name = "giz_worldwide"
34
  client = get_client()
35
  print(client.get_collections())
36
 
37
- # Fetch unique country codes from the metadata for the dropdown
38
  @st.cache_data
39
- def get_unique_countries_with_names(_client, collection_name, region_df):
40
  results = hybrid_search(_client, "", collection_name)
41
  country_set = set()
42
  for res in results[0] + results[1]:
@@ -47,28 +47,27 @@ def get_unique_countries_with_names(_client, collection_name, region_df):
47
  except json.JSONDecodeError:
48
  pass
49
 
50
- # Map ISO codes to country names
51
- country_names = [get_country_name(code, region_df) for code in country_set]
52
- return sorted(country_names)
53
 
 
54
  client = get_client()
55
- unique_countries = get_unique_countries_with_names(client, collection_name, region_df)
 
56
 
57
  # Layout filters in columns
58
  col1, col2, col3 = st.columns([1, 1, 4])
59
 
60
  with col1:
61
- country_filter = st.selectbox("Country Code", ["All"] + unique_countries)
62
  with col2:
63
  end_year_range = st.slider("Project End Year", min_value=2010, max_value=2030, value=(2010, 2030))
64
 
65
  # Checkbox to control whether to show only exact matches
66
  show_exact_matches = st.checkbox("Show only exact matches", value=False)
67
 
68
- button=st.button("Search")
69
- #found_docs = vectorstore.similarity_search(var)
70
- #print(found_docs)
71
- # results= get_context(vectorstore, f"find the relvant paragraphs for: {var}")
72
  if button:
73
  results = hybrid_search(client, var, collection_name)
74
 
@@ -86,8 +85,11 @@ if button:
86
  except json.JSONDecodeError:
87
  country_list = []
88
 
 
 
 
89
  # Apply country and year filters
90
- if (country_filter == "All" or country_filter in country_list) and (end_year_range[0] <= end_year <= end_year_range[1]):
91
  filtered.append(res)
92
  return filtered
93
 
 
34
  client = get_client()
35
  print(client.get_collections())
36
 
37
+ # Fetch unique country codes and map to country names
38
  @st.cache_data
39
+ def get_country_name_mapping(_client, collection_name, region_df):
40
  results = hybrid_search(_client, "", collection_name)
41
  country_set = set()
42
  for res in results[0] + results[1]:
 
47
  except json.JSONDecodeError:
48
  pass
49
 
50
+ # Create a mapping of country names to ISO codes
51
+ country_name_to_code = {get_country_name(code, region_df): code for code in country_set}
52
+ return country_name_to_code
53
 
54
+ # Get country name mapping
55
  client = get_client()
56
+ country_name_mapping = get_country_name_mapping(client, collection_name, region_df)
57
+ unique_country_names = sorted(country_name_mapping.keys()) # List of country names
58
 
59
  # Layout filters in columns
60
  col1, col2, col3 = st.columns([1, 1, 4])
61
 
62
  with col1:
63
+ country_filter = st.selectbox("Country", ["All"] + unique_country_names) # Display country names
64
  with col2:
65
  end_year_range = st.slider("Project End Year", min_value=2010, max_value=2030, value=(2010, 2030))
66
 
67
  # Checkbox to control whether to show only exact matches
68
  show_exact_matches = st.checkbox("Show only exact matches", value=False)
69
 
70
+ button = st.button("Search")
 
 
 
71
  if button:
72
  results = hybrid_search(client, var, collection_name)
73
 
 
85
  except json.JSONDecodeError:
86
  country_list = []
87
 
88
+ # Translate selected country name back to ISO code
89
+ selected_iso_code = country_name_mapping.get(country_filter, None)
90
+
91
  # Apply country and year filters
92
+ if (country_filter == "All" or selected_iso_code in country_list) and (end_year_range[0] <= end_year <= end_year_range[1]):
93
  filtered.append(res)
94
  return filtered
95