Tianyu Ding commited on
Commit
fd9106d
·
1 Parent(s): 6f6bd67
Files changed (1) hide show
  1. app.py +295 -8
app.py CHANGED
@@ -1,12 +1,299 @@
1
  import gradio as gr
 
 
 
 
 
 
2
 
3
- def greet(name, intensity):
4
- return "Hello, " + name + "!" * int(intensity)
 
 
 
 
 
 
 
 
5
 
6
- demo = gr.Interface(
7
- fn=greet,
8
- inputs=["text", "slider"],
9
- outputs=["text"],
10
- )
11
 
12
- demo.launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from pathlib import Path
3
+ from PIL import Image
4
+ import os
5
+ # from utils import get_image_paths, show_images2
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
 
9
+ import matplotlib.pyplot as plt
10
+ import os
11
+ import tensorflow as tf
12
+ from sklearn.metrics.pairwise import cosine_similarity
13
+ from pathlib import Path
14
+ import numpy as np
15
+ from tqdm import tqdm
16
+ from collections import OrderedDict
17
+ import compress_pickle
18
+ import concurrent
19
 
 
 
 
 
 
20
 
21
+ class ImageSimilarity:
22
+ def __init__(
23
+ self,
24
+ img_dir: Path,
25
+ recursive: bool = False,
26
+ BATCH_SIZE: int = 64,
27
+ IMG_SIZE: int = 224,
28
+ save_model: bool = True,
29
+ ):
30
+ self.batch_size = BATCH_SIZE
31
+ self.img_size = IMG_SIZE
32
+ self.img_dir = img_dir
33
+ self.model = tf.keras.applications.MobileNetV2(
34
+ input_shape=(IMG_SIZE, IMG_SIZE, 3),
35
+ alpha=1.0,
36
+ include_top=False,
37
+ weights="imagenet",
38
+ input_tensor=None,
39
+ pooling=None,
40
+ classifier_activation="softmax",
41
+ )
42
+
43
+ self.model.trainable = False
44
+ self.model.compile()
45
+
46
+ self.save_model = save_model
47
+ self.recursive = recursive
48
+ self.ifeatures = None
49
+ self.filename = "image_dict.lzma"
50
+ self.image_dict = None
51
+ self.images_found = None
52
+
53
+ def get_image_paths(self, directory_path: Path, recursive: bool = False) -> list:
54
+ image_extensions = [".jpg", ".jpeg", ".png"] # Add more extensions if needed
55
+ image_paths = []
56
+
57
+ for file_path in directory_path.iterdir():
58
+ if file_path.is_file() and (file_path.suffix.lower() in image_extensions):
59
+ image_paths.append(str(file_path.absolute()))
60
+
61
+ elif recursive and file_path.is_dir():
62
+ image_paths.extend(self.get_image_paths(file_path, recursive))
63
+
64
+ return image_paths
65
+
66
+ def load_image(self, x):
67
+ image_data = tf.io.read_file(x)
68
+ image_features = tf.image.decode_jpeg(image_data, channels=3)
69
+ image_features = tf.image.resize(image_features, (self.img_size, self.img_size))
70
+ return image_features
71
+
72
+ def load_image2(self, x):
73
+ image_data = tf.keras.utils.img_to_array(x)
74
+ return tf.image.resize(image_data, (self.img_size, self.img_size))
75
+
76
+ def get_vectors(self, image_data: tf.data.Dataset) -> np.array:
77
+ features = []
78
+ for i in tqdm(image_data):
79
+ y = self.model(i)
80
+ pooled_features = tf.keras.layers.GlobalMaxPooling2D()(y)
81
+ features.append(pooled_features)
82
+
83
+ ifeatures = tf.concat(features, axis=0)
84
+ ifeatures = tf.cast(ifeatures, tf.float16).numpy()
85
+ return ifeatures
86
+
87
+ def similar_image(self, x, k=5):
88
+ x = (
89
+ self.load_image(str(x.absolute()))
90
+ if isinstance(x, Path)
91
+ else self.load_image2(x)
92
+ )
93
+
94
+ x_logits = self.model(tf.expand_dims(x, 0))
95
+ x_logits = (
96
+ tf.keras.layers.GlobalAveragePooling2D()(x_logits)
97
+ .numpy()
98
+ .astype("float16")
99
+ .reshape((1, -1))
100
+ .tolist()
101
+ )
102
+
103
+ x_similarity = cosine_similarity(x_logits, self.ifeatures).tolist()[0]
104
+
105
+ x_sim_idx = np.argsort(x_similarity)[::-1][:k]
106
+ x_sim_values = sorted(x_similarity, reverse=True)[:k]
107
+ keys_at_indices = [list(self.image_dict.keys())[index] for index in x_sim_idx]
108
+ return keys_at_indices, x_sim_values
109
+
110
+ def build_image_features(self):
111
+ images = self.get_image_paths(self.img_dir, recursive=self.recursive)
112
+
113
+ image_data = (
114
+ tf.data.Dataset.from_tensor_slices(images)
115
+ .map(self.load_image, num_parallel_calls=tf.data.AUTOTUNE)
116
+ .batch(self.batch_size)
117
+ )
118
+
119
+ self.ifeatures = self.get_vectors(image_data)
120
+ self.image_dict = OrderedDict(zip(images, self.ifeatures))
121
+
122
+ # print('ifeatures.shape:', self.ifeatures.shape)
123
+ # print('Features loaded!')
124
+
125
+ def load_image_dict(self):
126
+ if os.path.isfile(self.filename):
127
+ image_dict = compress_pickle.load(self.filename, compression="lzma")
128
+ images = self.get_image_paths(self.img_dir, recursive=self.recursive)
129
+ if images == list(image_dict.keys()):
130
+ self.image_dict = image_dict
131
+ self.ifeatures = np.array(list(image_dict.values()))
132
+ else:
133
+ self.build_image_features()
134
+ else:
135
+ self.build_image_features()
136
+
137
+ def save_image_dict(self):
138
+ compress_pickle.dump(self.image_dict, self.filename, compression="lzma")
139
+
140
+ def is_changed(self):
141
+ images = self.get_image_paths(self.img_dir, recursive=self.recursive)
142
+ previous_images = list(self.image_dict.keys())
143
+ return images != previous_images
144
+
145
+ def find_similar_images(self, x, k=5):
146
+ # creating/loading vectors
147
+ self.load_image_dict()
148
+ if k == -1:
149
+ k = self.ifeatures.shape[0]
150
+
151
+ sim_img, x_sim = self.similar_image(x, k=k)
152
+ # print('plotting')
153
+ plt.figure(figsize=(5, 5))
154
+ testimg = plt.imread(str(x.absolute()))
155
+ plt.imshow(testimg)
156
+ plt.title(f"{x.name}(main)")
157
+ plt.show()
158
+ self.show_images(sim_img, similar=x_sim)
159
+ return x_sim
160
+
161
+ def find_similar_images2(self, x, k=5):
162
+ self.load_image_dict()
163
+ if k == -1:
164
+ k = self.ifeatures.shape[0]
165
+
166
+ sim_img, x_sim = self.similar_image(x, k=k)
167
+ return sim_img, x_sim
168
+
169
+ def show_images(self, x: list, similar: list = None, figsize=None):
170
+ n_plots = len(x)
171
+ # print('n plots: ', n_plots)
172
+ if figsize is None:
173
+ # figsize = (20, int(n_plots // 5) * 4)
174
+ figsize = (20, 5)
175
+
176
+ # print('figsize: ',figsize)
177
+ plt.figure(figsize=figsize)
178
+
179
+ x = [Path(i) for i in x]
180
+ for num, i in enumerate(x, 1):
181
+ plt.subplot((n_plots // 5) + 1, 5, num)
182
+ img = plt.imread(i)
183
+ plt.imshow(img)
184
+ title = (
185
+ f"{i.name}\n({100 * similar[num - 1]:.2f}%)"
186
+ if similar is not None
187
+ else i.name
188
+ )
189
+ plt.title(title)
190
+ plt.axis(False)
191
+ plt.tight_layout()
192
+
193
+ plt.show()
194
+
195
+ def __call__(self, x: Path, k=5):
196
+ with concurrent.futures.ThreadPoolExecutor() as executor:
197
+ finding = executor.submit(self.find_similar_images(x, k=5))
198
+
199
+ if self.save_model and (
200
+ self.is_changed() or (not Path(self.filename).exists())
201
+ ):
202
+ save_imagedict = executor.submit(self.save_image_dict)
203
+
204
+
205
+ def resize_image(img_path, max_size=800):
206
+ with Image.open(img_path) as img:
207
+ # change the size of the image to max_size but keep the aspect ratio
208
+ width, height = img.size
209
+ if width > height:
210
+ new_width = max_size
211
+ new_height = int(height * (new_width / width))
212
+ else:
213
+ new_height = max_size
214
+ new_width = int(width * (new_height / height))
215
+ img = img.resize((new_width, new_height))
216
+ return img
217
+
218
+
219
+
220
+ def get_image_paths(directory_path: Path, recursive: bool = False) -> list:
221
+ image_extensions = [".jpg", ".jpeg", ".png"] # Add more extensions if needed
222
+ image_paths = []
223
+
224
+ for file_path in directory_path.iterdir():
225
+ if file_path.is_file() and (file_path.suffix.lower() in image_extensions):
226
+ image_paths.append(str(file_path.absolute()))
227
+
228
+ elif recursive and file_path.is_dir():
229
+ image_paths.extend(get_image_paths(file_path, recursive))
230
+
231
+ return image_paths
232
+
233
+ def find_similar_images(img_dir, img_path, similar_images, save_model, recursive):
234
+ if img_dir and (img_path):
235
+ total_images = len(get_image_paths(Path(img_dir), recursive=recursive))
236
+ similar_images = min(similar_images, total_images)
237
+
238
+ main_image = Image.open(img_path) if isinstance(img_path, str) else Image.fromarray(img_path)
239
+ image_similarity = ImageSimilarity(
240
+ img_dir=Path(img_dir), recursive=recursive, save_model=save_model
241
+ )
242
+
243
+ similar_image_paths, similarity_values = image_similarity.find_similar_images2(main_image, k=similar_images)
244
+
245
+ # print(similar_image_paths, similarity_values)
246
+
247
+ if save_model:
248
+ image_similarity.save_image_dict()
249
+
250
+ # Prepare the output
251
+ status = f"Found {len(similar_image_paths)} similar images."
252
+
253
+ # Resize and load similar images
254
+ similar_images_list = [
255
+ (resize_image(path), f"Similarity: {sim:.4f}")
256
+ for path, sim in zip(similar_image_paths, similarity_values)
257
+ ]
258
+
259
+ # Resize the main image
260
+ resized_main_image = resize_image(img_path)
261
+
262
+ return status, resized_main_image, similar_images_list
263
+
264
+ return "Please provide both directory and image path.", None, None
265
+
266
+
267
+
268
+ with gr.Blocks() as demo:
269
+ gr.Markdown("# Photo2Photo Search Engine")
270
+
271
+ with gr.Row():
272
+ with gr.Column(scale=5):
273
+ img_dir = gr.Textbox(label="Directory to search")
274
+ with gr.Column(scale=3):
275
+ img_path = gr.Image(label="Upload an image", type="filepath")
276
+
277
+ with gr.Row():
278
+ with gr.Column(scale=1):
279
+ similar_images = gr.Number(label="Number of similar images to display:", value=7, minimum=1, maximum=50, step=1)
280
+ with gr.Column(scale=1):
281
+ save_model = gr.Checkbox(label="Save Model", value=False, info="Save the model for faster loads, check if you search in same folder again and again")
282
+ recursive = gr.Checkbox(label="Recursive", value=False, info="Search recursively for images in child folders")
283
+ with gr.Column(scale=1):
284
+ submit_button = gr.Button("Find Similar Images")
285
+
286
+ output_text = gr.Textbox(label="Status")
287
+ main_image_output = gr.Image(label="Main Image")
288
+ similar_images_output = gr.Gallery(label="Similar Images", show_label=True)
289
+
290
+ submit_button.click(
291
+ find_similar_images,
292
+ inputs=[img_dir, img_path, similar_images, save_model, recursive],
293
+ outputs=[output_text, main_image_output, similar_images_output]
294
+ )
295
+
296
+
297
+
298
+ if __name__ == "__main__":
299
+ demo.launch()