Spaces:
Runtime error
Runtime error
import numpy as np | |
from PIL import Image | |
from tensorflow.keras.datasets import cifar10 | |
from huggingface_hub import from_pretrained_keras | |
import gradio as gr | |
import os | |
def prepare_output(neighbours): | |
"""Function to return the image grid based on the nearest neighbours | |
@params neighbours: List of indices of the nearest neighbours""" | |
anchor_near_neighbours = reversed(neighbours) | |
img_grid = Image.new("RGB", (HEIGHT_WIDTH * 5, HEIGHT_WIDTH * 2)) | |
# Image Grid of top-10 neighbours | |
for idx, nn_idx in enumerate(anchor_near_neighbours): | |
img_arr = (np.array(x_test[nn_idx]) * 255).astype(np.uint8) | |
img_grid.paste( | |
Image.fromarray(img_arr, "RGB"), | |
((idx % 5) * HEIGHT_WIDTH, (idx // 5) * HEIGHT_WIDTH), | |
) | |
return img_grid | |
def get_nearest_neighbours(img): | |
"""Has the inference code to get the nearest neighbours from the model | |
@params img: Image to be fed to the model""" | |
# Pre-process image | |
img = np.expand_dims(img / 255, axis=0) | |
img_x_test = np.append(x_test, img, axis=0) | |
# Get the embeddings and check the cosine distance | |
embeddings = model.predict(img_x_test) | |
gram_matrix = np.einsum("ae,be->ab", embeddings, embeddings) | |
near_neighbours = np.argsort(gram_matrix.T)[:, -(NEAR_NEIGHBOURS + 1) :] | |
# Make image grid output | |
img_grid = prepare_output(near_neighbours[-1][:-1]) | |
return np.array(img_grid) | |
if __name__ == "__main__": | |
# Constants | |
HEIGHT_WIDTH = 32 | |
NEAR_NEIGHBOURS = 10 | |
(x_train, y_train), (x_test, y_test) = cifar10.load_data() | |
x_test = x_test.astype("float32") / 255.0 | |
model = from_pretrained_keras("keras-io/cifar10_metric_learning") | |
examples = os.listdir("examples") | |
title = "Metric Learning for Image Similarity Search" | |
more_text = """Embeddings for the input image are computed using the model. The nearest neighbours are then calculated | |
using cosine distance. These are shown here as an image grid.""" | |
description = f"This space uses model trained on CIFAR10 dataset using metric learning technique.\n{more_text}\n\n" | |
article = """ | |
<p style='text-align: center'> | |
<a href='https://keras.io/examples/vision/metric_learning/' target='_blank'>Keras Example given by Mat Kelcey</a> | |
<br> | |
Space by Vrinda Prabhu | |
</p> | |
""" | |
gr.Interface( | |
fn=get_nearest_neighbours, | |
inputs=gr.Image(shape=(32, 32)), # Resize to CIFAR | |
outputs=gr.Image(), | |
examples=examples, | |
article=article, | |
allow_flagging="never", | |
analytics_enabled=False, | |
title=title, | |
description=description, | |
).launch(enable_queue=True) | |