Spaces:
Paused
Paused
| import os | |
| import shutil | |
| import numpy as np | |
| from torch.utils.data import DataLoader | |
| from tests import get_tests_output_path, get_tests_path | |
| from TTS.utils.audio import AudioProcessor | |
| from TTS.vocoder.configs import WavernnConfig | |
| from TTS.vocoder.datasets.preprocess import load_wav_feat_data, preprocess_wav_files | |
| from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset | |
| file_path = os.path.dirname(os.path.realpath(__file__)) | |
| OUTPATH = os.path.join(get_tests_output_path(), "loader_tests/") | |
| os.makedirs(OUTPATH, exist_ok=True) | |
| C = WavernnConfig() | |
| test_data_path = os.path.join(get_tests_path(), "data/ljspeech/") | |
| test_mel_feat_path = os.path.join(test_data_path, "mel") | |
| test_quant_feat_path = os.path.join(test_data_path, "quant") | |
| ok_ljspeech = os.path.exists(test_data_path) | |
| def wavernn_dataset_case(batch_size, seq_len, hop_len, pad, mode, mulaw, num_workers): | |
| """run dataloader with given parameters and check conditions""" | |
| ap = AudioProcessor(**C.audio) | |
| C.batch_size = batch_size | |
| C.mode = mode | |
| C.seq_len = seq_len | |
| C.data_path = test_data_path | |
| preprocess_wav_files(test_data_path, C, ap) | |
| _, train_items = load_wav_feat_data(test_data_path, test_mel_feat_path, 5) | |
| dataset = WaveRNNDataset( | |
| ap=ap, items=train_items, seq_len=seq_len, hop_len=hop_len, pad=pad, mode=mode, mulaw=mulaw | |
| ) | |
| # sampler = DistributedSampler(dataset) if num_gpus > 1 else None | |
| loader = DataLoader( | |
| dataset, | |
| shuffle=True, | |
| collate_fn=dataset.collate, | |
| batch_size=batch_size, | |
| num_workers=num_workers, | |
| pin_memory=True, | |
| ) | |
| max_iter = 10 | |
| count_iter = 0 | |
| try: | |
| for data in loader: | |
| x_input, mels, _ = data | |
| expected_feat_shape = (ap.num_mels, (x_input.shape[-1] // hop_len) + (pad * 2)) | |
| assert np.all(mels.shape[1:] == expected_feat_shape), f" [!] {mels.shape} vs {expected_feat_shape}" | |
| assert (mels.shape[2] - pad * 2) * hop_len == x_input.shape[1] | |
| count_iter += 1 | |
| if count_iter == max_iter: | |
| break | |
| # except AssertionError: | |
| # shutil.rmtree(test_mel_feat_path) | |
| # shutil.rmtree(test_quant_feat_path) | |
| finally: | |
| shutil.rmtree(test_mel_feat_path) | |
| shutil.rmtree(test_quant_feat_path) | |
| def test_parametrized_wavernn_dataset(): | |
| """test dataloader with different parameters""" | |
| params = [ | |
| [16, C.audio["hop_length"] * 10, C.audio["hop_length"], 2, 10, True, 0], | |
| [16, C.audio["hop_length"] * 10, C.audio["hop_length"], 2, "mold", False, 4], | |
| [1, C.audio["hop_length"] * 10, C.audio["hop_length"], 2, 9, False, 0], | |
| [1, C.audio["hop_length"], C.audio["hop_length"], 2, 10, True, 0], | |
| [1, C.audio["hop_length"], C.audio["hop_length"], 2, "mold", False, 0], | |
| [1, C.audio["hop_length"] * 5, C.audio["hop_length"], 4, 10, False, 2], | |
| [1, C.audio["hop_length"] * 5, C.audio["hop_length"], 2, "mold", False, 0], | |
| ] | |
| for param in params: | |
| print(param) | |
| wavernn_dataset_case(*param) | |