File size: 3,490 Bytes
30a1a24
4390904
 
3862235
4390904
33de980
46c7fc6
814b6ba
 
9b4d509
33de980
 
c5d3863
33de980
 
37a9f9c
ae91f97
46c7fc6
 
33de980
2902a60
 
d005da4
 
 
 
 
 
33de980
 
4390904
d005da4
 
 
 
 
4390904
d005da4
 
 
 
 
 
 
33de980
 
 
d005da4
3f023c5
 
 
 
b1bf444
0523b65
b1bf444
 
 
 
c5d3863
74e7ff4
0523b65
b1bf444
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0523b65
b1bf444
 
 
 
 
 
 
 
 
285d8ef
 
 
 
 
68d2467
46c7fc6
 
30a1a24
46c7fc6
 
 
68d2467
 
 
 
 
 
46c7fc6
 
 
285d8ef
46c7fc6
b1bf444
5276af2
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from io import BytesIO
import os
import sys
import temp_file

from fastapi import FastAPI, File, UploadFile
from fastapi.responses import RedirectResponse, FileResponse
import gradio as gr
import requests
import uvicorn
from typing import List
import torch
from pdf2image import convert_from_bytes
from PIL import Image
from torch.utils.data import DataLoader
from transformers import AutoProcessor
import base64
from reportlab.pdfgen import canvas
from reportlab.lib.pagesizes import letter

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), './colpali-main')))

from colpali_engine.models.paligemma_colbert_architecture import ColPali
from colpali_engine.trainer.retrieval_evaluator import CustomEvaluator
from colpali_engine.utils.colpali_processing_utils import (
    process_images,
    process_queries,
)

app = FastAPI()

# Load model
model_name = "vidore/colpali"
token = os.environ.get("HF_TOKEN")
model = ColPali.from_pretrained(
    "google/paligemma-3b-mix-448", torch_dtype=torch.bfloat16, device_map="cpu", token = token).eval()

model.load_adapter(model_name)
processor = AutoProcessor.from_pretrained(model_name, token = token)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
if device != model.device:
    model.to(device)
mock_image = Image.new("RGB", (448, 448), (255, 255, 255))

# In-memory storage
ds = []
images = []

@app.get("/")
def read_root():
    return RedirectResponse(url="/docs")

@app.post("/index")
async def index(files: List[UploadFile] = File(...)):
    global ds, images
    images = []
    ds = []
    for file in files:
        content = await file.read()
        pdf_image_list = convert_from_bytes(content)
        images.extend(pdf_image_list)
    
    dataloader = DataLoader(
        images,
        batch_size=4,
        shuffle=False,
        collate_fn=lambda x: process_images(processor, x),
    )
    for batch_doc in dataloader:
        with torch.no_grad():
            batch_doc = {k: v.to(device) for k, v in batch_doc.items()}
            embeddings_doc = model(**batch_doc)
        ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
    
    return {"message": f"Uploaded and converted {len(images)} pages"}

@app.post("/search")
async def search(query: str, k: int):
    qs = []
    with torch.no_grad():
        batch_query = process_queries(processor, [query], mock_image)
        batch_query = {k: v.to(device) for k, v in batch_query.items()}
        embeddings_query = model(**batch_query)
        qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))

    retriever_evaluator = CustomEvaluator(is_multi_vector=True)
    scores = retriever_evaluator.evaluate(qs, ds)

    top_k_indices = scores.argsort(axis=1)[0][-k:][::-1]

    results = []
    for idx in top_k_indices:
        results.append({"image": images[idx], "page": f"Page {idx}"})

    # Generate PDF
    pdf_buffer = BytesIO()
    c = canvas.Canvas(pdf_buffer, pagesize=letter)
    width, height = letter
    for result in results:
        img = result["image"]
        with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
            img.save(temp_file.name, format='PNG')
            temp_file.seek(0)
            c.drawImage(temp_file.name, 0, 0, width, height)
            c.showPage()

    c.save()
    pdf_buffer.seek(0)

    return FileResponse(pdf_buffer, media_type='application/pdf', filename='search_results.pdf')

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)