ruanchaves commited on
Commit
086ac2b
·
1 Parent(s): f016c88

feat: Add --keep_special_tokens argument to control special token decoding

Browse files
Files changed (1) hide show
  1. translate.py +9 -1
translate.py CHANGED
@@ -75,6 +75,7 @@ def main(
75
  temperature: float = 1.0,
76
  top_k: int = 50,
77
  top_p: float = 1.0,
 
78
  ):
79
 
80
  os.makedirs(os.path.abspath(os.path.dirname(output_path)), exist_ok=True)
@@ -196,7 +197,7 @@ def main(
196
  )
197
 
198
  tgt_text = tokenizer.batch_decode(
199
- generated_tokens, skip_special_tokens=True
200
  )
201
  if accelerator.is_main_process:
202
  if (
@@ -335,6 +336,12 @@ if __name__ == "__main__":
335
  help="If do_sample is True, will sample from the top k most likely tokens.",
336
  )
337
 
 
 
 
 
 
 
338
  args = parser.parse_args()
339
 
340
  main(
@@ -353,4 +360,5 @@ if __name__ == "__main__":
353
  temperature=args.temperature,
354
  top_k=args.top_k,
355
  top_p=args.top_p,
 
356
  )
 
75
  temperature: float = 1.0,
76
  top_k: int = 50,
77
  top_p: float = 1.0,
78
+ keep_special_tokens: bool = False,
79
  ):
80
 
81
  os.makedirs(os.path.abspath(os.path.dirname(output_path)), exist_ok=True)
 
197
  )
198
 
199
  tgt_text = tokenizer.batch_decode(
200
+ generated_tokens, skip_special_tokens=not keep_special_tokens
201
  )
202
  if accelerator.is_main_process:
203
  if (
 
336
  help="If do_sample is True, will sample from the top k most likely tokens.",
337
  )
338
 
339
+ parser.add_argument(
340
+ "--keep_special_tokens",
341
+ action="store_true",
342
+ help="Keep special tokens in the decoded text.",
343
+ )
344
+
345
  args = parser.parse_args()
346
 
347
  main(
 
360
  temperature=args.temperature,
361
  top_k=args.top_k,
362
  top_p=args.top_p,
363
+ keep_special_tokens=args.keep_special_tokens
364
  )