|
_model_entrypoints = {} |
|
|
|
|
|
def build_model(config, **kwargs): |
|
model_name = config['MODEL']['NAME'] |
|
|
|
if not is_model(model_name): |
|
raise ValueError(f'Unkown model: {model_name}') |
|
|
|
return model_entrypoints(model_name)(config, **kwargs) |
|
|
|
def register_model(fn): |
|
module_name_split = fn.__module__.split('.') |
|
model_name = module_name_split[-1] |
|
_model_entrypoints[model_name] = fn |
|
return fn |
|
|
|
def model_entrypoints(model_name): |
|
return _model_entrypoints[model_name] |
|
|
|
def is_model(model_name): |
|
return model_name in _model_entrypoints |