|
import re |
|
import sys |
|
import typing as tp |
|
import unicodedata |
|
|
|
from sacremoses import MosesPunctNormalizer |
|
from sentence_splitter import SentenceSplitter |
|
from transformers import AutoModelForSeq2SeqLM, NllbTokenizer |
|
|
|
import torch |
|
|
|
MODEL_URL = "slone/nllb-210-v1" |
|
LANGUAGES = { |
|
"Русский | Russian": "rus_Cyrl", |
|
"English | Английский": "eng_Latn", |
|
"Azərbaycan | Azerbaijani | Азербайджанский": "azj_Latn", |
|
"Башҡорт | Bashkir | Башкирский": "bak_Cyrl", |
|
"Буряад | Buryat | Бурятский": "bxr_Cyrl", |
|
"Чӑваш | Chuvash | Чувашский": "chv_Cyrl", |
|
"Хакас | Khakas | Хакасский": "kjh_Cyrl", |
|
"Къарачай-малкъар | Karachay-Balkar | Карачаево-балкарский": "krc_Cyrl", |
|
"Марий | Meadow Mari | Марийский": "mhr_Cyrl", |
|
"Эрзянь | Erzya | Эрзянский": "myv_Cyrl", |
|
"Татар | Tatar | Татарский": "tat_Cyrl", |
|
"Тыва | Тувинский | Tuvan ": "tyv_Cyrl", |
|
} |
|
L1, L2 = "rus_Cyrl", "eng_Latn" |
|
|
|
def get_non_printing_char_replacer(replace_by: str = " ") -> tp.Callable[[str], str]: |
|
non_printable_map = {ord(c): replace_by for c in (chr(i) for i in range(sys.maxunicode + 1)) if unicodedata.category(c) in {"C", "Cc", "Cf", "Cs", "Co", "Cn"}} |
|
return lambda line: line.translate(non_printable_map) |
|
|
|
class TextPreprocessor: |
|
def __init__(self, lang="en"): |
|
self.mpn = MosesPunctNormalizer(lang=lang) |
|
self.mpn.substitutions = [(re.compile(r), sub) for r, sub in self.mpn.substitutions] |
|
self.replace_nonprint = get_non_printing_char_replacer(" ") |
|
|
|
def __call__(self, text: str) -> str: |
|
return unicodedata.normalize("NFKC", self.replace_nonprint(self.mpn.normalize(text))) |
|
|
|
def sentenize_with_fillers(text, splitter, fix_double_space=True, ignore_errors=False): |
|
if fix_double_space: |
|
text = re.sub(" +", " ", text) |
|
sentences = splitter.split(text) |
|
fillers = [] |
|
i = 0 |
|
for sentence in sentences: |
|
start_idx = text.find(sentence, i) |
|
if ignore_errors and start_idx == -1: |
|
start_idx = i + 1 |
|
assert start_idx != -1, f"sent not found after {i}: `{sentence}`" |
|
fillers.append(text[i:start_idx]) |
|
i = start_idx + len(sentence) |
|
fillers.append(text[i:]) |
|
return sentences, fillers |
|
|
|
class Translator: |
|
def __init__(self): |
|
self.model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_URL, low_cpu_mem_usage=True) |
|
self.model.cuda() if torch.cuda.is_available() else None |
|
self.tokenizer = NllbTokenizer.from_pretrained(MODEL_URL) |
|
self.splitter = SentenceSplitter("ru") |
|
self.preprocessor = TextPreprocessor() |
|
self.languages = LANGUAGES |
|
|
|
def translate(self, text, src_lang=L1, tgt_lang=L2, max_length="auto", num_beams=4, by_sentence=True, preprocess=True, **kwargs): |
|
sents, fillers = (sentenize_with_fillers(text, self.splitter, ignore_errors=True) if by_sentence else ([text], ["", ""])) |
|
results = [] |
|
if preprocess: |
|
for sent in sents: |
|
results.append(self.preprocessor(sent)) |
|
else: |
|
results = sents |
|
for sent, sep in zip(results, fillers): |
|
results.append(sep) |
|
results.append(self.translate_single(sent, src_lang, tgt_lang, max_length, num_beams, **kwargs)) |
|
results.append(fillers[-1]) |
|
return "".join(results) |
|
|
|
def translate_single(self, text, src_lang=L1, tgt_lang=L2, max_length="auto", num_beams=4, n_out=None, **kwargs): |
|
self.tokenizer.src_lang = src_lang |
|
encoded = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=512) |
|
max_length = int(32 + 2.0 * encoded.input_ids.shape[1]) if max_length == "auto" else max_length |
|
generated_tokens = self.model.generate(**encoded.to(self.model.device), forced_bos_token_id=self.tokenizer.lang_code_to_id[tgt_lang], max_length=max_length, num_beams=num_beams, num_return_sequences=n_out or 1, **kwargs) |
|
out = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) |
|
return out[0] if isinstance(text, str) and n_out is None else out |