inie2003 commited on
Commit
1a36398
·
verified ·
1 Parent(s): a3902aa

Added small object search

Browse files
Files changed (1) hide show
  1. helper.py +154 -21
helper.py CHANGED
@@ -8,8 +8,11 @@ import torch.nn as nn
8
  import boto3
9
  import streamlit as st
10
  from PIL import Image
 
11
  from io import BytesIO
 
12
  from typing import List, Union
 
13
 
14
 
15
  # Initialize the model globally to avoid reloading each time
@@ -21,12 +24,10 @@ tokenizer = get_tokenizer('hf-hub:timm/ViT-SO400M-14-SigLIP-384')
21
  def encode_query(query: Union[str, Image.Image]) -> torch.Tensor:
22
  """
23
  Encode the query using the OpenCLIP model.
24
-
25
  Parameters
26
  ----------
27
  query : Union[str, Image.Image]
28
  The query, which can be a text string or an Image object.
29
-
30
  Returns
31
  -------
32
  torch.Tensor
@@ -45,21 +46,49 @@ def encode_query(query: Union[str, Image.Image]) -> torch.Tensor:
45
 
46
  return query_embedding
47
 
48
- def load_hf_datasets(dataset_name):
49
  """
50
  Load Datasets from Hugging Face as DF
51
  ---------------------------------------
52
  dataset_name: str - name of dataset on Hugging Face
53
  ---------------------------------------
54
-
55
  RETURNS: dataset as pandas dataframe
56
  """
57
- dataset = load_dataset(f"quasara-io/{dataset_name}")
58
- # Access only the 'Main' split
59
- main_dataset = dataset['Main_1']
60
- # Convert to Pandas DataFrame
61
- df = main_dataset.to_pandas()
62
  return df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
  def get_image_vectors(df):
65
  # Get the image vectors from the dataframe
@@ -67,7 +96,7 @@ def get_image_vectors(df):
67
  return torch.tensor(image_vectors, dtype=torch.float32)
68
 
69
 
70
- def search(query, df, limit, offset, scoring_func, search_in_images, search_in_small_objects):
71
  if search_in_images:
72
  # Encode the image query
73
  query_vector = encode_query(query)
@@ -79,7 +108,7 @@ def search(query, df, limit, offset, scoring_func, search_in_images, search_in_s
79
 
80
  # Calculate the cosine similarity between the query vector and each image vector
81
  query_vector = query_vector[0, :].detach().numpy() # Detach and convert to a NumPy array
82
- image_vectors = image_vectors.detach().numpy() # Convert the image vectors to a NumPy array
83
  cosine_similarities = cosine_similarity([query_vector], image_vectors)
84
 
85
  # Get the top K indices of the most similar image vectors
@@ -88,6 +117,29 @@ def search(query, df, limit, offset, scoring_func, search_in_images, search_in_s
88
  # Return the top K indices
89
  return top_k_indices
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  def get_file_paths(df, top_k_indices, column_name = 'File_Path'):
92
  """
93
  Retrieve the file paths (or any specific column) from the DataFrame using the top K indices.
@@ -97,6 +149,21 @@ def get_file_paths(df, top_k_indices, column_name = 'File_Path'):
97
  - top_k_indices: numpy array of the top K indices
98
  - column_name: str, the name of the column to fetch (e.g., 'ImagePath')
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  Returns:
101
  - top_k_paths: list of file paths or values from the specified column
102
  """
@@ -104,8 +171,7 @@ def get_file_paths(df, top_k_indices, column_name = 'File_Path'):
104
  top_k_paths = df.iloc[top_k_indices][column_name].tolist()
105
  return top_k_paths
106
 
107
-
108
- def get_images_from_s3_to_display(bucket_name, file_paths, AWS_ACCESS_KEY_ID,AWS_SECRET_ACCESS_KEY, folder_name= None):
109
  """
110
  Retrieve and display images from AWS S3 in a Streamlit app.
111
 
