import os import sys import gradio as gr import numpy as np from PIL import Image from embedding import get_device, get_model_and_processor, to_embedding from datasets import Dataset, load_dataset def initialize_model(): device = get_device() model, processor = get_model_and_processor("patrickjohncyh/fashion-clip", device) ref_dataset = load_dataset("HadrienCr/embeddeDior", split="train") ref_dataset = ref_dataset.add_faiss_index("embeddings") return model, processor, ref_dataset, device def search( image: np.ndarray, reference_dataset: Dataset, model, processor, device: str, num_neighbors: int = 4, ): """a function that embeds a new image and returns the most probable results""" scores, retrieved_examples = reference_dataset.get_nearest_examples( "embeddings", to_embedding(np.expand_dims(image, 0), processor, model, device), k=num_neighbors, ) return retrieved_examples def process_image(image, num_results, remove_bg, model, processor, ref_dataset, device): if image is None: return [] # Return an empty list if no image is provided # Ensure the input image is a numpy array if isinstance(image, Image.Image): image = np.array(image) # Handle background removal if remove_bg: from rembg import remove image = remove(image)[:,:,0:3] # Perform the search results = search( image, ref_dataset, model, processor, device, num_neighbors=num_results ) images = results['image'] paths = results['path'] # Prepare the output output_images = [] for img, path in zip(images, paths): output_images.append((np.array(img), os.path.basename(path))) return output_images # Return the list of tuples def main(): print("Initializing model and loading reference dataset...") model, processor, ref_dataset, device = initialize_model() print("Initialization complete!") # Path to the examples folder examples_folder = "examples/" example_files = [ os.path.join(examples_folder, fname) for fname in os.listdir(examples_folder) if fname.lower().endswith(('png', 'jpg', 'jpeg')) ] with gr.Blocks() as demo: gr.Markdown("# Image Retrieval System") gr.Markdown("Upload an image to find similar images in the reference dataset.") with gr.Row(): with gr.Column(scale=1): input_image = gr.Image( label="Upload Image", type="pil" # Changed to PIL format ) num_results = gr.Slider( minimum=1, maximum=10, value=5, step=1, label="Number of results" ) remove_bg = gr.Checkbox(label="Remove Background") submit_btn = gr.Button("Search") with gr.Column(scale=2): gallery = gr.Gallery( label="Retrieved Images", show_label=True, columns=3, object_fit="contain" ) # Add the Examples component gr.Examples( examples=example_files, inputs=input_image, label="Example Images" ) submit_btn.click( fn=lambda img, num, bg: process_image(img, num, bg, model, processor, ref_dataset, device), inputs=[input_image, num_results, remove_bg], outputs=gallery ) demo.launch(share=True) if __name__ == "__main__": main()