Kaizouku's picture
Upload 564 files
2260825 verified
from typing import TYPE_CHECKING, Optional, Union
from ..modelcard import ModelCard
from ..tokenization_utils import PreTrainedTokenizer
from .base import ArgumentHandler, Pipeline
if TYPE_CHECKING:
from ..modeling_tf_utils import TFPreTrainedModel
from ..modeling_utils import PreTrainedModel
# Can't use @add_end_docstrings(PIPELINE_INIT_ARGS) here because this one does not accept `binary_output`
class FeatureExtractionPipeline(Pipeline):
"""
Feature extraction pipeline using no model head. This pipeline extracts the hidden states from the base
transformer, which can be used as features in downstream tasks.
This feature extraction pipeline can currently be loaded from :func:`~transformers.pipeline` using the task
identifier: :obj:`"feature-extraction"`.
All models may be used for this pipeline. See a list of all models, including community-contributed models on
`huggingface.co/models <https://huggingface.co/models>`__.
Arguments:
model (:obj:`~transformers.PreTrainedModel` or :obj:`~transformers.TFPreTrainedModel`):
The model that will be used by the pipeline to make predictions. This needs to be a model inheriting from
:class:`~transformers.PreTrainedModel` for PyTorch and :class:`~transformers.TFPreTrainedModel` for
TensorFlow.
tokenizer (:obj:`~transformers.PreTrainedTokenizer`):
The tokenizer that will be used by the pipeline to encode data for the model. This object inherits from
:class:`~transformers.PreTrainedTokenizer`.
modelcard (:obj:`str` or :class:`~transformers.ModelCard`, `optional`):
Model card attributed to the model for this pipeline.
framework (:obj:`str`, `optional`):
The framework to use, either :obj:`"pt"` for PyTorch or :obj:`"tf"` for TensorFlow. The specified framework
must be installed.
If no framework is specified, will default to the one currently installed. If no framework is specified and
both frameworks are installed, will default to the framework of the :obj:`model`, or to PyTorch if no model
is provided.
task (:obj:`str`, defaults to :obj:`""`):
A task-identifier for the pipeline.
args_parser (:class:`~transformers.pipelines.ArgumentHandler`, `optional`):
Reference to the object in charge of parsing supplied pipeline parameters.
device (:obj:`int`, `optional`, defaults to -1):
Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, a positive will run the model on
the associated CUDA device id.
"""
def __init__(
self,
model: Union["PreTrainedModel", "TFPreTrainedModel"],
tokenizer: PreTrainedTokenizer,
modelcard: Optional[ModelCard] = None,
framework: Optional[str] = None,
args_parser: ArgumentHandler = None,
device: int = -1,
task: str = "",
):
super().__init__(
model=model,
tokenizer=tokenizer,
modelcard=modelcard,
framework=framework,
args_parser=args_parser,
device=device,
binary_output=True,
task=task,
)
def __call__(self, *args, **kwargs):
"""
Extract the features of the input(s).
Args:
args (:obj:`str` or :obj:`List[str]`): One or several texts (or one list of texts) to get the features of.
Return:
A nested list of :obj:`float`: The features computed by the model.
"""
return super().__call__(*args, **kwargs).tolist()