File size: 4,019 Bytes
db66d62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import gradio as gr
import torch
import numpy as np
import h5py
import faiss
from PIL import Image
import io 
import pickle
import random

def getRandID():
    indx = random.randrange(0, 396503)
    return indx_to_id_dict[indx], indx

def chooseImageIndex(indexType):
    if (indexType == "FlatIP(default)"):
        return image_index_IP
    elif (indexType == "FlatL2"):
        return image_index_L2
    elif (indexType == "HNSWFlat"):
        return image_index_HNSW
    elif (indexType == "IVFFlat"):
        return image_index_IVF
    elif (indexType == "LSH"):
        return image_index_LSH

def chooseDNAIndex(indexType):
    if (indexType == "FlatIP(default)"):
        return dna_index_IP
    elif (indexType == "FlatL2"):
        return dna_index_L2
    elif (indexType == "HNSWFlat"):
        return dna_index_HNSW
    elif (indexType == "IVFFlat"):
        return dna_index_IVF
    elif (indexType == "LSH"):
        return dna_index_LSH



def searchEmbeddings(id, mod1, mod2, indexType):
    # variable and index initialization
    dim = 768
    count = 0
    num_neighbors = 10

    index = faiss.IndexFlatIP(dim)

    # get index
    if (mod2 == "Image"):
        index = chooseImageIndex(indexType)
    elif (mod2 == "DNA"):
        index = chooseDNAIndex(indexType)
    

    # search for query
    if (mod1 == "Image"):
        query = id_to_image_emb_dict[id]
    elif (mod1 == "DNA"):
        query = id_to_dna_emb_dict[id]
    query = query.astype(np.float32)
    D, I = index.search(query, num_neighbors)

    id_list = []
    i = 1
    for indx in I[0]:
        id = indx_to_id_dict[indx]
        id_list.append(id)
        
    return id_list

with gr.Blocks() as demo:

    # for hf: change all file paths, indx_to_id_dict as well

    # load indexes
    image_index_IP = faiss.read_index("big_image_index_FlatIP.index")
    image_index_L2 = faiss.read_index("big_image_index_FlatL2.index")
    image_index_HNSW = faiss.read_index("big_image_index_HNSWFlat.index")
    image_index_IVF = faiss.read_index("big_image_index_IVFFlat.index")
    image_index_LSH = faiss.read_index("big_image_index_LSH.index")

    dna_index_IP = faiss.read_index("big_dna_index_FlatIP.index")
    dna_index_L2 = faiss.read_index("big_dna_index_FlatL2.index")
    dna_index_HNSW = faiss.read_index("big_dna_index_HNSWFlat.index")
    dna_index_IVF = faiss.read_index("big_dna_index_IVFFlat.index")
    dna_index_LSH = faiss.read_index("big_dna_index_LSH.index")

    with open("dataset_processid_list.pickle", "rb") as f:
        dataset_processid_list = pickle.load(f)
    with open("processid_to_index.pickle", "rb") as f:
        processid_to_index = pickle.load(f)
    with open("big_indx_to_id_dict.pickle", "rb") as f:
        indx_to_id_dict = pickle.load(f)

    # initialize both possible dicts
    with open("big_id_to_image_emb_dict.pickle", "rb") as f:
        id_to_image_emb_dict = pickle.load(f)
    with open("big_id_to_dna_emb_dict.pickle", "rb") as f:
        id_to_dna_emb_dict = pickle.load(f)

    with gr.Column():
        with gr.Row():
            with gr.Column():
                rand_id = gr.Textbox(label="Random ID:")
                rand_id_indx = gr.Textbox(label="Index:")
                id_btn = gr.Button("Get Random ID")
            with gr.Column():
                mod1 = gr.Radio(choices=["DNA", "Image"], label="Search From:")
                mod2 = gr.Radio(choices=["DNA", "Image"], label="Search To:")

        indexType = gr.Radio(choices=["FlatIP(default)", "FlatL2", "HNSWFlat", "IVFFlat", "LSH"], label="Index:", value="FlatIP(default)")
        process_id = gr.Textbox(label="ID:", info="Enter a sample ID to search for")
        process_id_list = gr.Textbox(label="Closest 10 matches:" )
        search_btn = gr.Button("Search") 
        id_btn.click(fn=getRandID, inputs=[], outputs=[rand_id, rand_id_indx])

    search_btn.click(fn=searchEmbeddings, inputs=[process_id, mod1, mod2, indexType], 
                     outputs=[process_id_list])
    

demo.launch()