import base64 import io import os from PIL import Image import torch from transformers import ColPaliProcessor, ColPaliForRetrieval from typing import Dict, Any, List class EndpointHandler: def __init__(self, model_path: str = None): """ Initialize the endpoint handler using the ColPali retrieval model. If no model path is provided, it defaults to 'vidore/colpali-v1.3-hf'. """ if model_path is None: model_path = "vidore/colpali-v1.3-hf" try: self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Use the specialized ColPaliForRetrieval class. self.model = ColPaliForRetrieval.from_pretrained( model_path, device_map="cuda" if torch.cuda.is_available() else "cpu", trust_remote_code=True, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, ).to(self.device) # Use the specialized ColPaliProcessor. self.processor = ColPaliProcessor.from_pretrained(model_path, trust_remote_code=True) except Exception as e: raise RuntimeError(f"Error loading model or processor: {e}") def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: """ Process the input data, run inference using the ColPali retrieval model, and return the outputs. Expects a dictionary with an "inputs" key containing a list of dictionaries. Each dictionary should have: - "image": a base64-encoded image string. - "prompt": (optional) a text prompt (default is used if missing). """ try: inputs_list = data.get("inputs", []) config = data.get("config", {}) if not inputs_list or not isinstance(inputs_list, list): return {"error": "Inputs should be a list of dictionaries with 'image' and optionally 'prompt' keys."} images: List[Image.Image] = [] texts: List[str] = [] for item in inputs_list: image_b64 = item.get("image") if not image_b64: return {"error": "One of the input items is missing 'image' data."} try: # Decode the base64-encoded image and convert to RGB. image = Image.open(io.BytesIO(base64.b64decode(image_b64))).convert("RGB") images.append(image) except Exception as e: return {"error": f"Failed to decode one of the images: {e}"} # Use the provided prompt or a default prompt. prompt = item.get("prompt", "Describe the image content in detail.") texts.append(prompt) # Prepare inputs with the ColPali processor. model_inputs = self.processor( images=images, text=texts, padding=True, return_tensors="pt", ).to(self.device) # For retrieval, we call the model directly rather than using generate(). outputs = self.model(**model_inputs) # Assuming that the model returns logits or retrieval scores, # we extract and convert them to lists. retrieval_scores = outputs.logits.tolist() if hasattr(outputs, "logits") else outputs return {"responses": retrieval_scores} except Exception as e: return {"error": f"Unexpected error: {e}"} # Instantiate the endpoint handler. _service = EndpointHandler() def handle(data, context): """ Entry point for the Hugging Face dedicated inference endpoint. Processes the input data and returns the model's outputs. """ try: if data is None: return {"error": "No input data received"} return _service(data) except Exception as e: return {"error": f"Exception in handler: {e}"}