|
import os |
|
import sys |
|
import gradio as gr |
|
import numpy as np |
|
from PIL import Image |
|
from embedding import get_device, get_model_and_processor, to_embedding |
|
from datasets import Dataset, load_dataset |
|
|
|
def initialize_model(): |
|
device = get_device() |
|
model, processor = get_model_and_processor("patrickjohncyh/fashion-clip", device) |
|
ref_dataset = load_dataset("HadrienCr/embeddeDior", split="train") |
|
ref_dataset = ref_dataset.add_faiss_index("embeddings") |
|
return model, processor, ref_dataset, device |
|
|
|
def search( |
|
image: np.ndarray, |
|
reference_dataset: Dataset, |
|
model, |
|
processor, |
|
device: str, |
|
num_neighbors: int = 4, |
|
): |
|
"""a function that embeds a new image and returns the most probable results""" |
|
scores, retrieved_examples = reference_dataset.get_nearest_examples( |
|
"embeddings", |
|
to_embedding(np.expand_dims(image, 0), processor, model, device), |
|
k=num_neighbors, |
|
) |
|
return retrieved_examples |
|
|
|
def process_image(image, num_results, remove_bg, model, processor, ref_dataset, device): |
|
if image is None: |
|
return [] |
|
|
|
|
|
if isinstance(image, Image.Image): |
|
image = np.array(image) |
|
|
|
|
|
if remove_bg: |
|
from rembg import remove |
|
image = remove(image)[:,:,0:3] |
|
|
|
results = search( |
|
image, |
|
ref_dataset, |
|
model, |
|
processor, |
|
device, |
|
num_neighbors=num_results |
|
) |
|
images = results['image'] |
|
paths = results['path'] |
|
|
|
output_images = [] |
|
|
|
for img, path in zip(images, paths): |
|
output_images.append((np.array(img), os.path.basename(path))) |
|
|
|
return output_images |
|
|
|
|
|
def main(): |
|
print("Initializing model and loading reference dataset...") |
|
model, processor, ref_dataset, device = initialize_model() |
|
print("Initialization complete!") |
|
|
|
|
|
examples_folder = "examples/" |
|
example_files = [ |
|
os.path.join(examples_folder, fname) |
|
for fname in os.listdir(examples_folder) |
|
if fname.lower().endswith(('png', 'jpg', 'jpeg')) |
|
] |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Image Retrieval System") |
|
gr.Markdown("Upload an image to find similar images in the reference dataset.") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
input_image = gr.Image( |
|
label="Upload Image", |
|
type="pil" |
|
) |
|
num_results = gr.Slider( |
|
minimum=1, |
|
maximum=20, |
|
value=5, |
|
step=1, |
|
label="Number of results" |
|
) |
|
remove_bg = gr.Checkbox(label="Remove Background") |
|
submit_btn = gr.Button("Search") |
|
|
|
with gr.Column(scale=2): |
|
gallery = gr.Gallery( |
|
label="Retrieved Images", |
|
show_label=True, |
|
columns=3, |
|
object_fit="contain" |
|
) |
|
|
|
|
|
gr.Examples( |
|
examples=example_files, |
|
inputs=input_image, |
|
label="Example Images" |
|
) |
|
|
|
submit_btn.click( |
|
fn=lambda img, num, bg: process_image(img, num, bg, model, processor, ref_dataset, device), |
|
inputs=[input_image, num_results, remove_bg], |
|
outputs=gallery |
|
) |
|
|
|
demo.launch(share=True) |
|
|
|
if __name__ == "__main__": |
|
main() |