Spaces:
Sleeping
Sleeping
# coding=utf-8 | |
# Copyright 2018 The HuggingFace Inc. team. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import csv | |
import importlib | |
import json | |
import os | |
import pickle | |
import sys | |
import warnings | |
from abc import ABC, abstractmethod | |
from contextlib import contextmanager | |
from os.path import abspath, exists | |
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union | |
from ..feature_extraction_utils import PreTrainedFeatureExtractor | |
from ..file_utils import add_end_docstrings, is_tf_available, is_torch_available | |
from ..modelcard import ModelCard | |
from ..models.auto.configuration_auto import AutoConfig | |
from ..tokenization_utils import PreTrainedTokenizer, TruncationStrategy | |
from ..utils import logging | |
if is_tf_available(): | |
import tensorflow as tf | |
from ..models.auto.modeling_tf_auto import TFAutoModel | |
if is_torch_available(): | |
import torch | |
from ..models.auto.modeling_auto import AutoModel | |
if TYPE_CHECKING: | |
from ..modeling_tf_utils import TFPreTrainedModel | |
from ..modeling_utils import PreTrainedModel | |
logger = logging.get_logger(__name__) | |
def infer_framework_load_model( | |
model, | |
config: AutoConfig, | |
model_classes: Optional[Dict[str, Tuple[type]]] = None, | |
task: Optional[str] = None, | |
framework: Optional[str] = None, | |
**model_kwargs | |
): | |
""" | |
Select framework (TensorFlow or PyTorch) to use from the :obj:`model` passed. Returns a tuple (framework, model). | |
If :obj:`model` is instantiated, this function will just infer the framework from the model class. Otherwise | |
:obj:`model` is actually a checkpoint name and this method will try to instantiate it using :obj:`model_classes`. | |
Since we don't want to instantiate the model twice, this model is returned for use by the pipeline. | |
If both frameworks are installed and available for :obj:`model`, PyTorch is selected. | |
Args: | |
model (:obj:`str`, :class:`~transformers.PreTrainedModel` or :class:`~transformers.TFPreTrainedModel`): | |
The model to infer the framework from. If :obj:`str`, a checkpoint name. The model to infer the framewrok | |
from. | |
config (:class:`~transformers.AutoConfig`): | |
The config associated with the model to help using the correct class | |
model_classes (dictionary :obj:`str` to :obj:`type`, `optional`): | |
A mapping framework to class. | |
task (:obj:`str`): | |
The task defining which pipeline will be returned. | |
model_kwargs: | |
Additional dictionary of keyword arguments passed along to the model's :obj:`from_pretrained(..., | |
**model_kwargs)` function. | |
Returns: | |
:obj:`Tuple`: A tuple framework, model. | |
""" | |
if not is_tf_available() and not is_torch_available(): | |
raise RuntimeError( | |
"At least one of TensorFlow 2.0 or PyTorch should be installed. " | |
"To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ " | |
"To install PyTorch, read the instructions at https://pytorch.org/." | |
) | |
if isinstance(model, str): | |
model_kwargs["_from_pipeline"] = task | |
class_tuple = () | |
look_pt = is_torch_available() and framework in {"pt", None} | |
look_tf = is_tf_available() and framework in {"tf", None} | |
if model_classes: | |
if look_pt: | |
class_tuple = class_tuple + model_classes.get("pt", (AutoModel,)) | |
if look_tf: | |
class_tuple = class_tuple + model_classes.get("tf", (TFAutoModel,)) | |
if config.architectures: | |
classes = [] | |
for architecture in config.architectures: | |
transformers_module = importlib.import_module("transformers") | |
if look_pt: | |
_class = getattr(transformers_module, architecture, None) | |
if _class is not None: | |
classes.append(_class) | |
if look_tf: | |
_class = getattr(transformers_module, f"TF{architecture}", None) | |
if _class is not None: | |
classes.append(_class) | |
class_tuple = class_tuple + tuple(classes) | |
if len(class_tuple) == 0: | |
raise ValueError(f"Pipeline cannot infer suitable model classes from {model}") | |
for model_class in class_tuple: | |
kwargs = model_kwargs.copy() | |
if framework == "pt" and model.endswith(".h5"): | |
kwargs["from_tf"] = True | |
logger.warning( | |
"Model might be a TensorFlow model (ending with `.h5`) but TensorFlow is not available. " | |
"Trying to load the model with PyTorch." | |
) | |
elif framework == "tf" and model.endswith(".bin"): | |
kwargs["from_pt"] = True | |
logger.warning( | |
"Model might be a PyTorch model (ending with `.bin`) but PyTorch is not available. " | |
"Trying to load the model with Tensorflow." | |
) | |
try: | |
model = model_class.from_pretrained(model, **kwargs) | |
# Stop loading on the first successful load. | |
break | |
except (OSError, ValueError): | |
continue | |
if isinstance(model, str): | |
raise ValueError(f"Could not load model {model} with any of the following classes: {class_tuple}.") | |
framework = "tf" if model.__class__.__name__.startswith("TF") else "pt" | |
return framework, model | |
def infer_framework_from_model( | |
model, | |
model_classes: Optional[Dict[str, Tuple[type]]] = None, | |
task: Optional[str] = None, | |
framework: Optional[str] = None, | |
**model_kwargs | |
): | |
""" | |
Select framework (TensorFlow or PyTorch) to use from the :obj:`model` passed. Returns a tuple (framework, model). | |
If :obj:`model` is instantiated, this function will just infer the framework from the model class. Otherwise | |
:obj:`model` is actually a checkpoint name and this method will try to instantiate it using :obj:`model_classes`. | |
Since we don't want to instantiate the model twice, this model is returned for use by the pipeline. | |
If both frameworks are installed and available for :obj:`model`, PyTorch is selected. | |
Args: | |
model (:obj:`str`, :class:`~transformers.PreTrainedModel` or :class:`~transformers.TFPreTrainedModel`): | |
The model to infer the framework from. If :obj:`str`, a checkpoint name. The model to infer the framewrok | |
from. | |
model_classes (dictionary :obj:`str` to :obj:`type`, `optional`): | |
A mapping framework to class. | |
task (:obj:`str`): | |
The task defining which pipeline will be returned. | |
model_kwargs: | |
Additional dictionary of keyword arguments passed along to the model's :obj:`from_pretrained(..., | |
**model_kwargs)` function. | |
Returns: | |
:obj:`Tuple`: A tuple framework, model. | |
""" | |
if isinstance(model, str): | |
config = AutoConfig.from_pretrained(model, _from_pipeline=task, **model_kwargs) | |
else: | |
config = model.config | |
return infer_framework_load_model( | |
model, config, model_classes=model_classes, _from_pipeline=task, task=task, framework=framework, **model_kwargs | |
) | |
def get_framework(model, revision: Optional[str] = None): | |
""" | |
Select framework (TensorFlow or PyTorch) to use. | |
Args: | |
model (:obj:`str`, :class:`~transformers.PreTrainedModel` or :class:`~transformers.TFPreTrainedModel`): | |
If both frameworks are installed, picks the one corresponding to the model passed (either a model class or | |
the model name). If no specific model is provided, defaults to using PyTorch. | |
""" | |
warnings.warn( | |
"`get_framework` is deprecated and will be removed in v5, use `infer_framework_from_model` instead.", | |
FutureWarning, | |
) | |
if not is_tf_available() and not is_torch_available(): | |
raise RuntimeError( | |
"At least one of TensorFlow 2.0 or PyTorch should be installed. " | |
"To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ " | |
"To install PyTorch, read the instructions at https://pytorch.org/." | |
) | |
if isinstance(model, str): | |
if is_torch_available() and not is_tf_available(): | |
model = AutoModel.from_pretrained(model, revision=revision) | |
elif is_tf_available() and not is_torch_available(): | |
model = TFAutoModel.from_pretrained(model, revision=revision) | |
else: | |
try: | |
model = AutoModel.from_pretrained(model, revision=revision) | |
except OSError: | |
model = TFAutoModel.from_pretrained(model, revision=revision) | |
framework = "tf" if model.__class__.__name__.startswith("TF") else "pt" | |
return framework | |
def get_default_model(targeted_task: Dict, framework: Optional[str], task_options: Optional[Any]) -> str: | |
""" | |
Select a default model to use for a given task. Defaults to pytorch if ambiguous. | |
Args: | |
targeted_task (:obj:`Dict` ): | |
Dictionary representing the given task, that should contain default models | |
framework (:obj:`str`, None) | |
"pt", "tf" or None, representing a specific framework if it was specified, or None if we don't know yet. | |
task_options (:obj:`Any`, None) | |
Any further value required by the task to get fully specified, for instance (SRC, TGT) languages for | |
translation task. | |
Returns | |
:obj:`str` The model string representing the default model for this pipeline | |
""" | |
if is_torch_available() and not is_tf_available(): | |
framework = "pt" | |
elif is_tf_available() and not is_torch_available(): | |
framework = "tf" | |
defaults = targeted_task["default"] | |
if task_options: | |
if task_options not in defaults: | |
raise ValueError(f"The task does not provide any default models for options {task_options}") | |
default_models = defaults[task_options]["model"] | |
elif "model" in defaults: | |
default_models = targeted_task["default"]["model"] | |
else: | |
# XXX This error message needs to be updated to be more generic if more tasks are going to become | |
# parametrized | |
raise ValueError('The task defaults can\'t be correctly selected. You probably meant "translation_XX_to_YY"') | |
if framework is None: | |
framework = "pt" | |
return default_models[framework] | |
class PipelineException(Exception): | |
""" | |
Raised by a :class:`~transformers.Pipeline` when handling __call__. | |
Args: | |
task (:obj:`str`): The task of the pipeline. | |
model (:obj:`str`): The model used by the pipeline. | |
reason (:obj:`str`): The error message to display. | |
""" | |
def __init__(self, task: str, model: str, reason: str): | |
super().__init__(reason) | |
self.task = task | |
self.model = model | |
class ArgumentHandler(ABC): | |
""" | |
Base interface for handling arguments for each :class:`~transformers.pipelines.Pipeline`. | |
""" | |
def __call__(self, *args, **kwargs): | |
raise NotImplementedError() | |
class PipelineDataFormat: | |
""" | |
Base class for all the pipeline supported data format both for reading and writing. Supported data formats | |
currently includes: | |
- JSON | |
- CSV | |
- stdin/stdout (pipe) | |
:obj:`PipelineDataFormat` also includes some utilities to work with multi-columns like mapping from datasets | |
columns to pipelines keyword arguments through the :obj:`dataset_kwarg_1=dataset_column_1` format. | |
Args: | |
output_path (:obj:`str`, `optional`): Where to save the outgoing data. | |
input_path (:obj:`str`, `optional`): Where to look for the input data. | |
column (:obj:`str`, `optional`): The column to read. | |
overwrite (:obj:`bool`, `optional`, defaults to :obj:`False`): | |
Whether or not to overwrite the :obj:`output_path`. | |
""" | |
SUPPORTED_FORMATS = ["json", "csv", "pipe"] | |
def __init__( | |
self, | |
output_path: Optional[str], | |
input_path: Optional[str], | |
column: Optional[str], | |
overwrite: bool = False, | |
): | |
self.output_path = output_path | |
self.input_path = input_path | |
self.column = column.split(",") if column is not None else [""] | |
self.is_multi_columns = len(self.column) > 1 | |
if self.is_multi_columns: | |
self.column = [tuple(c.split("=")) if "=" in c else (c, c) for c in self.column] | |
if output_path is not None and not overwrite: | |
if exists(abspath(self.output_path)): | |
raise OSError(f"{self.output_path} already exists on disk") | |
if input_path is not None: | |
if not exists(abspath(self.input_path)): | |
raise OSError(f"{self.input_path} doesnt exist on disk") | |
def __iter__(self): | |
raise NotImplementedError() | |
def save(self, data: Union[dict, List[dict]]): | |
""" | |
Save the provided data object with the representation for the current | |
:class:`~transformers.pipelines.PipelineDataFormat`. | |
Args: | |
data (:obj:`dict` or list of :obj:`dict`): The data to store. | |
""" | |
raise NotImplementedError() | |
def save_binary(self, data: Union[dict, List[dict]]) -> str: | |
""" | |
Save the provided data object as a pickle-formatted binary data on the disk. | |
Args: | |
data (:obj:`dict` or list of :obj:`dict`): The data to store. | |
Returns: | |
:obj:`str`: Path where the data has been saved. | |
""" | |
path, _ = os.path.splitext(self.output_path) | |
binary_path = os.path.extsep.join((path, "pickle")) | |
with open(binary_path, "wb+") as f_output: | |
pickle.dump(data, f_output) | |
return binary_path | |
def from_str( | |
format: str, | |
output_path: Optional[str], | |
input_path: Optional[str], | |
column: Optional[str], | |
overwrite=False, | |
) -> "PipelineDataFormat": | |
""" | |
Creates an instance of the right subclass of :class:`~transformers.pipelines.PipelineDataFormat` depending on | |
:obj:`format`. | |
Args: | |
format: (:obj:`str`): | |
The format of the desired pipeline. Acceptable values are :obj:`"json"`, :obj:`"csv"` or :obj:`"pipe"`. | |
output_path (:obj:`str`, `optional`): | |
Where to save the outgoing data. | |
input_path (:obj:`str`, `optional`): | |
Where to look for the input data. | |
column (:obj:`str`, `optional`): | |
The column to read. | |
overwrite (:obj:`bool`, `optional`, defaults to :obj:`False`): | |
Whether or not to overwrite the :obj:`output_path`. | |
Returns: | |
:class:`~transformers.pipelines.PipelineDataFormat`: The proper data format. | |
""" | |
if format == "json": | |
return JsonPipelineDataFormat(output_path, input_path, column, overwrite=overwrite) | |
elif format == "csv": | |
return CsvPipelineDataFormat(output_path, input_path, column, overwrite=overwrite) | |
elif format == "pipe": | |
return PipedPipelineDataFormat(output_path, input_path, column, overwrite=overwrite) | |
else: | |
raise KeyError(f"Unknown reader {format} (Available reader are json/csv/pipe)") | |
class CsvPipelineDataFormat(PipelineDataFormat): | |
""" | |
Support for pipelines using CSV data format. | |
Args: | |
output_path (:obj:`str`, `optional`): Where to save the outgoing data. | |
input_path (:obj:`str`, `optional`): Where to look for the input data. | |
column (:obj:`str`, `optional`): The column to read. | |
overwrite (:obj:`bool`, `optional`, defaults to :obj:`False`): | |
Whether or not to overwrite the :obj:`output_path`. | |
""" | |
def __init__( | |
self, | |
output_path: Optional[str], | |
input_path: Optional[str], | |
column: Optional[str], | |
overwrite=False, | |
): | |
super().__init__(output_path, input_path, column, overwrite=overwrite) | |
def __iter__(self): | |
with open(self.input_path, "r") as f: | |
reader = csv.DictReader(f) | |
for row in reader: | |
if self.is_multi_columns: | |
yield {k: row[c] for k, c in self.column} | |
else: | |
yield row[self.column[0]] | |
def save(self, data: List[dict]): | |
""" | |
Save the provided data object with the representation for the current | |
:class:`~transformers.pipelines.PipelineDataFormat`. | |
Args: | |
data (:obj:`List[dict]`): The data to store. | |
""" | |
with open(self.output_path, "w") as f: | |
if len(data) > 0: | |
writer = csv.DictWriter(f, list(data[0].keys())) | |
writer.writeheader() | |
writer.writerows(data) | |
class JsonPipelineDataFormat(PipelineDataFormat): | |
""" | |
Support for pipelines using JSON file format. | |
Args: | |
output_path (:obj:`str`, `optional`): Where to save the outgoing data. | |
input_path (:obj:`str`, `optional`): Where to look for the input data. | |
column (:obj:`str`, `optional`): The column to read. | |
overwrite (:obj:`bool`, `optional`, defaults to :obj:`False`): | |
Whether or not to overwrite the :obj:`output_path`. | |
""" | |
def __init__( | |
self, | |
output_path: Optional[str], | |
input_path: Optional[str], | |
column: Optional[str], | |
overwrite=False, | |
): | |
super().__init__(output_path, input_path, column, overwrite=overwrite) | |
with open(input_path, "r") as f: | |
self._entries = json.load(f) | |
def __iter__(self): | |
for entry in self._entries: | |
if self.is_multi_columns: | |
yield {k: entry[c] for k, c in self.column} | |
else: | |
yield entry[self.column[0]] | |
def save(self, data: dict): | |
""" | |
Save the provided data object in a json file. | |
Args: | |
data (:obj:`dict`): The data to store. | |
""" | |
with open(self.output_path, "w") as f: | |
json.dump(data, f) | |
class PipedPipelineDataFormat(PipelineDataFormat): | |
""" | |
Read data from piped input to the python process. For multi columns data, columns should separated by \t | |
If columns are provided, then the output will be a dictionary with {column_x: value_x} | |
Args: | |
output_path (:obj:`str`, `optional`): Where to save the outgoing data. | |
input_path (:obj:`str`, `optional`): Where to look for the input data. | |
column (:obj:`str`, `optional`): The column to read. | |
overwrite (:obj:`bool`, `optional`, defaults to :obj:`False`): | |
Whether or not to overwrite the :obj:`output_path`. | |
""" | |
def __iter__(self): | |
for line in sys.stdin: | |
# Split for multi-columns | |
if "\t" in line: | |
line = line.split("\t") | |
if self.column: | |
# Dictionary to map arguments | |
yield {kwargs: l for (kwargs, _), l in zip(self.column, line)} | |
else: | |
yield tuple(line) | |
# No dictionary to map arguments | |
else: | |
yield line | |
def save(self, data: dict): | |
""" | |
Print the data. | |
Args: | |
data (:obj:`dict`): The data to store. | |
""" | |
print(data) | |
def save_binary(self, data: Union[dict, List[dict]]) -> str: | |
if self.output_path is None: | |
raise KeyError( | |
"When using piped input on pipeline outputting large object requires an output file path. " | |
"Please provide such output path through --output argument." | |
) | |
return super().save_binary(data) | |
class _ScikitCompat(ABC): | |
""" | |
Interface layer for the Scikit and Keras compatibility. | |
""" | |
def transform(self, X): | |
raise NotImplementedError() | |
def predict(self, X): | |
raise NotImplementedError() | |
PIPELINE_INIT_ARGS = r""" | |
Arguments: | |
model (:obj:`~transformers.PreTrainedModel` or :obj:`~transformers.TFPreTrainedModel`): | |
The model that will be used by the pipeline to make predictions. This needs to be a model inheriting from | |
:class:`~transformers.PreTrainedModel` for PyTorch and :class:`~transformers.TFPreTrainedModel` for | |
TensorFlow. | |
tokenizer (:obj:`~transformers.PreTrainedTokenizer`): | |
The tokenizer that will be used by the pipeline to encode data for the model. This object inherits from | |
:class:`~transformers.PreTrainedTokenizer`. | |
modelcard (:obj:`str` or :class:`~transformers.ModelCard`, `optional`): | |
Model card attributed to the model for this pipeline. | |
framework (:obj:`str`, `optional`): | |
The framework to use, either :obj:`"pt"` for PyTorch or :obj:`"tf"` for TensorFlow. The specified framework | |
must be installed. | |
If no framework is specified, will default to the one currently installed. If no framework is specified and | |
both frameworks are installed, will default to the framework of the :obj:`model`, or to PyTorch if no model | |
is provided. | |
task (:obj:`str`, defaults to :obj:`""`): | |
A task-identifier for the pipeline. | |
args_parser (:class:`~transformers.pipelines.ArgumentHandler`, `optional`): | |
Reference to the object in charge of parsing supplied pipeline parameters. | |
device (:obj:`int`, `optional`, defaults to -1): | |
Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, a positive will run the model on | |
the associated CUDA device id. | |
binary_output (:obj:`bool`, `optional`, defaults to :obj:`False`): | |
Flag indicating if the output the pipeline should happen in a binary format (i.e., pickle) or as raw text. | |
""" | |
class Pipeline(_ScikitCompat): | |
""" | |
The Pipeline class is the class from which all pipelines inherit. Refer to this class for methods shared across | |
different pipelines. | |
Base class implementing pipelined operations. Pipeline workflow is defined as a sequence of the following | |
operations: | |
Input -> Tokenization -> Model Inference -> Post-Processing (task dependent) -> Output | |
Pipeline supports running on CPU or GPU through the device argument (see below). | |
Some pipeline, like for instance :class:`~transformers.FeatureExtractionPipeline` (:obj:`'feature-extraction'` ) | |
output large tensor object as nested-lists. In order to avoid dumping such large structure as textual data we | |
provide the :obj:`binary_output` constructor argument. If set to :obj:`True`, the output will be stored in the | |
pickle format. | |
""" | |
default_input_names = None | |
def __init__( | |
self, | |
model: Union["PreTrainedModel", "TFPreTrainedModel"], | |
tokenizer: Optional[PreTrainedTokenizer] = None, | |
feature_extractor: Optional[PreTrainedFeatureExtractor] = None, | |
modelcard: Optional[ModelCard] = None, | |
framework: Optional[str] = None, | |
task: str = "", | |
args_parser: ArgumentHandler = None, | |
device: int = -1, | |
binary_output: bool = False, | |
): | |
if framework is None: | |
framework, model = infer_framework_load_model(model, config=model.config) | |
self.task = task | |
self.model = model | |
self.tokenizer = tokenizer | |
self.feature_extractor = feature_extractor | |
self.modelcard = modelcard | |
self.framework = framework | |
self.device = device if framework == "tf" else torch.device("cpu" if device < 0 else f"cuda:{device}") | |
self.binary_output = binary_output | |
# Special handling | |
if self.framework == "pt" and self.device.type == "cuda": | |
self.model = self.model.to(self.device) | |
# Update config with task specific parameters | |
task_specific_params = self.model.config.task_specific_params | |
if task_specific_params is not None and task in task_specific_params: | |
self.model.config.update(task_specific_params.get(task)) | |
def save_pretrained(self, save_directory: str): | |
""" | |
Save the pipeline's model and tokenizer. | |
Args: | |
save_directory (:obj:`str`): | |
A path to the directory where to saved. It will be created if it doesn't exist. | |
""" | |
if os.path.isfile(save_directory): | |
logger.error(f"Provided path ({save_directory}) should be a directory, not a file") | |
return | |
os.makedirs(save_directory, exist_ok=True) | |
self.model.save_pretrained(save_directory) | |
if self.tokenizer is not None: | |
self.tokenizer.save_pretrained(save_directory) | |
if self.feature_extractor is not None: | |
self.feature_extractor.save_pretrained(save_directory) | |
if self.modelcard is not None: | |
self.modelcard.save_pretrained(save_directory) | |
def transform(self, X): | |
""" | |
Scikit / Keras interface to transformers' pipelines. This method will forward to __call__(). | |
""" | |
return self(X=X) | |
def predict(self, X): | |
""" | |
Scikit / Keras interface to transformers' pipelines. This method will forward to __call__(). | |
""" | |
return self(X=X) | |
def device_placement(self): | |
""" | |
Context Manager allowing tensor allocation on the user-specified device in framework agnostic way. | |
Returns: | |
Context manager | |
Examples:: | |
# Explicitly ask for tensor allocation on CUDA device :0 | |
pipe = pipeline(..., device=0) | |
with pipe.device_placement(): | |
# Every framework specific tensor allocation will be done on the request device | |
output = pipe(...) | |
""" | |
if self.framework == "tf": | |
with tf.device("/CPU:0" if self.device == -1 else f"/device:GPU:{self.device}"): | |
yield | |
else: | |
if self.device.type == "cuda": | |
torch.cuda.set_device(self.device) | |
yield | |
def ensure_tensor_on_device(self, **inputs): | |
""" | |
Ensure PyTorch tensors are on the specified device. | |
Args: | |
inputs (keyword arguments that should be :obj:`torch.Tensor`): The tensors to place on :obj:`self.device`. | |
Return: | |
:obj:`Dict[str, torch.Tensor]`: The same as :obj:`inputs` but on the proper device. | |
""" | |
return { | |
name: tensor.to(self.device) if isinstance(tensor, torch.Tensor) else tensor | |
for name, tensor in inputs.items() | |
} | |
def check_model_type(self, supported_models: Union[List[str], dict]): | |
""" | |
Check if the model class is in supported by the pipeline. | |
Args: | |
supported_models (:obj:`List[str]` or :obj:`dict`): | |
The list of models supported by the pipeline, or a dictionary with model class values. | |
""" | |
if not isinstance(supported_models, list): # Create from a model mapping | |
supported_models_names = [] | |
for config, model in supported_models.items(): | |
# Mapping can now contain tuples of models for the same configuration. | |
if isinstance(model, tuple): | |
supported_models_names.extend([_model.__name__ for _model in model]) | |
else: | |
supported_models_names.append(model.__name__) | |
supported_models = supported_models_names | |
if self.model.__class__.__name__ not in supported_models: | |
raise PipelineException( | |
self.task, | |
self.model.base_model_prefix, | |
f"The model '{self.model.__class__.__name__}' is not supported for {self.task}. Supported models are {supported_models}", | |
) | |
def _parse_and_tokenize( | |
self, inputs, padding=True, add_special_tokens=True, truncation=TruncationStrategy.DO_NOT_TRUNCATE, **kwargs | |
): | |
""" | |
Parse arguments and tokenize | |
""" | |
# Parse arguments | |
inputs = self.tokenizer( | |
inputs, | |
add_special_tokens=add_special_tokens, | |
return_tensors=self.framework, | |
padding=padding, | |
truncation=truncation, | |
) | |
return inputs | |
def __call__(self, *args, **kwargs): | |
inputs = self._parse_and_tokenize(*args, **kwargs) | |
return self._forward(inputs) | |
def _forward(self, inputs, return_tensors=False): | |
""" | |
Internal framework specific forward dispatching | |
Args: | |
inputs: dict holding all the keyword arguments for required by the model forward method. | |
return_tensors: Whether to return native framework (pt/tf) tensors rather than numpy array | |
Returns: | |
Numpy array | |
""" | |
# Encode for forward | |
with self.device_placement(): | |
if self.framework == "tf": | |
# TODO trace model | |
predictions = self.model(inputs.data, training=False)[0] | |
else: | |
with torch.no_grad(): | |
inputs = self.ensure_tensor_on_device(**inputs) | |
predictions = self.model(**inputs)[0].cpu() | |
if return_tensors: | |
return predictions | |
else: | |
return predictions.numpy() | |