Iker commited on
Commit
24a9f47
·
1 Parent(s): 745c4a7

Fix input type

Browse files
Files changed (1) hide show
  1. translate_troch2trt.py +2 -2
translate_troch2trt.py CHANGED
@@ -56,7 +56,7 @@ def main(
56
  from torch2trt import torch2trt
57
 
58
  model = torch2trt(
59
- model, [torch.randn((batch_size, max_length)).to(device, dtype=dtype)]
60
  )
61
 
62
  else:
@@ -73,7 +73,7 @@ def main(
73
  ) as output_file:
74
  with torch.no_grad():
75
  for batch in data_loader:
76
- batch["input_ids"] = batch["input_ids"].to(device, dtype=dtype)
77
  generated_tokens = model.generate(
78
  **batch, forced_bos_token_id=lang_code_to_idx
79
  )
 
56
  from torch2trt import torch2trt
57
 
58
  model = torch2trt(
59
+ model, [torch.randn((batch_size, max_length)).to(device, dtype=torch.long)]
60
  )
61
 
62
  else:
 
73
  ) as output_file:
74
  with torch.no_grad():
75
  for batch in data_loader:
76
+ batch["input_ids"] = batch["input_ids"].to(device)
77
  generated_tokens = model.generate(
78
  **batch, forced_bos_token_id=lang_code_to_idx
79
  )