File size: 14,850 Bytes
8c31d70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, List, Optional, Union

import numpy as np
import torch
import transformers
from transformers import AutoImageProcessor
from transformers.image_utils import ImageInput, is_valid_image, load_image

from .ar_tokenizer_text_tokenizer import TextTokenizer
from .log import log

# Configuration for different vision-language models
IMAGE_CONFIGS = {
    "pixtral": {
        "patch_size": 16,
        "image_token": "[IMG]",
        "image_break_token": "[IMG_BREAK]",
        "image_end_token": "[IMG_END]",
    }
}

# Chat template for Pixtral-12B-Instruct
PIXTRAL_CHAT_TEMPLATE = '{%- if messages[0]["role"] == "system" %}\n    {%- set system_message = messages[0]["content"] %}\n    {%- set loop_messages = messages[1:] %}\n{%- else %}\n    {%- set loop_messages = messages %}\n{%- endif %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n    {%- if (message[\'role\'] == \'user\') != (loop.index0 % 2 == 0) %}\n        {{- raise_exception(\'After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\') }}\n    {%- endif %}\n    {%- if message["role"] == "user" %}\n        {%- if loop.last and system_message is defined %}\n            {{- "[INST]" + system_message + "\n\n" }}\n        {%- else %}\n            {{- "[INST]" }}\n        {%- endif %}\n        {%- if message["content"] is not string %}\n            {%- for chunk in message["content"] %}\n                {%- if chunk["type"] == "text" %}\n                    {{- chunk["content"] }}\n                {%- elif chunk["type"] == "image" %}\n                    {{- "[IMG]" }}\n                {%- else %}\n                    {{- raise_exception("Unrecognized content type!") }}\n                {%- endif %}\n            {%- endfor %}\n        {%- else %}\n            {{- message["content"] }}\n        {%- endif %}\n        {{- "[/INST]" }}\n    {%- elif message["role"] == "assistant" %}\n        {{- message["content"] + eos_token}}\n    {%- else %}\n        {{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}\n    {%- endif %}\n{%- endfor %}'


# Copied from transformers.models.pixtral.processing_pixtral.is_url
def is_url(val) -> bool:
    """Check if the given value is a URL."""
    return isinstance(val, str) and val.startswith("http")


# Copied from transformers.models.pixtral.processing_pixtral.is_image_or_image_url
def is_image_or_image_url(elem):
    """Check if the given element is an image or an image URL."""
    return is_url(elem) or is_valid_image(elem)


def load_image_list(
    image_list: List[Union[str, "PIL.Image.Image"]], timeout: Optional[float] = None
) -> List["PIL.Image.Image"]:
    """
    Load a list of images.

    Args:
        image_list (List[Union[str, PIL.Image.Image]]): The list of images to load.
        timeout (Optional[float]): The timeout for loading the image.

    Returns:
        List[PIL.Image.Image]: The list of loaded images.
    """
    return [load_image(image, timeout=timeout) for image in image_list]


