Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import logging | |
from argparse import Namespace | |
from copy import deepcopy | |
from pathlib import Path | |
from typing import Dict, Optional | |
from fairseq.data import Dictionary | |
logger = logging.getLogger(__name__) | |
def get_config_from_yaml(yaml_path: Path): | |
try: | |
import yaml | |
except ImportError: | |
print("Please install PyYAML: pip install PyYAML") | |
config = {} | |
if yaml_path.is_file(): | |
try: | |
with open(yaml_path) as f: | |
config = yaml.load(f, Loader=yaml.FullLoader) | |
except Exception as e: | |
raise Exception(f"Failed to load config from {yaml_path.as_posix()}: {e}") | |
else: | |
raise FileNotFoundError(f"{yaml_path.as_posix()} not found") | |
return config | |
class S2TDataConfig(object): | |
"""Wrapper class for data config YAML""" | |
def __init__(self, yaml_path: Path): | |
self.config = get_config_from_yaml(yaml_path) | |
self.root = yaml_path.parent | |
def _auto_convert_to_abs_path(self, x): | |
if isinstance(x, str): | |
if not Path(x).exists() and (self.root / x).exists(): | |
return (self.root / x).as_posix() | |
elif isinstance(x, dict): | |
return {k: self._auto_convert_to_abs_path(v) for k, v in x.items()} | |
return x | |
def vocab_filename(self): | |
"""fairseq vocabulary file under data root""" | |
return self.config.get("vocab_filename", "dict.txt") | |
def speaker_set_filename(self): | |
"""speaker set file under data root""" | |
return self.config.get("speaker_set_filename", None) | |
def shuffle(self) -> bool: | |
"""Shuffle dataset samples before batching""" | |
return self.config.get("shuffle", False) | |
def pre_tokenizer(self) -> Dict: | |
"""Pre-tokenizer to apply before subword tokenization. Returning | |
a dictionary with `tokenizer` providing the tokenizer name and | |
the other items providing the tokenizer-specific arguments. | |
Tokenizers are defined in `fairseq.data.encoders.*`""" | |
tokenizer = self.config.get("pre_tokenizer", {"tokenizer": None}) | |
return self._auto_convert_to_abs_path(tokenizer) | |
def bpe_tokenizer(self) -> Dict: | |
"""Subword tokenizer to apply after pre-tokenization. Returning | |
a dictionary with `bpe` providing the tokenizer name and | |
the other items providing the tokenizer-specific arguments. | |
Tokenizers are defined in `fairseq.data.encoders.*`""" | |
tokenizer = self.config.get("bpe_tokenizer", {"bpe": None}) | |
return self._auto_convert_to_abs_path(tokenizer) | |
def prepend_tgt_lang_tag(self) -> bool: | |
"""Prepend target lang ID token as the target BOS (e.g. for to-many | |
multilingual setting). During inference, this requires `--prefix-size 1` | |
to force BOS to be lang ID token.""" | |
return self.config.get("prepend_tgt_lang_tag", False) | |
def prepend_bos_and_append_tgt_lang_tag(self) -> bool: | |
"""Prepend BOS and append target lang ID token to the target (e.g. mBART with language token pretraining).""" | |
return self.config.get("prepend_bos_and_append_tgt_lang_tag", False) | |
def input_feat_per_channel(self): | |
"""The dimension of input features (per audio channel)""" | |
return self.config.get("input_feat_per_channel", 80) | |
def input_channels(self): | |
"""The number of channels in the input audio""" | |
return self.config.get("input_channels", 1) | |
def sample_rate(self): | |
return self.config.get("sample_rate", 16_000) | |
def sampling_alpha(self): | |
"""Hyper-parameter alpha = 1/T for temperature-based resampling. | |
(alpha = 1 for no resampling)""" | |
return self.config.get("sampling_alpha", 1.0) | |
def use_audio_input(self): | |
"""Needed by the dataset loader to see if the model requires | |
raw audio as inputs.""" | |
return self.config.get("use_audio_input", False) | |
def standardize_audio(self) -> bool: | |
return self.use_audio_input and self.config.get("standardize_audio", False) | |
def use_sample_rate(self): | |
"""Needed by the dataset loader to see if the model requires | |
raw audio with specific sample rate as inputs.""" | |
return self.config.get("use_sample_rate", 16000) | |
def audio_root(self): | |
"""Audio paths in the manifest TSV can be relative and this provides | |
the root path. Set this to empty string when using absolute paths.""" | |
return self.config.get("audio_root", "") | |
def get_transforms(self, transform_type, split, is_train): | |
"""Split-specific feature transforms. Allowing train set | |
wildcard `_train`, evaluation set wildcard `_eval` and general | |
wildcard `*` for matching.""" | |
from copy import deepcopy | |
cfg = deepcopy(self.config) | |
_cur = cfg.get(f"{transform_type}transforms", {}) | |
cur = _cur.get(split) | |
cur = _cur.get("_train") if cur is None and is_train else cur | |
cur = _cur.get("_eval") if cur is None and not is_train else cur | |
cur = _cur.get("*") if cur is None else cur | |
return cur | |
def get_feature_transforms(self, split, is_train): | |
cfg = deepcopy(self.config) | |
# TODO: deprecate transforms | |
cur = self.get_transforms("", split, is_train) | |
if cur is not None: | |
logger.warning( | |
"Auto converting transforms into feature_transforms, " | |
"but transforms will be deprecated in the future. Please " | |
"update this in the config." | |
) | |
ft_transforms = self.get_transforms("feature_", split, is_train) | |
if ft_transforms: | |
cur.extend(ft_transforms) | |
else: | |
cur = self.get_transforms("feature_", split, is_train) | |
cfg["feature_transforms"] = cur | |
return cfg | |
def get_waveform_transforms(self, split, is_train): | |
cfg = deepcopy(self.config) | |
cfg["waveform_transforms"] = self.get_transforms("waveform_", split, is_train) | |
return cfg | |
def get_dataset_transforms(self, split, is_train): | |
cfg = deepcopy(self.config) | |
cfg["dataset_transforms"] = self.get_transforms("dataset_", split, is_train) | |
return cfg | |
def global_cmvn_stats_npz(self) -> Optional[str]: | |
path = self.config.get("global_cmvn", {}).get("stats_npz_path", None) | |
return self._auto_convert_to_abs_path(path) | |
def vocoder(self) -> Dict[str, str]: | |
vocoder = self.config.get("vocoder", {"type": "griffin_lim"}) | |
return self._auto_convert_to_abs_path(vocoder) | |
def hub(self) -> Dict[str, str]: | |
return self.config.get("hub", {}) | |
class S2SDataConfig(S2TDataConfig): | |
"""Wrapper class for data config YAML""" | |
def vocab_filename(self): | |
"""fairseq vocabulary file under data root""" | |
return self.config.get("vocab_filename", None) | |
def pre_tokenizer(self) -> Dict: | |
return None | |
def bpe_tokenizer(self) -> Dict: | |
return None | |
def input_transformed_channels(self): | |
"""The number of channels in the audio after feature transforms""" | |
# TODO: move this into individual transforms | |
# TODO: deprecate transforms | |
_cur = self.config.get("transforms", {}) | |
ft_transforms = self.config.get("feature_transforms", {}) | |
if _cur and ft_transforms: | |
_cur.update(ft_transforms) | |
else: | |
_cur = self.config.get("feature_transforms", {}) | |
cur = _cur.get("_train", []) | |
_channels = self.input_channels | |
if "delta_deltas" in cur: | |
_channels *= 3 | |
return _channels | |
def output_sample_rate(self): | |
"""The audio sample rate of output target speech""" | |
return self.config.get("output_sample_rate", 22050) | |
def target_speaker_embed(self): | |
"""Target speaker embedding file (one line per target audio sample)""" | |
return self.config.get("target_speaker_embed", None) | |
def prepend_tgt_lang_tag_as_bos(self) -> bool: | |
"""Prepend target lang ID token as the target BOS.""" | |
return self.config.get("prepend_tgt_lang_tag_as_bos", False) | |
class MultitaskConfig(object): | |
"""Wrapper class for data config YAML""" | |
def __init__(self, yaml_path: Path): | |
config = get_config_from_yaml(yaml_path) | |
self.config = {} | |
for k, v in config.items(): | |
self.config[k] = SingleTaskConfig(k, v) | |
def get_all_tasks(self): | |
return self.config | |
def get_single_task(self, name): | |
assert name in self.config, f"multitask '{name}' does not exist!" | |
return self.config[name] | |
def first_pass_decoder_task_index(self): | |
"""Return the task index of the first-pass text decoder. | |
If there are multiple 'is_first_pass_decoder: True' in the config file, | |
the last task is used for the first-pass decoder. | |
If there is no 'is_first_pass_decoder: True' in the config file, | |
the last task whose task_name includes 'target' and decoder_type is not ctc. | |
""" | |
idx = -1 | |
for i, (k, v) in enumerate(self.config.items()): | |
if v.is_first_pass_decoder: | |
idx = i | |
if idx < 0: | |
for i, (k, v) in enumerate(self.config.items()): | |
if k.startswith("target") and v.decoder_type == "transformer": | |
idx = i | |
return idx | |
class SingleTaskConfig(object): | |
def __init__(self, name, config): | |
self.task_name = name | |
self.config = config | |
dict_path = config.get("dict", "") | |
self.tgt_dict = Dictionary.load(dict_path) if Path(dict_path).exists() else None | |
def data(self): | |
return self.config.get("data", "") | |
def decoder_type(self): | |
return self.config.get("decoder_type", "transformer") | |
def decoder_args(self): | |
"""Decoder arch related args""" | |
args = self.config.get("decoder_args", {}) | |
return Namespace(**args) | |
def criterion_cfg(self): | |
"""cfg for the multitask criterion""" | |
if self.decoder_type == "ctc": | |
from fairseq.criterions.ctc import CtcCriterionConfig | |
cfg = CtcCriterionConfig | |
cfg.zero_infinity = self.config.get("zero_infinity", True) | |
else: | |
from fairseq.criterions.label_smoothed_cross_entropy import ( | |
LabelSmoothedCrossEntropyCriterionConfig, | |
) | |
cfg = LabelSmoothedCrossEntropyCriterionConfig | |
cfg.label_smoothing = self.config.get("label_smoothing", 0.2) | |
return cfg | |
def input_from(self): | |
"""Condition on encoder/decoder of the main model""" | |
return "decoder" if "decoder_layer" in self.config else "encoder" | |
def input_layer(self): | |
if self.input_from == "decoder": | |
return self.config["decoder_layer"] - 1 | |
else: | |
# default using the output from the last encoder layer (-1) | |
return self.config.get("encoder_layer", 0) - 1 | |
def loss_weight_schedule(self): | |
return ( | |
"decay" | |
if "loss_weight_max" in self.config | |
and "loss_weight_decay_steps" in self.config | |
else "fixed" | |
) | |
def get_loss_weight(self, num_updates): | |
if self.loss_weight_schedule == "fixed": | |
weight = self.config.get("loss_weight", 1.0) | |
else: # "decay" | |
assert ( | |
self.config.get("loss_weight_decay_steps", 0) > 0 | |
), "loss_weight_decay_steps must be greater than 0 for a decay schedule" | |
loss_weight_min = self.config.get("loss_weight_min", 0.0001) | |
loss_weight_decay_stepsize = ( | |
self.config["loss_weight_max"] - loss_weight_min | |
) / self.config["loss_weight_decay_steps"] | |
weight = max( | |
self.config["loss_weight_max"] | |
- loss_weight_decay_stepsize * num_updates, | |
loss_weight_min, | |
) | |
return weight | |
def prepend_bos_and_append_tgt_lang_tag(self) -> bool: | |
"""Prepend BOS and append target lang ID token to the target (e.g. mBART with language token pretraining).""" | |
return self.config.get("prepend_bos_and_append_tgt_lang_tag", False) | |
def eos_token(self): | |
"""EOS token during generation""" | |
return self.config.get("eos_token", "<eos>") | |
def rdrop_alpha(self): | |
return self.config.get("rdrop_alpha", 0.0) | |
def is_first_pass_decoder(self): | |
flag = self.config.get("is_first_pass_decoder", False) | |
if flag: | |
if self.decoder_type == "ctc": | |
raise ValueError( | |
"First-pass decoder in the multi-decoder model must not be CTC." | |
) | |
if "target" not in self.task_name: | |
raise Warning( | |
'The name of the first-pass decoder does not include "target".' | |
) | |
return flag | |
def get_lang_tag_mapping(self): | |
return self.config.get("lang_tag_mapping", {}) | |