Iker commited on
Commit
4db19c9
·
1 Parent(s): ed535ee

Fix: Remove duplicates when using more than 1 device

Browse files
Files changed (1) hide show
  1. translate.py +17 -4
translate.py CHANGED
@@ -134,11 +134,17 @@ def main(
134
 
135
  model, data_loader = accelerator.prepare(model, data_loader)
136
 
 
 
137
  with tqdm(
138
- total=total_lines, desc="Dataset translation", leave=True, ascii=True
 
 
 
 
139
  ) as pbar, open(output_path, "w", encoding="utf-8") as output_file:
140
  with torch.no_grad():
141
- for batch in data_loader:
142
  batch["input_ids"] = batch["input_ids"]
143
  batch["attention_mask"] = batch["attention_mask"]
144
 
@@ -157,8 +163,15 @@ def main(
157
  tgt_text = tokenizer.batch_decode(
158
  generated_tokens, skip_special_tokens=True
159
  )
160
-
161
- print("\n".join(tgt_text), file=output_file)
 
 
 
 
 
 
 
162
 
163
  pbar.update(len(tgt_text))
164
 
 
134
 
135
  model, data_loader = accelerator.prepare(model, data_loader)
136
 
137
+ samples_seen: int = 0
138
+
139
  with tqdm(
140
+ total=total_lines,
141
+ desc="Dataset translation",
142
+ leave=True,
143
+ ascii=True,
144
+ disable=(not accelerator.is_main_process),
145
  ) as pbar, open(output_path, "w", encoding="utf-8") as output_file:
146
  with torch.no_grad():
147
+ for step, batch in enumerate(data_loader):
148
  batch["input_ids"] = batch["input_ids"]
149
  batch["attention_mask"] = batch["attention_mask"]
150
 
 
163
  tgt_text = tokenizer.batch_decode(
164
  generated_tokens, skip_special_tokens=True
165
  )
166
+ if accelerator.is_main_process:
167
+ if step == len(data_loader) - 1:
168
+ tgt_text = tgt_text[
169
+ : len(data_loader.dataset) - samples_seen
170
+ ]
171
+ else:
172
+ samples_seen += len(tgt_text)
173
+
174
+ print("\n".join(tgt_text), file=output_file)
175
 
176
  pbar.update(len(tgt_text))
177