Spaces:
Running
Running
| import os | |
| import numpy as np | |
| import pandas as pd | |
| import gradio as gr | |
| import tensorflow as tf | |
| from annoy import AnnoyIndex | |
| from tensorflow import keras | |
| def load_image(image_path): | |
| image = tf.io.read_file(image_path) | |
| image = tf.image.decode_jpeg(image, channels=3) | |
| image = tf.image.resize(image, (224, 224)) | |
| image = tf.image.convert_image_dtype(image, tf.float32) | |
| image = image/255. | |
| return image.numpy() | |
| # Specify Database Path | |
| database_path = './AnimalSubset' | |
| # Create Example Images | |
| class_names = [] | |
| with open('./Animal-ClassNames.txt', mode='r') as names: | |
| class_names = names.read().split(',')[:-1] | |
| example_image_paths = [ | |
| './AnimalSubset/Beetle/Beetle-Train (101).jpeg', | |
| './AnimalSubset/Butterfly/Butterfly-train (1042).jpeg', | |
| './AnimalSubset/Cat/Cat-Train (1004).jpeg', | |
| './AnimalSubset/Cow/Cow-Train (1022).jpeg', | |
| './AnimalSubset/Dog/Dog-Train (1144).jpeg', | |
| './AnimalSubset/Elephant/Elephant-Train (1043).jpeg', | |
| './AnimalSubset/Gorilla/Gorilla (1045).jpeg', | |
| './AnimalSubset/Hippo/Hippo - Train (1133).jpeg', | |
| './AnimalSubset/Lizard/Lizard-Train (161).jpeg', | |
| './AnimalSubset/Monkey/M (224).jpeg', | |
| './AnimalSubset/Mouse/Mouse-Train (1225).jpeg', | |
| './AnimalSubset/Panda/Panda (1992).jpeg', | |
| './AnimalSubset/Spider/Spider-Train (1191).jpeg', | |
| './AnimalSubset/Tiger/Tiger (1020).jpeg', | |
| './AnimalSubset/Zebra/Zebra (975).jpeg' | |
| ] | |
| example_images = [load_image(path) for path in example_image_paths] | |
| # Load Feature Extractor | |
| feature_extractor_path = './Animal-FeatureExtractor.keras' | |
| feature_extractor = keras.models.load_model( | |
| feature_extractor_path, compile=False) | |
| # Load Annoy index | |
| index_path = './AnimalSubset.ann' | |
| annoy_index = AnnoyIndex(256, 'angular') | |
| annoy_index.load(index_path) | |
| def similarity_search( | |
| query_image, num_images=5, *_, | |
| feature_extractor=feature_extractor, | |
| annoy_index=annoy_index, | |
| database_path=database_path, | |
| metadata_path='./Animals.csv' | |
| ): | |
| if np.max(query_image) == 255: | |
| query_image = query_image/255. | |
| query_vector = feature_extractor.predict( | |
| query_image[np.newaxis, ...], verbose=0)[0] | |
| # Compute nearest neighbors | |
| nearest_neighbors = annoy_index.get_nns_by_vector(query_vector, num_images) | |
| # Load metadata | |
| metadata = pd.read_csv(metadata_path, index_col=0) | |
| metadata = metadata.iloc[nearest_neighbors] | |
| closest_class = metadata.class_name.values[0] | |
| # Similar Images | |
| similar_images = [ | |
| load_image(os.path.join(database_path, class_name, file_name)) | |
| for class_name, file_name in zip(metadata.class_name.values, metadata.file_name.values) | |
| ] | |
| # return closest_class, similar_images | |
| image_gallery = gr.Gallery( | |
| value=similar_images, | |
| label='Similar Images', | |
| object_fit='fill', | |
| preview=True, | |
| visible=True, | |
| ) | |
| return closest_class, image_gallery | |
| # Gradio Application | |
| with gr.Blocks(theme='soft') as app: | |
| gr.Markdown("# Animal - Content Based Image Retrieval (CBIR)") | |
| gr.Markdown(f"Model only supports: {', '.join(class_names[:-1])} and {class_names[-1]}") | |
| gr.Markdown("Disclaimer:- Model might suggest incorrect images, try using a different image.") | |
| with gr.Row(equal_height=True): | |
| # Image Input | |
| query_image = gr.Image( | |
| label='Query Image', | |
| sources=['upload', 'clipboard'], | |
| height='50vh' | |
| ) | |
| # Output Gallery Display | |
| output_gallery = gr.Gallery(visible=False) | |
| with gr.Row(equal_height=True): | |
| # Predicted Class | |
| pred_class = gr.Textbox( | |
| label='Predicted Class', placeholder='Let the model think!!...') | |
| # Number of images to search | |
| n_images = gr.Slider( | |
| value=10, | |
| label='Number of images to search', | |
| minimum=1, | |
| maximum=99, | |
| step=1 | |
| ) | |
| # Search Button | |
| search_btn = gr.Button('Search') | |
| # Example Images | |
| examples = gr.Examples( | |
| examples=example_images, | |
| inputs=query_image, | |
| label='Something similar to me??', | |
| ) | |
| # Search - On Click | |
| search_btn.click( | |
| fn=similarity_search, | |
| inputs=[query_image, n_images], | |
| outputs=[pred_class, output_gallery] | |
| ) | |
| if __name__ == '__main__': | |
| app.launch() | |
| # pred_class, sim_images = similarity_search(example_images[class_names.index('Spider')]) | |
| # print(pred_class) |