Spaces:
Running
Running
import os | |
import random | |
import numpy as np | |
import torch | |
from grad.utils import fix_len_compatibility | |
from grad_extend.utils import parse_filelist | |
class TextMelSpeakerDataset(torch.utils.data.Dataset): | |
def __init__(self, filelist_path): | |
super().__init__() | |
self.filelist = parse_filelist(filelist_path, split_char='|') | |
self._filter() | |
print(f'----------{len(self.filelist)}----------') | |
def _filter(self): | |
items_new = [] | |
# segment = 200 | |
items_min = 250 # 10ms * 250 = 2.5 S | |
items_max = 500 # 10ms * 400 = 5.0 S | |
for mel, vec, pit, spk in self.filelist: | |
if not os.path.isfile(mel): | |
continue | |
if not os.path.isfile(vec): | |
continue | |
if not os.path.isfile(pit): | |
continue | |
if not os.path.isfile(spk): | |
continue | |
temp = np.load(pit) | |
usel = int(temp.shape[0] - 1) # useful length | |
if (usel < items_min): | |
continue | |
if (usel >= items_max): | |
usel = items_max | |
items_new.append([mel, vec, pit, spk, usel]) | |
self.filelist = items_new | |
def get_triplet(self, item): | |
# print(item) | |
mel = item[0] | |
vec = item[1] | |
pit = item[2] | |
spk = item[3] | |
use = item[4] | |
mel = torch.load(mel) | |
vec = np.load(vec) | |
vec = np.repeat(vec, 2, 0) # 320 VEC -> 160 * 2 | |
pit = np.load(pit) | |
spk = np.load(spk) | |
vec = torch.FloatTensor(vec) | |
pit = torch.FloatTensor(pit) | |
spk = torch.FloatTensor(spk) | |
vec = vec + torch.randn_like(vec) # Perturbation | |
len_vec = vec.size()[0] - 2 # for safe | |
len_pit = pit.size()[0] | |
len_min = min(len_pit, len_vec) | |
mel = mel[:, :len_min] | |
vec = vec[:len_min, :] | |
pit = pit[:len_min] | |
if len_min > use: | |
max_frame_start = vec.size(0) - use - 1 | |
frame_start = random.randint(0, max_frame_start) | |
frame_end = frame_start + use | |
mel = mel[:, frame_start:frame_end] | |
vec = vec[frame_start:frame_end, :] | |
pit = pit[frame_start:frame_end] | |
# print(mel.shape) | |
# print(vec.shape) | |
# print(pit.shape) | |
# print(spk.shape) | |
return (mel, vec, pit, spk) | |
def __getitem__(self, index): | |
mel, vec, pit, spk = self.get_triplet(self.filelist[index]) | |
item = {'mel': mel, 'vec': vec, 'pit': pit, 'spk': spk} | |
return item | |
def __len__(self): | |
return len(self.filelist) | |
def sample_test_batch(self, size): | |
idx = np.random.choice(range(len(self)), size=size, replace=False) | |
test_batch = [] | |
for index in idx: | |
test_batch.append(self.__getitem__(index)) | |
return test_batch | |
class TextMelSpeakerBatchCollate(object): | |
# mel: [freq, length] | |
# vec: [len, 256] | |
# pit: [len] | |
# spk: [256] | |
def __call__(self, batch): | |
B = len(batch) | |
mel_max_length = max([item['mel'].shape[-1] for item in batch]) | |
max_length = fix_len_compatibility(mel_max_length) | |
d_mel = batch[0]['mel'].shape[0] | |
d_vec = batch[0]['vec'].shape[1] | |
d_spk = batch[0]['spk'].shape[0] | |
# print("d_mel", d_mel) | |
# print("d_vec", d_vec) | |
# print("d_spk", d_spk) | |
mel = torch.zeros((B, d_mel, max_length), dtype=torch.float32) | |
vec = torch.zeros((B, max_length, d_vec), dtype=torch.float32) | |
pit = torch.zeros((B, max_length), dtype=torch.float32) | |
spk = torch.zeros((B, d_spk), dtype=torch.float32) | |
lengths = torch.LongTensor(B) | |
for i, item in enumerate(batch): | |
y_, x_, p_, s_ = item['mel'], item['vec'], item['pit'], item['spk'] | |
mel[i, :, :y_.shape[1]] = y_ | |
vec[i, :x_.shape[0], :] = x_ | |
pit[i, :p_.shape[0]] = p_ | |
spk[i] = s_ | |
lengths[i] = y_.shape[1] | |
# print("lengths", lengths.shape) | |
# print("vec", vec.shape) | |
# print("pit", pit.shape) | |
# print("spk", spk.shape) | |
# print("mel", mel.shape) | |
return {'lengths': lengths, 'vec': vec, 'pit': pit, 'spk': spk, 'mel': mel} | |