Spaces:
Sleeping
Sleeping
Commit
·
9072475
1
Parent(s):
78820af
Switch to sequential chunks
Browse files- remfx/datasets.py +19 -10
remfx/datasets.py
CHANGED
@@ -20,7 +20,6 @@ class GuitarFXDataset(Dataset):
|
|
20 |
sample_rate: int,
|
21 |
length: int = LENGTH,
|
22 |
chunk_size_in_sec: int = 3,
|
23 |
-
num_chunks: int = 10,
|
24 |
effect_types: List[str] = None,
|
25 |
):
|
26 |
self.length = length
|
@@ -30,7 +29,6 @@ class GuitarFXDataset(Dataset):
|
|
30 |
self.labels = []
|
31 |
self.root = Path(root)
|
32 |
self.chunk_size_in_sec = chunk_size_in_sec
|
33 |
-
self.num_chunks = num_chunks
|
34 |
|
35 |
if effect_types is None:
|
36 |
effect_types = [
|
@@ -46,10 +44,10 @@ class GuitarFXDataset(Dataset):
|
|
46 |
self.dry_files += dry_files
|
47 |
self.labels += [i] * len(wet_files)
|
48 |
for audio_file in wet_files:
|
49 |
-
|
50 |
audio_file, self.chunk_size_in_sec, self.num_chunks
|
51 |
)
|
52 |
-
self.chunks +=
|
53 |
print(
|
54 |
f"Found {len(self.wet_files)} wet files and {len(self.dry_files)} dry files.\n"
|
55 |
f"Total chunks: {len(self.chunks)}"
|
@@ -67,8 +65,9 @@ class GuitarFXDataset(Dataset):
|
|
67 |
effect_label = self.labels[song_idx] # Effect label
|
68 |
|
69 |
chunk_indices = self.chunks[idx]
|
70 |
-
|
71 |
-
|
|
|
72 |
|
73 |
resampled_x = self.resampler(x)
|
74 |
resampled_y = self.resampler(y)
|
@@ -83,8 +82,9 @@ class GuitarFXDataset(Dataset):
|
|
83 |
def create_random_chunks(
|
84 |
audio_file: str, chunk_size: int, num_chunks: int
|
85 |
) -> List[Tuple[int, int]]:
|
86 |
-
"""Create random chunks of size chunk_size (seconds)
|
87 |
-
|
|
|
88 |
"""
|
89 |
audio, sr = torchaudio.load(audio_file)
|
90 |
chunk_size_in_samples = chunk_size * sr
|
@@ -93,11 +93,20 @@ def create_random_chunks(
|
|
93 |
chunks = []
|
94 |
for i in range(num_chunks):
|
95 |
start = torch.randint(0, audio.shape[-1] - chunk_size_in_samples, (1,)).item()
|
96 |
-
|
97 |
-
chunks.append((start, end))
|
98 |
return chunks
|
99 |
|
100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
class Datamodule(pl.LightningDataModule):
|
102 |
def __init__(
|
103 |
self,
|
|
|
20 |
sample_rate: int,
|
21 |
length: int = LENGTH,
|
22 |
chunk_size_in_sec: int = 3,
|
|
|
23 |
effect_types: List[str] = None,
|
24 |
):
|
25 |
self.length = length
|
|
|
29 |
self.labels = []
|
30 |
self.root = Path(root)
|
31 |
self.chunk_size_in_sec = chunk_size_in_sec
|
|
|
32 |
|
33 |
if effect_types is None:
|
34 |
effect_types = [
|
|
|
44 |
self.dry_files += dry_files
|
45 |
self.labels += [i] * len(wet_files)
|
46 |
for audio_file in wet_files:
|
47 |
+
chunk_starts = create_sequential_chunks(
|
48 |
audio_file, self.chunk_size_in_sec, self.num_chunks
|
49 |
)
|
50 |
+
self.chunks += chunk_starts
|
51 |
print(
|
52 |
f"Found {len(self.wet_files)} wet files and {len(self.dry_files)} dry files.\n"
|
53 |
f"Total chunks: {len(self.chunks)}"
|
|
|
65 |
effect_label = self.labels[song_idx] # Effect label
|
66 |
|
67 |
chunk_indices = self.chunks[idx]
|
68 |
+
chunk_size_in_samples = self.chunk_size * sr
|
69 |
+
x = x[:, chunk_indices[0] : chunk_indices[0] + chunk_size_in_samples]
|
70 |
+
y = y[:, chunk_indices[0] : chunk_indices[0] + chunk_size_in_samples]
|
71 |
|
72 |
resampled_x = self.resampler(x)
|
73 |
resampled_y = self.resampler(y)
|
|
|
82 |
def create_random_chunks(
|
83 |
audio_file: str, chunk_size: int, num_chunks: int
|
84 |
) -> List[Tuple[int, int]]:
|
85 |
+
"""Create num_chunks random chunks of size chunk_size (seconds)
|
86 |
+
from an audio file.
|
87 |
+
Return sample_index of start of each chunk
|
88 |
"""
|
89 |
audio, sr = torchaudio.load(audio_file)
|
90 |
chunk_size_in_samples = chunk_size * sr
|
|
|
93 |
chunks = []
|
94 |
for i in range(num_chunks):
|
95 |
start = torch.randint(0, audio.shape[-1] - chunk_size_in_samples, (1,)).item()
|
96 |
+
chunks.append(start)
|
|
|
97 |
return chunks
|
98 |
|
99 |
|
100 |
+
def create_sequential_chunks(audio_file: str, chunk_size: int) -> List[Tuple[int, int]]:
|
101 |
+
"""Create sequential chunks of size chunk_size (seconds) from an audio file.
|
102 |
+
Return sample_index of start of each chunk
|
103 |
+
"""
|
104 |
+
audio, sr = torchaudio.load(audio_file)
|
105 |
+
chunk_size_in_samples = chunk_size * sr
|
106 |
+
chunk_starts = torch.arange(0, audio.shape[-1], chunk_size_in_samples)
|
107 |
+
return chunk_starts
|
108 |
+
|
109 |
+
|
110 |
class Datamodule(pl.LightningDataModule):
|
111 |
def __init__(
|
112 |
self,
|