Valeriy Sinyukov
Add translation models
5c5407c
raw
history blame
3.05 kB
import importlib.util
import os
import sys
import typing as tp
import warnings
from pathlib import Path
from . import pipeline
from .translation import create_translation_models
def import_model_module(file_path: os.PathLike):
module_name = str(Path(file_path).relative_to(os.getcwd())).replace(
os.path.sep, "."
)
spec = importlib.util.spec_from_file_location(module_name, file_path)
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
return module
models = {}
language_to_models = {}
file_dir = Path(__file__).parents[0]
for path in file_dir.glob("*"):
if path.is_dir():
model_file_path = path / "model.py"
if not model_file_path.exists():
continue
module = import_model_module(model_file_path)
name_key = "name"
get_model_key = "get_model"
supported_langs_key = "supported_langs"
name = getattr(module, name_key, None)
get_model = getattr(module, get_model_key, None)
supported_langs = getattr(module, supported_langs_key, None)
def check_attr_exists(attr_name, attr):
if attr is None:
warnings.warn(
f"Module {model_file_path} should define attribute '{attr_name}'"
)
return False
return True
def check_attr_type(attr_name, attr, type):
if isinstance(attr, type):
return True
warnings.warn(
f"'{attr_name}' should be of type {type}, but it is of type {type(attr)}"
)
return False
def check_attr_callable(attr_name, attr):
if callable(attr):
return True
warnings.warn(f"'{attr_name}' should be callable")
return False
if not check_attr_exists(name_key, name):
continue
if not check_attr_exists(get_model_key, get_model):
continue
if not check_attr_exists(supported_langs_key, supported_langs):
continue
if not check_attr_type(name_key, name, str):
continue
if not check_attr_callable(get_model_key, get_model):
continue
if not check_attr_type(supported_langs_key, supported_langs, tp.Iterable):
continue
models[name] = get_model
for lang in supported_langs:
language_to_models.setdefault(lang, {})
language_to_models[lang][name] = get_model
translation_models = create_translation_models(language_to_models["en"])
language_to_models.setdefault("ru", {}).update(translation_models)
models.update(translation_models)
def get_model(name: str):
if name not in models:
raise KeyError(f"No model with name {name}")
return models[name]()
def get_all_model_names():
return list(models.keys())
def get_model_names_by_lang(lang):
if lang not in language_to_models:
return []
return language_to_models[lang]