Tianyu Ding
commited on
Commit
·
fd9106d
1
Parent(s):
6f6bd67
app.py
CHANGED
@@ -1,12 +1,299 @@
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
-
demo = gr.Interface(
|
7 |
-
fn=greet,
|
8 |
-
inputs=["text", "slider"],
|
9 |
-
outputs=["text"],
|
10 |
-
)
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
from pathlib import Path
|
3 |
+
from PIL import Image
|
4 |
+
import os
|
5 |
+
# from utils import get_image_paths, show_images2
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import numpy as np
|
8 |
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
import os
|
11 |
+
import tensorflow as tf
|
12 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
13 |
+
from pathlib import Path
|
14 |
+
import numpy as np
|
15 |
+
from tqdm import tqdm
|
16 |
+
from collections import OrderedDict
|
17 |
+
import compress_pickle
|
18 |
+
import concurrent
|
19 |
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
+
class ImageSimilarity:
|
22 |
+
def __init__(
|
23 |
+
self,
|
24 |
+
img_dir: Path,
|
25 |
+
recursive: bool = False,
|
26 |
+
BATCH_SIZE: int = 64,
|
27 |
+
IMG_SIZE: int = 224,
|
28 |
+
save_model: bool = True,
|
29 |
+
):
|
30 |
+
self.batch_size = BATCH_SIZE
|
31 |
+
self.img_size = IMG_SIZE
|
32 |
+
self.img_dir = img_dir
|
33 |
+
self.model = tf.keras.applications.MobileNetV2(
|
34 |
+
input_shape=(IMG_SIZE, IMG_SIZE, 3),
|
35 |
+
alpha=1.0,
|
36 |
+
include_top=False,
|
37 |
+
weights="imagenet",
|
38 |
+
input_tensor=None,
|
39 |
+
pooling=None,
|
40 |
+
classifier_activation="softmax",
|
41 |
+
)
|
42 |
+
|
43 |
+
self.model.trainable = False
|
44 |
+
self.model.compile()
|
45 |
+
|
46 |
+
self.save_model = save_model
|
47 |
+
self.recursive = recursive
|
48 |
+
self.ifeatures = None
|
49 |
+
self.filename = "image_dict.lzma"
|
50 |
+
self.image_dict = None
|
51 |
+
self.images_found = None
|
52 |
+
|
53 |
+
def get_image_paths(self, directory_path: Path, recursive: bool = False) -> list:
|
54 |
+
image_extensions = [".jpg", ".jpeg", ".png"] # Add more extensions if needed
|
55 |
+
image_paths = []
|
56 |
+
|
57 |
+
for file_path in directory_path.iterdir():
|
58 |
+
if file_path.is_file() and (file_path.suffix.lower() in image_extensions):
|
59 |
+
image_paths.append(str(file_path.absolute()))
|
60 |
+
|
61 |
+
elif recursive and file_path.is_dir():
|
62 |
+
image_paths.extend(self.get_image_paths(file_path, recursive))
|
63 |
+
|
64 |
+
return image_paths
|
65 |
+
|
66 |
+
def load_image(self, x):
|
67 |
+
image_data = tf.io.read_file(x)
|
68 |
+
image_features = tf.image.decode_jpeg(image_data, channels=3)
|
69 |
+
image_features = tf.image.resize(image_features, (self.img_size, self.img_size))
|
70 |
+
return image_features
|
71 |
+
|
72 |
+
def load_image2(self, x):
|
73 |
+
image_data = tf.keras.utils.img_to_array(x)
|
74 |
+
return tf.image.resize(image_data, (self.img_size, self.img_size))
|
75 |
+
|
76 |
+
def get_vectors(self, image_data: tf.data.Dataset) -> np.array:
|
77 |
+
features = []
|
78 |
+
for i in tqdm(image_data):
|
79 |
+
y = self.model(i)
|
80 |
+
pooled_features = tf.keras.layers.GlobalMaxPooling2D()(y)
|
81 |
+
features.append(pooled_features)
|
82 |
+
|
83 |
+
ifeatures = tf.concat(features, axis=0)
|
84 |
+
ifeatures = tf.cast(ifeatures, tf.float16).numpy()
|
85 |
+
return ifeatures
|
86 |
+
|
87 |
+
def similar_image(self, x, k=5):
|
88 |
+
x = (
|
89 |
+
self.load_image(str(x.absolute()))
|
90 |
+
if isinstance(x, Path)
|
91 |
+
else self.load_image2(x)
|
92 |
+
)
|
93 |
+
|
94 |
+
x_logits = self.model(tf.expand_dims(x, 0))
|
95 |
+
x_logits = (
|
96 |
+
tf.keras.layers.GlobalAveragePooling2D()(x_logits)
|
97 |
+
.numpy()
|
98 |
+
.astype("float16")
|
99 |
+
.reshape((1, -1))
|
100 |
+
.tolist()
|
101 |
+
)
|
102 |
+
|
103 |
+
x_similarity = cosine_similarity(x_logits, self.ifeatures).tolist()[0]
|
104 |
+
|
105 |
+
x_sim_idx = np.argsort(x_similarity)[::-1][:k]
|
106 |
+
x_sim_values = sorted(x_similarity, reverse=True)[:k]
|
107 |
+
keys_at_indices = [list(self.image_dict.keys())[index] for index in x_sim_idx]
|
108 |
+
return keys_at_indices, x_sim_values
|
109 |
+
|
110 |
+
def build_image_features(self):
|
111 |
+
images = self.get_image_paths(self.img_dir, recursive=self.recursive)
|
112 |
+
|
113 |
+
image_data = (
|
114 |
+
tf.data.Dataset.from_tensor_slices(images)
|
115 |
+
.map(self.load_image, num_parallel_calls=tf.data.AUTOTUNE)
|
116 |
+
.batch(self.batch_size)
|
117 |
+
)
|
118 |
+
|
119 |
+
self.ifeatures = self.get_vectors(image_data)
|
120 |
+
self.image_dict = OrderedDict(zip(images, self.ifeatures))
|
121 |
+
|
122 |
+
# print('ifeatures.shape:', self.ifeatures.shape)
|
123 |
+
# print('Features loaded!')
|
124 |
+
|
125 |
+
def load_image_dict(self):
|
126 |
+
if os.path.isfile(self.filename):
|
127 |
+
image_dict = compress_pickle.load(self.filename, compression="lzma")
|
128 |
+
images = self.get_image_paths(self.img_dir, recursive=self.recursive)
|
129 |
+
if images == list(image_dict.keys()):
|
130 |
+
self.image_dict = image_dict
|
131 |
+
self.ifeatures = np.array(list(image_dict.values()))
|
132 |
+
else:
|
133 |
+
self.build_image_features()
|
134 |
+
else:
|
135 |
+
self.build_image_features()
|
136 |
+
|
137 |
+
def save_image_dict(self):
|
138 |
+
compress_pickle.dump(self.image_dict, self.filename, compression="lzma")
|
139 |
+
|
140 |
+
def is_changed(self):
|
141 |
+
images = self.get_image_paths(self.img_dir, recursive=self.recursive)
|
142 |
+
previous_images = list(self.image_dict.keys())
|
143 |
+
return images != previous_images
|
144 |
+
|
145 |
+
def find_similar_images(self, x, k=5):
|
146 |
+
# creating/loading vectors
|
147 |
+
self.load_image_dict()
|
148 |
+
if k == -1:
|
149 |
+
k = self.ifeatures.shape[0]
|
150 |
+
|
151 |
+
sim_img, x_sim = self.similar_image(x, k=k)
|
152 |
+
# print('plotting')
|
153 |
+
plt.figure(figsize=(5, 5))
|
154 |
+
testimg = plt.imread(str(x.absolute()))
|
155 |
+
plt.imshow(testimg)
|
156 |
+
plt.title(f"{x.name}(main)")
|
157 |
+
plt.show()
|
158 |
+
self.show_images(sim_img, similar=x_sim)
|
159 |
+
return x_sim
|
160 |
+
|
161 |
+
def find_similar_images2(self, x, k=5):
|
162 |
+
self.load_image_dict()
|
163 |
+
if k == -1:
|
164 |
+
k = self.ifeatures.shape[0]
|
165 |
+
|
166 |
+
sim_img, x_sim = self.similar_image(x, k=k)
|
167 |
+
return sim_img, x_sim
|
168 |
+
|
169 |
+
def show_images(self, x: list, similar: list = None, figsize=None):
|
170 |
+
n_plots = len(x)
|
171 |
+
# print('n plots: ', n_plots)
|
172 |
+
if figsize is None:
|
173 |
+
# figsize = (20, int(n_plots // 5) * 4)
|
174 |
+
figsize = (20, 5)
|
175 |
+
|
176 |
+
# print('figsize: ',figsize)
|
177 |
+
plt.figure(figsize=figsize)
|
178 |
+
|
179 |
+
x = [Path(i) for i in x]
|
180 |
+
for num, i in enumerate(x, 1):
|
181 |
+
plt.subplot((n_plots // 5) + 1, 5, num)
|
182 |
+
img = plt.imread(i)
|
183 |
+
plt.imshow(img)
|
184 |
+
title = (
|
185 |
+
f"{i.name}\n({100 * similar[num - 1]:.2f}%)"
|
186 |
+
if similar is not None
|
187 |
+
else i.name
|
188 |
+
)
|
189 |
+
plt.title(title)
|
190 |
+
plt.axis(False)
|
191 |
+
plt.tight_layout()
|
192 |
+
|
193 |
+
plt.show()
|
194 |
+
|
195 |
+
def __call__(self, x: Path, k=5):
|
196 |
+
with concurrent.futures.ThreadPoolExecutor() as executor:
|
197 |
+
finding = executor.submit(self.find_similar_images(x, k=5))
|
198 |
+
|
199 |
+
if self.save_model and (
|
200 |
+
self.is_changed() or (not Path(self.filename).exists())
|
201 |
+
):
|
202 |
+
save_imagedict = executor.submit(self.save_image_dict)
|
203 |
+
|
204 |
+
|
205 |
+
def resize_image(img_path, max_size=800):
|
206 |
+
with Image.open(img_path) as img:
|
207 |
+
# change the size of the image to max_size but keep the aspect ratio
|
208 |
+
width, height = img.size
|
209 |
+
if width > height:
|
210 |
+
new_width = max_size
|
211 |
+
new_height = int(height * (new_width / width))
|
212 |
+
else:
|
213 |
+
new_height = max_size
|
214 |
+
new_width = int(width * (new_height / height))
|
215 |
+
img = img.resize((new_width, new_height))
|
216 |
+
return img
|
217 |
+
|
218 |
+
|
219 |
+
|
220 |
+
def get_image_paths(directory_path: Path, recursive: bool = False) -> list:
|
221 |
+
image_extensions = [".jpg", ".jpeg", ".png"] # Add more extensions if needed
|
222 |
+
image_paths = []
|
223 |
+
|
224 |
+
for file_path in directory_path.iterdir():
|
225 |
+
if file_path.is_file() and (file_path.suffix.lower() in image_extensions):
|
226 |
+
image_paths.append(str(file_path.absolute()))
|
227 |
+
|
228 |
+
elif recursive and file_path.is_dir():
|
229 |
+
image_paths.extend(get_image_paths(file_path, recursive))
|
230 |
+
|
231 |
+
return image_paths
|
232 |
+
|
233 |
+
def find_similar_images(img_dir, img_path, similar_images, save_model, recursive):
|
234 |
+
if img_dir and (img_path):
|
235 |
+
total_images = len(get_image_paths(Path(img_dir), recursive=recursive))
|
236 |
+
similar_images = min(similar_images, total_images)
|
237 |
+
|
238 |
+
main_image = Image.open(img_path) if isinstance(img_path, str) else Image.fromarray(img_path)
|
239 |
+
image_similarity = ImageSimilarity(
|
240 |
+
img_dir=Path(img_dir), recursive=recursive, save_model=save_model
|
241 |
+
)
|
242 |
+
|
243 |
+
similar_image_paths, similarity_values = image_similarity.find_similar_images2(main_image, k=similar_images)
|
244 |
+
|
245 |
+
# print(similar_image_paths, similarity_values)
|
246 |
+
|
247 |
+
if save_model:
|
248 |
+
image_similarity.save_image_dict()
|
249 |
+
|
250 |
+
# Prepare the output
|
251 |
+
status = f"Found {len(similar_image_paths)} similar images."
|
252 |
+
|
253 |
+
# Resize and load similar images
|
254 |
+
similar_images_list = [
|
255 |
+
(resize_image(path), f"Similarity: {sim:.4f}")
|
256 |
+
for path, sim in zip(similar_image_paths, similarity_values)
|
257 |
+
]
|
258 |
+
|
259 |
+
# Resize the main image
|
260 |
+
resized_main_image = resize_image(img_path)
|
261 |
+
|
262 |
+
return status, resized_main_image, similar_images_list
|
263 |
+
|
264 |
+
return "Please provide both directory and image path.", None, None
|
265 |
+
|
266 |
+
|
267 |
+
|
268 |
+
with gr.Blocks() as demo:
|
269 |
+
gr.Markdown("# Photo2Photo Search Engine")
|
270 |
+
|
271 |
+
with gr.Row():
|
272 |
+
with gr.Column(scale=5):
|
273 |
+
img_dir = gr.Textbox(label="Directory to search")
|
274 |
+
with gr.Column(scale=3):
|
275 |
+
img_path = gr.Image(label="Upload an image", type="filepath")
|
276 |
+
|
277 |
+
with gr.Row():
|
278 |
+
with gr.Column(scale=1):
|
279 |
+
similar_images = gr.Number(label="Number of similar images to display:", value=7, minimum=1, maximum=50, step=1)
|
280 |
+
with gr.Column(scale=1):
|
281 |
+
save_model = gr.Checkbox(label="Save Model", value=False, info="Save the model for faster loads, check if you search in same folder again and again")
|
282 |
+
recursive = gr.Checkbox(label="Recursive", value=False, info="Search recursively for images in child folders")
|
283 |
+
with gr.Column(scale=1):
|
284 |
+
submit_button = gr.Button("Find Similar Images")
|
285 |
+
|
286 |
+
output_text = gr.Textbox(label="Status")
|
287 |
+
main_image_output = gr.Image(label="Main Image")
|
288 |
+
similar_images_output = gr.Gallery(label="Similar Images", show_label=True)
|
289 |
+
|
290 |
+
submit_button.click(
|
291 |
+
find_similar_images,
|
292 |
+
inputs=[img_dir, img_path, similar_images, save_model, recursive],
|
293 |
+
outputs=[output_text, main_image_output, similar_images_output]
|
294 |
+
)
|
295 |
+
|
296 |
+
|
297 |
+
|
298 |
+
if __name__ == "__main__":
|
299 |
+
demo.launch()
|