File size: 1,367 Bytes
7dc7c5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import gradio as gr
from utils import dataset_rag

dirty_hack = True

if dirty_hack:
    import os    
    os.environ['KMP_DUPLICATE_LIB_OK']='True'


datasets = [
    "not-lain/embedded-pokemon"
]

space_installed = None

try:
    import spaces
    space_installed = True
except ImportError:
    space_installed = False

if space_installed:
    @spaces.GPU
    def instance(dataset_name):
        return dataset_rag.Instance(dataset_name)
else:
    def instance(dataset_name):
        return dataset_rag.Instance(dataset_name)

def download(dataset):
    global ds
    client = instance(datasets[0])
    ds = client
    return client

def search_ds(image):
    scores, retrieved_examples = ds.search(image)
    return retrieved_examples, scores

with gr.Blocks(title="Image RAG") as demo:
    ds = None
    interactive_mode = False
    dataset_name = gr.Dropdown(label="Dataset", choices=datasets, value=datasets[0])
    download_dataset = gr.Button("Download Dataset")

    search = gr.Image(label="Search Image")
    search_button = gr.Button("Search")
    results = gr.Gallery(label="Results")
    scores = gr.Textbox(label="Scores", type="text", value="")
    search_button.click(search_ds, inputs=[search], outputs=[results, scores])
    
    download_dataset.click(download, dataset_name)


demo.launch()