# coding=utf-8 # Copyright 2018 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Auto Model class. """ import warnings from collections import OrderedDict from ...utils import logging # Add modeling imports here from ..albert.modeling_albert import ( AlbertForMaskedLM, AlbertForMultipleChoice, AlbertForPreTraining, AlbertForQuestionAnswering, AlbertForSequenceClassification, AlbertForTokenClassification, AlbertModel, ) from ..bart.modeling_bart import ( BartForCausalLM, BartForConditionalGeneration, BartForQuestionAnswering, BartForSequenceClassification, BartModel, ) from ..bert.modeling_bert import ( BertForMaskedLM, BertForMultipleChoice, BertForNextSentencePrediction, BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification, BertForTokenClassification, BertLMHeadModel, BertModel, ) from ..bert_generation.modeling_bert_generation import BertGenerationDecoder, BertGenerationEncoder from ..big_bird.modeling_big_bird import ( BigBirdForCausalLM, BigBirdForMaskedLM, BigBirdForMultipleChoice, BigBirdForPreTraining, BigBirdForQuestionAnswering, BigBirdForSequenceClassification, BigBirdForTokenClassification, BigBirdModel, ) from ..bigbird_pegasus.modeling_bigbird_pegasus import ( BigBirdPegasusForCausalLM, BigBirdPegasusForConditionalGeneration, BigBirdPegasusForQuestionAnswering, BigBirdPegasusForSequenceClassification, BigBirdPegasusModel, ) from ..blenderbot.modeling_blenderbot import BlenderbotForCausalLM, BlenderbotForConditionalGeneration, BlenderbotModel from ..blenderbot_small.modeling_blenderbot_small import ( BlenderbotSmallForCausalLM, BlenderbotSmallForConditionalGeneration, BlenderbotSmallModel, ) from ..camembert.modeling_camembert import ( CamembertForCausalLM, CamembertForMaskedLM, CamembertForMultipleChoice, CamembertForQuestionAnswering, CamembertForSequenceClassification, CamembertForTokenClassification, CamembertModel, ) from ..canine.modeling_canine import ( CanineForMultipleChoice, CanineForQuestionAnswering, CanineForSequenceClassification, CanineForTokenClassification, CanineModel, ) from ..clip.modeling_clip import CLIPModel from ..convbert.modeling_convbert import ( ConvBertForMaskedLM, ConvBertForMultipleChoice, ConvBertForQuestionAnswering, ConvBertForSequenceClassification, ConvBertForTokenClassification, ConvBertModel, ) from ..ctrl.modeling_ctrl import CTRLForSequenceClassification, CTRLLMHeadModel, CTRLModel from ..deberta.modeling_deberta import ( DebertaForMaskedLM, DebertaForQuestionAnswering, DebertaForSequenceClassification, DebertaForTokenClassification, DebertaModel, ) from ..deberta_v2.modeling_deberta_v2 import ( DebertaV2ForMaskedLM, DebertaV2ForQuestionAnswering, DebertaV2ForSequenceClassification, DebertaV2ForTokenClassification, DebertaV2Model, ) from ..deit.modeling_deit import DeiTForImageClassification, DeiTForImageClassificationWithTeacher, DeiTModel from ..detr.modeling_detr import DetrForObjectDetection, DetrModel from ..distilbert.modeling_distilbert import ( DistilBertForMaskedLM, DistilBertForMultipleChoice, DistilBertForQuestionAnswering, DistilBertForSequenceClassification, DistilBertForTokenClassification, DistilBertModel, ) from ..dpr.modeling_dpr import DPRQuestionEncoder from ..electra.modeling_electra import ( ElectraForMaskedLM, ElectraForMultipleChoice, ElectraForPreTraining, ElectraForQuestionAnswering, ElectraForSequenceClassification, ElectraForTokenClassification, ElectraModel, ) from ..encoder_decoder.modeling_encoder_decoder import EncoderDecoderModel from ..flaubert.modeling_flaubert import ( FlaubertForMultipleChoice, FlaubertForQuestionAnsweringSimple, FlaubertForSequenceClassification, FlaubertForTokenClassification, FlaubertModel, FlaubertWithLMHeadModel, ) from ..fsmt.modeling_fsmt import FSMTForConditionalGeneration, FSMTModel from ..funnel.modeling_funnel import ( FunnelBaseModel, FunnelForMaskedLM, FunnelForMultipleChoice, FunnelForPreTraining, FunnelForQuestionAnswering, FunnelForSequenceClassification, FunnelForTokenClassification, FunnelModel, ) from ..gpt2.modeling_gpt2 import GPT2ForSequenceClassification, GPT2LMHeadModel, GPT2Model from ..gpt_neo.modeling_gpt_neo import GPTNeoForCausalLM, GPTNeoForSequenceClassification, GPTNeoModel from ..hubert.modeling_hubert import HubertModel from ..ibert.modeling_ibert import ( IBertForMaskedLM, IBertForMultipleChoice, IBertForQuestionAnswering, IBertForSequenceClassification, IBertForTokenClassification, IBertModel, ) from ..layoutlm.modeling_layoutlm import ( LayoutLMForMaskedLM, LayoutLMForSequenceClassification, LayoutLMForTokenClassification, LayoutLMModel, ) from ..led.modeling_led import ( LEDForConditionalGeneration, LEDForQuestionAnswering, LEDForSequenceClassification, LEDModel, ) from ..longformer.modeling_longformer import ( LongformerForMaskedLM, LongformerForMultipleChoice, LongformerForQuestionAnswering, LongformerForSequenceClassification, LongformerForTokenClassification, LongformerModel, ) from ..luke.modeling_luke import LukeModel from ..lxmert.modeling_lxmert import LxmertForPreTraining, LxmertForQuestionAnswering, LxmertModel from ..m2m_100.modeling_m2m_100 import M2M100ForConditionalGeneration, M2M100Model from ..marian.modeling_marian import MarianForCausalLM, MarianModel, MarianMTModel from ..mbart.modeling_mbart import ( MBartForCausalLM, MBartForConditionalGeneration, MBartForQuestionAnswering, MBartForSequenceClassification, MBartModel, ) from ..megatron_bert.modeling_megatron_bert import ( MegatronBertForCausalLM, MegatronBertForMaskedLM, MegatronBertForMultipleChoice, MegatronBertForNextSentencePrediction, MegatronBertForPreTraining, MegatronBertForQuestionAnswering, MegatronBertForSequenceClassification, MegatronBertForTokenClassification, MegatronBertModel, ) from ..mobilebert.modeling_mobilebert import ( MobileBertForMaskedLM, MobileBertForMultipleChoice, MobileBertForNextSentencePrediction, MobileBertForPreTraining, MobileBertForQuestionAnswering, MobileBertForSequenceClassification, MobileBertForTokenClassification, MobileBertModel, ) from ..mpnet.modeling_mpnet import ( MPNetForMaskedLM, MPNetForMultipleChoice, MPNetForQuestionAnswering, MPNetForSequenceClassification, MPNetForTokenClassification, MPNetModel, ) from ..mt5.modeling_mt5 import MT5ForConditionalGeneration, MT5Model from ..openai.modeling_openai import OpenAIGPTForSequenceClassification, OpenAIGPTLMHeadModel, OpenAIGPTModel from ..pegasus.modeling_pegasus import PegasusForCausalLM, PegasusForConditionalGeneration, PegasusModel from ..prophetnet.modeling_prophetnet import ProphetNetForCausalLM, ProphetNetForConditionalGeneration, ProphetNetModel from ..rag.modeling_rag import ( # noqa: F401 - need to import all RagModels to be in globals() function RagModel, RagSequenceForGeneration, RagTokenForGeneration, ) from ..reformer.modeling_reformer import ( ReformerForMaskedLM, ReformerForQuestionAnswering, ReformerForSequenceClassification, ReformerModel, ReformerModelWithLMHead, ) from ..retribert.modeling_retribert import RetriBertModel from ..roberta.modeling_roberta import ( RobertaForCausalLM, RobertaForMaskedLM, RobertaForMultipleChoice, RobertaForQuestionAnswering, RobertaForSequenceClassification, RobertaForTokenClassification, RobertaModel, ) from ..roformer.modeling_roformer import ( RoFormerForCausalLM, RoFormerForMaskedLM, RoFormerForMultipleChoice, RoFormerForQuestionAnswering, RoFormerForSequenceClassification, RoFormerForTokenClassification, RoFormerModel, ) from ..speech_to_text.modeling_speech_to_text import Speech2TextForConditionalGeneration, Speech2TextModel from ..squeezebert.modeling_squeezebert import ( SqueezeBertForMaskedLM, SqueezeBertForMultipleChoice, SqueezeBertForQuestionAnswering, SqueezeBertForSequenceClassification, SqueezeBertForTokenClassification, SqueezeBertModel, ) from ..t5.modeling_t5 import T5ForConditionalGeneration, T5Model from ..tapas.modeling_tapas import ( TapasForMaskedLM, TapasForQuestionAnswering, TapasForSequenceClassification, TapasModel, ) from ..transfo_xl.modeling_transfo_xl import TransfoXLForSequenceClassification, TransfoXLLMHeadModel, TransfoXLModel from ..visual_bert.modeling_visual_bert import VisualBertForPreTraining, VisualBertModel from ..vit.modeling_vit import ViTForImageClassification, ViTModel from ..wav2vec2.modeling_wav2vec2 import Wav2Vec2ForMaskedLM, Wav2Vec2ForPreTraining, Wav2Vec2Model from ..xlm.modeling_xlm import ( XLMForMultipleChoice, XLMForQuestionAnsweringSimple, XLMForSequenceClassification, XLMForTokenClassification, XLMModel, XLMWithLMHeadModel, ) from ..xlm_prophetnet.modeling_xlm_prophetnet import ( XLMProphetNetForCausalLM, XLMProphetNetForConditionalGeneration, XLMProphetNetModel, ) from ..xlm_roberta.modeling_xlm_roberta import ( XLMRobertaForCausalLM, XLMRobertaForMaskedLM, XLMRobertaForMultipleChoice, XLMRobertaForQuestionAnswering, XLMRobertaForSequenceClassification, XLMRobertaForTokenClassification, XLMRobertaModel, ) from ..xlnet.modeling_xlnet import ( XLNetForMultipleChoice, XLNetForQuestionAnsweringSimple, XLNetForSequenceClassification, XLNetForTokenClassification, XLNetLMHeadModel, XLNetModel, ) from .auto_factory import _BaseAutoModelClass, auto_class_update from .configuration_auto import ( AlbertConfig, BartConfig, BertConfig, BertGenerationConfig, BigBirdConfig, BigBirdPegasusConfig, BlenderbotConfig, BlenderbotSmallConfig, CamembertConfig, CanineConfig, CLIPConfig, ConvBertConfig, CTRLConfig, DebertaConfig, DebertaV2Config, DeiTConfig, DetrConfig, DistilBertConfig, DPRConfig, ElectraConfig, EncoderDecoderConfig, FlaubertConfig, FSMTConfig, FunnelConfig, GPT2Config, GPTNeoConfig, HubertConfig, IBertConfig, LayoutLMConfig, LEDConfig, LongformerConfig, LukeConfig, LxmertConfig, M2M100Config, MarianConfig, MBartConfig, MegatronBertConfig, MobileBertConfig, MPNetConfig, MT5Config, OpenAIGPTConfig, PegasusConfig, ProphetNetConfig, ReformerConfig, RetriBertConfig, RobertaConfig, RoFormerConfig, Speech2TextConfig, SqueezeBertConfig, T5Config, TapasConfig, TransfoXLConfig, VisualBertConfig, ViTConfig, Wav2Vec2Config, XLMConfig, XLMProphetNetConfig, XLMRobertaConfig, XLNetConfig, ) logger = logging.get_logger(__name__) MODEL_MAPPING = OrderedDict( [ # Base model mapping (VisualBertConfig, VisualBertModel), (CanineConfig, CanineModel), (RoFormerConfig, RoFormerModel), (CLIPConfig, CLIPModel), (BigBirdPegasusConfig, BigBirdPegasusModel), (DeiTConfig, DeiTModel), (LukeConfig, LukeModel), (DetrConfig, DetrModel), (GPTNeoConfig, GPTNeoModel), (BigBirdConfig, BigBirdModel), (Speech2TextConfig, Speech2TextModel), (ViTConfig, ViTModel), (Wav2Vec2Config, Wav2Vec2Model), (HubertConfig, HubertModel), (M2M100Config, M2M100Model), (ConvBertConfig, ConvBertModel), (LEDConfig, LEDModel), (BlenderbotSmallConfig, BlenderbotSmallModel), (RetriBertConfig, RetriBertModel), (MT5Config, MT5Model), (T5Config, T5Model), (PegasusConfig, PegasusModel), (MarianConfig, MarianMTModel), (MBartConfig, MBartModel), (BlenderbotConfig, BlenderbotModel), (DistilBertConfig, DistilBertModel), (AlbertConfig, AlbertModel), (CamembertConfig, CamembertModel), (XLMRobertaConfig, XLMRobertaModel), (BartConfig, BartModel), (LongformerConfig, LongformerModel), (RobertaConfig, RobertaModel), (LayoutLMConfig, LayoutLMModel), (SqueezeBertConfig, SqueezeBertModel), (BertConfig, BertModel), (OpenAIGPTConfig, OpenAIGPTModel), (GPT2Config, GPT2Model), (MegatronBertConfig, MegatronBertModel), (MobileBertConfig, MobileBertModel), (TransfoXLConfig, TransfoXLModel), (XLNetConfig, XLNetModel), (FlaubertConfig, FlaubertModel), (FSMTConfig, FSMTModel), (XLMConfig, XLMModel), (CTRLConfig, CTRLModel), (ElectraConfig, ElectraModel), (ReformerConfig, ReformerModel), (FunnelConfig, (FunnelModel, FunnelBaseModel)), (LxmertConfig, LxmertModel), (BertGenerationConfig, BertGenerationEncoder), (DebertaConfig, DebertaModel), (DebertaV2Config, DebertaV2Model), (DPRConfig, DPRQuestionEncoder), (XLMProphetNetConfig, XLMProphetNetModel), (ProphetNetConfig, ProphetNetModel), (MPNetConfig, MPNetModel), (TapasConfig, TapasModel), (MarianConfig, MarianModel), (IBertConfig, IBertModel), ] ) MODEL_FOR_PRETRAINING_MAPPING = OrderedDict( [ # Model for pre-training mapping (VisualBertConfig, VisualBertForPreTraining), (LayoutLMConfig, LayoutLMForMaskedLM), (RetriBertConfig, RetriBertModel), (T5Config, T5ForConditionalGeneration), (DistilBertConfig, DistilBertForMaskedLM), (AlbertConfig, AlbertForPreTraining), (CamembertConfig, CamembertForMaskedLM), (XLMRobertaConfig, XLMRobertaForMaskedLM), (BartConfig, BartForConditionalGeneration), (FSMTConfig, FSMTForConditionalGeneration), (LongformerConfig, LongformerForMaskedLM), (RobertaConfig, RobertaForMaskedLM), (SqueezeBertConfig, SqueezeBertForMaskedLM), (BertConfig, BertForPreTraining), (BigBirdConfig, BigBirdForPreTraining), (OpenAIGPTConfig, OpenAIGPTLMHeadModel), (GPT2Config, GPT2LMHeadModel), (MegatronBertConfig, MegatronBertForPreTraining), (MobileBertConfig, MobileBertForPreTraining), (TransfoXLConfig, TransfoXLLMHeadModel), (XLNetConfig, XLNetLMHeadModel), (FlaubertConfig, FlaubertWithLMHeadModel), (XLMConfig, XLMWithLMHeadModel), (CTRLConfig, CTRLLMHeadModel), (ElectraConfig, ElectraForPreTraining), (LxmertConfig, LxmertForPreTraining), (FunnelConfig, FunnelForPreTraining), (MPNetConfig, MPNetForMaskedLM), (TapasConfig, TapasForMaskedLM), (IBertConfig, IBertForMaskedLM), (DebertaConfig, DebertaForMaskedLM), (DebertaV2Config, DebertaV2ForMaskedLM), (Wav2Vec2Config, Wav2Vec2ForPreTraining), ] ) MODEL_WITH_LM_HEAD_MAPPING = OrderedDict( [ # Model with LM heads mapping (RoFormerConfig, RoFormerForMaskedLM), (BigBirdPegasusConfig, BigBirdPegasusForConditionalGeneration), (GPTNeoConfig, GPTNeoForCausalLM), (BigBirdConfig, BigBirdForMaskedLM), (Speech2TextConfig, Speech2TextForConditionalGeneration), (Wav2Vec2Config, Wav2Vec2ForMaskedLM), (M2M100Config, M2M100ForConditionalGeneration), (ConvBertConfig, ConvBertForMaskedLM), (LEDConfig, LEDForConditionalGeneration), (BlenderbotSmallConfig, BlenderbotSmallForConditionalGeneration), (LayoutLMConfig, LayoutLMForMaskedLM), (T5Config, T5ForConditionalGeneration), (DistilBertConfig, DistilBertForMaskedLM), (AlbertConfig, AlbertForMaskedLM), (CamembertConfig, CamembertForMaskedLM), (XLMRobertaConfig, XLMRobertaForMaskedLM), (MarianConfig, MarianMTModel), (FSMTConfig, FSMTForConditionalGeneration), (BartConfig, BartForConditionalGeneration), (LongformerConfig, LongformerForMaskedLM), (RobertaConfig, RobertaForMaskedLM), (SqueezeBertConfig, SqueezeBertForMaskedLM), (BertConfig, BertForMaskedLM), (OpenAIGPTConfig, OpenAIGPTLMHeadModel), (GPT2Config, GPT2LMHeadModel), (MegatronBertConfig, MegatronBertForMaskedLM), (MobileBertConfig, MobileBertForMaskedLM), (TransfoXLConfig, TransfoXLLMHeadModel), (XLNetConfig, XLNetLMHeadModel), (FlaubertConfig, FlaubertWithLMHeadModel), (XLMConfig, XLMWithLMHeadModel), (CTRLConfig, CTRLLMHeadModel), (ElectraConfig, ElectraForMaskedLM), (EncoderDecoderConfig, EncoderDecoderModel), (ReformerConfig, ReformerModelWithLMHead), (FunnelConfig, FunnelForMaskedLM), (MPNetConfig, MPNetForMaskedLM), (TapasConfig, TapasForMaskedLM), (DebertaConfig, DebertaForMaskedLM), (DebertaV2Config, DebertaV2ForMaskedLM), (IBertConfig, IBertForMaskedLM), (MegatronBertConfig, MegatronBertForCausalLM), ] ) MODEL_FOR_CAUSAL_LM_MAPPING = OrderedDict( [ # Model for Causal LM mapping (RoFormerConfig, RoFormerForCausalLM), (BigBirdPegasusConfig, BigBirdPegasusForCausalLM), (GPTNeoConfig, GPTNeoForCausalLM), (BigBirdConfig, BigBirdForCausalLM), (CamembertConfig, CamembertForCausalLM), (XLMRobertaConfig, XLMRobertaForCausalLM), (RobertaConfig, RobertaForCausalLM), (BertConfig, BertLMHeadModel), (OpenAIGPTConfig, OpenAIGPTLMHeadModel), (GPT2Config, GPT2LMHeadModel), (TransfoXLConfig, TransfoXLLMHeadModel), (XLNetConfig, XLNetLMHeadModel), ( XLMConfig, XLMWithLMHeadModel, ), # XLM can be MLM and CLM => model should be split similar to BERT; leave here for now (CTRLConfig, CTRLLMHeadModel), (ReformerConfig, ReformerModelWithLMHead), (BertGenerationConfig, BertGenerationDecoder), (XLMProphetNetConfig, XLMProphetNetForCausalLM), (ProphetNetConfig, ProphetNetForCausalLM), (BartConfig, BartForCausalLM), (MBartConfig, MBartForCausalLM), (PegasusConfig, PegasusForCausalLM), (MarianConfig, MarianForCausalLM), (BlenderbotConfig, BlenderbotForCausalLM), (BlenderbotSmallConfig, BlenderbotSmallForCausalLM), (MegatronBertConfig, MegatronBertForCausalLM), ] ) MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = OrderedDict( [ # Model for Image Classification mapping (ViTConfig, ViTForImageClassification), (DeiTConfig, (DeiTForImageClassification, DeiTForImageClassificationWithTeacher)), ] ) MODEL_FOR_MASKED_LM_MAPPING = OrderedDict( [ # Model for Masked LM mapping (RoFormerConfig, RoFormerForMaskedLM), (BigBirdConfig, BigBirdForMaskedLM), (Wav2Vec2Config, Wav2Vec2ForMaskedLM), (ConvBertConfig, ConvBertForMaskedLM), (LayoutLMConfig, LayoutLMForMaskedLM), (DistilBertConfig, DistilBertForMaskedLM), (AlbertConfig, AlbertForMaskedLM), (BartConfig, BartForConditionalGeneration), (MBartConfig, MBartForConditionalGeneration), (CamembertConfig, CamembertForMaskedLM), (XLMRobertaConfig, XLMRobertaForMaskedLM), (LongformerConfig, LongformerForMaskedLM), (RobertaConfig, RobertaForMaskedLM), (SqueezeBertConfig, SqueezeBertForMaskedLM), (BertConfig, BertForMaskedLM), (MegatronBertConfig, MegatronBertForMaskedLM), (MobileBertConfig, MobileBertForMaskedLM), (FlaubertConfig, FlaubertWithLMHeadModel), (XLMConfig, XLMWithLMHeadModel), (ElectraConfig, ElectraForMaskedLM), (ReformerConfig, ReformerForMaskedLM), (FunnelConfig, FunnelForMaskedLM), (MPNetConfig, MPNetForMaskedLM), (TapasConfig, TapasForMaskedLM), (DebertaConfig, DebertaForMaskedLM), (DebertaV2Config, DebertaV2ForMaskedLM), (IBertConfig, IBertForMaskedLM), ] ) MODEL_FOR_OBJECT_DETECTION_MAPPING = OrderedDict( [ # Model for Object Detection mapping (DetrConfig, DetrForObjectDetection), ] ) MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict( [ # Model for Seq2Seq Causal LM mapping (BigBirdPegasusConfig, BigBirdPegasusForConditionalGeneration), (M2M100Config, M2M100ForConditionalGeneration), (LEDConfig, LEDForConditionalGeneration), (BlenderbotSmallConfig, BlenderbotSmallForConditionalGeneration), (MT5Config, MT5ForConditionalGeneration), (T5Config, T5ForConditionalGeneration), (PegasusConfig, PegasusForConditionalGeneration), (MarianConfig, MarianMTModel), (MBartConfig, MBartForConditionalGeneration), (BlenderbotConfig, BlenderbotForConditionalGeneration), (BartConfig, BartForConditionalGeneration), (FSMTConfig, FSMTForConditionalGeneration), (EncoderDecoderConfig, EncoderDecoderModel), (XLMProphetNetConfig, XLMProphetNetForConditionalGeneration), (ProphetNetConfig, ProphetNetForConditionalGeneration), ] ) MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict( [ # Model for Sequence Classification mapping (CanineConfig, CanineForSequenceClassification), (RoFormerConfig, RoFormerForSequenceClassification), (BigBirdPegasusConfig, BigBirdPegasusForSequenceClassification), (BigBirdConfig, BigBirdForSequenceClassification), (ConvBertConfig, ConvBertForSequenceClassification), (LEDConfig, LEDForSequenceClassification), (DistilBertConfig, DistilBertForSequenceClassification), (AlbertConfig, AlbertForSequenceClassification), (CamembertConfig, CamembertForSequenceClassification), (XLMRobertaConfig, XLMRobertaForSequenceClassification), (MBartConfig, MBartForSequenceClassification), (BartConfig, BartForSequenceClassification), (LongformerConfig, LongformerForSequenceClassification), (RobertaConfig, RobertaForSequenceClassification), (SqueezeBertConfig, SqueezeBertForSequenceClassification), (LayoutLMConfig, LayoutLMForSequenceClassification), (BertConfig, BertForSequenceClassification), (XLNetConfig, XLNetForSequenceClassification), (MegatronBertConfig, MegatronBertForSequenceClassification), (MobileBertConfig, MobileBertForSequenceClassification), (FlaubertConfig, FlaubertForSequenceClassification), (XLMConfig, XLMForSequenceClassification), (ElectraConfig, ElectraForSequenceClassification), (FunnelConfig, FunnelForSequenceClassification), (DebertaConfig, DebertaForSequenceClassification), (DebertaV2Config, DebertaV2ForSequenceClassification), (GPT2Config, GPT2ForSequenceClassification), (GPTNeoConfig, GPTNeoForSequenceClassification), (OpenAIGPTConfig, OpenAIGPTForSequenceClassification), (ReformerConfig, ReformerForSequenceClassification), (CTRLConfig, CTRLForSequenceClassification), (TransfoXLConfig, TransfoXLForSequenceClassification), (MPNetConfig, MPNetForSequenceClassification), (TapasConfig, TapasForSequenceClassification), (IBertConfig, IBertForSequenceClassification), ] ) MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict( [ # Model for Question Answering mapping (CanineConfig, CanineForQuestionAnswering), (RoFormerConfig, RoFormerForQuestionAnswering), (BigBirdPegasusConfig, BigBirdPegasusForQuestionAnswering), (BigBirdConfig, BigBirdForQuestionAnswering), (ConvBertConfig, ConvBertForQuestionAnswering), (LEDConfig, LEDForQuestionAnswering), (DistilBertConfig, DistilBertForQuestionAnswering), (AlbertConfig, AlbertForQuestionAnswering), (CamembertConfig, CamembertForQuestionAnswering), (BartConfig, BartForQuestionAnswering), (MBartConfig, MBartForQuestionAnswering), (LongformerConfig, LongformerForQuestionAnswering), (XLMRobertaConfig, XLMRobertaForQuestionAnswering), (RobertaConfig, RobertaForQuestionAnswering), (SqueezeBertConfig, SqueezeBertForQuestionAnswering), (BertConfig, BertForQuestionAnswering), (XLNetConfig, XLNetForQuestionAnsweringSimple), (FlaubertConfig, FlaubertForQuestionAnsweringSimple), (MegatronBertConfig, MegatronBertForQuestionAnswering), (MobileBertConfig, MobileBertForQuestionAnswering), (XLMConfig, XLMForQuestionAnsweringSimple), (ElectraConfig, ElectraForQuestionAnswering), (ReformerConfig, ReformerForQuestionAnswering), (FunnelConfig, FunnelForQuestionAnswering), (LxmertConfig, LxmertForQuestionAnswering), (MPNetConfig, MPNetForQuestionAnswering), (DebertaConfig, DebertaForQuestionAnswering), (DebertaV2Config, DebertaV2ForQuestionAnswering), (IBertConfig, IBertForQuestionAnswering), ] ) MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = OrderedDict( [ # Model for Table Question Answering mapping (TapasConfig, TapasForQuestionAnswering), ] ) MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict( [ # Model for Token Classification mapping (CanineConfig, CanineForTokenClassification), (RoFormerConfig, RoFormerForTokenClassification), (BigBirdConfig, BigBirdForTokenClassification), (ConvBertConfig, ConvBertForTokenClassification), (LayoutLMConfig, LayoutLMForTokenClassification), (DistilBertConfig, DistilBertForTokenClassification), (CamembertConfig, CamembertForTokenClassification), (FlaubertConfig, FlaubertForTokenClassification), (XLMConfig, XLMForTokenClassification), (XLMRobertaConfig, XLMRobertaForTokenClassification), (LongformerConfig, LongformerForTokenClassification), (RobertaConfig, RobertaForTokenClassification), (SqueezeBertConfig, SqueezeBertForTokenClassification), (BertConfig, BertForTokenClassification), (MegatronBertConfig, MegatronBertForTokenClassification), (MobileBertConfig, MobileBertForTokenClassification), (XLNetConfig, XLNetForTokenClassification), (AlbertConfig, AlbertForTokenClassification), (ElectraConfig, ElectraForTokenClassification), (FlaubertConfig, FlaubertForTokenClassification), (FunnelConfig, FunnelForTokenClassification), (MPNetConfig, MPNetForTokenClassification), (DebertaConfig, DebertaForTokenClassification), (DebertaV2Config, DebertaV2ForTokenClassification), (IBertConfig, IBertForTokenClassification), ] ) MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict( [ # Model for Multiple Choice mapping (CanineConfig, CanineForMultipleChoice), (RoFormerConfig, RoFormerForMultipleChoice), (BigBirdConfig, BigBirdForMultipleChoice), (ConvBertConfig, ConvBertForMultipleChoice), (CamembertConfig, CamembertForMultipleChoice), (ElectraConfig, ElectraForMultipleChoice), (XLMRobertaConfig, XLMRobertaForMultipleChoice), (LongformerConfig, LongformerForMultipleChoice), (RobertaConfig, RobertaForMultipleChoice), (SqueezeBertConfig, SqueezeBertForMultipleChoice), (BertConfig, BertForMultipleChoice), (DistilBertConfig, DistilBertForMultipleChoice), (MegatronBertConfig, MegatronBertForMultipleChoice), (MobileBertConfig, MobileBertForMultipleChoice), (XLNetConfig, XLNetForMultipleChoice), (AlbertConfig, AlbertForMultipleChoice), (XLMConfig, XLMForMultipleChoice), (FlaubertConfig, FlaubertForMultipleChoice), (FunnelConfig, FunnelForMultipleChoice), (MPNetConfig, MPNetForMultipleChoice), (IBertConfig, IBertForMultipleChoice), ] ) MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = OrderedDict( [ (BertConfig, BertForNextSentencePrediction), (MegatronBertConfig, MegatronBertForNextSentencePrediction), (MobileBertConfig, MobileBertForNextSentencePrediction), ] ) class AutoModel(_BaseAutoModelClass): _model_mapping = MODEL_MAPPING AutoModel = auto_class_update(AutoModel) class AutoModelForPreTraining(_BaseAutoModelClass): _model_mapping = MODEL_FOR_PRETRAINING_MAPPING AutoModelForPreTraining = auto_class_update(AutoModelForPreTraining, head_doc="pretraining") # Private on purpose, the public class will add the deprecation warnings. class _AutoModelWithLMHead(_BaseAutoModelClass): _model_mapping = MODEL_WITH_LM_HEAD_MAPPING _AutoModelWithLMHead = auto_class_update(_AutoModelWithLMHead, head_doc="language modeling") class AutoModelForCausalLM(_BaseAutoModelClass): _model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING AutoModelForCausalLM = auto_class_update(AutoModelForCausalLM, head_doc="causal language modeling") class AutoModelForMaskedLM(_BaseAutoModelClass): _model_mapping = MODEL_FOR_MASKED_LM_MAPPING AutoModelForMaskedLM = auto_class_update(AutoModelForMaskedLM, head_doc="masked language modeling") class AutoModelForSeq2SeqLM(_BaseAutoModelClass): _model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING AutoModelForSeq2SeqLM = auto_class_update( AutoModelForSeq2SeqLM, head_doc="sequence-to-sequence language modeling", checkpoint_for_example="t5-base" ) class AutoModelForSequenceClassification(_BaseAutoModelClass): _model_mapping = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING AutoModelForSequenceClassification = auto_class_update( AutoModelForSequenceClassification, head_doc="sequence classification" ) class AutoModelForQuestionAnswering(_BaseAutoModelClass): _model_mapping = MODEL_FOR_QUESTION_ANSWERING_MAPPING AutoModelForQuestionAnswering = auto_class_update(AutoModelForQuestionAnswering, head_doc="question answering") class AutoModelForTableQuestionAnswering(_BaseAutoModelClass): _model_mapping = MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING AutoModelForTableQuestionAnswering = auto_class_update( AutoModelForTableQuestionAnswering, head_doc="table question answering", checkpoint_for_example="google/tapas-base-finetuned-wtq", ) class AutoModelForTokenClassification(_BaseAutoModelClass): _model_mapping = MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING AutoModelForTokenClassification = auto_class_update(AutoModelForTokenClassification, head_doc="token classification") class AutoModelForMultipleChoice(_BaseAutoModelClass): _model_mapping = MODEL_FOR_MULTIPLE_CHOICE_MAPPING AutoModelForMultipleChoice = auto_class_update(AutoModelForMultipleChoice, head_doc="multiple choice") class AutoModelForNextSentencePrediction(_BaseAutoModelClass): _model_mapping = MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING AutoModelForNextSentencePrediction = auto_class_update( AutoModelForNextSentencePrediction, head_doc="next sentence prediction" ) class AutoModelForImageClassification(_BaseAutoModelClass): _model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING AutoModelForImageClassification = auto_class_update(AutoModelForImageClassification, head_doc="image classification") class AutoModelWithLMHead(_AutoModelWithLMHead): @classmethod def from_config(cls, config): warnings.warn( "The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use " "`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and " "`AutoModelForSeq2SeqLM` for encoder-decoder models.", FutureWarning, ) return super().from_config(config) @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): warnings.warn( "The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use " "`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and " "`AutoModelForSeq2SeqLM` for encoder-decoder models.", FutureWarning, ) return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)