Spaces:
Running
Running
File size: 5,801 Bytes
4390904 0af5344 814b6ba 9b4d509 23a02a2 b61b33c 33de980 c5d3863 33de980 37a9f9c 46c7fc6 b61b33c 33de980 2902a60 d005da4 33de980 4390904 d005da4 4390904 d005da4 33de980 d005da4 b909e86 b1bf444 0523b65 b1bf444 c5d3863 74e7ff4 0523b65 b1bf444 4ae29c1 30a1a24 46c7fc6 5306b43 46c7fc6 5306b43 68d2467 5306b43 68d2467 5306b43 46c7fc6 4ae29c1 285d8ef 5306b43 aaf0cd4 4ae29c1 60947a6 4ae29c1 fe7b387 4ae29c1 b1bf444 aaf0cd4 5276af2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
import os
import sys
import tempfile
import gradio as gr
import requests
import uvicorn
import torch
import base64
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import RedirectResponse, StreamingResponse
from typing import List
from pdf2image import convert_from_bytes
from PIL import Image
from torch.utils.data import DataLoader
from transformers import AutoProcessor
from reportlab.pdfgen import canvas
from reportlab.lib.pagesizes import letter
from io import BytesIO
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), './colpali-main')))
from colpali_engine.models.paligemma_colbert_architecture import ColPali
from colpali_engine.trainer.retrieval_evaluator import CustomEvaluator
from colpali_engine.utils.colpali_processing_utils import (
process_images,
process_queries,
)
app = FastAPI()
# Load model
model_name = "vidore/colpali"
token = os.environ.get("HF_TOKEN")
model = ColPali.from_pretrained(
"google/paligemma-3b-mix-448", torch_dtype=torch.bfloat16, device_map="cpu", token = token).eval()
model.load_adapter(model_name)
processor = AutoProcessor.from_pretrained(model_name, token = token)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
if device != model.device:
model.to(device)
mock_image = Image.new("RGB", (448, 448), (255, 255, 255))
# In-memory storage
ds = []
images = []
@app.get("/")
def read_root():
return RedirectResponse(url="/docs")
@app.post("/index")
async def index(files: List[UploadFile] = File(...)):
global ds, images
images = []
ds = []
for file in files:
content = await file.read()
pdf_image_list = convert_from_bytes(content)
images.extend(pdf_image_list)
dataloader = DataLoader(
images,
batch_size=4,
shuffle=False,
collate_fn=lambda x: process_images(processor, x),
)
for batch_doc in dataloader:
with torch.no_grad():
batch_doc = {k: v.to(device) for k, v in batch_doc.items()}
embeddings_doc = model(**batch_doc)
ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
return {"message": f"Uploaded and converted {len(images)} pages"}
def generate_pdf(results):
pdf_buffer = BytesIO()
c = canvas.Canvas(pdf_buffer, pagesize=letter)
width, height = letter
for result in results:
img_base64 = result["image"]
img_data = base64.b64decode(img_base64)
# Create a temporary file to hold the image
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
temp_file.write(img_data)
temp_file.flush()
# Draw the image from the temporary file
c.drawImage(temp_file.name, 0, 0, width, height)
c.showPage()
# Clean up the temporary file
os.remove(temp_file.name)
c.save()
pdf_buffer.seek(0)
return pdf_buffer
@app.get("/search")
async def search(query: str, k: int = 1):
qs = []
with torch.no_grad():
batch_query = process_queries(processor, [query], mock_image)
batch_query = {k: v.to(device) for k, v in batch_query.items()}
embeddings_query = model(**batch_query)
qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
retriever_evaluator = CustomEvaluator(is_multi_vector=True)
scores = retriever_evaluator.evaluate(qs, ds)
top_k_indices = scores.argsort(axis=1)[0][-k:][::-1]
results = []
for idx in top_k_indices:
img_byte_arr = BytesIO()
images[idx].save(img_byte_arr, format='PNG')
img_base64 = base64.b64encode(img_byte_arr.getvalue()).decode('utf-8')
results.append({"image": img_base64, "page": f"Page {idx}"})
pdf_buffer = generate_pdf(results)
# Use StreamingResponse to handle in-memory file
response = StreamingResponse(pdf_buffer, media_type='application/pdf')
response.headers['Content-Disposition'] = 'attachment; filename="results.pdf"'
return response
@app.post("/search_by_cv")
async def search_by_cv(file: UploadFile = File(...), k: int = 10):
# Lire le fichier PDF uploadé
content = await file.read()
pdf_image_list = convert_from_bytes(content)
# Générer les embeddings pour les pages du PDF uploadé
qs = []
dataloader = DataLoader(
pdf_image_list,
batch_size=4,
shuffle=False,
collate_fn=lambda x: process_images(processor, x),
)
for batch_query in dataloader:
with torch.no_grad():
batch_query = {k: v.to(device) for k, v in batch_query.items()}
embeddings_query = model(**batch_query)
qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
# Comparer les embeddings du CV uploadé avec ceux déjà indexés
retriever_evaluator = CustomEvaluator(is_multi_vector=True)
scores = retriever_evaluator.evaluate(qs, ds)
# Trouver les indices des résultats les plus pertinents
top_k_indices = scores.argsort(axis=1)[0][-k-1:-1][::-1]
# Préparer les résultats sous forme d'images
results = []
for idx in top_k_indices:
img_byte_arr = BytesIO()
images[idx].save(img_byte_arr, format='PNG')
img_base64 = base64.b64encode(img_byte_arr.getvalue()).decode('utf-8')
results.append({"image": img_base64, "page": f"Page {idx}"})
# Générer le PDF des résultats
pdf_buffer = generate_pdf(results)
# Utiliser StreamingResponse pour renvoyer le fichier PDF généré
response = StreamingResponse(pdf_buffer, media_type='application/pdf')
response.headers['Content-Disposition'] = 'attachment; filename="results.pdf"'
return response
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860) |