kevinpro commited on
Commit
51be568
·
verified ·
1 Parent(s): 3716781

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -28
app.py CHANGED
@@ -54,10 +54,6 @@ def translate(text: str, src_lang: str, tgt_lang: str):
54
  # Only assign GPU if cache not used
55
  @spaces.GPU
56
  def _translate(text: str, src_lang: str, tgt_lang: str):
57
- src_code = code_mapping[src_lang]
58
- tgt_code = code_mapping[tgt_lang]
59
- tokenizer.src_lang = src_code
60
- tokenizer.tgt_lang = tgt_code
61
 
62
  # normalizing the punctuation first
63
  text = punct_normalizer.normalize(text)
@@ -66,31 +62,27 @@ def _translate(text: str, src_lang: str, tgt_lang: str):
66
  translated_paragraphs = []
67
 
68
  for paragraph in paragraphs:
69
- splitter = get_language_specific_sentence_splitter(src_code)
70
- sentences = list(splitter(paragraph))
71
  translated_sentences = []
72
-
73
- for sentence in sentences:
74
- input_tokens = (
75
- tokenizer(sentence, return_tensors="pt")
76
- .input_ids[0]
77
- .cpu()
78
- .numpy()
79
- .tolist()
80
- )
81
- translated_chunk = model.generate(
82
- input_ids=torch.tensor([input_tokens]).to(device),
83
- forced_bos_token_id=tokenizer.convert_tokens_to_ids(tgt_code),
84
- max_length=len(input_tokens) + 50,
85
- num_return_sequences=1,
86
- num_beams=5,
87
- no_repeat_ngram_size=4, # repetition blocking works better if this number is below num_beams
88
- renormalize_logits=True, # recompute token probabilities after banning the repetitions
89
- )
90
- translated_chunk = tokenizer.decode(
91
- translated_chunk[0], skip_special_tokens=True
92
- )
93
- translated_sentences.append(translated_chunk)
94
 
95
  translated_paragraph = " ".join(translated_sentences)
96
  translated_paragraphs.append(translated_paragraph)
 
54
  # Only assign GPU if cache not used
55
  @spaces.GPU
56
  def _translate(text: str, src_lang: str, tgt_lang: str):
 
 
 
 
57
 
58
  # normalizing the punctuation first
59
  text = punct_normalizer.normalize(text)
 
62
  translated_paragraphs = []
63
 
64
  for paragraph in paragraphs:
 
 
65
  translated_sentences = []
66
+ input_tokens = (
67
+ tokenizer("Translate to Chinese:\n\n" + paragraph, return_tensors="pt")
68
+ .input_ids[0]
69
+ .cpu()
70
+ .numpy()
71
+ .tolist()
72
+ )
73
+ translated_chunk = model.generate(
74
+ input_ids=torch.tensor([input_tokens]).to(device),
75
+ forced_bos_token_id=tokenizer.convert_tokens_to_ids(tgt_code),
76
+ max_length=len(input_tokens) + 50,
77
+ num_return_sequences=1,
78
+ num_beams=5,
79
+ no_repeat_ngram_size=4, # repetition blocking works better if this number is below num_beams
80
+ renormalize_logits=True, # recompute token probabilities after banning the repetitions
81
+ )
82
+ translated_chunk = tokenizer.decode(
83
+ translated_chunk[0], skip_special_tokens=True
84
+ )
85
+ translated_sentences.append(translated_chunk)
 
 
86
 
87
  translated_paragraph = " ".join(translated_sentences)
88
  translated_paragraphs.append(translated_paragraph)