satishjasthij commited on
Commit
d1df841
Β·
1 Parent(s): 767749c
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ data/
2
+ .DS_Store
README.md CHANGED
@@ -1,13 +1,162 @@
1
  ---
2
- title: PicMatch
3
- emoji: πŸ“‰
4
- colorFrom: gray
5
- colorTo: blue
6
  sdk: gradio
 
7
  sdk_version: 4.39.0
 
 
8
  app_file: app.py
9
- pinned: false
10
- license: mit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: 'PicMatch: Your Visual Search Companion'
3
+ emoji: πŸ“·πŸ”
4
+ colorFrom: blue
5
+ colorTo: green
6
  sdk: gradio
7
+ python_version: 3.9
8
  sdk_version: 4.39.0
9
+ suggested_hardware: t4-small
10
+ suggested_storage: medium
11
  app_file: app.py
12
+ fullWidth: true
13
+ header: mini
14
+ short_description: Search images using text or other images as queries.
15
+ models:
16
+ - wkcn/TinyCLIP-ViT-8M-16-Text-3M-YFCC15M
17
+ - Salesforce/blip-image-captioning-base
18
+
19
+
20
+ tags:
21
+ - image search
22
+ - visual search
23
+ - image processing
24
+ - CLIP
25
+ - image captioning
26
+ thumbnail: https://example.com/thumbnail.png
27
+ pinned: true
28
+ hf_oauth: false
29
+ disable_embedding: false
30
+ startup_duration_timeout: 30m
31
+ custom_headers:
32
+ cross-origin-embedder-policy: require-corp
33
+ cross-origin-opener-policy: same-origin
34
+ cross-origin-resource-policy: cross-origin
35
+
36
  ---
37
 
