Valeriy Sinyukov
Remove model wrappers, use dict and model input
82ec9f7
raw
history blame contribute delete
1.36 kB
import typing as tp
from collections import namedtuple
from functools import partial
import torch
from transformers import pipeline
def get_translator():
return pipeline(
"translation_en_to_ru",
model="Helsinki-NLP/opus-mt-ru-en",
device="cuda" if torch.cuda.is_available() else "cpu",
torch_dtype="auto",
)
class TranslationModel:
def __init__(self, get_model):
self.translator = get_translator()
self.model = get_model()
def __call__(self, input, **kwargs):
def transform_input_dict_to_str(input):
if isinstance(input, tp.Dict):
return input["authors"] + " " + input["abstract"] + " " + input["title"]
if not isinstance(input, tp.Iterable) or isinstance(input, tp.Dict):
input = [input]
input = [transform_input_dict_to_str(i) for i in input]
translated_input = self.translator(input)
translated = [
translated_i["translation_text"] for translated_i in translated_input
]
out = self.model(translated)
if 1 == len(out):
return out[0]
return out
def create_translation_models(models):
return {
f"{name} (С помощью перевода)": partial(TranslationModel, get_model=get_model)
for name, get_model in models.items()
}