File size: 4,012 Bytes
a5e6882 16843db a5e6882 16843db a5e6882 16843db a5e6882 16843db a5e6882 16843db a5e6882 16843db a5e6882 16843db a5e6882 16843db a5e6882 16843db a5e6882 16843db a5e6882 16843db a5e6882 16843db a5e6882 16843db a5e6882 16843db a5e6882 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
import base64
import io
import os
from PIL import Image
import torch
from transformers import ColPaliProcessor, ColPaliForRetrieval
from typing import Dict, Any, List
class EndpointHandler:
def __init__(self, model_path: str = None):
"""
Initialize the endpoint handler using the ColPali retrieval model.
If no model path is provided, it defaults to 'vidore/colpali-v1.3-hf'.
"""
if model_path is None:
model_path = "vidore/colpali-v1.3-hf"
try:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Use the specialized ColPaliForRetrieval class.
self.model = ColPaliForRetrieval.from_pretrained(
model_path,
device_map="cuda" if torch.cuda.is_available() else "cpu",
trust_remote_code=True,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
).to(self.device)
# Use the specialized ColPaliProcessor.
self.processor = ColPaliProcessor.from_pretrained(model_path, trust_remote_code=True)
except Exception as e:
raise RuntimeError(f"Error loading model or processor: {e}")
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Process the input data, run inference using the ColPali retrieval model,
and return the outputs.
Expects a dictionary with an "inputs" key containing a list of dictionaries.
Each dictionary should have:
- "image": a base64-encoded image string.
- "prompt": (optional) a text prompt (default is used if missing).
"""
try:
inputs_list = data.get("inputs", [])
config = data.get("config", {})
if not inputs_list or not isinstance(inputs_list, list):
return {"error": "Inputs should be a list of dictionaries with 'image' and optionally 'prompt' keys."}
images: List[Image.Image] = []
texts: List[str] = []
for item in inputs_list:
image_b64 = item.get("image")
if not image_b64:
return {"error": "One of the input items is missing 'image' data."}
try:
# Decode the base64-encoded image and convert to RGB.
image = Image.open(io.BytesIO(base64.b64decode(image_b64))).convert("RGB")
images.append(image)
except Exception as e:
return {"error": f"Failed to decode one of the images: {e}"}
# Use the provided prompt or a default prompt.
prompt = item.get("prompt", "Describe the image content in detail.")
texts.append(prompt)
# Prepare inputs with the ColPali processor.
model_inputs = self.processor(
images=images,
text=texts,
padding=True,
return_tensors="pt",
).to(self.device)
# For retrieval, we call the model directly rather than using generate().
outputs = self.model(**model_inputs)
# Assuming that the model returns logits or retrieval scores,
# we extract and convert them to lists.
retrieval_scores = outputs.logits.tolist() if hasattr(outputs, "logits") else outputs
return {"responses": retrieval_scores}
except Exception as e:
return {"error": f"Unexpected error: {e}"}
# Instantiate the endpoint handler.
_service = EndpointHandler()
def handle(data, context):
"""
Entry point for the Hugging Face dedicated inference endpoint.
Processes the input data and returns the model's outputs.
"""
try:
if data is None:
return {"error": "No input data received"}
return _service(data)
except Exception as e:
return {"error": f"Exception in handler: {e}"} |