|
import pandas as pd
|
|
import os
|
|
import random
|
|
import ast
|
|
import numpy as np
|
|
import torch
|
|
from einops import repeat, rearrange
|
|
import librosa
|
|
|
|
from torch.utils.data import Dataset
|
|
import torchaudio
|
|
|
|
|
|
class DreamData(Dataset):
|
|
def __init__(self, data_dir, meta_dir, subset, prompt_dir,):
|
|
self.datadir = data_dir
|
|
meta = pd.read_csv(meta_dir)
|
|
self.meta = meta[meta['subset'] == subset]
|
|
self.subset = subset
|
|
self.prompts = pd.read_csv(prompt_dir)
|
|
|
|
def __getitem__(self, index):
|
|
row = self.meta.iloc[index]
|
|
|
|
|
|
spk_path = self.datadir + row['spk_path']
|
|
spk = torch.load(spk_path, map_location='cpu').squeeze(0)
|
|
|
|
speaker = row['speaker']
|
|
|
|
|
|
prompt = self.prompts[self.prompts['speaker_id'] == str(speaker)].sample(1)['prompt'].iloc[0]
|
|
return spk, prompt
|
|
|
|
def __len__(self):
|
|
return len(self.meta) |