visual-rag-tool / app.py
paultltc's picture
init commit
a924f05
raw
history blame
4.81 kB
import gradio as gr
from tool import VisualRAGTool
tool = VisualRAGTool()
def search(query, k, api_key):
"""Searches for the most relevant pages based on the query."""
print("=============== SEARCHING ===============")
context, answer = tool.search(query, k, api_key)
context_gallery = [(page.image, page.caption) for page in context]
print("========================================")
return gr.Gallery(value=context_gallery, label="Retrieved Documents", height=400, show_label=True, visible=True), answer
def index(files, contextualize_embeds, api_key):
"""Indexes the uploaded files."""
print("=============== INDEXING ===============")
indexed_files_num = tool.index(
files=files,
contextualize=contextualize_embeds,
api_key=api_key,
overwrite_db=True
)
print("========================================")
return gr.Textbox(f"Uploaded and processed {indexed_files_num} pages!"),\
gr.Textbox(
lines=2,
label="Query",
show_label=False,
placeholder="Enter your prompt here and press Shift+Enter or press the button",
interactive=True,
)
def show_processing_status():
"""Updates the upload status."""
return gr.Textbox(label="Processing Status", interactive=False, visible=True),\
gr.Checkbox(label="Contextualize Embeddings", visible=False),\
gr.Textbox(
lines=2,
label="Query",
show_label=False,
placeholder="Enter your prompt here and press Shift+Enter or press the button",
interactive=False,
)
with gr.Blocks(
theme=gr.themes.Ocean(),
title="ColPali Tool Demo",
) as demo:
gr.Markdown("""\
# ColPali: Efficient Document Retrieval with Vision Language Models (ColQwen2) 📚
Demo to test the ColPali RAG Tool powered by ColQwen2 (ColPali) on PDF documents.
ColPali is implemented from the [ColPali paper](https://arxiv.org/abs/2407.01449).
This tool allows you to upload PDF files and search for the most relevant pages based on your query.
Refresh the page if you change documents!
⚠️ This demo uses a model trained exclusively on A4 PDFs in portrait mode, containing english text. Performance is expected to drop for other page formats and languages.
Other models will be released with better robustness towards different languages and document formats!
""")
api_key = gr.Textbox(placeholder="Enter your OpenAI KEY here (optional)", label="API key")
stored_messages = gr.State(value=[])
gr.Markdown("## 1️⃣ Upload PDFs")
gr.Markdown("Upload PDF files to index and search.")
with gr.Group():
contextualize_embeds = gr.Checkbox(label="Contextualize Embeddings", info="Add images surrouding context as metadata. Generated using gpt-4o-mini. ⚠️ Indexing will be longer!", value=True)
upload_files = gr.File(file_types=[".pdf"], file_count="multiple", label="Upload files")
processing_status = gr.Textbox(label="Processing Status", interactive=False, visible=False)
gr.Markdown("## 2️⃣ Search")
gr.Markdown("Ask a question relevant to the documents you uploaded.")
with gr.Group():
chatbot = gr.Textbox(label="AI Assistant", placeholder="Generated response based on retrieved documents.", lines=6)
output_gallery = gr.Gallery(label="Retrieved Documents", height=600, show_label=True, visible=False)
with gr.Row(equal_height=True):
with gr.Column(scale=4):
text_input = gr.Textbox(
lines=2,
label="Query",
show_label=False,
placeholder="Enter your prompt here and press Shift+Enter or press the button",
)
with gr.Column(scale=1):
k = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Pages to retrieve")
search_button = gr.Button("🔍 Search", variant="primary")
# Define the flow of the demo
# upload_files.change(index, inputs=[upload_files, api_key], outputs=[upload_status])
upload_files.change(show_processing_status, inputs=[], outputs=[processing_status, contextualize_embeds, text_input])\
.then(index, inputs=[upload_files, contextualize_embeds, api_key], outputs=[processing_status, text_input])
text_input.submit(search, inputs=[text_input, k, api_key], outputs=[output_gallery, chatbot])
search_button.click(search, inputs=[text_input, k, api_key], outputs=[output_gallery, chatbot])
if __name__ == "__main__":
demo.queue(max_size=5).launch(debug=True, server_name="0.0.0.0")