|
import numpy as np |
|
from sentence_transformers import SentenceTransformer, util |
|
from open_clip import create_model_from_pretrained, get_tokenizer |
|
import torch |
|
from datasets import load_dataset |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
import torch.nn as nn |
|
import boto3 |
|
import streamlit as st |
|
from PIL import Image |
|
from PIL import ImageDraw |
|
from io import BytesIO |
|
import pandas as pd |
|
from typing import List, Union |
|
import concurrent.futures |
|
|
|
|
|
|
|
model, preprocess = create_model_from_pretrained('hf-hub:timm/ViT-SO400M-14-SigLIP-384') |
|
tokenizer = get_tokenizer('hf-hub:timm/ViT-SO400M-14-SigLIP-384') |
|
|
|
|
|
|
|
def encode_query(query: Union[str, Image.Image]) -> torch.Tensor: |
|
""" |
|
Encode the query using the OpenCLIP model. |
|
Parameters |
|
---------- |
|
query : Union[str, Image.Image] |
|
The query, which can be a text string or an Image object. |
|
Returns |
|
------- |
|
torch.Tensor |
|
The encoded query vector. |
|
""" |
|
if isinstance(query, Image.Image): |
|
query = preprocess(query).unsqueeze(0) |
|
with torch.no_grad(): |
|
query_embedding = model.encode_image(query) |
|
elif isinstance(query, str): |
|
text = tokenizer(query, context_length=model.context_length) |
|
with torch.no_grad(): |
|
query_embedding = model.encode_text(text) |
|
else: |
|
raise ValueError("Query must be either a string or an Image.") |
|
|
|
return query_embedding |
|
|
|
def load_hf_datasets(key,dataset): |
|
""" |
|
Load Datasets from Hugging Face as DF |
|
--------------------------------------- |
|
dataset_name: str - name of dataset on Hugging Face |
|
--------------------------------------- |
|
RETURNS: dataset as pandas dataframe |
|
""" |
|
df = dataset[key].to_pandas() |
|
|
|
return df |
|
|
|
def parallel_load_and_combine(dataset_keys, dataset): |
|
""" |
|
Load datasets in parallel and combine Main and Split keys |
|
---------------------------------------------------------- |
|
dataset_keys: list - keys of the dataset (e.g., ['Main_1', 'Split_1', ...]) |
|
dataset: DatasetDict - the loaded Hugging Face dataset |
|
---------------------------------------------------------- |
|
RETURNS: combined DataFrame from both Main and Split keys |
|
""" |
|
|
|
main_keys = [key for key in dataset_keys if key.startswith('Main')] |
|
split_keys = [key for key in dataset_keys if key.startswith('Split')] |
|
|
|
def process_key(key, key_type): |
|
df = load_hf_datasets(key, dataset) |
|
return df |
|
|
|
|
|
with concurrent.futures.ThreadPoolExecutor() as executor: |
|
main_dfs = list(executor.map(lambda key: process_key(key, 'Main'), main_keys)) |
|
|
|
|
|
with concurrent.futures.ThreadPoolExecutor() as executor: |
|
split_dfs = list(executor.map(lambda key: process_key(key, 'Split'), split_keys)) |
|
|
|
|
|
main_combined_df = pd.concat(main_dfs, ignore_index=True) if main_dfs else pd.DataFrame() |
|
split_combined_df = pd.concat(split_dfs, ignore_index=True) if split_dfs else pd.DataFrame() |
|
|
|
|
|
return main_combined_df, split_combined_df |
|
|
|
def get_image_vectors(df): |
|
|
|
image_vectors = np.vstack(df['Vector'].to_numpy()) |
|
return torch.tensor(image_vectors, dtype=torch.float32) |
|
|
|
|
|
def search(query, df, limit, offset, scoring_func, search_in_images): |
|
if search_in_images: |
|
|
|
query_vector = encode_query(query) |
|
|
|
|
|
|
|
image_vectors = get_image_vectors(df) |
|
|
|
|
|
|
|
query_vector = query_vector[0, :].detach().numpy() |
|
image_vectors = image_vectors.detach().numpy() |
|
cosine_similarities = cosine_similarity([query_vector], image_vectors) |
|
|
|
|
|
top_k_indices = np.argsort(-cosine_similarities[0])[:limit] |
|
|
|
|
|
return top_k_indices |
|
|
|
|
|
def batch_search(query, df, batch_size=100000, limit=10): |
|
top_k_indices = [] |
|
|
|
|
|
vectors = get_image_vectors(df).numpy() |
|
|
|
|
|
query_vector = encode_query(query)[0].detach().numpy() |
|
|
|
|
|
for i in range(0, len(vectors), batch_size): |
|
batch_vectors = vectors[i:i + batch_size] |
|
|
|
|
|
batch_similarities = cosine_similarity([query_vector], batch_vectors) |
|
|
|
|
|
top_k_indices.extend(np.argsort(-batch_similarities[0])[:limit]) |
|
|
|
return top_k_indices |
|
|
|
|
|
def get_file_paths(df, top_k_indices, column_name = 'File_Path'): |
|
""" |
|
Retrieve the file paths (or any specific column) from the DataFrame using the top K indices. |
|
|
|
Parameters: |
|
- df: pandas DataFrame containing the data |
|
- top_k_indices: numpy array of the top K indices |
|
- column_name: str, the name of the column to fetch (e.g., 'ImagePath') |
|
|
|
Returns: |
|
- top_k_paths: list of file paths or values from the specified column |
|
""" |
|
|
|
top_k_paths = df.iloc[top_k_indices][column_name].tolist() |
|
return top_k_paths |
|
def get_cordinates(df, top_k_indices, column_name = 'Coordinate'): |
|
""" |
|
Retrieve the file paths (or any specific column) from the DataFrame using the top K indices. |
|
|
|
Parameters: |
|
- df: pandas DataFrame containing the data |
|
- top_k_indices: numpy array of the top K indices |
|
- column_name: str, the name of the column to fetch (e.g., 'ImagePath') |
|
|
|
Returns: |
|
- top_k_paths: list of file paths or values from the specified column |
|
""" |
|
|
|
top_k_paths = df.iloc[top_k_indices][column_name].tolist() |
|
return top_k_paths |
|
|
|
def get_images_from_s3_to_display(bucket_name, file_paths, AWS_ACCESS_KEY_ID,AWS_SECRET_ACCESS_KEY, folder_name): |
|
""" |
|
Retrieve and display images from AWS S3 in a Streamlit app. |
|
|
|
Parameters: |
|
- bucket_name: str, the name of the S3 bucket |
|
- file_paths: list, a list of file paths to retrieve from S3 |
|
|
|
Returns: |
|
- None (directly displays images in the Streamlit app) |
|
""" |
|
|
|
s3 = boto3.client( |
|
's3', |
|
aws_access_key_id=AWS_ACCESS_KEY_ID, |
|
aws_secret_access_key=AWS_SECRET_ACCESS_KEY |
|
) |
|
|
|
|
|
for file_path in file_paths: |
|
|
|
s3_object = s3.get_object(Bucket=bucket_name, Key=f"{folder_name}{file_path}") |
|
img_data = s3_object['Body'].read() |
|
|
|
|
|
img = Image.open(BytesIO(img_data)) |
|
st.image(img, caption=file_path, use_column_width=True) |
|
|
|
|
|
|
|
def get_images_with_bounding_boxes_from_s3(bucket_name, file_paths, bounding_boxes, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, folder_name): |
|
""" |
|
Retrieve and display images from AWS S3 with corresponding bounding boxes in a Streamlit app. |
|
|
|
Parameters: |
|
- bucket_name: str, the name of the S3 bucket |
|
- file_paths: list, a list of file paths to retrieve from S3 |
|
- bounding_boxes: list of numpy arrays or lists, each containing coordinates of bounding boxes (in the form [x_min, y_min, x_max, y_max]) |
|
- AWS_ACCESS_KEY_ID: str, AWS access key ID for authentication |
|
- AWS_SECRET_ACCESS_KEY: str, AWS secret access key for authentication |
|
- folder_name: str, the folder prefix in S3 bucket where the images are stored |
|
|
|
Returns: |
|
- None (directly displays images in the Streamlit app with bounding boxes) |
|
""" |
|
|
|
s3 = boto3.client( |
|
's3', |
|
aws_access_key_id=AWS_ACCESS_KEY_ID, |
|
aws_secret_access_key=AWS_SECRET_ACCESS_KEY |
|
) |
|
|
|
|
|
for file_path, box_coords in zip(file_paths, bounding_boxes): |
|
|
|
s3_object = s3.get_object(Bucket=bucket_name, Key=f"{folder_name}{file_path}") |
|
img_data = s3_object['Body'].read() |
|
|
|
|
|
img = Image.open(BytesIO(img_data)) |
|
|
|
|
|
draw = ImageDraw.Draw(img) |
|
|
|
|
|
if isinstance(box_coords, (np.ndarray, list)): |
|
|
|
if len(box_coords) > 0 and isinstance(box_coords[0], (np.ndarray, list)): |
|
|
|
for box in box_coords: |
|
x_min, y_min, x_max, y_max = map(int, box) |
|
draw.rectangle([x_min, y_min, x_max, y_max], outline="red", width=3) |
|
else: |
|
|
|
x_min, y_min, x_max, y_max = map(int, box_coords) |
|
draw.rectangle([x_min, y_min, x_max, y_max], outline="red", width=3) |
|
else: |
|
raise ValueError(f"Bounding box data for {file_path} is not in an iterable format.") |
|
|
|
|
|
st.image(img, caption=file_path, use_column_width=True) |
|
|
|
|
|
def main(): |
|
print('Begin Main') |
|
dataset_name = "WayveScenes" |
|
query = "black car" |
|
limit = 10 |
|
offset = 0 |
|
scoring_func = "cosine" |
|
search_in_images = True |
|
search_in_small_objects = True |
|
dataset = load_dataset(f"quasara-io/{dataset_name}") |
|
print('loaded dataset') |
|
dataset_keys = dataset.keys() |
|
main_df, split_df = parallel_load_and_combine(dataset_keys, dataset) |
|
|
|
print('processed datasets') |
|
if search_in_small_objects: |
|
results = batch_search(query, split_df) |
|
print(results) |
|
top_k_paths = get_file_paths(split_df,results) |
|
top_k_cordinates = get_cordinates(split_df, results) |
|
print(top_k_paths) |
|
print(top_k_cordinates) |
|
return top_k_paths, top_k_cordinates |
|
else: |
|
results = search(query, main_df, limit, offset, scoring_func, search_in_images) |
|
top_k_paths = get_file_paths(main_df,results) |
|
print(top_k_paths) |
|
return top_k_paths |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|