|
import importlib.util |
|
import os |
|
import sys |
|
import typing as tp |
|
import warnings |
|
from pathlib import Path |
|
|
|
from . import pipeline |
|
from .translation import create_translation_models |
|
|
|
|
|
def import_model_module(file_path: os.PathLike): |
|
module_name = str(Path(file_path).relative_to(os.getcwd())).replace( |
|
os.path.sep, "." |
|
) |
|
spec = importlib.util.spec_from_file_location(module_name, file_path) |
|
module = importlib.util.module_from_spec(spec) |
|
sys.modules[module_name] = module |
|
spec.loader.exec_module(module) |
|
return module |
|
|
|
|
|
models = {} |
|
language_to_models = {} |
|
|
|
file_dir = Path(__file__).parents[0] |
|
|
|
for path in file_dir.glob("*"): |
|
if path.is_dir(): |
|
model_file_path = path / "model.py" |
|
if not model_file_path.exists(): |
|
continue |
|
module = import_model_module(model_file_path) |
|
name_key = "name" |
|
get_model_key = "get_model" |
|
supported_langs_key = "supported_langs" |
|
name = getattr(module, name_key, None) |
|
get_model = getattr(module, get_model_key, None) |
|
supported_langs = getattr(module, supported_langs_key, None) |
|
|
|
def check_attr_exists(attr_name, attr): |
|
if attr is None: |
|
warnings.warn( |
|
f"Module {model_file_path} should define attribute '{attr_name}'" |
|
) |
|
return False |
|
return True |
|
|
|
def check_attr_type(attr_name, attr, type): |
|
if isinstance(attr, type): |
|
return True |
|
warnings.warn( |
|
f"'{attr_name}' should be of type {type}, but it is of type {type(attr)}" |
|
) |
|
return False |
|
|
|
def check_attr_callable(attr_name, attr): |
|
if callable(attr): |
|
return True |
|
warnings.warn(f"'{attr_name}' should be callable") |
|
return False |
|
|
|
if not check_attr_exists(name_key, name): |
|
continue |
|
if not check_attr_exists(get_model_key, get_model): |
|
continue |
|
if not check_attr_exists(supported_langs_key, supported_langs): |
|
continue |
|
if not check_attr_type(name_key, name, str): |
|
continue |
|
if not check_attr_callable(get_model_key, get_model): |
|
continue |
|
if not check_attr_type(supported_langs_key, supported_langs, tp.Iterable): |
|
continue |
|
|
|
models[name] = get_model |
|
for lang in supported_langs: |
|
language_to_models.setdefault(lang, {}) |
|
language_to_models[lang][name] = get_model |
|
|
|
translation_models = create_translation_models(language_to_models["en"]) |
|
language_to_models.setdefault("ru", {}).update(translation_models) |
|
models.update(translation_models) |
|
|
|
|
|
def get_model(name: str): |
|
if name not in models: |
|
raise KeyError(f"No model with name {name}") |
|
return models[name]() |
|
|
|
|
|
def get_all_model_names(): |
|
return list(models.keys()) |
|
|
|
|
|
def get_model_names_by_lang(lang): |
|
if lang not in language_to_models: |
|
return [] |
|
return language_to_models[lang] |
|
|