HUANG-Stephanie commited on
Commit
df88ad3
·
verified ·
1 Parent(s): 33de980

Create gradio_app.py

Browse files
Files changed (1) hide show
  1. gradio_app.py +47 -0
gradio_app.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import requests
3
+
4
+ def index_gradio(file, ds):
5
+ """Upload PDFs and get embeddings."""
6
+ url = "http://localhost:8082/index"
7
+ files = [("files", (f.name, f.file)) for f in file]
8
+ response = requests.post(url, files=files)
9
+ result = response.json()
10
+ return result['message'], ds, []
11
+
12
+ def search_gradio(query: str, ds, images, k):
13
+ """Send a search query and get results."""
14
+ url = "http://localhost:8082/search"
15
+ payload = {'query': query, 'k': k}
16
+ response = requests.post(url, json=payload)
17
+ result = response.json()
18
+ results = result['results']
19
+ return results
20
+
21
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
22
+ gr.Markdown("# ColPali: Efficient Document Retrieval with Vision Language Models 📚")
23
+
24
+ with gr.Row():
25
+ with gr.Column(scale=2):
26
+ gr.Markdown("## 1️⃣ Upload PDFs")
27
+ file = gr.File(file_types=["pdf"], file_count="multiple", label="Upload PDFs")
28
+
29
+ convert_button = gr.Button("🔄 Index documents")
30
+ message = gr.Textbox("Files not yet uploaded", label="Status")
31
+ embeds = gr.State(value=[])
32
+ imgs = gr.State(value=[])
33
+
34
+ with gr.Column(scale=3):
35
+ gr.Markdown("## 2️⃣ Search")
36
+ query = gr.Textbox(placeholder="Enter your query here", label="Query")
37
+ k = gr.Slider(minimum=1, maximum=10, step=1, label="Number of results", value=1)
38
+
39
+ # Define the actions
40
+ search_button = gr.Button("🔍 Search", variant="primary")
41
+ output_gallery = gr.Gallery(label="Retrieved Documents", height=600, show_label=True)
42
+
43
+ convert_button.click(index_gradio, inputs=[file, embeds], outputs=[message, embeds, imgs])
44
+ search_button.click(search_gradio, inputs=[query, embeds, imgs, k], outputs=[output_gallery])
45
+
46
+ if __name__ == "__main__":
47
+ demo.queue(max_size=10).launch(debug=True)