Spaces:
Runtime error
Runtime error
| # Copyright (c) 2024 Alibaba Inc | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import logging | |
| import random | |
| import pyarrow.parquet as pq | |
| import torch | |
| import torchaudio | |
| from torch.nn.utils.rnn import pad_sequence | |
| import torch.nn.functional as F | |
| import numpy as np | |
| import re | |
| torchaudio.set_audio_backend('soundfile') | |
| AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'} | |
| CHORUS = {"intro": 0, "chorus": 1, "verse1": 2, "verse2": 3, "verse": 2, | |
| "outro": 4} | |
| metadata_pattern = re.compile(r'^\[(ti|ar|al|by|offset):.*\]$') | |
| timestamp_pattern = re.compile(r'^\[\d{2}:\d{2}\.\d{2}\](.*)$') | |
| def parquet_opener(data, mode='train', audio_data={}): | |
| """ Give url or local file, return file descriptor | |
| Inplace operation. | |
| Args: | |
| data(Iterable[str]): url or local file list | |
| Returns: | |
| Iterable[{src, stream}] | |
| """ | |
| for sample in data: | |
| assert 'src' in sample | |
| url = sample['src'] | |
| try: | |
| df = pq.read_table(url).to_pandas() | |
| for i in df.index: | |
| sample.update(dict(df.loc[i])) | |
| yield {**sample} | |
| except Exception as ex: | |
| logging.warning('Failed to open {}, ex info {}'.format(url, ex)) | |
| def clean_lyrics(data, mode="train"): | |
| for sample in data: | |
| lyrics = sample["text"] | |
| cleaned = [] | |
| for line in lyrics.splitlines(): | |
| if metadata_pattern.match(line): | |
| continue | |
| timestamp_match = timestamp_pattern.match(line) | |
| if timestamp_match: | |
| lyric = timestamp_match.group(1).strip() | |
| if lyric: | |
| cleaned.append(lyric) | |
| else: | |
| if line.strip(): | |
| cleaned.append(line.strip()) | |
| sample["text"] = '\n'.join(cleaned) | |
| yield sample | |
| def cut_by_length(data, max_length=8000, num_times=4, mode="train"): | |
| for sample in data: | |
| if "semantic_token" in sample: | |
| sample["semantic_token"] = [ | |
| sample["semantic_token"][0][:max_length]] | |
| if "acoustic_token" not in sample: | |
| sample["acoustic_token"] = sample["speech_token"] | |
| sample["acoustic_token"] = sample["acoustic_token"][ | |
| :max_length * num_times] | |
| yield sample | |
| def filter(data, | |
| max_length=22500, # 22500 #5min #10240 | |
| max_acoustic_length=45000, | |
| min_length=10, | |
| min_acoustic_length=150, | |
| token_max_length=200, | |
| token_min_length=1, | |
| min_output_input_ratio=0.0005, | |
| max_output_input_ratio=1, | |
| mode='train'): | |
| """ Filter sample according to feature and label length | |
| Inplace operation. | |
| Args:: | |
| data: Iterable[{key, wav, label, sample_rate}] | |
| max_length: drop utterance which is greater than max_length(10ms) | |
| min_length: drop utterance which is less than min_length(10ms) | |
| token_max_length: drop utterance which is greater than | |
| token_max_length, especially when use char unit for | |
| english modeling | |
| token_min_length: drop utterance which is | |
| less than token_max_length | |
| min_output_input_ratio: minimal ration of | |
| token_length / feats_length(10ms) | |
| max_output_input_ratio: maximum ration of | |
| token_length / feats_length(10ms) | |
| Returns: | |
| Iterable[{key, wav, label, sample_rate}] | |
| """ | |
| if mode == "train": | |
| for sample in data: | |
| if "semantic_token" in sample: | |
| new_sample_frames = sample['semantic_token'][0].shape[0] | |
| else: | |
| new_sample_frames = sample['speech_token'] | |
| if "text_token" in sample: | |
| new_sample_frames += len(sample['text_token']) | |
| if new_sample_frames > max_length or new_sample_frames < min_length: | |
| print(f"skipped 1 item length={new_sample_frames}") | |
| continue | |
| sample["chorus"] = sample["chorus"].split(",") | |
| if not isinstance(sample["time_start"], np.ndarray): | |
| sample["time_start"] = [sample["time_start"]] | |
| sample["time_end"] = [sample["time_end"]] | |
| for i, t in enumerate(sample["chorus"]): | |
| if sample["chorus"][i] == "verse": | |
| sample["chorus"][i] = "verse1" | |
| yield sample | |
| if mode == "train_flow": | |
| for sample in data: | |
| if "semantic_token" in sample: | |
| new_sample_frames = sample['semantic_token'][0].shape[0] | |
| if "acoustic_token" in sample: | |
| target_sample_frames = sample['acoustic_token'][0].shape[0] | |
| if new_sample_frames > max_length or new_sample_frames < min_acoustic_length or new_sample_frames < min_length or target_sample_frames > max_acoustic_length: | |
| print( | |
| f"skipped 1 item length={new_sample_frames}, target_length={target_sample_frames}") | |
| continue | |
| yield sample | |
| elif mode == "inference": | |
| for sample in data: | |
| yield sample | |
| def resample(data, resample_rate=22050, min_sample_rate=16000, mode='train'): | |
| """ Resample data. | |
| Inplace operation. | |
| Args: | |
| data: Iterable[{key, wav, label, sample_rate}] | |
| resample_rate: target resample rate | |
| Returns: | |
| Iterable[{key, wav, label, sample_rate}] | |
| """ | |
| for sample in data: | |
| assert 'sample_rate' in sample | |
| assert 'speech' in sample | |
| sample_rate = sample['sample_rate'] | |
| waveform = sample['speech'] | |
| if sample_rate != resample_rate: | |
| if sample_rate < min_sample_rate: | |
| continue | |
| sample['sample_rate'] = resample_rate | |
| sample['speech'] = torchaudio.transforms.Resample( | |
| orig_freq=sample_rate, new_freq=resample_rate)(waveform) | |
| max_val = sample['speech'].abs().max() | |
| if max_val > 1: | |
| sample['speech'] /= max_val | |
| yield sample | |
| def truncate(data, truncate_length=24576, mode='train'): | |
| """ Truncate data. | |
| Args: | |
| data: Iterable[{key, wav, label, sample_rate}] | |
| truncate_length: truncate length | |
| Returns: | |
| Iterable[{key, wav, label, sample_rate}] | |
| """ | |
| for sample in data: | |
| waveform = sample['audio'] | |
| if waveform.shape[1] > truncate_length: | |
| start = random.randint(0, waveform.shape[1] - truncate_length) | |
| waveform = waveform[:, start: start + truncate_length] | |
| else: | |
| waveform = torch.concat([waveform, torch.zeros(1, truncate_length - | |
| waveform.shape[1])], | |
| dim=1) | |
| sample['audio'] = waveform | |
| yield sample | |
| def upsample(data, resample_rate=48000, min_sample_rate=16000, mode='train', | |
| n_codebook=4): | |
| """ Resample data. | |
| Inplace operation. | |
| Args: | |
| data: Iterable[{key, wav, label, sample_rate}] | |
| resample_rate: target resample rate | |
| Returns: | |
| Iterable[{key, wav, label, sample_rate}] | |
| """ | |
| for sample in data: | |
| assert 'semantic_token' in sample | |
| # TODO: unify data processing key names | |
| if 'acoustic_token' not in sample: | |
| continue | |
| if 'sample_rate' in sample.keys(): | |
| sample_rate = sample['sample_rate'] | |
| else: | |
| sample_rate = 24000 | |
| token = np.array(sample['semantic_token'][0][:-1]) | |
| # Calculate the repetition factor for resampling | |
| repetition_factor = int(n_codebook * resample_rate / sample_rate) | |
| if sample_rate != resample_rate: | |
| if sample_rate < min_sample_rate: | |
| continue | |
| sample['sample_rate'] = resample_rate | |
| sample['semantic_token'] = np.array( | |
| [np.repeat(token, repetition_factor)]) | |
| yield sample | |
| def compute_fbank(data, | |
| feat_extractor, | |
| mode='train'): | |
| """ Extract fbank | |
| Args: | |
| data: Iterable[{key, wav, label, sample_rate}] | |
| Returns: | |
| Iterable[{key, feat, label}] | |
| """ | |
| for sample in data: | |
| assert 'sample_rate' in sample | |
| assert 'speech' in sample | |
| assert 'utt' in sample | |
| assert 'text_token' in sample | |
| waveform = sample['speech'] | |
| mat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1) | |
| sample['speech_feat'] = mat | |
| del sample['speech'] | |
| yield sample | |
| def parse_embedding(data, normalize, mode='train'): | |
| """ Parse utt_embedding/spk_embedding | |
| Args: | |
| data: Iterable[{key, wav, label, sample_rate}] | |
| Returns: | |
| Iterable[{key, feat, label}] | |
| """ | |
| for sample in data: | |
| sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], | |
| dtype=torch.float32) | |
| sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], | |
| dtype=torch.float32) | |
| if normalize: | |
| sample['utt_embedding'] = F.normalize(sample['utt_embedding'], | |
| dim=0) | |
| sample['spk_embedding'] = F.normalize(sample['spk_embedding'], | |
| dim=0) | |
| yield sample | |
| def tokenize(data, get_tokenizer, allowed_special, mode='train'): | |
| """ Decode text to chars or BPE | |
| Inplace operation | |
| Args: | |
| data: Iterable[{key, wav, txt, sample_rate}] | |
| Returns: | |
| Iterable[{key, wav, txt, tokens, label, sample_rate}] | |
| """ | |
| tokenizer = get_tokenizer() | |
| for sample in data: | |
| assert 'text' in sample | |
| sample['text_token'] = tokenizer.encode(sample['text'], | |
| allowed_special=allowed_special) | |
| yield sample | |
| def shuffle(data, shuffle_size=10000, mode='train'): | |
| """ Local shuffle the data | |
| Args: | |
| data: Iterable[{key, feat, label}] | |
| shuffle_size: buffer size for shuffle | |
| Returns: | |
| Iterable[{key, feat, label}] | |
| """ | |
| buf = [] | |
| for sample in data: | |
| buf.append(sample) | |
| if len(buf) >= shuffle_size: | |
| random.shuffle(buf) | |
| for x in buf: | |
| yield x | |
| buf = [] | |
| # The sample left over | |
| random.shuffle(buf) | |
| for x in buf: | |
| yield x | |
| def sort(data, sort_size=500, mode='train'): | |
| """ Sort the data by feature length. | |
| Sort is used after shuffle and before batch, so we can group | |
| utts with similar lengths into a batch, and `sort_size` should | |
| be less than `shuffle_size` | |
| Args: | |
| data: Iterable[{key, feat, label}] | |
| sort_size: buffer size for sort | |
| Returns: | |
| Iterable[{key, feat, label}] | |
| """ | |
| buf = [] | |
| for sample in data: | |
| if sample["chorus"] == "verse": | |
| sample["chorus"] = "verse1" | |
| if sample["acoustic_token"].shape[0] == 1: | |
| sample["acoustic_token"] = np.concatenate( | |
| sample["acoustic_token"][0]) | |
| else: | |
| sample["acoustic_token"] = np.concatenate(sample["acoustic_token"]) | |
| sample["acoustic_token"] = torch.from_numpy(sample["acoustic_token"]) | |
| buf.append(sample) | |
| if len(buf) >= sort_size: | |
| buf.sort(key=lambda x: x['acoustic_token'].size(0)) | |
| for x in buf: | |
| yield x | |
| buf = [] | |
| # The sample left over | |
| buf.sort(key=lambda x: x['acoustic_token'].size(0)) | |
| for x in buf: | |
| yield x | |
| def static_batch(data, batch_size=32): | |
| """ Static batch the data by `batch_size` | |
| Args: | |
| data: Iterable[{key, feat, label}] | |
| batch_size: batch size | |
| Returns: | |
| Iterable[List[{key, feat, label}]] | |
| """ | |
| buf = [] | |
| data_empty = True | |
| for sample in data: | |
| data_empty = False | |
| buf.append(sample) | |
| if len(buf) >= batch_size: | |
| yield buf | |
| buf = [] | |
| if data_empty: | |
| raise ValueError("data is empty") | |
| if len(buf) > 0: | |
| yield buf | |
| def dynamic_batch(data, max_frames_in_batch=12000, mode='train'): | |
| """ Dynamic batch the data until the total frames in batch | |
| reach `max_frames_in_batch` | |
| Args: | |
| data: Iterable[{key, feat, label}] | |
| max_frames_in_batch: max_frames in one batch | |
| Returns: | |
| Iterable[List[{key, feat, label}]] | |
| """ | |
| buf = [] | |
| longest_frames = 0 | |
| for sample in data: | |
| assert 'acoustic_token' in sample | |
| assert isinstance(sample['acoustic_token'], torch.Tensor) | |
| if 'semantic_token' in sample: | |
| new_sample_frames = sample['semantic_token'][0].shape[0] | |
| else: | |
| new_sample_frames = sample['semantic_token'] | |
| if "text_token" in sample: | |
| new_sample_frames += len(sample['text_token']) | |
| longest_frames = max(longest_frames, new_sample_frames) | |
| frames_after_padding = longest_frames * (len(buf) + 1) | |
| if frames_after_padding > max_frames_in_batch: | |
| if len(buf) > 0: | |
| yield buf | |
| buf = [sample] | |
| longest_frames = new_sample_frames | |
| else: | |
| buf.append(sample) | |
| if len(buf) > 0: | |
| yield buf | |
| def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, | |
| mode='train'): | |
| """ Wrapper for static/dynamic batch | |
| """ | |
| if mode == 'inference': | |
| return static_batch(data, 1) | |
| elif mode == 'processing': | |
| return static_batch(data, batch_size) | |
| else: | |
| if batch_type == 'static': | |
| return static_batch(data, batch_size) | |
| elif batch_type == 'dynamic': | |
| return dynamic_batch(data, max_frames_in_batch) | |
| else: | |
| logging.fatal('Unsupported batch type {}'.format(batch_type)) | |
| def padding(data, mode='train'): | |
| """ Padding the data into training data | |
| Args: | |
| data: Iterable[List[{key, feat, label}]] | |
| Returns: | |
| Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)] | |
| """ | |
| if mode == "train": | |
| for sample in data: | |
| assert isinstance(sample, list) | |
| if len(sample) != 0: | |
| acoustic_feat_len = torch.tensor( | |
| [x['acoustic_token'].size(0) for x in sample], | |
| dtype=torch.int32) | |
| order = torch.argsort(acoustic_feat_len, descending=True) | |
| utts = [sample[i]['utt'] for i in order] | |
| acoustic_token = [ | |
| sample[i]['acoustic_token'].clone().to(torch.int32) for i in | |
| order] | |
| acoustic_token_len = torch.tensor( | |
| [i.size(0) for i in acoustic_token], dtype=torch.int32) | |
| acoustic_token = pad_sequence(acoustic_token, | |
| batch_first=True, | |
| padding_value=0) | |
| text = [sample[i]['text'] for i in order] | |
| text_token = [torch.tensor(sample[i]['text_token']).long() for i | |
| in order] | |
| text_token_len = torch.tensor([i.size(0) for i in text_token], | |
| dtype=torch.int32) | |
| text_token = pad_sequence(text_token, batch_first=True, | |
| padding_value=0) | |
| time_start = torch.tensor( | |
| [sample[i]['time_start'] for i in order]) | |
| time_end = torch.tensor([sample[i]['time_end'] for i in order]) | |
| if isinstance(sample[0]['chorus'], str): | |
| chorus = torch.tensor( | |
| [CHORUS[sample[i]['chorus']] for i in order]) | |
| else: | |
| chorus = [ | |
| torch.tensor([CHORUS[t] for t in sample[i]['chorus']]) | |
| for i in order] | |
| chorus = pad_sequence(chorus, batch_first=True, | |
| padding_value=-1) | |
| batch = { | |
| "utts" : utts, | |
| "acoustic_token" : acoustic_token, | |
| "acoustic_token_len": acoustic_token_len, | |
| "time_start" : time_start, | |
| "time_end" : time_end, | |
| "chorus" : chorus, | |
| "text" : text, | |
| "text_token" : text_token, | |
| "text_token_len" : text_token_len, | |
| } | |
| if "semantic_token" in sample[0]: | |
| semantic_token = [ | |
| torch.tensor(sample[i]['semantic_token'][0], | |
| dtype=torch.int32) for i in order] | |
| semantic_token_len = torch.tensor( | |
| [i.size(0) for i in semantic_token], | |
| dtype=torch.int32) | |
| semantic_token = pad_sequence(semantic_token, | |
| batch_first=True, | |
| padding_value=0) | |
| batch.update({"semantic_token" : semantic_token, | |
| "semantic_token_len": semantic_token_len}) | |
| yield batch | |
| else: | |
| logging.info("WARNING: sample is empty []!") | |
| elif mode == "inference": | |
| for sample in data: | |
| assert isinstance(sample, list) | |
| utts = [sample[i]['utt'] for i in range(len(sample))] | |
| text = [sample[i]['text'] for i in range(len(sample))] | |
| text_token = [torch.tensor(sample[i]['text_token']).long() for i in | |
| range(len(sample))] | |
| text_token_len = torch.tensor([i.size(0) for i in text_token], | |
| dtype=torch.int32) | |
| text_token = pad_sequence(text_token, batch_first=True, | |
| padding_value=0) | |
| time_start = torch.tensor( | |
| [sample[i]['time_start'] for i in range(len(sample))]) | |
| time_end = torch.tensor( | |
| [sample[i]['time_end'] for i in range(len(sample))]) | |
| if isinstance(sample[0]['chorus'], str): | |
| chorus = torch.tensor([CHORUS[sample[i]['chorus']] for i in | |
| range(len(sample))]) | |
| else: | |
| chorus = [torch.tensor([CHORUS[t] for t in sample[i]['chorus']]) | |
| for i in range(len(sample))] | |
| chorus = pad_sequence(chorus, batch_first=True, | |
| padding_value=-1) | |
| if "acoustic_token" in sample[0]: | |
| acoustic_token = [ | |
| sample[i]['acoustic_token'].clone().to(torch.int32) for i in | |
| range(len(sample))] | |
| acoustic_token_len = torch.tensor( | |
| [i.size(0) for i in acoustic_token], dtype=torch.int32) | |
| acoustic_token = pad_sequence(acoustic_token, | |
| batch_first=True, | |
| padding_value=0) | |
| else: | |
| acoustic_token = None | |
| acoustic_token_len = None | |
| batch = { | |
| "utts" : utts, | |
| "acoustic_token" : acoustic_token, | |
| "acoustic_token_len": acoustic_token_len, | |
| "time_start" : time_start, | |
| "time_end" : time_end, | |
| "chorus" : chorus, | |
| "text" : text, | |
| "text_token" : text_token, | |
| "text_token_len" : text_token_len, | |
| } | |
| if "semantic_token" in sample[0]: | |
| semantic_token = [torch.tensor(sample[i]['semantic_token'][0], | |
| dtype=torch.int32) for i in | |
| range(len(sample))] | |
| semantic_token_len = torch.tensor( | |
| [i.size(0) for i in semantic_token], dtype=torch.int32) | |
| semantic_token = pad_sequence(semantic_token, | |
| batch_first=True, | |
| padding_value=0) | |
| batch.update({"semantic_token" : semantic_token, | |
| "semantic_token_len": semantic_token_len}) | |
| yield batch | |