Spaces:
Running
Running
import gradio as gr | |
import torch | |
import numpy as np | |
import h5py | |
import faiss | |
from PIL import Image | |
import io | |
import pickle | |
import random | |
import click | |
def getRandID(): | |
indx = random.randrange(0, 396503) | |
return indx_to_id_dict[indx], indx | |
def chooseImageIndex(indexType): | |
if indexType == "FlatIP(default)": | |
return image_index_IP | |
elif indexType == "FlatL2": | |
raise NotImplementedError | |
return image_index_L2 | |
elif indexType == "HNSWFlat": | |
raise NotImplementedError | |
return image_index_HNSW | |
elif indexType == "IVFFlat": | |
raise NotImplementedError | |
return image_index_IVF | |
elif indexType == "LSH": | |
raise NotImplementedError | |
return image_index_LSH | |
def chooseDNAIndex(indexType): | |
if indexType == "FlatIP(default)": | |
return dna_index_IP | |
elif indexType == "FlatL2": | |
raise NotImplementedError | |
return dna_index_L2 | |
elif indexType == "HNSWFlat": | |
raise NotImplementedError | |
return dna_index_HNSW | |
elif indexType == "IVFFlat": | |
raise NotImplementedError | |
return dna_index_IVF | |
elif indexType == "LSH": | |
raise NotImplementedError | |
return dna_index_LSH | |
def searchEmbeddings(id, mod1, mod2, indexType): | |
# variable and index initialization | |
dim = 768 | |
count = 0 | |
num_neighbors = 10 | |
index = faiss.IndexFlatIP(dim) | |
# get index | |
if mod2 == "Image": | |
index = chooseImageIndex(indexType) | |
elif mod2 == "DNA": | |
index = chooseDNAIndex(indexType) | |
# search for query | |
if mod1 == "Image": | |
query = id_to_image_emb_dict[id] | |
elif mod1 == "DNA": | |
query = id_to_dna_emb_dict[id] | |
query = query.astype(np.float32) | |
D, I = index.search(query, num_neighbors) | |
id_list = [] | |
i = 1 | |
for indx in I[0]: | |
id = indx_to_id_dict[indx] | |
id_list.append(id) | |
return id_list | |
with gr.Blocks() as demo: | |
# for hf: change all file paths, indx_to_id_dict as well | |
# load indexes | |
image_index_IP = faiss.read_index("bioscan_5m_image_IndexFlatIP.index") | |
# image_index_L2 = faiss.read_index("big_image_index_FlatL2.index") | |
# image_index_HNSW = faiss.read_index("big_image_index_HNSWFlat.index") | |
# image_index_IVF = faiss.read_index("big_image_index_IVFFlat.index") | |
# image_index_LSH = faiss.read_index("big_image_index_LSH.index") | |
dna_index_IP = faiss.read_index("bioscan_5m_dna_IndexFlatIP.index") | |
# dna_index_L2 = faiss.read_index("big_dna_index_FlatL2.index") | |
# dna_index_HNSW = faiss.read_index("big_dna_index_HNSWFlat.index") | |
# dna_index_IVF = faiss.read_index("big_dna_index_IVFFlat.index") | |
# dna_index_LSH = faiss.read_index("big_dna_index_LSH.index") | |
# with open("dataset_processid_list.pickle", "rb") as f: | |
# dataset_processid_list = pickle.load(f) | |
# with open("processid_to_index.pickle", "rb") as f: | |
# processid_to_index = pickle.load(f) | |
with open("big_indx_to_id_dict.pickle", "rb") as f: | |
indx_to_id_dict = pickle.load(f) | |
# initialize both possible dicts | |
with open("big_id_to_image_emb_dict.pickle", "rb") as f: | |
id_to_image_emb_dict = pickle.load(f) | |
# with open("big_id_to_dna_emb_dict.pickle", "rb") as f: | |
# id_to_dna_emb_dict = pickle.load(f) | |
id_to_dna_emb_dict = None | |
with gr.Column(): | |
with gr.Row(): | |
with gr.Column(): | |
rand_id = gr.Textbox(label="Random ID:") | |
rand_id_indx = gr.Textbox(label="Index:") | |
id_btn = gr.Button("Get Random ID") | |
with gr.Column(): | |
key_type = gr.Radio(choices=["DNA", "Image"], label="Search From:") | |
query_type = gr.Radio(choices=["DNA", "Image"], label="Search To:") | |
index_type = gr.Radio( | |
choices=["FlatIP(default)", "FlatL2", "HNSWFlat", "IVFFlat", "LSH"], label="Index:", value="FlatIP(default)" | |
) | |
process_id = gr.Textbox(label="ID:", info="Enter a sample ID to search for") | |
process_id_list = gr.Textbox(label="Closest 10 matches:") | |
search_btn = gr.Button("Search") | |
id_btn.click(fn=getRandID, inputs=[], outputs=[rand_id, rand_id_indx]) | |
search_btn.click(fn=searchEmbeddings, inputs=[process_id, key_type, query_type, index_type], outputs=[process_id_list]) | |
demo.launch() | |