Spaces:
Runtime error
Runtime error
import os | |
import gin | |
import numpy as np | |
import pytorch_lightning as pl | |
import torch | |
class GeneralDataset(torch.utils.data.Dataset): | |
def __init__(self, path: str, split: str = "train", load_to_memory: bool = True): | |
super().__init__() | |
# split = "train" | |
self.load_to_memory = load_to_memory | |
self.split_path = os.path.join(path, split) | |
self.data_list = [ | |
f.replace("audio_", "") | |
for f in os.listdir(os.path.join(self.split_path, "audio")) | |
if f[-4:] == ".npy" | |
] | |
if load_to_memory: | |
self.audio = [ | |
np.load(os.path.join(self.split_path, "audio", "audio_%s" % name)) | |
for name in self.data_list | |
] | |
self.control = [ | |
np.load(os.path.join(self.split_path, "control", "control_%s" % name)) | |
for name in self.data_list | |
] | |
self.data_mean = np.load(os.path.join(path, "data_mean.npy")) | |
self.data_std = np.load(os.path.join(path, "data_std.npy")) | |
def __len__(self): | |
return len(self.data_list) | |
def __getitem__(self, idx): | |
# idx = 10 | |
name = self.data_list[idx] | |
if self.load_to_memory: | |
audio = self.audio[idx] | |
control = self.control[idx] | |
else: | |
audio_name = "audio_%s" % name | |
control_name = "control_%s" % name | |
audio = np.load(os.path.join(self.split_path, "audio", audio_name)) | |
control = np.load(os.path.join(self.split_path, "control", control_name)) | |
denormalised_control = (control * self.data_std) + self.data_mean | |
return { | |
"audio": audio, | |
"f0": denormalised_control[0:1, :], | |
"amp": denormalised_control[1:2, :], | |
"control": control, | |
"name": os.path.splitext(os.path.basename(name))[0], | |
} | |
class GeneralDataModule(pl.LightningDataModule): | |
def __init__( | |
self, | |
data_root: str, | |
batch_size: int = 16, | |
load_to_memory: bool = True, | |
**dataloader_args | |
): | |
super().__init__() | |
self.data_dir = data_root | |
self.batch_size = batch_size | |
self.dataloader_args = dataloader_args | |
self.load_to_memory = load_to_memory | |
def prepare_data(self): | |
pass | |
def setup(self, stage: str = None): | |
if stage == "fit": | |
self.urmp_train = GeneralDataset(self.data_dir, "train", self.load_to_memory) | |
self.urmp_val = GeneralDataset(self.data_dir, "val", self.load_to_memory) | |
elif stage == "test" or stage is None: | |
self.urmp_test = GeneralDataset(self.data_dir, "test", self.load_to_memory) | |
def _make_dataloader(self, dataset): | |
return torch.utils.data.DataLoader( | |
dataset, self.batch_size, **self.dataloader_args | |
) | |
def train_dataloader(self): | |
return self._make_dataloader(self.urmp_train) | |
def val_dataloader(self): | |
return self._make_dataloader(self.urmp_val) | |
def test_dataloader(self): | |
return self._make_dataloader(self.urmp_test) | |