Spaces:
Running
Running
Move attention mask to device
Browse files- translate_troch2trt.py +1 -0
translate_troch2trt.py
CHANGED
@@ -74,6 +74,7 @@ def main(
|
|
74 |
with torch.no_grad():
|
75 |
for batch in data_loader:
|
76 |
batch["input_ids"] = batch["input_ids"].to(device)
|
|
|
77 |
generated_tokens = model.generate(
|
78 |
**batch, forced_bos_token_id=lang_code_to_idx
|
79 |
)
|
|
|
74 |
with torch.no_grad():
|
75 |
for batch in data_loader:
|
76 |
batch["input_ids"] = batch["input_ids"].to(device)
|
77 |
+
batch["attention_mask"] = batch["attention_mask"].to(device)
|
78 |
generated_tokens = model.generate(
|
79 |
**batch, forced_bos_token_id=lang_code_to_idx
|
80 |
)
|