class ImageTextTokenizer(TextTokenizer):
    """
    Image-text tokenizer class that extends the text tokenizer to support vision tokens as well.
    """

    def __init__(
        self,
        model_family: str,
        is_instruct_model: bool,
        tokenizer_path: str,
        image_processor_path: str,
    ):
        """
        Initialize the ImageTextTokenizer.

        Args:
            model_family (str): The model family.
            is_instruct_model (bool): Whether the model is an instruct model.
            s3_credential_path (str): The path to the s3 credential file. Defaults to "credentials/pbss_dir.secret".

        Raises:
            AssertionError: If the model family is not supported or if the transformers version is incompatible.
        """
        super().__init__(
            model_family=model_family,
            is_instruct_model=is_instruct_model,
            local_path=tokenizer_path,
        )
        assert model_family in ["pixtral"], f"Unsupported model family: {model_family}"
        if model_family == "pixtral":
            # Need transformers>=4.45.0
            assert transformers.__version__ >= "4.45.0", "Pixtral requires transformers>=4.45.0"
            assert is_instruct_model, "Pixtral requires is_instruct_model=True"
            if not hasattr(self.tokenizer, "chat_template") or self.tokenizer.chat_template is None:
                setattr(self.tokenizer, "chat_template", PIXTRAL_CHAT_TEMPLATE)
                log.debug(f"Pixtral tokenizer chat template set to: {PIXTRAL_CHAT_TEMPLATE}")

        # Set up image-specific configurations
        image_config = IMAGE_CONFIGS[model_family]
        self.patch_size = image_config["patch_size"]
        self.image_token = image_config["image_token"]
        self.image_break_token = image_config["image_break_token"]
        self.image_end_token = image_config["image_end_token"]

        # Initialize the image processor
        self.image_processor = AutoImageProcessor.from_pretrained(image_processor_path)

    def encode(
        self,
        text: Union[str, List[str], List[int]],
        *,  # Enforce keyword-only arguments
        images: Optional[ImageInput] = None,
        image_kwargs: Optional[Dict[str, Any]] = None,
        **text_kwargs,
    ) -> List[int]:
        """
        Process the images and return the tokenized images and text.

        Args:
            text (`str`, `List[str]`, `List[List[str]]`):
                The sequence or batch of sequences to be encoded.
            images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
                The image or batch of images to be prepared.
            image_kwargs (Optional[Dict[str, Any]]): Additional keyword arguments for image processing.
            **text_kwargs: Additional keyword arguments for text processing.

        Returns:
            A dictionary with the following fields:
            - **input_ids** -- List of token ids to be fed to a model.
            - **attention_mask** -- List of indices specifying which tokens should be attended to by the model.
            - **pixel_values** -- Pixel values to be fed to a model.

        Raises:
            ValueError: If the input images are in an invalid format.
        """

        output_dict, image_inputs = {}, {}
        if images is not None:
            # Preprocess images
            if is_image_or_image_url(images):
                images = [[images]]
            elif isinstance(images, list) and is_image_or_image_url(images[0]):
                images = [images]
            elif (
                not isinstance(images, list)
                and not isinstance(images[0], list)
                and not is_image_or_image_url(images[0][0])
            ):
                raise ValueError(
                    "Invalid input images. Please provide a single image or a list of images or a list of list of images."
                )

            # Load and process images
            images = [load_image_list(sample) for sample in images]
            image_kwargs = image_kwargs or {}
            image_inputs = self.image_processor(images, patch_size=self.patch_size, return_tensors="np", **image_kwargs)

            # Validate image inputs
            assert "pixel_values" in image_inputs, "pixel_values not found in image_inputs"
            assert "image_sizes" in image_inputs, "image_sizes not found in image_inputs"
            assert len(image_inputs.keys()) == 2, "Only one key is allowed in image_inputs, got {}".format(
                image_inputs.keys()
            )

            # Extract pixel values and image sizes
            pixel_values = image_inputs["pixel_values"][0]
            image_sizes = image_inputs["image_sizes"][0]
            unique_sizes = np.unique(image_sizes, axis=0)

            assert len(unique_sizes) == 1, "All images must have the same size, got {}".format(unique_sizes)

            # Convert pixel values to PyTorch tensor
            pixel_values = np.asarray(pixel_values)
            pixel_values = torch.from_numpy(pixel_values)
            output_dict["pixel_values"] = pixel_values
            output_dict["image_sizes"] = image_sizes

        # Expand image tokens in text
        if image_inputs.get("pixel_values") is not None:
            replace_strings = []
            # Calculate the number of tokens needed for each image and create a placeholder
            for image_size in image_sizes:
                height, width = image_size
                num_height_tokens = height // self.patch_size
                num_width_tokens = width // self.patch_size
                replace_tokens = [[self.image_token] * num_width_tokens + [self.image_break_token]] * num_height_tokens
                # Flatten list
                replace_tokens = [item for sublist in replace_tokens for item in sublist]
                replace_tokens[-1] = self.image_end_token
                replace_str = "".join(replace_tokens)
                replace_strings.append(replace_str)
                text = text.replace(self.image_token, "<placeholder>", 1)

            # Replace placeholders with actual image token sequences
            while "<placeholder>" in text:
                replace_str = replace_strings.pop(0)
                text = text.replace("<placeholder>", replace_str, 1)

        # Encode the text
        text_inputs = super(ImageTextTokenizer, self).encode(text, **text_kwargs)

        output_dict["input_ids"] = text_inputs
        return output_dict

    def apply_chat_template(
        self,
        conversation: List[Dict[str, Any]] | List[List[Dict[str, Any]]],
        *,
        images: Optional[ImageInput] = None,
        image_kwargs: Optional[Dict[str, Any]] = None,
        add_generation_prompt: bool = False,
        tokenize: bool = True,
        padding: bool = False,
        truncation: bool = False,
        max_length: Optional[int] = None,
        return_tensors: Optional[str] = None,
        return_dict: bool = True,
        return_assistant_tokens_mask: bool = False,
        generation_prefix: str = "",
        tokenizer_kwargs: Optional[Dict[str, Any]] = None,
        **kwargs,
    ):
        """
        Apply the chat template to the conversation.

        Args:
            conversation (List[Dict[str, Any]] | List[List[Dict[str, Any]]]): The conversation to process.
            images (Optional[ImageInput]): Images to include in the conversation.
            image_kwargs (Optional[Dict[str, Any]]): Additional keyword arguments for image processing.
            add_generation_prompt (bool): Whether to add a generation prompt.
            tokenize (bool): Whether to tokenize the output.
            padding (bool): Whether to pad the output.
            truncation (bool): Whether to truncate the output.
            max_length (Optional[int]): Maximum length of the output.
            return_tensors (Optional[str]): The type of tensors to return.
            return_dict (bool): Whether to return a dictionary.
            return_assistant_tokens_mask (bool): Whether to return the assistant tokens mask.
            generation_prefix (str): Prefix to add before asking model to generate. Helpful to guide the generation. Defaults to "".
            tokenizer_kwargs (Optional[Dict[str, Any]]): Additional keyword arguments for the tokenizer.
            **kwargs: Additional keyword arguments.

        Returns:
            The processed conversation with applied chat template.

        Raises:
            AssertionError: If return_dict is False or if the conversation format is invalid.
        """
        assert return_dict, "return_dict must be True for ImageTextTokenizer"
        assert isinstance(conversation, list), "conversation must be a list"
        if isinstance(conversation[0], list):
            assert len(conversation) == 1, "Only support single-conversation input, got {}".format(conversation)
            conversation = conversation[0]

        # Extract images from the conversation if not provided
        if images is None:
            images = []
            for msg in conversation:
                if msg.get("images", None) is not None:
                    images = images + (msg["images"])
            images = load_image_list(images)
        # In case the input does not have images, will ignore
        # Useful in feeding VLM inputs with and without images
        if isinstance(images, list) and len(images) == 0:
            images = None

        # Apply the chat template to the text
        text = super().apply_chat_template(
            conversation,
            tokenize=False,
            add_generation_prompt=add_generation_prompt,
            padding=padding,
            truncation=truncation,
            max_length=max_length,
            return_tensors=return_tensors,
            return_dict=False,
            return_assistant_tokens_mask=return_assistant_tokens_mask,
            generation_prefix=generation_prefix,
            tokenizer_kwargs=tokenizer_kwargs,
            **kwargs,
        )

        if tokenizer_kwargs is None:
            tokenizer_kwargs = {}

        # Encode the text and images
        output = self.encode(
            text,
            images=images,
            image_kwargs=image_kwargs,
            tokenize=tokenize,
            padding=padding,
            truncation=truncation,
            max_length=max_length,
            add_special_tokens=False,
            return_tensors=return_tensors,
            **tokenizer_kwargs,
        )
        return output

    @property
    def model_input_names(self):
        """
        Get the combined model input names from both the text tokenizer and image processor.

        Returns:
            List[str]: A list of unique input names.
        """
        tokenizer_input_names = self.tokenizer.model_input_names
        image_processor_input_names = self.image_processor.model_input_names
        return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))