Spaces:
Running
Running
Fix bug
Browse files- translate_troch2trt.py +2 -1
translate_troch2trt.py
CHANGED
|
@@ -56,7 +56,8 @@ def main(
|
|
| 56 |
from torch2trt import torch2trt
|
| 57 |
|
| 58 |
model = torch2trt(
|
| 59 |
-
model
|
|
|
|
| 60 |
)
|
| 61 |
|
| 62 |
else:
|
|
|
|
| 56 |
from torch2trt import torch2trt
|
| 57 |
|
| 58 |
model = torch2trt(
|
| 59 |
+
model.to(device, dtype=dtype),
|
| 60 |
+
[torch.randn((batch_size, max_length)).to(device, dtype=torch.long)],
|
| 61 |
)
|
| 62 |
|
| 63 |
else:
|