Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -54,10 +54,6 @@ def translate(text: str, src_lang: str, tgt_lang: str):
|
|
| 54 |
# Only assign GPU if cache not used
|
| 55 |
@spaces.GPU
|
| 56 |
def _translate(text: str, src_lang: str, tgt_lang: str):
|
| 57 |
-
src_code = code_mapping[src_lang]
|
| 58 |
-
tgt_code = code_mapping[tgt_lang]
|
| 59 |
-
tokenizer.src_lang = src_code
|
| 60 |
-
tokenizer.tgt_lang = tgt_code
|
| 61 |
|
| 62 |
# normalizing the punctuation first
|
| 63 |
text = punct_normalizer.normalize(text)
|
|
@@ -66,31 +62,27 @@ def _translate(text: str, src_lang: str, tgt_lang: str):
|
|
| 66 |
translated_paragraphs = []
|
| 67 |
|
| 68 |
for paragraph in paragraphs:
|
| 69 |
-
splitter = get_language_specific_sentence_splitter(src_code)
|
| 70 |
-
sentences = list(splitter(paragraph))
|
| 71 |
translated_sentences = []
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
)
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
)
|
| 93 |
-
translated_sentences.append(translated_chunk)
|
| 94 |
|
| 95 |
translated_paragraph = " ".join(translated_sentences)
|
| 96 |
translated_paragraphs.append(translated_paragraph)
|
|
|
|
| 54 |
# Only assign GPU if cache not used
|
| 55 |
@spaces.GPU
|
| 56 |
def _translate(text: str, src_lang: str, tgt_lang: str):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
# normalizing the punctuation first
|
| 59 |
text = punct_normalizer.normalize(text)
|
|
|
|
| 62 |
translated_paragraphs = []
|
| 63 |
|
| 64 |
for paragraph in paragraphs:
|
|
|
|
|
|
|
| 65 |
translated_sentences = []
|
| 66 |
+
input_tokens = (
|
| 67 |
+
tokenizer("Translate to Chinese:\n\n" + paragraph, return_tensors="pt")
|
| 68 |
+
.input_ids[0]
|
| 69 |
+
.cpu()
|
| 70 |
+
.numpy()
|
| 71 |
+
.tolist()
|
| 72 |
+
)
|
| 73 |
+
translated_chunk = model.generate(
|
| 74 |
+
input_ids=torch.tensor([input_tokens]).to(device),
|
| 75 |
+
forced_bos_token_id=tokenizer.convert_tokens_to_ids(tgt_code),
|
| 76 |
+
max_length=len(input_tokens) + 50,
|
| 77 |
+
num_return_sequences=1,
|
| 78 |
+
num_beams=5,
|
| 79 |
+
no_repeat_ngram_size=4, # repetition blocking works better if this number is below num_beams
|
| 80 |
+
renormalize_logits=True, # recompute token probabilities after banning the repetitions
|
| 81 |
+
)
|
| 82 |
+
translated_chunk = tokenizer.decode(
|
| 83 |
+
translated_chunk[0], skip_special_tokens=True
|
| 84 |
+
)
|
| 85 |
+
translated_sentences.append(translated_chunk)
|
|
|
|
|
|
|
| 86 |
|
| 87 |
translated_paragraph = " ".join(translated_sentences)
|
| 88 |
translated_paragraphs.append(translated_paragraph)
|