mattricesound commited on
Commit
9072475
·
1 Parent(s): 78820af

Switch to sequential chunks

Browse files
Files changed (1) hide show
  1. remfx/datasets.py +19 -10
remfx/datasets.py CHANGED
@@ -20,7 +20,6 @@ class GuitarFXDataset(Dataset):
20
  sample_rate: int,
21
  length: int = LENGTH,
22
  chunk_size_in_sec: int = 3,
23
- num_chunks: int = 10,
24
  effect_types: List[str] = None,
25
  ):
26
  self.length = length
@@ -30,7 +29,6 @@ class GuitarFXDataset(Dataset):
30
  self.labels = []
31
  self.root = Path(root)
32
  self.chunk_size_in_sec = chunk_size_in_sec
33
- self.num_chunks = num_chunks
34
 
35
  if effect_types is None:
36
  effect_types = [
@@ -46,10 +44,10 @@ class GuitarFXDataset(Dataset):
46
  self.dry_files += dry_files
47
  self.labels += [i] * len(wet_files)
48
  for audio_file in wet_files:
49
- chunks = create_random_chunks(
50
  audio_file, self.chunk_size_in_sec, self.num_chunks
51
  )
52
- self.chunks += chunks
53
  print(
54
  f"Found {len(self.wet_files)} wet files and {len(self.dry_files)} dry files.\n"
55
  f"Total chunks: {len(self.chunks)}"
@@ -67,8 +65,9 @@ class GuitarFXDataset(Dataset):
67
  effect_label = self.labels[song_idx] # Effect label
68
 
69
  chunk_indices = self.chunks[idx]
70
- x = x[:, chunk_indices[0] : chunk_indices[1]]
71
- y = y[:, chunk_indices[0] : chunk_indices[1]]
 
72
 
73
  resampled_x = self.resampler(x)
74
  resampled_y = self.resampler(y)
@@ -83,8 +82,9 @@ class GuitarFXDataset(Dataset):
83
  def create_random_chunks(
84
  audio_file: str, chunk_size: int, num_chunks: int
85
  ) -> List[Tuple[int, int]]:
86
- """Create random chunks of size chunk_size (seconds) from an audio file.
87
- Return sample_indices
 
88
  """
89
  audio, sr = torchaudio.load(audio_file)
90
  chunk_size_in_samples = chunk_size * sr
@@ -93,11 +93,20 @@ def create_random_chunks(
93
  chunks = []
94
  for i in range(num_chunks):
95
  start = torch.randint(0, audio.shape[-1] - chunk_size_in_samples, (1,)).item()
96
- end = start + chunk_size_in_samples
97
- chunks.append((start, end))
98
  return chunks
99
 
100
 
 
 
 
 
 
 
 
 
 
 
101
  class Datamodule(pl.LightningDataModule):
102
  def __init__(
103
  self,
 
20
  sample_rate: int,
21
  length: int = LENGTH,
22
  chunk_size_in_sec: int = 3,
 
23
  effect_types: List[str] = None,
24
  ):
25
  self.length = length
 
29
  self.labels = []
30
  self.root = Path(root)
31
  self.chunk_size_in_sec = chunk_size_in_sec
 
32
 
33
  if effect_types is None:
34
  effect_types = [
 
44
  self.dry_files += dry_files
45
  self.labels += [i] * len(wet_files)
46
  for audio_file in wet_files:
47
+ chunk_starts = create_sequential_chunks(
48
  audio_file, self.chunk_size_in_sec, self.num_chunks
49
  )
50
+ self.chunks += chunk_starts
51
  print(
52
  f"Found {len(self.wet_files)} wet files and {len(self.dry_files)} dry files.\n"
53
  f"Total chunks: {len(self.chunks)}"
 
65
  effect_label = self.labels[song_idx] # Effect label
66
 
67
  chunk_indices = self.chunks[idx]
68
+ chunk_size_in_samples = self.chunk_size * sr
69
+ x = x[:, chunk_indices[0] : chunk_indices[0] + chunk_size_in_samples]
70
+ y = y[:, chunk_indices[0] : chunk_indices[0] + chunk_size_in_samples]
71
 
72
  resampled_x = self.resampler(x)
73
  resampled_y = self.resampler(y)
 
82
  def create_random_chunks(
83
  audio_file: str, chunk_size: int, num_chunks: int
84
  ) -> List[Tuple[int, int]]:
85
+ """Create num_chunks random chunks of size chunk_size (seconds)
86
+ from an audio file.
87
+ Return sample_index of start of each chunk
88
  """
89
  audio, sr = torchaudio.load(audio_file)
90
  chunk_size_in_samples = chunk_size * sr
 
93
  chunks = []
94
  for i in range(num_chunks):
95
  start = torch.randint(0, audio.shape[-1] - chunk_size_in_samples, (1,)).item()
96
+ chunks.append(start)
 
97
  return chunks
98
 
99
 
100
+ def create_sequential_chunks(audio_file: str, chunk_size: int) -> List[Tuple[int, int]]:
101
+ """Create sequential chunks of size chunk_size (seconds) from an audio file.
102
+ Return sample_index of start of each chunk
103
+ """
104
+ audio, sr = torchaudio.load(audio_file)
105
+ chunk_size_in_samples = chunk_size * sr
106
+ chunk_starts = torch.arange(0, audio.shape[-1], chunk_size_in_samples)
107
+ return chunk_starts
108
+
109
+
110
  class Datamodule(pl.LightningDataModule):
111
  def __init__(
112
  self,