Spaces:
Build error
Build error
| import streamlit as st | |
| from helper import load_dataset, parallel_load_and_combine,search, get_file_paths, get_cordinates, get_images_from_s3_to_display, get_images_with_bounding_boxes_from_s3, batch_search | |
| import os | |
| import time | |
| # Load environment variables | |
| AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID") | |
| AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY") | |
| # Predefined list of datasets | |
| datasets = ["WayveScenes","MajorTom-Europe"] # Example dataset names | |
| description = { | |
| "StopSign_test" : "A test dataset for me", | |
| "WayveScenes": "A large-scale dataset featuring diverse urban driving scenes, captured from autonomous vehicles to advance AI perception and navigation in complex environments.", | |
| "MajorTom-Europe": "A geospatial dataset containing satellite imagery from across Europe, designed for tasks like land-use classification, environmental monitoring, and earth observation analytics."} | |
| # AWS S3 bucket name | |
| bucket_name = "datasets-quasara-io" | |
| # Streamlit App | |
| def main(): | |
| st.title("Semantic Search and Image Display") | |
| # Select dataset from dropdown | |
| dataset_name = st.selectbox("Select Dataset", datasets) | |
| if dataset_name == 'StopSign_test': | |
| folder_path = "" | |
| else: | |
| folder_path = f'{dataset_name}/' | |
| st.caption(description[dataset_name]) #trial area | |
| # Progress bar for loading dataset | |
| loading_text = st.empty() # Placeholder for dynamic text | |
| loading_text.text("Loading dataset...") | |
| progress_bar = st.progress(0) | |
| # Simulate dataset loading progress | |
| for i in range(0, 100, 25): | |
| time.sleep(0.2) # Simulate work being done | |
| progress_bar.progress(i + 25) | |
| # Load the selected dataset | |
| dataset = load_dataset(f"quasara-io/{dataset_name}") | |
| # Complete progress when loading is done | |
| progress_bar.progress(100) | |
| loading_text.text("Dataset loaded successfully!") | |
| # Input search query | |
| query = st.text_input("Enter your search query") | |
| # Number of results to display | |
| limit = st.number_input("Number of results to display", min_value=1, max_value=10, value=10) | |
| if st.checkbox("Enable Small Object Search"): | |
| search_in_small_objects = True | |
| st.text("Small Object Search Enabled") | |
| else: | |
| search_in_small_objects = False | |
| st.text("Small Object Search Disabled") | |
| # Search button | |
| if st.button("Search"): | |
| # Validate input | |
| if not query: | |
| st.warning("Please enter a search query.") | |
| else: | |
| # Progress bar for search | |
| search_loading_text = st.empty() | |
| search_loading_text.text("Performing search...") | |
| search_progress_bar = st.progress(0) | |
| # Simulate search progress (e.g., in 4 steps) | |
| for i in range(0, 100, 25): | |
| time.sleep(0.3) # Simulate work being done | |
| search_progress_bar.progress(i + 25) | |
| #Get Dataset Keys to speed up processing/search | |
| dataset_keys = dataset.keys() | |
| main_df,split_df = parallel_load_and_combine(dataset_keys,dataset) | |
| #Small Search | |
| if search_in_small_objects: | |
| # Perform the search | |
| results = batch_search(query, split_df) | |
| top_k_paths = get_file_paths(split_df,results) | |
| top_k_cordinates = get_cordinates(split_df, results) | |
| # Complete the search progress | |
| search_progress_bar.progress(100) | |
| search_loading_text.text("Search completed!") | |
| #Load Images with Bounding boxes | |
| if top_k_paths and top_k_cordinates: | |
| get_images_with_bounding_boxes_from_s3(bucket_name,top_k_paths, top_k_cordinates, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, folder_path) | |
| else: | |
| st.write("No results found.") | |
| else: | |
| #Normal Search | |
| results = batch_search(query, main_df) | |
| top_k_paths = get_file_paths(main_df, results) | |
| # Complete the search progress | |
| search_progress_bar.progress(100) | |
| search_loading_text.text("Search completed!") | |
| #Load Images | |
| # Display images from S3 | |
| if top_k_paths: | |
| st.write(f"Displaying top {len(top_k_paths)} results for query '{query}':") | |
| get_images_from_s3_to_display(bucket_name, top_k_paths, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, folder_path) | |
| else: | |
| st.write("No results found.") | |
| if __name__ == "__main__": | |
| main() |