Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import logging | |
from argparse import Namespace | |
from pathlib import Path | |
from typing import List | |
from fairseq.data import Dictionary, encoders | |
from fairseq.data.audio.audio_utils import get_features_or_waveform | |
from fairseq.data.audio.data_cfg import MultitaskConfig | |
from fairseq.data.audio.speech_to_text_dataset import ( | |
S2TDataConfig, | |
SpeechToTextDataset, | |
SpeechToTextDatasetCreator, | |
TextTargetMultitaskData, | |
) | |
from fairseq.tasks import LegacyFairseqTask, register_task | |
logger = logging.getLogger(__name__) | |
class SpeechToTextTask(LegacyFairseqTask): | |
def add_args(cls, parser): | |
parser.add_argument("data", help="manifest root path") | |
parser.add_argument( | |
"--config-yaml", | |
type=str, | |
default="config.yaml", | |
help="Configuration YAML filename (under manifest root)", | |
) | |
parser.add_argument( | |
"--multitask-config-yaml", | |
type=str, | |
default=None, | |
help="Configuration YAML filename for the multitasks (under manifest root)", | |
) | |
parser.add_argument( | |
"--max-source-positions", | |
default=6000, | |
type=int, | |
metavar="N", | |
help="max number of tokens in the source sequence", | |
) | |
parser.add_argument( | |
"--max-target-positions", | |
default=1024, | |
type=int, | |
metavar="N", | |
help="max number of tokens in the target sequence", | |
) | |
def __init__(self, args, tgt_dict): | |
super().__init__(args) | |
self.tgt_dict = tgt_dict | |
self.data_cfg = S2TDataConfig(Path(args.data) / args.config_yaml) | |
self.speaker_to_id = self._get_speaker_to_id() | |
if ( | |
self.data_cfg.prepend_tgt_lang_tag | |
and self.data_cfg.prepend_bos_and_append_tgt_lang_tag | |
): | |
raise ValueError( | |
"Please set only one of the two options to avoid adding target token multiple times" | |
) | |
self.multitask_tasks = {} | |
self.tgt_dict_mt = None | |
self.eos_token_mt = None | |
if getattr(args, "multitask_config_yaml", None) is not None: | |
multitask_cfg = MultitaskConfig( | |
Path(args.data) / args.multitask_config_yaml | |
) | |
first_pass_task_idx = multitask_cfg.first_pass_decoder_task_index | |
for i, (task_name, task_config) in enumerate( | |
multitask_cfg.get_all_tasks().items() | |
): | |
task_obj = DummyMultiTask( | |
task_config, | |
task_config.tgt_dict, | |
first_pass=i == first_pass_task_idx, | |
) | |
self.multitask_tasks[task_name] = task_obj | |
if task_obj.is_first_pass_decoder: | |
self.tgt_dict_mt = task_obj.target_dictionary | |
if task_config.prepend_bos_and_append_tgt_lang_tag: | |
self.eos_token_mt = task_config.eos_token | |
assert not isinstance(self.eos_token_mt, List) | |
if not self.eos_token_mt: | |
raise Warning( | |
"Please provide eos_token in --multitask-config-yaml to replace eos in sequence generator" | |
) | |
def _get_speaker_to_id(self): | |
speaker_to_id = None | |
speaker_set_filename = self.data_cfg.config.get("speaker_set_filename") | |
if speaker_set_filename is not None: | |
speaker_set_path = Path(self.args.data) / speaker_set_filename | |
with open(speaker_set_path) as f: | |
speaker_to_id = {r.strip(): i for i, r in enumerate(f)} | |
return speaker_to_id | |
def setup_task(cls, args, **kwargs): | |
data_cfg = S2TDataConfig(Path(args.data) / args.config_yaml) | |
dict_path = Path(args.data) / data_cfg.vocab_filename | |
if not dict_path.is_file(): | |
raise FileNotFoundError(f"Dict not found: {dict_path.as_posix()}") | |
tgt_dict = Dictionary.load(dict_path.as_posix()) | |
logger.info( | |
f"dictionary size ({data_cfg.vocab_filename}): " f"{len(tgt_dict):,}" | |
) | |
if getattr(args, "train_subset", None) is not None: | |
if not all(s.startswith("train") for s in args.train_subset.split(",")): | |
raise ValueError('Train splits should be named like "train*".') | |
return cls(args, tgt_dict) | |
def build_criterion(self, args): | |
from fairseq import criterions | |
if self.data_cfg.prepend_tgt_lang_tag and args.ignore_prefix_size != 1: | |
raise ValueError( | |
'Please set "--ignore-prefix-size 1" since ' | |
"target language ID token is prepended as BOS." | |
) | |
return criterions.build_criterion(args, self) | |
def load_dataset(self, split, epoch=1, combine=False, **kwargs): | |
is_train_split = split.startswith("train") | |
pre_tokenizer = self.build_tokenizer(self.args) | |
bpe_tokenizer = self.build_bpe(self.args) | |
self.datasets[split] = SpeechToTextDatasetCreator.from_tsv( | |
root=self.args.data, | |
cfg=self.data_cfg, | |
splits=split, | |
tgt_dict=self.tgt_dict, | |
pre_tokenizer=pre_tokenizer, | |
bpe_tokenizer=bpe_tokenizer, | |
is_train_split=is_train_split, | |
epoch=epoch, | |
seed=self.args.seed, | |
speaker_to_id=self.speaker_to_id, | |
multitask=self.multitask_tasks, | |
) | |
def target_dictionary(self): | |
return self.tgt_dict | |
def target_dictionary_mt(self): | |
return self.tgt_dict_mt | |
def source_dictionary(self): | |
return None | |
def max_positions(self): | |
return self.args.max_source_positions, self.args.max_target_positions | |
def build_model(self, args, from_checkpoint=False): | |
args.input_feat_per_channel = self.data_cfg.input_feat_per_channel | |
args.input_channels = self.data_cfg.input_channels | |
args.speaker_to_id = self.speaker_to_id | |
return super(SpeechToTextTask, self).build_model(args, from_checkpoint) | |
def build_generator_dual_decoder( | |
self, | |
models, | |
args, | |
extra_gen_cls_kwargs, | |
): | |
from examples.speech_to_speech.unity.sequence_generator_multi_decoder import ( | |
MultiDecoderSequenceGenerator, | |
) | |
lang_token_ids_aux = { | |
i | |
for s, i in self.tgt_dict_mt.indices.items() | |
if TextTargetMultitaskData.is_lang_tag(s) | |
} | |
extra_gen_cls_kwargs["symbols_to_strip_from_output"].update(lang_token_ids_aux) | |
eos_id_mt = ( | |
self.tgt_dict_mt.index(self.eos_token_mt) if self.eos_token_mt else None | |
) | |
assert eos_id_mt != self.tgt_dict_mt.unk() | |
extra_gen_cls_kwargs["eos_mt"] = eos_id_mt | |
return MultiDecoderSequenceGenerator( | |
models, | |
self.target_dictionary, | |
self.target_dictionary_mt, | |
beam_size=max(1, getattr(args, "beam", 1)), | |
beam_size_mt=max(1, getattr(args, "beam_mt", 1)), | |
max_len_a=getattr(args, "max_len_a", 0), | |
max_len_b=getattr(args, "max_len_b", 200), | |
max_len_a_mt=getattr(args, "max_len_a_mt", 0), | |
max_len_b_mt=getattr(args, "max_len_b_mt", 0), | |
min_len=getattr(args, "min_len", 1), | |
normalize_scores=(not getattr(args, "unnormalized", False)), | |
len_penalty=getattr(args, "lenpen", 1), | |
len_penalty_mt=getattr(args, "lenpen_mt", 1), | |
unk_penalty=getattr(args, "unkpen", 0), | |
temperature=getattr(args, "temperature", 1.0), | |
match_source_len=getattr(args, "match_source_len", False), | |
no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0), | |
**extra_gen_cls_kwargs, | |
) | |
def build_generator( | |
self, | |
models, | |
args, | |
seq_gen_cls=None, | |
extra_gen_cls_kwargs=None, | |
): | |
if self.data_cfg.prepend_tgt_lang_tag and args.prefix_size != 1: | |
raise ValueError( | |
'Please set "--prefix-size 1" since ' | |
"target language ID token is prepended as BOS." | |
) | |
lang_token_ids = { | |
i | |
for s, i in self.tgt_dict.indices.items() | |
if SpeechToTextDataset.is_lang_tag(s) | |
} | |
if extra_gen_cls_kwargs is None: | |
extra_gen_cls_kwargs = {} | |
extra_gen_cls_kwargs["symbols_to_strip_from_output"] = lang_token_ids | |
eos_token = ( | |
args.eos_token | |
if "eos_token" in args and args.eos_token is not None | |
else self.data_cfg.config.get("eos_token", None) | |
) | |
if self.data_cfg.prepend_bos_and_append_tgt_lang_tag and not eos_token: | |
raise Warning( | |
"Please provide --eos_token to replace eos in sequence generator" | |
) | |
eos_id = self.tgt_dict.index(eos_token) if eos_token else None | |
extra_gen_cls_kwargs["eos"] = eos_id | |
has_dual_decoder = getattr(models[0], "mt_task_name", None) is not None | |
if has_dual_decoder: | |
return self.build_generator_dual_decoder( | |
models, | |
args, | |
extra_gen_cls_kwargs=extra_gen_cls_kwargs, | |
) | |
else: | |
return super().build_generator( | |
models, | |
args, | |
seq_gen_cls=None, | |
extra_gen_cls_kwargs=extra_gen_cls_kwargs, | |
) | |
def train_step( | |
self, sample, model, criterion, optimizer, update_num, ignore_grad=False | |
): | |
for task_name, task_obj in self.multitask_tasks.items(): | |
criterion.set_multitask_loss_weight( | |
task_name, task_obj.args.get_loss_weight(update_num) | |
) | |
if task_name in model.multitask_decoders: | |
model.multitask_decoders[task_name].train() | |
loss, sample_size, logging_output = super().train_step( | |
sample, model, criterion, optimizer, update_num, ignore_grad | |
) | |
return loss, sample_size, logging_output | |
def valid_step(self, sample, model, criterion): | |
for task_name, task_obj in self.multitask_tasks.items(): | |
if task_name in model.multitask_decoders: | |
model.multitask_decoders[task_name].eval() | |
loss, sample_size, logging_output = super().valid_step(sample, model, criterion) | |
return loss, sample_size, logging_output | |
def build_tokenizer(self, args): | |
logger.info(f"pre-tokenizer: {self.data_cfg.pre_tokenizer}") | |
return encoders.build_tokenizer(Namespace(**self.data_cfg.pre_tokenizer)) | |
def build_bpe(self, args): | |
logger.info(f"tokenizer: {self.data_cfg.bpe_tokenizer}") | |
return encoders.build_bpe(Namespace(**self.data_cfg.bpe_tokenizer)) | |
def get_interactive_tokens_and_lengths(self, lines, encode_fn): | |
n_frames = [get_features_or_waveform(p).shape[0] for p in lines] | |
return lines, n_frames | |
def build_dataset_for_inference(self, src_tokens, src_lengths, **kwargs): | |
return SpeechToTextDataset( | |
"interactive", False, self.data_cfg, src_tokens, src_lengths | |
) | |
class DummyMultiTask(LegacyFairseqTask): | |
def __init__(self, args, tgt_dict, first_pass=False): | |
super().__init__(args) | |
self.tgt_dict = tgt_dict | |
self.first_pass = first_pass | |
def target_dictionary(self): | |
return self.tgt_dict | |
def is_first_pass_decoder(self): | |
return self.first_pass | |
def inference_step( | |
self, generator, models, sample, prefix_tokens=None, constraints=None | |
): | |
if self.args.decoder_type == "ctc": | |
model = models[0] # only support single model | |
encoder_out = model(**sample) | |
if hasattr(model, "get_logits"): | |
emissions = model.get_logits( | |
encoder_out | |
) # no need to normalize emissions | |
else: | |
emissions = model.get_normalized_probs(encoder_out, log_probs=True) | |
return generator.decode( | |
emissions.transpose(0, 1).float().cpu().contiguous() | |
) | |
else: | |
raise NotImplementedError("only ctc decoder is supported at the moment") | |
def build_generator( | |
self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None | |
): | |
if self.args.decoder_type == "ctc": | |
from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder | |
return W2lViterbiDecoder(args, self.tgt_dict) | |
else: | |
raise NotImplementedError("only ctc decoder is supported at the moment") | |