Spaces:
Sleeping
Sleeping
# coding=utf-8 | |
# Copyright 2018 The HuggingFace Inc. team. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
""" Auto Model class. """ | |
import warnings | |
from collections import OrderedDict | |
from ...utils import logging | |
# Add modeling imports here | |
from ..albert.modeling_albert import ( | |
AlbertForMaskedLM, | |
AlbertForMultipleChoice, | |
AlbertForPreTraining, | |
AlbertForQuestionAnswering, | |
AlbertForSequenceClassification, | |
AlbertForTokenClassification, | |
AlbertModel, | |
) | |
from ..bart.modeling_bart import ( | |
BartForCausalLM, | |
BartForConditionalGeneration, | |
BartForQuestionAnswering, | |
BartForSequenceClassification, | |
BartModel, | |
) | |
from ..bert.modeling_bert import ( | |
BertForMaskedLM, | |
BertForMultipleChoice, | |
BertForNextSentencePrediction, | |
BertForPreTraining, | |
BertForQuestionAnswering, | |
BertForSequenceClassification, | |
BertForTokenClassification, | |
BertLMHeadModel, | |
BertModel, | |
) | |
from ..bert_generation.modeling_bert_generation import BertGenerationDecoder, BertGenerationEncoder | |
from ..big_bird.modeling_big_bird import ( | |
BigBirdForCausalLM, | |
BigBirdForMaskedLM, | |
BigBirdForMultipleChoice, | |
BigBirdForPreTraining, | |
BigBirdForQuestionAnswering, | |
BigBirdForSequenceClassification, | |
BigBirdForTokenClassification, | |
BigBirdModel, | |
) | |
from ..bigbird_pegasus.modeling_bigbird_pegasus import ( | |
BigBirdPegasusForCausalLM, | |
BigBirdPegasusForConditionalGeneration, | |
BigBirdPegasusForQuestionAnswering, | |
BigBirdPegasusForSequenceClassification, | |
BigBirdPegasusModel, | |
) | |
from ..blenderbot.modeling_blenderbot import BlenderbotForCausalLM, BlenderbotForConditionalGeneration, BlenderbotModel | |
from ..blenderbot_small.modeling_blenderbot_small import ( | |
BlenderbotSmallForCausalLM, | |
BlenderbotSmallForConditionalGeneration, | |
BlenderbotSmallModel, | |
) | |
from ..camembert.modeling_camembert import ( | |
CamembertForCausalLM, | |
CamembertForMaskedLM, | |
CamembertForMultipleChoice, | |
CamembertForQuestionAnswering, | |
CamembertForSequenceClassification, | |
CamembertForTokenClassification, | |
CamembertModel, | |
) | |
from ..canine.modeling_canine import ( | |
CanineForMultipleChoice, | |
CanineForQuestionAnswering, | |
CanineForSequenceClassification, | |
CanineForTokenClassification, | |
CanineModel, | |
) | |
from ..clip.modeling_clip import CLIPModel | |
from ..convbert.modeling_convbert import ( | |
ConvBertForMaskedLM, | |
ConvBertForMultipleChoice, | |
ConvBertForQuestionAnswering, | |
ConvBertForSequenceClassification, | |
ConvBertForTokenClassification, | |
ConvBertModel, | |
) | |
from ..ctrl.modeling_ctrl import CTRLForSequenceClassification, CTRLLMHeadModel, CTRLModel | |
from ..deberta.modeling_deberta import ( | |
DebertaForMaskedLM, | |
DebertaForQuestionAnswering, | |
DebertaForSequenceClassification, | |
DebertaForTokenClassification, | |
DebertaModel, | |
) | |
from ..deberta_v2.modeling_deberta_v2 import ( | |
DebertaV2ForMaskedLM, | |
DebertaV2ForQuestionAnswering, | |
DebertaV2ForSequenceClassification, | |
DebertaV2ForTokenClassification, | |
DebertaV2Model, | |
) | |
from ..deit.modeling_deit import DeiTForImageClassification, DeiTForImageClassificationWithTeacher, DeiTModel | |
from ..detr.modeling_detr import DetrForObjectDetection, DetrModel | |
from ..distilbert.modeling_distilbert import ( | |
DistilBertForMaskedLM, | |
DistilBertForMultipleChoice, | |
DistilBertForQuestionAnswering, | |
DistilBertForSequenceClassification, | |
DistilBertForTokenClassification, | |
DistilBertModel, | |
) | |
from ..dpr.modeling_dpr import DPRQuestionEncoder | |
from ..electra.modeling_electra import ( | |
ElectraForMaskedLM, | |
ElectraForMultipleChoice, | |
ElectraForPreTraining, | |
ElectraForQuestionAnswering, | |
ElectraForSequenceClassification, | |
ElectraForTokenClassification, | |
ElectraModel, | |
) | |
from ..encoder_decoder.modeling_encoder_decoder import EncoderDecoderModel | |
from ..flaubert.modeling_flaubert import ( | |
FlaubertForMultipleChoice, | |
FlaubertForQuestionAnsweringSimple, | |
FlaubertForSequenceClassification, | |
FlaubertForTokenClassification, | |
FlaubertModel, | |
FlaubertWithLMHeadModel, | |
) | |
from ..fsmt.modeling_fsmt import FSMTForConditionalGeneration, FSMTModel | |
from ..funnel.modeling_funnel import ( | |
FunnelBaseModel, | |
FunnelForMaskedLM, | |
FunnelForMultipleChoice, | |
FunnelForPreTraining, | |
FunnelForQuestionAnswering, | |
FunnelForSequenceClassification, | |
FunnelForTokenClassification, | |
FunnelModel, | |
) | |
from ..gpt2.modeling_gpt2 import GPT2ForSequenceClassification, GPT2LMHeadModel, GPT2Model | |
from ..gpt_neo.modeling_gpt_neo import GPTNeoForCausalLM, GPTNeoForSequenceClassification, GPTNeoModel | |
from ..hubert.modeling_hubert import HubertModel | |
from ..ibert.modeling_ibert import ( | |
IBertForMaskedLM, | |
IBertForMultipleChoice, | |
IBertForQuestionAnswering, | |
IBertForSequenceClassification, | |
IBertForTokenClassification, | |
IBertModel, | |
) | |
from ..layoutlm.modeling_layoutlm import ( | |
LayoutLMForMaskedLM, | |
LayoutLMForSequenceClassification, | |
LayoutLMForTokenClassification, | |
LayoutLMModel, | |
) | |
from ..led.modeling_led import ( | |
LEDForConditionalGeneration, | |
LEDForQuestionAnswering, | |
LEDForSequenceClassification, | |
LEDModel, | |
) | |
from ..longformer.modeling_longformer import ( | |
LongformerForMaskedLM, | |
LongformerForMultipleChoice, | |
LongformerForQuestionAnswering, | |
LongformerForSequenceClassification, | |
LongformerForTokenClassification, | |
LongformerModel, | |
) | |
from ..luke.modeling_luke import LukeModel | |
from ..lxmert.modeling_lxmert import LxmertForPreTraining, LxmertForQuestionAnswering, LxmertModel | |
from ..m2m_100.modeling_m2m_100 import M2M100ForConditionalGeneration, M2M100Model | |
from ..marian.modeling_marian import MarianForCausalLM, MarianModel, MarianMTModel | |
from ..mbart.modeling_mbart import ( | |
MBartForCausalLM, | |
MBartForConditionalGeneration, | |
MBartForQuestionAnswering, | |
MBartForSequenceClassification, | |
MBartModel, | |
) | |
from ..megatron_bert.modeling_megatron_bert import ( | |
MegatronBertForCausalLM, | |
MegatronBertForMaskedLM, | |
MegatronBertForMultipleChoice, | |
MegatronBertForNextSentencePrediction, | |
MegatronBertForPreTraining, | |
MegatronBertForQuestionAnswering, | |
MegatronBertForSequenceClassification, | |
MegatronBertForTokenClassification, | |
MegatronBertModel, | |
) | |
from ..mobilebert.modeling_mobilebert import ( | |
MobileBertForMaskedLM, | |
MobileBertForMultipleChoice, | |
MobileBertForNextSentencePrediction, | |
MobileBertForPreTraining, | |
MobileBertForQuestionAnswering, | |
MobileBertForSequenceClassification, | |
MobileBertForTokenClassification, | |
MobileBertModel, | |
) | |
from ..mpnet.modeling_mpnet import ( | |
MPNetForMaskedLM, | |
MPNetForMultipleChoice, | |
MPNetForQuestionAnswering, | |
MPNetForSequenceClassification, | |
MPNetForTokenClassification, | |
MPNetModel, | |
) | |
from ..mt5.modeling_mt5 import MT5ForConditionalGeneration, MT5Model | |
from ..openai.modeling_openai import OpenAIGPTForSequenceClassification, OpenAIGPTLMHeadModel, OpenAIGPTModel | |
from ..pegasus.modeling_pegasus import PegasusForCausalLM, PegasusForConditionalGeneration, PegasusModel | |
from ..prophetnet.modeling_prophetnet import ProphetNetForCausalLM, ProphetNetForConditionalGeneration, ProphetNetModel | |
from ..rag.modeling_rag import ( # noqa: F401 - need to import all RagModels to be in globals() function | |
RagModel, | |
RagSequenceForGeneration, | |
RagTokenForGeneration, | |
) | |
from ..reformer.modeling_reformer import ( | |
ReformerForMaskedLM, | |
ReformerForQuestionAnswering, | |
ReformerForSequenceClassification, | |
ReformerModel, | |
ReformerModelWithLMHead, | |
) | |
from ..retribert.modeling_retribert import RetriBertModel | |
from ..roberta.modeling_roberta import ( | |
RobertaForCausalLM, | |
RobertaForMaskedLM, | |
RobertaForMultipleChoice, | |
RobertaForQuestionAnswering, | |
RobertaForSequenceClassification, | |
RobertaForTokenClassification, | |
RobertaModel, | |
) | |
from ..roformer.modeling_roformer import ( | |
RoFormerForCausalLM, | |
RoFormerForMaskedLM, | |
RoFormerForMultipleChoice, | |
RoFormerForQuestionAnswering, | |
RoFormerForSequenceClassification, | |
RoFormerForTokenClassification, | |
RoFormerModel, | |
) | |
from ..speech_to_text.modeling_speech_to_text import Speech2TextForConditionalGeneration, Speech2TextModel | |
from ..squeezebert.modeling_squeezebert import ( | |
SqueezeBertForMaskedLM, | |
SqueezeBertForMultipleChoice, | |
SqueezeBertForQuestionAnswering, | |
SqueezeBertForSequenceClassification, | |
SqueezeBertForTokenClassification, | |
SqueezeBertModel, | |
) | |
from ..t5.modeling_t5 import T5ForConditionalGeneration, T5Model | |
from ..tapas.modeling_tapas import ( | |
TapasForMaskedLM, | |
TapasForQuestionAnswering, | |
TapasForSequenceClassification, | |
TapasModel, | |
) | |
from ..transfo_xl.modeling_transfo_xl import TransfoXLForSequenceClassification, TransfoXLLMHeadModel, TransfoXLModel | |
from ..visual_bert.modeling_visual_bert import VisualBertForPreTraining, VisualBertModel | |
from ..vit.modeling_vit import ViTForImageClassification, ViTModel | |
from ..wav2vec2.modeling_wav2vec2 import Wav2Vec2ForMaskedLM, Wav2Vec2ForPreTraining, Wav2Vec2Model | |
from ..xlm.modeling_xlm import ( | |
XLMForMultipleChoice, | |
XLMForQuestionAnsweringSimple, | |
XLMForSequenceClassification, | |
XLMForTokenClassification, | |
XLMModel, | |
XLMWithLMHeadModel, | |
) | |
from ..xlm_prophetnet.modeling_xlm_prophetnet import ( | |
XLMProphetNetForCausalLM, | |
XLMProphetNetForConditionalGeneration, | |
XLMProphetNetModel, | |
) | |
from ..xlm_roberta.modeling_xlm_roberta import ( | |
XLMRobertaForCausalLM, | |
XLMRobertaForMaskedLM, | |
XLMRobertaForMultipleChoice, | |
XLMRobertaForQuestionAnswering, | |
XLMRobertaForSequenceClassification, | |
XLMRobertaForTokenClassification, | |
XLMRobertaModel, | |
) | |
from ..xlnet.modeling_xlnet import ( | |
XLNetForMultipleChoice, | |
XLNetForQuestionAnsweringSimple, | |
XLNetForSequenceClassification, | |
XLNetForTokenClassification, | |
XLNetLMHeadModel, | |
XLNetModel, | |
) | |
from .auto_factory import _BaseAutoModelClass, auto_class_update | |
from .configuration_auto import ( | |
AlbertConfig, | |
BartConfig, | |
BertConfig, | |
BertGenerationConfig, | |
BigBirdConfig, | |
BigBirdPegasusConfig, | |
BlenderbotConfig, | |
BlenderbotSmallConfig, | |
CamembertConfig, | |
CanineConfig, | |
CLIPConfig, | |
ConvBertConfig, | |
CTRLConfig, | |
DebertaConfig, | |
DebertaV2Config, | |
DeiTConfig, | |
DetrConfig, | |
DistilBertConfig, | |
DPRConfig, | |
ElectraConfig, | |
EncoderDecoderConfig, | |
FlaubertConfig, | |
FSMTConfig, | |
FunnelConfig, | |
GPT2Config, | |
GPTNeoConfig, | |
HubertConfig, | |
IBertConfig, | |
LayoutLMConfig, | |
LEDConfig, | |
LongformerConfig, | |
LukeConfig, | |
LxmertConfig, | |
M2M100Config, | |
MarianConfig, | |
MBartConfig, | |
MegatronBertConfig, | |
MobileBertConfig, | |
MPNetConfig, | |
MT5Config, | |
OpenAIGPTConfig, | |
PegasusConfig, | |
ProphetNetConfig, | |
ReformerConfig, | |
RetriBertConfig, | |
RobertaConfig, | |
RoFormerConfig, | |
Speech2TextConfig, | |
SqueezeBertConfig, | |
T5Config, | |
TapasConfig, | |
TransfoXLConfig, | |
VisualBertConfig, | |
ViTConfig, | |
Wav2Vec2Config, | |
XLMConfig, | |
XLMProphetNetConfig, | |
XLMRobertaConfig, | |
XLNetConfig, | |
) | |
logger = logging.get_logger(__name__) | |
MODEL_MAPPING = OrderedDict( | |
[ | |
# Base model mapping | |
(VisualBertConfig, VisualBertModel), | |
(CanineConfig, CanineModel), | |
(RoFormerConfig, RoFormerModel), | |
(CLIPConfig, CLIPModel), | |
(BigBirdPegasusConfig, BigBirdPegasusModel), | |
(DeiTConfig, DeiTModel), | |
(LukeConfig, LukeModel), | |
(DetrConfig, DetrModel), | |
(GPTNeoConfig, GPTNeoModel), | |
(BigBirdConfig, BigBirdModel), | |
(Speech2TextConfig, Speech2TextModel), | |
(ViTConfig, ViTModel), | |
(Wav2Vec2Config, Wav2Vec2Model), | |
(HubertConfig, HubertModel), | |
(M2M100Config, M2M100Model), | |
(ConvBertConfig, ConvBertModel), | |
(LEDConfig, LEDModel), | |
(BlenderbotSmallConfig, BlenderbotSmallModel), | |
(RetriBertConfig, RetriBertModel), | |
(MT5Config, MT5Model), | |
(T5Config, T5Model), | |
(PegasusConfig, PegasusModel), | |
(MarianConfig, MarianMTModel), | |
(MBartConfig, MBartModel), | |
(BlenderbotConfig, BlenderbotModel), | |
(DistilBertConfig, DistilBertModel), | |
(AlbertConfig, AlbertModel), | |
(CamembertConfig, CamembertModel), | |
(XLMRobertaConfig, XLMRobertaModel), | |
(BartConfig, BartModel), | |
(LongformerConfig, LongformerModel), | |
(RobertaConfig, RobertaModel), | |
(LayoutLMConfig, LayoutLMModel), | |
(SqueezeBertConfig, SqueezeBertModel), | |
(BertConfig, BertModel), | |
(OpenAIGPTConfig, OpenAIGPTModel), | |
(GPT2Config, GPT2Model), | |
(MegatronBertConfig, MegatronBertModel), | |
(MobileBertConfig, MobileBertModel), | |
(TransfoXLConfig, TransfoXLModel), | |
(XLNetConfig, XLNetModel), | |
(FlaubertConfig, FlaubertModel), | |
(FSMTConfig, FSMTModel), | |
(XLMConfig, XLMModel), | |
(CTRLConfig, CTRLModel), | |
(ElectraConfig, ElectraModel), | |
(ReformerConfig, ReformerModel), | |
(FunnelConfig, (FunnelModel, FunnelBaseModel)), | |
(LxmertConfig, LxmertModel), | |
(BertGenerationConfig, BertGenerationEncoder), | |
(DebertaConfig, DebertaModel), | |
(DebertaV2Config, DebertaV2Model), | |
(DPRConfig, DPRQuestionEncoder), | |
(XLMProphetNetConfig, XLMProphetNetModel), | |
(ProphetNetConfig, ProphetNetModel), | |
(MPNetConfig, MPNetModel), | |
(TapasConfig, TapasModel), | |
(MarianConfig, MarianModel), | |
(IBertConfig, IBertModel), | |
] | |
) | |
MODEL_FOR_PRETRAINING_MAPPING = OrderedDict( | |
[ | |
# Model for pre-training mapping | |
(VisualBertConfig, VisualBertForPreTraining), | |
(LayoutLMConfig, LayoutLMForMaskedLM), | |
(RetriBertConfig, RetriBertModel), | |
(T5Config, T5ForConditionalGeneration), | |
(DistilBertConfig, DistilBertForMaskedLM), | |
(AlbertConfig, AlbertForPreTraining), | |
(CamembertConfig, CamembertForMaskedLM), | |
(XLMRobertaConfig, XLMRobertaForMaskedLM), | |
(BartConfig, BartForConditionalGeneration), | |
(FSMTConfig, FSMTForConditionalGeneration), | |
(LongformerConfig, LongformerForMaskedLM), | |
(RobertaConfig, RobertaForMaskedLM), | |
(SqueezeBertConfig, SqueezeBertForMaskedLM), | |
(BertConfig, BertForPreTraining), | |
(BigBirdConfig, BigBirdForPreTraining), | |
(OpenAIGPTConfig, OpenAIGPTLMHeadModel), | |
(GPT2Config, GPT2LMHeadModel), | |
(MegatronBertConfig, MegatronBertForPreTraining), | |
(MobileBertConfig, MobileBertForPreTraining), | |
(TransfoXLConfig, TransfoXLLMHeadModel), | |
(XLNetConfig, XLNetLMHeadModel), | |
(FlaubertConfig, FlaubertWithLMHeadModel), | |
(XLMConfig, XLMWithLMHeadModel), | |
(CTRLConfig, CTRLLMHeadModel), | |
(ElectraConfig, ElectraForPreTraining), | |
(LxmertConfig, LxmertForPreTraining), | |
(FunnelConfig, FunnelForPreTraining), | |
(MPNetConfig, MPNetForMaskedLM), | |
(TapasConfig, TapasForMaskedLM), | |
(IBertConfig, IBertForMaskedLM), | |
(DebertaConfig, DebertaForMaskedLM), | |
(DebertaV2Config, DebertaV2ForMaskedLM), | |
(Wav2Vec2Config, Wav2Vec2ForPreTraining), | |
] | |
) | |
MODEL_WITH_LM_HEAD_MAPPING = OrderedDict( | |
[ | |
# Model with LM heads mapping | |
(RoFormerConfig, RoFormerForMaskedLM), | |
(BigBirdPegasusConfig, BigBirdPegasusForConditionalGeneration), | |
(GPTNeoConfig, GPTNeoForCausalLM), | |
(BigBirdConfig, BigBirdForMaskedLM), | |
(Speech2TextConfig, Speech2TextForConditionalGeneration), | |
(Wav2Vec2Config, Wav2Vec2ForMaskedLM), | |
(M2M100Config, M2M100ForConditionalGeneration), | |
(ConvBertConfig, ConvBertForMaskedLM), | |
(LEDConfig, LEDForConditionalGeneration), | |
(BlenderbotSmallConfig, BlenderbotSmallForConditionalGeneration), | |
(LayoutLMConfig, LayoutLMForMaskedLM), | |
(T5Config, T5ForConditionalGeneration), | |
(DistilBertConfig, DistilBertForMaskedLM), | |
(AlbertConfig, AlbertForMaskedLM), | |
(CamembertConfig, CamembertForMaskedLM), | |
(XLMRobertaConfig, XLMRobertaForMaskedLM), | |
(MarianConfig, MarianMTModel), | |
(FSMTConfig, FSMTForConditionalGeneration), | |
(BartConfig, BartForConditionalGeneration), | |
(LongformerConfig, LongformerForMaskedLM), | |
(RobertaConfig, RobertaForMaskedLM), | |
(SqueezeBertConfig, SqueezeBertForMaskedLM), | |
(BertConfig, BertForMaskedLM), | |
(OpenAIGPTConfig, OpenAIGPTLMHeadModel), | |
(GPT2Config, GPT2LMHeadModel), | |
(MegatronBertConfig, MegatronBertForMaskedLM), | |
(MobileBertConfig, MobileBertForMaskedLM), | |
(TransfoXLConfig, TransfoXLLMHeadModel), | |
(XLNetConfig, XLNetLMHeadModel), | |
(FlaubertConfig, FlaubertWithLMHeadModel), | |
(XLMConfig, XLMWithLMHeadModel), | |
(CTRLConfig, CTRLLMHeadModel), | |
(ElectraConfig, ElectraForMaskedLM), | |
(EncoderDecoderConfig, EncoderDecoderModel), | |
(ReformerConfig, ReformerModelWithLMHead), | |
(FunnelConfig, FunnelForMaskedLM), | |
(MPNetConfig, MPNetForMaskedLM), | |
(TapasConfig, TapasForMaskedLM), | |
(DebertaConfig, DebertaForMaskedLM), | |
(DebertaV2Config, DebertaV2ForMaskedLM), | |
(IBertConfig, IBertForMaskedLM), | |
(MegatronBertConfig, MegatronBertForCausalLM), | |
] | |
) | |
MODEL_FOR_CAUSAL_LM_MAPPING = OrderedDict( | |
[ | |
# Model for Causal LM mapping | |
(RoFormerConfig, RoFormerForCausalLM), | |
(BigBirdPegasusConfig, BigBirdPegasusForCausalLM), | |
(GPTNeoConfig, GPTNeoForCausalLM), | |
(BigBirdConfig, BigBirdForCausalLM), | |
(CamembertConfig, CamembertForCausalLM), | |
(XLMRobertaConfig, XLMRobertaForCausalLM), | |
(RobertaConfig, RobertaForCausalLM), | |
(BertConfig, BertLMHeadModel), | |
(OpenAIGPTConfig, OpenAIGPTLMHeadModel), | |
(GPT2Config, GPT2LMHeadModel), | |
(TransfoXLConfig, TransfoXLLMHeadModel), | |
(XLNetConfig, XLNetLMHeadModel), | |
( | |
XLMConfig, | |
XLMWithLMHeadModel, | |
), # XLM can be MLM and CLM => model should be split similar to BERT; leave here for now | |
(CTRLConfig, CTRLLMHeadModel), | |
(ReformerConfig, ReformerModelWithLMHead), | |
(BertGenerationConfig, BertGenerationDecoder), | |
(XLMProphetNetConfig, XLMProphetNetForCausalLM), | |
(ProphetNetConfig, ProphetNetForCausalLM), | |
(BartConfig, BartForCausalLM), | |
(MBartConfig, MBartForCausalLM), | |
(PegasusConfig, PegasusForCausalLM), | |
(MarianConfig, MarianForCausalLM), | |
(BlenderbotConfig, BlenderbotForCausalLM), | |
(BlenderbotSmallConfig, BlenderbotSmallForCausalLM), | |
(MegatronBertConfig, MegatronBertForCausalLM), | |
] | |
) | |
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = OrderedDict( | |
[ | |
# Model for Image Classification mapping | |
(ViTConfig, ViTForImageClassification), | |
(DeiTConfig, (DeiTForImageClassification, DeiTForImageClassificationWithTeacher)), | |
] | |
) | |
MODEL_FOR_MASKED_LM_MAPPING = OrderedDict( | |
[ | |
# Model for Masked LM mapping | |
(RoFormerConfig, RoFormerForMaskedLM), | |
(BigBirdConfig, BigBirdForMaskedLM), | |
(Wav2Vec2Config, Wav2Vec2ForMaskedLM), | |
(ConvBertConfig, ConvBertForMaskedLM), | |
(LayoutLMConfig, LayoutLMForMaskedLM), | |
(DistilBertConfig, DistilBertForMaskedLM), | |
(AlbertConfig, AlbertForMaskedLM), | |
(BartConfig, BartForConditionalGeneration), | |
(MBartConfig, MBartForConditionalGeneration), | |
(CamembertConfig, CamembertForMaskedLM), | |
(XLMRobertaConfig, XLMRobertaForMaskedLM), | |
(LongformerConfig, LongformerForMaskedLM), | |
(RobertaConfig, RobertaForMaskedLM), | |
(SqueezeBertConfig, SqueezeBertForMaskedLM), | |
(BertConfig, BertForMaskedLM), | |
(MegatronBertConfig, MegatronBertForMaskedLM), | |
(MobileBertConfig, MobileBertForMaskedLM), | |
(FlaubertConfig, FlaubertWithLMHeadModel), | |
(XLMConfig, XLMWithLMHeadModel), | |
(ElectraConfig, ElectraForMaskedLM), | |
(ReformerConfig, ReformerForMaskedLM), | |
(FunnelConfig, FunnelForMaskedLM), | |
(MPNetConfig, MPNetForMaskedLM), | |
(TapasConfig, TapasForMaskedLM), | |
(DebertaConfig, DebertaForMaskedLM), | |
(DebertaV2Config, DebertaV2ForMaskedLM), | |
(IBertConfig, IBertForMaskedLM), | |
] | |
) | |
MODEL_FOR_OBJECT_DETECTION_MAPPING = OrderedDict( | |
[ | |
# Model for Object Detection mapping | |
(DetrConfig, DetrForObjectDetection), | |
] | |
) | |
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict( | |
[ | |
# Model for Seq2Seq Causal LM mapping | |
(BigBirdPegasusConfig, BigBirdPegasusForConditionalGeneration), | |
(M2M100Config, M2M100ForConditionalGeneration), | |
(LEDConfig, LEDForConditionalGeneration), | |
(BlenderbotSmallConfig, BlenderbotSmallForConditionalGeneration), | |
(MT5Config, MT5ForConditionalGeneration), | |
(T5Config, T5ForConditionalGeneration), | |
(PegasusConfig, PegasusForConditionalGeneration), | |
(MarianConfig, MarianMTModel), | |
(MBartConfig, MBartForConditionalGeneration), | |
(BlenderbotConfig, BlenderbotForConditionalGeneration), | |
(BartConfig, BartForConditionalGeneration), | |
(FSMTConfig, FSMTForConditionalGeneration), | |
(EncoderDecoderConfig, EncoderDecoderModel), | |
(XLMProphetNetConfig, XLMProphetNetForConditionalGeneration), | |
(ProphetNetConfig, ProphetNetForConditionalGeneration), | |
] | |
) | |
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict( | |
[ | |
# Model for Sequence Classification mapping | |
(CanineConfig, CanineForSequenceClassification), | |
(RoFormerConfig, RoFormerForSequenceClassification), | |
(BigBirdPegasusConfig, BigBirdPegasusForSequenceClassification), | |
(BigBirdConfig, BigBirdForSequenceClassification), | |
(ConvBertConfig, ConvBertForSequenceClassification), | |
(LEDConfig, LEDForSequenceClassification), | |
(DistilBertConfig, DistilBertForSequenceClassification), | |
(AlbertConfig, AlbertForSequenceClassification), | |
(CamembertConfig, CamembertForSequenceClassification), | |
(XLMRobertaConfig, XLMRobertaForSequenceClassification), | |
(MBartConfig, MBartForSequenceClassification), | |
(BartConfig, BartForSequenceClassification), | |
(LongformerConfig, LongformerForSequenceClassification), | |
(RobertaConfig, RobertaForSequenceClassification), | |
(SqueezeBertConfig, SqueezeBertForSequenceClassification), | |
(LayoutLMConfig, LayoutLMForSequenceClassification), | |
(BertConfig, BertForSequenceClassification), | |
(XLNetConfig, XLNetForSequenceClassification), | |
(MegatronBertConfig, MegatronBertForSequenceClassification), | |
(MobileBertConfig, MobileBertForSequenceClassification), | |
(FlaubertConfig, FlaubertForSequenceClassification), | |
(XLMConfig, XLMForSequenceClassification), | |
(ElectraConfig, ElectraForSequenceClassification), | |
(FunnelConfig, FunnelForSequenceClassification), | |
(DebertaConfig, DebertaForSequenceClassification), | |
(DebertaV2Config, DebertaV2ForSequenceClassification), | |
(GPT2Config, GPT2ForSequenceClassification), | |
(GPTNeoConfig, GPTNeoForSequenceClassification), | |
(OpenAIGPTConfig, OpenAIGPTForSequenceClassification), | |
(ReformerConfig, ReformerForSequenceClassification), | |
(CTRLConfig, CTRLForSequenceClassification), | |
(TransfoXLConfig, TransfoXLForSequenceClassification), | |
(MPNetConfig, MPNetForSequenceClassification), | |
(TapasConfig, TapasForSequenceClassification), | |
(IBertConfig, IBertForSequenceClassification), | |
] | |
) | |
MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict( | |
[ | |
# Model for Question Answering mapping | |
(CanineConfig, CanineForQuestionAnswering), | |
(RoFormerConfig, RoFormerForQuestionAnswering), | |
(BigBirdPegasusConfig, BigBirdPegasusForQuestionAnswering), | |
(BigBirdConfig, BigBirdForQuestionAnswering), | |
(ConvBertConfig, ConvBertForQuestionAnswering), | |
(LEDConfig, LEDForQuestionAnswering), | |
(DistilBertConfig, DistilBertForQuestionAnswering), | |
(AlbertConfig, AlbertForQuestionAnswering), | |
(CamembertConfig, CamembertForQuestionAnswering), | |
(BartConfig, BartForQuestionAnswering), | |
(MBartConfig, MBartForQuestionAnswering), | |
(LongformerConfig, LongformerForQuestionAnswering), | |
(XLMRobertaConfig, XLMRobertaForQuestionAnswering), | |
(RobertaConfig, RobertaForQuestionAnswering), | |
(SqueezeBertConfig, SqueezeBertForQuestionAnswering), | |
(BertConfig, BertForQuestionAnswering), | |
(XLNetConfig, XLNetForQuestionAnsweringSimple), | |
(FlaubertConfig, FlaubertForQuestionAnsweringSimple), | |
(MegatronBertConfig, MegatronBertForQuestionAnswering), | |
(MobileBertConfig, MobileBertForQuestionAnswering), | |
(XLMConfig, XLMForQuestionAnsweringSimple), | |
(ElectraConfig, ElectraForQuestionAnswering), | |
(ReformerConfig, ReformerForQuestionAnswering), | |
(FunnelConfig, FunnelForQuestionAnswering), | |
(LxmertConfig, LxmertForQuestionAnswering), | |
(MPNetConfig, MPNetForQuestionAnswering), | |
(DebertaConfig, DebertaForQuestionAnswering), | |
(DebertaV2Config, DebertaV2ForQuestionAnswering), | |
(IBertConfig, IBertForQuestionAnswering), | |
] | |
) | |
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = OrderedDict( | |
[ | |
# Model for Table Question Answering mapping | |
(TapasConfig, TapasForQuestionAnswering), | |
] | |
) | |
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict( | |
[ | |
# Model for Token Classification mapping | |
(CanineConfig, CanineForTokenClassification), | |
(RoFormerConfig, RoFormerForTokenClassification), | |
(BigBirdConfig, BigBirdForTokenClassification), | |
(ConvBertConfig, ConvBertForTokenClassification), | |
(LayoutLMConfig, LayoutLMForTokenClassification), | |
(DistilBertConfig, DistilBertForTokenClassification), | |
(CamembertConfig, CamembertForTokenClassification), | |
(FlaubertConfig, FlaubertForTokenClassification), | |
(XLMConfig, XLMForTokenClassification), | |
(XLMRobertaConfig, XLMRobertaForTokenClassification), | |
(LongformerConfig, LongformerForTokenClassification), | |
(RobertaConfig, RobertaForTokenClassification), | |
(SqueezeBertConfig, SqueezeBertForTokenClassification), | |
(BertConfig, BertForTokenClassification), | |
(MegatronBertConfig, MegatronBertForTokenClassification), | |
(MobileBertConfig, MobileBertForTokenClassification), | |
(XLNetConfig, XLNetForTokenClassification), | |
(AlbertConfig, AlbertForTokenClassification), | |
(ElectraConfig, ElectraForTokenClassification), | |
(FlaubertConfig, FlaubertForTokenClassification), | |
(FunnelConfig, FunnelForTokenClassification), | |
(MPNetConfig, MPNetForTokenClassification), | |
(DebertaConfig, DebertaForTokenClassification), | |
(DebertaV2Config, DebertaV2ForTokenClassification), | |
(IBertConfig, IBertForTokenClassification), | |
] | |
) | |
MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict( | |
[ | |
# Model for Multiple Choice mapping | |
(CanineConfig, CanineForMultipleChoice), | |
(RoFormerConfig, RoFormerForMultipleChoice), | |
(BigBirdConfig, BigBirdForMultipleChoice), | |
(ConvBertConfig, ConvBertForMultipleChoice), | |
(CamembertConfig, CamembertForMultipleChoice), | |
(ElectraConfig, ElectraForMultipleChoice), | |
(XLMRobertaConfig, XLMRobertaForMultipleChoice), | |
(LongformerConfig, LongformerForMultipleChoice), | |
(RobertaConfig, RobertaForMultipleChoice), | |
(SqueezeBertConfig, SqueezeBertForMultipleChoice), | |
(BertConfig, BertForMultipleChoice), | |
(DistilBertConfig, DistilBertForMultipleChoice), | |
(MegatronBertConfig, MegatronBertForMultipleChoice), | |
(MobileBertConfig, MobileBertForMultipleChoice), | |
(XLNetConfig, XLNetForMultipleChoice), | |
(AlbertConfig, AlbertForMultipleChoice), | |
(XLMConfig, XLMForMultipleChoice), | |
(FlaubertConfig, FlaubertForMultipleChoice), | |
(FunnelConfig, FunnelForMultipleChoice), | |
(MPNetConfig, MPNetForMultipleChoice), | |
(IBertConfig, IBertForMultipleChoice), | |
] | |
) | |
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = OrderedDict( | |
[ | |
(BertConfig, BertForNextSentencePrediction), | |
(MegatronBertConfig, MegatronBertForNextSentencePrediction), | |
(MobileBertConfig, MobileBertForNextSentencePrediction), | |
] | |
) | |
class AutoModel(_BaseAutoModelClass): | |
_model_mapping = MODEL_MAPPING | |
AutoModel = auto_class_update(AutoModel) | |
class AutoModelForPreTraining(_BaseAutoModelClass): | |
_model_mapping = MODEL_FOR_PRETRAINING_MAPPING | |
AutoModelForPreTraining = auto_class_update(AutoModelForPreTraining, head_doc="pretraining") | |
# Private on purpose, the public class will add the deprecation warnings. | |
class _AutoModelWithLMHead(_BaseAutoModelClass): | |
_model_mapping = MODEL_WITH_LM_HEAD_MAPPING | |
_AutoModelWithLMHead = auto_class_update(_AutoModelWithLMHead, head_doc="language modeling") | |
class AutoModelForCausalLM(_BaseAutoModelClass): | |
_model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING | |
AutoModelForCausalLM = auto_class_update(AutoModelForCausalLM, head_doc="causal language modeling") | |
class AutoModelForMaskedLM(_BaseAutoModelClass): | |
_model_mapping = MODEL_FOR_MASKED_LM_MAPPING | |
AutoModelForMaskedLM = auto_class_update(AutoModelForMaskedLM, head_doc="masked language modeling") | |
class AutoModelForSeq2SeqLM(_BaseAutoModelClass): | |
_model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING | |
AutoModelForSeq2SeqLM = auto_class_update( | |
AutoModelForSeq2SeqLM, head_doc="sequence-to-sequence language modeling", checkpoint_for_example="t5-base" | |
) | |
class AutoModelForSequenceClassification(_BaseAutoModelClass): | |
_model_mapping = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING | |
AutoModelForSequenceClassification = auto_class_update( | |
AutoModelForSequenceClassification, head_doc="sequence classification" | |
) | |
class AutoModelForQuestionAnswering(_BaseAutoModelClass): | |
_model_mapping = MODEL_FOR_QUESTION_ANSWERING_MAPPING | |
AutoModelForQuestionAnswering = auto_class_update(AutoModelForQuestionAnswering, head_doc="question answering") | |
class AutoModelForTableQuestionAnswering(_BaseAutoModelClass): | |
_model_mapping = MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING | |
AutoModelForTableQuestionAnswering = auto_class_update( | |
AutoModelForTableQuestionAnswering, | |
head_doc="table question answering", | |
checkpoint_for_example="google/tapas-base-finetuned-wtq", | |
) | |
class AutoModelForTokenClassification(_BaseAutoModelClass): | |
_model_mapping = MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING | |
AutoModelForTokenClassification = auto_class_update(AutoModelForTokenClassification, head_doc="token classification") | |
class AutoModelForMultipleChoice(_BaseAutoModelClass): | |
_model_mapping = MODEL_FOR_MULTIPLE_CHOICE_MAPPING | |
AutoModelForMultipleChoice = auto_class_update(AutoModelForMultipleChoice, head_doc="multiple choice") | |
class AutoModelForNextSentencePrediction(_BaseAutoModelClass): | |
_model_mapping = MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING | |
AutoModelForNextSentencePrediction = auto_class_update( | |
AutoModelForNextSentencePrediction, head_doc="next sentence prediction" | |
) | |
class AutoModelForImageClassification(_BaseAutoModelClass): | |
_model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING | |
AutoModelForImageClassification = auto_class_update(AutoModelForImageClassification, head_doc="image classification") | |
class AutoModelWithLMHead(_AutoModelWithLMHead): | |
def from_config(cls, config): | |
warnings.warn( | |
"The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use " | |
"`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and " | |
"`AutoModelForSeq2SeqLM` for encoder-decoder models.", | |
FutureWarning, | |
) | |
return super().from_config(config) | |
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): | |
warnings.warn( | |
"The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use " | |
"`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and " | |
"`AutoModelForSeq2SeqLM` for encoder-decoder models.", | |
FutureWarning, | |
) | |
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) | |