Spaces:
Sleeping
Sleeping
# coding=utf-8 | |
# Copyright 2018 The HuggingFace Inc. team. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
""" Auto Model class. """ | |
import warnings | |
from collections import OrderedDict | |
from ...utils import logging | |
# Add modeling imports here | |
from ..albert.modeling_tf_albert import ( | |
TFAlbertForMaskedLM, | |
TFAlbertForMultipleChoice, | |
TFAlbertForPreTraining, | |
TFAlbertForQuestionAnswering, | |
TFAlbertForSequenceClassification, | |
TFAlbertForTokenClassification, | |
TFAlbertModel, | |
) | |
from ..bart.modeling_tf_bart import TFBartForConditionalGeneration, TFBartModel | |
from ..bert.modeling_tf_bert import ( | |
TFBertForMaskedLM, | |
TFBertForMultipleChoice, | |
TFBertForNextSentencePrediction, | |
TFBertForPreTraining, | |
TFBertForQuestionAnswering, | |
TFBertForSequenceClassification, | |
TFBertForTokenClassification, | |
TFBertLMHeadModel, | |
TFBertModel, | |
) | |
from ..blenderbot.modeling_tf_blenderbot import TFBlenderbotForConditionalGeneration, TFBlenderbotModel | |
from ..blenderbot_small.modeling_tf_blenderbot_small import ( | |
TFBlenderbotSmallForConditionalGeneration, | |
TFBlenderbotSmallModel, | |
) | |
from ..camembert.modeling_tf_camembert import ( | |
TFCamembertForMaskedLM, | |
TFCamembertForMultipleChoice, | |
TFCamembertForQuestionAnswering, | |
TFCamembertForSequenceClassification, | |
TFCamembertForTokenClassification, | |
TFCamembertModel, | |
) | |
from ..convbert.modeling_tf_convbert import ( | |
TFConvBertForMaskedLM, | |
TFConvBertForMultipleChoice, | |
TFConvBertForQuestionAnswering, | |
TFConvBertForSequenceClassification, | |
TFConvBertForTokenClassification, | |
TFConvBertModel, | |
) | |
from ..ctrl.modeling_tf_ctrl import TFCTRLForSequenceClassification, TFCTRLLMHeadModel, TFCTRLModel | |
from ..distilbert.modeling_tf_distilbert import ( | |
TFDistilBertForMaskedLM, | |
TFDistilBertForMultipleChoice, | |
TFDistilBertForQuestionAnswering, | |
TFDistilBertForSequenceClassification, | |
TFDistilBertForTokenClassification, | |
TFDistilBertModel, | |
) | |
from ..dpr.modeling_tf_dpr import TFDPRQuestionEncoder | |
from ..electra.modeling_tf_electra import ( | |
TFElectraForMaskedLM, | |
TFElectraForMultipleChoice, | |
TFElectraForPreTraining, | |
TFElectraForQuestionAnswering, | |
TFElectraForSequenceClassification, | |
TFElectraForTokenClassification, | |
TFElectraModel, | |
) | |
from ..flaubert.modeling_tf_flaubert import ( | |
TFFlaubertForMultipleChoice, | |
TFFlaubertForQuestionAnsweringSimple, | |
TFFlaubertForSequenceClassification, | |
TFFlaubertForTokenClassification, | |
TFFlaubertModel, | |
TFFlaubertWithLMHeadModel, | |
) | |
from ..funnel.modeling_tf_funnel import ( | |
TFFunnelBaseModel, | |
TFFunnelForMaskedLM, | |
TFFunnelForMultipleChoice, | |
TFFunnelForPreTraining, | |
TFFunnelForQuestionAnswering, | |
TFFunnelForSequenceClassification, | |
TFFunnelForTokenClassification, | |
TFFunnelModel, | |
) | |
from ..gpt2.modeling_tf_gpt2 import TFGPT2ForSequenceClassification, TFGPT2LMHeadModel, TFGPT2Model | |
from ..hubert.modeling_tf_hubert import TFHubertModel | |
from ..layoutlm.modeling_tf_layoutlm import ( | |
TFLayoutLMForMaskedLM, | |
TFLayoutLMForSequenceClassification, | |
TFLayoutLMForTokenClassification, | |
TFLayoutLMModel, | |
) | |
from ..led.modeling_tf_led import TFLEDForConditionalGeneration, TFLEDModel | |
from ..longformer.modeling_tf_longformer import ( | |
TFLongformerForMaskedLM, | |
TFLongformerForMultipleChoice, | |
TFLongformerForQuestionAnswering, | |
TFLongformerForSequenceClassification, | |
TFLongformerForTokenClassification, | |
TFLongformerModel, | |
) | |
from ..lxmert.modeling_tf_lxmert import TFLxmertForPreTraining, TFLxmertModel | |
from ..marian.modeling_tf_marian import TFMarianModel, TFMarianMTModel | |
from ..mbart.modeling_tf_mbart import TFMBartForConditionalGeneration, TFMBartModel | |
from ..mobilebert.modeling_tf_mobilebert import ( | |
TFMobileBertForMaskedLM, | |
TFMobileBertForMultipleChoice, | |
TFMobileBertForNextSentencePrediction, | |
TFMobileBertForPreTraining, | |
TFMobileBertForQuestionAnswering, | |
TFMobileBertForSequenceClassification, | |
TFMobileBertForTokenClassification, | |
TFMobileBertModel, | |
) | |
from ..mpnet.modeling_tf_mpnet import ( | |
TFMPNetForMaskedLM, | |
TFMPNetForMultipleChoice, | |
TFMPNetForQuestionAnswering, | |
TFMPNetForSequenceClassification, | |
TFMPNetForTokenClassification, | |
TFMPNetModel, | |
) | |
from ..mt5.modeling_tf_mt5 import TFMT5ForConditionalGeneration, TFMT5Model | |
from ..openai.modeling_tf_openai import TFOpenAIGPTForSequenceClassification, TFOpenAIGPTLMHeadModel, TFOpenAIGPTModel | |
from ..pegasus.modeling_tf_pegasus import TFPegasusForConditionalGeneration, TFPegasusModel | |
from ..roberta.modeling_tf_roberta import ( | |
TFRobertaForMaskedLM, | |
TFRobertaForMultipleChoice, | |
TFRobertaForQuestionAnswering, | |
TFRobertaForSequenceClassification, | |
TFRobertaForTokenClassification, | |
TFRobertaModel, | |
) | |
from ..roformer.modeling_tf_roformer import ( | |
TFRoFormerForCausalLM, | |
TFRoFormerForMaskedLM, | |
TFRoFormerForMultipleChoice, | |
TFRoFormerForQuestionAnswering, | |
TFRoFormerForSequenceClassification, | |
TFRoFormerForTokenClassification, | |
TFRoFormerModel, | |
) | |
from ..t5.modeling_tf_t5 import TFT5ForConditionalGeneration, TFT5Model | |
from ..transfo_xl.modeling_tf_transfo_xl import ( | |
TFTransfoXLForSequenceClassification, | |
TFTransfoXLLMHeadModel, | |
TFTransfoXLModel, | |
) | |
from ..wav2vec2.modeling_tf_wav2vec2 import TFWav2Vec2Model | |
from ..xlm.modeling_tf_xlm import ( | |
TFXLMForMultipleChoice, | |
TFXLMForQuestionAnsweringSimple, | |
TFXLMForSequenceClassification, | |
TFXLMForTokenClassification, | |
TFXLMModel, | |
TFXLMWithLMHeadModel, | |
) | |
from ..xlm_roberta.modeling_tf_xlm_roberta import ( | |
TFXLMRobertaForMaskedLM, | |
TFXLMRobertaForMultipleChoice, | |
TFXLMRobertaForQuestionAnswering, | |
TFXLMRobertaForSequenceClassification, | |
TFXLMRobertaForTokenClassification, | |
TFXLMRobertaModel, | |
) | |
from ..xlnet.modeling_tf_xlnet import ( | |
TFXLNetForMultipleChoice, | |
TFXLNetForQuestionAnsweringSimple, | |
TFXLNetForSequenceClassification, | |
TFXLNetForTokenClassification, | |
TFXLNetLMHeadModel, | |
TFXLNetModel, | |
) | |
from .auto_factory import _BaseAutoModelClass, auto_class_update | |
from .configuration_auto import ( | |
AlbertConfig, | |
BartConfig, | |
BertConfig, | |
BlenderbotConfig, | |
BlenderbotSmallConfig, | |
CamembertConfig, | |
ConvBertConfig, | |
CTRLConfig, | |
DistilBertConfig, | |
DPRConfig, | |
ElectraConfig, | |
FlaubertConfig, | |
FunnelConfig, | |
GPT2Config, | |
HubertConfig, | |
LayoutLMConfig, | |
LEDConfig, | |
LongformerConfig, | |
LxmertConfig, | |
MarianConfig, | |
MBartConfig, | |
MobileBertConfig, | |
MPNetConfig, | |
MT5Config, | |
OpenAIGPTConfig, | |
PegasusConfig, | |
RobertaConfig, | |
RoFormerConfig, | |
T5Config, | |
TransfoXLConfig, | |
Wav2Vec2Config, | |
XLMConfig, | |
XLMRobertaConfig, | |
XLNetConfig, | |
) | |
logger = logging.get_logger(__name__) | |
TF_MODEL_MAPPING = OrderedDict( | |
[ | |
# Base model mapping | |
(RoFormerConfig, TFRoFormerModel), | |
(ConvBertConfig, TFConvBertModel), | |
(LEDConfig, TFLEDModel), | |
(LxmertConfig, TFLxmertModel), | |
(MT5Config, TFMT5Model), | |
(T5Config, TFT5Model), | |
(DistilBertConfig, TFDistilBertModel), | |
(AlbertConfig, TFAlbertModel), | |
(BartConfig, TFBartModel), | |
(CamembertConfig, TFCamembertModel), | |
(XLMRobertaConfig, TFXLMRobertaModel), | |
(LongformerConfig, TFLongformerModel), | |
(RobertaConfig, TFRobertaModel), | |
(LayoutLMConfig, TFLayoutLMModel), | |
(BertConfig, TFBertModel), | |
(OpenAIGPTConfig, TFOpenAIGPTModel), | |
(GPT2Config, TFGPT2Model), | |
(MobileBertConfig, TFMobileBertModel), | |
(TransfoXLConfig, TFTransfoXLModel), | |
(XLNetConfig, TFXLNetModel), | |
(FlaubertConfig, TFFlaubertModel), | |
(XLMConfig, TFXLMModel), | |
(CTRLConfig, TFCTRLModel), | |
(ElectraConfig, TFElectraModel), | |
(FunnelConfig, (TFFunnelModel, TFFunnelBaseModel)), | |
(DPRConfig, TFDPRQuestionEncoder), | |
(MPNetConfig, TFMPNetModel), | |
(BartConfig, TFBartModel), | |
(MBartConfig, TFMBartModel), | |
(MarianConfig, TFMarianModel), | |
(PegasusConfig, TFPegasusModel), | |
(BlenderbotConfig, TFBlenderbotModel), | |
(BlenderbotSmallConfig, TFBlenderbotSmallModel), | |
(Wav2Vec2Config, TFWav2Vec2Model), | |
(HubertConfig, TFHubertModel), | |
] | |
) | |
TF_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict( | |
[ | |
# Model for pre-training mapping | |
(LxmertConfig, TFLxmertForPreTraining), | |
(T5Config, TFT5ForConditionalGeneration), | |
(DistilBertConfig, TFDistilBertForMaskedLM), | |
(AlbertConfig, TFAlbertForPreTraining), | |
(BartConfig, TFBartForConditionalGeneration), | |
(CamembertConfig, TFCamembertForMaskedLM), | |
(XLMRobertaConfig, TFXLMRobertaForMaskedLM), | |
(RobertaConfig, TFRobertaForMaskedLM), | |
(LayoutLMConfig, TFLayoutLMForMaskedLM), | |
(BertConfig, TFBertForPreTraining), | |
(OpenAIGPTConfig, TFOpenAIGPTLMHeadModel), | |
(GPT2Config, TFGPT2LMHeadModel), | |
(MobileBertConfig, TFMobileBertForPreTraining), | |
(TransfoXLConfig, TFTransfoXLLMHeadModel), | |
(XLNetConfig, TFXLNetLMHeadModel), | |
(FlaubertConfig, TFFlaubertWithLMHeadModel), | |
(XLMConfig, TFXLMWithLMHeadModel), | |
(CTRLConfig, TFCTRLLMHeadModel), | |
(ElectraConfig, TFElectraForPreTraining), | |
(FunnelConfig, TFFunnelForPreTraining), | |
(MPNetConfig, TFMPNetForMaskedLM), | |
] | |
) | |
TF_MODEL_WITH_LM_HEAD_MAPPING = OrderedDict( | |
[ | |
# Model with LM heads mapping | |
(RoFormerConfig, TFRoFormerForMaskedLM), | |
(ConvBertConfig, TFConvBertForMaskedLM), | |
(LEDConfig, TFLEDForConditionalGeneration), | |
(T5Config, TFT5ForConditionalGeneration), | |
(DistilBertConfig, TFDistilBertForMaskedLM), | |
(AlbertConfig, TFAlbertForMaskedLM), | |
(MarianConfig, TFMarianMTModel), | |
(BartConfig, TFBartForConditionalGeneration), | |
(CamembertConfig, TFCamembertForMaskedLM), | |
(XLMRobertaConfig, TFXLMRobertaForMaskedLM), | |
(LongformerConfig, TFLongformerForMaskedLM), | |
(RobertaConfig, TFRobertaForMaskedLM), | |
(LayoutLMConfig, TFLayoutLMForMaskedLM), | |
(BertConfig, TFBertForMaskedLM), | |
(OpenAIGPTConfig, TFOpenAIGPTLMHeadModel), | |
(GPT2Config, TFGPT2LMHeadModel), | |
(MobileBertConfig, TFMobileBertForMaskedLM), | |
(TransfoXLConfig, TFTransfoXLLMHeadModel), | |
(XLNetConfig, TFXLNetLMHeadModel), | |
(FlaubertConfig, TFFlaubertWithLMHeadModel), | |
(XLMConfig, TFXLMWithLMHeadModel), | |
(CTRLConfig, TFCTRLLMHeadModel), | |
(ElectraConfig, TFElectraForMaskedLM), | |
(FunnelConfig, TFFunnelForMaskedLM), | |
(MPNetConfig, TFMPNetForMaskedLM), | |
] | |
) | |
TF_MODEL_FOR_CAUSAL_LM_MAPPING = OrderedDict( | |
[ | |
# Model for Causal LM mapping | |
(RoFormerConfig, TFRoFormerForCausalLM), | |
(BertConfig, TFBertLMHeadModel), | |
(OpenAIGPTConfig, TFOpenAIGPTLMHeadModel), | |
(GPT2Config, TFGPT2LMHeadModel), | |
(TransfoXLConfig, TFTransfoXLLMHeadModel), | |
(XLNetConfig, TFXLNetLMHeadModel), | |
( | |
XLMConfig, | |
TFXLMWithLMHeadModel, | |
), # XLM can be MLM and CLM => model should be split similar to BERT; leave here for now | |
(CTRLConfig, TFCTRLLMHeadModel), | |
] | |
) | |
TF_MODEL_FOR_MASKED_LM_MAPPING = OrderedDict( | |
[ | |
# Model for Masked LM mapping | |
(RoFormerConfig, TFRoFormerForMaskedLM), | |
(ConvBertConfig, TFConvBertForMaskedLM), | |
(DistilBertConfig, TFDistilBertForMaskedLM), | |
(AlbertConfig, TFAlbertForMaskedLM), | |
(CamembertConfig, TFCamembertForMaskedLM), | |
(XLMRobertaConfig, TFXLMRobertaForMaskedLM), | |
(LongformerConfig, TFLongformerForMaskedLM), | |
(RobertaConfig, TFRobertaForMaskedLM), | |
(LayoutLMConfig, TFLayoutLMForMaskedLM), | |
(BertConfig, TFBertForMaskedLM), | |
(MobileBertConfig, TFMobileBertForMaskedLM), | |
(FlaubertConfig, TFFlaubertWithLMHeadModel), | |
(XLMConfig, TFXLMWithLMHeadModel), | |
(ElectraConfig, TFElectraForMaskedLM), | |
(FunnelConfig, TFFunnelForMaskedLM), | |
(MPNetConfig, TFMPNetForMaskedLM), | |
] | |
) | |
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict( | |
[ | |
# Model for Seq2Seq Causal LM mapping | |
(LEDConfig, TFLEDForConditionalGeneration), | |
(MT5Config, TFMT5ForConditionalGeneration), | |
(T5Config, TFT5ForConditionalGeneration), | |
(MarianConfig, TFMarianMTModel), | |
(MBartConfig, TFMBartForConditionalGeneration), | |
(PegasusConfig, TFPegasusForConditionalGeneration), | |
(BlenderbotConfig, TFBlenderbotForConditionalGeneration), | |
(BlenderbotSmallConfig, TFBlenderbotSmallForConditionalGeneration), | |
(BartConfig, TFBartForConditionalGeneration), | |
] | |
) | |
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict( | |
[ | |
# Model for Sequence Classification mapping | |
(RoFormerConfig, TFRoFormerForSequenceClassification), | |
(ConvBertConfig, TFConvBertForSequenceClassification), | |
(DistilBertConfig, TFDistilBertForSequenceClassification), | |
(AlbertConfig, TFAlbertForSequenceClassification), | |
(CamembertConfig, TFCamembertForSequenceClassification), | |
(XLMRobertaConfig, TFXLMRobertaForSequenceClassification), | |
(LongformerConfig, TFLongformerForSequenceClassification), | |
(RobertaConfig, TFRobertaForSequenceClassification), | |
(LayoutLMConfig, TFLayoutLMForSequenceClassification), | |
(BertConfig, TFBertForSequenceClassification), | |
(XLNetConfig, TFXLNetForSequenceClassification), | |
(MobileBertConfig, TFMobileBertForSequenceClassification), | |
(FlaubertConfig, TFFlaubertForSequenceClassification), | |
(XLMConfig, TFXLMForSequenceClassification), | |
(ElectraConfig, TFElectraForSequenceClassification), | |
(FunnelConfig, TFFunnelForSequenceClassification), | |
(GPT2Config, TFGPT2ForSequenceClassification), | |
(MPNetConfig, TFMPNetForSequenceClassification), | |
(OpenAIGPTConfig, TFOpenAIGPTForSequenceClassification), | |
(TransfoXLConfig, TFTransfoXLForSequenceClassification), | |
(CTRLConfig, TFCTRLForSequenceClassification), | |
] | |
) | |
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict( | |
[ | |
# Model for Question Answering mapping | |
(RoFormerConfig, TFRoFormerForQuestionAnswering), | |
(ConvBertConfig, TFConvBertForQuestionAnswering), | |
(DistilBertConfig, TFDistilBertForQuestionAnswering), | |
(AlbertConfig, TFAlbertForQuestionAnswering), | |
(CamembertConfig, TFCamembertForQuestionAnswering), | |
(XLMRobertaConfig, TFXLMRobertaForQuestionAnswering), | |
(LongformerConfig, TFLongformerForQuestionAnswering), | |
(RobertaConfig, TFRobertaForQuestionAnswering), | |
(BertConfig, TFBertForQuestionAnswering), | |
(XLNetConfig, TFXLNetForQuestionAnsweringSimple), | |
(MobileBertConfig, TFMobileBertForQuestionAnswering), | |
(FlaubertConfig, TFFlaubertForQuestionAnsweringSimple), | |
(XLMConfig, TFXLMForQuestionAnsweringSimple), | |
(ElectraConfig, TFElectraForQuestionAnswering), | |
(FunnelConfig, TFFunnelForQuestionAnswering), | |
(MPNetConfig, TFMPNetForQuestionAnswering), | |
] | |
) | |
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict( | |
[ | |
# Model for Token Classification mapping | |
(RoFormerConfig, TFRoFormerForTokenClassification), | |
(ConvBertConfig, TFConvBertForTokenClassification), | |
(DistilBertConfig, TFDistilBertForTokenClassification), | |
(AlbertConfig, TFAlbertForTokenClassification), | |
(CamembertConfig, TFCamembertForTokenClassification), | |
(FlaubertConfig, TFFlaubertForTokenClassification), | |
(XLMConfig, TFXLMForTokenClassification), | |
(XLMRobertaConfig, TFXLMRobertaForTokenClassification), | |
(LongformerConfig, TFLongformerForTokenClassification), | |
(RobertaConfig, TFRobertaForTokenClassification), | |
(LayoutLMConfig, TFLayoutLMForTokenClassification), | |
(BertConfig, TFBertForTokenClassification), | |
(MobileBertConfig, TFMobileBertForTokenClassification), | |
(XLNetConfig, TFXLNetForTokenClassification), | |
(ElectraConfig, TFElectraForTokenClassification), | |
(FunnelConfig, TFFunnelForTokenClassification), | |
(MPNetConfig, TFMPNetForTokenClassification), | |
] | |
) | |
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict( | |
[ | |
# Model for Multiple Choice mapping | |
(RoFormerConfig, TFRoFormerForMultipleChoice), | |
(ConvBertConfig, TFConvBertForMultipleChoice), | |
(CamembertConfig, TFCamembertForMultipleChoice), | |
(XLMConfig, TFXLMForMultipleChoice), | |
(XLMRobertaConfig, TFXLMRobertaForMultipleChoice), | |
(LongformerConfig, TFLongformerForMultipleChoice), | |
(RobertaConfig, TFRobertaForMultipleChoice), | |
(BertConfig, TFBertForMultipleChoice), | |
(DistilBertConfig, TFDistilBertForMultipleChoice), | |
(MobileBertConfig, TFMobileBertForMultipleChoice), | |
(XLNetConfig, TFXLNetForMultipleChoice), | |
(FlaubertConfig, TFFlaubertForMultipleChoice), | |
(AlbertConfig, TFAlbertForMultipleChoice), | |
(ElectraConfig, TFElectraForMultipleChoice), | |
(FunnelConfig, TFFunnelForMultipleChoice), | |
(MPNetConfig, TFMPNetForMultipleChoice), | |
] | |
) | |
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = OrderedDict( | |
[ | |
(BertConfig, TFBertForNextSentencePrediction), | |
(MobileBertConfig, TFMobileBertForNextSentencePrediction), | |
] | |
) | |
class TFAutoModel(_BaseAutoModelClass): | |
_model_mapping = TF_MODEL_MAPPING | |
TFAutoModel = auto_class_update(TFAutoModel) | |
class TFAutoModelForPreTraining(_BaseAutoModelClass): | |
_model_mapping = TF_MODEL_FOR_PRETRAINING_MAPPING | |
TFAutoModelForPreTraining = auto_class_update(TFAutoModelForPreTraining, head_doc="pretraining") | |
# Private on purpose, the public class will add the deprecation warnings. | |
class _TFAutoModelWithLMHead(_BaseAutoModelClass): | |
_model_mapping = TF_MODEL_WITH_LM_HEAD_MAPPING | |
_TFAutoModelWithLMHead = auto_class_update(_TFAutoModelWithLMHead, head_doc="language modeling") | |
class TFAutoModelForCausalLM(_BaseAutoModelClass): | |
_model_mapping = TF_MODEL_FOR_CAUSAL_LM_MAPPING | |
TFAutoModelForCausalLM = auto_class_update(TFAutoModelForCausalLM, head_doc="causal language modeling") | |
class TFAutoModelForMaskedLM(_BaseAutoModelClass): | |
_model_mapping = TF_MODEL_FOR_MASKED_LM_MAPPING | |
TFAutoModelForMaskedLM = auto_class_update(TFAutoModelForMaskedLM, head_doc="masked language modeling") | |
class TFAutoModelForSeq2SeqLM(_BaseAutoModelClass): | |
_model_mapping = TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING | |
TFAutoModelForSeq2SeqLM = auto_class_update( | |
TFAutoModelForSeq2SeqLM, head_doc="sequence-to-sequence language modeling", checkpoint_for_example="t5-base" | |
) | |
class TFAutoModelForSequenceClassification(_BaseAutoModelClass): | |
_model_mapping = TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING | |
TFAutoModelForSequenceClassification = auto_class_update( | |
TFAutoModelForSequenceClassification, head_doc="sequence classification" | |
) | |
class TFAutoModelForQuestionAnswering(_BaseAutoModelClass): | |
_model_mapping = TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING | |
TFAutoModelForQuestionAnswering = auto_class_update(TFAutoModelForQuestionAnswering, head_doc="question answering") | |
class TFAutoModelForTokenClassification(_BaseAutoModelClass): | |
_model_mapping = TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING | |
TFAutoModelForTokenClassification = auto_class_update( | |
TFAutoModelForTokenClassification, head_doc="token classification" | |
) | |
class TFAutoModelForMultipleChoice(_BaseAutoModelClass): | |
_model_mapping = TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING | |
TFAutoModelForMultipleChoice = auto_class_update(TFAutoModelForMultipleChoice, head_doc="multiple choice") | |
class TFAutoModelForNextSentencePrediction(_BaseAutoModelClass): | |
_model_mapping = TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING | |
TFAutoModelForNextSentencePrediction = auto_class_update( | |
TFAutoModelForNextSentencePrediction, head_doc="next sentence prediction" | |
) | |
class TFAutoModelWithLMHead(_TFAutoModelWithLMHead): | |
def from_config(cls, config): | |
warnings.warn( | |
"The class `TFAutoModelWithLMHead` is deprecated and will be removed in a future version. Please use " | |
"`TFAutoModelForCausalLM` for causal language models, `TFAutoModelForMaskedLM` for masked language models and " | |
"`TFAutoModelForSeq2SeqLM` for encoder-decoder models.", | |
FutureWarning, | |
) | |
return super().from_config(config) | |
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): | |
warnings.warn( | |
"The class `TFAutoModelWithLMHead` is deprecated and will be removed in a future version. Please use " | |
"`TFAutoModelForCausalLM` for causal language models, `TFAutoModelForMaskedLM` for masked language models and " | |
"`TFAutoModelForSeq2SeqLM` for encoder-decoder models.", | |
FutureWarning, | |
) | |
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) | |