Spaces:
Sleeping
Sleeping
# coding=utf-8 | |
# Copyright 2018 The Google Flax Team Authors and The HuggingFace Inc. team. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
""" Auto Model class. """ | |
from collections import OrderedDict | |
from ...utils import logging | |
from ..bart.modeling_flax_bart import ( | |
FlaxBartForConditionalGeneration, | |
FlaxBartForQuestionAnswering, | |
FlaxBartForSequenceClassification, | |
FlaxBartModel, | |
) | |
from ..bert.modeling_flax_bert import ( | |
FlaxBertForMaskedLM, | |
FlaxBertForMultipleChoice, | |
FlaxBertForNextSentencePrediction, | |
FlaxBertForPreTraining, | |
FlaxBertForQuestionAnswering, | |
FlaxBertForSequenceClassification, | |
FlaxBertForTokenClassification, | |
FlaxBertModel, | |
) | |
from ..big_bird.modeling_flax_big_bird import ( | |
FlaxBigBirdForMaskedLM, | |
FlaxBigBirdForMultipleChoice, | |
FlaxBigBirdForPreTraining, | |
FlaxBigBirdForQuestionAnswering, | |
FlaxBigBirdForSequenceClassification, | |
FlaxBigBirdForTokenClassification, | |
FlaxBigBirdModel, | |
) | |
from ..clip.modeling_flax_clip import FlaxCLIPModel | |
from ..electra.modeling_flax_electra import ( | |
FlaxElectraForMaskedLM, | |
FlaxElectraForMultipleChoice, | |
FlaxElectraForPreTraining, | |
FlaxElectraForQuestionAnswering, | |
FlaxElectraForSequenceClassification, | |
FlaxElectraForTokenClassification, | |
FlaxElectraModel, | |
) | |
from ..gpt2.modeling_flax_gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model | |
from ..gpt_neo.modeling_flax_gpt_neo import FlaxGPTNeoForCausalLM, FlaxGPTNeoModel | |
from ..marian.modeling_flax_marian import FlaxMarianModel, FlaxMarianMTModel | |
from ..mbart.modeling_flax_mbart import ( | |
FlaxMBartForConditionalGeneration, | |
FlaxMBartForQuestionAnswering, | |
FlaxMBartForSequenceClassification, | |
FlaxMBartModel, | |
) | |
from ..roberta.modeling_flax_roberta import ( | |
FlaxRobertaForMaskedLM, | |
FlaxRobertaForMultipleChoice, | |
FlaxRobertaForQuestionAnswering, | |
FlaxRobertaForSequenceClassification, | |
FlaxRobertaForTokenClassification, | |
FlaxRobertaModel, | |
) | |
from ..t5.modeling_flax_t5 import FlaxT5ForConditionalGeneration, FlaxT5Model | |
from ..vit.modeling_flax_vit import FlaxViTForImageClassification, FlaxViTModel | |
from ..wav2vec2.modeling_flax_wav2vec2 import FlaxWav2Vec2ForPreTraining, FlaxWav2Vec2Model | |
from .auto_factory import _BaseAutoModelClass, auto_class_update | |
from .configuration_auto import ( | |
BartConfig, | |
BertConfig, | |
BigBirdConfig, | |
CLIPConfig, | |
ElectraConfig, | |
GPT2Config, | |
GPTNeoConfig, | |
MarianConfig, | |
MBartConfig, | |
MT5Config, | |
RobertaConfig, | |
T5Config, | |
ViTConfig, | |
Wav2Vec2Config, | |
) | |
logger = logging.get_logger(__name__) | |
FLAX_MODEL_MAPPING = OrderedDict( | |
[ | |
# Base model mapping | |
(RobertaConfig, FlaxRobertaModel), | |
(BertConfig, FlaxBertModel), | |
(BigBirdConfig, FlaxBigBirdModel), | |
(BartConfig, FlaxBartModel), | |
(GPT2Config, FlaxGPT2Model), | |
(GPTNeoConfig, FlaxGPTNeoModel), | |
(ElectraConfig, FlaxElectraModel), | |
(CLIPConfig, FlaxCLIPModel), | |
(ViTConfig, FlaxViTModel), | |
(MBartConfig, FlaxMBartModel), | |
(T5Config, FlaxT5Model), | |
(MT5Config, FlaxT5Model), | |
(Wav2Vec2Config, FlaxWav2Vec2Model), | |
(MarianConfig, FlaxMarianModel), | |
] | |
) | |
FLAX_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict( | |
[ | |
# Model for pre-training mapping | |
(RobertaConfig, FlaxRobertaForMaskedLM), | |
(BertConfig, FlaxBertForPreTraining), | |
(BigBirdConfig, FlaxBigBirdForPreTraining), | |
(BartConfig, FlaxBartForConditionalGeneration), | |
(ElectraConfig, FlaxElectraForPreTraining), | |
(MBartConfig, FlaxMBartForConditionalGeneration), | |
(T5Config, FlaxT5ForConditionalGeneration), | |
(MT5Config, FlaxT5ForConditionalGeneration), | |
(Wav2Vec2Config, FlaxWav2Vec2ForPreTraining), | |
] | |
) | |
FLAX_MODEL_FOR_MASKED_LM_MAPPING = OrderedDict( | |
[ | |
# Model for Masked LM mapping | |
(RobertaConfig, FlaxRobertaForMaskedLM), | |
(BertConfig, FlaxBertForMaskedLM), | |
(BigBirdConfig, FlaxBigBirdForMaskedLM), | |
(BartConfig, FlaxBartForConditionalGeneration), | |
(ElectraConfig, FlaxElectraForMaskedLM), | |
(MBartConfig, FlaxMBartForConditionalGeneration), | |
] | |
) | |
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict( | |
[ | |
# Model for Seq2Seq Causal LM mapping | |
(BartConfig, FlaxBartForConditionalGeneration), | |
(T5Config, FlaxT5ForConditionalGeneration), | |
(MT5Config, FlaxT5ForConditionalGeneration), | |
(MarianConfig, FlaxMarianMTModel), | |
] | |
) | |
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = OrderedDict( | |
[ | |
# Model for Image-classsification | |
(ViTConfig, FlaxViTForImageClassification), | |
] | |
) | |
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING = OrderedDict( | |
[ | |
# Model for Causal LM mapping | |
(GPT2Config, FlaxGPT2LMHeadModel), | |
(GPTNeoConfig, FlaxGPTNeoForCausalLM), | |
] | |
) | |
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict( | |
[ | |
# Model for Sequence Classification mapping | |
(RobertaConfig, FlaxRobertaForSequenceClassification), | |
(BertConfig, FlaxBertForSequenceClassification), | |
(BigBirdConfig, FlaxBigBirdForSequenceClassification), | |
(BartConfig, FlaxBartForSequenceClassification), | |
(ElectraConfig, FlaxElectraForSequenceClassification), | |
(MBartConfig, FlaxMBartForSequenceClassification), | |
] | |
) | |
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict( | |
[ | |
# Model for Question Answering mapping | |
(RobertaConfig, FlaxRobertaForQuestionAnswering), | |
(BertConfig, FlaxBertForQuestionAnswering), | |
(BigBirdConfig, FlaxBigBirdForQuestionAnswering), | |
(BartConfig, FlaxBartForQuestionAnswering), | |
(ElectraConfig, FlaxElectraForQuestionAnswering), | |
(MBartConfig, FlaxMBartForQuestionAnswering), | |
] | |
) | |
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict( | |
[ | |
# Model for Token Classification mapping | |
(RobertaConfig, FlaxRobertaForTokenClassification), | |
(BertConfig, FlaxBertForTokenClassification), | |
(BigBirdConfig, FlaxBigBirdForTokenClassification), | |
(ElectraConfig, FlaxElectraForTokenClassification), | |
] | |
) | |
FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict( | |
[ | |
# Model for Multiple Choice mapping | |
(RobertaConfig, FlaxRobertaForMultipleChoice), | |
(BertConfig, FlaxBertForMultipleChoice), | |
(BigBirdConfig, FlaxBigBirdForMultipleChoice), | |
(ElectraConfig, FlaxElectraForMultipleChoice), | |
] | |
) | |
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = OrderedDict( | |
[ | |
(BertConfig, FlaxBertForNextSentencePrediction), | |
] | |
) | |
class FlaxAutoModel(_BaseAutoModelClass): | |
_model_mapping = FLAX_MODEL_MAPPING | |
FlaxAutoModel = auto_class_update(FlaxAutoModel) | |
class FlaxAutoModelForPreTraining(_BaseAutoModelClass): | |
_model_mapping = FLAX_MODEL_FOR_PRETRAINING_MAPPING | |
FlaxAutoModelForPreTraining = auto_class_update(FlaxAutoModelForPreTraining, head_doc="pretraining") | |
class FlaxAutoModelForCausalLM(_BaseAutoModelClass): | |
_model_mapping = FLAX_MODEL_FOR_CAUSAL_LM_MAPPING | |
FlaxAutoModelForCausalLM = auto_class_update(FlaxAutoModelForCausalLM, head_doc="causal language modeling") | |
class FlaxAutoModelForMaskedLM(_BaseAutoModelClass): | |
_model_mapping = FLAX_MODEL_FOR_MASKED_LM_MAPPING | |
FlaxAutoModelForMaskedLM = auto_class_update(FlaxAutoModelForMaskedLM, head_doc="masked language modeling") | |
class FlaxAutoModelForSeq2SeqLM(_BaseAutoModelClass): | |
_model_mapping = FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING | |
FlaxAutoModelForSeq2SeqLM = auto_class_update( | |
FlaxAutoModelForSeq2SeqLM, head_doc="sequence-to-sequence language modeling", checkpoint_for_example="t5-base" | |
) | |
class FlaxAutoModelForSequenceClassification(_BaseAutoModelClass): | |
_model_mapping = FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING | |
FlaxAutoModelForSequenceClassification = auto_class_update( | |
FlaxAutoModelForSequenceClassification, head_doc="sequence classification" | |
) | |
class FlaxAutoModelForQuestionAnswering(_BaseAutoModelClass): | |
_model_mapping = FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING | |
FlaxAutoModelForQuestionAnswering = auto_class_update(FlaxAutoModelForQuestionAnswering, head_doc="question answering") | |
class FlaxAutoModelForTokenClassification(_BaseAutoModelClass): | |
_model_mapping = FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING | |
FlaxAutoModelForTokenClassification = auto_class_update( | |
FlaxAutoModelForTokenClassification, head_doc="token classification" | |
) | |
class FlaxAutoModelForMultipleChoice(_BaseAutoModelClass): | |
_model_mapping = FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING | |
FlaxAutoModelForMultipleChoice = auto_class_update(FlaxAutoModelForMultipleChoice, head_doc="multiple choice") | |
class FlaxAutoModelForNextSentencePrediction(_BaseAutoModelClass): | |
_model_mapping = FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING | |
FlaxAutoModelForNextSentencePrediction = auto_class_update( | |
FlaxAutoModelForNextSentencePrediction, head_doc="next sentence prediction" | |
) | |
class FlaxAutoModelForImageClassification(_BaseAutoModelClass): | |
_model_mapping = FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING | |
FlaxAutoModelForImageClassification = auto_class_update( | |
FlaxAutoModelForImageClassification, head_doc="image classification" | |
) | |