|
import base64 |
|
import io |
|
import os |
|
from PIL import Image |
|
import torch |
|
from transformers import ColPaliProcessor, ColPaliForRetrieval |
|
from typing import Dict, Any, List |
|
|
|
class EndpointHandler: |
|
def __init__(self, model_path: str = None): |
|
""" |
|
Initialize the endpoint handler using the ColPali retrieval model. |
|
If no model path is provided, it defaults to 'vidore/colpali-v1.3-hf'. |
|
""" |
|
if model_path is None: |
|
model_path = "vidore/colpali-v1.3-hf" |
|
try: |
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
self.model = ColPaliForRetrieval.from_pretrained( |
|
model_path, |
|
device_map="cuda" if torch.cuda.is_available() else "cpu", |
|
trust_remote_code=True, |
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
).to(self.device) |
|
|
|
self.processor = ColPaliProcessor.from_pretrained(model_path, trust_remote_code=True) |
|
except Exception as e: |
|
raise RuntimeError(f"Error loading model or processor: {e}") |
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
""" |
|
Process the input data, run inference using the ColPali retrieval model, |
|
and return the outputs. |
|
|
|
Expects a dictionary with an "inputs" key containing a list of dictionaries. |
|
Each dictionary should have: |
|
- "image": a base64-encoded image string. |
|
- "prompt": (optional) a text prompt (default is used if missing). |
|
""" |
|
try: |
|
inputs_list = data.get("inputs", []) |
|
config = data.get("config", {}) |
|
|
|
if not inputs_list or not isinstance(inputs_list, list): |
|
return {"error": "Inputs should be a list of dictionaries with 'image' and optionally 'prompt' keys."} |
|
|
|
images: List[Image.Image] = [] |
|
texts: List[str] = [] |
|
|
|
for item in inputs_list: |
|
image_b64 = item.get("image") |
|
if not image_b64: |
|
return {"error": "One of the input items is missing 'image' data."} |
|
try: |
|
|
|
image = Image.open(io.BytesIO(base64.b64decode(image_b64))).convert("RGB") |
|
images.append(image) |
|
except Exception as e: |
|
return {"error": f"Failed to decode one of the images: {e}"} |
|
|
|
prompt = item.get("prompt", "Describe the image content in detail.") |
|
texts.append(prompt) |
|
|
|
|
|
model_inputs = self.processor( |
|
images=images, |
|
text=texts, |
|
padding=True, |
|
return_tensors="pt", |
|
).to(self.device) |
|
|
|
|
|
outputs = self.model(**model_inputs) |
|
|
|
|
|
retrieval_scores = outputs.logits.tolist() if hasattr(outputs, "logits") else outputs |
|
|
|
return {"responses": retrieval_scores} |
|
|
|
except Exception as e: |
|
return {"error": f"Unexpected error: {e}"} |
|
|
|
|
|
_service = EndpointHandler() |
|
|
|
def handle(data, context): |
|
""" |
|
Entry point for the Hugging Face dedicated inference endpoint. |
|
Processes the input data and returns the model's outputs. |
|
""" |
|
try: |
|
if data is None: |
|
return {"error": "No input data received"} |
|
return _service(data) |
|
except Exception as e: |
|
return {"error": f"Exception in handler: {e}"} |