mattricesound commited on
Commit
d60a776
·
1 Parent(s): 9325b1e

Add random chunk choice

Browse files
Files changed (1) hide show
  1. remfx/datasets.py +20 -20
remfx/datasets.py CHANGED
@@ -211,26 +211,26 @@ class EffectDataset(Dataset):
211
  chunks, orig_sr = create_sequential_chunks(
212
  random_file_choice, self.chunk_size
213
  )
214
- for chunk in chunks:
215
- resampled_chunk = torchaudio.functional.resample(
216
- chunk, orig_sr, sample_rate
217
- )
218
- if resampled_chunk.shape[-1] < chunk_size:
219
- # Skip if chunk is too small
220
- continue
221
- # Sum to mono
222
- if resampled_chunk.shape[0] > 1:
223
- resampled_chunk = resampled_chunk.sum(0, keepdim=True)
224
-
225
- dry, wet, dry_effects, wet_effects = self.process_effects(
226
- resampled_chunk
227
- )
228
- output_dir = self.proc_root / str(num_chunk)
229
- output_dir.mkdir(exist_ok=True)
230
- torchaudio.save(output_dir / "input.wav", wet, self.sample_rate)
231
- torchaudio.save(output_dir / "target.wav", dry, self.sample_rate)
232
- torch.save(dry_effects, output_dir / "dry_effects.pt")
233
- torch.save(wet_effects, output_dir / "wet_effects.pt")
234
 
235
  print("Finished rendering")
236
  else:
 
211
  chunks, orig_sr = create_sequential_chunks(
212
  random_file_choice, self.chunk_size
213
  )
214
+ random_chunk = random.choice(chunks)
215
+ resampled_chunk = torchaudio.functional.resample(
216
+ random_chunk, orig_sr, sample_rate
217
+ )
218
+ if resampled_chunk.shape[-1] < chunk_size:
219
+ # Skip if chunk is too small
220
+ continue
221
+ # Sum to mono
222
+ if resampled_chunk.shape[0] > 1:
223
+ resampled_chunk = resampled_chunk.sum(0, keepdim=True)
224
+
225
+ dry, wet, dry_effects, wet_effects = self.process_effects(
226
+ resampled_chunk
227
+ )
228
+ output_dir = self.proc_root / str(num_chunk)
229
+ output_dir.mkdir(exist_ok=True)
230
+ torchaudio.save(output_dir / "input.wav", wet, self.sample_rate)
231
+ torchaudio.save(output_dir / "target.wav", dry, self.sample_rate)
232
+ torch.save(dry_effects, output_dir / "dry_effects.pt")
233
+ torch.save(wet_effects, output_dir / "wet_effects.pt")
234
 
235
  print("Finished rendering")
236
  else: