File size: 4,656 Bytes
a924f05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ae1680f
a924f05
 
144f917
a924f05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import gradio as gr

from tool import VisualRAGTool

tool = VisualRAGTool()

def search(query, k, api_key):
    """Searches for the most relevant pages based on the query."""
    print("=============== SEARCHING ===============")

    context, answer = tool.search(query, k, api_key)

    context_gallery = [(page.image, page.caption) for page in context]

    print("========================================")

    return gr.Gallery(value=context_gallery, label="Retrieved Documents", height=400, show_label=True, visible=True), answer

def index(files, contextualize_embeds, api_key):
    """Indexes the uploaded files."""
    print("=============== INDEXING ===============")
    
    indexed_files_num = tool.index(
        files=files, 
        contextualize=contextualize_embeds, 
        api_key=api_key, 
        overwrite_db=True
    )

    print("========================================")
    return gr.Textbox(f"Uploaded and processed {indexed_files_num} pages!"),\
            gr.Textbox(
                lines=2,
                label="Query",
                show_label=False,
                placeholder="Enter your prompt here and press Shift+Enter or press the button",
                interactive=True,
            )

def show_processing_status():
    """Updates the upload status."""
    return gr.Textbox(label="Processing Status", interactive=False, visible=True),\
            gr.Checkbox(label="Contextualize Embeddings", visible=False),\
            gr.Textbox(
                    lines=2,
                    label="Query",
                    show_label=False,
                    placeholder="Enter your prompt here and press Shift+Enter or press the button",
                    interactive=False,
                )

with gr.Blocks(
    theme=gr.themes.Ocean(),
    title="ColPali Tool Demo",
) as demo:
    gr.Markdown("""\
    # ColPali: Efficient Document Retrieval with Vision Language Models (ColQwen2) 📚   
    Demo to test the ColPali RAG Tool powered by ColQwen2 (ColPali) on PDF documents.
    ColPali is implemented from the [ColPali paper](https://arxiv.org/abs/2407.01449).

    This tool allows you to upload PDF files and search for the most relevant pages based on your query.
    Refresh the page if you change documents!

    ⚠️ This demo uses a CPU version of the model, so it may be slow! For faster results you may want to fork the space and run it on a GPU. 
    """)

    api_key = gr.Textbox(placeholder="Enter your OpenAI KEY here", label="API key")

    stored_messages = gr.State(value=[])

    gr.Markdown("## 1️⃣ Upload PDFs")
    gr.Markdown("Upload PDF files to index and search.")
    with gr.Group():
        contextualize_embeds = gr.Checkbox(label="Contextualize Embeddings", info="Add images surrouding context as metadata. Generated using gpt-4o-mini. ⚠️ Indexing will be longer!", value=True)
        upload_files = gr.File(file_types=[".pdf"], file_count="multiple", label="Upload files")
        processing_status = gr.Textbox(label="Processing Status", interactive=False, visible=False)
        
    gr.Markdown("## 2️⃣ Search")
    gr.Markdown("Ask a question relevant to the documents you uploaded.")
    with gr.Group():
        chatbot = gr.Textbox(label="AI Assistant", placeholder="Generated response based on retrieved documents.", lines=6)
        output_gallery = gr.Gallery(label="Retrieved Documents", height=600, show_label=True, visible=False)
        with gr.Row(equal_height=True):
            with gr.Column(scale=4):
                text_input = gr.Textbox(
                    lines=2,
                    label="Query",
                    show_label=False,
                    placeholder="Enter your prompt here and press Shift+Enter or press the button",
                )
            with gr.Column(scale=1):
                k = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Pages to retrieve")
        search_button = gr.Button("🔍 Search", variant="primary")

    # Define the flow of the demo
    # upload_files.change(index, inputs=[upload_files, api_key], outputs=[upload_status])
    upload_files.change(show_processing_status, inputs=[], outputs=[processing_status, contextualize_embeds, text_input])\
                    .then(index, inputs=[upload_files, contextualize_embeds, api_key], outputs=[processing_status, text_input])
    text_input.submit(search, inputs=[text_input, k, api_key], outputs=[output_gallery, chatbot])
    search_button.click(search, inputs=[text_input, k, api_key], outputs=[output_gallery, chatbot])

if __name__ == "__main__":
    demo.queue(max_size=5).launch(debug=True, server_name="0.0.0.0")