Iker commited on
Commit
dacb273
·
1 Parent(s): a911872
Files changed (1) hide show
  1. translate_troch2trt.py +3 -1
translate_troch2trt.py CHANGED
@@ -55,8 +55,10 @@ def main(
55
  device = "cuda"
56
  from torch2trt import torch2trt
57
 
 
 
58
  model = torch2trt(
59
- model.to(device, dtype=dtype),
60
  [torch.randn((batch_size, max_length)).to(device, dtype=torch.long)],
61
  )
62
 
 
55
  device = "cuda"
56
  from torch2trt import torch2trt
57
 
58
+ model.to(device, dtype=dtype)
59
+
60
  model = torch2trt(
61
+ model,
62
  [torch.randn((batch_size, max_length)).to(device, dtype=torch.long)],
63
  )
64