TTTS / ttts /classifier /dataset.py
mrfakename's picture
Add source code
4ee33aa
raw
history blame
2.38 kB
import json
import os
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.data
import torchaudio
import torchvision
from tqdm import tqdm
class PreprocessedMelDataset(torch.utils.data.Dataset):
def __init__(self, opt):
super().__init__()
clean = opt['dataset']['clean']
noise = opt['dataset']['noise']
self.labels = []
self.paths = []
with open(clean) as file:
for line in file:
line=line.strip()
if line.endswith('.wav'):
self.paths.append(line.replace('.wav','.mel.path'))
self.labels += [0]
else:
self.paths += [str(p) for p in Path(line).rglob("*.mel.pth")]
self.labels += [0 for _ in range(len(self.paths) - len(self.labels))]
with open(noise) as file:
for line in file:
line=line.strip()
if line.endswith('.wav'):
self.paths.append(line+'.mel.pth')
self.labels += [1]
else:
self.paths += [str(p) for p in Path(line).rglob("*.mel.pth")]
self.labels += [1 for _ in range(len(self.paths) - len(self.labels))]
self.pad_to = opt['dataset']['pad_to_samples']
self.squeeze = opt['dataset']['should_squeeze']
def __getitem__(self, index):
mel = torch.load(self.paths[index])
if mel.shape[-1] >= self.pad_to:
start = torch.randint(0, mel.shape[-1] - self.pad_to+1, (1,))
mel = mel[:, :, start:start+self.pad_to]
else:
padding_needed = self.pad_to - mel.shape[-1]
mel = F.pad(mel, (0,padding_needed))
assert mel.shape[-1] == self.pad_to
if self.squeeze:
mel = mel.squeeze()
label = self.labels[index]
return mel, label
def __len__(self):
return len(self.paths)
if __name__ == '__main__':
cfg = json.load(open('ttts/classifier/config.json'))
ds = PreprocessedMelDataset(cfg)
dl = torch.utils.data.DataLoader(ds, **cfg['dataloader'])
i = 0
for _, b in tqdm(enumerate(dl)):
#pass
torchvision.utils.save_image((b['mel']+1)/2, f'{i}.png')
i += 1
if i > 20:
break