p2p-demo / app.py
Tianyu Ding
u
fd9106d
raw
history blame
10.3 kB
import gradio as gr
from pathlib import Path
from PIL import Image
import os
# from utils import get_image_paths, show_images2
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.pyplot as plt
import os
import tensorflow as tf
from sklearn.metrics.pairwise import cosine_similarity
from pathlib import Path
import numpy as np
from tqdm import tqdm
from collections import OrderedDict
import compress_pickle
import concurrent
class ImageSimilarity:
def __init__(
self,
img_dir: Path,
recursive: bool = False,
BATCH_SIZE: int = 64,
IMG_SIZE: int = 224,
save_model: bool = True,
):
self.batch_size = BATCH_SIZE
self.img_size = IMG_SIZE
self.img_dir = img_dir
self.model = tf.keras.applications.MobileNetV2(
input_shape=(IMG_SIZE, IMG_SIZE, 3),
alpha=1.0,
include_top=False,
weights="imagenet",
input_tensor=None,
pooling=None,
classifier_activation="softmax",
)
self.model.trainable = False
self.model.compile()
self.save_model = save_model
self.recursive = recursive
self.ifeatures = None
self.filename = "image_dict.lzma"
self.image_dict = None
self.images_found = None
def get_image_paths(self, directory_path: Path, recursive: bool = False) -> list:
image_extensions = [".jpg", ".jpeg", ".png"] # Add more extensions if needed
image_paths = []
for file_path in directory_path.iterdir():
if file_path.is_file() and (file_path.suffix.lower() in image_extensions):
image_paths.append(str(file_path.absolute()))
elif recursive and file_path.is_dir():
image_paths.extend(self.get_image_paths(file_path, recursive))
return image_paths
def load_image(self, x):
image_data = tf.io.read_file(x)
image_features = tf.image.decode_jpeg(image_data, channels=3)
image_features = tf.image.resize(image_features, (self.img_size, self.img_size))
return image_features
def load_image2(self, x):
image_data = tf.keras.utils.img_to_array(x)
return tf.image.resize(image_data, (self.img_size, self.img_size))
def get_vectors(self, image_data: tf.data.Dataset) -> np.array:
features = []
for i in tqdm(image_data):
y = self.model(i)
pooled_features = tf.keras.layers.GlobalMaxPooling2D()(y)
features.append(pooled_features)
ifeatures = tf.concat(features, axis=0)
ifeatures = tf.cast(ifeatures, tf.float16).numpy()
return ifeatures
def similar_image(self, x, k=5):
x = (
self.load_image(str(x.absolute()))
if isinstance(x, Path)
else self.load_image2(x)
)
x_logits = self.model(tf.expand_dims(x, 0))
x_logits = (
tf.keras.layers.GlobalAveragePooling2D()(x_logits)
.numpy()
.astype("float16")
.reshape((1, -1))
.tolist()
)
x_similarity = cosine_similarity(x_logits, self.ifeatures).tolist()[0]
x_sim_idx = np.argsort(x_similarity)[::-1][:k]
x_sim_values = sorted(x_similarity, reverse=True)[:k]
keys_at_indices = [list(self.image_dict.keys())[index] for index in x_sim_idx]
return keys_at_indices, x_sim_values
def build_image_features(self):
images = self.get_image_paths(self.img_dir, recursive=self.recursive)
image_data = (
tf.data.Dataset.from_tensor_slices(images)
.map(self.load_image, num_parallel_calls=tf.data.AUTOTUNE)
.batch(self.batch_size)
)
self.ifeatures = self.get_vectors(image_data)
self.image_dict = OrderedDict(zip(images, self.ifeatures))
# print('ifeatures.shape:', self.ifeatures.shape)
# print('Features loaded!')
def load_image_dict(self):
if os.path.isfile(self.filename):
image_dict = compress_pickle.load(self.filename, compression="lzma")
images = self.get_image_paths(self.img_dir, recursive=self.recursive)
if images == list(image_dict.keys()):
self.image_dict = image_dict
self.ifeatures = np.array(list(image_dict.values()))
else:
self.build_image_features()
else:
self.build_image_features()
def save_image_dict(self):
compress_pickle.dump(self.image_dict, self.filename, compression="lzma")
def is_changed(self):
images = self.get_image_paths(self.img_dir, recursive=self.recursive)
previous_images = list(self.image_dict.keys())
return images != previous_images
def find_similar_images(self, x, k=5):
# creating/loading vectors
self.load_image_dict()
if k == -1:
k = self.ifeatures.shape[0]
sim_img, x_sim = self.similar_image(x, k=k)
# print('plotting')
plt.figure(figsize=(5, 5))
testimg = plt.imread(str(x.absolute()))
plt.imshow(testimg)
plt.title(f"{x.name}(main)")
plt.show()
self.show_images(sim_img, similar=x_sim)
return x_sim
def find_similar_images2(self, x, k=5):
self.load_image_dict()
if k == -1:
k = self.ifeatures.shape[0]
sim_img, x_sim = self.similar_image(x, k=k)
return sim_img, x_sim
def show_images(self, x: list, similar: list = None, figsize=None):
n_plots = len(x)
# print('n plots: ', n_plots)
if figsize is None:
# figsize = (20, int(n_plots // 5) * 4)
figsize = (20, 5)
# print('figsize: ',figsize)
plt.figure(figsize=figsize)
x = [Path(i) for i in x]
for num, i in enumerate(x, 1):
plt.subplot((n_plots // 5) + 1, 5, num)
img = plt.imread(i)
plt.imshow(img)
title = (
f"{i.name}\n({100 * similar[num - 1]:.2f}%)"
if similar is not None
else i.name
)
plt.title(title)
plt.axis(False)
plt.tight_layout()
plt.show()
def __call__(self, x: Path, k=5):
with concurrent.futures.ThreadPoolExecutor() as executor:
finding = executor.submit(self.find_similar_images(x, k=5))
if self.save_model and (
self.is_changed() or (not Path(self.filename).exists())
):
save_imagedict = executor.submit(self.save_image_dict)
def resize_image(img_path, max_size=800):
with Image.open(img_path) as img:
# change the size of the image to max_size but keep the aspect ratio
width, height = img.size
if width > height:
new_width = max_size
new_height = int(height * (new_width / width))
else:
new_height = max_size
new_width = int(width * (new_height / height))
img = img.resize((new_width, new_height))
return img
def get_image_paths(directory_path: Path, recursive: bool = False) -> list:
image_extensions = [".jpg", ".jpeg", ".png"] # Add more extensions if needed
image_paths = []
for file_path in directory_path.iterdir():
if file_path.is_file() and (file_path.suffix.lower() in image_extensions):
image_paths.append(str(file_path.absolute()))
elif recursive and file_path.is_dir():
image_paths.extend(get_image_paths(file_path, recursive))
return image_paths
def find_similar_images(img_dir, img_path, similar_images, save_model, recursive):
if img_dir and (img_path):
total_images = len(get_image_paths(Path(img_dir), recursive=recursive))
similar_images = min(similar_images, total_images)
main_image = Image.open(img_path) if isinstance(img_path, str) else Image.fromarray(img_path)
image_similarity = ImageSimilarity(
img_dir=Path(img_dir), recursive=recursive, save_model=save_model
)
similar_image_paths, similarity_values = image_similarity.find_similar_images2(main_image, k=similar_images)
# print(similar_image_paths, similarity_values)
if save_model:
image_similarity.save_image_dict()
# Prepare the output
status = f"Found {len(similar_image_paths)} similar images."
# Resize and load similar images
similar_images_list = [
(resize_image(path), f"Similarity: {sim:.4f}")
for path, sim in zip(similar_image_paths, similarity_values)
]
# Resize the main image
resized_main_image = resize_image(img_path)
return status, resized_main_image, similar_images_list
return "Please provide both directory and image path.", None, None
with gr.Blocks() as demo:
gr.Markdown("# Photo2Photo Search Engine")
with gr.Row():
with gr.Column(scale=5):
img_dir = gr.Textbox(label="Directory to search")
with gr.Column(scale=3):
img_path = gr.Image(label="Upload an image", type="filepath")
with gr.Row():
with gr.Column(scale=1):
similar_images = gr.Number(label="Number of similar images to display:", value=7, minimum=1, maximum=50, step=1)
with gr.Column(scale=1):
save_model = gr.Checkbox(label="Save Model", value=False, info="Save the model for faster loads, check if you search in same folder again and again")
recursive = gr.Checkbox(label="Recursive", value=False, info="Search recursively for images in child folders")
with gr.Column(scale=1):
submit_button = gr.Button("Find Similar Images")
output_text = gr.Textbox(label="Status")
main_image_output = gr.Image(label="Main Image")
similar_images_output = gr.Gallery(label="Similar Images", show_label=True)
submit_button.click(
find_similar_images,
inputs=[img_dir, img_path, similar_images, save_model, recursive],
outputs=[output_text, main_image_output, similar_images_output]
)
if __name__ == "__main__":
demo.launch()