Spaces:
Running
Running
File size: 4,275 Bytes
db66d62 b383c02 db66d62 b383c02 db66d62 b383c02 db66d62 b383c02 db66d62 b383c02 db66d62 b383c02 db66d62 b383c02 db66d62 b383c02 db66d62 b383c02 db66d62 b383c02 db66d62 b383c02 db66d62 b383c02 db66d62 b383c02 db66d62 b383c02 db66d62 b383c02 db66d62 b383c02 db66d62 b383c02 db66d62 b383c02 db66d62 b383c02 db66d62 b383c02 db66d62 b383c02 db66d62 b383c02 db66d62 b383c02 db66d62 b383c02 db66d62 b383c02 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import gradio as gr
import torch
import numpy as np
import h5py
import faiss
from PIL import Image
import io
import pickle
import random
def getRandID():
indx = random.randrange(0, 396503)
return indx_to_id_dict[indx], indx
def chooseImageIndex(indexType):
if indexType == "FlatIP(default)":
return image_index_IP
elif indexType == "FlatL2":
raise NotImplementedError
return image_index_L2
elif indexType == "HNSWFlat":
raise NotImplementedError
return image_index_HNSW
elif indexType == "IVFFlat":
raise NotImplementedError
return image_index_IVF
elif indexType == "LSH":
raise NotImplementedError
return image_index_LSH
def chooseDNAIndex(indexType):
if indexType == "FlatIP(default)":
return dna_index_IP
elif indexType == "FlatL2":
raise NotImplementedError
return dna_index_L2
elif indexType == "HNSWFlat":
raise NotImplementedError
return dna_index_HNSW
elif indexType == "IVFFlat":
raise NotImplementedError
return dna_index_IVF
elif indexType == "LSH":
raise NotImplementedError
return dna_index_LSH
def searchEmbeddings(id, mod1, mod2, indexType):
# variable and index initialization
dim = 768
count = 0
num_neighbors = 10
index = faiss.IndexFlatIP(dim)
# get index
if mod2 == "Image":
index = chooseImageIndex(indexType)
elif mod2 == "DNA":
index = chooseDNAIndex(indexType)
# search for query
if mod1 == "Image":
query = id_to_image_emb_dict[id]
elif mod1 == "DNA":
query = id_to_dna_emb_dict[id]
query = query.astype(np.float32)
D, I = index.search(query, num_neighbors)
id_list = []
i = 1
for indx in I[0]:
id = indx_to_id_dict[indx]
id_list.append(id)
return id_list
with gr.Blocks() as demo:
# for hf: change all file paths, indx_to_id_dict as well
# load indexes
image_index_IP = faiss.read_index("bioscan_5m_image_IndexFlatIP.index")
# image_index_L2 = faiss.read_index("big_image_index_FlatL2.index")
# image_index_HNSW = faiss.read_index("big_image_index_HNSWFlat.index")
# image_index_IVF = faiss.read_index("big_image_index_IVFFlat.index")
# image_index_LSH = faiss.read_index("big_image_index_LSH.index")
dna_index_IP = faiss.read_index("bioscan_5m_dna_IndexFlatIP.index")
# dna_index_L2 = faiss.read_index("big_dna_index_FlatL2.index")
# dna_index_HNSW = faiss.read_index("big_dna_index_HNSWFlat.index")
# dna_index_IVF = faiss.read_index("big_dna_index_IVFFlat.index")
# dna_index_LSH = faiss.read_index("big_dna_index_LSH.index")
with open("dataset_processid_list.pickle", "rb") as f:
dataset_processid_list = pickle.load(f)
with open("processid_to_index.pickle", "rb") as f:
processid_to_index = pickle.load(f)
with open("big_indx_to_id_dict.pickle", "rb") as f:
indx_to_id_dict = pickle.load(f)
# initialize both possible dicts
with open("big_id_to_image_emb_dict.pickle", "rb") as f:
id_to_image_emb_dict = pickle.load(f)
with open("big_id_to_dna_emb_dict.pickle", "rb") as f:
id_to_dna_emb_dict = pickle.load(f)
with gr.Column():
with gr.Row():
with gr.Column():
rand_id = gr.Textbox(label="Random ID:")
rand_id_indx = gr.Textbox(label="Index:")
id_btn = gr.Button("Get Random ID")
with gr.Column():
mod1 = gr.Radio(choices=["DNA", "Image"], label="Search From:")
mod2 = gr.Radio(choices=["DNA", "Image"], label="Search To:")
indexType = gr.Radio(
choices=["FlatIP(default)", "FlatL2", "HNSWFlat", "IVFFlat", "LSH"], label="Index:", value="FlatIP(default)"
)
process_id = gr.Textbox(label="ID:", info="Enter a sample ID to search for")
process_id_list = gr.Textbox(label="Closest 10 matches:")
search_btn = gr.Button("Search")
id_btn.click(fn=getRandID, inputs=[], outputs=[rand_id, rand_id_indx])
search_btn.click(fn=searchEmbeddings, inputs=[process_id, mod1, mod2, indexType], outputs=[process_id_list])
demo.launch()
|