File size: 11,211 Bytes
3b5d4a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
import os
import json
import warnings
from pathlib import Path


import torch
import torch.nn as nn

from transformers import (
    PreTrainedTokenizer,
    PreTrainedTokenizerBase,
    ProcessorMixin,
    BatchFeature,
)
from transformers.utils import (
    logging,
    direct_transformers_import,
    PROCESSOR_NAME,
    CHAT_TEMPLATE_NAME,
)
from transformers.image_utils import ImageInput
from transformers.dynamic_module_utils import custom_object_save

logger = logging.get_logger(__name__)

# Dynamically import the Transformers module to grab the attribute classes of the processor form their names.
transformers_module = direct_transformers_import(Path(__file__).parent)


class MultiProcessorKwargs:
    _defaults = {
        "tokenizer_1_kwargs": {
            "padding": False,
        },
        "tokenizer_2_kwargs": {
            "padding": False,
        },
    }


class MultiProcessor(ProcessorMixin):
    attributes = ["tokenizer_1", "tokenizer_2"]
    valid_kwargs = ["chat_template"]
    tokenizer_1_class = "AutoTokenizer"
    tokenizer_2_class = "AutoTokenizer"

    tokenizer_1: PreTrainedTokenizer
    tokenizer_2: PreTrainedTokenizer

    def __init__(
        self,
        tokenizer_1=None,
        tokenizer_2=None,
        chat_template=None,
        **kwargs,
    ):
        super().__init__(
            tokenizer_1,
            tokenizer_2,
            chat_template=chat_template,
            **kwargs,
        )

    def __call__(
        self,
        text_1: str | list[str] | None = None,
        text_2: str | list[str] | None = None,
        **kwargs,
    ) -> BatchFeature:
        def _validate_text_input(text) -> str | list[str]:
            if isinstance(text, list):
                assert all(
                    isinstance(t, str) for t in text
                ), f"Expected list of str but got {type(text)}"
                assert all(len(t) > 0 for t in text), "Expected non-empty strings"
            else:
                assert isinstance(text, str), f"Expected str but got {type(text)}"
            return text

        def _normalize_text_input(text: str | list[str]) -> list[str]:
            if isinstance(text, str):
                return [text]
            return text

        _text_1: str | list[str] = _validate_text_input(text_1)
        text_1_list: list[str] = _normalize_text_input(_text_1)
        _text_2: str | list[str] = _validate_text_input(text_2)
        text_2_list: list[str] = _normalize_text_input(_text_2)

        tokenizer_1_output_kwargs = {
            **MultiProcessorKwargs._defaults["tokenizer_1_kwargs"],
            "return_tensors": "pt",
            **kwargs,
        }
        tokenizer_2_output_kwargs = {
            **MultiProcessorKwargs._defaults["tokenizer_2_kwargs"],
            "return_tensors": "pt",
            **kwargs,
        }

        # tokenize
        text_1_inputs = self.tokenizer_1(
            text_1_list,
            **tokenizer_1_output_kwargs,
        )
        text_2_inputs = self.tokenizer_2(
            text_2_list,
            **tokenizer_2_output_kwargs,
        )

        return BatchFeature(
            data={
                "input_ids": text_1_inputs.get("input_ids"),
                "attention_mask": text_1_inputs.get("attention_mask"),
                "input_ids_2": text_2_inputs.get("input_ids"),
                "attention_mask_2": text_2_inputs.get("attention_mask"),
            }
        )

    def batch_decode(self, *args, **kwargs):
        """
        This method forwards all its arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
        refer to the docstring of this method for more information.
        """
        return self.tokenizer_2_tokenizer.batch_decode(*args, **kwargs)

    def decode(self, *args, **kwargs):
        """
        This method forwards all its arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
        the docstring of this method for more information.
        """
        return self.tokenizer_2_tokenizer.decode(*args, **kwargs)

    @property
    def model_input_names(self):
        return ["text_1", "text_2"]

    # edit from: https://github.com/huggingface/transformers/blob/1d063793318b20654ebb850f48f43e0a247ab7bb/src/transformers/processing_utils.py#L980-L995
    @classmethod
    def _get_arguments_from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
        args = []
        for attribute_name in cls.attributes:
            class_name = getattr(cls, f"{attribute_name}_class")
            subfolder = attribute_name  # subfolder is the same as attribute_name
            if isinstance(class_name, tuple):
                classes = tuple(
                    getattr(transformers_module, n) if n is not None else None
                    for n in class_name
                )
                use_fast = kwargs.get("use_fast", True)
                if use_fast and classes[1] is not None:
                    attribute_class = classes[1]
                else:
                    attribute_class = classes[0]
            else:
                attribute_class = getattr(transformers_module, class_name)

            assert attribute_class is not None, f"Missing attribute class: {class_name}"
            args.append(
                attribute_class.from_pretrained(
                    pretrained_model_name_or_path,
                    subfolder=subfolder,
                    **kwargs,
                )
            )
        return args

    # edit from: https://github.com/huggingface/transformers/blob/1d063793318b20654ebb850f48f43e0a247ab7bb/src/transformers/processing_utils.py#L460-L560
    def save_pretrained(self, save_directory, push_to_hub: bool = False, **kwargs):
        """
        Saves the attributes of this processor (feature extractor, tokenizer...) in the specified directory so that it
        can be reloaded using the [`~ProcessorMixin.from_pretrained`] method.

        <Tip>

        This class method is simply calling [`~feature_extraction_utils.FeatureExtractionMixin.save_pretrained`] and
        [`~tokenization_utils_base.PreTrainedTokenizerBase.save_pretrained`]. Please refer to the docstrings of the
        methods above for more information.

        </Tip>

        Args:
            save_directory (`str` or `os.PathLike`):
                Directory where the feature extractor JSON file and the tokenizer files will be saved (directory will
                be created if it does not exist).
            push_to_hub (`bool`, *optional*, defaults to `False`):
                Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
                repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
                namespace).
            kwargs (`Dict[str, Any]`, *optional*):
                Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
        """
        use_auth_token = kwargs.pop("use_auth_token", None)

        if use_auth_token is not None:
            warnings.warn(
                "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
                FutureWarning,
            )
            if kwargs.get("token", None) is not None:
                raise ValueError(
                    "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
                )
            kwargs["token"] = use_auth_token

        os.makedirs(save_directory, exist_ok=True)

        if push_to_hub:
            commit_message = kwargs.pop("commit_message", None)
            repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
            repo_id = self._create_repo(repo_id, **kwargs)
            files_timestamps = self._get_files_timestamps(save_directory)
        # If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
        # loaded from the Hub.
        if self._auto_class is not None:
            attrs = [
                getattr(self, attribute_name) for attribute_name in self.attributes
            ]
            configs = [
                (a.init_kwargs if isinstance(a, PreTrainedTokenizerBase) else a)
                for a in attrs
            ]
            configs.append(self)
            custom_object_save(self, save_directory, config=configs)

        for attribute_name in self.attributes:
            attribute = getattr(self, attribute_name)
            # Include the processor class in the attribute config so this processor can then be reloaded with the
            # `AutoProcessor` API.
            if hasattr(attribute, "_set_processor_class"):
                attribute._set_processor_class(self.__class__.__name__)
            attribute.save_pretrained(
                os.path.join(
                    save_directory,
                    attribute_name,  # CHANGED: save to subfolder
                ),
            )

        if self._auto_class is not None:
            # We added an attribute to the init_kwargs of the tokenizers, which needs to be cleaned up.
            for attribute_name in self.attributes:
                attribute = getattr(self, attribute_name)
                if isinstance(attribute, PreTrainedTokenizerBase):
                    del attribute.init_kwargs["auto_map"]

        # If we save using the predefined names, we can load using `from_pretrained`
        # plus we save chat_template in its own file
        output_processor_file = os.path.join(save_directory, PROCESSOR_NAME)
        output_chat_template_file = os.path.join(save_directory, CHAT_TEMPLATE_NAME)

        processor_dict = self.to_dict()
        # Save `chat_template` in its own file. We can't get it from `processor_dict` as we popped it in `to_dict`
        # to avoid serializing chat template in json config file. So let's get it from `self` directly
        if self.chat_template is not None:
            chat_template_json_string = (
                json.dumps(
                    {"chat_template": self.chat_template}, indent=2, sort_keys=True
                )
                + "\n"
            )
            with open(output_chat_template_file, "w", encoding="utf-8") as writer:
                writer.write(chat_template_json_string)
            logger.info(f"chat template saved in {output_chat_template_file}")

        # For now, let's not save to `processor_config.json` if the processor doesn't have extra attributes and
        # `auto_map` is not specified.
        if set(processor_dict.keys()) != {"processor_class"}:
            self.to_json_file(output_processor_file)
            logger.info(f"processor saved in {output_processor_file}")

        if push_to_hub:
            self._upload_modified_files(
                save_directory,
                repo_id,
                files_timestamps,
                commit_message=commit_message,
                token=kwargs.get("token"),
            )

        if set(processor_dict.keys()) == {"processor_class"}:
            return []
        return [output_processor_file]