Spaces:
Running
Running
File size: 3,490 Bytes
30a1a24 4390904 3862235 4390904 33de980 46c7fc6 814b6ba 9b4d509 33de980 c5d3863 33de980 37a9f9c ae91f97 46c7fc6 33de980 2902a60 d005da4 33de980 4390904 d005da4 4390904 d005da4 33de980 d005da4 3f023c5 b1bf444 0523b65 b1bf444 c5d3863 74e7ff4 0523b65 b1bf444 0523b65 b1bf444 285d8ef 68d2467 46c7fc6 30a1a24 46c7fc6 68d2467 46c7fc6 285d8ef 46c7fc6 b1bf444 5276af2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
from io import BytesIO
import os
import sys
import temp_file
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import RedirectResponse, FileResponse
import gradio as gr
import requests
import uvicorn
from typing import List
import torch
from pdf2image import convert_from_bytes
from PIL import Image
from torch.utils.data import DataLoader
from transformers import AutoProcessor
import base64
from reportlab.pdfgen import canvas
from reportlab.lib.pagesizes import letter
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), './colpali-main')))
from colpali_engine.models.paligemma_colbert_architecture import ColPali
from colpali_engine.trainer.retrieval_evaluator import CustomEvaluator
from colpali_engine.utils.colpali_processing_utils import (
process_images,
process_queries,
)
app = FastAPI()
# Load model
model_name = "vidore/colpali"
token = os.environ.get("HF_TOKEN")
model = ColPali.from_pretrained(
"google/paligemma-3b-mix-448", torch_dtype=torch.bfloat16, device_map="cpu", token = token).eval()
model.load_adapter(model_name)
processor = AutoProcessor.from_pretrained(model_name, token = token)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
if device != model.device:
model.to(device)
mock_image = Image.new("RGB", (448, 448), (255, 255, 255))
# In-memory storage
ds = []
images = []
@app.get("/")
def read_root():
return RedirectResponse(url="/docs")
@app.post("/index")
async def index(files: List[UploadFile] = File(...)):
global ds, images
images = []
ds = []
for file in files:
content = await file.read()
pdf_image_list = convert_from_bytes(content)
images.extend(pdf_image_list)
dataloader = DataLoader(
images,
batch_size=4,
shuffle=False,
collate_fn=lambda x: process_images(processor, x),
)
for batch_doc in dataloader:
with torch.no_grad():
batch_doc = {k: v.to(device) for k, v in batch_doc.items()}
embeddings_doc = model(**batch_doc)
ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
return {"message": f"Uploaded and converted {len(images)} pages"}
@app.post("/search")
async def search(query: str, k: int):
qs = []
with torch.no_grad():
batch_query = process_queries(processor, [query], mock_image)
batch_query = {k: v.to(device) for k, v in batch_query.items()}
embeddings_query = model(**batch_query)
qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
retriever_evaluator = CustomEvaluator(is_multi_vector=True)
scores = retriever_evaluator.evaluate(qs, ds)
top_k_indices = scores.argsort(axis=1)[0][-k:][::-1]
results = []
for idx in top_k_indices:
results.append({"image": images[idx], "page": f"Page {idx}"})
# Generate PDF
pdf_buffer = BytesIO()
c = canvas.Canvas(pdf_buffer, pagesize=letter)
width, height = letter
for result in results:
img = result["image"]
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
img.save(temp_file.name, format='PNG')
temp_file.seek(0)
c.drawImage(temp_file.name, 0, 0, width, height)
c.showPage()
c.save()
pdf_buffer.seek(0)
return FileResponse(pdf_buffer, media_type='application/pdf', filename='search_results.pdf')
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860) |