File size: 11,231 Bytes
20ea451
 
 
 
 
 
 
 
 
 
1a36398
20ea451
1a36398
20ea451
1a36398
20ea451
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a36398
20ea451
42c8d90
 
20ea451
42c8d90
 
20ea451
1a36398
 
42c8d90
1a36398
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81bf6cd
20ea451
 
 
 
 
 
1a36398
20ea451
 
 
 
 
 
 
 
 
 
 
9bd955c
20ea451
 
 
 
 
 
 
 
1a36398
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20ea451
 
 
 
 
 
 
 
 
1a36398
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20ea451
 
 
 
 
 
 
1a36398
20ea451
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a36398
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20ea451
1a36398
 
20ea451
 
 
 
 
1a36398
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20ea451
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
import numpy as np
from sentence_transformers import SentenceTransformer, util
from open_clip import create_model_from_pretrained, get_tokenizer
import torch
from datasets import load_dataset
from sklearn.metrics.pairwise import cosine_similarity
import torch.nn as nn
import boto3
import streamlit as st
from PIL import Image
from PIL import ImageDraw
from io import BytesIO
import pandas as pd
from typing import List, Union
import concurrent.futures


# Initialize the model globally to avoid reloading each time
model, preprocess = create_model_from_pretrained('hf-hub:timm/ViT-SO400M-14-SigLIP-384')
tokenizer = get_tokenizer('hf-hub:timm/ViT-SO400M-14-SigLIP-384')

#what model do we use? 

def encode_query(query: Union[str, Image.Image]) -> torch.Tensor:
    """
    Encode the query using the OpenCLIP model.
    Parameters
    ----------
    query : Union[str, Image.Image]
        The query, which can be a text string or an Image object.
    Returns
    -------
    torch.Tensor
        The encoded query vector.
    """
    if isinstance(query, Image.Image):
        query = preprocess(query).unsqueeze(0)  # Preprocess the image and add batch dimension
        with torch.no_grad():
            query_embedding = model.encode_image(query)  # Get image embedding
    elif isinstance(query, str):
        text = tokenizer(query, context_length=model.context_length)
        with torch.no_grad():
            query_embedding = model.encode_text(text)  # Get text embedding
    else:
        raise ValueError("Query must be either a string or an Image.")
    
    return query_embedding

def load_hf_datasets(key,dataset):
    """
    Load Datasets from Hugging Face as DF
    ---------------------------------------
    dataset_name: str - name of dataset on Hugging Face
    ---------------------------------------
    RETURNS: dataset as pandas dataframe
    """
    df = dataset[key].to_pandas()
            
    return df

def parallel_load_and_combine(dataset_keys, dataset):
    """
    Load datasets in parallel and combine Main and Split keys
    ----------------------------------------------------------
    dataset_keys: list - keys of the dataset (e.g., ['Main_1', 'Split_1', ...])
    dataset: DatasetDict - the loaded Hugging Face dataset
    ----------------------------------------------------------
    RETURNS: combined DataFrame from both Main and Split keys
    """
    # Separate keys into Main and Split lists
    main_keys = [key for key in dataset_keys if key.startswith('Main')]
    split_keys = [key for key in dataset_keys if key.startswith('Split')]

    def process_key(key, key_type):
        df = load_hf_datasets(key, dataset)
        return df

    # Parallel loading of Main keys
    with concurrent.futures.ThreadPoolExecutor() as executor:
        main_dfs = list(executor.map(lambda key: process_key(key, 'Main'), main_keys))

    # Parallel loading of Split keys
    with concurrent.futures.ThreadPoolExecutor() as executor:
        split_dfs = list(executor.map(lambda key: process_key(key, 'Split'), split_keys))

    # Combine Main DataFrames and Split DataFrames
    main_combined_df = pd.concat(main_dfs, ignore_index=True) if main_dfs else pd.DataFrame()
    split_combined_df = pd.concat(split_dfs, ignore_index=True) if split_dfs else pd.DataFrame()


    return main_combined_df, split_combined_df
    
