Spaces:
Running
Running
Fix input type
Browse files- translate_troch2trt.py +2 -2
translate_troch2trt.py
CHANGED
@@ -56,7 +56,7 @@ def main(
|
|
56 |
from torch2trt import torch2trt
|
57 |
|
58 |
model = torch2trt(
|
59 |
-
model, [torch.randn((batch_size, max_length)).to(device, dtype=
|
60 |
)
|
61 |
|
62 |
else:
|
@@ -73,7 +73,7 @@ def main(
|
|
73 |
) as output_file:
|
74 |
with torch.no_grad():
|
75 |
for batch in data_loader:
|
76 |
-
batch["input_ids"] = batch["input_ids"].to(device
|
77 |
generated_tokens = model.generate(
|
78 |
**batch, forced_bos_token_id=lang_code_to_idx
|
79 |
)
|
|
|
56 |
from torch2trt import torch2trt
|
57 |
|
58 |
model = torch2trt(
|
59 |
+
model, [torch.randn((batch_size, max_length)).to(device, dtype=torch.long)]
|
60 |
)
|
61 |
|
62 |
else:
|
|
|
73 |
) as output_file:
|
74 |
with torch.no_grad():
|
75 |
for batch in data_loader:
|
76 |
+
batch["input_ids"] = batch["input_ids"].to(device)
|
77 |
generated_tokens = model.generate(
|
78 |
**batch, forced_bos_token_id=lang_code_to_idx
|
79 |
)
|