38
+ # πŸ“Έ PicMatch: Your Visual Search Companion πŸ”
39
+
40
+ PicMatch lets you effortlessly search through your image archive using either a text description or another image as your query. Find those needle-in-a-haystack photos in a flash! ✨
41
+
42
+ ## πŸš€ Getting Started: Let the Fun Begin!
43
+
44
+ 1. **Prerequisites:** Ensure you have Python 3.9 or higher installed on your system. 🐍
45
+
46
+ 2. **Create a Virtual Environment:**
47
+ ```bash
48
+ python -m venv env
49
+ ```
50
+
51
+ 3. **Activate the Environment:**
52
+ ```bash
53
+ source ./venv/bin/activate
54
+ ```
55
+
56
+ 4. **Install Dependencies:**
57
+ ```bash
58
+ python -m pip install -r requirements.txt
59
+ ```
60
+
61
+ 5. **Start the App (with Sample Data):**
62
+ ```bash
63
+ python app.py
64
+ ```
65
+
66
+ 6. **Open Your Browser:** Head to `localhost:7860` to access the PicMatch interface. 🌐
67
+
68
+ ## πŸ“‚ Data: Organize Your Visual Treasures
69
+
70
+ Make sure you have the following folders in your project's root directory:
71
+
72
+ ```
73
+ data
74
+ β”œβ”€β”€ images
75
+ └── features
76
+ ```
77
+
78
+ ## πŸ› οΈ Image Pipeline: Download & Process with Speed ⚑
79
+
80
+ The `engine/download_data.py` Python script streamlines downloading and processing images from a list of URLs. It's designed for performance and reliability:
81
+
82
+ - **Async Operations:** Uses `asyncio` for concurrent image downloading and processing. ⏩
83
+ - **Rate Limiting:** Follows API usage rules to prevent blocks with a `RateLimiter`. 🚦
84
+ - **Parallel Resizing:** Employs a `ProcessPoolExecutor` for fast image resizing. βš™οΈ
85
+ - **State Management:** Saves progress in a JSON file so you can resume later. πŸ’Ύ
86
+
87
+ ### πŸ—οΈ Key Components:
88
+
89
+ - **`ImagePipeline` Class:** Manages the entire pipeline, its state, and rate limiting. πŸŽ›οΈ
90
+ - **Functions:** Handle URL feeding (`url_feeder`), downloading (`image_downloader`), and processing (`image_processor`). πŸ“₯
91
+ - **`ImageSaver` Class:** Defines how images are processed and saved. πŸ–ΌοΈ
92
+ - **`resize_image` Function:** Ensures image resizing maintains the correct aspect ratio. πŸ“
93
+
94
+ ### πŸƒ How it Works:
95
+
96
+ 1. **Start:** Configure the pipeline with your URL list, download limits, and rate settings.
97
+ 2. **Feed URLs:** Asynchronously read URLs from your file.
98
+ 3. **Download:** Download images concurrently while respecting rate limits.
99
+ 4. **Process:** Save the original images and resize them in parallel.
100
+ 5. **Save State:** Regularly save progress to avoid starting over if interrupted.
101
+
102
+ To get the sample data run the command
103
+ ```bash
104
+ cd engine && python download_data.py
105
+ ```
106
+
107
+ ## ✨ Feature Creation: Making Your Images Searchable ✨
108
+
109
+ This step prepares your images for searching. We generate two types of embeddings:
110
+
111
+ - **Visual Embeddings (CLIP):** Capture the visual content of your images. πŸ‘οΈβ€πŸ—¨οΈ
112
+ - **Textual Embeddings:** Create embeddings from image captions for text-based search. πŸ’¬
113
+
114
+ To generate these features run the command
115
+ ```bash
116
+ cd engine && python generate_features.py
117
+ ```
118
+ This process uses these awesome models from Hugging Face:
119
+
120
+ - TinyCLIP: `wkcn/TinyCLIP-ViT-8M-16-Text-3M-YFCC15M`
121
+ - BLIP Image Captioning: `Salesforce/blip-image-captioning-base`
122
+ - SentenceTransformer: `all-MiniLM-L6-v2`
123
+
124
+ ## ⚑ Asynchronous Feature Extraction: Supercharge Your Process ⚑
125
+
126
+ This script extracts image features (both visual and textual) efficiently:
127
+
128
+ - **Asynchronous:** Loads images, extracts features, and saves them concurrently. ⚑
129
+ - **Dual Embeddings:** Creates both CLIP (visual) and caption (textual) embeddings. πŸ–ΌοΈπŸ“
130
+ - **Checkpoints:** Keeps track of progress and allows resuming from interruptions. πŸ”„
131
+ - **Parallel:** Uses multiple CPU cores for feature extraction. βš™οΈ
132
+
133
+
134
+ ## πŸ“Š Vector Database Module: Milvus for Fast Search 🚀
135
+
136
+ This module connects to the Milvus vector database to store and search your image embeddings:
137
+
138
+ - **Milvus:** A high-performance database built for handling vector data. πŸ“Š
139
+ - **Easy Interface:** Provides a simple way to manage embeddings and perform searches. πŸ”
140
+ - **Single Server:** Ensures only one Milvus server is running for efficiency.
141
+ - **Indexing:** Automatically creates an index to speed up your searches. πŸš€
142
+ - **Similarity Search:** Find the most similar images using cosine similarity. πŸ’―
143
+
144
+
145
+
146
+ ## πŸ“š References: The Brains Behind PicMatch 🧠
147
+
148
+ PicMatch leverages these incredible open-source projects:
149
+
150
+ - **TinyCLIP:** The visual powerhouse for understanding your images.
151
+ - πŸ‘‰ [https://huggingface.co/wkcn/TinyCLIP-ViT-8M-16-Text-3M-YFCC15M](https://huggingface.co/wkcn/TinyCLIP-ViT-8M-16-Text-3M-YFCC15M)
152
+
153
+ - **Image Captioning:** The wordsmith that describes your photos in detail.
154
+ - πŸ‘‰ [https://huggingface.co/Salesforce/blip-image-captioning-base](https://huggingface.co/Salesforce/blip-image-captioning-base)
155
+
156
+ - **Sentence Transformers:** Turns captions into embeddings for text-based search.
157
+ - πŸ‘‰ [https://sbert.net](https://sbert.net)
158
+
159
+ - **Unsplash:** Images used were taken from Unsplash's open source data
160
+ - πŸ‘‰ [https://github.com/unsplash/datasets](https://github.com/unsplash/datasets)
161
+
162
+ Let's give credit where credit is due! πŸ™Œ These projects make PicMatch smarter and more capable.
app.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ from PIL import Image
4
+ from engine.search import ImageSearchModule
5
+ import os
6
+ from pathlib import Path
7
+
8
+ PROJECT_ROOT = Path(__file__).resolve().parent
9
+
10
+ def check_dirs():
11
+ dirs = {
12
+ "Data": (PROJECT_ROOT / "data"),
13
+ "Images": (PROJECT_ROOT / "data" / "images"),
14
+ "Features": (PROJECT_ROOT / "data" / "features")
15
+ }
16
+ for dir_name, dir_path in dirs.items():
17
+ if not dir_path.exists():
18
+ raise FileNotFoundError(f"{dir_name} directory not found: {dir_path}")
19
+
20
+ print("All data directories exist βœ…")
21
+
22
+
23
+ check_dirs()
24
+
25
+ # Initialize the ImageSearchModule
26
+ search = ImageSearchModule(
27
+ image_embeddings_dir=str(PROJECT_ROOT / "data/features"),
28
+ original_images_dir=str(PROJECT_ROOT / "data/images"),
29
+ )
30
+ print("Add image embeddings and caption embeddings to vector database")
31
+ search.add_images()
32
+
33
+
34
+ def search_images(input_data, search_type):
35
+ if search_type == "image" and input_data is not None:
36
+ # Fix: Get the file path directly from the input data
37
+ results = search.search_by_image(input_data, top_k=10, similarity_threshold=0)
38
+ elif search_type == "text" and input_data.strip():
39
+ results = search.search_by_text(input_data, top_k=10, similarity_threshold=0)
40
+ else:
41
+ return [(Image.new("RGB", (100, 100), color="gray"), "No results")] * 10
42
+
43
+ images_with_captions = []
44
+ for image_name, similarity in results:
45
+ image_path = os.path.join(search.original_images_dir, f"resized_{image_name}")
46
+ matching_files = [
47
+ f
48
+ for f in os.listdir(search.original_images_dir)
49
+ if f.startswith(f"resized_{image_name}")
50
+ ]
51
+ if matching_files:
52
+ img = Image.open(
53
+ os.path.join(search.original_images_dir, matching_files[0])
54
+ )
55
+ images_with_captions.append((img, f"Similarity: {similarity:.2f}"))
56
+ else:
57
+ images_with_captions.append(
58
+ (Image.new("RGB", (100, 100), color="gray"), "Image not found")
59
+ )
60
+
61
+ # Pad the results if less than 10 images are found
62
+ while len(images_with_captions) < 10:
63
+ images_with_captions.append(
64
+ (Image.new("RGB", (100, 100), color="gray"), "No result")
65
+ )
66
+
67
+ return images_with_captions
68
+
69
+
70
+ with gr.Blocks() as demo:
71
+ gr.Markdown("# Image Search App")
72
+ with gr.Tab("Image Search"):
73
+ # Fix: Change input type to 'filepath'
74
+ image_input = gr.Image(type="filepath", label="Upload an image")
75
+ image_button = gr.Button("Search by Image")
76
+
77
+ with gr.Tab("Text Search"):
78
+ text_input = gr.Textbox(label="Enter text query")
79
+ text_button = gr.Button("Search by Text")
80
+
81
+ gallery = gr.Gallery(
82
+ label="Search Results",
83
+ show_label=False,
84
+ elem_id="gallery",
85
+ columns=2,
86
+ height="auto",
87
+ )
88
+
89
+ image_button.click(
90
+ fn=search_images,
91
+ inputs=[image_input, gr.Textbox(value="image", visible=False)],
92
+ outputs=[gallery],
93
+ )
94
+
95
+ text_button.click(
96
+ fn=search_images,
97
+ inputs=[text_input, gr.Textbox(value="text", visible=False)],
98
+ outputs=[gallery],
99
+ )
100
+
101
+ demo.launch()
copy_images_features.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import shutil
4
+ import glob
5
+
6
+ from tqdm import tqdm
7
+
8
+ def sample_images_and_features(image_folder, feature_folder, sample_size, dest_image_folder, dest_feature_folder):
9
+ """
10
+ Randomly samples a specified number of resized images along with their corresponding
11
+ CLIP and caption features, and copies them to new folders.
12
+
13
+ Args:
14
+ image_folder (str): Path to the folder containing resized images.
15
+ feature_folder (str): Path to the folder containing feature files.
16
+ sample_size (int): Number of images to sample.
17
+ dest_image_folder (str): Destination folder for sampled images.
18
+ dest_feature_folder (str): Destination folder for sampled feature files.
19
+ """
20
+
21
+ # Ensure destination folders exist
22
+ os.makedirs(dest_image_folder, exist_ok=True)
23
+ os.makedirs(dest_feature_folder, exist_ok=True)
24
+
25
+ # Get all resized image file names
26
+ image_files = glob.glob(os.path.join(image_folder, "resized_*.jpg"))
27
+ image_files.extend(glob.glob(os.path.join(image_folder, "resized_*.png")))
28
+ image_files.extend(glob.glob(os.path.join(image_folder, "resized_*.jpeg")))
29
+
30
+ # Check if there are enough images
31
+ if len(image_files) < sample_size:
32
+ raise ValueError("Not enough resized images in the source folder.")
33
+
34
+ # Sample a subset of image files
35
+ sampled_images = random.sample(image_files, sample_size)
36
+
37
+ # Copy images and corresponding feature files
38
+ for image_path in tqdm(sampled_images):
39
+ image_name = os.path.basename(image_path)
40
+ base_name, _ = os.path.splitext(image_name)
41
+
42
+ # Construct paths for CLIP and caption feature files
43
+ clip_feature_path = os.path.join(feature_folder, f"{base_name}_clip.npy")
44
+ caption_feature_path = os.path.join(feature_folder, f"{base_name}_caption.npy")
45
+
46
+ # Copy image file
47
+ shutil.copy2(image_path, dest_image_folder) # copy2 preserves metadata
48
+
49
+ # Copy feature files (if they exist)
50
+ if os.path.exists(clip_feature_path):
51
+ shutil.copy2(clip_feature_path, dest_feature_folder)
52
+ if os.path.exists(caption_feature_path):
53
+ shutil.copy2(caption_feature_path, dest_feature_folder)
54
+
55
+ if __name__ == "__main__":
56
+ from pathlib import Path
57
+
58
+ PROJECT_ROOT = Path(__file__).resolve().parent
59
+ image_folder = str(PROJECT_ROOT / "data/images")
60
+ feature_folder = str(PROJECT_ROOT / "data/features")
61
+ sample_size = 10
62
+ dest_image_folder = str(PROJECT_ROOT / "data_temp/images")
63
+ dest_feature_folder = str(PROJECT_ROOT / "data_temp/features")
64
+ sample_images_and_features(image_folder, feature_folder, sample_size, dest_image_folder, dest_feature_folder)
engine/__init__.py ADDED
File without changes
engine/download_data.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ from pathlib import Path
3
+ import time
4
+ import json
5
+ import os, io
6
+
7
+ import aiofiles
8
+ import aiohttp
9
+ import asyncio
10
+ from PIL import Image
11
+ from abc import ABC, abstractmethod
12
+ from concurrent.futures import ProcessPoolExecutor
13
+ from dataclasses import asdict, dataclass
14
+
15
+
16
+ @dataclass
17
+ class ProcessState:
18
+ urls_processed: int = 0
19
+ images_downloaded: int = 0
20
+ images_saved: int = 0
21
+ images_resized: int = 0
22
+
23
+
24
+ class ImageProcessor(ABC):
25
+ @abstractmethod
26
+ def process(self, image: bytes, filename: str) -> None:
27
+ pass
28
+
29
+
30
+ class ImageSaver(ImageProcessor):
31
+ async def process(self, image: bytes, filename: str) -> None:
32
+ async with aiofiles.open(filename, "wb") as f:
33
+ await f.write(image)
34
+
35
+
36
+ def resize_image(image: bytes, filename: str, max_size: int = 300) -> None:
37
+ with Image.open(io.BytesIO(image)) as img:
38
+ img.thumbnail((max_size, max_size))
39
+ img.save(filename, optimize=True, quality=85)
40
+
41
+
42
+ class RateLimiter:
43
+ """
44
+ High-Level Concept: The Token Bucket Algorithm
45
+ ==============================================
46
+ The Rate_Limiter class implements what's known as the "Token Bucket" algorithm. Imagine you have a bucket that can hold a certain number of tokens. Here's how it works:
47
+
48
+ The bucket is filled with tokens at a constant rate.
49
+ When you want to perform an action (in our case, make an API request), you need to take a token from the bucket.
50
+ If there's a token available, you can perform the action immediately.
51
+ If there are no tokens, you have to wait until a new token is added to the bucket.
52
+ The bucket has a maximum capacity, so tokens don't accumulate indefinitely when not used.
53
+
54
+ This mechanism allows for both steady-state rate limiting and handling short bursts of activity.
55
+
56
+ In the constructor:
57
+ ===================
58
+ rate: is how many tokens we add per time period (e.g., 10 tokens per second)
59
+
60
+ per: is the time period (usually 1 second)
61
+
62
+ burst: is the bucket size (maximum number of tokens)
63
+
64
+ We start with a full bucket (self.tokens = burst)
65
+ We note the current time (self.updated_at)
66
+
67
+ Logic:
68
+ ======
69
+ 1. Calculate how much time has passed since we last updated the token count.
70
+
71
+ 2. Add tokens based on the time passed and our rate:
72
+ self.tokens += time_passed * (self.rate / self.per)
73
+
74
+ 3. If we've added too many tokens, cap it at our maximum (burst size).
75
+
76
+ 4. Update our "last updated" time.
77
+
78
+ 5. If we have at least one token:
79
+ Remove a token (self.tokens -= 1)
80
+ Return immediately, allowing the API call to proceed
81
+
82
+ 6. If we don't have a token:
83
+ Calculate how long we need to wait for the next token
84
+ Sleep for that duration
85
+
86
+ Let's walk through an example:
87
+ ==============================
88
+ Suppose we set up our RateLimiter like this:
89
+
90
+ Copylimiter = RateLimiter(rate=10, per=1, burst=10)
91
+
92
+ This means:
93
+ - We allow 10 requests per second on average
94
+ - We can burst up to 10 requests at once
95
+ - After the burst, we'll be limited to 1 request every 0.1 seconds
96
+
97
+ Now, imagine a sequence of API calls:
98
+
99
+ 1. The first 10 calls will happen immediately (burst capacity)
100
+ 2. The 11th call will wait for 0.1 seconds (time to generate 1 token)
101
+ 3. Subsequent calls will each wait about 0.1 seconds
102
+
103
+ If there's a pause in API calls, tokens will accumulate (up to the burst limit), allowing for another burst of activity.
104
+
105
+ This mechanism ensures that:
106
+ 1. We respect the average rate limit (10 per second in this example)
107
+ 2. We can handle short bursts of activity (up to 10 at once)
108
+ 3. We smoothly regulate requests when operating at capacity
109
+ """
110
+
111
+ def __init__(self, rate: float, per: float = 1.0, burst: int = 1):
112
+ self.rate = rate
113
+ self.per = per
114
+ self.burst = burst
115
+ self.tokens = burst
116
+ self.updated_at = time.monotonic()
117
+
118
+ async def wait(self):
119
+ while True:
120
+ now = time.monotonic()
121
+ time_passed = now - self.updated_at
122
+ self.tokens += time_passed * (self.rate / self.per)
123
+ if self.tokens > self.burst:
124
+ self.tokens = self.burst
125
+ self.updated_at = now
126
+
127
+ if self.tokens >= 1:
128
+ self.tokens -= 1
129
+ return
130
+ else:
131
+ await asyncio.sleep((1 - self.tokens) / (self.rate / self.per))
132
+
133
+
134
+ class ImagePipeline:
135
+ def __init__(
136
+ self,
137
+ txt_file: str,
138
+ loop: asyncio.AbstractEventLoop,
139
+ max_concurrent_downloads: int = 10,
140
+ max_workers: int = max(os.cpu_count() - 4, 4),
141
+ rate_limit: float = 10,
142
+ rate_limit_period: float = 1,
143
+ downloaded_images_dir: str = "",
144
+ ):
145
+ self.txt_file = txt_file
146
+ self.loop = loop
147
+ self.url_queue = asyncio.Queue(maxsize=1000)
148
+ self.image_queue = asyncio.Queue(maxsize=100)
149
+ self.semaphore = asyncio.Semaphore(max_concurrent_downloads)
150
+ self.state = ProcessState()
151
+ self.state_file = "pipeline_state.json"
152
+ self.saver = ImageSaver()
153
+ self.process_pool = ProcessPoolExecutor(max_workers=max_workers)
154
+ self.rate_limiter = RateLimiter(
155
+ rate=rate_limit, per=rate_limit_period, burst=max_concurrent_downloads
156
+ )
157
+ self.downloaded_images_dir = Path(downloaded_images_dir)
158
+
159
+ async def url_feeder(self):
160
+ try:
161
+ print(f"Starting to read URLs from {self.txt_file}")
162
+ async with aiofiles.open(self.txt_file, mode="r") as f:
163
+ line_number = 0
164
+ async for line in f:
165
+ line_number += 1
166
+ if line_number <= self.state.urls_processed:
167
+ continue
168
+
169
+ url = line.strip()
170
+ if url: # Skip empty lines
171
+ await self.url_queue.put(url)
172
+ self.state.urls_processed += 1
173
+
174
+ # Check if we need to wait for the queue to have space
175
+ if self.url_queue.qsize() >= self.url_queue.maxsize - 1:
176
+ await asyncio.sleep(0.1)
177
+ except Exception as e:
178
+ print(f"Error in url_feeder: {e}")
179
+ finally:
180
+ await self.url_queue.put(None)
181
+
182
+ async def image_downloader(self):
183
+ print("Starting image downloader")
184
+ async with aiohttp.ClientSession() as session:
185
+ while True:
186
+ url = await self.url_queue.get()
187
+ if url is None:
188
+ print("Finished downloading images")
189
+ await self.image_queue.put(None)
190
+ break
191
+ try:
192
+ await self.rate_limiter.wait() # Wait for rate limit
193
+ async with self.semaphore:
194
+ async with session.get(url) as response:
195
+ if response.status == 200:
196
+ image = await response.read()
197
+ await self.image_queue.put((image, url))
198
+ self.state.images_downloaded += 1
199
+ if self.state.images_downloaded % 100 == 0:
200
+ print(
201
+ f"Downloaded {self.state.images_downloaded} images"
202
+ )
203
+ except Exception as e:
204
+ print(f"Error downloading {url}: {e}")
205
+ finally:
206
+ self.url_queue.task_done()
207
+
208
+ async def image_processor(self):
209
+ print("Starting image processor")
210
+ while True:
211
+ item = await self.image_queue.get()
212
+ if item is None:
213
+ print("Finished processing images")
214
+ break
215
+ image, url = item
216
+ filename = os.path.basename(url)
217
+ if not filename.lower().endswith((".png", ".jpg", ".jpeg")):
218
+ filename += ".png"
219
+ try:
220
+ # Save the original image
221
+ await self.saver.process(
222
+ image, str(self.downloaded_images_dir / f"original_{filename}")
223
+ )
224
+ self.state.images_saved += 1
225
+ if self.state.images_resized % 100 == 0:
226
+ print(f"Processed {self.state.images_resized} images")
227
+
228
+ # Resize the image using the process pool
229
+ # loop = asyncio.get_running_loop()
230
+ await self.loop.run_in_executor(
231
+ self.process_pool,
232
+ resize_image,
233
+ image,
234
+ str(self.downloaded_images_dir / f"resized_{filename}"),
235
+ )
236
+ self.state.images_resized += 1
237
+ except Exception as e:
238
+ print(f"Error processing {url}: {e}")
239
+ finally:
240
+ self.image_queue.task_done()
241
+
242
+ def save_state(self):
243
+ with open(self.state_file, "w") as f:
244
+ json.dump(asdict(self.state), f)
245
+
246
+ def load_state(self):
247
+ if os.path.exists(self.state_file):
248
+ with open(self.state_file, "r") as f:
249
+ self.state = ProcessState(**json.load(f))
250
+
251
+ async def run(self):
252
+ print("Starting pipeline")
253
+ self.load_state()
254
+ print(f"Loaded state: {self.state}")
255
+ tasks = [
256
+ asyncio.create_task(self.url_feeder()),
257
+ asyncio.create_task(self.image_downloader()),
258
+ asyncio.create_task(self.image_processor()),
259
+ ]
260
+ try:
261
+ await asyncio.gather(*tasks)
262
+ except Exception as e:
263
+ print(f"Pipeline error: {e}")
264
+ finally:
265
+ self.save_state()
266
+ print(f"Final state: {self.state}")
267
+ self.process_pool.shutdown()
268
+ print("Pipeline finished")
269
+
270
+
271
+ if __name__ == "__main__":
272
+ from pathlib import Path
273
+
274
+ PROJECT_ROOT = Path(__file__).resolve().parent
275
+ loop = asyncio.get_event_loop()
276
+ text_file = PROJECT_ROOT / "data/image_urls.txt"
277
+ if not text_file.exists():
278
+ import pandas as pd
279
+
280
+ dataframe = pd.read_csv(PROJECT_ROOT / "data/photos.tsv000", sep="\t")
281
+ num_image_urls = len(dataframe)
282
+ print(f"Number of image urls: {num_image_urls}")
283
+ with open(text_file, "w") as f:
284
+ for url in dataframe["photo_image_url"]:
285
+ f.write(url + "\n")
286
+ print("Started downloading images")
287
+ pipeline = ImagePipeline(
288
+ txt_file=text_file,
289
+ loop=loop,
290
+ rate_limit=100,
291
+ rate_limit_period=1,
292
+ downloaded_images_dir=str(PROJECT_ROOT / "data/data/images"),
293
+ )
294
+ # asyncio.run(pipeline.run())
295
+ loop.run_until_complete(pipeline.run())
296
+ print("Finished downloading images")
engine/generate_features.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import os
3
+ import logging
4
+ from PIL import Image
5
+ import torch
6
+ from transformers import (
7
+ CLIPProcessor,
8
+ CLIPModel,
9
+ BlipProcessor,
10
+ BlipForConditionalGeneration,
11
+ )
12
+ from sentence_transformers import SentenceTransformer
13
+ import numpy as np
14
+ import aiofiles
15
+ import json
16
+ from abc import ABC, abstractmethod
17
+ from typing import Set, Tuple
18
+ from concurrent.futures import ProcessPoolExecutor
19
+ from dataclasses import dataclass, field
20
+
21
+ logging.basicConfig(level=logging.INFO)
22
+ logger = logging.getLogger(__name__)
23
+
24
+ device = "cpu"
25
+
26
+
27
+ @dataclass
28
+ class State:
29
+ processed_files: Set[str] = field(default_factory=set)
30
+
31
+ def to_dict(self) -> dict:
32
+ return {"processed_files": list(self.processed_files)}
33
+
34
+ @staticmethod
35
+ def from_dict(state_dict: dict) -> "State":
36
+ return State(processed_files=set(state_dict.get("processed_files", [])))
37
+
38
+
39
+ class ImageProcessor(ABC):
40
+ @abstractmethod
41
+ def process(self, image: Image.Image) -> np.ndarray:
42
+ pass
43
+
44
+
45
+ class CLIPImageProcessor(ImageProcessor):
46
+ def __init__(self):
47
+ self.model = CLIPModel.from_pretrained(
48
+ "wkcn/TinyCLIP-ViT-8M-16-Text-3M-YFCC15M"
49
+ ).to(device)
50
+ self.processor = CLIPProcessor.from_pretrained(
51
+ "wkcn/TinyCLIP-ViT-8M-16-Text-3M-YFCC15M"
52
+ )
53
+ print("Initialized CLIP model and processor")
54
+
55
+ def process(self, image: Image.Image) -> np.ndarray:
56
+ inputs = self.processor(images=image, return_tensors="pt").to(device)
57
+ outputs = self.model.get_image_features(**inputs)
58
+ return outputs.detach().cpu().numpy()
59
+
60
+
61
+ class ImageCaptioningProcessor(ImageProcessor):
62
+ def __init__(self):
63
+ self.image_caption_model = BlipForConditionalGeneration.from_pretrained(
64
+ "Salesforce/blip-image-captioning-base"
65
+ ).to(device)
66
+ self.image_caption_processor = BlipProcessor.from_pretrained(
67
+ "Salesforce/blip-image-captioning-base"
68
+ )
69
+ self.text_embedding_model = SentenceTransformer(
70
+ "all-MiniLM-L6-v2", device=device
71
+ )
72
+ print("Initialized BLIP model and processor")
73
+
74
+ def process(self, image: Image.Image) -> np.ndarray:
75
+ inputs = self.image_caption_processor(images=image, return_tensors="pt").to(
76
+ device
77
+ )
78
+ output = self.image_caption_model.generate(**inputs)
79
+ caption = self.image_caption_processor.decode(
80
+ output[0], skip_special_tokens=True
81
+ )
82
+ # embedding dim 384
83
+ return self.text_embedding_model.encode(caption).flatten()
84
+
85
+
86
+ class ImageFeatureExtractor:
87
+ def __init__(
88
+ self,
89
+ clip_processor: CLIPImageProcessor,
90
+ caption_processor: ImageCaptioningProcessor,
91
+ max_queue_size: int = 100,
92
+ checkpoint_file: str = "checkpoint.json",
93
+ ):
94
+ self.clip_processor = clip_processor
95
+ self.caption_processor = caption_processor
96
+ self.image_queue = asyncio.Queue(maxsize=max_queue_size)
97
+ self.processed_images_queue = asyncio.Queue()
98
+ self.checkpoint_file = checkpoint_file
99
+ self.state = self.load_state()
100
+ self.executor = ProcessPoolExecutor()
101
+ self.total_images = 0
102
+ self.processed_count = 0
103
+ print(
104
+ "Initialized ImageFeatureExtractor with checkpoint file:", checkpoint_file
105
+ )
106
+
107
+ async def image_loader(self, input_folder: str):
108
+ print(f"Loading images from {input_folder}")
109
+ for filename in os.listdir(input_folder):
110
+ if "resized_" in filename and filename not in self.state.processed_files:
111
+ try:
112
+ file_path = os.path.join(input_folder, filename)
113
+ await self.image_queue.put((filename, file_path))
114
+ self.total_images += 1
115
+ print(f"Loaded image {filename} into queue")
116
+ except Exception as e:
117
+ logger.error(f"Error loading image {filename}: {e}")
118
+ await self.image_queue.put(None) # Sentinel to signal end of images
119
+ print(f"Total images to process: {self.total_images}")
120
+
121
+ async def image_processor_worker(self, loop: asyncio.AbstractEventLoop):
122
+ while True:
123
+ item = await self.image_queue.get()
124
+ if item is None:
125
+ await self.image_queue.put(None) # Propagate sentinel
126
+ break
127
+ filename, file_path = item
128
+ try:
129
+ print(f"Processing image {filename}")
130
+ image = Image.open(file_path)
131
+ clip_embedding, caption_embedding = await asyncio.gather(
132
+ loop.run_in_executor(
133
+ self.executor, self.clip_processor.process, image
134
+ ),
135
+ loop.run_in_executor(
136
+ self.executor, self.caption_processor.process, image
137
+ ),
138
+ )
139
+ await self.processed_images_queue.put(
140
+ (filename, clip_embedding, caption_embedding)
141
+ )
142
+ print(f"Processed image {filename}")
143
+ except Exception as e:
144
+ logger.error(f"Error processing image {filename}: {e}")
145
+ finally:
146
+ self.image_queue.task_done()
147
+
148
+ async def save_processed_images(self, output_folder: str):
149
+ while self.processed_count < self.total_images:
150
+ filename, clip_embedding, caption_embedding = (
151
+ await self.processed_images_queue.get()
152
+ )
153
+ try:
154
+ clip_output_path = os.path.join(
155
+ output_folder, f"{os.path.splitext(filename)[0]}_clip.npy"
156
+ )
157
+ caption_output_path = os.path.join(
158
+ output_folder, f"{os.path.splitext(filename)[0]}_caption.npy"
159
+ )
160
+
161
+ await asyncio.gather(
162
+ self.save_embedding(clip_output_path, clip_embedding),
163
+ self.save_embedding(caption_output_path, caption_embedding),
164
+ )
165
+
166
+ self.state.processed_files.add(filename)
167
+ self.save_state()
168
+ self.processed_count += 1
169
+ print(f"Saved processed embeddings for {filename}")
170
+ except Exception as e:
171
+ logger.error(f"Error saving processed image {filename}: {e}")
172
+ finally:
173
+ self.processed_images_queue.task_done()
174
+
175
+ async def save_embedding(self, output_path: str, embedding: np.ndarray):
176
+ async with aiofiles.open(output_path, "wb") as f:
177
+ await f.write(embedding.tobytes())
178
+
179
+ def load_state(self) -> State:
180
+ try:
181
+ with open(self.checkpoint_file, "r") as f:
182
+ state_dict = json.load(f)
183
+ print("Loaded state from checkpoint")
184
+ return State.from_dict(state_dict)
185
+ except (FileNotFoundError, json.JSONDecodeError):
186
+ print("No checkpoint found, starting with empty state")
187
+ return State()
188
+
189
+ def save_state(self):
190
+ with open(self.checkpoint_file, "w") as f:
191
+ json.dump(self.state.to_dict(), f)
192
+ print("Saved state to checkpoint")
193
+
194
+ async def run(
195
+ self,
196
+ input_folder: str,
197
+ output_folder: str,
198
+ loop: asyncio.AbstractEventLoop,
199
+ num_workers: int = 2,
200
+ ):
201
+ os.makedirs(output_folder, exist_ok=True)
202
+ print(f"Output folder {output_folder} created")
203
+
204
+ tasks = [
205
+ loop.create_task(self.image_loader(input_folder)),
206
+ loop.create_task(self.save_processed_images(output_folder)),
207
+ ]
208
+ tasks.extend(
209
+ [
210
+ loop.create_task(self.image_processor_worker(loop))
211
+ for _ in range(num_workers)
212
+ ]
213
+ )
214
+
215
+ await asyncio.gather(*tasks)
216
+
217
+
218
+ class ImageFeatureExtractorFactory:
219
+ @staticmethod
220
+ def create() -> ImageFeatureExtractor:
221
+ print(
222
+ "Creating ImageFeatureExtractor with CLIPImageProcessor and ImageCaptioningProcessor"
223
+ )
224
+ return ImageFeatureExtractor(CLIPImageProcessor(), ImageCaptioningProcessor())
225
+
226
+
227
+ async def main(loop: asyncio.AbstractEventLoop, input_folder: str, output_folder: str):
228
+ print("Starting main function")
229
+
230
+ extractor = ImageFeatureExtractorFactory.create()
231
+
232
+ try:
233
+ await extractor.run(input_folder, output_folder, loop)
234
+ except Exception as e:
235
+ logger.error(f"An error occurred during execution: {e}")
236
+ finally:
237
+ logger.info("Image processing completed.")
238
+
239
+
240
+ if __name__ == "__main__":
241
+ from pathlib import Path
242
+
243
+ PROJECT_ROOT = Path(__file__).resolve().parent.parent
244
+ loop = asyncio.new_event_loop()
245
+ asyncio.set_event_loop(loop)
246
+ print("Event loop created and set")
247
+ input_folder = str(PROJECT_ROOT / "data/images")
248
+ output_folder = str(PROJECT_ROOT / "data/features")
249
+ loop.run_until_complete(main(loop, input_folder, output_folder))
250
+ loop.close()
251
+ print("Event loop closed")
engine/search.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ from typing import List, Tuple
4
+ import torch
5
+ from glob import glob
6
+ from PIL import Image
7
+ from tqdm import tqdm
8
+ import matplotlib.pyplot as plt
9
+ from transformers import CLIPProcessor, CLIPModel
10
+ from sentence_transformers import SentenceTransformer
11
+ import sqlite3
12
+ from .vector_database import (
13
+ VectorDB,
14
+ ImageEmbeddingCollectionSchema,
15
+ TextEmbeddingCollectionSchema,
16
+ )
17
+
18
+
19
+ class ImageSearchModule:
20
+ def __init__(
21
+ self,
22
+ image_embeddings_dir: str,
23
+ original_images_dir: str,
24
+ sqlite_db_path: str = "image_tracker.db",
25
+ ):
26
+ self.image_embeddings_dir = image_embeddings_dir
27
+ self.original_images_dir = original_images_dir
28
+ self.vector_db = VectorDB()
29
+ self.vector_db.create_collection(ImageEmbeddingCollectionSchema)
30
+ self.vector_db.create_collection(TextEmbeddingCollectionSchema)
31
+
32
+ self.clip_model = CLIPModel.from_pretrained(
33
+ "wkcn/TinyCLIP-ViT-8M-16-Text-3M-YFCC15M"
34
+ )
35
+ self.clip_preprocess = CLIPProcessor.from_pretrained(
36
+ "wkcn/TinyCLIP-ViT-8M-16-Text-3M-YFCC15M"
37
+ )
38
+ self.text_embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
39
+
40
+ self.sqlite_conn = sqlite3.connect(sqlite_db_path)
41
+ self._create_sqlite_table()
42
+
43
+ def _create_sqlite_table(self):
44
+ cursor = self.sqlite_conn.cursor()
45
+ cursor.execute(
46
+ """
47
+ CREATE TABLE IF NOT EXISTS added_images (
48
+ image_name TEXT PRIMARY KEY
49
+ )
50
+ """
51
+ )
52
+ self.sqlite_conn.commit()
53
+
54
+ def add_images(self):
55
+ print("Adding images to vector databases")
56
+ cursor = self.sqlite_conn.cursor()
57
+
58
+ for filename in tqdm(os.listdir(self.image_embeddings_dir)):
59
+ if filename.startswith("resized_") and filename.endswith("_clip.npy"):
60
+ image_name = filename[
61
+ 8:-9
62
+ ] # Remove "resized_" prefix and "_clip.npy" suffix
63
+
64
+ cursor.execute(
65
+ "SELECT 1 FROM added_images WHERE image_name = ?", (image_name,)
66
+ )
67
+ if cursor.fetchone() is None:
68
+ clip_embedding_path = os.path.join(
69
+ self.image_embeddings_dir, filename
70
+ )
71
+ caption_embedding_path = os.path.join(
72
+ self.image_embeddings_dir, f"resized_{image_name}_caption.npy"
73
+ )
74
+
75
+ if os.path.exists(clip_embedding_path) and os.path.exists(
76
+ caption_embedding_path
77
+ ):
78
+ with open(clip_embedding_path, "rb") as buffer:
79
+ image_embedding = np.frombuffer(
80
+ buffer.read(), dtype=np.float32
81
+ ).reshape(512)
82
+ with open(caption_embedding_path, "rb") as buffer:
83
+ text_embedding = np.frombuffer(
84
+ buffer.read(), dtype=np.float32
85
+ ).reshape(384)
86
+
87
+ if self.vector_db.insert_record(
88
+ ImageEmbeddingCollectionSchema.collection_name,
89
+ image_embedding,
90
+ image_name,
91
+ ):
92
+ self.vector_db.insert_record(
93
+ TextEmbeddingCollectionSchema.collection_name,
94
+ text_embedding,
95
+ image_name,
96
+ )
97
+ cursor.execute(
98
+ "INSERT INTO added_images (image_name) VALUES (?)",
99
+ (image_name,),
100
+ )
101
+ self.sqlite_conn.commit()
102
+
103
+ print("Finished adding images to vector databases")
104
+
105
+ def search_by_image(
106
+ self, query_image_path: str, top_k: int = 5, similarity_threshold: float = 0.5
107
+ ) -> List[Tuple[str, float]]:
108
+ if not os.path.exists(query_image_path):
109
+ print(f"Image file not found: {query_image_path}")
110
+ return []
111
+ try:
112
+ query_image = Image.open(query_image_path)
113
+ query_embedding = self._get_image_embedding(query_image)
114
+ results = self.vector_db.client.search(
115
+ collection_name=ImageEmbeddingCollectionSchema.collection_name,
116
+ data=[query_embedding],
117
+ output_fields=["filename"],
118
+ search_params={"metric_type": "COSINE"},
119
+ limit=top_k,
120
+ ).pop()
121
+ return [(item["entity"]["filename"], item["distance"]) for item in results if item["distance"] >= similarity_threshold]
122
+ except Exception as e:
123
+ print(f"Error processing image: {e}")
124
+ return []
125
+
126
+ def search_by_text(
127
+ self, query_text: str, top_k: int = 5,similarity_threshold: float = 0.5
128
+ ) -> List[Tuple[str, float]]:
129
+ if not query_text.strip():
130
+ print("Empty text query")
131
+ return []
132
+ try:
133
+ query_embedding = self._get_text_embedding(query_text)
134
+ results = self.vector_db.client.search(
135
+ collection_name=TextEmbeddingCollectionSchema.collection_name,
136
+ data=[query_embedding],
137
+ search_params={"metric_type": "COSINE"},
138
+ output_fields=["filename"],
139
+ limit=top_k,
140
+ ).pop()
141
+ return [(item["entity"]["filename"], item["distance"]) for item in results if item["distance"] >= similarity_threshold]
142
+ except Exception as e:
143
+ print(f"Error processing text: {e}")
144
+ return []
145
+
146
+ def _get_image_embedding(self, image: Image.Image) -> np.ndarray:
147
+ with torch.no_grad():
148
+ image_input = self.clip_preprocess(images=image, return_tensors="pt")[
149
+ "pixel_values"
150
+ ].to(self.clip_model.device)
151
+ image_features = self.clip_model.get_image_features(image_input)
152
+ return image_features.cpu().numpy().flatten()
153
+
154
+ def _get_text_embedding(self, text: str) -> np.ndarray:
155
+ with torch.no_grad():
156
+ embedding = self.text_embedding_model.encode(text).flatten()
157
+ return embedding
158
+
159
+ def display_results(self, results: List[Tuple[str, float]]):
160
+ if not results:
161
+ print("No results to display.")
162
+ return
163
+
164
+ num_images = min(5, len(results))
165
+ fig, axes = plt.subplots(1, num_images, figsize=(20, 4))
166
+ axes = [axes] if num_images == 1 else axes
167
+
168
+ for i, (image_name, similarity) in enumerate(results[:num_images]):
169
+ pattern = os.path.join(
170
+ self.original_images_dir, f"resized_{image_name}" + "*"
171
+ )
172
+ matching_files = glob(pattern)
173
+ if matching_files:
174
+ image_path = matching_files[0]
175
+ img = Image.open(image_path)
176
+ axes[i].imshow(img)
177
+ axes[i].set_title(f"Similarity: {similarity:.2f}")
178
+ axes[i].axis("off")
179
+ else:
180
+ print(f"No matching image found for {image_name}")
181
+ axes[i].text(0.5, 0.5, "Image not found", ha="center", va="center")
182
+ axes[i].axis("off")
183
+
184
+ plt.tight_layout()
185
+ plt.show()
186
+
187
+ def __del__(self):
188
+ if hasattr(self, "sqlite_conn"):
189
+ self.sqlite_conn.close()
190
+
191
+
192
+ if __name__ == "__main__":
193
+ from pathlib import Path
194
+ import requests
195
+
196
+ PROJECT_ROOT = Path(__file__).resolve().parent.parent
197
+ search = ImageSearchModule(
198
+ image_embeddings_dir=str(PROJECT_ROOT / "data/features"),
199
+ original_images_dir=str(PROJECT_ROOT / "data/images"),
200
+ )
201
+ search.add_images()
202
+
203
+ # Search by image
204
+ img_url = (
205
+ "https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg"
206
+ )
207
+ raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
208
+ raw_image.save(PROJECT_ROOT / "test.jpg")
209
+ image_results = search.search_by_image(str(PROJECT_ROOT / "test.jpg"))
210
+ print("Image search results:")
211
+ search.display_results(image_results)
212
+
213
+ # Search by text
214
+ text_results = search.search_by_text("Images of Nature")
215
+ print("Text search results:")
216
+ search.display_results(text_results)
engine/upload_data_to_hf.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from huggingface_hub import HfApi
3
+ from pathlib import Path
4
+
5
+ PROJECT_ROOT = Path(__file__).resolve().parent.parent
6
+ api = HfApi()
7
+ print("Uploading data.....")
8
+ api.upload_folder(
9
+ folder_path=str(PROJECT_ROOT / "data"),
10
+ repo_id="satishjasthij/Unsplash-Visual-Semantic",
11
+ repo_type="space",
12
+ token=os.getenv("HUGGINGFACE_TOKEN"),
13
+ commit_message="add dataset",
14
+ create_pr=True,
15
+ )
16
+ print("Finished uploading data")
engine/vector_database.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, asdict
2
+ from pathlib import Path
3
+ import random
4
+ import numpy as np
5
+ from pymilvus import MilvusClient
6
+
7
+
8
+ @dataclass
9
+ class MilvusServer:
10
+ uri: str = "milvus.db"
11
+
12
+
13
+ @dataclass
14
+ class EmbeddingCollectionSchema:
15
+ collection_name: str
16
+ vector_field_name: str
17
+ dimension: int
18
+ auto_id: bool
19
+ enable_dynamic_field: bool
20
+ metric_type: str
21
+
22
+
23
+ ImageEmbeddingCollectionSchema = EmbeddingCollectionSchema(
24
+ collection_name="image_embeddings",
25
+ vector_field_name="embedding",
26
+ dimension=512,
27
+ auto_id=True,
28
+ enable_dynamic_field=True,
29
+ metric_type="COSINE",
30
+ )
31
+
32
+ TextEmbeddingCollectionSchema = EmbeddingCollectionSchema(
33
+ collection_name="text_embeddings",
34
+ vector_field_name="embedding",
35
+ dimension=384,
36
+ auto_id=True,
37
+ enable_dynamic_field=True,
38
+ metric_type="COSINE",
39
+ )
40
+
41
+
42
+ class VectorDB:
43
+
44
+ def __init__(self, client: MilvusClient = MilvusClient(uri=MilvusServer.uri)):
45
+ self.client = client
46
+
47
+ def create_collection(self, schema: EmbeddingCollectionSchema):
48
+ if self.client.has_collection(collection_name=schema.collection_name):
49
+ print(f"Collection {schema.collection_name} already exists")
50
+ return True
51
+ # self.client.drop_collection(collection_name=schema.collection_name)
52
+ print(f"Creating collection {schema.collection_name}")
53
+ self.client.create_collection(**asdict(schema))
54
+ print(f"Collection {schema.collection_name} created")
55
+ return True
56
+
57
+ def insert_record(
58
+ self, collection_name: str, embedding: np.ndarray, file_path: str
59
+ ) -> bool:
60
+ try:
61
+ self.client.insert(
62
+ collection_name=collection_name,
63
+ data={"embedding": embedding, "filename": file_path},
64
+ )
65
+ except Exception as e:
66
+ print(f"Error inserting record: {e}")
67
+ return False
68
+ return True
requirements.txt ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ aiohttp==3.9.5
3
+ aiosignal==1.3.1
4
+ annotated-types==0.7.0
5
+ anyio==4.4.0
6
+ asttokens==2.4.1
7
+ async-timeout==4.0.3
8
+ attrs==23.2.0
9
+ black==24.4.2
10
+ certifi==2024.7.4
11
+ charset-normalizer==3.3.2
12
+ click==8.1.7
13
+ contourpy==1.2.1
14
+ cycler==0.12.1
15
+ decorator==5.1.1
16
+ dnspython==2.6.1
17
+ email-validator==2.2.0
18
+ environs==9.5.0
19
+ exceptiongroup==1.2.2
20
+ executing==2.0.1
21
+ fastapi==0.111.1
22
+ fastapi-cli==0.0.4
23
+ ffmpy==0.3.2
24
+ filelock==3.15.4
25
+ fonttools==4.53.1
26
+ frozenlist==1.4.1
27
+ fsspec==2024.6.1
28
+ gradio==4.39.0
29
+ gradio-client==1.1.1
30
+ grpcio==1.63.0
31
+ h11==0.14.0
32
+ httpcore==1.0.5
33
+ httptools==0.6.1
34
+ httpx==0.27.0
35
+ huggingface-hub==0.24.0
36
+ idna==3.7
37
+ importlib-resources==6.4.0
38
+ ipython==8.18.1
39
+ jedi==0.19.1
40
+ jinja2==3.1.4
41
+ joblib==1.4.2
42
+ kiwisolver==1.4.5
43
+ markdown-it-py==3.0.0
44
+ markupsafe==2.1.5
45
+ marshmallow==3.21.3
46
+ matplotlib==3.9.1
47
+ matplotlib-inline==0.1.7
48
+ mdurl==0.1.2
49
+ milvus-lite==2.4.8
50
+ mpmath==1.3.0
51
+ multidict==6.0.5
52
+ mypy-extensions==1.0.0
53
+ networkx==3.2.1
54
+ numpy==1.26.4
55
+ orjson==3.10.6
56
+ packaging==24.1
57
+ pandas==2.2.2
58
+ parso==0.8.4
59
+ pathspec==0.12.1
60
+ pexpect==4.9.0
61
+ pillow==10.4.0
62
+ platformdirs==4.2.2
63
+ prompt-toolkit==3.0.47
64
+ protobuf==5.27.2
65
+ psutil==6.0.0
66
+ ptyprocess==0.7.0
67
+ pure-eval==0.2.3
68
+ pydantic==2.8.2
69
+ pydantic-core==2.20.1
70
+ pydub==0.25.1
71
+ pygments==2.18.0
72
+ pymilvus==2.4.4
73
+ pyparsing==3.1.2
74
+ python-dateutil==2.9.0.post0
75
+ python-dotenv==1.0.1
76
+ python-multipart==0.0.9
77
+ pytz==2024.1
78
+ pyyaml==6.0.1
79
+ regex==2024.5.15
80
+ requests==2.32.3
81
+ rich==13.7.1
82
+ ruff==0.5.5
83
+ safetensors==0.4.3
84
+ scikit-learn==1.5.1
85
+ scipy==1.13.1
86
+ semantic-version==2.10.0
87
+ sentence-transformers==3.0.1
88
+ setuptools==71.1.0
89
+ shellingham==1.5.4
90
+ six==1.16.0
91
+ sniffio==1.3.1
92
+ stack-data==0.6.3
93
+ starlette==0.37.2
94
+ sympy==1.13.1
95
+ threadpoolctl==3.5.0
96
+ tokenizers==0.19.1
97
+ tomli==2.0.1
98
+ tomlkit==0.12.0
99
+ torch==2.3.1
100
+ torchvision==0.18.1
101
+ tqdm==4.66.4
102
+ traitlets==5.14.3
103
+ transformers==4.42.4
104
+ typer==0.12.3
105
+ typing-extensions==4.12.2
106
+ tzdata==2024.1
107
+ ujson==5.10.0
108
+ urllib3==2.2.2
109
+ uvicorn==0.30.3
110
+ uvloop==0.19.0
111
+ watchfiles==0.22.0
112
+ wcwidth==0.2.13
113
+ websockets==11.0.3
114
+ yarl==1.9.4
115
+ zipp==3.19.2