Spaces:
Running
Running
File size: 2,722 Bytes
3dd84f8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
import os
import random
import json
import torch
from torch.utils.data import Dataset
from text import cleaned_text_to_sequence
def intersperse(lst: list, item: int):
"""
putting a blank token between any two input tokens to improve pronunciation
see https://github.com/jaywalnut310/glow-tts/issues/43 for more details
"""
result = [item] * (len(lst) * 2 + 1)
result[1::2] = lst
return result
class StableDataset(Dataset):
def __init__(self, filelist_path, hop_length):
self.filelist_path = filelist_path
self.hop_length = hop_length
self._load_filelist(filelist_path)
def _load_filelist(self, filelist_path):
filelist, lengths = [], []
with open(filelist_path, 'r', encoding='utf-8') as f:
for line in f:
line = json.loads(line.strip())
filelist.append((line['mel_path'], line['phone']))
lengths.append(line['mel_length'])
self.filelist = filelist
self.lengths = lengths # length is used for DistributedBucketSampler
def __len__(self):
return len(self.filelist)
def __getitem__(self, idx):
mel_path, phone = self.filelist[idx]
mel = torch.load(mel_path, map_location='cpu', weights_only=True)
phone = torch.tensor(intersperse(cleaned_text_to_sequence(phone), 0), dtype=torch.long)
return mel, phone
def collate_fn(batch):
texts = [item[1] for item in batch]
mels = [item[0] for item in batch]
mels_sliced = [random_slice_tensor(mel) for mel in mels]
text_lengths = torch.tensor([text.size(-1) for text in texts], dtype=torch.long)
mel_lengths = torch.tensor([mel.size(-1) for mel in mels], dtype=torch.long)
mels_sliced_lengths = torch.tensor([mel_sliced.size(-1) for mel_sliced in mels_sliced], dtype=torch.long)
# pad to the same length
texts_padded = torch.nested.to_padded_tensor(torch.nested.nested_tensor(texts), padding=0)
mels_padded = torch.nested.to_padded_tensor(torch.nested.nested_tensor(mels), padding=0)
mels_sliced_padded = torch.nested.to_padded_tensor(torch.nested.nested_tensor(mels_sliced), padding=0)
return texts_padded, text_lengths, mels_padded, mel_lengths, mels_sliced_padded, mels_sliced_lengths
# random slice mel for reference encoder to prevent overfitting
def random_slice_tensor(x: torch.Tensor):
length = x.size(-1)
if length < 8:
return x
segmnt_size = random.randint(length // 12, length // 3)
start = random.randint(0, length - segmnt_size)
return x[..., start : start + segmnt_size] |