cvquest-colpali / app.py
HUANG-Stephanie's picture
Update app.py
3862235 verified
raw
history blame
3.49 kB
from io import BytesIO
import os
import sys
import temp_file
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import RedirectResponse, FileResponse
import gradio as gr
import requests
import uvicorn
from typing import List
import torch
from pdf2image import convert_from_bytes
from PIL import Image
from torch.utils.data import DataLoader
from transformers import AutoProcessor
import base64
from reportlab.pdfgen import canvas
from reportlab.lib.pagesizes import letter
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), './colpali-main')))
from colpali_engine.models.paligemma_colbert_architecture import ColPali
from colpali_engine.trainer.retrieval_evaluator import CustomEvaluator
from colpali_engine.utils.colpali_processing_utils import (
process_images,
process_queries,
)
app = FastAPI()
# Load model
model_name = "vidore/colpali"
token = os.environ.get("HF_TOKEN")
model = ColPali.from_pretrained(
"google/paligemma-3b-mix-448", torch_dtype=torch.bfloat16, device_map="cpu", token = token).eval()
model.load_adapter(model_name)
processor = AutoProcessor.from_pretrained(model_name, token = token)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
if device != model.device:
model.to(device)
mock_image = Image.new("RGB", (448, 448), (255, 255, 255))
# In-memory storage
ds = []
images = []
@app.get("/")
def read_root():
return RedirectResponse(url="/docs")
@app.post("/index")
async def index(files: List[UploadFile] = File(...)):
global ds, images
images = []
ds = []
for file in files:
content = await file.read()
pdf_image_list = convert_from_bytes(content)
images.extend(pdf_image_list)
dataloader = DataLoader(
images,
batch_size=4,
shuffle=False,
collate_fn=lambda x: process_images(processor, x),
)
for batch_doc in dataloader:
with torch.no_grad():
batch_doc = {k: v.to(device) for k, v in batch_doc.items()}
embeddings_doc = model(**batch_doc)
ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
return {"message": f"Uploaded and converted {len(images)} pages"}
@app.post("/search")
async def search(query: str, k: int):
qs = []
with torch.no_grad():
batch_query = process_queries(processor, [query], mock_image)
batch_query = {k: v.to(device) for k, v in batch_query.items()}
embeddings_query = model(**batch_query)
qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
retriever_evaluator = CustomEvaluator(is_multi_vector=True)
scores = retriever_evaluator.evaluate(qs, ds)
top_k_indices = scores.argsort(axis=1)[0][-k:][::-1]
results = []
for idx in top_k_indices:
results.append({"image": images[idx], "page": f"Page {idx}"})
# Generate PDF
pdf_buffer = BytesIO()
c = canvas.Canvas(pdf_buffer, pagesize=letter)
width, height = letter
for result in results:
img = result["image"]
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
img.save(temp_file.name, format='PNG')
temp_file.seek(0)
c.drawImage(temp_file.name, 0, 0, width, height)
c.showPage()
c.save()
pdf_buffer.seek(0)
return FileResponse(pdf_buffer, media_type='application/pdf', filename='search_results.pdf')
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)