Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -54,10 +54,6 @@ def translate(text: str, src_lang: str, tgt_lang: str):
|
|
54 |
# Only assign GPU if cache not used
|
55 |
@spaces.GPU
|
56 |
def _translate(text: str, src_lang: str, tgt_lang: str):
|
57 |
-
src_code = code_mapping[src_lang]
|
58 |
-
tgt_code = code_mapping[tgt_lang]
|
59 |
-
tokenizer.src_lang = src_code
|
60 |
-
tokenizer.tgt_lang = tgt_code
|
61 |
|
62 |
# normalizing the punctuation first
|
63 |
text = punct_normalizer.normalize(text)
|
@@ -66,31 +62,27 @@ def _translate(text: str, src_lang: str, tgt_lang: str):
|
|
66 |
translated_paragraphs = []
|
67 |
|
68 |
for paragraph in paragraphs:
|
69 |
-
splitter = get_language_specific_sentence_splitter(src_code)
|
70 |
-
sentences = list(splitter(paragraph))
|
71 |
translated_sentences = []
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
)
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
)
|
93 |
-
translated_sentences.append(translated_chunk)
|
94 |
|
95 |
translated_paragraph = " ".join(translated_sentences)
|
96 |
translated_paragraphs.append(translated_paragraph)
|
|
|
54 |
# Only assign GPU if cache not used
|
55 |
@spaces.GPU
|
56 |
def _translate(text: str, src_lang: str, tgt_lang: str):
|
|
|
|
|
|
|
|
|
57 |
|
58 |
# normalizing the punctuation first
|
59 |
text = punct_normalizer.normalize(text)
|
|
|
62 |
translated_paragraphs = []
|
63 |
|
64 |
for paragraph in paragraphs:
|
|
|
|
|
65 |
translated_sentences = []
|
66 |
+
input_tokens = (
|
67 |
+
tokenizer("Translate to Chinese:\n\n" + paragraph, return_tensors="pt")
|
68 |
+
.input_ids[0]
|
69 |
+
.cpu()
|
70 |
+
.numpy()
|
71 |
+
.tolist()
|
72 |
+
)
|
73 |
+
translated_chunk = model.generate(
|
74 |
+
input_ids=torch.tensor([input_tokens]).to(device),
|
75 |
+
forced_bos_token_id=tokenizer.convert_tokens_to_ids(tgt_code),
|
76 |
+
max_length=len(input_tokens) + 50,
|
77 |
+
num_return_sequences=1,
|
78 |
+
num_beams=5,
|
79 |
+
no_repeat_ngram_size=4, # repetition blocking works better if this number is below num_beams
|
80 |
+
renormalize_logits=True, # recompute token probabilities after banning the repetitions
|
81 |
+
)
|
82 |
+
translated_chunk = tokenizer.decode(
|
83 |
+
translated_chunk[0], skip_special_tokens=True
|
84 |
+
)
|
85 |
+
translated_sentences.append(translated_chunk)
|
|
|
|
|
86 |
|
87 |
translated_paragraph = " ".join(translated_sentences)
|
88 |
translated_paragraphs.append(translated_paragraph)
|