Updated handler.py to use ColPaliProcessor, ColPaliForRetrieval
Browse files- handler.py +23 -35
handler.py
CHANGED
@@ -3,40 +3,40 @@ import io
|
|
3 |
import os
|
4 |
from PIL import Image
|
5 |
import torch
|
6 |
-
from transformers import
|
7 |
from typing import Dict, Any, List
|
8 |
|
9 |
class EndpointHandler:
|
10 |
def __init__(self, model_path: str = None):
|
11 |
"""
|
12 |
-
Initialize the endpoint handler
|
13 |
-
If no model path is provided, it defaults to 'vidore/colpali-v1.3-hf'
|
14 |
"""
|
15 |
if model_path is None:
|
16 |
-
model_path =
|
17 |
try:
|
18 |
-
# Select GPU if available, otherwise fall back to CPU.
|
19 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
20 |
-
#
|
21 |
-
self.model =
|
22 |
model_path,
|
23 |
device_map="cuda" if torch.cuda.is_available() else "cpu",
|
24 |
trust_remote_code=True,
|
25 |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
26 |
-
_attn_implementation="flash_attention_2"
|
27 |
).to(self.device)
|
28 |
-
#
|
29 |
-
self.processor =
|
30 |
except Exception as e:
|
31 |
raise RuntimeError(f"Error loading model or processor: {e}")
|
32 |
|
33 |
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
34 |
"""
|
35 |
-
Process the input data
|
|
|
|
|
36 |
Expects a dictionary with an "inputs" key containing a list of dictionaries.
|
37 |
Each dictionary should have:
|
38 |
- "image": a base64-encoded image string.
|
39 |
-
- "prompt": (optional) a text prompt (
|
40 |
"""
|
41 |
try:
|
42 |
inputs_list = data.get("inputs", [])
|
@@ -53,42 +53,30 @@ class EndpointHandler:
|
|
53 |
if not image_b64:
|
54 |
return {"error": "One of the input items is missing 'image' data."}
|
55 |
try:
|
56 |
-
# Decode base64 image and convert to RGB.
|
57 |
image = Image.open(io.BytesIO(base64.b64decode(image_b64))).convert("RGB")
|
58 |
images.append(image)
|
59 |
except Exception as e:
|
60 |
return {"error": f"Failed to decode one of the images: {e}"}
|
61 |
-
# Use the provided prompt or
|
62 |
prompt = item.get("prompt", "Describe the image content in detail.")
|
63 |
texts.append(prompt)
|
64 |
|
65 |
-
#
|
66 |
model_inputs = self.processor(
|
67 |
-
text=texts,
|
68 |
images=images,
|
|
|
69 |
padding=True,
|
70 |
return_tensors="pt",
|
71 |
).to(self.device)
|
72 |
|
73 |
-
#
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
# Generate outputs using the model.
|
80 |
-
outputs = self.model.generate(
|
81 |
-
**model_inputs,
|
82 |
-
temperature=temperature,
|
83 |
-
max_new_tokens=max_new_tokens,
|
84 |
-
num_return_sequences=num_return_sequences,
|
85 |
-
do_sample=do_sample,
|
86 |
-
)
|
87 |
-
|
88 |
-
# Decode the generated tokens into human-readable text.
|
89 |
-
text_output = self.processor.tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
90 |
|
91 |
-
return {"responses":
|
92 |
|
93 |
except Exception as e:
|
94 |
return {"error": f"Unexpected error: {e}"}
|
@@ -99,7 +87,7 @@ _service = EndpointHandler()
|
|
99 |
def handle(data, context):
|
100 |
"""
|
101 |
Entry point for the Hugging Face dedicated inference endpoint.
|
102 |
-
|
103 |
"""
|
104 |
try:
|
105 |
if data is None:
|
|
|
3 |
import os
|
4 |
from PIL import Image
|
5 |
import torch
|
6 |
+
from transformers import ColPaliProcessor, ColPaliForRetrieval
|
7 |
from typing import Dict, Any, List
|
8 |
|
9 |
class EndpointHandler:
|
10 |
def __init__(self, model_path: str = None):
|
11 |
"""
|
12 |
+
Initialize the endpoint handler using the ColPali retrieval model.
|
13 |
+
If no model path is provided, it defaults to 'vidore/colpali-v1.3-hf'.
|
14 |
"""
|
15 |
if model_path is None:
|
16 |
+
model_path = "vidore/colpali-v1.3-hf"
|
17 |
try:
|
|
|
18 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
19 |
+
# Use the specialized ColPaliForRetrieval class.
|
20 |
+
self.model = ColPaliForRetrieval.from_pretrained(
|
21 |
model_path,
|
22 |
device_map="cuda" if torch.cuda.is_available() else "cpu",
|
23 |
trust_remote_code=True,
|
24 |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
|
|
25 |
).to(self.device)
|
26 |
+
# Use the specialized ColPaliProcessor.
|
27 |
+
self.processor = ColPaliProcessor.from_pretrained(model_path, trust_remote_code=True)
|
28 |
except Exception as e:
|
29 |
raise RuntimeError(f"Error loading model or processor: {e}")
|
30 |
|
31 |
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
32 |
"""
|
33 |
+
Process the input data, run inference using the ColPali retrieval model,
|
34 |
+
and return the outputs.
|
35 |
+
|
36 |
Expects a dictionary with an "inputs" key containing a list of dictionaries.
|
37 |
Each dictionary should have:
|
38 |
- "image": a base64-encoded image string.
|
39 |
+
- "prompt": (optional) a text prompt (default is used if missing).
|
40 |
"""
|
41 |
try:
|
42 |
inputs_list = data.get("inputs", [])
|
|
|
53 |
if not image_b64:
|
54 |
return {"error": "One of the input items is missing 'image' data."}
|
55 |
try:
|
56 |
+
# Decode the base64-encoded image and convert to RGB.
|
57 |
image = Image.open(io.BytesIO(base64.b64decode(image_b64))).convert("RGB")
|
58 |
images.append(image)
|
59 |
except Exception as e:
|
60 |
return {"error": f"Failed to decode one of the images: {e}"}
|
61 |
+
# Use the provided prompt or a default prompt.
|
62 |
prompt = item.get("prompt", "Describe the image content in detail.")
|
63 |
texts.append(prompt)
|
64 |
|
65 |
+
# Prepare inputs with the ColPali processor.
|
66 |
model_inputs = self.processor(
|
|
|
67 |
images=images,
|
68 |
+
text=texts,
|
69 |
padding=True,
|
70 |
return_tensors="pt",
|
71 |
).to(self.device)
|
72 |
|
73 |
+
# For retrieval, we call the model directly rather than using generate().
|
74 |
+
outputs = self.model(**model_inputs)
|
75 |
+
# Assuming that the model returns logits or retrieval scores,
|
76 |
+
# we extract and convert them to lists.
|
77 |
+
retrieval_scores = outputs.logits.tolist() if hasattr(outputs, "logits") else outputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
|
79 |
+
return {"responses": retrieval_scores}
|
80 |
|
81 |
except Exception as e:
|
82 |
return {"error": f"Unexpected error: {e}"}
|
|
|
87 |
def handle(data, context):
|
88 |
"""
|
89 |
Entry point for the Hugging Face dedicated inference endpoint.
|
90 |
+
Processes the input data and returns the model's outputs.
|
91 |
"""
|
92 |
try:
|
93 |
if data is None:
|