Spaces:
Running
Running
Fix: Remove duplicates when using more than 1 device
Browse files- translate.py +17 -4
translate.py
CHANGED
@@ -134,11 +134,17 @@ def main(
|
|
134 |
|
135 |
model, data_loader = accelerator.prepare(model, data_loader)
|
136 |
|
|
|
|
|
137 |
with tqdm(
|
138 |
-
total=total_lines,
|
|
|
|
|
|
|
|
|
139 |
) as pbar, open(output_path, "w", encoding="utf-8") as output_file:
|
140 |
with torch.no_grad():
|
141 |
-
for batch in data_loader:
|
142 |
batch["input_ids"] = batch["input_ids"]
|
143 |
batch["attention_mask"] = batch["attention_mask"]
|
144 |
|
@@ -157,8 +163,15 @@ def main(
|
|
157 |
tgt_text = tokenizer.batch_decode(
|
158 |
generated_tokens, skip_special_tokens=True
|
159 |
)
|
160 |
-
|
161 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
|
163 |
pbar.update(len(tgt_text))
|
164 |
|
|
|
134 |
|
135 |
model, data_loader = accelerator.prepare(model, data_loader)
|
136 |
|
137 |
+
samples_seen: int = 0
|
138 |
+
|
139 |
with tqdm(
|
140 |
+
total=total_lines,
|
141 |
+
desc="Dataset translation",
|
142 |
+
leave=True,
|
143 |
+
ascii=True,
|
144 |
+
disable=(not accelerator.is_main_process),
|
145 |
) as pbar, open(output_path, "w", encoding="utf-8") as output_file:
|
146 |
with torch.no_grad():
|
147 |
+
for step, batch in enumerate(data_loader):
|
148 |
batch["input_ids"] = batch["input_ids"]
|
149 |
batch["attention_mask"] = batch["attention_mask"]
|
150 |
|
|
|
163 |
tgt_text = tokenizer.batch_decode(
|
164 |
generated_tokens, skip_special_tokens=True
|
165 |
)
|
166 |
+
if accelerator.is_main_process:
|
167 |
+
if step == len(data_loader) - 1:
|
168 |
+
tgt_text = tgt_text[
|
169 |
+
: len(data_loader.dataset) - samples_seen
|
170 |
+
]
|
171 |
+
else:
|
172 |
+
samples_seen += len(tgt_text)
|
173 |
+
|
174 |
+
print("\n".join(tgt_text), file=output_file)
|
175 |
|
176 |
pbar.update(len(tgt_text))
|
177 |
|