File size: 2,865 Bytes
8ba7a1f
 
 
f6019ba
8ba7a1f
f6019ba
8ba7a1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f6019ba
 
 
 
 
 
 
 
 
 
 
4650494
 
f6019ba
 
 
 
 
8ba7a1f
 
 
 
 
 
 
 
 
 
 
 
f6019ba
 
 
 
 
 
 
 
 
8ba7a1f
 
 
 
f6019ba
 
 
 
8ba7a1f
 
 
 
 
 
f6019ba
8ba7a1f
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import streamlit as st
from helper import load_hf_datasets, search, get_file_paths, get_images_from_s3_to_display
import os
import time

# Load environment variables
AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
# Predefined list of datasets
datasets = ["WayveScenes", "StopSign_test"]  # Example dataset names

# AWS S3 bucket name
bucket_name = "datasets-quasara-io"

# Streamlit App
def main():
    st.title("Semantic Search and Image Display")

    # Select dataset from dropdown
    dataset_name = st.selectbox("Select Dataset", datasets)
    if dataset_name == 'WayveScenes':
        folder_path = 'WayveScenes/'
    else:
        folder_path = ''
    
    # Progress bar for loading dataset
    loading_text = st.empty()  # Placeholder for dynamic text
    loading_text.text("Loading dataset...")
    progress_bar = st.progress(0)
    
    # Simulate dataset loading progress
    for i in range(0, 100, 25):
        time.sleep(0.2)  # Simulate work being done
        progress_bar.progress(i + 25)
    
    # Load the selected dataset
    df = load_hf_datasets(dataset_name)

    # Complete progress when loading is done
    progress_bar.progress(100)
    loading_text.text("Dataset loaded successfully!")

    # Input search query
    query = st.text_input("Enter your search query")

    # Number of results to display
    limit = st.number_input("Number of results to display", min_value=1, max_value=10, value=10)

    # Search button
    if st.button("Search"):
        # Validate input
        if not query:
            st.warning("Please enter a search query.")
        else:
            # Progress bar for search
            search_loading_text = st.empty()
            search_loading_text.text("Performing search...")
            search_progress_bar = st.progress(0)
            
            # Simulate search progress (e.g., in 4 steps)
            for i in range(0, 100, 25):
                time.sleep(0.3)  # Simulate work being done
                search_progress_bar.progress(i + 25)

            # Perform the search
            results = search(query, df, limit, 0, "cosine", search_in_images=True, search_in_small_objects=False)

            # Complete the search progress
            search_progress_bar.progress(100)
            search_loading_text.text("Search completed!")

            # Get the S3 file paths of the top results
            top_k_paths = get_file_paths(df, results)

            # Display images from S3
            if top_k_paths:
                st.write(f"Displaying top {len(top_k_paths)} results for query '{query}':")
                get_images_from_s3_to_display(bucket_name, top_k_paths, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, folder_path)
            else:
                st.write("No results found.")

if __name__ == "__main__":
    main()