|
import os |
|
import json |
|
import warnings |
|
from pathlib import Path |
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from transformers import ( |
|
PreTrainedTokenizer, |
|
PreTrainedTokenizerBase, |
|
ProcessorMixin, |
|
BatchFeature, |
|
) |
|
from transformers.utils import ( |
|
logging, |
|
direct_transformers_import, |
|
PROCESSOR_NAME, |
|
CHAT_TEMPLATE_NAME, |
|
) |
|
from transformers.image_utils import ImageInput |
|
from transformers.dynamic_module_utils import custom_object_save |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
transformers_module = direct_transformers_import(Path(__file__).parent) |
|
|
|
|
|
class MultiProcessorKwargs: |
|
_defaults = { |
|
"tokenizer_1_kwargs": { |
|
"padding": False, |
|
}, |
|
"tokenizer_2_kwargs": { |
|
"padding": False, |
|
}, |
|
} |
|
|
|
|
|
class MultiProcessor(ProcessorMixin): |
|
attributes = ["tokenizer_1", "tokenizer_2"] |
|
valid_kwargs = ["chat_template"] |
|
tokenizer_1_class = "AutoTokenizer" |
|
tokenizer_2_class = "AutoTokenizer" |
|
|
|
tokenizer_1: PreTrainedTokenizer |
|
tokenizer_2: PreTrainedTokenizer |
|
|
|
def __init__( |
|
self, |
|
tokenizer_1=None, |
|
tokenizer_2=None, |
|
chat_template=None, |
|
**kwargs, |
|
): |
|
super().__init__( |
|
tokenizer_1, |
|
tokenizer_2, |
|
chat_template=chat_template, |
|
**kwargs, |
|
) |
|
|
|
def __call__( |
|
self, |
|
text_1: str | list[str] | None = None, |
|
text_2: str | list[str] | None = None, |
|
**kwargs, |
|
) -> BatchFeature: |
|
def _validate_text_input(text) -> str | list[str]: |
|
if isinstance(text, list): |
|
assert all( |
|
isinstance(t, str) for t in text |
|
), f"Expected list of str but got {type(text)}" |
|
assert all(len(t) > 0 for t in text), "Expected non-empty strings" |
|
else: |
|
assert isinstance(text, str), f"Expected str but got {type(text)}" |
|
return text |
|
|
|
def _normalize_text_input(text: str | list[str]) -> list[str]: |
|
if isinstance(text, str): |
|
return [text] |
|
return text |
|
|
|
_text_1: str | list[str] = _validate_text_input(text_1) |
|
text_1_list: list[str] = _normalize_text_input(_text_1) |
|
_text_2: str | list[str] = _validate_text_input(text_2) |
|
text_2_list: list[str] = _normalize_text_input(_text_2) |
|
|
|
tokenizer_1_output_kwargs = { |
|
**MultiProcessorKwargs._defaults["tokenizer_1_kwargs"], |
|
"return_tensors": "pt", |
|
**kwargs, |
|
} |
|
tokenizer_2_output_kwargs = { |
|
**MultiProcessorKwargs._defaults["tokenizer_2_kwargs"], |
|
"return_tensors": "pt", |
|
**kwargs, |
|
} |
|
|
|
|
|
text_1_inputs = self.tokenizer_1( |
|
text_1_list, |
|
**tokenizer_1_output_kwargs, |
|
) |
|
text_2_inputs = self.tokenizer_2( |
|
text_2_list, |
|
**tokenizer_2_output_kwargs, |
|
) |
|
|
|
return BatchFeature( |
|
data={ |
|
"input_ids": text_1_inputs.get("input_ids"), |
|
"attention_mask": text_1_inputs.get("attention_mask"), |
|
"input_ids_2": text_2_inputs.get("input_ids"), |
|
"attention_mask_2": text_2_inputs.get("attention_mask"), |
|
} |
|
) |
|
|
|
def batch_decode(self, *args, **kwargs): |
|
""" |
|
This method forwards all its arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please |
|
refer to the docstring of this method for more information. |
|
""" |
|
return self.tokenizer_2_tokenizer.batch_decode(*args, **kwargs) |
|
|
|
def decode(self, *args, **kwargs): |
|
""" |
|
This method forwards all its arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to |
|
the docstring of this method for more information. |
|
""" |
|
return self.tokenizer_2_tokenizer.decode(*args, **kwargs) |
|
|
|
@property |
|
def model_input_names(self): |
|
return ["text_1", "text_2"] |
|
|
|
|
|
@classmethod |
|
def _get_arguments_from_pretrained(cls, pretrained_model_name_or_path, **kwargs): |
|
args = [] |
|
for attribute_name in cls.attributes: |
|
class_name = getattr(cls, f"{attribute_name}_class") |
|
subfolder = attribute_name |
|
if isinstance(class_name, tuple): |
|
classes = tuple( |
|
getattr(transformers_module, n) if n is not None else None |
|
for n in class_name |
|
) |
|
use_fast = kwargs.get("use_fast", True) |
|
if use_fast and classes[1] is not None: |
|
attribute_class = classes[1] |
|
else: |
|
attribute_class = classes[0] |
|
else: |
|
attribute_class = getattr(transformers_module, class_name) |
|
|
|
assert attribute_class is not None, f"Missing attribute class: {class_name}" |
|
args.append( |
|
attribute_class.from_pretrained( |
|
pretrained_model_name_or_path, |
|
subfolder=subfolder, |
|
**kwargs, |
|
) |
|
) |
|
return args |
|
|
|
|
|
def save_pretrained(self, save_directory, push_to_hub: bool = False, **kwargs): |
|
""" |
|
Saves the attributes of this processor (feature extractor, tokenizer...) in the specified directory so that it |
|
can be reloaded using the [`~ProcessorMixin.from_pretrained`] method. |
|
|
|
<Tip> |
|
|
|
This class method is simply calling [`~feature_extraction_utils.FeatureExtractionMixin.save_pretrained`] and |
|
[`~tokenization_utils_base.PreTrainedTokenizerBase.save_pretrained`]. Please refer to the docstrings of the |
|
methods above for more information. |
|
|
|
</Tip> |
|
|
|
Args: |
|
save_directory (`str` or `os.PathLike`): |
|
Directory where the feature extractor JSON file and the tokenizer files will be saved (directory will |
|
be created if it does not exist). |
|
push_to_hub (`bool`, *optional*, defaults to `False`): |
|
Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the |
|
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your |
|
namespace). |
|
kwargs (`Dict[str, Any]`, *optional*): |
|
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. |
|
""" |
|
use_auth_token = kwargs.pop("use_auth_token", None) |
|
|
|
if use_auth_token is not None: |
|
warnings.warn( |
|
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", |
|
FutureWarning, |
|
) |
|
if kwargs.get("token", None) is not None: |
|
raise ValueError( |
|
"`token` and `use_auth_token` are both specified. Please set only the argument `token`." |
|
) |
|
kwargs["token"] = use_auth_token |
|
|
|
os.makedirs(save_directory, exist_ok=True) |
|
|
|
if push_to_hub: |
|
commit_message = kwargs.pop("commit_message", None) |
|
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) |
|
repo_id = self._create_repo(repo_id, **kwargs) |
|
files_timestamps = self._get_files_timestamps(save_directory) |
|
|
|
|
|
if self._auto_class is not None: |
|
attrs = [ |
|
getattr(self, attribute_name) for attribute_name in self.attributes |
|
] |
|
configs = [ |
|
(a.init_kwargs if isinstance(a, PreTrainedTokenizerBase) else a) |
|
for a in attrs |
|
] |
|
configs.append(self) |
|
custom_object_save(self, save_directory, config=configs) |
|
|
|
for attribute_name in self.attributes: |
|
attribute = getattr(self, attribute_name) |
|
|
|
|
|
if hasattr(attribute, "_set_processor_class"): |
|
attribute._set_processor_class(self.__class__.__name__) |
|
attribute.save_pretrained( |
|
os.path.join( |
|
save_directory, |
|
attribute_name, |
|
), |
|
) |
|
|
|
if self._auto_class is not None: |
|
|
|
for attribute_name in self.attributes: |
|
attribute = getattr(self, attribute_name) |
|
if isinstance(attribute, PreTrainedTokenizerBase): |
|
del attribute.init_kwargs["auto_map"] |
|
|
|
|
|
|
|
output_processor_file = os.path.join(save_directory, PROCESSOR_NAME) |
|
output_chat_template_file = os.path.join(save_directory, CHAT_TEMPLATE_NAME) |
|
|
|
processor_dict = self.to_dict() |
|
|
|
|
|
if self.chat_template is not None: |
|
chat_template_json_string = ( |
|
json.dumps( |
|
{"chat_template": self.chat_template}, indent=2, sort_keys=True |
|
) |
|
+ "\n" |
|
) |
|
with open(output_chat_template_file, "w", encoding="utf-8") as writer: |
|
writer.write(chat_template_json_string) |
|
logger.info(f"chat template saved in {output_chat_template_file}") |
|
|
|
|
|
|
|
if set(processor_dict.keys()) != {"processor_class"}: |
|
self.to_json_file(output_processor_file) |
|
logger.info(f"processor saved in {output_processor_file}") |
|
|
|
if push_to_hub: |
|
self._upload_modified_files( |
|
save_directory, |
|
repo_id, |
|
files_timestamps, |
|
commit_message=commit_message, |
|
token=kwargs.get("token"), |
|
) |
|
|
|
if set(processor_dict.keys()) == {"processor_class"}: |
|
return [] |
|
return [output_processor_file] |
|
|