Spaces:
Sleeping
Sleeping
# coding=utf-8 | |
# Copyright 2018 The HuggingFace Inc. team. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
""" Auto Config class. """ | |
import re | |
from collections import OrderedDict | |
from ...configuration_utils import PretrainedConfig | |
from ..albert.configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig | |
from ..bart.configuration_bart import BART_PRETRAINED_CONFIG_ARCHIVE_MAP, BartConfig | |
from ..bert.configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig | |
from ..bert_generation.configuration_bert_generation import BertGenerationConfig | |
from ..big_bird.configuration_big_bird import BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP, BigBirdConfig | |
from ..bigbird_pegasus.configuration_bigbird_pegasus import ( | |
BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
BigBirdPegasusConfig, | |
) | |
from ..blenderbot.configuration_blenderbot import BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP, BlenderbotConfig | |
from ..blenderbot_small.configuration_blenderbot_small import ( | |
BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
BlenderbotSmallConfig, | |
) | |
from ..camembert.configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig | |
from ..canine.configuration_canine import CANINE_PRETRAINED_CONFIG_ARCHIVE_MAP, CanineConfig | |
from ..clip.configuration_clip import CLIP_PRETRAINED_CONFIG_ARCHIVE_MAP, CLIPConfig | |
from ..convbert.configuration_convbert import CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, ConvBertConfig | |
from ..ctrl.configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig | |
from ..deberta.configuration_deberta import DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, DebertaConfig | |
from ..deberta_v2.configuration_deberta_v2 import DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP, DebertaV2Config | |
from ..deit.configuration_deit import DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP, DeiTConfig | |
from ..detr.configuration_detr import DETR_PRETRAINED_CONFIG_ARCHIVE_MAP, DetrConfig | |
from ..distilbert.configuration_distilbert import DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DistilBertConfig | |
from ..dpr.configuration_dpr import DPR_PRETRAINED_CONFIG_ARCHIVE_MAP, DPRConfig | |
from ..electra.configuration_electra import ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, ElectraConfig | |
from ..encoder_decoder.configuration_encoder_decoder import EncoderDecoderConfig | |
from ..flaubert.configuration_flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, FlaubertConfig | |
from ..fsmt.configuration_fsmt import FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP, FSMTConfig | |
from ..funnel.configuration_funnel import FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP, FunnelConfig | |
from ..gpt2.configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config | |
from ..gpt_neo.configuration_gpt_neo import GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTNeoConfig | |
from ..hubert.configuration_hubert import HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, HubertConfig | |
from ..ibert.configuration_ibert import IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, IBertConfig | |
from ..layoutlm.configuration_layoutlm import LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP, LayoutLMConfig | |
from ..led.configuration_led import LED_PRETRAINED_CONFIG_ARCHIVE_MAP, LEDConfig | |
from ..longformer.configuration_longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig | |
from ..luke.configuration_luke import LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP, LukeConfig | |
from ..lxmert.configuration_lxmert import LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP, LxmertConfig | |
from ..m2m_100.configuration_m2m_100 import M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP, M2M100Config | |
from ..marian.configuration_marian import MarianConfig | |
from ..mbart.configuration_mbart import MBART_PRETRAINED_CONFIG_ARCHIVE_MAP, MBartConfig | |
from ..megatron_bert.configuration_megatron_bert import MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, MegatronBertConfig | |
from ..mobilebert.configuration_mobilebert import MobileBertConfig | |
from ..mpnet.configuration_mpnet import MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP, MPNetConfig | |
from ..mt5.configuration_mt5 import MT5Config | |
from ..openai.configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig | |
from ..pegasus.configuration_pegasus import PegasusConfig | |
from ..prophetnet.configuration_prophetnet import PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP, ProphetNetConfig | |
from ..rag.configuration_rag import RagConfig | |
from ..reformer.configuration_reformer import ReformerConfig | |
from ..retribert.configuration_retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig | |
from ..roberta.configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig | |
from ..roformer.configuration_roformer import ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, RoFormerConfig | |
from ..speech_to_text.configuration_speech_to_text import ( | |
SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
Speech2TextConfig, | |
) | |
from ..squeezebert.configuration_squeezebert import SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, SqueezeBertConfig | |
from ..t5.configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config | |
from ..tapas.configuration_tapas import TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP, TapasConfig | |
from ..transfo_xl.configuration_transfo_xl import TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, TransfoXLConfig | |
from ..visual_bert.configuration_visual_bert import VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, VisualBertConfig | |
from ..vit.configuration_vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig | |
from ..wav2vec2.configuration_wav2vec2 import WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP, Wav2Vec2Config | |
from ..xlm.configuration_xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig | |
from ..xlm_prophetnet.configuration_xlm_prophetnet import ( | |
XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
XLMProphetNetConfig, | |
) | |
from ..xlm_roberta.configuration_xlm_roberta import XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMRobertaConfig | |
from ..xlnet.configuration_xlnet import XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLNetConfig | |
ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict( | |
(key, value) | |
for pretrained_map in [ | |
# Add archive maps here | |
VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
CANINE_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
CLIP_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
DETR_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
LED_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
BART_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
MBART_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
T5_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
DPR_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, | |
] | |
for key, value, in pretrained_map.items() | |
) | |
CONFIG_MAPPING = OrderedDict( | |
[ | |
# Add configs here | |
("visual_bert", VisualBertConfig), | |
("canine", CanineConfig), | |
("roformer", RoFormerConfig), | |
("clip", CLIPConfig), | |
("bigbird_pegasus", BigBirdPegasusConfig), | |
("deit", DeiTConfig), | |
("luke", LukeConfig), | |
("detr", DetrConfig), | |
("gpt_neo", GPTNeoConfig), | |
("big_bird", BigBirdConfig), | |
("speech_to_text", Speech2TextConfig), | |
("vit", ViTConfig), | |
("wav2vec2", Wav2Vec2Config), | |
("m2m_100", M2M100Config), | |
("convbert", ConvBertConfig), | |
("led", LEDConfig), | |
("blenderbot-small", BlenderbotSmallConfig), | |
("retribert", RetriBertConfig), | |
("ibert", IBertConfig), | |
("mt5", MT5Config), | |
("t5", T5Config), | |
("mobilebert", MobileBertConfig), | |
("distilbert", DistilBertConfig), | |
("albert", AlbertConfig), | |
("bert-generation", BertGenerationConfig), | |
("camembert", CamembertConfig), | |
("xlm-roberta", XLMRobertaConfig), | |
("pegasus", PegasusConfig), | |
("marian", MarianConfig), | |
("mbart", MBartConfig), | |
("megatron-bert", MegatronBertConfig), | |
("mpnet", MPNetConfig), | |
("bart", BartConfig), | |
("blenderbot", BlenderbotConfig), | |
("reformer", ReformerConfig), | |
("longformer", LongformerConfig), | |
("roberta", RobertaConfig), | |
("deberta-v2", DebertaV2Config), | |
("deberta", DebertaConfig), | |
("flaubert", FlaubertConfig), | |
("fsmt", FSMTConfig), | |
("squeezebert", SqueezeBertConfig), | |
("hubert", HubertConfig), | |
("bert", BertConfig), | |
("openai-gpt", OpenAIGPTConfig), | |
("gpt2", GPT2Config), | |
("transfo-xl", TransfoXLConfig), | |
("xlnet", XLNetConfig), | |
("xlm-prophetnet", XLMProphetNetConfig), | |
("prophetnet", ProphetNetConfig), | |
("xlm", XLMConfig), | |
("ctrl", CTRLConfig), | |
("electra", ElectraConfig), | |
("encoder-decoder", EncoderDecoderConfig), | |
("funnel", FunnelConfig), | |
("lxmert", LxmertConfig), | |
("dpr", DPRConfig), | |
("layoutlm", LayoutLMConfig), | |
("rag", RagConfig), | |
("tapas", TapasConfig), | |
] | |
) | |
MODEL_NAMES_MAPPING = OrderedDict( | |
[ | |
# Add full (and cased) model names here | |
("visual_bert", "VisualBert"), | |
("canine", "Canine"), | |
("roformer", "RoFormer"), | |
("clip", "CLIP"), | |
("bigbird_pegasus", "BigBirdPegasus"), | |
("deit", "DeiT"), | |
("luke", "LUKE"), | |
("detr", "DETR"), | |
("gpt_neo", "GPT Neo"), | |
("big_bird", "BigBird"), | |
("speech_to_text", "Speech2Text"), | |
("vit", "ViT"), | |
("wav2vec2", "Wav2Vec2"), | |
("m2m_100", "M2M100"), | |
("convbert", "ConvBERT"), | |
("led", "LED"), | |
("blenderbot-small", "BlenderbotSmall"), | |
("retribert", "RetriBERT"), | |
("ibert", "I-BERT"), | |
("t5", "T5"), | |
("mobilebert", "MobileBERT"), | |
("distilbert", "DistilBERT"), | |
("albert", "ALBERT"), | |
("bert-generation", "Bert Generation"), | |
("camembert", "CamemBERT"), | |
("xlm-roberta", "XLM-RoBERTa"), | |
("pegasus", "Pegasus"), | |
("blenderbot", "Blenderbot"), | |
("marian", "Marian"), | |
("mbart", "mBART"), | |
("megatron-bert", "MegatronBert"), | |
("bart", "BART"), | |
("reformer", "Reformer"), | |
("longformer", "Longformer"), | |
("roberta", "RoBERTa"), | |
("flaubert", "FlauBERT"), | |
("fsmt", "FairSeq Machine-Translation"), | |
("squeezebert", "SqueezeBERT"), | |
("bert", "BERT"), | |
("openai-gpt", "OpenAI GPT"), | |
("gpt2", "OpenAI GPT-2"), | |
("transfo-xl", "Transformer-XL"), | |
("xlnet", "XLNet"), | |
("xlm", "XLM"), | |
("ctrl", "CTRL"), | |
("electra", "ELECTRA"), | |
("encoder-decoder", "Encoder decoder"), | |
("funnel", "Funnel Transformer"), | |
("lxmert", "LXMERT"), | |
("deberta-v2", "DeBERTa-v2"), | |
("deberta", "DeBERTa"), | |
("layoutlm", "LayoutLM"), | |
("dpr", "DPR"), | |
("rag", "RAG"), | |
("xlm-prophetnet", "XLMProphetNet"), | |
("prophetnet", "ProphetNet"), | |
("mt5", "mT5"), | |
("mpnet", "MPNet"), | |
("tapas", "TAPAS"), | |
("hubert", "Hubert"), | |
] | |
) | |
def _get_class_name(model_class): | |
if isinstance(model_class, (list, tuple)): | |
return " or ".join([f":class:`~transformers.{c.__name__}`" for c in model_class]) | |
return f":class:`~transformers.{model_class.__name__}`" | |
def _list_model_options(indent, config_to_class=None, use_model_types=True): | |
if config_to_class is None and not use_model_types: | |
raise ValueError("Using `use_model_types=False` requires a `config_to_class` dictionary.") | |
if use_model_types: | |
if config_to_class is None: | |
model_type_to_name = { | |
model_type: f":class:`~transformers.{config.__name__}`" | |
for model_type, config in CONFIG_MAPPING.items() | |
} | |
else: | |
model_type_to_name = { | |
model_type: _get_class_name(config_to_class[config]) | |
for model_type, config in CONFIG_MAPPING.items() | |
if config in config_to_class | |
} | |
lines = [ | |
f"{indent}- **{model_type}** -- {model_type_to_name[model_type]} ({MODEL_NAMES_MAPPING[model_type]} model)" | |
for model_type in sorted(model_type_to_name.keys()) | |
] | |
else: | |
config_to_name = {config.__name__: _get_class_name(clas) for config, clas in config_to_class.items()} | |
config_to_model_name = { | |
config.__name__: MODEL_NAMES_MAPPING[model_type] for model_type, config in CONFIG_MAPPING.items() | |
} | |
lines = [ | |
f"{indent}- :class:`~transformers.{config_name}` configuration class: {config_to_name[config_name]} ({config_to_model_name[config_name]} model)" | |
for config_name in sorted(config_to_name.keys()) | |
] | |
return "\n".join(lines) | |
def replace_list_option_in_docstrings(config_to_class=None, use_model_types=True): | |
def docstring_decorator(fn): | |
docstrings = fn.__doc__ | |
lines = docstrings.split("\n") | |
i = 0 | |
while i < len(lines) and re.search(r"^(\s*)List options\s*$", lines[i]) is None: | |
i += 1 | |
if i < len(lines): | |
indent = re.search(r"^(\s*)List options\s*$", lines[i]).groups()[0] | |
if use_model_types: | |
indent = f"{indent} " | |
lines[i] = _list_model_options(indent, config_to_class=config_to_class, use_model_types=use_model_types) | |
docstrings = "\n".join(lines) | |
else: | |
raise ValueError( | |
f"The function {fn} should have an empty 'List options' in its docstring as placeholder, current docstring is:\n{docstrings}" | |
) | |
fn.__doc__ = docstrings | |
return fn | |
return docstring_decorator | |
class AutoConfig: | |
r""" | |
This is a generic configuration class that will be instantiated as one of the configuration classes of the library | |
when created with the :meth:`~transformers.AutoConfig.from_pretrained` class method. | |
This class cannot be instantiated directly using ``__init__()`` (throws an error). | |
""" | |
def __init__(self): | |
raise EnvironmentError( | |
"AutoConfig is designed to be instantiated " | |
"using the `AutoConfig.from_pretrained(pretrained_model_name_or_path)` method." | |
) | |
def for_model(cls, model_type: str, *args, **kwargs): | |
if model_type in CONFIG_MAPPING: | |
config_class = CONFIG_MAPPING[model_type] | |
return config_class(*args, **kwargs) | |
raise ValueError( | |
f"Unrecognized model identifier: {model_type}. Should contain one of {', '.join(CONFIG_MAPPING.keys())}" | |
) | |
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): | |
r""" | |
Instantiate one of the configuration classes of the library from a pretrained model configuration. | |
The configuration class to instantiate is selected based on the :obj:`model_type` property of the config object | |
that is loaded, or when it's missing, by falling back to using pattern matching on | |
:obj:`pretrained_model_name_or_path`: | |
List options | |
Args: | |
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`): | |
Can be either: | |
- A string, the `model id` of a pretrained model configuration hosted inside a model repo on | |
huggingface.co. Valid model ids can be located at the root-level, like ``bert-base-uncased``, or | |
namespaced under a user or organization name, like ``dbmdz/bert-base-german-cased``. | |
- A path to a `directory` containing a configuration file saved using the | |
:meth:`~transformers.PretrainedConfig.save_pretrained` method, or the | |
:meth:`~transformers.PreTrainedModel.save_pretrained` method, e.g., ``./my_model_directory/``. | |
- A path or url to a saved configuration JSON `file`, e.g., | |
``./my_model_directory/configuration.json``. | |
cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`): | |
Path to a directory in which a downloaded pretrained model configuration should be cached if the | |
standard cache should not be used. | |
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`): | |
Whether or not to force the (re-)download the model weights and configuration files and override the | |
cached versions if they exist. | |
resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`): | |
Whether or not to delete incompletely received files. Will attempt to resume the download if such a | |
file exists. | |
proxies (:obj:`Dict[str, str]`, `optional`): | |
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128', | |
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. | |
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`): | |
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a | |
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any | |
identifier allowed by git. | |
return_unused_kwargs (:obj:`bool`, `optional`, defaults to :obj:`False`): | |
If :obj:`False`, then this function returns just the final configuration object. | |
If :obj:`True`, then this functions returns a :obj:`Tuple(config, unused_kwargs)` where `unused_kwargs` | |
is a dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e., | |
the part of ``kwargs`` which has not been used to update ``config`` and is otherwise ignored. | |
kwargs(additional keyword arguments, `optional`): | |
The values in kwargs of any keys which are configuration attributes will be used to override the loaded | |
values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled | |
by the ``return_unused_kwargs`` keyword parameter. | |
Examples:: | |
>>> from transformers import AutoConfig | |
>>> # Download configuration from huggingface.co and cache. | |
>>> config = AutoConfig.from_pretrained('bert-base-uncased') | |
>>> # Download configuration from huggingface.co (user-uploaded) and cache. | |
>>> config = AutoConfig.from_pretrained('dbmdz/bert-base-german-cased') | |
>>> # If configuration file is in a directory (e.g., was saved using `save_pretrained('./test/saved_model/')`). | |
>>> config = AutoConfig.from_pretrained('./test/bert_saved_model/') | |
>>> # Load a specific configuration file. | |
>>> config = AutoConfig.from_pretrained('./test/bert_saved_model/my_configuration.json') | |
>>> # Change some config attributes when loading a pretrained config. | |
>>> config = AutoConfig.from_pretrained('bert-base-uncased', output_attentions=True, foo=False) | |
>>> config.output_attentions | |
True | |
>>> config, unused_kwargs = AutoConfig.from_pretrained('bert-base-uncased', output_attentions=True, foo=False, return_unused_kwargs=True) | |
>>> config.output_attentions | |
True | |
>>> config.unused_kwargs | |
{'foo': False} | |
""" | |
kwargs["_from_auto"] = True | |
config_dict, _ = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs) | |
if "model_type" in config_dict: | |
config_class = CONFIG_MAPPING[config_dict["model_type"]] | |
return config_class.from_dict(config_dict, **kwargs) | |
else: | |
# Fallback: use pattern matching on the string. | |
for pattern, config_class in CONFIG_MAPPING.items(): | |
if pattern in str(pretrained_model_name_or_path): | |
return config_class.from_dict(config_dict, **kwargs) | |
raise ValueError( | |
f"Unrecognized model in {pretrained_model_name_or_path}. " | |
"Should have a `model_type` key in its config.json, or contain one of the following strings " | |
f"in its name: {', '.join(CONFIG_MAPPING.keys())}" | |
) | |