""" Thanks to Freddy Boulton (https://github.com/freddyaboulton) for helping with this. """ import pickle import gradio as gr from datasets import load_dataset from transformers import AutoModel from similarity_utils import BuildLSHTable seed = 42 # Only runs once when the script is first run. with open("lsh.pickle", "rb") as handle: loaded_lsh = pickle.load(handle) # Load model for computing embeddings. model_ckpt = "nateraw/vit-base-beans" model = AutoModel.from_pretrained(model_ckpt) lsh_builder = BuildLSHTable(model) lsh_builder.lsh = loaded_lsh # Candidate images. dataset = load_dataset("beans") candidate_dataset = dataset["train"].shuffle(seed=seed) def query(image, top_k): results = lsh_builder.query(image) # Should be a list of string file paths for gr.Gallery to work images = [] # List of labels for each image in the gallery labels = [] candidates = [] overlaps = [] for idx, r in enumerate(sorted(results, key=results.get, reverse=True)): if idx == top_k: break image_id, label = r.split("_")[0], r.split("_")[1] candidates.append(candidate_dataset[int(image_id)]["image"]) labels.append(label) overlaps.append(results[r]) candidates.insert(0, image) labels.insert(0, label) for i, candidate in enumerate(candidates): filename = f"{i}.png" candidate.save(filename) images.append(filename) # The gallery component can be a list of tuples, where the first element is a path to a file # and the second element is an optional caption for that image return list(zip(images, labels)) # You can set the type of gr.Image to be PIL, numpy or str (filepath) # Not sure what the best for this demo is. gr.Interface( query, inputs=[gr.Image(), gr.Slider(value=5, minimum=1, maximum=10, step=1)], outputs=gr.Gallery(), ).launch()