AriNubar commited on
Commit
5ba9dc3
·
1 Parent(s): 67e0b37

Create translation.py

Browse files
Files changed (1) hide show
  1. translation.py +187 -0
translation.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import sys
3
+ import typing as tp
4
+
5
+ import torch
6
+ import pysbd
7
+ from transformers import NllbTokenizer, AutoModelForSeq2SeqLM
8
+ import unicodedata
9
+
10
+ hy_segmenter = pysbd.Segmenter(language="hy", clean=False)
11
+
12
+ MODEL_NAME = "AriNubar/nllb-200-distilled-600m-en-hyw"
13
+
14
+ LANGUAGES = {
15
+ "Արեւմտահայերէն | Western Armenian": "hyw_Armn",
16
+ "Անգլերէն | English": "eng_Latn",
17
+ }
18
+
19
+ def get_non_printing_char_replacer(replace_by: str = " "):
20
+ non_printable_map = {
21
+ ord(c): replace_by
22
+ for c in (chr(i) for i in range(sys.maxunicode + 1))
23
+ # same as \p{C} in perl
24
+ # see https://www.unicode.org/reports/tr44/#General_Category_Values
25
+ if unicodedata.category(c) in {"C", "Cc", "Cf", "Cs", "Co", "Cn"}
26
+ }
27
+
28
+ def replace_non_printing_char(line) -> str:
29
+ return line.translate(non_printable_map)
30
+
31
+ return replace_non_printing_char
32
+
33
+ def clean_text(text: str, lang) -> str:
34
+ HYW_CHARS_TO_NORMALIZE = {
35
+ "«": '"',
36
+ "»": '"',
37
+ "“": '"',
38
+ "”": '"',
39
+ "’": "'",
40
+ "‘": "'",
41
+ "–": "-",
42
+ "—": "-",
43
+ "ՙ": "'",
44
+ "՚": "'",
45
+ }
46
+
47
+ DOUBLE_CHARS_TO_NORMALIZE = {
48
+ "Կ՛": "Կ'",
49
+ "կ՛": "կ'",
50
+ "Չ՛": "Չ'",
51
+ "չ՛": "չ'",
52
+ "Մ՛": "Մ'",
53
+ "մ՛": "մ'",
54
+ }
55
+ replace_nonprint = get_non_printing_char_replacer()
56
+
57
+ text = replace_nonprint(text)
58
+ # print(text)
59
+ text = text.replace("\t", " ").replace("\n", " ").replace("\r", " ").replace(r"[^\x00-\x7F]+", " ").replace(r"\s+", " ")
60
+ text = text.strip()
61
+
62
+ if lang == "hyw_Armn":
63
+ text = text.translate(str.maketrans(HYW_CHARS_TO_NORMALIZE))
64
+ for k, v in DOUBLE_CHARS_TO_NORMALIZE.items():
65
+ text = text.replace(k, v)
66
+
67
+ return text
68
+
69
+ def sentenize_with_fillers(text, splitter, fix_double_space=True, ignore_errors=False):
70
+ if fix_double_space:
71
+ text = re.sub(r"\s+", " ", text)
72
+ text = text.strip()
73
+
74
+ sentences = splitter.segment(text)
75
+
76
+ fillers = []
77
+ i = 0
78
+
79
+ for sent in sentences:
80
+ start_idx = text.find(sent, i)
81
+ if ignore_errors and start_idx == -1:
82
+ start_idx = i + 1
83
+ assert start_idx != -1, f"Sent not found after index {i} in text: {text}"
84
+
85
+ fillers.append(text[i:start_idx])
86
+ i = start_idx + len(sent)
87
+
88
+ fillers.append(text[i:])
89
+ return sentences, fillers
90
+
91
+ def init_tokenizer(tokenizer, new_lang='hyw_Armn'):
92
+ """ Add a new language token to the tokenizer vocabulary (this should be done each time after its initialization) """
93
+ old_len = len(tokenizer) - int(new_lang in tokenizer.added_tokens_encoder)
94
+ tokenizer.lang_code_to_id[new_lang] = old_len-1
95
+ tokenizer.id_to_lang_code[old_len-1] = new_lang
96
+ # always move "mask" to the last position
97
+ tokenizer.fairseq_tokens_to_ids["<mask>"] = len(tokenizer.sp_model) + len(tokenizer.lang_code_to_id) + tokenizer.fairseq_offset
98
+
99
+ tokenizer.fairseq_tokens_to_ids.update(tokenizer.lang_code_to_id)
100
+ tokenizer.fairseq_ids_to_tokens = {v: k for k, v in tokenizer.fairseq_tokens_to_ids.items()}
101
+ if new_lang not in tokenizer._additional_special_tokens:
102
+ tokenizer._additional_special_tokens.append(new_lang)
103
+ # clear the added token encoder; otherwise a new token may end up there by mistake
104
+ tokenizer.added_tokens_encoder = {}
105
+ tokenizer.added_tokens_decoder = {}
106
+
107
+ class Translator:
108
+ def __init__(self) -> None:
109
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
110
+
111
+ if torch.cuda.is_available():
112
+ self.model = self.model.cuda()
113
+
114
+ self.tokenizer = NllbTokenizer.from_pretrained(MODEL_NAME)
115
+ init_tokenizer(self.tokenizer)
116
+
117
+ self.hyw_splitter = pysbd.Segmenter(language="hy", clean=False)
118
+ self.eng_splitter = pysbd.Segmenter(language="en", clean=False)
119
+ self.languages = LANGUAGES
120
+
121
+
122
+ def translate_single(
123
+ self,
124
+ text,
125
+ src_lang,
126
+ tgt_lang,
127
+ max_length="auto",
128
+ num_beams=4,
129
+ n_out=None,
130
+ **kwargs,
131
+ ):
132
+ self.tokenizer.src_lang = src_lang
133
+ encoded = self.tokenizer(
134
+ text, return_tensors="pt", truncation=True, max_length=256
135
+ )
136
+ if max_length == "auto":
137
+ max_length = int(32 + 2.0 * encoded.input_ids.shape[1])
138
+ generated_tokens = self.model.generate(
139
+ **encoded.to(self.model.device),
140
+ forced_bos_token_id=self.tokenizer.lang_code_to_id[tgt_lang],
141
+ max_length=max_length,
142
+ num_beams=num_beams,
143
+ num_return_sequences=n_out or 1,
144
+ **kwargs,
145
+ )
146
+ out = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
147
+ if isinstance(text, str) and n_out is None:
148
+ return out[0]
149
+ return out
150
+
151
+
152
+ def translate(self, text: str,
153
+ src_lang: str,
154
+ tgt_lang: str,
155
+ max_length="auto",
156
+ num_beams=4,
157
+ by_sentence=True,
158
+ clean=True,
159
+ **kwargs):
160
+
161
+ if by_sentence:
162
+ if src_lang =="eng_Latn":
163
+ sents, fillers = sentenize_with_fillers(text, self.eng_splitter, ignore_errors=True)
164
+ elif src_lang == "hyw_Armn":
165
+ sents, fillers = sentenize_with_fillers(text, self.hyw_splitter, ignore_errors=True)
166
+
167
+ else:
168
+ sents = [text]
169
+ fillers = ["", ""]
170
+
171
+ if clean:
172
+ sents = [clean_text(sent, src_lang) for sent in sents]
173
+
174
+ results = []
175
+ for sent, sep in zip(sents, fillers):
176
+ results.append(sep)
177
+ results.append(self.translate_single(sent, src_lang, tgt_lang, max_length, num_beams, **kwargs))
178
+
179
+ results.append(fillers[-1])
180
+
181
+ return " ".join(results)
182
+
183
+ if __name__ == "__main__":
184
+ print("Initializing translator...")
185
+ translator = Translator()
186
+ print("Translator initialized.")
187
+ print(translator.translate("Hello, world!", "eng_Latn", "hyw_Armn"))