Iker commited on
Commit
745c4a7
·
1 Parent(s): 236bc54
Files changed (1) hide show
  1. translate_troch2trt.py +1 -1
translate_troch2trt.py CHANGED
@@ -73,7 +73,7 @@ def main(
73
  ) as output_file:
74
  with torch.no_grad():
75
  for batch in data_loader:
76
- batch = batch.to(device, dtype=dtype)
77
  generated_tokens = model.generate(
78
  **batch, forced_bos_token_id=lang_code_to_idx
79
  )
 
73
  ) as output_file:
74
  with torch.no_grad():
75
  for batch in data_loader:
76
+ batch["input_ids"] = batch["input_ids"].to(device, dtype=dtype)
77
  generated_tokens = model.generate(
78
  **batch, forced_bos_token_id=lang_code_to_idx
79
  )