|
import os |
|
import numpy as np |
|
import pickle |
|
from torch.utils import data |
|
import torchaudio.transforms as T |
|
import torchaudio |
|
import torch |
|
import csv |
|
import pytorch_lightning as pl |
|
from music2latent import EncoderDecoder |
|
import json |
|
import math |
|
from sklearn.preprocessing import StandardScaler |
|
|
|
from dataset_loaders.jamendo import JamendoDataset |
|
from dataset_loaders.pmemo import PMEmoDataset |
|
from dataset_loaders.deam import DEAMDataset |
|
from dataset_loaders.emomusic import EmoMusicDataset |
|
|
|
from omegaconf import DictConfig |
|
|
|
DATASET_REGISTRY = { |
|
"jamendo": JamendoDataset, |
|
"pmemo": PMEmoDataset, |
|
"deam": DEAMDataset, |
|
"emomusic": EmoMusicDataset |
|
} |
|
|
|
class DataModule(pl.LightningDataModule): |
|
def __init__(self, cfg: DictConfig): |
|
super().__init__() |
|
self.cfg = cfg |
|
|
|
self.train_datasets = [] |
|
self.val_datasets = [] |
|
self.test_datasets = [] |
|
|
|
def setup(self, stage=None): |
|
|
|
self.train_datasets = [] |
|
self.val_datasets = [] |
|
self.test_datasets = [] |
|
|
|
|
|
for dataset_name in self.cfg.datasets: |
|
dataset_cfg = self.cfg.dataset[dataset_name] |
|
|
|
if dataset_name in DATASET_REGISTRY: |
|
train_dataset = DATASET_REGISTRY[dataset_name](**dataset_cfg, cfg=self.cfg, tr_val='train') |
|
val_dataset = DATASET_REGISTRY[dataset_name](**dataset_cfg, cfg=self.cfg, tr_val='validation') |
|
test_dataset = DATASET_REGISTRY[dataset_name](**dataset_cfg, cfg=self.cfg, tr_val='test') |
|
|
|
self.train_datasets.append(train_dataset) |
|
self.val_datasets.append(val_dataset) |
|
self.test_datasets.append(test_dataset) |
|
else: |
|
raise ValueError(f"Dataset {dataset_name} not found in registry") |
|
|
|
def train_dataloader(self): |
|
return [data.DataLoader(ds, batch_size=self.cfg.dataset[ds_name].batch_size, |
|
shuffle=True, num_workers=self.cfg.dataset[ds_name].num_workers, |
|
persistent_workers=True) |
|
for ds, ds_name in zip(self.train_datasets, self.cfg.datasets)] |
|
|
|
def val_dataloader(self): |
|
return [data.DataLoader(ds, batch_size=self.cfg.dataset[ds_name].batch_size, |
|
shuffle=False, num_workers=self.cfg.dataset[ds_name].num_workers, |
|
persistent_workers=True) |
|
for ds, ds_name in zip(self.val_datasets, self.cfg.datasets)] |
|
|
|
def test_dataloader(self): |
|
return [data.DataLoader(ds, batch_size=self.cfg.dataset[ds_name].batch_size, |
|
shuffle=False, num_workers=self.cfg.dataset[ds_name].num_workers, |
|
persistent_workers=True) |
|
for ds, ds_name in zip(self.test_datasets, self.cfg.datasets)] |
|
|
|
|
|
|
|
|