|
import streamlit as st |
|
from helper import load_dataset, parallel_load_and_combine,search, get_file_paths, get_cordinates, get_images_from_s3_to_display, get_images_with_bounding_boxes_from_s3, batch_search |
|
import os |
|
import time |
|
|
|
|
|
AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID") |
|
AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY") |
|
|
|
datasets = ["WayveScenes","MajorTom-Europe"] |
|
description = { |
|
"StopSign_test" : "A test dataset for me", |
|
"WayveScenes": "A large-scale dataset featuring diverse urban driving scenes, captured from autonomous vehicles to advance AI perception and navigation in complex environments.", |
|
"MajorTom-Europe": "A geospatial dataset containing satellite imagery from across Europe, designed for tasks like land-use classification, environmental monitoring, and earth observation analytics."} |
|
|
|
bucket_name = "datasets-quasara-io" |
|
|
|
|
|
def main(): |
|
st.title("Semantic Search and Image Display") |
|
|
|
|
|
dataset_name = st.selectbox("Select Dataset", datasets) |
|
|
|
if dataset_name == 'StopSign_test': |
|
folder_path = "" |
|
else: |
|
folder_path = f'{dataset_name}/' |
|
st.caption(description[dataset_name]) |
|
|
|
loading_text = st.empty() |
|
loading_text.text("Loading dataset...") |
|
progress_bar = st.progress(0) |
|
|
|
|
|
for i in range(0, 100, 25): |
|
time.sleep(0.2) |
|
progress_bar.progress(i + 25) |
|
|
|
|
|
dataset = load_dataset(f"quasara-io/{dataset_name}") |
|
|
|
|
|
progress_bar.progress(100) |
|
loading_text.text("Dataset loaded successfully!") |
|
|
|
|
|
query = st.text_input("Enter your search query") |
|
|
|
|
|
limit = st.number_input("Number of results to display", min_value=1, max_value=10, value=10) |
|
if st.checkbox("Enable Small Object Search"): |
|
search_in_small_objects = True |
|
st.text("Small Object Search Enabled") |
|
else: |
|
search_in_small_objects = False |
|
st.text("Small Object Search Disabled") |
|
|
|
|
|
if st.button("Search"): |
|
|
|
if not query: |
|
st.warning("Please enter a search query.") |
|
else: |
|
|
|
search_loading_text = st.empty() |
|
search_loading_text.text("Performing search...") |
|
search_progress_bar = st.progress(0) |
|
|
|
|
|
for i in range(0, 100, 25): |
|
time.sleep(0.3) |
|
search_progress_bar.progress(i + 25) |
|
|
|
|
|
dataset_keys = dataset.keys() |
|
main_df,split_df = parallel_load_and_combine(dataset_keys,dataset) |
|
|
|
|
|
if search_in_small_objects: |
|
|
|
results = batch_search(query, split_df) |
|
top_k_paths = get_file_paths(split_df,results) |
|
top_k_cordinates = get_cordinates(split_df, results) |
|
|
|
search_progress_bar.progress(100) |
|
search_loading_text.text("Search completed!") |
|
|
|
if top_k_paths and top_k_cordinates: |
|
get_images_with_bounding_boxes_from_s3(bucket_name,top_k_paths, top_k_cordinates, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, folder_path) |
|
else: |
|
st.write("No results found.") |
|
else: |
|
|
|
results = batch_search(query, main_df) |
|
top_k_paths = get_file_paths(main_df, results) |
|
|
|
search_progress_bar.progress(100) |
|
search_loading_text.text("Search completed!") |
|
|
|
|
|
if top_k_paths: |
|
st.write(f"Displaying top {len(top_k_paths)} results for query '{query}':") |
|
get_images_from_s3_to_display(bucket_name, top_k_paths, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, folder_path) |
|
else: |
|
st.write("No results found.") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |