import typing as tp | |
from collections import namedtuple | |
from functools import partial | |
import torch | |
from transformers import pipeline | |
def get_translator(): | |
return pipeline( | |
"translation_en_to_ru", | |
model="Helsinki-NLP/opus-mt-ru-en", | |
device="cuda" if torch.cuda.is_available() else "cpu", | |
torch_dtype="auto", | |
) | |
class TranslationModel: | |
def __init__(self, get_model): | |
self.translator = get_translator() | |
self.model = get_model() | |
def __call__(self, input, **kwargs): | |
def transform_input_dict_to_str(input): | |
if isinstance(input, tp.Dict): | |
return input["authors"] + " " + input["abstract"] + " " + input["title"] | |
if not isinstance(input, tp.Iterable) or isinstance(input, tp.Dict): | |
input = [input] | |
input = [transform_input_dict_to_str(i) for i in input] | |
translated_input = self.translator(input) | |
translated = [ | |
translated_i["translation_text"] for translated_i in translated_input | |
] | |
out = self.model(translated) | |
if 1 == len(out): | |
return out[0] | |
return out | |
def create_translation_models(models): | |
return { | |
f"{name} (С помощью перевода)": partial(TranslationModel, get_model=get_model) | |
for name, get_model in models.items() | |
} | |