HadrienCr commited on
Commit
8d75374
·
1 Parent(s): c0f6daf

Add application file

Browse files
Files changed (1) hide show
  1. app.py +121 -0
app.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import gradio as gr
4
+ import numpy as np
5
+ from PIL import Image
6
+ from embedding import get_device, get_model_and_processor, to_embedding
7
+ from datasets import Dataset, load_dataset
8
+
9
+ def initialize_model():
10
+ device = get_device()
11
+ model, processor = get_model_and_processor("patrickjohncyh/fashion-clip", device)
12
+ ref_dataset = load_dataset("HadrienCr/embeddeDior", split="train")
13
+ ref_dataset = ref_dataset.add_faiss_index("embeddings")
14
+ return model, processor, ref_dataset, device
15
+
16
+ def search(
17
+ image: np.ndarray,
18
+ reference_dataset: Dataset,
19
+ model,
20
+ processor,
21
+ device: str,
22
+ num_neighbors: int = 4,
23
+ ):
24
+ """a function that embeds a new image and returns the most probable results"""
25
+ scores, retrieved_examples = reference_dataset.get_nearest_examples(
26
+ "embeddings",
27
+ to_embedding(np.expand_dims(image, 0), processor, model, device),
28
+ k=num_neighbors,
29
+ )
30
+ return retrieved_examples
31
+
32
+ def process_image(image, num_results, remove_bg, model, processor, ref_dataset, device):
33
+ if image is None:
34
+ return [] # Return an empty list if no image is provided
35
+
36
+ # Ensure the input image is a numpy array
37
+ if isinstance(image, Image.Image):
38
+ image = np.array(image)
39
+
40
+ # Handle background removal
41
+ if remove_bg:
42
+ from rembg import remove
43
+ image = remove(image)[:,:,0:3]
44
+ # Perform the search
45
+ results = search(
46
+ image,
47
+ ref_dataset,
48
+ model,
49
+ processor,
50
+ device,
51
+ num_neighbors=num_results
52
+ )
53
+ images = results['image']
54
+ paths = results['path']
55
+ # Prepare the output
56
+ output_images = []
57
+
58
+ for img, path in zip(images, paths):
59
+ output_images.append((np.array(img), os.path.basename(path)))
60
+
61
+ return output_images # Return the list of tuples
62
+
63
+
64
+ def main():
65
+ print("Initializing model and loading reference dataset...")
66
+ model, processor, ref_dataset, device = initialize_model()
67
+ print("Initialization complete!")
68
+
69
+ # Path to the examples folder
70
+ examples_folder = "examples/"
71
+ example_files = [
72
+ os.path.join(examples_folder, fname)
73
+ for fname in os.listdir(examples_folder)
74
+ if fname.lower().endswith(('png', 'jpg', 'jpeg'))
75
+ ]
76
+
77
+ with gr.Blocks() as demo:
78
+ gr.Markdown("# Image Retrieval System")
79
+ gr.Markdown("Upload an image to find similar images in the reference dataset.")
80
+
81
+ with gr.Row():
82
+ with gr.Column(scale=1):
83
+ input_image = gr.Image(
84
+ label="Upload Image",
85
+ type="pil" # Changed to PIL format
86
+ )
87
+ num_results = gr.Slider(
88
+ minimum=1,
89
+ maximum=10,
90
+ value=5,
91
+ step=1,
92
+ label="Number of results"
93
+ )
94
+ remove_bg = gr.Checkbox(label="Remove Background")
95
+ submit_btn = gr.Button("Search")
96
+
97
+ with gr.Column(scale=2):
98
+ gallery = gr.Gallery(
99
+ label="Retrieved Images",
100
+ show_label=True,
101
+ columns=3,
102
+ object_fit="contain"
103
+ )
104
+
105
+ # Add the Examples component
106
+ gr.Examples(
107
+ examples=example_files,
108
+ inputs=input_image,
109
+ label="Example Images"
110
+ )
111
+
112
+ submit_btn.click(
113
+ fn=lambda img, num, bg: process_image(img, num, bg, model, processor, ref_dataset, device),
114
+ inputs=[input_image, num_results, remove_bg],
115
+ outputs=gallery
116
+ )
117
+
118
+ demo.launch(share=True)
119
+
120
+ if __name__ == "__main__":
121
+ main()