Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import json | |
import logging | |
import os | |
import random | |
from pathlib import Path | |
import numpy as np | |
import torch | |
import torch.utils.data | |
from . import data_utils | |
from fairseq.data.fairseq_dataset import FairseqDataset | |
F0_FRAME_SPACE = 0.005 # sec | |
logger = logging.getLogger(__name__) | |
class ExpressiveCodeDataConfig(object): | |
def __init__(self, json_path): | |
with open(json_path, "r") as f: | |
self.config = json.load(f) | |
self._manifests = self.config["manifests"] | |
def manifests(self): | |
return self._manifests | |
def n_units(self): | |
return self.config["n_units"] | |
def sampling_rate(self): | |
return self.config["sampling_rate"] | |
def code_hop_size(self): | |
return self.config["code_hop_size"] | |
def f0_stats(self): | |
"""pre-computed f0 statistics path""" | |
return self.config.get("f0_stats", None) | |
def f0_vq_type(self): | |
"""naive or precomp""" | |
return self.config["f0_vq_type"] | |
def f0_vq_name(self): | |
return self.config["f0_vq_name"] | |
def get_f0_vq_naive_quantizer(self, log, norm_mean, norm_std): | |
key = "log" if log else "linear" | |
if norm_mean and norm_std: | |
key += "_mean_std_norm" | |
elif norm_mean: | |
key += "_mean_norm" | |
else: | |
key += "_none_norm" | |
return self.config["f0_vq_naive_quantizer"][key] | |
def f0_vq_n_units(self): | |
return self.config["f0_vq_n_units"] | |
def multispkr(self): | |
"""how to parse speaker label from audio path""" | |
return self.config.get("multispkr", None) | |
def get_f0(audio, rate=16000): | |
try: | |
import amfm_decompy.basic_tools as basic | |
import amfm_decompy.pYAAPT as pYAAPT | |
from librosa.util import normalize | |
except ImportError: | |
raise "Please install amfm_decompy (`pip install AMFM-decompy`) and librosa (`pip install librosa`)." | |
assert audio.ndim == 1 | |
frame_length = 20.0 # ms | |
to_pad = int(frame_length / 1000 * rate) // 2 | |
audio = normalize(audio) * 0.95 | |
audio = np.pad(audio, (to_pad, to_pad), "constant", constant_values=0) | |
audio = basic.SignalObj(audio, rate) | |
pitch = pYAAPT.yaapt( | |
audio, | |
frame_length=frame_length, | |
frame_space=F0_FRAME_SPACE * 1000, | |
nccf_thresh1=0.25, | |
tda_frame_length=25.0, | |
) | |
f0 = pitch.samp_values | |
return f0 | |
def interpolate_f0(f0): | |
try: | |
from scipy.interpolate import interp1d | |
except ImportError: | |
raise "Please install scipy (`pip install scipy`)" | |
orig_t = np.arange(f0.shape[0]) | |
f0_interp = f0[:] | |
ii = f0_interp != 0 | |
if ii.sum() > 1: | |
f0_interp = interp1d( | |
orig_t[ii], f0_interp[ii], bounds_error=False, kind="linear", fill_value=0 | |
)(orig_t) | |
f0_interp = torch.Tensor(f0_interp).type_as(f0).to(f0.device) | |
return f0_interp | |
def naive_quantize(x, edges): | |
bin_idx = (x.view(-1, 1) > edges.view(1, -1)).long().sum(dim=1) | |
return bin_idx | |
def load_wav(full_path): | |
try: | |
import soundfile as sf | |
except ImportError: | |
raise "Please install soundfile (`pip install SoundFile`)" | |
data, sampling_rate = sf.read(full_path) | |
return data, sampling_rate | |
def parse_code(code_str, dictionary, append_eos): | |
code, duration = torch.unique_consecutive( | |
torch.ShortTensor(list(map(int, code_str.split()))), return_counts=True | |
) | |
code = " ".join(map(str, code.tolist())) | |
code = dictionary.encode_line(code, append_eos).short() | |
if append_eos: | |
duration = torch.cat((duration, duration.new_zeros((1,))), dim=0) # eos | |
duration = duration.short() | |
return code, duration | |
def parse_manifest(manifest, dictionary): | |
audio_files = [] | |
codes = [] | |
durations = [] | |
speakers = [] | |
with open(manifest) as info: | |
for line in info.readlines(): | |
sample = eval(line.strip()) | |
if "cpc_km100" in sample: | |
k = "cpc_km100" | |
elif "hubert_km100" in sample: | |
k = "hubert_km100" | |
elif "phone" in sample: | |
k = "phone" | |
else: | |
assert False, "unknown format" | |
code = sample[k] | |
code, duration = parse_code(code, dictionary, append_eos=True) | |
codes.append(code) | |
durations.append(duration) | |
audio_files.append(sample["audio"]) | |
speakers.append(sample.get("speaker", None)) | |
return audio_files, codes, durations, speakers | |
def parse_speaker(path, method): | |
if type(path) == str: | |
path = Path(path) | |
if method == "parent_name": | |
return path.parent.name | |
elif method == "parent_parent_name": | |
return path.parent.parent.name | |
elif method == "_": | |
return path.name.split("_")[0] | |
elif method == "single": | |
return "A" | |
elif callable(method): | |
return method(path) | |
else: | |
raise NotImplementedError() | |
def get_f0_by_filename(filename, tgt_sampling_rate): | |
audio, sampling_rate = load_wav(filename) | |
if sampling_rate != tgt_sampling_rate: | |
raise ValueError( | |
"{} SR doesn't match target {} SR".format(sampling_rate, tgt_sampling_rate) | |
) | |
# compute un-interpolated f0, and use Ann's interp in __getitem__ if set | |
f0 = get_f0(audio, rate=tgt_sampling_rate) | |
f0 = torch.from_numpy(f0.astype(np.float32)) | |
return f0 | |
def align_f0_to_durations(f0, durations, f0_code_ratio, tol=1): | |
code_len = durations.sum() | |
targ_len = int(f0_code_ratio * code_len) | |
diff = f0.size(0) - targ_len | |
assert abs(diff) <= tol, ( | |
f"Cannot subsample F0: |{f0.size(0)} - {f0_code_ratio}*{code_len}|" | |
f" > {tol} (dur=\n{durations})" | |
) | |
if diff > 0: | |
f0 = f0[:targ_len] | |
elif diff < 0: | |
f0 = torch.cat((f0, f0.new_full((-diff,), f0[-1])), 0) | |
f0_offset = 0.0 | |
seg_f0s = [] | |
for dur in durations: | |
f0_dur = dur.item() * f0_code_ratio | |
seg_f0 = f0[int(f0_offset) : int(f0_offset + f0_dur)] | |
seg_f0 = seg_f0[seg_f0 != 0] | |
if len(seg_f0) == 0: | |
seg_f0 = torch.tensor(0).type(seg_f0.type()) | |
else: | |
seg_f0 = seg_f0.mean() | |
seg_f0s.append(seg_f0) | |
f0_offset += f0_dur | |
assert int(f0_offset) == f0.size(0), f"{f0_offset} {f0.size()} {durations.sum()}" | |
return torch.tensor(seg_f0s) | |
class Paddings(object): | |
def __init__(self, code_val, dur_val=0, f0_val=-2.0): | |
self.code = code_val | |
self.dur = dur_val | |
self.f0 = f0_val | |
class Shifts(object): | |
def __init__(self, shifts_str, pads): | |
self._shifts = list(map(int, shifts_str.split(","))) | |
assert len(self._shifts) == 2, self._shifts | |
assert all(s >= 0 for s in self._shifts) | |
self.extra_length = max(s for s in self._shifts) | |
self.pads = pads | |
def dur(self): | |
return self._shifts[0] | |
def f0(self): | |
return self._shifts[1] | |
def shift_one(seq, left_pad_num, right_pad_num, pad): | |
assert seq.ndim == 1 | |
bos = seq.new_full((left_pad_num,), pad) | |
eos = seq.new_full((right_pad_num,), pad) | |
seq = torch.cat([bos, seq, eos]) | |
mask = torch.ones_like(seq).bool() | |
mask[left_pad_num : len(seq) - right_pad_num] = 0 | |
return seq, mask | |
def __call__(self, code, dur, f0): | |
if self.extra_length == 0: | |
code_mask = torch.zeros_like(code).bool() | |
dur_mask = torch.zeros_like(dur).bool() | |
f0_mask = torch.zeros_like(f0).bool() | |
return code, code_mask, dur, dur_mask, f0, f0_mask | |
code, code_mask = self.shift_one(code, 0, self.extra_length, self.pads.code) | |
dur, dur_mask = self.shift_one( | |
dur, self.dur, self.extra_length - self.dur, self.pads.dur | |
) | |
f0, f0_mask = self.shift_one( | |
f0, self.f0, self.extra_length - self.f0, self.pads.f0 | |
) | |
return code, code_mask, dur, dur_mask, f0, f0_mask | |
class CodeDataset(FairseqDataset): | |
def __init__( | |
self, | |
manifest, | |
dictionary, | |
dur_dictionary, | |
f0_dictionary, | |
config, | |
discrete_dur, | |
discrete_f0, | |
log_f0, | |
normalize_f0_mean, | |
normalize_f0_std, | |
interpolate_f0, | |
return_filename=False, | |
strip_filename=True, | |
shifts="0,0", | |
return_continuous_f0=False, | |
): | |
random.seed(1234) | |
self.dictionary = dictionary | |
self.dur_dictionary = dur_dictionary | |
self.f0_dictionary = f0_dictionary | |
self.config = config | |
# duration config | |
self.discrete_dur = discrete_dur | |
# pitch config | |
self.discrete_f0 = discrete_f0 | |
self.log_f0 = log_f0 | |
self.normalize_f0_mean = normalize_f0_mean | |
self.normalize_f0_std = normalize_f0_std | |
self.interpolate_f0 = interpolate_f0 | |
self.return_filename = return_filename | |
self.strip_filename = strip_filename | |
self.f0_code_ratio = config.code_hop_size / ( | |
config.sampling_rate * F0_FRAME_SPACE | |
) | |
# use lazy loading to avoid sharing file handlers across workers | |
self.manifest = manifest | |
self._codes = None | |
self._durs = None | |
self._f0s = None | |
with open(f"{manifest}.leng.txt", "r") as f: | |
lengs = [int(line.rstrip()) for line in f] | |
edges = np.cumsum([0] + lengs) | |
self.starts, self.ends = edges[:-1], edges[1:] | |
with open(f"{manifest}.path.txt", "r") as f: | |
self.file_names = [line.rstrip() for line in f] | |
logger.info(f"num entries: {len(self.starts)}") | |
if os.path.exists(f"{manifest}.f0_stat.pt"): | |
self.f0_stats = torch.load(f"{manifest}.f0_stat.pt") | |
elif config.f0_stats: | |
self.f0_stats = torch.load(config.f0_stats) | |
self.multispkr = config.multispkr | |
if config.multispkr: | |
with open(f"{manifest}.speaker.txt", "r") as f: | |
self.spkrs = [line.rstrip() for line in f] | |
self.id_to_spkr = sorted(self.spkrs) | |
self.spkr_to_id = {k: v for v, k in enumerate(self.id_to_spkr)} | |
self.pads = Paddings( | |
dictionary.pad(), | |
0, # use 0 for duration padding | |
f0_dictionary.pad() if discrete_f0 else -5.0, | |
) | |
self.shifts = Shifts(shifts, pads=self.pads) | |
self.return_continuous_f0 = return_continuous_f0 | |
def get_data_handlers(self): | |
logging.info(f"loading data for {self.manifest}") | |
self._codes = np.load(f"{self.manifest}.code.npy", mmap_mode="r") | |
self._durs = np.load(f"{self.manifest}.dur.npy", mmap_mode="r") | |
if self.discrete_f0: | |
if self.config.f0_vq_type == "precomp": | |
self._f0s = np.load( | |
f"{self.manifest}.{self.config.f0_vq_name}.npy", mmap_mode="r" | |
) | |
elif self.config.f0_vq_type == "naive": | |
self._f0s = np.load(f"{self.manifest}.f0.npy", mmap_mode="r") | |
quantizers_path = self.config.get_f0_vq_naive_quantizer( | |
self.log_f0, self.normalize_f0_mean, self.normalize_f0_std | |
) | |
quantizers = torch.load(quantizers_path) | |
n_units = self.config.f0_vq_n_units | |
self._f0_quantizer = torch.from_numpy(quantizers[n_units]) | |
else: | |
raise ValueError(f"f0_vq_type {self.config.f0_vq_type} not supported") | |
else: | |
self._f0s = np.load(f"{self.manifest}.f0.npy", mmap_mode="r") | |
def preprocess_f0(self, f0, stats): | |
""" | |
1. interpolate | |
2. log transform (keep unvoiced frame 0) | |
""" | |
# TODO: change this to be dependent on config for naive quantizer | |
f0 = f0.clone() | |
if self.interpolate_f0: | |
f0 = interpolate_f0(f0) | |
mask = f0 != 0 # only process voiced frames | |
if self.log_f0: | |
f0[mask] = f0[mask].log() | |
if self.normalize_f0_mean: | |
mean = stats["logf0_mean"] if self.log_f0 else stats["f0_mean"] | |
f0[mask] = f0[mask] - mean | |
if self.normalize_f0_std: | |
std = stats["logf0_std"] if self.log_f0 else stats["f0_std"] | |
f0[mask] = f0[mask] / std | |
return f0 | |
def _get_raw_item(self, index): | |
start, end = self.starts[index], self.ends[index] | |
if self._codes is None: | |
self.get_data_handlers() | |
code = torch.from_numpy(np.array(self._codes[start:end])).long() | |
dur = torch.from_numpy(np.array(self._durs[start:end])) | |
f0 = torch.from_numpy(np.array(self._f0s[start:end])) | |
return code, dur, f0 | |
def __getitem__(self, index): | |
code, dur, f0 = self._get_raw_item(index) | |
code = torch.cat([code.new([self.dictionary.bos()]), code]) | |
# use 0 for eos and bos | |
dur = torch.cat([dur.new([0]), dur]) | |
if self.discrete_dur: | |
dur = self.dur_dictionary.encode_line( | |
" ".join(map(str, dur.tolist())), append_eos=False | |
).long() | |
else: | |
dur = dur.float() | |
# TODO: find a more elegant approach | |
raw_f0 = None | |
if self.discrete_f0: | |
if self.config.f0_vq_type == "precomp": | |
f0 = self.f0_dictionary.encode_line( | |
" ".join(map(str, f0.tolist())), append_eos=False | |
).long() | |
else: | |
f0 = f0.float() | |
f0 = self.preprocess_f0(f0, self.f0_stats[self.spkrs[index]]) | |
if self.return_continuous_f0: | |
raw_f0 = f0 | |
raw_f0 = torch.cat([raw_f0.new([self.f0_dictionary.bos()]), raw_f0]) | |
f0 = naive_quantize(f0, self._f0_quantizer) | |
f0 = torch.cat([f0.new([self.f0_dictionary.bos()]), f0]) | |
else: | |
f0 = f0.float() | |
if self.multispkr: | |
f0 = self.preprocess_f0(f0, self.f0_stats[self.spkrs[index]]) | |
else: | |
f0 = self.preprocess_f0(f0, self.f0_stats) | |
f0 = torch.cat([f0.new([0]), f0]) | |
if raw_f0 is not None: | |
*_, raw_f0, raw_f0_mask = self.shifts(code, dur, raw_f0) | |
else: | |
raw_f0_mask = None | |
code, code_mask, dur, dur_mask, f0, f0_mask = self.shifts(code, dur, f0) | |
if raw_f0_mask is not None: | |
assert (raw_f0_mask == f0_mask).all() | |
# is a padded frame if either input or output is padded | |
feats = { | |
"source": code[:-1], | |
"target": code[1:], | |
"mask": code_mask[1:].logical_or(code_mask[:-1]), | |
"dur_source": dur[:-1], | |
"dur_target": dur[1:], | |
"dur_mask": dur_mask[1:].logical_or(dur_mask[:-1]), | |
"f0_source": f0[:-1], | |
"f0_target": f0[1:], | |
"f0_mask": f0_mask[1:].logical_or(f0_mask[:-1]), | |
} | |
if raw_f0 is not None: | |
feats["raw_f0"] = raw_f0[1:] | |
if self.return_filename: | |
fname = self.file_names[index] | |
feats["filename"] = ( | |
fname if not self.strip_filename else Path(fname).with_suffix("").name | |
) | |
return feats | |
def __len__(self): | |
return len(self.starts) | |
def size(self, index): | |
return self.ends[index] - self.starts[index] + self.shifts.extra_length | |
def num_tokens(self, index): | |
return self.size(index) | |
def collater(self, samples): | |
pad_idx, eos_idx = self.dictionary.pad(), self.dictionary.eos() | |
if len(samples) == 0: | |
return {} | |
src_tokens = data_utils.collate_tokens( | |
[s["source"] for s in samples], pad_idx, eos_idx, left_pad=False | |
) | |
tgt_tokens = data_utils.collate_tokens( | |
[s["target"] for s in samples], | |
pad_idx=pad_idx, | |
eos_idx=pad_idx, # appending padding, eos is there already | |
left_pad=False, | |
) | |
src_durs, tgt_durs = [ | |
data_utils.collate_tokens( | |
[s[k] for s in samples], | |
pad_idx=self.pads.dur, | |
eos_idx=self.pads.dur, | |
left_pad=False, | |
) | |
for k in ["dur_source", "dur_target"] | |
] | |
src_f0s, tgt_f0s = [ | |
data_utils.collate_tokens( | |
[s[k] for s in samples], | |
pad_idx=self.pads.f0, | |
eos_idx=self.pads.f0, | |
left_pad=False, | |
) | |
for k in ["f0_source", "f0_target"] | |
] | |
mask, dur_mask, f0_mask = [ | |
data_utils.collate_tokens( | |
[s[k] for s in samples], | |
pad_idx=1, | |
eos_idx=1, | |
left_pad=False, | |
) | |
for k in ["mask", "dur_mask", "f0_mask"] | |
] | |
src_lengths = torch.LongTensor([s["source"].numel() for s in samples]) | |
n_tokens = sum(len(s["source"]) for s in samples) | |
result = { | |
"nsentences": len(samples), | |
"ntokens": n_tokens, | |
"net_input": { | |
"src_tokens": src_tokens, | |
"src_lengths": src_lengths, | |
"dur_src": src_durs, | |
"f0_src": src_f0s, | |
}, | |
"target": tgt_tokens, | |
"dur_target": tgt_durs, | |
"f0_target": tgt_f0s, | |
"mask": mask, | |
"dur_mask": dur_mask, | |
"f0_mask": f0_mask, | |
} | |
if "filename" in samples[0]: | |
result["filename"] = [s["filename"] for s in samples] | |
# TODO: remove this hack into the inference dataset | |
if "prefix" in samples[0]: | |
result["prefix"] = [s["prefix"] for s in samples] | |
if "raw_f0" in samples[0]: | |
raw_f0s = data_utils.collate_tokens( | |
[s["raw_f0"] for s in samples], | |
pad_idx=self.pads.f0, | |
eos_idx=self.pads.f0, | |
left_pad=False, | |
) | |
result["raw_f0"] = raw_f0s | |
return result | |