Spaces:
Paused
Paused
| import functools | |
| import random | |
| import unittest | |
| import torch | |
| from TTS.config.shared_configs import BaseDatasetConfig | |
| from TTS.tts.datasets import load_tts_samples | |
| from TTS.tts.utils.data import get_length_balancer_weights | |
| from TTS.tts.utils.languages import get_language_balancer_weights | |
| from TTS.tts.utils.speakers import get_speaker_balancer_weights | |
| from TTS.utils.samplers import BucketBatchSampler, PerfectBatchSampler | |
| # Fixing random state to avoid random fails | |
| torch.manual_seed(0) | |
| dataset_config_en = BaseDatasetConfig( | |
| formatter="ljspeech", | |
| meta_file_train="metadata.csv", | |
| meta_file_val="metadata.csv", | |
| path="tests/data/ljspeech", | |
| language="en", | |
| ) | |
| dataset_config_pt = BaseDatasetConfig( | |
| formatter="ljspeech", | |
| meta_file_train="metadata.csv", | |
| meta_file_val="metadata.csv", | |
| path="tests/data/ljspeech", | |
| language="pt-br", | |
| ) | |
| # Adding the EN samples twice to create a language unbalanced dataset | |
| train_samples, eval_samples = load_tts_samples( | |
| [dataset_config_en, dataset_config_en, dataset_config_pt], eval_split=True | |
| ) | |
| # gerenate a speaker unbalanced dataset | |
| for i, sample in enumerate(train_samples): | |
| if i < 5: | |
| sample["speaker_name"] = "ljspeech-0" | |
| else: | |
| sample["speaker_name"] = "ljspeech-1" | |
| def is_balanced(lang_1, lang_2): | |
| return 0.85 < lang_1 / lang_2 < 1.2 | |
| class TestSamplers(unittest.TestCase): | |
| def test_language_random_sampler(self): # pylint: disable=no-self-use | |
| random_sampler = torch.utils.data.RandomSampler(train_samples) | |
| ids = functools.reduce(lambda a, b: a + b, [list(random_sampler) for i in range(100)]) | |
| en, pt = 0, 0 | |
| for index in ids: | |
| if train_samples[index]["language"] == "en": | |
| en += 1 | |
| else: | |
| pt += 1 | |
| assert not is_balanced(en, pt), "Random sampler is supposed to be unbalanced" | |
| def test_language_weighted_random_sampler(self): # pylint: disable=no-self-use | |
| weighted_sampler = torch.utils.data.sampler.WeightedRandomSampler( | |
| get_language_balancer_weights(train_samples), len(train_samples) | |
| ) | |
| ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)]) | |
| en, pt = 0, 0 | |
| for index in ids: | |
| if train_samples[index]["language"] == "en": | |
| en += 1 | |
| else: | |
| pt += 1 | |
| assert is_balanced(en, pt), "Language Weighted sampler is supposed to be balanced" | |
| def test_speaker_weighted_random_sampler(self): # pylint: disable=no-self-use | |
| weighted_sampler = torch.utils.data.sampler.WeightedRandomSampler( | |
| get_speaker_balancer_weights(train_samples), len(train_samples) | |
| ) | |
| ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)]) | |
| spk1, spk2 = 0, 0 | |
| for index in ids: | |
| if train_samples[index]["speaker_name"] == "ljspeech-0": | |
| spk1 += 1 | |
| else: | |
| spk2 += 1 | |
| assert is_balanced(spk1, spk2), "Speaker Weighted sampler is supposed to be balanced" | |
| def test_perfect_sampler(self): # pylint: disable=no-self-use | |
| classes = set() | |
| for item in train_samples: | |
| classes.add(item["speaker_name"]) | |
| sampler = PerfectBatchSampler( | |
| train_samples, | |
| classes, | |
| batch_size=2 * 3, # total batch size | |
| num_classes_in_batch=2, | |
| label_key="speaker_name", | |
| shuffle=False, | |
| drop_last=True, | |
| ) | |
| batchs = functools.reduce(lambda a, b: a + b, [list(sampler) for i in range(100)]) | |
| for batch in batchs: | |
| spk1, spk2 = 0, 0 | |
| # for in each batch | |
| for index in batch: | |
| if train_samples[index]["speaker_name"] == "ljspeech-0": | |
| spk1 += 1 | |
| else: | |
| spk2 += 1 | |
| assert spk1 == spk2, "PerfectBatchSampler is supposed to be perfectly balanced" | |
| def test_perfect_sampler_shuffle(self): # pylint: disable=no-self-use | |
| classes = set() | |
| for item in train_samples: | |
| classes.add(item["speaker_name"]) | |
| sampler = PerfectBatchSampler( | |
| train_samples, | |
| classes, | |
| batch_size=2 * 3, # total batch size | |
| num_classes_in_batch=2, | |
| label_key="speaker_name", | |
| shuffle=True, | |
| drop_last=False, | |
| ) | |
| batchs = functools.reduce(lambda a, b: a + b, [list(sampler) for i in range(100)]) | |
| for batch in batchs: | |
| spk1, spk2 = 0, 0 | |
| # for in each batch | |
| for index in batch: | |
| if train_samples[index]["speaker_name"] == "ljspeech-0": | |
| spk1 += 1 | |
| else: | |
| spk2 += 1 | |
| assert spk1 == spk2, "PerfectBatchSampler is supposed to be perfectly balanced" | |
| def test_length_weighted_random_sampler(self): # pylint: disable=no-self-use | |
| for _ in range(1000): | |
| # gerenate a lenght unbalanced dataset with random max/min audio lenght | |
| min_audio = random.randrange(1, 22050) | |
| max_audio = random.randrange(44100, 220500) | |
| for idx, item in enumerate(train_samples): | |
| # increase the diversity of durations | |
| random_increase = random.randrange(100, 1000) | |
| if idx < 5: | |
| item["audio_length"] = min_audio + random_increase | |
| else: | |
| item["audio_length"] = max_audio + random_increase | |
| weighted_sampler = torch.utils.data.sampler.WeightedRandomSampler( | |
| get_length_balancer_weights(train_samples, num_buckets=2), len(train_samples) | |
| ) | |
| ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)]) | |
| len1, len2 = 0, 0 | |
| for index in ids: | |
| if train_samples[index]["audio_length"] < max_audio: | |
| len1 += 1 | |
| else: | |
| len2 += 1 | |
| assert is_balanced(len1, len2), "Length Weighted sampler is supposed to be balanced" | |
| def test_bucket_batch_sampler(self): | |
| bucket_size_multiplier = 2 | |
| sampler = range(len(train_samples)) | |
| sampler = BucketBatchSampler( | |
| sampler, | |
| data=train_samples, | |
| batch_size=7, | |
| drop_last=True, | |
| sort_key=lambda x: len(x["text"]), | |
| bucket_size_multiplier=bucket_size_multiplier, | |
| ) | |
| # check if the samples are sorted by text lenght whuile bucketing | |
| min_text_len_in_bucket = 0 | |
| bucket_items = [] | |
| for batch_idx, batch in enumerate(list(sampler)): | |
| if (batch_idx + 1) % bucket_size_multiplier == 0: | |
| for bucket_item in bucket_items: | |
| self.assertLessEqual(min_text_len_in_bucket, len(train_samples[bucket_item]["text"])) | |
| min_text_len_in_bucket = len(train_samples[bucket_item]["text"]) | |
| min_text_len_in_bucket = 0 | |
| bucket_items = [] | |
| else: | |
| bucket_items += batch | |
| # check sampler length | |
| self.assertEqual(len(sampler), len(train_samples) // 7) | |