Spaces:
Sleeping
Sleeping
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import logging | |
| from pathlib import Path | |
| from argparse import Namespace | |
| from fairseq.data import Dictionary, encoders | |
| from fairseq.data.audio.speech_to_text_dataset import ( | |
| S2TDataConfig, | |
| SpeechToTextDataset, | |
| SpeechToTextDatasetCreator, | |
| get_features_or_waveform | |
| ) | |
| from fairseq.tasks import LegacyFairseqTask, register_task | |
| logger = logging.getLogger(__name__) | |
| class SpeechToTextTask(LegacyFairseqTask): | |
| def add_args(cls, parser): | |
| parser.add_argument("data", help="manifest root path") | |
| parser.add_argument( | |
| "--config-yaml", | |
| type=str, | |
| default="config.yaml", | |
| help="Configuration YAML filename (under manifest root)", | |
| ) | |
| parser.add_argument( | |
| "--max-source-positions", | |
| default=6000, | |
| type=int, | |
| metavar="N", | |
| help="max number of tokens in the source sequence", | |
| ) | |
| parser.add_argument( | |
| "--max-target-positions", | |
| default=1024, | |
| type=int, | |
| metavar="N", | |
| help="max number of tokens in the target sequence", | |
| ) | |
| def __init__(self, args, tgt_dict): | |
| super().__init__(args) | |
| self.tgt_dict = tgt_dict | |
| self.data_cfg = S2TDataConfig(Path(args.data) / args.config_yaml) | |
| self.speaker_to_id = self._get_speaker_to_id() | |
| def _get_speaker_to_id(self): | |
| speaker_to_id = None | |
| speaker_set_filename = self.data_cfg.config.get("speaker_set_filename") | |
| if speaker_set_filename is not None: | |
| speaker_set_path = Path(self.args.data) / speaker_set_filename | |
| with open(speaker_set_path) as f: | |
| speaker_to_id = {r.strip(): i for i, r in enumerate(f)} | |
| return speaker_to_id | |
| def setup_task(cls, args, **kwargs): | |
| data_cfg = S2TDataConfig(Path(args.data) / args.config_yaml) | |
| dict_path = Path(args.data) / data_cfg.vocab_filename | |
| if not dict_path.is_file(): | |
| raise FileNotFoundError(f"Dict not found: {dict_path.as_posix()}") | |
| tgt_dict = Dictionary.load(dict_path.as_posix()) | |
| logger.info( | |
| f"dictionary size ({data_cfg.vocab_filename}): " f"{len(tgt_dict):,}" | |
| ) | |
| if getattr(args, "train_subset", None) is not None: | |
| if not all(s.startswith("train") for s in args.train_subset.split(",")): | |
| raise ValueError('Train splits should be named like "train*".') | |
| return cls(args, tgt_dict) | |
| def build_criterion(self, args): | |
| from fairseq import criterions | |
| if self.data_cfg.prepend_tgt_lang_tag and args.ignore_prefix_size != 1: | |
| raise ValueError( | |
| 'Please set "--ignore-prefix-size 1" since ' | |
| "target language ID token is prepended as BOS." | |
| ) | |
| return criterions.build_criterion(args, self) | |
| def load_dataset(self, split, epoch=1, combine=False, **kwargs): | |
| is_train_split = split.startswith("train") | |
| pre_tokenizer = self.build_tokenizer(self.args) | |
| bpe_tokenizer = self.build_bpe(self.args) | |
| self.datasets[split] = SpeechToTextDatasetCreator.from_tsv( | |
| self.args.data, | |
| self.data_cfg, | |
| split, | |
| self.tgt_dict, | |
| pre_tokenizer, | |
| bpe_tokenizer, | |
| is_train_split=is_train_split, | |
| epoch=epoch, | |
| seed=self.args.seed, | |
| speaker_to_id=self.speaker_to_id | |
| ) | |
| def target_dictionary(self): | |
| return self.tgt_dict | |
| def source_dictionary(self): | |
| return None | |
| def max_positions(self): | |
| return self.args.max_source_positions, self.args.max_target_positions | |
| def build_model(self, args): | |
| args.input_feat_per_channel = self.data_cfg.input_feat_per_channel | |
| args.input_channels = self.data_cfg.input_channels | |
| args.speaker_to_id = self.speaker_to_id | |
| return super(SpeechToTextTask, self).build_model(args) | |
| def build_generator( | |
| self, | |
| models, | |
| args, | |
| seq_gen_cls=None, | |
| extra_gen_cls_kwargs=None, | |
| ): | |
| if self.data_cfg.prepend_tgt_lang_tag and args.prefix_size != 1: | |
| raise ValueError( | |
| 'Please set "--prefix-size 1" since ' | |
| "target language ID token is prepended as BOS." | |
| ) | |
| lang_token_ids = { | |
| i | |
| for s, i in self.tgt_dict.indices.items() | |
| if SpeechToTextDataset.is_lang_tag(s) | |
| } | |
| if extra_gen_cls_kwargs is None: | |
| extra_gen_cls_kwargs = {} | |
| extra_gen_cls_kwargs["symbols_to_strip_from_output"] = lang_token_ids | |
| return super().build_generator( | |
| models, args, seq_gen_cls=None, | |
| extra_gen_cls_kwargs=extra_gen_cls_kwargs | |
| ) | |
| def build_tokenizer(self, args): | |
| logger.info(f"pre-tokenizer: {self.data_cfg.pre_tokenizer}") | |
| return encoders.build_tokenizer(Namespace(**self.data_cfg.pre_tokenizer)) | |
| def build_bpe(self, args): | |
| logger.info(f"tokenizer: {self.data_cfg.bpe_tokenizer}") | |
| return encoders.build_bpe(Namespace(**self.data_cfg.bpe_tokenizer)) | |
| def get_interactive_tokens_and_lengths(self, lines, encode_fn): | |
| n_frames = [get_features_or_waveform(p).shape[0] for p in lines] | |
| return lines, n_frames | |
| def build_dataset_for_inference(self, src_tokens, src_lengths, **kwargs): | |
| return SpeechToTextDataset( | |
| "interactive", False, self.data_cfg, src_tokens, src_lengths | |
| ) | |