|
import gradio as gr |
|
from pathlib import Path |
|
from PIL import Image |
|
import os |
|
|
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
|
|
import matplotlib.pyplot as plt |
|
import os |
|
import tensorflow as tf |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
from pathlib import Path |
|
import numpy as np |
|
from tqdm import tqdm |
|
from collections import OrderedDict |
|
import compress_pickle |
|
import concurrent |
|
|
|
|
|
class ImageSimilarity: |
|
def __init__( |
|
self, |
|
img_dir: Path, |
|
recursive: bool = False, |
|
BATCH_SIZE: int = 64, |
|
IMG_SIZE: int = 224, |
|
save_model: bool = True, |
|
): |
|
self.batch_size = BATCH_SIZE |
|
self.img_size = IMG_SIZE |
|
self.img_dir = img_dir |
|
self.model = tf.keras.applications.MobileNetV2( |
|
input_shape=(IMG_SIZE, IMG_SIZE, 3), |
|
alpha=1.0, |
|
include_top=False, |
|
weights="imagenet", |
|
input_tensor=None, |
|
pooling=None, |
|
classifier_activation="softmax", |
|
) |
|
|
|
self.model.trainable = False |
|
self.model.compile() |
|
|
|
self.save_model = save_model |
|
self.recursive = recursive |
|
self.ifeatures = None |
|
self.filename = "image_dict.lzma" |
|
self.image_dict = None |
|
self.images_found = None |
|
|
|
def get_image_paths(self, directory_path: Path, recursive: bool = False) -> list: |
|
image_extensions = [".jpg", ".jpeg", ".png"] |
|
image_paths = [] |
|
|
|
for file_path in directory_path.iterdir(): |
|
if file_path.is_file() and (file_path.suffix.lower() in image_extensions): |
|
image_paths.append(str(file_path.absolute())) |
|
|
|
elif recursive and file_path.is_dir(): |
|
image_paths.extend(self.get_image_paths(file_path, recursive)) |
|
|
|
return image_paths |
|
|
|
def load_image(self, x): |
|
image_data = tf.io.read_file(x) |
|
image_features = tf.image.decode_jpeg(image_data, channels=3) |
|
image_features = tf.image.resize(image_features, (self.img_size, self.img_size)) |
|
return image_features |
|
|
|
def load_image2(self, x): |
|
image_data = tf.keras.utils.img_to_array(x) |
|
return tf.image.resize(image_data, (self.img_size, self.img_size)) |
|
|
|
def get_vectors(self, image_data: tf.data.Dataset) -> np.array: |
|
features = [] |
|
for i in tqdm(image_data): |
|
y = self.model(i) |
|
pooled_features = tf.keras.layers.GlobalMaxPooling2D()(y) |
|
features.append(pooled_features) |
|
|
|
ifeatures = tf.concat(features, axis=0) |
|
ifeatures = tf.cast(ifeatures, tf.float16).numpy() |
|
return ifeatures |
|
|
|
def similar_image(self, x, k=5): |
|
x = ( |
|
self.load_image(str(x.absolute())) |
|
if isinstance(x, Path) |
|
else self.load_image2(x) |
|
) |
|
|
|
x_logits = self.model(tf.expand_dims(x, 0)) |
|
x_logits = ( |
|
tf.keras.layers.GlobalAveragePooling2D()(x_logits) |
|
.numpy() |
|
.astype("float16") |
|
.reshape((1, -1)) |
|
.tolist() |
|
) |
|
|
|
x_similarity = cosine_similarity(x_logits, self.ifeatures).tolist()[0] |
|
|
|
x_sim_idx = np.argsort(x_similarity)[::-1][:k] |
|
x_sim_values = sorted(x_similarity, reverse=True)[:k] |
|
keys_at_indices = [list(self.image_dict.keys())[index] for index in x_sim_idx] |
|
return keys_at_indices, x_sim_values |
|
|
|
def build_image_features(self): |
|
images = self.get_image_paths(self.img_dir, recursive=self.recursive) |
|
|
|
image_data = ( |
|
tf.data.Dataset.from_tensor_slices(images) |
|
.map(self.load_image, num_parallel_calls=tf.data.AUTOTUNE) |
|
.batch(self.batch_size) |
|
) |
|
|
|
self.ifeatures = self.get_vectors(image_data) |
|
self.image_dict = OrderedDict(zip(images, self.ifeatures)) |
|
|
|
|
|
|
|
|
|
def load_image_dict(self): |
|
if os.path.isfile(self.filename): |
|
image_dict = compress_pickle.load(self.filename, compression="lzma") |
|
images = self.get_image_paths(self.img_dir, recursive=self.recursive) |
|
if images == list(image_dict.keys()): |
|
self.image_dict = image_dict |
|
self.ifeatures = np.array(list(image_dict.values())) |
|
else: |
|
self.build_image_features() |
|
else: |
|
self.build_image_features() |
|
|
|
def save_image_dict(self): |
|
compress_pickle.dump(self.image_dict, self.filename, compression="lzma") |
|
|
|
def is_changed(self): |
|
images = self.get_image_paths(self.img_dir, recursive=self.recursive) |
|
previous_images = list(self.image_dict.keys()) |
|
return images != previous_images |
|
|
|
def find_similar_images(self, x, k=5): |
|
|
|
self.load_image_dict() |
|
if k == -1: |
|
k = self.ifeatures.shape[0] |
|
|
|
sim_img, x_sim = self.similar_image(x, k=k) |
|
|
|
plt.figure(figsize=(5, 5)) |
|
testimg = plt.imread(str(x.absolute())) |
|
plt.imshow(testimg) |
|
plt.title(f"{x.name}(main)") |
|
plt.show() |
|
self.show_images(sim_img, similar=x_sim) |
|
return x_sim |
|
|
|
def find_similar_images2(self, x, k=5): |
|
self.load_image_dict() |
|
if k == -1: |
|
k = self.ifeatures.shape[0] |
|
|
|
sim_img, x_sim = self.similar_image(x, k=k) |
|
return sim_img, x_sim |
|
|
|
def show_images(self, x: list, similar: list = None, figsize=None): |
|
n_plots = len(x) |
|
|
|
if figsize is None: |
|
|
|
figsize = (20, 5) |
|
|
|
|
|
plt.figure(figsize=figsize) |
|
|
|
x = [Path(i) for i in x] |
|
for num, i in enumerate(x, 1): |
|
plt.subplot((n_plots // 5) + 1, 5, num) |
|
img = plt.imread(i) |
|
plt.imshow(img) |
|
title = ( |
|
f"{i.name}\n({100 * similar[num - 1]:.2f}%)" |
|
if similar is not None |
|
else i.name |
|
) |
|
plt.title(title) |
|
plt.axis(False) |
|
plt.tight_layout() |
|
|
|
plt.show() |
|
|
|
def __call__(self, x: Path, k=5): |
|
with concurrent.futures.ThreadPoolExecutor() as executor: |
|
finding = executor.submit(self.find_similar_images(x, k=5)) |
|
|
|
if self.save_model and ( |
|
self.is_changed() or (not Path(self.filename).exists()) |
|
): |
|
save_imagedict = executor.submit(self.save_image_dict) |
|
|
|
|
|
def resize_image(img_path, max_size=800): |
|
with Image.open(img_path) as img: |
|
|
|
width, height = img.size |
|
if width > height: |
|
new_width = max_size |
|
new_height = int(height * (new_width / width)) |
|
else: |
|
new_height = max_size |
|
new_width = int(width * (new_height / height)) |
|
img = img.resize((new_width, new_height)) |
|
return img |
|
|
|
|
|
|
|
def get_image_paths(directory_path: Path, recursive: bool = False) -> list: |
|
image_extensions = [".jpg", ".jpeg", ".png"] |
|
image_paths = [] |
|
|
|
for file_path in directory_path.iterdir(): |
|
if file_path.is_file() and (file_path.suffix.lower() in image_extensions): |
|
image_paths.append(str(file_path.absolute())) |
|
|
|
elif recursive and file_path.is_dir(): |
|
image_paths.extend(get_image_paths(file_path, recursive)) |
|
|
|
return image_paths |
|
|
|
def find_similar_images(img_dir, img_path, similar_images, save_model, recursive): |
|
if img_dir and (img_path): |
|
total_images = len(get_image_paths(Path(img_dir), recursive=recursive)) |
|
similar_images = min(similar_images, total_images) |
|
|
|
main_image = Image.open(img_path) if isinstance(img_path, str) else Image.fromarray(img_path) |
|
image_similarity = ImageSimilarity( |
|
img_dir=Path(img_dir), recursive=recursive, save_model=save_model |
|
) |
|
|
|
similar_image_paths, similarity_values = image_similarity.find_similar_images2(main_image, k=similar_images) |
|
|
|
|
|
|
|
if save_model: |
|
image_similarity.save_image_dict() |
|
|
|
|
|
status = f"Found {len(similar_image_paths)} similar images." |
|
|
|
|
|
similar_images_list = [ |
|
(resize_image(path), f"Similarity: {sim:.4f}") |
|
for path, sim in zip(similar_image_paths, similarity_values) |
|
] |
|
|
|
|
|
resized_main_image = resize_image(img_path) |
|
|
|
return status, resized_main_image, similar_images_list |
|
|
|
return "Please provide both directory and image path.", None, None |
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Photo2Photo Search Engine") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=5): |
|
img_dir = gr.Textbox(label="Directory to search") |
|
with gr.Column(scale=3): |
|
img_path = gr.Image(label="Upload an image", type="filepath") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
similar_images = gr.Number(label="Number of similar images to display:", value=7, minimum=1, maximum=50, step=1) |
|
with gr.Column(scale=1): |
|
save_model = gr.Checkbox(label="Save Model", value=False, info="Save the model for faster loads, check if you search in same folder again and again") |
|
recursive = gr.Checkbox(label="Recursive", value=False, info="Search recursively for images in child folders") |
|
with gr.Column(scale=1): |
|
submit_button = gr.Button("Find Similar Images") |
|
|
|
output_text = gr.Textbox(label="Status") |
|
main_image_output = gr.Image(label="Main Image") |
|
similar_images_output = gr.Gallery(label="Similar Images", show_label=True) |
|
|
|
submit_button.click( |
|
find_similar_images, |
|
inputs=[img_dir, img_path, similar_images, save_model, recursive], |
|
outputs=[output_text, main_image_output, similar_images_output] |
|
) |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|