File size: 4,275 Bytes
db66d62
 
 
 
 
 
b383c02
db66d62
 
 
b383c02
db66d62
 
 
 
b383c02
db66d62
b383c02
db66d62
b383c02
 
db66d62
b383c02
 
db66d62
b383c02
 
db66d62
b383c02
 
db66d62
 
b383c02
db66d62
b383c02
db66d62
b383c02
 
db66d62
b383c02
 
db66d62
b383c02
 
db66d62
b383c02
 
db66d62
 
 
 
 
 
 
 
 
 
 
 
b383c02
db66d62
b383c02
db66d62
 
 
b383c02
db66d62
b383c02
db66d62
 
 
 
 
 
 
 
 
b383c02
db66d62
 
b383c02
db66d62
 
 
 
 
b383c02
 
 
 
 
 
 
 
 
 
 
db66d62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b383c02
 
 
db66d62
b383c02
 
db66d62
 
b383c02
 
db66d62
b383c02
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import gradio as gr
import torch
import numpy as np
import h5py
import faiss
from PIL import Image
import io
import pickle
import random


def getRandID():
    indx = random.randrange(0, 396503)
    return indx_to_id_dict[indx], indx


def chooseImageIndex(indexType):
    if indexType == "FlatIP(default)":
        return image_index_IP
    elif indexType == "FlatL2":
        raise NotImplementedError
        return image_index_L2
    elif indexType == "HNSWFlat":
        raise NotImplementedError
        return image_index_HNSW
    elif indexType == "IVFFlat":
        raise NotImplementedError
        return image_index_IVF
    elif indexType == "LSH":
        raise NotImplementedError
        return image_index_LSH


def chooseDNAIndex(indexType):
    if indexType == "FlatIP(default)":
        return dna_index_IP
    elif indexType == "FlatL2":
        raise NotImplementedError
        return dna_index_L2
    elif indexType == "HNSWFlat":
        raise NotImplementedError
        return dna_index_HNSW
    elif indexType == "IVFFlat":
        raise NotImplementedError
        return dna_index_IVF
    elif indexType == "LSH":
        raise NotImplementedError
        return dna_index_LSH


def searchEmbeddings(id, mod1, mod2, indexType):
    # variable and index initialization
    dim = 768
    count = 0
    num_neighbors = 10

    index = faiss.IndexFlatIP(dim)

    # get index
    if mod2 == "Image":
        index = chooseImageIndex(indexType)
    elif mod2 == "DNA":
        index = chooseDNAIndex(indexType)

    # search for query
    if mod1 == "Image":
        query = id_to_image_emb_dict[id]
    elif mod1 == "DNA":
        query = id_to_dna_emb_dict[id]
    query = query.astype(np.float32)
    D, I = index.search(query, num_neighbors)

    id_list = []
    i = 1
    for indx in I[0]:
        id = indx_to_id_dict[indx]
        id_list.append(id)

    return id_list


with gr.Blocks() as demo:

    # for hf: change all file paths, indx_to_id_dict as well

    # load indexes
    image_index_IP = faiss.read_index("bioscan_5m_image_IndexFlatIP.index")
    # image_index_L2 = faiss.read_index("big_image_index_FlatL2.index")
    # image_index_HNSW = faiss.read_index("big_image_index_HNSWFlat.index")
    # image_index_IVF = faiss.read_index("big_image_index_IVFFlat.index")
    # image_index_LSH = faiss.read_index("big_image_index_LSH.index")

    dna_index_IP = faiss.read_index("bioscan_5m_dna_IndexFlatIP.index")
    # dna_index_L2 = faiss.read_index("big_dna_index_FlatL2.index")
    # dna_index_HNSW = faiss.read_index("big_dna_index_HNSWFlat.index")
    # dna_index_IVF = faiss.read_index("big_dna_index_IVFFlat.index")
    # dna_index_LSH = faiss.read_index("big_dna_index_LSH.index")

    with open("dataset_processid_list.pickle", "rb") as f:
        dataset_processid_list = pickle.load(f)
    with open("processid_to_index.pickle", "rb") as f:
        processid_to_index = pickle.load(f)
    with open("big_indx_to_id_dict.pickle", "rb") as f:
        indx_to_id_dict = pickle.load(f)

    # initialize both possible dicts
    with open("big_id_to_image_emb_dict.pickle", "rb") as f:
        id_to_image_emb_dict = pickle.load(f)
    with open("big_id_to_dna_emb_dict.pickle", "rb") as f:
        id_to_dna_emb_dict = pickle.load(f)

    with gr.Column():
        with gr.Row():
            with gr.Column():
                rand_id = gr.Textbox(label="Random ID:")
                rand_id_indx = gr.Textbox(label="Index:")
                id_btn = gr.Button("Get Random ID")
            with gr.Column():
                mod1 = gr.Radio(choices=["DNA", "Image"], label="Search From:")
                mod2 = gr.Radio(choices=["DNA", "Image"], label="Search To:")

        indexType = gr.Radio(
            choices=["FlatIP(default)", "FlatL2", "HNSWFlat", "IVFFlat", "LSH"], label="Index:", value="FlatIP(default)"
        )
        process_id = gr.Textbox(label="ID:", info="Enter a sample ID to search for")
        process_id_list = gr.Textbox(label="Closest 10 matches:")
        search_btn = gr.Button("Search")
        id_btn.click(fn=getRandID, inputs=[], outputs=[rand_id, rand_id_indx])

    search_btn.click(fn=searchEmbeddings, inputs=[process_id, mod1, mod2, indexType], outputs=[process_id_list])


demo.launch()