Spaces:
Running
Running
Fix bug
Browse files- translate_troch2trt.py +3 -1
translate_troch2trt.py
CHANGED
@@ -55,8 +55,10 @@ def main(
|
|
55 |
device = "cuda"
|
56 |
from torch2trt import torch2trt
|
57 |
|
|
|
|
|
58 |
model = torch2trt(
|
59 |
-
model
|
60 |
[torch.randn((batch_size, max_length)).to(device, dtype=torch.long)],
|
61 |
)
|
62 |
|
|
|
55 |
device = "cuda"
|
56 |
from torch2trt import torch2trt
|
57 |
|
58 |
+
model.to(device, dtype=dtype)
|
59 |
+
|
60 |
model = torch2trt(
|
61 |
+
model,
|
62 |
[torch.randn((batch_size, max_length)).to(device, dtype=torch.long)],
|
63 |
)
|
64 |
|