File size: 5,140 Bytes
94b55f0
 
602d806
 
89cecf3
3649694
602d806
 
 
5ad6fc2
5dfd724
c8a9051
 
 
f7fe59b
 
9c66171
ce1baf7
 
9b1e831
c8a9051
 
 
 
9b1e831
 
0d01d71
f7fe59b
b11b5f5
f7fe59b
4be1e51
9c66171
 
 
 
602d806
 
9b1e831
602d806
 
 
9b1e831
0d01d71
9b1e831
0d01d71
 
 
 
 
b11b5f5
602d806
0d01d71
c8a9051
b11b5f5
c8a9051
ec28a2a
 
 
c8a9051
 
 
ec28a2a
 
602d806
295589c
 
5ad6fc2
295589c
 
 
 
 
 
 
 
 
13ebcf2
295589c
 
 
0d01d71
 
ec28a2a
0d01d71
ec28a2a
9b1e831
 
 
 
602d806
 
4e6cb11
602d806
9b1e831
602d806
a2d6d06
602d806
 
9c66171
602d806
 
 
 
dad1e49
d546c80
 
c8a9051
5923654
0d01d71
f700076
9357d80
c8a9051
9357d80
0d01d71
 
 
 
 
602d806
f700076
0d01d71
 
 
602d806
0d01d71
 
 
10278bd
602d806
0d01d71
c8a9051
0d01d71
 
b11b5f5
0d01d71
b11b5f5
602d806
 
5dfd724
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import os

import gradio as gr
import torch

from pdf2image import convert_from_path
from PIL import Image
from torch.utils.data import DataLoader
from tqdm import tqdm
import tempfile

from colpali_engine.models import ColQwen2, ColQwen2Processor

# Ensure the temporary directory exists
if not os.path.exists('/tmp/gradio'):
    os.makedirs('/tmp/gradio')

device = "cuda:0" if torch.cuda.is_available() else "cpu"

model = ColQwen2.from_pretrained(
    "manu/colqwen2-v1.0-alpha",
    torch_dtype=torch.bfloat16,
    device_map=device,
).eval()
processor = ColQwen2Processor.from_pretrained("manu/colqwen2-v1.0-alpha")

def search(query: str, ds, images, k):
    if not ds:
        return None, "No documents have been indexed. Please upload and index documents first."
    
    k = min(k, len(ds))
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    if device != model.device:
        model.to(device)
        
    qs = []
    with torch.no_grad():
        batch_query = processor.process_queries([query]).to(model.device)
        embeddings_query = model(**batch_query)
        qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))

    scores = processor.score(qs, ds, device=device)

    top_k_indices = scores[0].topk(k).indices.tolist()

    results = []
    for idx in top_k_indices:
        results.append((images[idx], f"Page {idx}"))

    return results, ""

def index(files, ds):
    if not files:
        return "No files uploaded. Please upload PDF files to index.", ds, []
    
    print("Converting files")
    images = convert_files(files)
    print(f"Files converted with {len(images)} images.")
    status_message, ds, images = index_gpu(images, ds)
    print(f"Indexed {len(ds)} embeddings.")
    return status_message, ds, images

def convert_files(files):
    images = []
    for f in files:
        if isinstance(f, dict):
            file_path = f['filepath']
        elif isinstance(f, str):
            file_path = f
        else:
            raise TypeError(f"Unsupported file type: {type(f)}")
        print(f"Processing file: {file_path}")
        if not os.path.exists(file_path):
            print(f"File does not exist: {file_path}")
            continue
        try:
            images.extend(convert_from_path(file_path, thread_count=4))
        except Exception as e:
            print(f"Error converting {file_path}: {e}")
            # Handle the error or skip the file
    if len(images) >= 150:
        raise gr.Error("The number of images in the dataset should be less than 150.")
    return images

def index_gpu(images, ds):
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    if device != model.device:
        model.to(device)
        
    dataloader = DataLoader(
        images,
        batch_size=4,
        shuffle=False,
        collate_fn=lambda x: processor.process_images(x).to(model.device),
    )

    for batch_doc in tqdm(dataloader):
        with torch.no_grad():
            batch_doc = {k: v.to(device) for k, v in batch_doc.items()}
            embeddings_doc = model(**batch_doc)
        ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
    return f"Uploaded and converted {len(images)} pages", ds, images

with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# ColPali: Efficient Document Retrieval with Vision Language Models (ColQwen2) 📚")
    gr.Markdown("""Demo to test ColQwen2 (ColPali) on PDF documents. 
    ColPali is a model implemented from the [ColPali paper](https://arxiv.org/abs/2407.01449).

    This demo allows you to upload PDF files and search for the most relevant pages based on your query.
    Refresh the page if you change documents !

    ⚠️ This demo uses a model trained exclusively on A4 PDFs in portrait mode, containing English text. Performance is expected to drop for other page formats and languages.
    Other models will be released with better robustness towards different languages and document formats !
    """)
    with gr.Row():
        with gr.Column(scale=2):
            gr.Markdown("## 1️⃣ Upload PDFs")
            file = gr.File(file_types=["pdf"], file_count="multiple", label="Upload PDFs")

            convert_button = gr.Button("🔄 Index documents")
            message = gr.Textbox("Files not yet uploaded", label="Status")
            embeds = gr.State(value=[])
            imgs = gr.State(value=[])

        with gr.Column(scale=3):
            gr.Markdown("## 2️⃣ Search")
            query = gr.Textbox(placeholder="Enter your query here", label="Query")
            k = gr.Slider(minimum=1, maximum=10, step=1, label="Number of results", value=5)

    # Define the actions
    convert_button.click(index, inputs=[file, embeds], outputs=[message, embeds, imgs])
    search_button = gr.Button("🔍 Search", variant="primary")
    output_gallery = gr.Gallery(label="Retrieved Documents", height=600, show_label=True)
    error_output = gr.Textbox(label="Error Message")

    search_button.click(search, inputs=[query, embeds, imgs, k], outputs=[output_gallery, error_output])

if __name__ == "__main__":
    demo.queue(max_size=10).launch(debug=True)