import typing as tp from collections import namedtuple from functools import partial import torch from transformers import pipeline def get_translator(): return pipeline( "translation_en_to_ru", model="Helsinki-NLP/opus-mt-ru-en", device="cuda" if torch.cuda.is_available() else "cpu", torch_dtype="auto", ) class TranslationModel: def __init__(self, get_model): self.translator = get_translator() self.model = get_model() def __call__(self, input, **kwargs): def transform_input_dict_to_str(input): if isinstance(input, tp.Dict): return input["authors"] + " " + input["abstract"] + " " + input["title"] if not isinstance(input, tp.Iterable) or isinstance(input, tp.Dict): input = [input] input = [transform_input_dict_to_str(i) for i in input] translated_input = self.translator(input) translated = [ translated_i["translation_text"] for translated_i in translated_input ] out = self.model(translated) if 1 == len(out): return out[0] return out def create_translation_models(models): return { f"{name} (С помощью перевода)": partial(TranslationModel, get_model=get_model) for name, get_model in models.items() }