English
music
emotion
music2emo / data_loader.py
kjysmu's picture
Upload 5 files
2dfd92b verified
import os
import numpy as np
import pickle
from torch.utils import data
import torchaudio.transforms as T
import torchaudio
import torch
import csv
import pytorch_lightning as pl
from music2latent import EncoderDecoder
import json
import math
from sklearn.preprocessing import StandardScaler
from dataset_loaders.jamendo import JamendoDataset
from dataset_loaders.pmemo import PMEmoDataset
from dataset_loaders.deam import DEAMDataset
from dataset_loaders.emomusic import EmoMusicDataset
from omegaconf import DictConfig
DATASET_REGISTRY = {
"jamendo": JamendoDataset,
"pmemo": PMEmoDataset,
"deam": DEAMDataset,
"emomusic": EmoMusicDataset
}
class DataModule(pl.LightningDataModule):
def __init__(self, cfg: DictConfig):
super().__init__()
self.cfg = cfg
self.train_datasets = []
self.val_datasets = []
self.test_datasets = []
def setup(self, stage=None):
# Clear previous dataset lists
self.train_datasets = []
self.val_datasets = []
self.test_datasets = []
# Register the datasets and load them
for dataset_name in self.cfg.datasets:
dataset_cfg = self.cfg.dataset[dataset_name]
if dataset_name in DATASET_REGISTRY:
train_dataset = DATASET_REGISTRY[dataset_name](**dataset_cfg, cfg=self.cfg, tr_val='train')
val_dataset = DATASET_REGISTRY[dataset_name](**dataset_cfg, cfg=self.cfg, tr_val='validation')
test_dataset = DATASET_REGISTRY[dataset_name](**dataset_cfg, cfg=self.cfg, tr_val='test')
self.train_datasets.append(train_dataset)
self.val_datasets.append(val_dataset)
self.test_datasets.append(test_dataset)
else:
raise ValueError(f"Dataset {dataset_name} not found in registry")
def train_dataloader(self):
return [data.DataLoader(ds, batch_size=self.cfg.dataset[ds_name].batch_size,
shuffle=True, num_workers=self.cfg.dataset[ds_name].num_workers,
persistent_workers=True)
for ds, ds_name in zip(self.train_datasets, self.cfg.datasets)]
def val_dataloader(self):
return [data.DataLoader(ds, batch_size=self.cfg.dataset[ds_name].batch_size,
shuffle=False, num_workers=self.cfg.dataset[ds_name].num_workers,
persistent_workers=True)
for ds, ds_name in zip(self.val_datasets, self.cfg.datasets)]
def test_dataloader(self):
return [data.DataLoader(ds, batch_size=self.cfg.dataset[ds_name].batch_size,
shuffle=False, num_workers=self.cfg.dataset[ds_name].num_workers,
persistent_workers=True)
for ds, ds_name in zip(self.test_datasets, self.cfg.datasets)]