ColPali-demo / app.py
hantech's picture
Update app.py
13ebcf2 verified
import os
import gradio as gr
import torch
from pdf2image import convert_from_path
from PIL import Image
from torch.utils.data import DataLoader
from tqdm import tqdm
import tempfile
from colpali_engine.models import ColQwen2, ColQwen2Processor
# Ensure the temporary directory exists
if not os.path.exists('/tmp/gradio'):
os.makedirs('/tmp/gradio')
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = ColQwen2.from_pretrained(
"manu/colqwen2-v1.0-alpha",
torch_dtype=torch.bfloat16,
device_map=device,
).eval()
processor = ColQwen2Processor.from_pretrained("manu/colqwen2-v1.0-alpha")
def search(query: str, ds, images, k):
if not ds:
return None, "No documents have been indexed. Please upload and index documents first."
k = min(k, len(ds))
device = "cuda:0" if torch.cuda.is_available() else "cpu"
if device != model.device:
model.to(device)
qs = []
with torch.no_grad():
batch_query = processor.process_queries([query]).to(model.device)
embeddings_query = model(**batch_query)
qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
scores = processor.score(qs, ds, device=device)
top_k_indices = scores[0].topk(k).indices.tolist()
results = []
for idx in top_k_indices:
results.append((images[idx], f"Page {idx}"))
return results, ""
def index(files, ds):
if not files:
return "No files uploaded. Please upload PDF files to index.", ds, []
print("Converting files")
images = convert_files(files)
print(f"Files converted with {len(images)} images.")
status_message, ds, images = index_gpu(images, ds)
print(f"Indexed {len(ds)} embeddings.")
return status_message, ds, images
def convert_files(files):
images = []
for f in files:
if isinstance(f, dict):
file_path = f['filepath']
elif isinstance(f, str):
file_path = f
else:
raise TypeError(f"Unsupported file type: {type(f)}")
print(f"Processing file: {file_path}")
if not os.path.exists(file_path):
print(f"File does not exist: {file_path}")
continue
try:
images.extend(convert_from_path(file_path, thread_count=4))
except Exception as e:
print(f"Error converting {file_path}: {e}")
# Handle the error or skip the file
if len(images) >= 150:
raise gr.Error("The number of images in the dataset should be less than 150.")
return images
def index_gpu(images, ds):
device = "cuda:0" if torch.cuda.is_available() else "cpu"
if device != model.device:
model.to(device)
dataloader = DataLoader(
images,
batch_size=4,
shuffle=False,
collate_fn=lambda x: processor.process_images(x).to(model.device),
)
for batch_doc in tqdm(dataloader):
with torch.no_grad():
batch_doc = {k: v.to(device) for k, v in batch_doc.items()}
embeddings_doc = model(**batch_doc)
ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
return f"Uploaded and converted {len(images)} pages", ds, images
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# ColPali: Efficient Document Retrieval with Vision Language Models (ColQwen2) ๐Ÿ“š")
gr.Markdown("""Demo to test ColQwen2 (ColPali) on PDF documents.
ColPali is a model implemented from the [ColPali paper](https://arxiv.org/abs/2407.01449).
This demo allows you to upload PDF files and search for the most relevant pages based on your query.
Refresh the page if you change documents !
โš ๏ธ This demo uses a model trained exclusively on A4 PDFs in portrait mode, containing English text. Performance is expected to drop for other page formats and languages.
Other models will be released with better robustness towards different languages and document formats !
""")
with gr.Row():
with gr.Column(scale=2):
gr.Markdown("## 1๏ธโƒฃ Upload PDFs")
file = gr.File(file_types=["pdf"], file_count="multiple", label="Upload PDFs")
convert_button = gr.Button("๐Ÿ”„ Index documents")
message = gr.Textbox("Files not yet uploaded", label="Status")
embeds = gr.State(value=[])
imgs = gr.State(value=[])
with gr.Column(scale=3):
gr.Markdown("## 2๏ธโƒฃ Search")
query = gr.Textbox(placeholder="Enter your query here", label="Query")
k = gr.Slider(minimum=1, maximum=10, step=1, label="Number of results", value=5)
# Define the actions
convert_button.click(index, inputs=[file, embeds], outputs=[message, embeds, imgs])
search_button = gr.Button("๐Ÿ” Search", variant="primary")
output_gallery = gr.Gallery(label="Retrieved Documents", height=600, show_label=True)
error_output = gr.Textbox(label="Error Message")
search_button.click(search, inputs=[query, embeds, imgs, k], outputs=[output_gallery, error_output])
if __name__ == "__main__":
demo.queue(max_size=10).launch(debug=True)