def get_image_vectors(df):
    # Get the image vectors from the dataframe
    image_vectors = np.vstack(df['Vector'].to_numpy())
    return torch.tensor(image_vectors, dtype=torch.float32)


def search(query, df, limit, offset, scoring_func, search_in_images):
    if search_in_images:
        # Encode the image query
        query_vector = encode_query(query)
        
        
        # Get the image vectors from the dataframe
        image_vectors = get_image_vectors(df)
        
        
        # Calculate the cosine similarity between the query vector and each image vector
        query_vector = query_vector[0, :].detach().numpy()  # Detach and convert to a NumPy array
        image_vectors = image_vectors.detach().numpy()  # Convert the image vectors to a NumPy array
        cosine_similarities = cosine_similarity([query_vector], image_vectors)

        # Get the top K indices of the most similar image vectors
        top_k_indices = np.argsort(-cosine_similarities[0])[:limit]

        # Return the top K indices
        return top_k_indices

#Try Batch Search
def batch_search(query, df, batch_size=100000, limit=10):
    top_k_indices = []
    
    # Get the image vectors from the dataframe and ensure they are NumPy arrays
    vectors = get_image_vectors(df).numpy()  # Convert to NumPy array if it's a tensor
    
    # Encode the query and ensure it's a NumPy array
    query_vector = encode_query(query)[0].detach().numpy()  # Assuming the first element is the query embedding
    
    # Iterate over the batches and compute cosine similarities
    for i in range(0, len(vectors), batch_size):
        batch_vectors = vectors[i:i + batch_size]  # Extract a batch of vectors
        
        # Compute cosine similarity between the query vector and the batch
        batch_similarities = cosine_similarity([query_vector], batch_vectors)
        
        # Get the top-k similar vectors within this batch
        top_k_indices.extend(np.argsort(-batch_similarities[0])[:limit])
    
    return top_k_indices


def get_file_paths(df, top_k_indices, column_name = 'File_Path'):
    """
    Retrieve the file paths (or any specific column) from the DataFrame using the top K indices.
    
    Parameters:
    - df: pandas DataFrame containing the data
    - top_k_indices: numpy array of the top K indices
    - column_name: str, the name of the column to fetch (e.g., 'ImagePath')
    
    Returns:
    - top_k_paths: list of file paths or values from the specified column
    """
    # Fetch the specific column corresponding to the top K indices
    top_k_paths = df.iloc[top_k_indices][column_name].tolist()
    return top_k_paths
def get_cordinates(df, top_k_indices, column_name = 'Coordinate'):
    """
    Retrieve the file paths (or any specific column) from the DataFrame using the top K indices.
    
    Parameters:
    - df: pandas DataFrame containing the data
    - top_k_indices: numpy array of the top K indices
    - column_name: str, the name of the column to fetch (e.g., 'ImagePath')
    
    Returns:
    - top_k_paths: list of file paths or values from the specified column
    """
    # Fetch the specific column corresponding to the top K indices
    top_k_paths = df.iloc[top_k_indices][column_name].tolist()
    return top_k_paths

def get_images_from_s3_to_display(bucket_name, file_paths, AWS_ACCESS_KEY_ID,AWS_SECRET_ACCESS_KEY,  folder_name):
    """
    Retrieve and display images from AWS S3 in a Streamlit app.
    
    Parameters:
    - bucket_name: str, the name of the S3 bucket
    - file_paths: list, a list of file paths to retrieve from S3
    
    Returns:
    - None (directly displays images in the Streamlit app)
    """
    # Initialize S3 client
    s3 = boto3.client(
            's3',
            aws_access_key_id=AWS_ACCESS_KEY_ID,
            aws_secret_access_key=AWS_SECRET_ACCESS_KEY
        )

    # Iterate over file paths and display each image
    for file_path in file_paths:
        # Retrieve the image from S3
        s3_object = s3.get_object(Bucket=bucket_name, Key=f"{folder_name}{file_path}")
        img_data = s3_object['Body'].read()
        
        # Open the image using PIL and display it using Streamlit
        img = Image.open(BytesIO(img_data))
        st.image(img, caption=file_path, use_column_width=True)



