chong.zhang
update
96fe5d9
# Copyright (c) 2024 Alibaba Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import random
import pyarrow.parquet as pq
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
import numpy as np
import re
torchaudio.set_audio_backend('soundfile')
AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'}
CHORUS = {"intro": 0, "chorus": 1, "verse1": 2, "verse2": 3, "verse": 2,
"outro": 4}
metadata_pattern = re.compile(r'^\[(ti|ar|al|by|offset):.*\]$')
timestamp_pattern = re.compile(r'^\[\d{2}:\d{2}\.\d{2}\](.*)$')
def parquet_opener(data, mode='train', audio_data={}):
""" Give url or local file, return file descriptor
Inplace operation.
Args:
data(Iterable[str]): url or local file list
Returns:
Iterable[{src, stream}]
"""
for sample in data:
assert 'src' in sample
url = sample['src']
try:
df = pq.read_table(url).to_pandas()
for i in df.index:
sample.update(dict(df.loc[i]))
yield {**sample}
except Exception as ex:
logging.warning('Failed to open {}, ex info {}'.format(url, ex))
def clean_lyrics(data, mode="train"):
for sample in data:
lyrics = sample["text"]
cleaned = []
for line in lyrics.splitlines():
if metadata_pattern.match(line):
continue
timestamp_match = timestamp_pattern.match(line)
if timestamp_match:
lyric = timestamp_match.group(1).strip()
if lyric:
cleaned.append(lyric)
else:
if line.strip():
cleaned.append(line.strip())
sample["text"] = '\n'.join(cleaned)
yield sample
def cut_by_length(data, max_length=8000, num_times=4, mode="train"):
for sample in data:
if "semantic_token" in sample:
sample["semantic_token"] = [
sample["semantic_token"][0][:max_length]]
if "acoustic_token" not in sample:
sample["acoustic_token"] = sample["speech_token"]
sample["acoustic_token"] = sample["acoustic_token"][
:max_length * num_times]
yield sample
def filter(data,
max_length=22500, # 22500 #5min #10240
max_acoustic_length=45000,
min_length=10,
min_acoustic_length=150,
token_max_length=200,
token_min_length=1,
min_output_input_ratio=0.0005,
max_output_input_ratio=1,
mode='train'):
""" Filter sample according to feature and label length
Inplace operation.
Args::
data: Iterable[{key, wav, label, sample_rate}]
max_length: drop utterance which is greater than max_length(10ms)
min_length: drop utterance which is less than min_length(10ms)
token_max_length: drop utterance which is greater than
token_max_length, especially when use char unit for
english modeling
token_min_length: drop utterance which is
less than token_max_length
min_output_input_ratio: minimal ration of
token_length / feats_length(10ms)
max_output_input_ratio: maximum ration of
token_length / feats_length(10ms)
Returns:
Iterable[{key, wav, label, sample_rate}]
"""
if mode == "train":
for sample in data:
if "semantic_token" in sample:
new_sample_frames = sample['semantic_token'][0].shape[0]
else:
new_sample_frames = sample['speech_token']
if "text_token" in sample:
new_sample_frames += len(sample['text_token'])
if new_sample_frames > max_length or new_sample_frames < min_length:
print(f"skipped 1 item length={new_sample_frames}")
continue
sample["chorus"] = sample["chorus"].split(",")
if not isinstance(sample["time_start"], np.ndarray):
sample["time_start"] = [sample["time_start"]]
sample["time_end"] = [sample["time_end"]]
for i, t in enumerate(sample["chorus"]):
if sample["chorus"][i] == "verse":
sample["chorus"][i] = "verse1"
yield sample
if mode == "train_flow":
for sample in data:
if "semantic_token" in sample:
new_sample_frames = sample['semantic_token'][0].shape[0]
if "acoustic_token" in sample:
target_sample_frames = sample['acoustic_token'][0].shape[0]
if new_sample_frames > max_length or new_sample_frames < min_acoustic_length or new_sample_frames < min_length or target_sample_frames > max_acoustic_length:
print(
f"skipped 1 item length={new_sample_frames}, target_length={target_sample_frames}")
continue
yield sample
elif mode == "inference":
for sample in data:
yield sample
def resample(data, resample_rate=22050, min_sample_rate=16000, mode='train'):
""" Resample data.
Inplace operation.
Args:
data: Iterable[{key, wav, label, sample_rate}]
resample_rate: target resample rate
Returns:
Iterable[{key, wav, label, sample_rate}]
"""
for sample in data:
assert 'sample_rate' in sample
assert 'speech' in sample
sample_rate = sample['sample_rate']
waveform = sample['speech']
if sample_rate != resample_rate:
if sample_rate < min_sample_rate:
continue
sample['sample_rate'] = resample_rate
sample['speech'] = torchaudio.transforms.Resample(
orig_freq=sample_rate, new_freq=resample_rate)(waveform)
max_val = sample['speech'].abs().max()
if max_val > 1:
sample['speech'] /= max_val
yield sample
def truncate(data, truncate_length=24576, mode='train'):
""" Truncate data.
Args:
data: Iterable[{key, wav, label, sample_rate}]
truncate_length: truncate length
Returns:
Iterable[{key, wav, label, sample_rate}]
"""
for sample in data:
waveform = sample['audio']
if waveform.shape[1] > truncate_length:
start = random.randint(0, waveform.shape[1] - truncate_length)
waveform = waveform[:, start: start + truncate_length]
else:
waveform = torch.concat([waveform, torch.zeros(1, truncate_length -
waveform.shape[1])],
dim=1)
sample['audio'] = waveform
yield sample
def upsample(data, resample_rate=48000, min_sample_rate=16000, mode='train',
n_codebook=4):
""" Resample data.
Inplace operation.
Args:
data: Iterable[{key, wav, label, sample_rate}]
resample_rate: target resample rate
Returns:
Iterable[{key, wav, label, sample_rate}]
"""
for sample in data:
assert 'semantic_token' in sample
# TODO: unify data processing key names
if 'acoustic_token' not in sample:
continue
if 'sample_rate' in sample.keys():
sample_rate = sample['sample_rate']
else:
sample_rate = 24000
token = np.array(sample['semantic_token'][0][:-1])
# Calculate the repetition factor for resampling
repetition_factor = int(n_codebook * resample_rate / sample_rate)
if sample_rate != resample_rate:
if sample_rate < min_sample_rate:
continue
sample['sample_rate'] = resample_rate
sample['semantic_token'] = np.array(
[np.repeat(token, repetition_factor)])
yield sample
def compute_fbank(data,
feat_extractor,
mode='train'):
""" Extract fbank
Args:
data: Iterable[{key, wav, label, sample_rate}]
Returns:
Iterable[{key, feat, label}]
"""
for sample in data:
assert 'sample_rate' in sample
assert 'speech' in sample
assert 'utt' in sample
assert 'text_token' in sample
waveform = sample['speech']
mat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
sample['speech_feat'] = mat
del sample['speech']
yield sample
def parse_embedding(data, normalize, mode='train'):
""" Parse utt_embedding/spk_embedding
Args:
data: Iterable[{key, wav, label, sample_rate}]
Returns:
Iterable[{key, feat, label}]
"""
for sample in data:
sample['utt_embedding'] = torch.tensor(sample['utt_embedding'],
dtype=torch.float32)
sample['spk_embedding'] = torch.tensor(sample['spk_embedding'],
dtype=torch.float32)
if normalize:
sample['utt_embedding'] = F.normalize(sample['utt_embedding'],
dim=0)
sample['spk_embedding'] = F.normalize(sample['spk_embedding'],
dim=0)
yield sample
def tokenize(data, get_tokenizer, allowed_special, mode='train'):
""" Decode text to chars or BPE
Inplace operation
Args:
data: Iterable[{key, wav, txt, sample_rate}]
Returns:
Iterable[{key, wav, txt, tokens, label, sample_rate}]
"""
tokenizer = get_tokenizer()
for sample in data:
assert 'text' in sample
sample['text_token'] = tokenizer.encode(sample['text'],
allowed_special=allowed_special)
yield sample
def shuffle(data, shuffle_size=10000, mode='train'):
""" Local shuffle the data
Args:
data: Iterable[{key, feat, label}]
shuffle_size: buffer size for shuffle
Returns:
Iterable[{key, feat, label}]
"""
buf = []
for sample in data:
buf.append(sample)
if len(buf) >= shuffle_size:
random.shuffle(buf)
for x in buf:
yield x
buf = []
# The sample left over
random.shuffle(buf)
for x in buf:
yield x
def sort(data, sort_size=500, mode='train'):
""" Sort the data by feature length.
Sort is used after shuffle and before batch, so we can group
utts with similar lengths into a batch, and `sort_size` should
be less than `shuffle_size`
Args:
data: Iterable[{key, feat, label}]
sort_size: buffer size for sort
Returns:
Iterable[{key, feat, label}]
"""
buf = []
for sample in data:
if sample["chorus"] == "verse":
sample["chorus"] = "verse1"
if sample["acoustic_token"].shape[0] == 1:
sample["acoustic_token"] = np.concatenate(
sample["acoustic_token"][0])
else:
sample["acoustic_token"] = np.concatenate(sample["acoustic_token"])
sample["acoustic_token"] = torch.from_numpy(sample["acoustic_token"])
buf.append(sample)
if len(buf) >= sort_size:
buf.sort(key=lambda x: x['acoustic_token'].size(0))
for x in buf:
yield x
buf = []
# The sample left over
buf.sort(key=lambda x: x['acoustic_token'].size(0))
for x in buf:
yield x
def static_batch(data, batch_size=32):
""" Static batch the data by `batch_size`
Args:
data: Iterable[{key, feat, label}]
batch_size: batch size
Returns:
Iterable[List[{key, feat, label}]]
"""
buf = []
data_empty = True
for sample in data:
data_empty = False
buf.append(sample)
if len(buf) >= batch_size:
yield buf
buf = []
if data_empty:
raise ValueError("data is empty")
if len(buf) > 0:
yield buf
def dynamic_batch(data, max_frames_in_batch=12000, mode='train'):
""" Dynamic batch the data until the total frames in batch
reach `max_frames_in_batch`
Args:
data: Iterable[{key, feat, label}]
max_frames_in_batch: max_frames in one batch
Returns:
Iterable[List[{key, feat, label}]]
"""
buf = []
longest_frames = 0
for sample in data:
assert 'acoustic_token' in sample
assert isinstance(sample['acoustic_token'], torch.Tensor)
if 'semantic_token' in sample:
new_sample_frames = sample['semantic_token'][0].shape[0]
else:
new_sample_frames = sample['semantic_token']
if "text_token" in sample:
new_sample_frames += len(sample['text_token'])
longest_frames = max(longest_frames, new_sample_frames)
frames_after_padding = longest_frames * (len(buf) + 1)
if frames_after_padding > max_frames_in_batch:
if len(buf) > 0:
yield buf
buf = [sample]
longest_frames = new_sample_frames
else:
buf.append(sample)
if len(buf) > 0:
yield buf
def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000,
mode='train'):
""" Wrapper for static/dynamic batch
"""
if mode == 'inference':
return static_batch(data, 1)
elif mode == 'processing':
return static_batch(data, batch_size)
else:
if batch_type == 'static':
return static_batch(data, batch_size)
elif batch_type == 'dynamic':
return dynamic_batch(data, max_frames_in_batch)
else:
logging.fatal('Unsupported batch type {}'.format(batch_type))
def padding(data, mode='train'):
""" Padding the data into training data
Args:
data: Iterable[List[{key, feat, label}]]
Returns:
Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
"""
if mode == "train":
for sample in data:
assert isinstance(sample, list)
if len(sample) != 0:
acoustic_feat_len = torch.tensor(
[x['acoustic_token'].size(0) for x in sample],
dtype=torch.int32)
order = torch.argsort(acoustic_feat_len, descending=True)
utts = [sample[i]['utt'] for i in order]
acoustic_token = [
sample[i]['acoustic_token'].clone().to(torch.int32) for i in
order]
acoustic_token_len = torch.tensor(
[i.size(0) for i in acoustic_token], dtype=torch.int32)
acoustic_token = pad_sequence(acoustic_token,
batch_first=True,
padding_value=0)
text = [sample[i]['text'] for i in order]
text_token = [torch.tensor(sample[i]['text_token']).long() for i
in order]
text_token_len = torch.tensor([i.size(0) for i in text_token],
dtype=torch.int32)
text_token = pad_sequence(text_token, batch_first=True,
padding_value=0)
time_start = torch.tensor(
[sample[i]['time_start'] for i in order])
time_end = torch.tensor([sample[i]['time_end'] for i in order])
if isinstance(sample[0]['chorus'], str):
chorus = torch.tensor(
[CHORUS[sample[i]['chorus']] for i in order])
else:
chorus = [
torch.tensor([CHORUS[t] for t in sample[i]['chorus']])
for i in order]
chorus = pad_sequence(chorus, batch_first=True,
padding_value=-1)
batch = {
"utts" : utts,
"acoustic_token" : acoustic_token,
"acoustic_token_len": acoustic_token_len,
"time_start" : time_start,
"time_end" : time_end,
"chorus" : chorus,
"text" : text,
"text_token" : text_token,
"text_token_len" : text_token_len,
}
if "semantic_token" in sample[0]:
semantic_token = [
torch.tensor(sample[i]['semantic_token'][0],
dtype=torch.int32) for i in order]
semantic_token_len = torch.tensor(
[i.size(0) for i in semantic_token],
dtype=torch.int32)
semantic_token = pad_sequence(semantic_token,
batch_first=True,
padding_value=0)
batch.update({"semantic_token" : semantic_token,
"semantic_token_len": semantic_token_len})
yield batch
else:
logging.info("WARNING: sample is empty []!")
elif mode == "inference":
for sample in data:
assert isinstance(sample, list)
utts = [sample[i]['utt'] for i in range(len(sample))]
text = [sample[i]['text'] for i in range(len(sample))]
text_token = [torch.tensor(sample[i]['text_token']).long() for i in
range(len(sample))]
text_token_len = torch.tensor([i.size(0) for i in text_token],
dtype=torch.int32)
text_token = pad_sequence(text_token, batch_first=True,
padding_value=0)
time_start = torch.tensor(
[sample[i]['time_start'] for i in range(len(sample))])
time_end = torch.tensor(
[sample[i]['time_end'] for i in range(len(sample))])
if isinstance(sample[0]['chorus'], str):
chorus = torch.tensor([CHORUS[sample[i]['chorus']] for i in
range(len(sample))])
else:
chorus = [torch.tensor([CHORUS[t] for t in sample[i]['chorus']])
for i in range(len(sample))]
chorus = pad_sequence(chorus, batch_first=True,
padding_value=-1)
if "acoustic_token" in sample[0]:
acoustic_token = [
sample[i]['acoustic_token'].clone().to(torch.int32) for i in
range(len(sample))]
acoustic_token_len = torch.tensor(
[i.size(0) for i in acoustic_token], dtype=torch.int32)
acoustic_token = pad_sequence(acoustic_token,
batch_first=True,
padding_value=0)
else:
acoustic_token = None
acoustic_token_len = None
batch = {
"utts" : utts,
"acoustic_token" : acoustic_token,
"acoustic_token_len": acoustic_token_len,
"time_start" : time_start,
"time_end" : time_end,
"chorus" : chorus,
"text" : text,
"text_token" : text_token,
"text_token_len" : text_token_len,
}
if "semantic_token" in sample[0]:
semantic_token = [torch.tensor(sample[i]['semantic_token'][0],
dtype=torch.int32) for i in
range(len(sample))]
semantic_token_len = torch.tensor(
[i.size(0) for i in semantic_token], dtype=torch.int32)
semantic_token = pad_sequence(semantic_token,
batch_first=True,
padding_value=0)
batch.update({"semantic_token" : semantic_token,
"semantic_token_len": semantic_token_len})
yield batch