Iker commited on
Commit
a911872
·
1 Parent(s): ab7a062
Files changed (1) hide show
  1. translate_troch2trt.py +2 -1
translate_troch2trt.py CHANGED
@@ -56,7 +56,8 @@ def main(
56
  from torch2trt import torch2trt
57
 
58
  model = torch2trt(
59
- model, [torch.randn((batch_size, max_length)).to(device, dtype=torch.long)]
 
60
  )
61
 
62
  else:
 
56
  from torch2trt import torch2trt
57
 
58
  model = torch2trt(
59
+ model.to(device, dtype=dtype),
60
+ [torch.randn((batch_size, max_length)).to(device, dtype=torch.long)],
61
  )
62
 
63
  else: