Spaces:
Build error
Build error
import json | |
import os | |
from pathlib import Path | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
import torch.utils.data | |
import torchaudio | |
import torchvision | |
from tqdm import tqdm | |
from ttts.classifier.infer import read_jsonl | |
class PreprocessedMelDataset(torch.utils.data.Dataset): | |
def __init__(self, opt): | |
# cache_path = opt['dataset']['cache_path'] # Will fail when multiple paths specified, must be specified in this case. | |
# if os.path.exists(cache_path): | |
# self.paths = torch.load(cache_path) | |
# else: | |
# print("Building cache..") | |
# path = Path(path) | |
# self.paths = [str(p) for p in path.rglob("*.mel.pth")] | |
# torch.save(self.paths, cache_path) | |
paths = read_jsonl(opt['dataset']['path']) | |
pre = os.path.expanduser(opt['dataset']['pre']) | |
self.paths = [os.path.join(pre,d['path'])+'.mel.pth' for d in paths] | |
self.pad_to = opt['dataset']['pad_to_samples'] | |
self.squeeze = opt['dataset']['should_squeeze'] | |
def __getitem__(self, index): | |
try: | |
mel = torch.load(self.paths[index]) | |
except: | |
mel = torch.zeros(1,100,self.pad_to) | |
if mel.shape[-1] >= self.pad_to: | |
start = torch.randint(0, mel.shape[-1] - self.pad_to+1, (1,)) | |
mel = mel[:, :, start:start+self.pad_to] | |
mask = torch.zeros_like(mel) | |
else: | |
mask = torch.zeros_like(mel) | |
padding_needed = self.pad_to - mel.shape[-1] | |
mel = F.pad(mel, (0,padding_needed)) | |
mask = F.pad(mask, (0,padding_needed), value=1) | |
assert mel.shape[-1] == self.pad_to | |
if self.squeeze: | |
mel = mel.squeeze() | |
return mel | |
def __len__(self): | |
return len(self.paths) | |
if __name__ == '__main__': | |
params = { | |
'mode': 'preprocessed_mel', | |
'path': 'Y:\\separated\\large_mel_cheaters', | |
'cache_path': 'Y:\\separated\\large_mel_cheaters_win.pth', | |
'pad_to_samples': 646, | |
'phase': 'train', | |
'n_workers': 0, | |
'batch_size': 16, | |
} | |
cfg = json.load(open('vqvae/config.json')) | |
ds = PreprocessedMelDataset(cfg) | |
dl = torch.utils.data.DataLoader(ds, **cfg['dataloader']) | |
i = 0 | |
for b in tqdm(dl): | |
#pass | |
torchvision.utils.save_image((b['mel']+1)/2, f'{i}.png') | |
i += 1 | |
if i > 20: | |
break |