Visual Document Retrieval
Transformers
Safetensors
ColPali
English
pretraining
adrish commited on
Commit
16843db
·
1 Parent(s): a5e6882

Updated handler.py to use ColPaliProcessor, ColPaliForRetrieval

Browse files
Files changed (1) hide show
  1. handler.py +23 -35
handler.py CHANGED
@@ -3,40 +3,40 @@ import io
3
  import os
4
  from PIL import Image
5
  import torch
6
- from transformers import AutoProcessor, AutoModelForImageTextToText
7
  from typing import Dict, Any, List
8
 
9
  class EndpointHandler:
10
  def __init__(self, model_path: str = None):
11
  """
12
- Initialize the endpoint handler by loading the ColPali model for image-to-text generation.
13
- If no model path is provided, it defaults to 'vidore/colpali-v1.3-hf' on Hugging Face.
14
  """
15
  if model_path is None:
16
- model_path = os.path.dirname(os.path.realpath(__file__))
17
  try:
18
- # Select GPU if available, otherwise fall back to CPU.
19
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
- # Load the model with the generic ImageTextToText interface.
21
- self.model = AutoModelForImageTextToText.from_pretrained(
22
  model_path,
23
  device_map="cuda" if torch.cuda.is_available() else "cpu",
24
  trust_remote_code=True,
25
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
26
- _attn_implementation="flash_attention_2"
27
  ).to(self.device)
28
- # Load the processor which handles both image preprocessing and text tokenization.
29
- self.processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
30
  except Exception as e:
31
  raise RuntimeError(f"Error loading model or processor: {e}")
32
 
33
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
34
  """
35
- Process the input data for image-to-text generation.
 
 
36
  Expects a dictionary with an "inputs" key containing a list of dictionaries.
37
  Each dictionary should have:
38
  - "image": a base64-encoded image string.
39
- - "prompt": (optional) a text prompt (a default prompt is used if missing).
40
  """
41
  try:
42
  inputs_list = data.get("inputs", [])
@@ -53,42 +53,30 @@ class EndpointHandler:
53
  if not image_b64:
54
  return {"error": "One of the input items is missing 'image' data."}
55
  try:
56
- # Decode base64 image and convert to RGB.
57
  image = Image.open(io.BytesIO(base64.b64decode(image_b64))).convert("RGB")
58
  images.append(image)
59
  except Exception as e:
60
  return {"error": f"Failed to decode one of the images: {e}"}
61
- # Use the provided prompt or fall back to a default prompt.
62
  prompt = item.get("prompt", "Describe the image content in detail.")
63
  texts.append(prompt)
64
 
65
- # Process both text and image inputs via the processor.
66
  model_inputs = self.processor(
67
- text=texts,
68
  images=images,
 
69
  padding=True,
70
  return_tensors="pt",
71
  ).to(self.device)
72
 
73
- # Generation configuration (can be overridden by the request).
74
- max_new_tokens = config.get("max_new_tokens", 1000)
75
- temperature = config.get("temperature", 0.8)
76
- num_return_sequences = config.get("num_return_sequences", 1)
77
- do_sample = bool(config.get("do_sample", True))
78
-
79
- # Generate outputs using the model.
80
- outputs = self.model.generate(
81
- **model_inputs,
82
- temperature=temperature,
83
- max_new_tokens=max_new_tokens,
84
- num_return_sequences=num_return_sequences,
85
- do_sample=do_sample,
86
- )
87
-
88
- # Decode the generated tokens into human-readable text.
89
- text_output = self.processor.tokenizer.batch_decode(outputs, skip_special_tokens=True)
90
 
91
- return {"responses": text_output}
92
 
93
  except Exception as e:
94
  return {"error": f"Unexpected error: {e}"}
@@ -99,7 +87,7 @@ _service = EndpointHandler()
99
  def handle(data, context):
100
  """
101
  Entry point for the Hugging Face dedicated inference endpoint.
102
- It processes the input data and returns the model's generated responses.
103
  """
104
  try:
105
  if data is None:
 
3
  import os
4
  from PIL import Image
5
  import torch
6
+ from transformers import ColPaliProcessor, ColPaliForRetrieval
7
  from typing import Dict, Any, List
8
 
9
  class EndpointHandler:
10
  def __init__(self, model_path: str = None):
11
  """
12
+ Initialize the endpoint handler using the ColPali retrieval model.
13
+ If no model path is provided, it defaults to 'vidore/colpali-v1.3-hf'.
14
  """
15
  if model_path is None:
16
+ model_path = "vidore/colpali-v1.3-hf"
17
  try:
 
18
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+ # Use the specialized ColPaliForRetrieval class.
20
+ self.model = ColPaliForRetrieval.from_pretrained(
21
  model_path,
22
  device_map="cuda" if torch.cuda.is_available() else "cpu",
23
  trust_remote_code=True,
24
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
 
25
  ).to(self.device)
26
+ # Use the specialized ColPaliProcessor.
27
+ self.processor = ColPaliProcessor.from_pretrained(model_path, trust_remote_code=True)
28
  except Exception as e:
29
  raise RuntimeError(f"Error loading model or processor: {e}")
30
 
31
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
32
  """
33
+ Process the input data, run inference using the ColPali retrieval model,
34
+ and return the outputs.
35
+
36
  Expects a dictionary with an "inputs" key containing a list of dictionaries.
37
  Each dictionary should have:
38
  - "image": a base64-encoded image string.
39
+ - "prompt": (optional) a text prompt (default is used if missing).
40
  """
41
  try:
42
  inputs_list = data.get("inputs", [])
 
53
  if not image_b64:
54
  return {"error": "One of the input items is missing 'image' data."}
55
  try:
56
+ # Decode the base64-encoded image and convert to RGB.
57
  image = Image.open(io.BytesIO(base64.b64decode(image_b64))).convert("RGB")
58
  images.append(image)
59
  except Exception as e:
60
  return {"error": f"Failed to decode one of the images: {e}"}
61
+ # Use the provided prompt or a default prompt.
62
  prompt = item.get("prompt", "Describe the image content in detail.")
63
  texts.append(prompt)
64
 
65
+ # Prepare inputs with the ColPali processor.
66
  model_inputs = self.processor(
 
67
  images=images,
68
+ text=texts,
69
  padding=True,
70
  return_tensors="pt",
71
  ).to(self.device)
72
 
73
+ # For retrieval, we call the model directly rather than using generate().
74
+ outputs = self.model(**model_inputs)
75
+ # Assuming that the model returns logits or retrieval scores,
76
+ # we extract and convert them to lists.
77
+ retrieval_scores = outputs.logits.tolist() if hasattr(outputs, "logits") else outputs
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
+ return {"responses": retrieval_scores}
80
 
81
  except Exception as e:
82
  return {"error": f"Unexpected error: {e}"}
 
87
  def handle(data, context):
88
  """
89
  Entry point for the Hugging Face dedicated inference endpoint.
90
+ Processes the input data and returns the model's outputs.
91
  """
92
  try:
93
  if data is None: