Spaces:
Sleeping
Sleeping
import os | |
import numpy as np | |
from typing import List, Tuple | |
import torch | |
from glob import glob | |
from PIL import Image | |
from tqdm import tqdm | |
import matplotlib.pyplot as plt | |
from transformers import CLIPProcessor, CLIPModel | |
from sentence_transformers import SentenceTransformer | |
import sqlite3 | |
from .vector_database import ( | |
VectorDB, | |
ImageEmbeddingCollectionSchema, | |
TextEmbeddingCollectionSchema, | |
) | |
class ImageSearchModule: | |
def __init__( | |
self, | |
image_embeddings_dir: str, | |
original_images_dir: str, | |
sqlite_db_path: str = "image_tracker.db", | |
): | |
self.image_embeddings_dir = image_embeddings_dir | |
self.original_images_dir = original_images_dir | |
self.vector_db = VectorDB() | |
self.vector_db.create_collection(ImageEmbeddingCollectionSchema) | |
self.vector_db.create_collection(TextEmbeddingCollectionSchema) | |
self.clip_model = CLIPModel.from_pretrained( | |
"wkcn/TinyCLIP-ViT-8M-16-Text-3M-YFCC15M" | |
) | |
self.clip_preprocess = CLIPProcessor.from_pretrained( | |
"wkcn/TinyCLIP-ViT-8M-16-Text-3M-YFCC15M" | |
) | |
self.text_embedding_model = SentenceTransformer("all-MiniLM-L6-v2") | |
self.sqlite_conn = sqlite3.connect(sqlite_db_path) | |
self._create_sqlite_table() | |
def _create_sqlite_table(self): | |
cursor = self.sqlite_conn.cursor() | |
cursor.execute( | |
""" | |
CREATE TABLE IF NOT EXISTS added_images ( | |
image_name TEXT PRIMARY KEY | |
) | |
""" | |
) | |
self.sqlite_conn.commit() | |
def add_images(self): | |
print("Adding images to vector databases") | |
cursor = self.sqlite_conn.cursor() | |
for filename in tqdm(os.listdir(self.image_embeddings_dir)): | |
if filename.startswith("resized_") and filename.endswith("_clip.npy"): | |
image_name = filename[ | |
8:-9 | |
] # Remove "resized_" prefix and "_clip.npy" suffix | |
cursor.execute( | |
"SELECT 1 FROM added_images WHERE image_name = ?", (image_name,) | |
) | |
if cursor.fetchone() is None: | |
clip_embedding_path = os.path.join( | |
self.image_embeddings_dir, filename | |
) | |
caption_embedding_path = os.path.join( | |
self.image_embeddings_dir, f"resized_{image_name}_caption.npy" | |
) | |
if os.path.exists(clip_embedding_path) and os.path.exists( | |
caption_embedding_path | |
): | |
with open(clip_embedding_path, "rb") as buffer: | |
image_embedding = np.frombuffer( | |
buffer.read(), dtype=np.float32 | |
).reshape(512) | |
with open(caption_embedding_path, "rb") as buffer: | |
text_embedding = np.frombuffer( | |
buffer.read(), dtype=np.float32 | |
).reshape(384) | |
if self.vector_db.insert_record( | |
ImageEmbeddingCollectionSchema.collection_name, | |
image_embedding, | |
image_name, | |
): | |
self.vector_db.insert_record( | |
TextEmbeddingCollectionSchema.collection_name, | |
text_embedding, | |
image_name, | |
) | |
cursor.execute( | |
"INSERT INTO added_images (image_name) VALUES (?)", | |
(image_name,), | |
) | |
self.sqlite_conn.commit() | |
print("Finished adding images to vector databases") | |
def search_by_image( | |
self, query_image_path: str, top_k: int = 5, similarity_threshold: float = 0.5 | |
) -> List[Tuple[str, float]]: | |
if not os.path.exists(query_image_path): | |
print(f"Image file not found: {query_image_path}") | |
return [] | |
try: | |
query_image = Image.open(query_image_path) | |
query_embedding = self._get_image_embedding(query_image) | |
results = self.vector_db.client.search( | |
collection_name=ImageEmbeddingCollectionSchema.collection_name, | |
data=[query_embedding], | |
output_fields=["filename"], | |
search_params={"metric_type": "COSINE"}, | |
limit=top_k, | |
).pop() | |
return [(item["entity"]["filename"], item["distance"]) for item in results if item["distance"] >= similarity_threshold] | |
except Exception as e: | |
print(f"Error processing image: {e}") | |
return [] | |
def search_by_text( | |
self, query_text: str, top_k: int = 5,similarity_threshold: float = 0.5 | |
) -> List[Tuple[str, float]]: | |
if not query_text.strip(): | |
print("Empty text query") | |
return [] | |
try: | |
query_embedding = self._get_text_embedding(query_text) | |
results = self.vector_db.client.search( | |
collection_name=TextEmbeddingCollectionSchema.collection_name, | |
data=[query_embedding], | |
search_params={"metric_type": "COSINE"}, | |
output_fields=["filename"], | |
limit=top_k, | |
).pop() | |
return [(item["entity"]["filename"], item["distance"]) for item in results if item["distance"] >= similarity_threshold] | |
except Exception as e: | |
print(f"Error processing text: {e}") | |
return [] | |
def _get_image_embedding(self, image: Image.Image) -> np.ndarray: | |
with torch.no_grad(): | |
image_input = self.clip_preprocess(images=image, return_tensors="pt")[ | |
"pixel_values" | |
].to(self.clip_model.device) | |
image_features = self.clip_model.get_image_features(image_input) | |
return image_features.cpu().numpy().flatten() | |
def _get_text_embedding(self, text: str) -> np.ndarray: | |
with torch.no_grad(): | |
embedding = self.text_embedding_model.encode(text).flatten() | |
return embedding | |
def display_results(self, results: List[Tuple[str, float]]): | |
if not results: | |
print("No results to display.") | |
return | |
num_images = min(5, len(results)) | |
fig, axes = plt.subplots(1, num_images, figsize=(20, 4)) | |
axes = [axes] if num_images == 1 else axes | |
for i, (image_name, similarity) in enumerate(results[:num_images]): | |
pattern = os.path.join( | |
self.original_images_dir, f"resized_{image_name}" + "*" | |
) | |
matching_files = glob(pattern) | |
if matching_files: | |
image_path = matching_files[0] | |
img = Image.open(image_path) | |
axes[i].imshow(img) | |
axes[i].set_title(f"Similarity: {similarity:.2f}") | |
axes[i].axis("off") | |
else: | |
print(f"No matching image found for {image_name}") | |
axes[i].text(0.5, 0.5, "Image not found", ha="center", va="center") | |
axes[i].axis("off") | |
plt.tight_layout() | |
plt.show() | |
def __del__(self): | |
if hasattr(self, "sqlite_conn"): | |
self.sqlite_conn.close() | |
if __name__ == "__main__": | |
from pathlib import Path | |
import requests | |
PROJECT_ROOT = Path(__file__).resolve().parent.parent | |
search = ImageSearchModule( | |
image_embeddings_dir=str(PROJECT_ROOT / "data/features"), | |
original_images_dir=str(PROJECT_ROOT / "data/images"), | |
) | |
search.add_images() | |
# Search by image | |
img_url = ( | |
"https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg" | |
) | |
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB") | |
raw_image.save(PROJECT_ROOT / "test.jpg") | |
image_results = search.search_by_image(str(PROJECT_ROOT / "test.jpg")) | |
print("Image search results:") | |
search.display_results(image_results) | |
# Search by text | |
text_results = search.search_by_text("Images of Nature") | |
print("Text search results:") | |
search.display_results(text_results) |