browser-backend / app.py
atwang's picture
update API variables
73ff1bb
raw
history blame
4.35 kB
import gradio as gr
import torch
import numpy as np
import h5py
import faiss
from PIL import Image
import io
import pickle
import random
import click
def getRandID():
indx = random.randrange(0, 396503)
return indx_to_id_dict[indx], indx
def chooseImageIndex(indexType):
if indexType == "FlatIP(default)":
return image_index_IP
elif indexType == "FlatL2":
raise NotImplementedError
return image_index_L2
elif indexType == "HNSWFlat":
raise NotImplementedError
return image_index_HNSW
elif indexType == "IVFFlat":
raise NotImplementedError
return image_index_IVF
elif indexType == "LSH":
raise NotImplementedError
return image_index_LSH
def chooseDNAIndex(indexType):
if indexType == "FlatIP(default)":
return dna_index_IP
elif indexType == "FlatL2":
raise NotImplementedError
return dna_index_L2
elif indexType == "HNSWFlat":
raise NotImplementedError
return dna_index_HNSW
elif indexType == "IVFFlat":
raise NotImplementedError
return dna_index_IVF
elif indexType == "LSH":
raise NotImplementedError
return dna_index_LSH
def searchEmbeddings(id, mod1, mod2, indexType):
# variable and index initialization
dim = 768
count = 0
num_neighbors = 10
index = faiss.IndexFlatIP(dim)
# get index
if mod2 == "Image":
index = chooseImageIndex(indexType)
elif mod2 == "DNA":
index = chooseDNAIndex(indexType)
# search for query
if mod1 == "Image":
query = id_to_image_emb_dict[id]
elif mod1 == "DNA":
query = id_to_dna_emb_dict[id]
query = query.astype(np.float32)
D, I = index.search(query, num_neighbors)
id_list = []
i = 1
for indx in I[0]:
id = indx_to_id_dict[indx]
id_list.append(id)
return id_list
with gr.Blocks() as demo:
# for hf: change all file paths, indx_to_id_dict as well
# load indexes
image_index_IP = faiss.read_index("bioscan_5m_image_IndexFlatIP.index")
# image_index_L2 = faiss.read_index("big_image_index_FlatL2.index")
# image_index_HNSW = faiss.read_index("big_image_index_HNSWFlat.index")
# image_index_IVF = faiss.read_index("big_image_index_IVFFlat.index")
# image_index_LSH = faiss.read_index("big_image_index_LSH.index")
dna_index_IP = faiss.read_index("bioscan_5m_dna_IndexFlatIP.index")
# dna_index_L2 = faiss.read_index("big_dna_index_FlatL2.index")
# dna_index_HNSW = faiss.read_index("big_dna_index_HNSWFlat.index")
# dna_index_IVF = faiss.read_index("big_dna_index_IVFFlat.index")
# dna_index_LSH = faiss.read_index("big_dna_index_LSH.index")
# with open("dataset_processid_list.pickle", "rb") as f:
# dataset_processid_list = pickle.load(f)
# with open("processid_to_index.pickle", "rb") as f:
# processid_to_index = pickle.load(f)
with open("big_indx_to_id_dict.pickle", "rb") as f:
indx_to_id_dict = pickle.load(f)
# initialize both possible dicts
with open("big_id_to_image_emb_dict.pickle", "rb") as f:
id_to_image_emb_dict = pickle.load(f)
# with open("big_id_to_dna_emb_dict.pickle", "rb") as f:
# id_to_dna_emb_dict = pickle.load(f)
id_to_dna_emb_dict = None
with gr.Column():
with gr.Row():
with gr.Column():
rand_id = gr.Textbox(label="Random ID:")
rand_id_indx = gr.Textbox(label="Index:")
id_btn = gr.Button("Get Random ID")
with gr.Column():
key_type = gr.Radio(choices=["DNA", "Image"], label="Search From:")
query_type = gr.Radio(choices=["DNA", "Image"], label="Search To:")
index_type = gr.Radio(
choices=["FlatIP(default)", "FlatL2", "HNSWFlat", "IVFFlat", "LSH"], label="Index:", value="FlatIP(default)"
)
process_id = gr.Textbox(label="ID:", info="Enter a sample ID to search for")
process_id_list = gr.Textbox(label="Closest 10 matches:")
search_btn = gr.Button("Search")
id_btn.click(fn=getRandID, inputs=[], outputs=[rand_id, rand_id_indx])
search_btn.click(fn=searchEmbeddings, inputs=[process_id, key_type, query_type, index_type], outputs=[process_id_list])
demo.launch()