Spaces:
Running
Running
Fix bug
Browse files- translate_troch2trt.py +2 -1
translate_troch2trt.py
CHANGED
@@ -56,7 +56,8 @@ def main(
|
|
56 |
from torch2trt import torch2trt
|
57 |
|
58 |
model = torch2trt(
|
59 |
-
model
|
|
|
60 |
)
|
61 |
|
62 |
else:
|
|
|
56 |
from torch2trt import torch2trt
|
57 |
|
58 |
model = torch2trt(
|
59 |
+
model.to(device, dtype=dtype),
|
60 |
+
[torch.randn((batch_size, max_length)).to(device, dtype=torch.long)],
|
61 |
)
|
62 |
|
63 |
else:
|