Spaces:
Running
Running
| # coding: utf-8 | |
| import os | |
| import os.path as osp | |
| import time | |
| import random | |
| import numpy as np | |
| import random | |
| import soundfile as sf | |
| import librosa | |
| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| import torchaudio | |
| from torch.utils.data import DataLoader | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| logger.setLevel(logging.DEBUG) | |
| import pandas as pd | |
| _pad = "$" | |
| _punctuation = ';:,.!?¡¿—…"«»“” ' | |
| _letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" | |
| _letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ" | |
| # Export all symbols: | |
| symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa) | |
| dicts = {} | |
| for i in range(len((symbols))): | |
| dicts[symbols[i]] = i | |
| class TextCleaner: | |
| def __init__(self, dummy=None): | |
| self.word_index_dictionary = dicts | |
| def __call__(self, text): | |
| indexes = [] | |
| for char in text: | |
| try: | |
| indexes.append(self.word_index_dictionary[char]) | |
| except KeyError: | |
| print(text) | |
| return indexes | |
| np.random.seed(1) | |
| random.seed(1) | |
| SPECT_PARAMS = {"n_fft": 2048, "win_length": 1200, "hop_length": 300} | |
| MEL_PARAMS = { | |
| "n_mels": 80, | |
| } | |
| to_mel = torchaudio.transforms.MelSpectrogram( | |
| n_mels=80, n_fft=2048, win_length=1200, hop_length=300 | |
| ) | |
| mean, std = -4, 4 | |
| def preprocess(wave): | |
| wave_tensor = torch.from_numpy(wave).float() | |
| mel_tensor = to_mel(wave_tensor) | |
| mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std | |
| return mel_tensor | |
| class FilePathDataset(torch.utils.data.Dataset): | |
| def __init__( | |
| self, | |
| data_list, | |
| root_path, | |
| sr=24000, | |
| data_augmentation=False, | |
| validation=False, | |
| OOD_data="Data/OOD_texts.txt", | |
| min_length=50, | |
| ): | |
| spect_params = SPECT_PARAMS | |
| mel_params = MEL_PARAMS | |
| _data_list = [l[:-1].split("|") for l in data_list] | |
| self.data_list = [data if len(data) == 3 else (*data, 0) for data in _data_list] | |
| self.text_cleaner = TextCleaner() | |
| self.sr = sr | |
| self.df = pd.DataFrame(self.data_list) | |
| self.to_melspec = torchaudio.transforms.MelSpectrogram(**MEL_PARAMS) | |
| self.mean, self.std = -4, 4 | |
| self.data_augmentation = data_augmentation and (not validation) | |
| self.max_mel_length = 192 | |
| self.min_length = min_length | |
| with open(OOD_data, "r") as f: | |
| tl = f.readlines() | |
| idx = 1 if ".wav" in tl[0].split("|")[0] else 0 | |
| self.ptexts = [t.split("|")[idx] for t in tl] | |
| self.root_path = root_path | |
| def __len__(self): | |
| return len(self.data_list) | |
| def __getitem__(self, idx): | |
| data = self.data_list[idx] | |
| path = data[0] | |
| wave, text_tensor, speaker_id = self._load_tensor(data) | |
| mel_tensor = preprocess(wave).squeeze() | |
| acoustic_feature = mel_tensor.squeeze() | |
| length_feature = acoustic_feature.size(1) | |
| acoustic_feature = acoustic_feature[:, : (length_feature - length_feature % 2)] | |
| # get reference sample | |
| ref_data = (self.df[self.df[2] == str(speaker_id)]).sample(n=1).iloc[0].tolist() | |
| ref_mel_tensor, ref_label = self._load_data(ref_data[:3]) | |
| # get OOD text | |
| ps = "" | |
| while len(ps) < self.min_length: | |
| rand_idx = np.random.randint(0, len(self.ptexts) - 1) | |
| ps = self.ptexts[rand_idx] | |
| text = self.text_cleaner(ps) | |
| text.insert(0, 0) | |
| text.append(0) | |
| ref_text = torch.LongTensor(text) | |
| return ( | |
| speaker_id, | |
| acoustic_feature, | |
| text_tensor, | |
| ref_text, | |
| ref_mel_tensor, | |
| ref_label, | |
| path, | |
| wave, | |
| ) | |
| def _load_tensor(self, data): | |
| wave_path, text, speaker_id = data | |
| speaker_id = int(speaker_id) | |
| wave, sr = sf.read(osp.join(self.root_path, wave_path)) | |
| if wave.shape[-1] == 2: | |
| wave = wave[:, 0].squeeze() | |
| if sr != 24000: | |
| wave = librosa.resample(wave, orig_sr=sr, target_sr=24000) | |
| print(wave_path, sr) | |
| wave = np.concatenate([np.zeros([5000]), wave, np.zeros([5000])], axis=0) | |
| text = self.text_cleaner(text) | |
| text.insert(0, 0) | |
| text.append(0) | |
| text = torch.LongTensor(text) | |
| return wave, text, speaker_id | |
| def _load_data(self, data): | |
| wave, text_tensor, speaker_id = self._load_tensor(data) | |
| mel_tensor = preprocess(wave).squeeze() | |
| mel_length = mel_tensor.size(1) | |
| if mel_length > self.max_mel_length: | |
| random_start = np.random.randint(0, mel_length - self.max_mel_length) | |
| mel_tensor = mel_tensor[ | |
| :, random_start : random_start + self.max_mel_length | |
| ] | |
| return mel_tensor, speaker_id | |
| class Collater(object): | |
| """ | |
| Args: | |
| adaptive_batch_size (bool): if true, decrease batch size when long data comes. | |
| """ | |
| def __init__(self, return_wave=False): | |
| self.text_pad_index = 0 | |
| self.min_mel_length = 192 | |
| self.max_mel_length = 192 | |
| self.return_wave = return_wave | |
| def __call__(self, batch): | |
| # batch[0] = wave, mel, text, f0, speakerid | |
| batch_size = len(batch) | |
| # sort by mel length | |
| lengths = [b[1].shape[1] for b in batch] | |
| batch_indexes = np.argsort(lengths)[::-1] | |
| batch = [batch[bid] for bid in batch_indexes] | |
| nmels = batch[0][1].size(0) | |
| max_mel_length = max([b[1].shape[1] for b in batch]) | |
| max_text_length = max([b[2].shape[0] for b in batch]) | |
| max_rtext_length = max([b[3].shape[0] for b in batch]) | |
| labels = torch.zeros((batch_size)).long() | |
| mels = torch.zeros((batch_size, nmels, max_mel_length)).float() | |
| texts = torch.zeros((batch_size, max_text_length)).long() | |
| ref_texts = torch.zeros((batch_size, max_rtext_length)).long() | |
| input_lengths = torch.zeros(batch_size).long() | |
| ref_lengths = torch.zeros(batch_size).long() | |
| output_lengths = torch.zeros(batch_size).long() | |
| ref_mels = torch.zeros((batch_size, nmels, self.max_mel_length)).float() | |
| ref_labels = torch.zeros((batch_size)).long() | |
| paths = ["" for _ in range(batch_size)] | |
| waves = [None for _ in range(batch_size)] | |
| for bid, ( | |
| label, | |
| mel, | |
| text, | |
| ref_text, | |
| ref_mel, | |
| ref_label, | |
| path, | |
| wave, | |
| ) in enumerate(batch): | |
| mel_size = mel.size(1) | |
| text_size = text.size(0) | |
| rtext_size = ref_text.size(0) | |
| labels[bid] = label | |
| mels[bid, :, :mel_size] = mel | |
| texts[bid, :text_size] = text | |
| ref_texts[bid, :rtext_size] = ref_text | |
| input_lengths[bid] = text_size | |
| ref_lengths[bid] = rtext_size | |
| output_lengths[bid] = mel_size | |
| paths[bid] = path | |
| ref_mel_size = ref_mel.size(1) | |
| ref_mels[bid, :, :ref_mel_size] = ref_mel | |
| ref_labels[bid] = ref_label | |
| waves[bid] = wave | |
| return ( | |
| waves, | |
| texts, | |
| input_lengths, | |
| ref_texts, | |
| ref_lengths, | |
| mels, | |
| output_lengths, | |
| ref_mels, | |
| ) | |
| def build_dataloader( | |
| path_list, | |
| root_path, | |
| validation=False, | |
| OOD_data="Data/OOD_texts.txt", | |
| min_length=50, | |
| batch_size=4, | |
| num_workers=1, | |
| device="cpu", | |
| collate_config={}, | |
| dataset_config={}, | |
| ): | |
| dataset = FilePathDataset( | |
| path_list, | |
| root_path, | |
| OOD_data=OOD_data, | |
| min_length=min_length, | |
| validation=validation, | |
| **dataset_config | |
| ) | |
| collate_fn = Collater(**collate_config) | |
| data_loader = DataLoader( | |
| dataset, | |
| batch_size=batch_size, | |
| shuffle=(not validation), | |
| num_workers=num_workers, | |
| drop_last=(not validation), | |
| collate_fn=collate_fn, | |
| pin_memory=(device != "cpu"), | |
| ) | |
| return data_loader | |