Spaces:
Sleeping
Sleeping
from typing import TYPE_CHECKING, Optional, Union | |
from ..modelcard import ModelCard | |
from ..tokenization_utils import PreTrainedTokenizer | |
from .base import ArgumentHandler, Pipeline | |
if TYPE_CHECKING: | |
from ..modeling_tf_utils import TFPreTrainedModel | |
from ..modeling_utils import PreTrainedModel | |
# Can't use @add_end_docstrings(PIPELINE_INIT_ARGS) here because this one does not accept `binary_output` | |
class FeatureExtractionPipeline(Pipeline): | |
""" | |
Feature extraction pipeline using no model head. This pipeline extracts the hidden states from the base | |
transformer, which can be used as features in downstream tasks. | |
This feature extraction pipeline can currently be loaded from :func:`~transformers.pipeline` using the task | |
identifier: :obj:`"feature-extraction"`. | |
All models may be used for this pipeline. See a list of all models, including community-contributed models on | |
`huggingface.co/models <https://huggingface.co/models>`__. | |
Arguments: | |
model (:obj:`~transformers.PreTrainedModel` or :obj:`~transformers.TFPreTrainedModel`): | |
The model that will be used by the pipeline to make predictions. This needs to be a model inheriting from | |
:class:`~transformers.PreTrainedModel` for PyTorch and :class:`~transformers.TFPreTrainedModel` for | |
TensorFlow. | |
tokenizer (:obj:`~transformers.PreTrainedTokenizer`): | |
The tokenizer that will be used by the pipeline to encode data for the model. This object inherits from | |
:class:`~transformers.PreTrainedTokenizer`. | |
modelcard (:obj:`str` or :class:`~transformers.ModelCard`, `optional`): | |
Model card attributed to the model for this pipeline. | |
framework (:obj:`str`, `optional`): | |
The framework to use, either :obj:`"pt"` for PyTorch or :obj:`"tf"` for TensorFlow. The specified framework | |
must be installed. | |
If no framework is specified, will default to the one currently installed. If no framework is specified and | |
both frameworks are installed, will default to the framework of the :obj:`model`, or to PyTorch if no model | |
is provided. | |
task (:obj:`str`, defaults to :obj:`""`): | |
A task-identifier for the pipeline. | |
args_parser (:class:`~transformers.pipelines.ArgumentHandler`, `optional`): | |
Reference to the object in charge of parsing supplied pipeline parameters. | |
device (:obj:`int`, `optional`, defaults to -1): | |
Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, a positive will run the model on | |
the associated CUDA device id. | |
""" | |
def __init__( | |
self, | |
model: Union["PreTrainedModel", "TFPreTrainedModel"], | |
tokenizer: PreTrainedTokenizer, | |
modelcard: Optional[ModelCard] = None, | |
framework: Optional[str] = None, | |
args_parser: ArgumentHandler = None, | |
device: int = -1, | |
task: str = "", | |
): | |
super().__init__( | |
model=model, | |
tokenizer=tokenizer, | |
modelcard=modelcard, | |
framework=framework, | |
args_parser=args_parser, | |
device=device, | |
binary_output=True, | |
task=task, | |
) | |
def __call__(self, *args, **kwargs): | |
""" | |
Extract the features of the input(s). | |
Args: | |
args (:obj:`str` or :obj:`List[str]`): One or several texts (or one list of texts) to get the features of. | |
Return: | |
A nested list of :obj:`float`: The features computed by the model. | |
""" | |
return super().__call__(*args, **kwargs).tolist() | |