image-retriever / app.py
npbm's picture
i cant use git for the life of me. might need more testing
7dc7c5c verified
raw
history blame
1.37 kB
import gradio as gr
from utils import dataset_rag
dirty_hack = True
if dirty_hack:
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'
datasets = [
"not-lain/embedded-pokemon"
]
space_installed = None
try:
import spaces
space_installed = True
except ImportError:
space_installed = False
if space_installed:
@spaces.GPU
def instance(dataset_name):
return dataset_rag.Instance(dataset_name)
else:
def instance(dataset_name):
return dataset_rag.Instance(dataset_name)
def download(dataset):
global ds
client = instance(datasets[0])
ds = client
return client
def search_ds(image):
scores, retrieved_examples = ds.search(image)
return retrieved_examples, scores
with gr.Blocks(title="Image RAG") as demo:
ds = None
interactive_mode = False
dataset_name = gr.Dropdown(label="Dataset", choices=datasets, value=datasets[0])
download_dataset = gr.Button("Download Dataset")
search = gr.Image(label="Search Image")
search_button = gr.Button("Search")
results = gr.Gallery(label="Results")
scores = gr.Textbox(label="Scores", type="text", value="")
search_button.click(search_ds, inputs=[search], outputs=[results, scores])
download_dataset.click(download, dataset_name)
demo.launch()