Spaces:
Runtime error
Runtime error
| import numpy as np | |
| from PIL import Image | |
| from tensorflow.keras.datasets import cifar10 | |
| from huggingface_hub import from_pretrained_keras | |
| import gradio as gr | |
| def prepare_output(neighbours): | |
| """Function to return the image grid based on the nearest neighbours | |
| @params neighbours: List of indices of the nearest neighbours""" | |
| anchor_near_neighbours = reversed(neighbours) | |
| img_grid = Image.new("RGB", (HEIGHT_WIDTH * 5, HEIGHT_WIDTH * 2)) | |
| # Image Grid of top-10 neighbours | |
| for idx, nn_idx in enumerate(anchor_near_neighbours): | |
| img_arr = (np.array(x_test[nn_idx]) * 255).astype(np.uint8) | |
| img_grid.paste( | |
| Image.fromarray(img_arr, "RGB"), | |
| ((idx % 5) * HEIGHT_WIDTH, (idx // 5) * HEIGHT_WIDTH), | |
| ) | |
| return img_grid | |
| def get_nearest_neighbours(img): | |
| """Has the inference code to get the nearest neighbours from the model | |
| @params img: Image to be fed to the model""" | |
| # Pre-process image | |
| img = np.expand_dims(img / 255, axis=0) | |
| img_x_test = np.append(x_test, img, axis=0) | |
| # Get the embeddings and check the cosine distance | |
| embeddings = model.predict(img_x_test) | |
| gram_matrix = np.einsum("ae,be->ab", embeddings, embeddings) | |
| near_neighbours = np.argsort(gram_matrix.T)[:, -(NEAR_NEIGHBOURS + 1) :] | |
| # Make image grid output | |
| img_grid = prepare_output(near_neighbours[-1][:-1]) | |
| return np.array(img_grid) | |
| if __name__ == "__main__": | |
| # Constants | |
| HEIGHT_WIDTH = 32 | |
| NEAR_NEIGHBOURS = 10 | |
| (x_train, y_train), (x_test, y_test) = cifar10.load_data() | |
| x_test = x_test.astype("float32") / 255.0 | |
| model = from_pretrained_keras("keras-io/cifar10_metric_learning") | |
| examples = ["examples/yatch.jpeg", "examples/horse.jpeg", "examples/car.jpeg"] | |
| title = "Metric Learning for Image Similarity Search" | |
| more_text = """Embeddings for the input image are computed using the model. The nearest neighbours are then calculated | |
| using cosine distance. These are shown here as an image grid.""" | |
| description = f"This space uses model trained on CIFAR10 dataset using metric learning technique.\n{more_text}\n\n" | |
| article = """ | |
| <p style='text-align: center'> | |
| <a href='https://keras.io/examples/vision/metric_learning/' target='_blank'>Keras Example given by Mat Kelcey</a> | |
| <br> | |
| Space by Vrinda Prabhu | |
| </p> | |
| """ | |
| gr.Interface( | |
| fn=get_nearest_neighbours, | |
| inputs=gr.Image(shape=(32, 32)), # Resize to CIFAR | |
| outputs=gr.Image(), | |
| examples=examples, | |
| article=article, | |
| allow_flagging="never", | |
| analytics_enabled=False, | |
| title=title, | |
| description=description, | |
| ).launch(enable_queue=True) | |