@@ -135,21 +201,88 @@ def get_images_from_s3_to_display(bucket_name, file_paths, AWS_ACCESS_KEY_ID,AWS
135
 
136
 
137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  def main():
139
- dataset_name = "StopSign_test"
 
140
  query = "black car"
141
  limit = 10
142
  offset = 0
143
  scoring_func = "cosine"
144
  search_in_images = True
145
- search_in_small_objects = False
146
-
147
- df = load_hf_datasets(dataset_name)
148
- results = search(query, df, limit, offset, scoring_func, search_in_images, search_in_small_objects)
149
- top_k_paths = get_file_paths(df,results)
150
- return top_k_paths
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
 
153
  if __name__ == "__main__":
154
  main()
155
-
 
8
  import boto3
9
  import streamlit as st
10
  from PIL import Image
11
+ from PIL import ImageDraw
12
  from io import BytesIO
13
+ import pandas as pd
14
  from typing import List, Union
15
+ import concurrent.futures
16
 
17
 
18
  # Initialize the model globally to avoid reloading each time
 
24
  def encode_query(query: Union[str, Image.Image]) -> torch.Tensor:
25
  """
26
  Encode the query using the OpenCLIP model.
 
27
  Parameters
28
  ----------
29
  query : Union[str, Image.Image]
30
  The query, which can be a text string or an Image object.
 
31
  Returns
32
  -------
33
  torch.Tensor
 
46
 
47
  return query_embedding
48
 
49
+ def load_hf_datasets(key,dataset):
50
  """
51
  Load Datasets from Hugging Face as DF
52
  ---------------------------------------
53
  dataset_name: str - name of dataset on Hugging Face
54
  ---------------------------------------
 
55
  RETURNS: dataset as pandas dataframe
56
  """
57
+ df = dataset[key].to_pandas()
58
+
 
 
 
59
  return df
60
+
61
+ def parallel_load_and_combine(dataset_keys, dataset):
62
+ """
63
+ Load datasets in parallel and combine Main and Split keys
64
+ ----------------------------------------------------------
65
+ dataset_keys: list - keys of the dataset (e.g., ['Main_1', 'Split_1', ...])
66
+ dataset: DatasetDict - the loaded Hugging Face dataset
67
+ ----------------------------------------------------------
68
+ RETURNS: combined DataFrame from both Main and Split keys
69
+ """
70
+ # Separate keys into Main and Split lists
71
+ main_keys = [key for key in dataset_keys if key.startswith('Main')]
72
+ split_keys = [key for key in dataset_keys if key.startswith('Split')]
73
+
74
+ def process_key(key, key_type):
75
+ df = load_hf_datasets(key, dataset)
76
+ return df
77
+
78
+ # Parallel loading of Main keys
79
+ with concurrent.futures.ThreadPoolExecutor() as executor:
80
+ main_dfs = list(executor.map(lambda key: process_key(key, 'Main'), main_keys))
81
+
82
+ # Parallel loading of Split keys
83
+ with concurrent.futures.ThreadPoolExecutor() as executor:
84
+ split_dfs = list(executor.map(lambda key: process_key(key, 'Split'), split_keys))
85
+
86
+ # Combine Main DataFrames and Split DataFrames
87
+ main_combined_df = pd.concat(main_dfs, ignore_index=True) if main_dfs else pd.DataFrame()
88
+ split_combined_df = pd.concat(split_dfs, ignore_index=True) if split_dfs else pd.DataFrame()
89
+
90
+
91
+ return main_combined_df, split_combined_df
92
 
93
  def get_image_vectors(df):
94
  # Get the image vectors from the dataframe
 
96
  return torch.tensor(image_vectors, dtype=torch.float32)
97
 
98
 
99
+ def search(query, df, limit, offset, scoring_func, search_in_images):
100
  if search_in_images:
101
  # Encode the image query
102
  query_vector = encode_query(query)
 
108
 
109
  # Calculate the cosine similarity between the query vector and each image vector
110
  query_vector = query_vector[0, :].detach().numpy() # Detach and convert to a NumPy array
111
+ image_vectors = image_vectoßrs.detach().numpy() # Convert the image vectors to a NumPy array
112
  cosine_similarities = cosine_similarity([query_vector], image_vectors)
113
 
114
  # Get the top K indices of the most similar image vectors
 
117
  # Return the top K indices
118
  return top_k_indices
119
 
120
+ #Try Batch Search
121
+ def batch_search(query, df, batch_size=100000, limit=10):
122
+ top_k_indices = []
123
+
124
+ # Get the image vectors from the dataframe and ensure they are NumPy arrays
125
+ vectors = get_image_vectors(df).numpy() # Convert to NumPy array if it's a tensor
126
+
127
+ # Encode the query and ensure it's a NumPy array
128
+ query_vector = encode_query(query)[0].detach().numpy() # Assuming the first element is the query embedding
129
+
130
+ # Iterate over the batches and compute cosine similarities
131
+ for i in range(0, len(vectors), batch_size):
132
+ batch_vectors = vectors[i:i + batch_size] # Extract a batch of vectors
133
+
134
+ # Compute cosine similarity between the query vector and the batch
135
+ batch_similarities = cosine_similarity([query_vector], batch_vectors)
136
+
137
+ # Get the top-k similar vectors within this batch
138
+ top_k_indices.extend(np.argsort(-batch_similarities[0])[:limit])
139
+
140
+ return top_k_indices
141
+
142
+
143
  def get_file_paths(df, top_k_indices, column_name = 'File_Path'):
144
  """
145
  Retrieve the file paths (or any specific column) from the DataFrame using the top K indices.
 
149
  - top_k_indices: numpy array of the top K indices
150
  - column_name: str, the name of the column to fetch (e.g., 'ImagePath')
151
 
152
+ Returns:
153
+ - top_k_paths: list of file paths or values from the specified column
154
+ """
155
+ # Fetch the specific column corresponding to the top K indices
156
+ top_k_paths = df.iloc[top_k_indices][column_name].tolist()
157
+ return top_k_paths
158
+ def get_cordinates(df, top_k_indices, column_name = 'Coordinate'):
159
+ """
160
+ Retrieve the file paths (or any specific column) from the DataFrame using the top K indices.
161
+
162
+ Parameters:
163
+ - df: pandas DataFrame containing the data
164
+ - top_k_indices: numpy array of the top K indices
165
+ - column_name: str, the name of the column to fetch (e.g., 'ImagePath')
166
+
167
  Returns:
168
  - top_k_paths: list of file paths or values from the specified column
169
  """
 
171
  top_k_paths = df.iloc[top_k_indices][column_name].tolist()
172
  return top_k_paths
173
 
174
+ def get_images_from_s3_to_display(bucket_name, file_paths, AWS_ACCESS_KEY_ID,AWS_SECRET_ACCESS_KEY, folder_name):
 
175
  """
176
  Retrieve and display images from AWS S3 in a Streamlit app.
177
 
 
201
 
202
 
203
 
204
+ def get_images_with_bounding_boxes_from_s3(bucket_name, file_paths, bounding_boxes, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, folder_name):
205
+ """
206
+ Retrieve and display images from AWS S3 with corresponding bounding boxes in a Streamlit app.
207
+
208
+ Parameters:
209
+ - bucket_name: str, the name of the S3 bucket
210
+ - file_paths: list, a list of file paths to retrieve from S3
211
+ - bounding_boxes: list of numpy arrays or lists, each containing coordinates of bounding boxes (in the form [x_min, y_min, x_max, y_max])
212
+ - AWS_ACCESS_KEY_ID: str, AWS access key ID for authentication
213
+ - AWS_SECRET_ACCESS_KEY: str, AWS secret access key for authentication
214
+ - folder_name: str, the folder prefix in S3 bucket where the images are stored
215
+
216
+ Returns:
217
+ - None (directly displays images in the Streamlit app with bounding boxes)
218
+ """
219
+ # Initialize S3 client
220
+ s3 = boto3.client(
221
+ 's3',
222
+ aws_access_key_id=AWS_ACCESS_KEY_ID,
223
+ aws_secret_access_key=AWS_SECRET_ACCESS_KEY
224
+ )
225
+
226
+ # Iterate over file paths and corresponding bounding boxes
227
+ for file_path, box_coords in zip(file_paths, bounding_boxes):
228
+ # Retrieve the image from S3
229
+ s3_object = s3.get_object(Bucket=bucket_name, Key=f"{folder_name}{file_path}")
230
+ img_data = s3_object['Body'].read()
231
+
232
+ # Open the image using PIL
233
+ img = Image.open(BytesIO(img_data))
234
+
235
+ # Draw bounding boxes on the image
236
+ draw = ImageDraw.Draw(img)
237
+
238
+ # Ensure box_coords is iterable, in case it's a single numpy array or float value
239
+ if isinstance(box_coords, (np.ndarray, list)):
240
+ # Check if we have multiple bounding boxes or a single one
241
+ if len(box_coords) > 0 and isinstance(box_coords[0], (np.ndarray, list)):
242
+ # Multiple bounding boxes
243
+ for box in box_coords:
244
+ x_min, y_min, x_max, y_max = map(int, box) # Convert to integers
245
+ draw.rectangle([x_min, y_min, x_max, y_max], outline="red", width=3)
246
+ else:
247
+ # Single bounding box
248
+ x_min, y_min, x_max, y_max = map(int, box_coords)
249
+ draw.rectangle([x_min, y_min, x_max, y_max], outline="red", width=3)
250
+ else:
251
+ raise ValueError(f"Bounding box data for {file_path} is not in an iterable format.")
252
+
253
+ # Display the image with bounding boxes using Streamlit
254
+ st.image(img, caption=file_path, use_column_width=True)
255
+
256
+
257
  def main():
258
+ print('Begin Main')
259
+ dataset_name = "WayveScenes"
260
  query = "black car"
261
  limit = 10
262
  offset = 0
263
  scoring_func = "cosine"
264
  search_in_images = True
265
+ search_in_small_objects = True
266
+ dataset = load_dataset(f"quasara-io/{dataset_name}")
267
+ print('loaded dataset')
268
+ dataset_keys = dataset.keys()
269
+ main_df, split_df = parallel_load_and_combine(dataset_keys, dataset)
270
+ #Now we get the coordinates and the stuff
271
+ print('processed datasets')
272
+ if search_in_small_objects:
273
+ results = batch_search(query, split_df)
274
+ print(results)
275
+ top_k_paths = get_file_paths(split_df,results)
276
+ top_k_cordinates = get_cordinates(split_df, results)
277
+ print(top_k_paths)
278
+ print(top_k_cordinates)
279
+ return top_k_paths, top_k_cordinates
280
+ else:
281
+ results = search(query, main_df, limit, offset, scoring_func, search_in_images)
282
+ top_k_paths = get_file_paths(main_df,results)
283
+ print(top_k_paths)
284
+ return top_k_paths
285
 
286
 
287
  if __name__ == "__main__":
288
  main()