File size: 390 Bytes
3e1d9f3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 |
from typing import Dict, Any, Tuple
from torch import nn
from .build_shikra import load_pretrained_shikra
PREPROCESSOR = Dict[str, Any]
# TODO: Registry
def load_pretrained(model_args, training_args) -> Tuple[nn.Module, PREPROCESSOR]:
type_ = model_args.type
if type_ == 'shikra':
return load_pretrained_shikra(model_args, training_args)
else:
assert False
|