File size: 5,098 Bytes
a5e6882 16843db a5e6882 9931aed 16843db a5e6882 0d89ced a5e6882 9931aed a5e6882 9931aed 16843db a5e6882 9931aed 16843db a5e6882 9931aed 16843db a5e6882 9931aed a5e6882 9931aed a5e6882 9931aed a5e6882 9931aed a5e6882 9931aed a5e6882 9931aed a5e6882 9931aed a5e6882 9931aed a5e6882 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
import base64
import io
import os
from PIL import Image
import torch
from transformers import ColPaliProcessor, ColPaliForRetrieval
from typing import Dict, Any, List
class EndpointHandler:
def __init__(self, model_path: str = None):
"""
Initialize the endpoint handler using the ColPali model for OCR extraction.
If no model path is provided, it defaults to 'vidore/colpali-v1.3-hf'.
"""
if model_path is None:
model_path = os.path.dirname(os.path.realpath(__file__))
try:
# Use GPU if available, otherwise CPU.
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the specialized ColPali model (designed for retrieval but repurposed here for OCR generation).
self.model = ColPaliForRetrieval.from_pretrained(
model_path,
device_map="cuda" if torch.cuda.is_available() else "cpu",
trust_remote_code=True,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
).to(self.device)
# Load the processor that handles image preprocessing.
self.processor = ColPaliProcessor.from_pretrained(model_path, trust_remote_code=True)
except Exception as e:
raise RuntimeError(f"Error loading model or processor: {e}")
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Process the input data for OCR extraction.
Expects a dictionary with an "inputs" key containing a list of dictionaries.
Each dictionary must have an "image" key with a base64-encoded image string.
For OCR extraction, no text prompt is provided.
"""
try:
inputs_list = data.get("inputs", [])
if not inputs_list or not isinstance(inputs_list, list):
return {"error": "Inputs should be a list of dictionaries with an 'image' key."}
images: List[Image.Image] = []
for item in inputs_list:
image_b64 = item.get("image")
if not image_b64:
return {"error": "One of the input items is missing 'image' data."}
try:
# Decode the base64 string and convert to an RGB PIL image.
image = Image.open(io.BytesIO(base64.b64decode(image_b64))).convert("RGB")
images.append(image)
except Exception as e:
return {"error": f"Failed to decode one of the images: {e}"}
# Process only images with the processor (to avoid the text+image conflict).
model_inputs = self.processor(
images=images,
return_tensors="pt",
padding=True,
).to(self.device)
# Manually create a dummy text prompt by inserting a beginning-of-sequence token.
# This is necessary to trigger text generation even though no prompt is provided.
bos_token_id = (
self.processor.tokenizer.bos_token_id
or self.processor.tokenizer.cls_token_id
or self.processor.tokenizer.pad_token_id
)
if bos_token_id is None:
raise RuntimeError("No BOS token found in the tokenizer.")
batch_size = model_inputs["pixel_values"].shape[0]
dummy_input_ids = torch.full((batch_size, 1), bos_token_id, dtype=torch.long).to(self.device)
model_inputs["input_ids"] = dummy_input_ids
# Generation parameters (can be overridden via the "config" field).
config = data.get("config", {})
max_new_tokens = config.get("max_new_tokens", 256)
temperature = config.get("temperature", 0.8)
num_return_sequences = config.get("num_return_sequences", 1)
do_sample = bool(config.get("do_sample", True))
# Call generate on the model using the image-only inputs augmented with the dummy text.
outputs = self.model.generate(
**model_inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
num_return_sequences=num_return_sequences,
do_sample=do_sample,
)
# Decode generated tokens into text using the processor's tokenizer.
text_output = self.processor.tokenizer.batch_decode(outputs, skip_special_tokens=True)
return {"responses": text_output}
except Exception as e:
return {"error": f"Unexpected error: {e}"}
# Instantiate the endpoint handler.
_service = EndpointHandler()
def handle(data, context):
"""
Entry point for the Hugging Face dedicated inference endpoint.
Processes input data and returns the extracted OCR text.
"""
try:
if data is None:
return {"error": "No input data received"}
return _service(data)
except Exception as e:
return {"error": f"Exception in handler: {e}"} |