File size: 4,223 Bytes
f44876d
 
 
 
 
 
 
 
 
c135cc3
 
193923d
f44876d
193923d
 
 
 
 
 
 
 
 
 
 
 
f44876d
c135cc3
f44876d
 
c135cc3
 
f44876d
 
 
 
c135cc3
f44876d
 
 
c135cc3
f44876d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
928013e
c135cc3
f44876d
 
 
 
 
c135cc3
 
f44876d
c135cc3
 
 
 
 
 
f44876d
c135cc3
f44876d
 
 
c135cc3
f44876d
c135cc3
 
 
f44876d
c135cc3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import re
import sys
import typing as tp
import unicodedata

from sacremoses import MosesPunctNormalizer
from sentence_splitter import SentenceSplitter
from transformers import AutoModelForSeq2SeqLM, NllbTokenizer

import torch

MODEL_URL = "slone/nllb-210-v1"
LANGUAGES = {
    "Русский | Russian": "rus_Cyrl",
    "English | Английский": "eng_Latn",
    "Azərbaycan | Azerbaijani  | Азербайджанский": "azj_Latn",
    "Башҡорт | Bashkir | Башкирский": "bak_Cyrl",
    "Буряад | Buryat | Бурятский": "bxr_Cyrl",
    "Чӑваш | Chuvash | Чувашский": "chv_Cyrl",
    "Хакас | Khakas | Хакасский": "kjh_Cyrl",
    "Къарачай-малкъар | Karachay-Balkar | Карачаево-балкарский": "krc_Cyrl",
    "Марий | Meadow Mari | Марийский": "mhr_Cyrl",
    "Эрзянь | Erzya | Эрзянский": "myv_Cyrl",
    "Татар | Tatar | Татарский": "tat_Cyrl",
    "Тыва | Тувинский | Tuvan ": "tyv_Cyrl",
}
L1, L2 = "rus_Cyrl", "eng_Latn"

def get_non_printing_char_replacer(replace_by: str = " ") -> tp.Callable[[str], str]:
    non_printable_map = {ord(c): replace_by for c in (chr(i) for i in range(sys.maxunicode + 1)) if unicodedata.category(c) in {"C", "Cc", "Cf", "Cs", "Co", "Cn"}}
    return lambda line: line.translate(non_printable_map)

class TextPreprocessor:
    def __init__(self, lang="en"):
        self.mpn = MosesPunctNormalizer(lang=lang)
        self.mpn.substitutions = [(re.compile(r), sub) for r, sub in self.mpn.substitutions]
        self.replace_nonprint = get_non_printing_char_replacer(" ")

    def __call__(self, text: str) -> str:
        return unicodedata.normalize("NFKC", self.replace_nonprint(self.mpn.normalize(text)))

def sentenize_with_fillers(text, splitter, fix_double_space=True, ignore_errors=False):
    if fix_double_space:
        text = re.sub(" +", " ", text)
    sentences = splitter.split(text)
    fillers = []
    i = 0
    for sentence in sentences:
        start_idx = text.find(sentence, i)
        if ignore_errors and start_idx == -1:
            start_idx = i + 1
        assert start_idx != -1, f"sent not found after {i}: `{sentence}`"
        fillers.append(text[i:start_idx])
        i = start_idx + len(sentence)
    fillers.append(text[i:])
    return sentences, fillers

class Translator:
    def __init__(self):
        self.model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_URL, low_cpu_mem_usage=True)
        self.model.cuda() if torch.cuda.is_available() else None
        self.tokenizer = NllbTokenizer.from_pretrained(MODEL_URL)
        self.splitter = SentenceSplitter("ru")
        self.preprocessor = TextPreprocessor()
        self.languages = LANGUAGES

    def translate(self, text, src_lang=L1, tgt_lang=L2, max_length="auto", num_beams=4, by_sentence=True, preprocess=True, **kwargs):
        sents, fillers = (sentenize_with_fillers(text, self.splitter, ignore_errors=True) if by_sentence else ([text], ["", ""]))
        results = []
        if preprocess:
            for sent in sents:
                results.append(self.preprocessor(sent))
        else:
            results = sents
        for sent, sep in zip(results, fillers):
            results.append(sep)
            results.append(self.translate_single(sent, src_lang, tgt_lang, max_length, num_beams, **kwargs))
        results.append(fillers[-1])
        return "".join(results)

    def translate_single(self, text, src_lang=L1, tgt_lang=L2, max_length="auto", num_beams=4, n_out=None, **kwargs):
        self.tokenizer.src_lang = src_lang
        encoded = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
        max_length = int(32 + 2.0 * encoded.input_ids.shape[1]) if max_length == "auto" else max_length
        generated_tokens = self.model.generate(**encoded.to(self.model.device), forced_bos_token_id=self.tokenizer.lang_code_to_id[tgt_lang], max_length=max_length, num_beams=num_beams, num_return_sequences=n_out or 1, **kwargs)
        out = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
        return out[0] if isinstance(text, str) and n_out is None else out