Spaces:
Running
Running
Move batch to device
Browse files- translate_troch2trt.py +7 -3
translate_troch2trt.py
CHANGED
@@ -52,24 +52,28 @@ def main(
|
|
52 |
raise ValueError("Precision must be 16, 32 or 64.")
|
53 |
|
54 |
if tensorrt:
|
|
|
55 |
from torch2trt import torch2trt
|
56 |
|
57 |
model = torch2trt(
|
58 |
-
model, [torch.randn((batch_size, max_length)).to(
|
59 |
)
|
60 |
|
61 |
else:
|
62 |
if torch.cuda.is_available():
|
63 |
-
|
|
|
64 |
else:
|
65 |
-
|
66 |
print("CUDA not available. Using CPU. This will be slow.")
|
|
|
67 |
|
68 |
with tqdm(total=total_lines, desc="Dataset translation") as pbar, open(
|
69 |
output_path, "w+", encoding="utf-8"
|
70 |
) as output_file:
|
71 |
with torch.no_grad():
|
72 |
for batch in data_loader:
|
|
|
73 |
generated_tokens = model.generate(
|
74 |
**batch, forced_bos_token_id=lang_code_to_idx
|
75 |
)
|
|
|
52 |
raise ValueError("Precision must be 16, 32 or 64.")
|
53 |
|
54 |
if tensorrt:
|
55 |
+
device = "cuda"
|
56 |
from torch2trt import torch2trt
|
57 |
|
58 |
model = torch2trt(
|
59 |
+
model, [torch.randn((batch_size, max_length)).to(device, dtype=dtype)]
|
60 |
)
|
61 |
|
62 |
else:
|
63 |
if torch.cuda.is_available():
|
64 |
+
device = "cuda"
|
65 |
+
|
66 |
else:
|
67 |
+
device = "cpu"
|
68 |
print("CUDA not available. Using CPU. This will be slow.")
|
69 |
+
model.to(device, dtype=dtype)
|
70 |
|
71 |
with tqdm(total=total_lines, desc="Dataset translation") as pbar, open(
|
72 |
output_path, "w+", encoding="utf-8"
|
73 |
) as output_file:
|
74 |
with torch.no_grad():
|
75 |
for batch in data_loader:
|
76 |
+
batch = batch.to(device, dtype=dtype)
|
77 |
generated_tokens = model.generate(
|
78 |
**batch, forced_bos_token_id=lang_code_to_idx
|
79 |
)
|