File size: 390 Bytes
3e1d9f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from typing import Dict, Any, Tuple

from torch import nn

from .build_shikra import load_pretrained_shikra

PREPROCESSOR = Dict[str, Any]


# TODO: Registry
def load_pretrained(model_args, training_args) -> Tuple[nn.Module, PREPROCESSOR]:
    type_ = model_args.type
    if type_ == 'shikra':
        return load_pretrained_shikra(model_args, training_args)
    else:
        assert False