Spaces:
Sleeping
Sleeping
File size: 24,467 Bytes
2260825 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 |
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
from ..configuration_utils import PretrainedConfig
from ..feature_extraction_utils import PreTrainedFeatureExtractor
from ..file_utils import is_tf_available, is_torch_available
from ..models.auto.configuration_auto import AutoConfig
from ..models.auto.feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING, AutoFeatureExtractor
from ..models.auto.tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer
from ..tokenization_utils import PreTrainedTokenizer
from ..utils import logging
from .automatic_speech_recognition import AutomaticSpeechRecognitionPipeline
from .base import (
ArgumentHandler,
CsvPipelineDataFormat,
JsonPipelineDataFormat,
PipedPipelineDataFormat,
Pipeline,
PipelineDataFormat,
PipelineException,
get_default_model,
infer_framework_load_model,
)
from .conversational import Conversation, ConversationalPipeline
from .feature_extraction import FeatureExtractionPipeline
from .fill_mask import FillMaskPipeline
from .image_classification import ImageClassificationPipeline
from .question_answering import QuestionAnsweringArgumentHandler, QuestionAnsweringPipeline
from .table_question_answering import TableQuestionAnsweringArgumentHandler, TableQuestionAnsweringPipeline
from .text2text_generation import SummarizationPipeline, Text2TextGenerationPipeline, TranslationPipeline
from .text_classification import TextClassificationPipeline
from .text_generation import TextGenerationPipeline
from .token_classification import (
AggregationStrategy,
NerPipeline,
TokenClassificationArgumentHandler,
TokenClassificationPipeline,
)
from .zero_shot_classification import ZeroShotClassificationArgumentHandler, ZeroShotClassificationPipeline
if is_tf_available():
import tensorflow as tf
from ..models.auto.modeling_tf_auto import (
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
TF_MODEL_WITH_LM_HEAD_MAPPING,
TFAutoModel,
TFAutoModelForCausalLM,
TFAutoModelForMaskedLM,
TFAutoModelForQuestionAnswering,
TFAutoModelForSeq2SeqLM,
TFAutoModelForSequenceClassification,
TFAutoModelForTokenClassification,
)
if is_torch_available():
import torch
from ..models.auto.modeling_auto import (
MODEL_FOR_MASKED_LM_MAPPING,
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
AutoModel,
AutoModelForCausalLM,
AutoModelForImageClassification,
AutoModelForMaskedLM,
AutoModelForQuestionAnswering,
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
AutoModelForTableQuestionAnswering,
AutoModelForTokenClassification,
)
if TYPE_CHECKING:
from ..modeling_tf_utils import TFPreTrainedModel
from ..modeling_utils import PreTrainedModel
logger = logging.get_logger(__name__)
# Register all the supported tasks here
TASK_ALIASES = {
"sentiment-analysis": "text-classification",
"ner": "token-classification",
}
SUPPORTED_TASKS = {
"automatic-speech-recognition": {
"impl": AutomaticSpeechRecognitionPipeline,
"tf": (),
# Only load from `config.architectures`, AutoModelForCTC and AutoModelForConditionalGeneration
# do not exist yet.
"pt": () if is_torch_available() else (),
"default": {"model": {"pt": "facebook/wav2vec2-base-960h"}},
},
"feature-extraction": {
"impl": FeatureExtractionPipeline,
"tf": (TFAutoModel,) if is_tf_available() else (),
"pt": (AutoModel,) if is_torch_available() else (),
"default": {"model": {"pt": "distilbert-base-cased", "tf": "distilbert-base-cased"}},
},
"text-classification": {
"impl": TextClassificationPipeline,
"tf": (TFAutoModelForSequenceClassification,) if is_tf_available() else (),
"pt": (AutoModelForSequenceClassification,) if is_torch_available() else (),
"default": {
"model": {
"pt": "distilbert-base-uncased-finetuned-sst-2-english",
"tf": "distilbert-base-uncased-finetuned-sst-2-english",
},
},
},
"token-classification": {
"impl": TokenClassificationPipeline,
"tf": (TFAutoModelForTokenClassification,) if is_tf_available() else (),
"pt": (AutoModelForTokenClassification,) if is_torch_available() else (),
"default": {
"model": {
"pt": "dbmdz/bert-large-cased-finetuned-conll03-english",
"tf": "dbmdz/bert-large-cased-finetuned-conll03-english",
},
},
},
"question-answering": {
"impl": QuestionAnsweringPipeline,
"tf": (TFAutoModelForQuestionAnswering,) if is_tf_available() else (),
"pt": (AutoModelForQuestionAnswering,) if is_torch_available() else (),
"default": {
"model": {"pt": "distilbert-base-cased-distilled-squad", "tf": "distilbert-base-cased-distilled-squad"},
},
},
"table-question-answering": {
"impl": TableQuestionAnsweringPipeline,
"pt": (AutoModelForTableQuestionAnswering,) if is_torch_available() else (),
"tf": (),
"default": {
"model": {
"pt": "google/tapas-base-finetuned-wtq",
"tokenizer": "google/tapas-base-finetuned-wtq",
"tf": "google/tapas-base-finetuned-wtq",
},
},
},
"fill-mask": {
"impl": FillMaskPipeline,
"tf": (TFAutoModelForMaskedLM,) if is_tf_available() else (),
"pt": (AutoModelForMaskedLM,) if is_torch_available() else (),
"default": {"model": {"pt": "distilroberta-base", "tf": "distilroberta-base"}},
},
"summarization": {
"impl": SummarizationPipeline,
"tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (),
"pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (),
"default": {"model": {"pt": "sshleifer/distilbart-cnn-12-6", "tf": "t5-small"}},
},
# This task is a special case as it's parametrized by SRC, TGT languages.
"translation": {
"impl": TranslationPipeline,
"tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (),
"pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (),
"default": {
("en", "fr"): {"model": {"pt": "t5-base", "tf": "t5-base"}},
("en", "de"): {"model": {"pt": "t5-base", "tf": "t5-base"}},
("en", "ro"): {"model": {"pt": "t5-base", "tf": "t5-base"}},
},
},
"text2text-generation": {
"impl": Text2TextGenerationPipeline,
"tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (),
"pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (),
"default": {"model": {"pt": "t5-base", "tf": "t5-base"}},
},
"text-generation": {
"impl": TextGenerationPipeline,
"tf": (TFAutoModelForCausalLM,) if is_tf_available() else (),
"pt": (AutoModelForCausalLM,) if is_torch_available() else (),
"default": {"model": {"pt": "gpt2", "tf": "gpt2"}},
},
"zero-shot-classification": {
"impl": ZeroShotClassificationPipeline,
"tf": (TFAutoModelForSequenceClassification,) if is_tf_available() else (),
"pt": (AutoModelForSequenceClassification,) if is_torch_available() else (),
"default": {
"model": {"pt": "facebook/bart-large-mnli", "tf": "roberta-large-mnli"},
"config": {"pt": "facebook/bart-large-mnli", "tf": "roberta-large-mnli"},
"tokenizer": {"pt": "facebook/bart-large-mnli", "tf": "roberta-large-mnli"},
},
},
"conversational": {
"impl": ConversationalPipeline,
"tf": (TFAutoModelForSeq2SeqLM, TFAutoModelForCausalLM) if is_tf_available() else (),
"pt": (AutoModelForSeq2SeqLM, AutoModelForCausalLM) if is_torch_available() else (),
"default": {"model": {"pt": "microsoft/DialoGPT-medium", "tf": "microsoft/DialoGPT-medium"}},
},
"image-classification": {
"impl": ImageClassificationPipeline,
"tf": (),
"pt": (AutoModelForImageClassification,) if is_torch_available() else (),
"default": {"model": {"pt": "google/vit-base-patch16-224"}},
},
}
def check_task(task: str) -> Tuple[Dict, Any]:
"""
Checks an incoming task string, to validate it's correct and return the default Pipeline and Model classes, and
default models if they exist.
Args:
task (:obj:`str`):
The task defining which pipeline will be returned. Currently accepted tasks are:
- :obj:`"feature-extraction"`
- :obj:`"text-classification"`
- :obj:`"sentiment-analysis"` (alias of :obj:`"text-classification")
- :obj:`"token-classification"`
- :obj:`"ner"` (alias of :obj:`"token-classification")
- :obj:`"question-answering"`
- :obj:`"fill-mask"`
- :obj:`"summarization"`
- :obj:`"translation_xx_to_yy"`
- :obj:`"translation"`
- :obj:`"text-generation"`
- :obj:`"conversational"`
Returns:
(task_defaults:obj:`dict`, task_options: (:obj:`tuple`, None)) The actual dictionary required to initialize the
pipeline and some extra task options for parametrized tasks like "translation_XX_to_YY"
"""
if task in TASK_ALIASES:
task = TASK_ALIASES[task]
if task in SUPPORTED_TASKS:
targeted_task = SUPPORTED_TASKS[task]
return targeted_task, None
if task.startswith("translation"):
tokens = task.split("_")
if len(tokens) == 4 and tokens[0] == "translation" and tokens[2] == "to":
targeted_task = SUPPORTED_TASKS["translation"]
return targeted_task, (tokens[1], tokens[3])
raise KeyError(f"Invalid translation task {task}, use 'translation_XX_to_YY' format")
raise KeyError(
f"Unknown task {task}, available tasks are {list(SUPPORTED_TASKS.keys()) + ['translation_XX_to_YY']}"
)
def pipeline(
task: str,
model: Optional = None,
config: Optional[Union[str, PretrainedConfig]] = None,
tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None,
feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None,
framework: Optional[str] = None,
revision: Optional[str] = None,
use_fast: bool = True,
use_auth_token: Optional[Union[str, bool]] = None,
model_kwargs: Dict[str, Any] = {},
**kwargs
) -> Pipeline:
"""
Utility factory method to build a :class:`~transformers.Pipeline`.
Pipelines are made of:
- A :doc:`tokenizer <tokenizer>` in charge of mapping raw textual input to token.
- A :doc:`model <model>` to make predictions from the inputs.
- Some (optional) post processing for enhancing model's output.
Args:
task (:obj:`str`):
The task defining which pipeline will be returned. Currently accepted tasks are:
- :obj:`"feature-extraction"`: will return a :class:`~transformers.FeatureExtractionPipeline`.
- :obj:`"text-classification"`: will return a :class:`~transformers.TextClassificationPipeline`.
- :obj:`"sentiment-analysis"`: (alias of :obj:`"text-classification") will return a
:class:`~transformers.TextClassificationPipeline`.
- :obj:`"token-classification"`: will return a :class:`~transformers.TokenClassificationPipeline`.
- :obj:`"ner"` (alias of :obj:`"token-classification"): will return a
:class:`~transformers.TokenClassificationPipeline`.
- :obj:`"question-answering"`: will return a :class:`~transformers.QuestionAnsweringPipeline`.
- :obj:`"fill-mask"`: will return a :class:`~transformers.FillMaskPipeline`.
- :obj:`"summarization"`: will return a :class:`~transformers.SummarizationPipeline`.
- :obj:`"translation_xx_to_yy"`: will return a :class:`~transformers.TranslationPipeline`.
- :obj:`"text2text-generation"`: will return a :class:`~transformers.Text2TextGenerationPipeline`.
- :obj:`"text-generation"`: will return a :class:`~transformers.TextGenerationPipeline`.
- :obj:`"zero-shot-classification:`: will return a :class:`~transformers.ZeroShotClassificationPipeline`.
- :obj:`"conversational"`: will return a :class:`~transformers.ConversationalPipeline`.
model (:obj:`str` or :obj:`~transformers.PreTrainedModel` or :obj:`~transformers.TFPreTrainedModel`, `optional`):
The model that will be used by the pipeline to make predictions. This can be a model identifier or an
actual instance of a pretrained model inheriting from :class:`~transformers.PreTrainedModel` (for PyTorch)
or :class:`~transformers.TFPreTrainedModel` (for TensorFlow).
If not provided, the default for the :obj:`task` will be loaded.
config (:obj:`str` or :obj:`~transformers.PretrainedConfig`, `optional`):
The configuration that will be used by the pipeline to instantiate the model. This can be a model
identifier or an actual pretrained model configuration inheriting from
:class:`~transformers.PretrainedConfig`.
If not provided, the default configuration file for the requested model will be used. That means that if
:obj:`model` is given, its default configuration will be used. However, if :obj:`model` is not supplied,
this :obj:`task`'s default model's config is used instead.
tokenizer (:obj:`str` or :obj:`~transformers.PreTrainedTokenizer`, `optional`):
The tokenizer that will be used by the pipeline to encode data for the model. This can be a model
identifier or an actual pretrained tokenizer inheriting from :class:`~transformers.PreTrainedTokenizer`.
If not provided, the default tokenizer for the given :obj:`model` will be loaded (if it is a string). If
:obj:`model` is not specified or not a string, then the default tokenizer for :obj:`config` is loaded (if
it is a string). However, if :obj:`config` is also not given or not a string, then the default tokenizer
for the given :obj:`task` will be loaded.
feature_extractor (:obj:`str` or :obj:`~transformers.PreTrainedFeatureExtractor`, `optional`):
The feature extractor that will be used by the pipeline to encode data for the model. This can be a model
identifier or an actual pretrained feature extractor inheriting from
:class:`~transformers.PreTrainedFeatureExtractor`.
Feature extractors are used for non-NLP models, such as Speech or Vision models as well as multi-modal
models. Multi-modal models will also require a tokenizer to be passed.
If not provided, the default feature extractor for the given :obj:`model` will be loaded (if it is a
string). If :obj:`model` is not specified or not a string, then the default feature extractor for
:obj:`config` is loaded (if it is a string). However, if :obj:`config` is also not given or not a string,
then the default feature extractor for the given :obj:`task` will be loaded.
framework (:obj:`str`, `optional`):
The framework to use, either :obj:`"pt"` for PyTorch or :obj:`"tf"` for TensorFlow. The specified framework
must be installed.
If no framework is specified, will default to the one currently installed. If no framework is specified and
both frameworks are installed, will default to the framework of the :obj:`model`, or to PyTorch if no model
is provided.
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
When passing a task name or a string model identifier: The specific model version to use. It can be a
branch name, a tag name, or a commit id, since we use a git-based system for storing models and other
artifacts on huggingface.co, so ``revision`` can be any identifier allowed by git.
use_fast (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to use a Fast tokenizer if possible (a :class:`~transformers.PreTrainedTokenizerFast`).
use_auth_token (:obj:`str` or `bool`, `optional`):
The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token
generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`).
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
model_kwargs:
Additional dictionary of keyword arguments passed along to the model's :obj:`from_pretrained(...,
**model_kwargs)` function.
kwargs:
Additional keyword arguments passed along to the specific pipeline init (see the documentation for the
corresponding pipeline class for possible values).
Returns:
:class:`~transformers.Pipeline`: A suitable pipeline for the task.
Examples::
>>> from transformers import pipeline, AutoModelForTokenClassification, AutoTokenizer
>>> # Sentiment analysis pipeline
>>> pipeline('sentiment-analysis')
>>> # Question answering pipeline, specifying the checkpoint identifier
>>> pipeline('question-answering', model='distilbert-base-cased-distilled-squad', tokenizer='bert-base-cased')
>>> # Named entity recognition pipeline, passing in a specific model and tokenizer
>>> model = AutoModelForTokenClassification.from_pretrained("dbmdz/bert-large-cased-finetuned-conll03-english")
>>> tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
>>> pipeline('ner', model=model, tokenizer=tokenizer)
"""
if model is None and tokenizer is not None:
raise RuntimeError(
"Impossible to instantiate a pipeline with tokenizer specified but not the model "
"as the provided tokenizer may not be compatible with the default model. "
"Please provide a PreTrainedModel class or a path/identifier to a pretrained model when providing tokenizer."
)
if model is None and feature_extractor is not None:
raise RuntimeError(
"Impossible to instantiate a pipeline with feature_extractor specified but not the model "
"as the provided feature_extractor may not be compatible with the default model. "
"Please provide a PreTrainedModel class or a path/identifier to a pretrained model when providing feature_extractor."
)
# Retrieve the task
targeted_task, task_options = check_task(task)
task_class = targeted_task["impl"]
# Use default model/config/tokenizer for the task if no model is provided
if model is None:
# At that point framework might still be undetermined
model = get_default_model(targeted_task, framework, task_options)
# Config is the primordial information item.
# Instantiate config if needed
if isinstance(config, str):
config = AutoConfig.from_pretrained(config, revision=revision, _from_pipeline=task, **model_kwargs)
elif config is None and isinstance(model, str):
config = AutoConfig.from_pretrained(model, revision=revision, _from_pipeline=task, **model_kwargs)
model_name = model if isinstance(model, str) else None
# Retrieve use_auth_token and add it to model_kwargs to be used in .from_pretrained
model_kwargs["use_auth_token"] = model_kwargs.get("use_auth_token", use_auth_token)
# Infer the framework from the model
# Forced if framework already defined, inferred if it's None
# Will load the correct model if possible
model_classes = {"tf": targeted_task["tf"], "pt": targeted_task["pt"]}
framework, model = infer_framework_load_model(
model,
model_classes=model_classes,
config=config,
framework=framework,
revision=revision,
task=task,
**model_kwargs,
)
model_config = model.config
load_tokenizer = type(model_config) in TOKENIZER_MAPPING
load_feature_extractor = type(model_config) in FEATURE_EXTRACTOR_MAPPING
if load_tokenizer:
# Try to infer tokenizer from model or config name (if provided as str)
if tokenizer is None:
if isinstance(model_name, str):
tokenizer = model_name
elif isinstance(config, str):
tokenizer = config
else:
# Impossible to guess what is the right tokenizer here
raise Exception(
"Impossible to guess which tokenizer to use. "
"Please provide a PreTrainedTokenizer class or a path/identifier to a pretrained tokenizer."
)
# Instantiate tokenizer if needed
if isinstance(tokenizer, (str, tuple)):
if isinstance(tokenizer, tuple):
# For tuple we have (tokenizer name, {kwargs})
use_fast = tokenizer[1].pop("use_fast", use_fast)
tokenizer_identifier = tokenizer[0]
tokenizer_kwargs = tokenizer[1]
else:
tokenizer_identifier = tokenizer
tokenizer_kwargs = model_kwargs
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_identifier, revision=revision, use_fast=use_fast, _from_pipeline=task, **tokenizer_kwargs
)
if load_feature_extractor:
# Try to infer feature extractor from model or config name (if provided as str)
if feature_extractor is None:
if isinstance(model_name, str):
feature_extractor = model_name
elif isinstance(config, str):
feature_extractor = config
else:
# Impossible to guess what is the right feature_extractor here
raise Exception(
"Impossible to guess which feature extractor to use. "
"Please provide a PreTrainedFeatureExtractor class or a path/identifier "
"to a pretrained feature extractor."
)
# Instantiate feature_extractor if needed
if isinstance(feature_extractor, (str, tuple)):
feature_extractor = AutoFeatureExtractor.from_pretrained(
feature_extractor, revision=revision, _from_pipeline=task, **model_kwargs
)
if task == "translation" and model.config.task_specific_params:
for key in model.config.task_specific_params:
if key.startswith("translation"):
task = key
warnings.warn(
f'"translation" task was used, instead of "translation_XX_to_YY", defaulting to "{task}"',
UserWarning,
)
break
if tokenizer is not None:
kwargs["tokenizer"] = tokenizer
if feature_extractor is not None:
kwargs["feature_extractor"] = feature_extractor
return task_class(model=model, framework=framework, task=task, **kwargs)
|