Iker commited on
Commit
ab7a062
·
1 Parent(s): 24a9f47

Move attention mask to device

Browse files
Files changed (1) hide show
  1. translate_troch2trt.py +1 -0
translate_troch2trt.py CHANGED
@@ -74,6 +74,7 @@ def main(
74
  with torch.no_grad():
75
  for batch in data_loader:
76
  batch["input_ids"] = batch["input_ids"].to(device)
 
77
  generated_tokens = model.generate(
78
  **batch, forced_bos_token_id=lang_code_to_idx
79
  )
 
74
  with torch.no_grad():
75
  for batch in data_loader:
76
  batch["input_ids"] = batch["input_ids"].to(device)
77
+ batch["attention_mask"] = batch["attention_mask"].to(device)
78
  generated_tokens = model.generate(
79
  **batch, forced_bos_token_id=lang_code_to_idx
80
  )