def get_images_with_bounding_boxes_from_s3(bucket_name, file_paths, bounding_boxes, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, folder_name):
    """
    Retrieve and display images from AWS S3 with corresponding bounding boxes in a Streamlit app.
    
    Parameters:
    - bucket_name: str, the name of the S3 bucket
    - file_paths: list, a list of file paths to retrieve from S3
    - bounding_boxes: list of numpy arrays or lists, each containing coordinates of bounding boxes (in the form [x_min, y_min, x_max, y_max])
    - AWS_ACCESS_KEY_ID: str, AWS access key ID for authentication
    - AWS_SECRET_ACCESS_KEY: str, AWS secret access key for authentication
    - folder_name: str, the folder prefix in S3 bucket where the images are stored
    
    Returns:
    - None (directly displays images in the Streamlit app with bounding boxes)
    """
    # Initialize S3 client
    s3 = boto3.client(
        's3',
        aws_access_key_id=AWS_ACCESS_KEY_ID,
        aws_secret_access_key=AWS_SECRET_ACCESS_KEY
    )

    # Iterate over file paths and corresponding bounding boxes
    for file_path, box_coords in zip(file_paths, bounding_boxes):
        # Retrieve the image from S3
        s3_object = s3.get_object(Bucket=bucket_name, Key=f"{folder_name}{file_path}")
        img_data = s3_object['Body'].read()

        # Open the image using PIL
        img = Image.open(BytesIO(img_data))
        
        # Draw bounding boxes on the image
        draw = ImageDraw.Draw(img)

        # Ensure box_coords is iterable, in case it's a single numpy array or float value
        if isinstance(box_coords, (np.ndarray, list)):
            # Check if we have multiple bounding boxes or a single one
            if len(box_coords) > 0 and isinstance(box_coords[0], (np.ndarray, list)):
                # Multiple bounding boxes
                for box in box_coords:
                    x_min, y_min, x_max, y_max = map(int, box)  # Convert to integers
                    draw.rectangle([x_min, y_min, x_max, y_max], outline="red", width=3)
            else:
                # Single bounding box
                x_min, y_min, x_max, y_max = map(int, box_coords)
                draw.rectangle([x_min, y_min, x_max, y_max], outline="red", width=3)
        else:
            raise ValueError(f"Bounding box data for {file_path} is not in an iterable format.")
        
        # Display the image with bounding boxes using Streamlit
        st.image(img, caption=file_path, use_column_width=True)


def main():
    print('Begin Main')
    dataset_name = "WayveScenes"
    query = "black car"
    limit = 10
    offset = 0
    scoring_func = "cosine"
    search_in_images = True
    search_in_small_objects = True
    dataset = load_dataset(f"quasara-io/{dataset_name}")
    print('loaded dataset')
    dataset_keys = dataset.keys()
    main_df, split_df = parallel_load_and_combine(dataset_keys, dataset)
    #Now we get the coordinates and the stuff
    print('processed datasets')
    if search_in_small_objects:
        results = batch_search(query, split_df)
        print(results)
        top_k_paths = get_file_paths(split_df,results) 
        top_k_cordinates = get_cordinates(split_df, results)
        print(top_k_paths)
        print(top_k_cordinates)
        return top_k_paths, top_k_cordinates
    else:
        results = search(query, main_df, limit, offset, scoring_func, search_in_images)
        top_k_paths = get_file_paths(main_df,results) 
        print(top_k_paths)
        return top_k_paths


if __name__ == "__main__":
    main()