ArneBinder's picture
fix load_model_with_adapter location
14a87b0 verified
raw
history blame
1.1 kB
from typing import Any, Dict
from pie_modules.models import * # noqa: F403
from pie_modules.taskmodules import * # noqa: F403
from pytorch_ie import AutoModel, AutoTaskModule, PyTorchIEModel, TaskModule
from pytorch_ie.models import * # noqa: F403
from pytorch_ie.taskmodules import * # noqa: F403
from transformers import PreTrainedModel, PreTrainedTokenizer
def load_model_from_pie_model(model_kwargs: Dict[str, Any]) -> PreTrainedModel:
pie_model: PyTorchIEModel = AutoModel.from_pretrained(**model_kwargs)
return pie_model.model.model
def load_tokenizer_from_pie_taskmodule(taskmodule_kwargs: Dict[str, Any]) -> PreTrainedTokenizer:
pie_taskmodule: TaskModule = AutoTaskModule.from_pretrained(**taskmodule_kwargs)
return pie_taskmodule.tokenizer
def load_model_with_adapter(
model_kwargs: Dict[str, Any], adapter_kwargs: Dict[str, Any]
) -> "ModelAdaptersMixin":
from adapters import AutoAdapterModel, ModelAdaptersMixin
model = AutoAdapterModel.from_pretrained(**model_kwargs)
model.load_adapter(set_active=True, **adapter_kwargs)
return model