Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import logging | |
import os | |
import sys | |
import io | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from .. import FairseqDataset | |
from ..data_utils import compute_mask_indices, get_buckets, get_bucketed_sizes | |
from fairseq.data.audio.audio_utils import ( | |
parse_path, | |
read_from_stored_zip, | |
is_sf_audio_data, | |
) | |
from fairseq.data.text_compressor import TextCompressor, TextCompressionLevel | |
logger = logging.getLogger(__name__) | |
class RawAudioDataset(FairseqDataset): | |
def __init__( | |
self, | |
sample_rate, | |
max_sample_size=None, | |
min_sample_size=0, | |
shuffle=True, | |
pad=False, | |
normalize=False, | |
compute_mask_indices=False, | |
**mask_compute_kwargs, | |
): | |
super().__init__() | |
self.sample_rate = sample_rate | |
self.sizes = [] | |
self.max_sample_size = ( | |
max_sample_size if max_sample_size is not None else sys.maxsize | |
) | |
self.min_sample_size = min_sample_size | |
self.pad = pad | |
self.shuffle = shuffle | |
self.normalize = normalize | |
self.compute_mask_indices = compute_mask_indices | |
if self.compute_mask_indices: | |
self.mask_compute_kwargs = mask_compute_kwargs | |
self._features_size_map = {} | |
self._C = mask_compute_kwargs["encoder_embed_dim"] | |
self._conv_feature_layers = eval(mask_compute_kwargs["conv_feature_layers"]) | |
def __getitem__(self, index): | |
raise NotImplementedError() | |
def __len__(self): | |
return len(self.sizes) | |
def postprocess(self, feats, curr_sample_rate): | |
if feats.dim() == 2: | |
feats = feats.mean(-1) | |
if curr_sample_rate != self.sample_rate: | |
raise Exception(f"sample rate: {curr_sample_rate}, need {self.sample_rate}") | |
assert feats.dim() == 1, feats.dim() | |
if self.normalize: | |
with torch.no_grad(): | |
feats = F.layer_norm(feats, feats.shape) | |
return feats | |
def crop_to_max_size(self, wav, target_size): | |
size = len(wav) | |
diff = size - target_size | |
if diff <= 0: | |
return wav | |
start = np.random.randint(0, diff + 1) | |
end = size - diff + start | |
return wav[start:end] | |
def _compute_mask_indices(self, dims, padding_mask): | |
B, T, C = dims | |
mask_indices, mask_channel_indices = None, None | |
if self.mask_compute_kwargs["mask_prob"] > 0: | |
mask_indices = compute_mask_indices( | |
(B, T), | |
padding_mask, | |
self.mask_compute_kwargs["mask_prob"], | |
self.mask_compute_kwargs["mask_length"], | |
self.mask_compute_kwargs["mask_selection"], | |
self.mask_compute_kwargs["mask_other"], | |
min_masks=2, | |
no_overlap=self.mask_compute_kwargs["no_mask_overlap"], | |
min_space=self.mask_compute_kwargs["mask_min_space"], | |
) | |
mask_indices = torch.from_numpy(mask_indices) | |
if self.mask_compute_kwargs["mask_channel_prob"] > 0: | |
mask_channel_indices = compute_mask_indices( | |
(B, C), | |
None, | |
self.mask_compute_kwargs["mask_channel_prob"], | |
self.mask_compute_kwargs["mask_channel_length"], | |
self.mask_compute_kwargs["mask_channel_selection"], | |
self.mask_compute_kwargs["mask_channel_other"], | |
no_overlap=self.mask_compute_kwargs["no_mask_channel_overlap"], | |
min_space=self.mask_compute_kwargs["mask_channel_min_space"], | |
) | |
mask_channel_indices = ( | |
torch.from_numpy(mask_channel_indices).unsqueeze(1).expand(-1, T, -1) | |
) | |
return mask_indices, mask_channel_indices | |
def _bucket_tensor(tensor, num_pad, value): | |
return F.pad(tensor, (0, num_pad), value=value) | |
def collater(self, samples): | |
samples = [s for s in samples if s["source"] is not None] | |
if len(samples) == 0: | |
return {} | |
sources = [s["source"] for s in samples] | |
sizes = [len(s) for s in sources] | |
if self.pad: | |
target_size = min(max(sizes), self.max_sample_size) | |
else: | |
target_size = min(min(sizes), self.max_sample_size) | |
collated_sources = sources[0].new_zeros(len(sources), target_size) | |
padding_mask = ( | |
torch.BoolTensor(collated_sources.shape).fill_(False) if self.pad else None | |
) | |
for i, (source, size) in enumerate(zip(sources, sizes)): | |
diff = size - target_size | |
if diff == 0: | |
collated_sources[i] = source | |
elif diff < 0: | |
assert self.pad | |
collated_sources[i] = torch.cat( | |
[source, source.new_full((-diff,), 0.0)] | |
) | |
padding_mask[i, diff:] = True | |
else: | |
collated_sources[i] = self.crop_to_max_size(source, target_size) | |
input = {"source": collated_sources} | |
out = {"id": torch.LongTensor([s["id"] for s in samples])} | |
if self.pad: | |
input["padding_mask"] = padding_mask | |
if hasattr(self, "num_buckets") and self.num_buckets > 0: | |
assert self.pad, "Cannot bucket without padding first." | |
bucket = max(self._bucketed_sizes[s["id"]] for s in samples) | |
num_pad = bucket - collated_sources.size(-1) | |
if num_pad: | |
input["source"] = self._bucket_tensor(collated_sources, num_pad, 0) | |
input["padding_mask"] = self._bucket_tensor(padding_mask, num_pad, True) | |
if self.compute_mask_indices: | |
B = input["source"].size(0) | |
T = self._get_mask_indices_dims(input["source"].size(-1)) | |
padding_mask_reshaped = input["padding_mask"].clone() | |
extra = padding_mask_reshaped.size(1) % T | |
if extra > 0: | |
padding_mask_reshaped = padding_mask_reshaped[:, :-extra] | |
padding_mask_reshaped = padding_mask_reshaped.view( | |
padding_mask_reshaped.size(0), T, -1 | |
) | |
padding_mask_reshaped = padding_mask_reshaped.all(-1) | |
input["padding_count"] = padding_mask_reshaped.sum(-1).max().item() | |
mask_indices, mask_channel_indices = self._compute_mask_indices( | |
(B, T, self._C), | |
padding_mask_reshaped, | |
) | |
input["mask_indices"] = mask_indices | |
input["mask_channel_indices"] = mask_channel_indices | |
out["sample_size"] = mask_indices.sum().item() | |
out["net_input"] = input | |
return out | |
def _get_mask_indices_dims(self, size, padding=0, dilation=1): | |
if size not in self._features_size_map: | |
L_in = size | |
for (_, kernel_size, stride) in self._conv_feature_layers: | |
L_out = L_in + 2 * padding - dilation * (kernel_size - 1) - 1 | |
L_out = 1 + L_out // stride | |
L_in = L_out | |
self._features_size_map[size] = L_out | |
return self._features_size_map[size] | |
def num_tokens(self, index): | |
return self.size(index) | |
def size(self, index): | |
"""Return an example's size as a float or tuple. This value is used when | |
filtering a dataset with ``--max-positions``.""" | |
if self.pad: | |
return self.sizes[index] | |
return min(self.sizes[index], self.max_sample_size) | |
def ordered_indices(self): | |
"""Return an ordered list of indices. Batches will be constructed based | |
on this order.""" | |
if self.shuffle: | |
order = [np.random.permutation(len(self))] | |
order.append( | |
np.minimum( | |
np.array(self.sizes), | |
self.max_sample_size, | |
) | |
) | |
return np.lexsort(order)[::-1] | |
else: | |
return np.arange(len(self)) | |
def set_bucket_info(self, num_buckets): | |
self.num_buckets = num_buckets | |
if self.num_buckets > 0: | |
self._collated_sizes = np.minimum( | |
np.array(self.sizes), | |
self.max_sample_size, | |
) | |
self.buckets = get_buckets( | |
self._collated_sizes, | |
self.num_buckets, | |
) | |
self._bucketed_sizes = get_bucketed_sizes( | |
self._collated_sizes, self.buckets | |
) | |
logger.info( | |
f"{len(self.buckets)} bucket(s) for the audio dataset: " | |
f"{self.buckets}" | |
) | |
class FileAudioDataset(RawAudioDataset): | |
def __init__( | |
self, | |
manifest_path, | |
sample_rate, | |
max_sample_size=None, | |
min_sample_size=0, | |
shuffle=True, | |
pad=False, | |
normalize=False, | |
num_buckets=0, | |
compute_mask_indices=False, | |
text_compression_level=TextCompressionLevel.none, | |
**mask_compute_kwargs, | |
): | |
super().__init__( | |
sample_rate=sample_rate, | |
max_sample_size=max_sample_size, | |
min_sample_size=min_sample_size, | |
shuffle=shuffle, | |
pad=pad, | |
normalize=normalize, | |
compute_mask_indices=compute_mask_indices, | |
**mask_compute_kwargs, | |
) | |
self.text_compressor = TextCompressor(level=text_compression_level) | |
skipped = 0 | |
self.fnames = [] | |
sizes = [] | |
self.skipped_indices = set() | |
with open(manifest_path, "r") as f: | |
self.root_dir = f.readline().strip() | |
for i, line in enumerate(f): | |
items = line.strip().split("\t") | |
assert len(items) == 2, line | |
sz = int(items[1]) | |
if min_sample_size is not None and sz < min_sample_size: | |
skipped += 1 | |
self.skipped_indices.add(i) | |
continue | |
self.fnames.append(self.text_compressor.compress(items[0])) | |
sizes.append(sz) | |
logger.info(f"loaded {len(self.fnames)}, skipped {skipped} samples") | |
self.sizes = np.array(sizes, dtype=np.int64) | |
try: | |
import pyarrow | |
self.fnames = pyarrow.array(self.fnames) | |
except: | |
logger.debug( | |
"Could not create a pyarrow array. Please install pyarrow for better performance" | |
) | |
pass | |
self.set_bucket_info(num_buckets) | |
def __getitem__(self, index): | |
import soundfile as sf | |
fn = self.fnames[index] | |
fn = fn if isinstance(self.fnames, list) else fn.as_py() | |
fn = self.text_compressor.decompress(fn) | |
path_or_fp = os.path.join(self.root_dir, fn) | |
_path, slice_ptr = parse_path(path_or_fp) | |
if len(slice_ptr) == 2: | |
byte_data = read_from_stored_zip(_path, slice_ptr[0], slice_ptr[1]) | |
assert is_sf_audio_data(byte_data) | |
path_or_fp = io.BytesIO(byte_data) | |
wav, curr_sample_rate = sf.read(path_or_fp, dtype="float32") | |
feats = torch.from_numpy(wav).float() | |
feats = self.postprocess(feats, curr_sample_rate) | |
return {"id": index, "source": feats} | |
class BinarizedAudioDataset(RawAudioDataset): | |
def __init__( | |
self, | |
data_dir, | |
split, | |
sample_rate, | |
max_sample_size=None, | |
min_sample_size=0, | |
shuffle=True, | |
pad=False, | |
normalize=False, | |
num_buckets=0, | |
compute_mask_indices=False, | |
**mask_compute_kwargs, | |
): | |
super().__init__( | |
sample_rate=sample_rate, | |
max_sample_size=max_sample_size, | |
min_sample_size=min_sample_size, | |
shuffle=shuffle, | |
pad=pad, | |
normalize=normalize, | |
compute_mask_indices=compute_mask_indices, | |
**mask_compute_kwargs, | |
) | |
from fairseq.data import data_utils, Dictionary | |
self.fnames_dict = Dictionary.load(os.path.join(data_dir, "dict.txt")) | |
root_path = os.path.join(data_dir, f"{split}.root") | |
if os.path.exists(root_path): | |
with open(root_path, "r") as f: | |
self.root_dir = next(f).strip() | |
else: | |
self.root_dir = None | |
fnames_path = os.path.join(data_dir, split) | |
self.fnames = data_utils.load_indexed_dataset(fnames_path, self.fnames_dict) | |
lengths_path = os.path.join(data_dir, f"{split}.lengths") | |
with open(lengths_path, "r") as f: | |
for line in f: | |
sz = int(line.rstrip()) | |
assert ( | |
sz >= min_sample_size | |
), f"Min sample size is not supported for binarized dataset, but found a sample with size {sz}" | |
self.sizes.append(sz) | |
self.sizes = np.array(self.sizes, dtype=np.int64) | |
self.set_bucket_info(num_buckets) | |
logger.info(f"loaded {len(self.fnames)} samples") | |
def __getitem__(self, index): | |
import soundfile as sf | |
fname = self.fnames_dict.string(self.fnames[index], separator="") | |
if self.root_dir: | |
fname = os.path.join(self.root_dir, fname) | |
wav, curr_sample_rate = sf.read(fname) | |
feats = torch.from_numpy(wav).float() | |
feats = self.postprocess(feats, curr_sample_rate) | |
return {"id": index, "source": feats} | |