File size: 4,581 Bytes
20d6bb2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import torch
import librosa
import numpy as np
import random
import os
from torch.utils.data import DataLoader
from modules.audio import mel_spectrogram
duration_setting = {
"min": 1.0,
"max": 30.0,
}
# assume single speaker
class FT_Dataset(torch.utils.data.Dataset):
def __init__(self,
data_path,
spect_params,
sr=22050,
batch_size=1,
):
self.data_path = data_path
# recursively find all files in data_path
self.data = []
for root, _, files in os.walk(data_path):
for file in files:
if (file.endswith(".wav") or
file.endswith(".mp3") or
file.endswith(".flac") or
file.endswith(".ogg") or
file.endswith(".m4a") or
file.endswith(".opus")):
self.data.append(os.path.join(root, file))
mel_fn_args = {
"n_fft": spect_params['n_fft'],
"win_size": spect_params['win_length'],
"hop_size": spect_params['hop_length'],
"num_mels": spect_params['n_mels'],
"sampling_rate": sr,
"fmin": spect_params['fmin'],
"fmax": None if spect_params['fmax'] == "None" else spect_params['fmax'],
"center": False
}
self.to_mel = lambda x: mel_spectrogram(x, **mel_fn_args)
self.sr = sr
assert len(self.data) != 0
# if dataset length is less than batch size, repeat the dataset
while len(self.data) < batch_size:
self.data += self.data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
idx = idx % len(self.data)
wav_path = self.data[idx]
try:
speech, orig_sr = librosa.load(wav_path, sr=self.sr)
except Exception as e:
print(f"Failed to load wav file with error {e}")
return self.__getitem__(random.randint(0, len(self)))
if len(speech) < self.sr * duration_setting["min"] or len(speech) > self.sr * duration_setting["max"]:
print(f"Audio {wav_path} is too short or too long, skipping")
return self.__getitem__(random.randint(0, len(self)))
return_dict = {
'audio': speech,
'sr': orig_sr
}
wave, orig_sr = return_dict['audio'], return_dict['sr']
if orig_sr != self.sr:
wave = librosa.resample(wave, orig_sr, self.sr)
wave = torch.from_numpy(wave).float()
mel = self.to_mel(wave.unsqueeze(0)).squeeze(0)
return wave, mel
def build_ft_dataloader(data_path, spect_params, sr, batch_size=1, num_workers=0):
dataset = FT_Dataset(data_path, spect_params, sr, batch_size)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
collate_fn=collate,
)
return dataloader
def collate(batch):
batch_size = len(batch)
# sort by mel length
lengths = [b[1].shape[1] for b in batch]
batch_indexes = np.argsort(lengths)[::-1]
batch = [batch[bid] for bid in batch_indexes]
nmels = batch[0][1].size(0)
max_mel_length = max([b[1].shape[1] for b in batch])
max_wave_length = max([b[0].size(0) for b in batch])
mels = torch.zeros((batch_size, nmels, max_mel_length)).float() - 10
waves = torch.zeros((batch_size, max_wave_length)).float()
mel_lengths = torch.zeros(batch_size).long()
wave_lengths = torch.zeros(batch_size).long()
for bid, (wave, mel) in enumerate(batch):
mel_size = mel.size(1)
mels[bid, :, :mel_size] = mel
waves[bid, : wave.size(0)] = wave
mel_lengths[bid] = mel_size
wave_lengths[bid] = wave.size(0)
return waves, mels, wave_lengths, mel_lengths
if __name__ == "__main__":
data_path = "./example/reference"
sr = 22050
spect_params = {
"n_fft": 1024,
"win_length": 1024,
"hop_length": 256,
"n_mels": 80,
"fmin": 0,
"fmax": 8000,
}
dataloader = build_ft_dataloader(data_path, spect_params, sr, batch_size=2, num_workers=0)
for idx, batch in enumerate(dataloader):
wave, mel, wave_lengths, mel_lengths = batch
print(wave.shape, mel.shape)
if idx == 10:
break |