Spaces:
Running
Running
import gradio as gr | |
from tool import VisualRAGTool | |
tool = VisualRAGTool() | |
def search(query, k, api_key): | |
"""Searches for the most relevant pages based on the query.""" | |
print("=============== SEARCHING ===============") | |
context, answer = tool.search(query, k, api_key) | |
context_gallery = [(page.image, page.caption) for page in context] | |
print("========================================") | |
return gr.Gallery(value=context_gallery, label="Retrieved Documents", height=400, show_label=True, visible=True), answer | |
def index(files, contextualize_embeds, api_key): | |
"""Indexes the uploaded files.""" | |
print("=============== INDEXING ===============") | |
indexed_files_num = tool.index( | |
files=files, | |
contextualize=contextualize_embeds, | |
api_key=api_key, | |
overwrite_db=True | |
) | |
print("========================================") | |
return gr.Textbox(f"Uploaded and processed {indexed_files_num} pages!"),\ | |
gr.Textbox( | |
lines=2, | |
label="Query", | |
show_label=False, | |
placeholder="Enter your prompt here and press Shift+Enter or press the button", | |
interactive=True, | |
) | |
def show_processing_status(): | |
"""Updates the upload status.""" | |
return gr.Textbox(label="Processing Status", interactive=False, visible=True),\ | |
gr.Checkbox(label="Contextualize Embeddings", visible=False),\ | |
gr.Textbox( | |
lines=2, | |
label="Query", | |
show_label=False, | |
placeholder="Enter your prompt here and press Shift+Enter or press the button", | |
interactive=False, | |
) | |
with gr.Blocks( | |
theme=gr.themes.Ocean(), | |
title="ColPali Tool Demo", | |
) as demo: | |
gr.Markdown("""\ | |
# ColPali: Efficient Document Retrieval with Vision Language Models (ColQwen2) 📚 | |
Demo to test the ColPali RAG Tool powered by ColQwen2 (ColPali) on PDF documents. | |
ColPali is implemented from the [ColPali paper](https://arxiv.org/abs/2407.01449). | |
This tool allows you to upload PDF files and search for the most relevant pages based on your query. | |
Refresh the page if you change documents! | |
⚠️ This demo uses a CPU version of the model, so it may be slow! For faster results you may want to fork the space and run it on a GPU. | |
""") | |
api_key = gr.Textbox(placeholder="Enter your OpenAI KEY here", label="API key") | |
stored_messages = gr.State(value=[]) | |
gr.Markdown("## 1️⃣ Upload PDFs") | |
gr.Markdown("Upload PDF files to index and search.") | |
with gr.Group(): | |
contextualize_embeds = gr.Checkbox(label="Contextualize Embeddings", info="Add images surrouding context as metadata. Generated using gpt-4o-mini. ⚠️ Indexing will be longer!", value=True) | |
upload_files = gr.File(file_types=[".pdf"], file_count="multiple", label="Upload files") | |
processing_status = gr.Textbox(label="Processing Status", interactive=False, visible=False) | |
gr.Markdown("## 2️⃣ Search") | |
gr.Markdown("Ask a question relevant to the documents you uploaded.") | |
with gr.Group(): | |
chatbot = gr.Textbox(label="AI Assistant", placeholder="Generated response based on retrieved documents.", lines=6) | |
output_gallery = gr.Gallery(label="Retrieved Documents", height=600, show_label=True, visible=False) | |
with gr.Row(equal_height=True): | |
with gr.Column(scale=4): | |
text_input = gr.Textbox( | |
lines=2, | |
label="Query", | |
show_label=False, | |
placeholder="Enter your prompt here and press Shift+Enter or press the button", | |
) | |
with gr.Column(scale=1): | |
k = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Pages to retrieve") | |
search_button = gr.Button("🔍 Search", variant="primary") | |
# Define the flow of the demo | |
# upload_files.change(index, inputs=[upload_files, api_key], outputs=[upload_status]) | |
upload_files.change(show_processing_status, inputs=[], outputs=[processing_status, contextualize_embeds, text_input])\ | |
.then(index, inputs=[upload_files, contextualize_embeds, api_key], outputs=[processing_status, text_input]) | |
text_input.submit(search, inputs=[text_input, k, api_key], outputs=[output_gallery, chatbot]) | |
search_button.click(search, inputs=[text_input, k, api_key], outputs=[output_gallery, chatbot]) | |
if __name__ == "__main__": | |
demo.queue(max_size=5).launch(debug=True, server_name="0.0.0.0") |