|
from collections import namedtuple |
|
from functools import partial |
|
|
|
import torch |
|
from transformers import pipeline |
|
|
|
|
|
def get_translator(): |
|
return pipeline( |
|
"translation_en_to_ru", |
|
model="Helsinki-NLP/opus-mt-ru-en", |
|
device="cuda" if torch.cuda.is_available() else "cpu", |
|
torch_dtype="auto", |
|
) |
|
|
|
class Input: |
|
def __init__(self, title, abstract, authors): |
|
self.title = title |
|
self.abstract = abstract |
|
self.authors = authors |
|
|
|
class TranslationModel: |
|
def __init__(self, get_model): |
|
self.translator = get_translator() |
|
self.model = get_model() |
|
|
|
def __call__(self, input): |
|
def translate(text): |
|
if text is None or text.strip() == "": |
|
return "" |
|
text = str(text).strip() |
|
translated = self.translator(text)[0]['translation_text'] |
|
return translated |
|
title = translate(input.title) |
|
abstract = translate(input.abstract) |
|
authors = translate(input.authors) |
|
out = self.model(Input(title, abstract, authors)) |
|
return out |
|
|
|
|
|
def create_translation_models(models): |
|
return { |
|
f"{name} (С помощью перевода)": partial(TranslationModel, get_model=get_model) |
|
for name, get_model in models.items() |
|
} |
|
|
|
|