Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -7,21 +7,20 @@ from pdf2image import convert_from_path
|
|
7 |
from PIL import Image
|
8 |
from torch.utils.data import DataLoader
|
9 |
from tqdm import tqdm
|
10 |
-
import os
|
11 |
|
|
|
|
|
|
|
12 |
if not os.path.exists('/tmp/gradio'):
|
13 |
os.makedirs('/tmp/gradio')
|
14 |
-
|
15 |
-
from colpali_engine.models import ColQwen2, ColQwen2Processor
|
16 |
|
17 |
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
18 |
|
19 |
model = ColQwen2.from_pretrained(
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
).eval()
|
25 |
processor = ColQwen2Processor.from_pretrained("manu/colqwen2-v1.0-alpha")
|
26 |
|
27 |
def search(query: str, ds, images, k):
|
@@ -49,14 +48,16 @@ def search(query: str, ds, images, k):
|
|
49 |
|
50 |
return results
|
51 |
|
52 |
-
|
53 |
def index(files, ds):
|
|
|
|
|
|
|
54 |
print("Converting files")
|
55 |
images = convert_files(files)
|
56 |
print(f"Files converted with {len(images)} images.")
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
|
61 |
def convert_files(files):
|
62 |
images = []
|
@@ -67,15 +68,11 @@ def convert_files(files):
|
|
67 |
raise gr.Error("The number of images in the dataset should be less than 150.")
|
68 |
return images
|
69 |
|
70 |
-
|
71 |
def index_gpu(images, ds):
|
72 |
-
"""Example script to run inference with ColPali (ColQwen2)"""
|
73 |
-
|
74 |
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
75 |
if device != model.device:
|
76 |
model.to(device)
|
77 |
|
78 |
-
# run inference - docs
|
79 |
dataloader = DataLoader(
|
80 |
images,
|
81 |
batch_size=4,
|
@@ -90,17 +87,15 @@ def index_gpu(images, ds):
|
|
90 |
ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
|
91 |
return f"Uploaded and converted {len(images)} pages", ds, images
|
92 |
|
93 |
-
|
94 |
-
|
95 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
96 |
gr.Markdown("# ColPali: Efficient Document Retrieval with Vision Language Models (ColQwen2) 📚")
|
97 |
gr.Markdown("""Demo to test ColQwen2 (ColPali) on PDF documents.
|
98 |
-
ColPali is model implemented from the [ColPali paper](https://arxiv.org/abs/2407.01449).
|
99 |
|
100 |
This demo allows you to upload PDF files and search for the most relevant pages based on your query.
|
101 |
Refresh the page if you change documents !
|
102 |
|
103 |
-
⚠️ This demo uses a model trained exclusively on A4 PDFs in portrait mode, containing
|
104 |
Other models will be released with better robustness towards different languages and document formats !
|
105 |
""")
|
106 |
with gr.Row():
|
@@ -118,12 +113,11 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
118 |
query = gr.Textbox(placeholder="Enter your query here", label="Query")
|
119 |
k = gr.Slider(minimum=1, maximum=10, step=1, label="Number of results", value=5)
|
120 |
|
121 |
-
|
122 |
# Define the actions
|
|
|
123 |
search_button = gr.Button("🔍 Search", variant="primary")
|
124 |
output_gallery = gr.Gallery(label="Retrieved Documents", height=600, show_label=True)
|
125 |
|
126 |
-
convert_button.click(index, inputs=[file, embeds], outputs=[message, embeds, imgs])
|
127 |
search_button.click(search, inputs=[query, embeds, imgs, k], outputs=[output_gallery])
|
128 |
|
129 |
if __name__ == "__main__":
|
|
|
7 |
from PIL import Image
|
8 |
from torch.utils.data import DataLoader
|
9 |
from tqdm import tqdm
|
|
|
10 |
|
11 |
+
from colpali_engine.models import ColQwen2, ColQwen2Processor
|
12 |
+
|
13 |
+
# Ensure the temporary directory exists
|
14 |
if not os.path.exists('/tmp/gradio'):
|
15 |
os.makedirs('/tmp/gradio')
|
|
|
|
|
16 |
|
17 |
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
18 |
|
19 |
model = ColQwen2.from_pretrained(
|
20 |
+
"manu/colqwen2-v1.0-alpha",
|
21 |
+
torch_dtype=torch.bfloat16,
|
22 |
+
device_map=device,
|
23 |
+
).eval()
|
|
|
24 |
processor = ColQwen2Processor.from_pretrained("manu/colqwen2-v1.0-alpha")
|
25 |
|
26 |
def search(query: str, ds, images, k):
|
|
|
48 |
|
49 |
return results
|
50 |
|
|
|
51 |
def index(files, ds):
|
52 |
+
if not files:
|
53 |
+
return gr.Error("No files uploaded. Please upload PDF files to index."), ds, []
|
54 |
+
|
55 |
print("Converting files")
|
56 |
images = convert_files(files)
|
57 |
print(f"Files converted with {len(images)} images.")
|
58 |
+
status_message, ds, images = index_gpu(images, ds)
|
59 |
+
print(f"Indexed {len(ds)} embeddings.")
|
60 |
+
return status_message, ds, images
|
61 |
|
62 |
def convert_files(files):
|
63 |
images = []
|
|
|
68 |
raise gr.Error("The number of images in the dataset should be less than 150.")
|
69 |
return images
|
70 |
|
|
|
71 |
def index_gpu(images, ds):
|
|
|
|
|
72 |
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
73 |
if device != model.device:
|
74 |
model.to(device)
|
75 |
|
|
|
76 |
dataloader = DataLoader(
|
77 |
images,
|
78 |
batch_size=4,
|
|
|
87 |
ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
|
88 |
return f"Uploaded and converted {len(images)} pages", ds, images
|
89 |
|
|
|
|
|
90 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
91 |
gr.Markdown("# ColPali: Efficient Document Retrieval with Vision Language Models (ColQwen2) 📚")
|
92 |
gr.Markdown("""Demo to test ColQwen2 (ColPali) on PDF documents.
|
93 |
+
ColPali is a model implemented from the [ColPali paper](https://arxiv.org/abs/2407.01449).
|
94 |
|
95 |
This demo allows you to upload PDF files and search for the most relevant pages based on your query.
|
96 |
Refresh the page if you change documents !
|
97 |
|
98 |
+
⚠️ This demo uses a model trained exclusively on A4 PDFs in portrait mode, containing English text. Performance is expected to drop for other page formats and languages.
|
99 |
Other models will be released with better robustness towards different languages and document formats !
|
100 |
""")
|
101 |
with gr.Row():
|
|
|
113 |
query = gr.Textbox(placeholder="Enter your query here", label="Query")
|
114 |
k = gr.Slider(minimum=1, maximum=10, step=1, label="Number of results", value=5)
|
115 |
|
|
|
116 |
# Define the actions
|
117 |
+
convert_button.click(index, inputs=[file, embeds], outputs=[message, embeds, imgs])
|
118 |
search_button = gr.Button("🔍 Search", variant="primary")
|
119 |
output_gallery = gr.Gallery(label="Retrieved Documents", height=600, show_label=True)
|
120 |
|
|
|
121 |
search_button.click(search, inputs=[query, embeds, imgs, k], outputs=[output_gallery])
|
122 |
|
123 |
if __name__ == "__main__":
|