Spaces:
Sleeping
Sleeping
File size: 6,980 Bytes
7d6db8f 8125531 14ae0ea c7866f1 14ae0ea a89496d c7866f1 d9f47ef 7173e65 8125531 c7866f1 14ae0ea e0a5f6f 14ae0ea c7866f1 b175ee9 e0a5f6f d8d3e30 c7866f1 7173e65 6990e4a e0a5f6f 6990e4a d8d3e30 e0a5f6f c7866f1 e0a5f6f c7866f1 57c446b c7866f1 7173e65 8125531 c7866f1 8125531 d8d3e30 7173e65 d8d3e30 4a7a6b8 c7866f1 7173e65 c7866f1 7173e65 e0a5f6f c7866f1 7bb4fe3 7173e65 c7866f1 7173e65 8125531 e0a5f6f c7866f1 e0a5f6f 8125531 e0a5f6f 8125531 e0a5f6f 8125531 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 |
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import torchaudio
from pathlib import Path
import pytorch_lightning as pl
import sys
from typing import Any, Dict
from remfx import effects
from tqdm import tqdm
from remfx.utils import create_sequential_chunks
import shutil
# https://zenodo.org/record/1193957 -> VocalSet
ALL_EFFECTS = effects.Pedalboard_Effects
class VocalSet(Dataset):
def __init__(
self,
root: str,
sample_rate: int,
chunk_size: int = 3,
applied_effects: Dict[str, torch.nn.Module] = None,
effect_to_remove: Dict[str, torch.nn.Module] = None,
max_effects_per_file: int = 1,
render_files: bool = True,
render_root: str = None,
mode: str = "train",
):
super().__init__()
self.chunks = []
self.song_idx = []
self.root = Path(root)
self.render_root = Path(render_root)
self.chunk_size = chunk_size
self.sample_rate = sample_rate
self.mode = mode
self.max_effects_per_file = max_effects_per_file
self.effect_to_remove = effect_to_remove
mode_path = self.root / self.mode
self.files = sorted(list(mode_path.glob("./**/*.wav")))
self.normalize = effects.LoudnessNormalize(sample_rate, target_lufs_db=-20)
self.applied_effects = applied_effects
self.effect_to_remove_name = list(effect_to_remove.keys())[0]
effect_str = "_".join([e for e in self.applied_effects])
effect_str += f"_{self.effect_to_remove_name}"
self.proc_root = self.render_root / "processed" / effect_str / self.mode
if self.proc_root.exists() and len(list(self.proc_root.iterdir())) > 0:
print("Found processed files.")
if render_files:
re_render = input(
"WARNING: By default, will re-render files.\n"
"Set render_files=False to skip re-rendering.\n"
"Are you sure you want to re-render? (y/n): "
)
if re_render != "y":
sys.exit()
shutil.rmtree(self.proc_root)
self.num_chunks = 0
print("Total files:", len(self.files))
print("Processing files...")
if render_files:
# Split audio file into chunks, resample, then apply random effects
self.proc_root.mkdir(parents=True, exist_ok=True)
for audio_file in tqdm(self.files, total=len(self.files)):
chunks, orig_sr = create_sequential_chunks(audio_file, self.chunk_size)
for chunk in chunks:
resampled_chunk = torchaudio.functional.resample(
chunk, orig_sr, sample_rate
)
if resampled_chunk.shape[-1] < chunk_size:
# Skip if chunk is too small
continue
x, y, effect = self.process_effects(resampled_chunk)
output_dir = self.proc_root / str(self.num_chunks)
output_dir.mkdir(exist_ok=True)
torchaudio.save(output_dir / "input.wav", x, self.sample_rate)
torchaudio.save(output_dir / "target.wav", y, self.sample_rate)
torch.save(effect, output_dir / "effect.pt")
self.num_chunks += 1
else:
self.num_chunks = len(list(self.proc_root.iterdir()))
print(
f"Found {len(self.files)} {self.mode} files .\n"
f"Total chunks: {self.num_chunks}"
)
def __len__(self):
return self.num_chunks
def __getitem__(self, idx):
input_file = self.proc_root / str(idx) / "input.wav"
target_file = self.proc_root / str(idx) / "target.wav"
effect_name = torch.load(self.proc_root / str(idx) / "effect.pt")
input, sr = torchaudio.load(input_file)
target, sr = torchaudio.load(target_file)
return (input, target, effect_name)
def process_effects(self, dry: torch.Tensor):
# Apply random number of effects up to num_effects - 1 (excluding effect_to_remove)
if self.max_effects_per_file > 1:
num_effects = torch.randint(self.max_effects_per_file - 1, (1,)).item()
# Remove effect to remove from applied effects if present
self.applied_effects.pop(self.effect_to_remove_name, None)
# Choose random effects to apply
effect_indices = torch.randperm(len(self.applied_effects.keys()))[
:num_effects
]
effects_to_apply = [
list(self.applied_effects.keys())[i] for i in effect_indices
]
labels = []
for effect_name in effects_to_apply:
effect = self.applied_effects[effect_name]
dry = effect(dry)
labels.append(ALL_EFFECTS.index(type(effect)))
# Apply effect_to_remove
effect = self.effect_to_remove[self.effect_to_remove_name]
wet = effect(torch.clone(dry))
labels.append(ALL_EFFECTS.index(type(effect)))
# Convert labels to one-hot
one_hot = F.one_hot(torch.tensor(labels), num_classes=len(ALL_EFFECTS))
effects_present = torch.sum(one_hot, dim=0).float()
# Normalize
normalized_dry = self.normalize(dry)
normalized_wet = self.normalize(wet)
return normalized_dry, normalized_wet, effects_present
class VocalSetDatamodule(pl.LightningDataModule):
def __init__(
self,
train_dataset,
val_dataset,
test_dataset,
*,
batch_size: int,
num_workers: int,
pin_memory: bool = False,
**kwargs: int,
) -> None:
super().__init__()
self.train_dataset = train_dataset
self.val_dataset = val_dataset
self.test_dataset = test_dataset
self.batch_size = batch_size
self.num_workers = num_workers
self.pin_memory = pin_memory
def setup(self, stage: Any = None) -> None:
pass
def train_dataloader(self) -> DataLoader:
return DataLoader(
dataset=self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
shuffle=True,
)
def val_dataloader(self) -> DataLoader:
return DataLoader(
dataset=self.val_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
shuffle=False,
)
def test_dataloader(self) -> DataLoader:
return DataLoader(
dataset=self.test_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
shuffle=False,
)
|