File size: 3,741 Bytes
8d75374
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b32144d
8d75374
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import os
import sys
import gradio as gr
import numpy as np
from PIL import Image
from embedding import get_device, get_model_and_processor, to_embedding
from datasets import Dataset, load_dataset

def initialize_model():
    device = get_device()
    model, processor = get_model_and_processor("patrickjohncyh/fashion-clip", device)
    ref_dataset = load_dataset("HadrienCr/embeddeDior", split="train")
    ref_dataset = ref_dataset.add_faiss_index("embeddings")
    return model, processor, ref_dataset, device

def search(
    image: np.ndarray,
    reference_dataset: Dataset,
    model,
    processor,
    device: str,
    num_neighbors: int = 4,
):
    """a function that embeds a new image and returns the most probable results"""
    scores, retrieved_examples = reference_dataset.get_nearest_examples(
        "embeddings",
        to_embedding(np.expand_dims(image, 0), processor, model, device),
        k=num_neighbors,
    )
    return retrieved_examples

def process_image(image, num_results, remove_bg, model, processor, ref_dataset, device):
    if image is None:
        return []  # Return an empty list if no image is provided
    
    # Ensure the input image is a numpy array
    if isinstance(image, Image.Image):
        image = np.array(image)
    
    # Handle background removal
    if remove_bg:
        from rembg import remove
        image = remove(image)[:,:,0:3]
    # Perform the search
    results = search(
        image,
        ref_dataset,
        model,
        processor,
        device,
        num_neighbors=num_results
    )
    images = results['image']
    paths = results['path']
    # Prepare the output
    output_images = []

    for img, path in zip(images, paths):
        output_images.append((np.array(img), os.path.basename(path)))
    
    return output_images  # Return the list of tuples


def main():
    print("Initializing model and loading reference dataset...")
    model, processor, ref_dataset, device = initialize_model()
    print("Initialization complete!")

    # Path to the examples folder
    examples_folder = "examples/"
    example_files = [
        os.path.join(examples_folder, fname) 
        for fname in os.listdir(examples_folder) 
        if fname.lower().endswith(('png', 'jpg', 'jpeg'))
    ]

    with gr.Blocks() as demo:
        gr.Markdown("# Image Retrieval System")
        gr.Markdown("Upload an image to find similar images in the reference dataset.")
        
        with gr.Row():
            with gr.Column(scale=1):
                input_image = gr.Image(
                    label="Upload Image",
                    type="pil"  # Changed to PIL format
                )
                num_results = gr.Slider(
                    minimum=1, 
                    maximum=20, 
                    value=5, 
                    step=1, 
                    label="Number of results"
                )
                remove_bg = gr.Checkbox(label="Remove Background")
                submit_btn = gr.Button("Search")
            
            with gr.Column(scale=2):
                gallery = gr.Gallery(
                    label="Retrieved Images",
                    show_label=True,
                    columns=3,
                    object_fit="contain"
                )

        # Add the Examples component
        gr.Examples(
            examples=example_files,
            inputs=input_image,
            label="Example Images"
        )

        submit_btn.click(
            fn=lambda img, num, bg: process_image(img, num, bg, model, processor, ref_dataset, device),
            inputs=[input_image, num_results, remove_bg],
            outputs=gallery
        )

    demo.launch(share=True)

if __name__ == "__main__":
    main()