Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import contextlib | |
from typing import Optional | |
import transformers | |
from mmengine.registry import Registry | |
from transformers import AutoConfig, PreTrainedModel | |
from transformers.models.auto.auto_factory import _BaseAutoModelClass | |
from mmpretrain.registry import MODELS, TOKENIZER | |
def register_hf_tokenizer( | |
cls: Optional[type] = None, | |
registry: Registry = TOKENIZER, | |
): | |
"""Register HuggingFace-style PreTrainedTokenizerBase class.""" | |
if cls is None: | |
# use it as a decorator: @register_hf_tokenizer() | |
def _register(cls): | |
register_hf_tokenizer(cls=cls) | |
return cls | |
return _register | |
def from_pretrained(**kwargs): | |
if ('pretrained_model_name_or_path' not in kwargs | |
and 'name_or_path' not in kwargs): | |
raise TypeError( | |
f'{cls.__name__}.from_pretrained() missing required ' | |
"argument 'pretrained_model_name_or_path' or 'name_or_path'.") | |
# `pretrained_model_name_or_path` is too long for config, | |
# add an alias name `name_or_path` here. | |
name_or_path = kwargs.pop('pretrained_model_name_or_path', | |
kwargs.pop('name_or_path')) | |
return cls.from_pretrained(name_or_path, **kwargs) | |
registry._register_module(module=from_pretrained, module_name=cls.__name__) | |
return cls | |
_load_hf_pretrained_model = True | |
def no_load_hf_pretrained_model(): | |
global _load_hf_pretrained_model | |
_load_hf_pretrained_model = False | |
yield | |
_load_hf_pretrained_model = True | |
def register_hf_model( | |
cls: Optional[type] = None, | |
registry: Registry = MODELS, | |
): | |
"""Register HuggingFace-style PreTrainedModel class.""" | |
if cls is None: | |
# use it as a decorator: @register_hf_tokenizer() | |
def _register(cls): | |
register_hf_model(cls=cls) | |
return cls | |
return _register | |
if issubclass(cls, _BaseAutoModelClass): | |
get_config = AutoConfig.from_pretrained | |
from_config = cls.from_config | |
elif issubclass(cls, PreTrainedModel): | |
get_config = cls.config_class.from_pretrained | |
from_config = cls | |
else: | |
raise TypeError('Not auto model nor pretrained model of huggingface.') | |
def build(**kwargs): | |
if ('pretrained_model_name_or_path' not in kwargs | |
and 'name_or_path' not in kwargs): | |
raise TypeError( | |
f'{cls.__name__} missing required argument ' | |
'`pretrained_model_name_or_path` or `name_or_path`.') | |
# `pretrained_model_name_or_path` is too long for config, | |
# add an alias name `name_or_path` here. | |
name_or_path = kwargs.pop('pretrained_model_name_or_path', | |
kwargs.pop('name_or_path')) | |
if kwargs.pop('load_pretrained', True) and _load_hf_pretrained_model: | |
return cls.from_pretrained(name_or_path, **kwargs) | |
else: | |
cfg = get_config(name_or_path, **kwargs) | |
return from_config(cfg) | |
registry._register_module(module=build, module_name=cls.__name__) | |
return cls | |
register_hf_model(transformers.AutoModelForCausalLM) | |