File size: 6,039 Bytes
7e738ef
78a7c54
f9c5a74
 
 
 
 
 
f42f33d
 
78a7c54
bc890cd
78a7c54
 
 
 
 
 
 
 
 
 
 
f9c5a74
 
 
 
 
 
 
 
 
 
f42f33d
f9c5a74
f21b6d3
2eb1042
f9c5a74
 
f21b6d3
f9c5a74
 
2eb1042
f9c5a74
 
 
 
f21b6d3
f9c5a74
 
 
 
f21b6d3
f9c5a74
 
 
 
 
 
 
 
 
 
 
 
 
 
f21b6d3
f9c5a74
 
 
 
 
 
 
 
 
 
 
f21b6d3
f9c5a74
 
 
 
 
 
 
 
 
 
8313c17
 
 
 
 
2f43f91
8313c17
2f43f91
 
 
 
 
 
 
 
 
2eb1042
 
 
 
8313c17
2eb1042
 
 
 
 
 
 
 
 
2f43f91
 
 
8313c17
bc890cd
 
8313c17
 
2eb1042
 
bc890cd
0a97d2e
8313c17
2eb1042
8313c17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f21b6d3
8313c17
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import spaces
import os
import gradio as gr
from pdf2image import convert_from_path
from byaldi import RAGMultiModalModel
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
import torch
import torchvision
import subprocess

# Run the commands from setup.sh to install poppler-utils
def install_poppler():
    try:
        subprocess.run(["pdfinfo"], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    except FileNotFoundError:
        print("Poppler not found. Installing...")
        # Run the setup commands
        subprocess.run("apt-get update", shell=True)
        subprocess.run("apt-get install -y poppler-utils", shell=True)

# Call the Poppler installation check
install_poppler()

# Install flash-attn if not already installed
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)

# Load the RAG Model and the Qwen2-VL-2B-Instruct model
RAG = RAGMultiModalModel.from_pretrained("vidore/colpali")
model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct",
                                                        trust_remote_code=True, torch_dtype=torch.bfloat16).cuda().eval()
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", trust_remote_code=True)

@spaces.GPU()
def process_pdf_and_query(pdf_file, user_query):
    # Convert the PDF to images
    images = convert_from_path(pdf_file.name)  # pdf_file.name gives the file path
    num_images = len(images)

    # Indexing the PDF in RAG
    RAG.index(
        input_path=pdf_file.name,
        index_name="image_index",  # index will be saved at index_root/index_name/
        store_collection_with_index=False,
        overwrite=True
    )

    # Search the query in the RAG model
    results = RAG.search(user_query, k=1)
    if not results:
        return "No results found.", num_images

    # Retrieve the page number and process image
    image_index = results[0]["page_num"] - 1
    messages = [
        {
            "role": "user",
            "content": [
                {
                    "type": "image",
                    "image": images[image_index],
                },
                {"type": "text", "text": user_query},
            ],
        }
    ]

    # Generate text with the Qwen model
    text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    image_inputs, video_inputs = process_vision_info(messages)
    inputs = processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )
    inputs = inputs.to("cuda")
    
    # Generate the output response
    generated_ids = model.generate(**inputs, max_new_tokens=50)
    generated_ids_trimmed = [
        out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    output_text = processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )

    return output_text[0], num_images

# Define the Gradio Interface
pdf_input = gr.File(label="Upload PDF")  # Single PDF file input
query_input = gr.Textbox(label="Enter your query", placeholder="Ask a question about the PDF")  # User query input
output_text = gr.Textbox(label="Model Answer")  # Output for the model's answer
output_images = gr.Textbox(label="Number of Images in PDF")  # Output for number of images

# Footer HTML
footer = """
<div style="text-align: center; margin-top: 20px;">
    <a href="https://www.linkedin.com/in/pejman-ebrahimi-4a60151a7/" target="_blank">LinkedIn</a> |
    <a href="https://github.com/arad1367" target="_blank">GitHub</a> |
    <a href="https://arad1367.pythonanywhere.com/" target="_blank">Live demo of my PhD defense</a> |
    <a href="https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct" target="_blank">Qwen/Qwen2-VL-2B-Instruct</a> |
    <a href="https://github.com/AnswerDotAI/byaldi" target="_blank">Byaldi</a> |
    <a href="https://github.com/illuin-tech/colpali" target="_blank">ColPali</a>
    <br>
    Made with πŸ’– by Pejman Ebrahimi
</div>
"""

# Explanation about Multimodal RAG
explanation = """
<div style="text-align: center; margin-bottom: 20px;">
    <h2 style="font-weight: bold; font-size: 24px;">Multimodal RAG (Retrieval-Augmented Generation)</h2>
    <p>
        This application utilizes the ColPali model as a multimodal retriever, 
        which retrieves relevant information from documents and generates answers 
        using the Qwen/Qwen2-VL-2B-Instruct LLM (Large Language Model) 
        via the Byaldi library, developed by Answer.ai.
    </p>
</div>
"""

# Launch the Gradio app with additional features
demo = gr.Interface(
    fn=process_pdf_and_query, 
    inputs=[pdf_input, query_input],  # List of inputs
    outputs=[output_text, output_images],  # List of outputs
    title="Multimodal RAG with Image Query - By <a href='https://github.com/arad1367'>Pejman Ebrahimi</a>",
    theme='freddyaboulton/dracula_revamped',
)

# Add additional elements to the interface
with demo:
    gr.HTML(explanation)  # Explanation section
    gr.HTML(footer)  # Footer section
    gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")  # Duplicate button
    submit_button = gr.Button("Submit", elem_classes="submit-button")  # Custom Submit Button

# Custom CSS for styling the button
css = """
<style>
    .submit-button {
        background-color: green;
        color: white;
        border: none;
        border-radius: 5px;
        padding: 10px 20px;
        font-size: 16px;
        cursor: pointer;
    }
    .duplicate-button {
        background-color: lightgreen;
        color: white;
        border: none;
        border-radius: 5px;
        padding: 10px 20px;
        font-size: 16px;
        cursor: pointer;
    }
</style>
"""
gr.HTML(css)

# Launch the Gradio app
demo.launch(debug=True)  # Start the interface