from typing import Dict, Any, Tuple from torch import nn from .build_shikra import load_pretrained_shikra PREPROCESSOR = Dict[str, Any] # TODO: Registry def load_pretrained(model_args, training_args) -> Tuple[nn.Module, PREPROCESSOR]: type_ = model_args.type if type_ == 'shikra': return load_pretrained_shikra(model_args, training_args) else: assert False