File size: 13,745 Bytes
6a62ffb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import logging
from argparse import Namespace
from copy import deepcopy
from pathlib import Path
from typing import Dict, Optional

from fairseq.data import Dictionary

logger = logging.getLogger(__name__)


def get_config_from_yaml(yaml_path: Path):
    try:
        import yaml
    except ImportError:
        print("Please install PyYAML: pip install PyYAML")
    config = {}
    if yaml_path.is_file():
        try:
            with open(yaml_path) as f:
                config = yaml.load(f, Loader=yaml.FullLoader)
        except Exception as e:
            raise Exception(f"Failed to load config from {yaml_path.as_posix()}: {e}")
    else:
        raise FileNotFoundError(f"{yaml_path.as_posix()} not found")

    return config


class S2TDataConfig(object):
    """Wrapper class for data config YAML"""

    def __init__(self, yaml_path: Path):
        self.config = get_config_from_yaml(yaml_path)
        self.root = yaml_path.parent

    def _auto_convert_to_abs_path(self, x):
        if isinstance(x, str):
            if not Path(x).exists() and (self.root / x).exists():
                return (self.root / x).as_posix()
        elif isinstance(x, dict):
            return {k: self._auto_convert_to_abs_path(v) for k, v in x.items()}
        return x

    @property
    def vocab_filename(self):
        """fairseq vocabulary file under data root"""
        return self.config.get("vocab_filename", "dict.txt")

    @property
    def speaker_set_filename(self):
        """speaker set file under data root"""
        return self.config.get("speaker_set_filename", None)

    @property
    def shuffle(self) -> bool:
        """Shuffle dataset samples before batching"""
        return self.config.get("shuffle", False)

    @property
    def pre_tokenizer(self) -> Dict:
        """Pre-tokenizer to apply before subword tokenization. Returning
        a dictionary with `tokenizer` providing the tokenizer name and
        the other items providing the tokenizer-specific arguments.
        Tokenizers are defined in `fairseq.data.encoders.*`"""
        tokenizer = self.config.get("pre_tokenizer", {"tokenizer": None})
        return self._auto_convert_to_abs_path(tokenizer)

    @property
    def bpe_tokenizer(self) -> Dict:
        """Subword tokenizer to apply after pre-tokenization. Returning
        a dictionary with `bpe` providing the tokenizer name and
        the other items providing the tokenizer-specific arguments.
        Tokenizers are defined in `fairseq.data.encoders.*`"""
        tokenizer = self.config.get("bpe_tokenizer", {"bpe": None})
        return self._auto_convert_to_abs_path(tokenizer)

    @property
    def prepend_tgt_lang_tag(self) -> bool:
        """Prepend target lang ID token as the target BOS (e.g. for to-many
        multilingual setting). During inference, this requires `--prefix-size 1`
        to force BOS to be lang ID token."""
        return self.config.get("prepend_tgt_lang_tag", False)

    @property
    def prepend_bos_and_append_tgt_lang_tag(self) -> bool:
        """Prepend BOS and append target lang ID token to the target (e.g. mBART with language token pretraining)."""
        return self.config.get("prepend_bos_and_append_tgt_lang_tag", False)

    @property
    def input_feat_per_channel(self):
        """The dimension of input features (per audio channel)"""
        return self.config.get("input_feat_per_channel", 80)

    @property
    def input_channels(self):
        """The number of channels in the input audio"""
        return self.config.get("input_channels", 1)

    @property
    def sample_rate(self):
        return self.config.get("sample_rate", 16_000)

    @property
    def sampling_alpha(self):
        """Hyper-parameter alpha = 1/T for temperature-based resampling.
        (alpha = 1 for no resampling)"""
        return self.config.get("sampling_alpha", 1.0)

    @property
    def use_audio_input(self):
        """Needed by the dataset loader to see if the model requires
        raw audio as inputs."""
        return self.config.get("use_audio_input", False)

    def standardize_audio(self) -> bool:
        return self.use_audio_input and self.config.get("standardize_audio", False)

    @property
    def use_sample_rate(self):
        """Needed by the dataset loader to see if the model requires
        raw audio with specific sample rate as inputs."""
        return self.config.get("use_sample_rate", 16000)

    @property
    def audio_root(self):
        """Audio paths in the manifest TSV can be relative and this provides
        the root path. Set this to empty string when using absolute paths."""
        return self.config.get("audio_root", "")

    def get_transforms(self, transform_type, split, is_train):
        """Split-specific feature transforms. Allowing train set
        wildcard `_train`, evaluation set wildcard `_eval` and general
        wildcard `*` for matching."""
        from copy import deepcopy

        cfg = deepcopy(self.config)
        _cur = cfg.get(f"{transform_type}transforms", {})
        cur = _cur.get(split)
        cur = _cur.get("_train") if cur is None and is_train else cur
        cur = _cur.get("_eval") if cur is None and not is_train else cur
        cur = _cur.get("*") if cur is None else cur
        return cur

    def get_feature_transforms(self, split, is_train):
        cfg = deepcopy(self.config)
        # TODO: deprecate transforms
        cur = self.get_transforms("", split, is_train)
        if cur is not None:
            logger.warning(
                "Auto converting transforms into feature_transforms, "
                "but transforms will be deprecated in the future. Please "
                "update this in the config."
            )
            ft_transforms = self.get_transforms("feature_", split, is_train)
            if ft_transforms:
                cur.extend(ft_transforms)
        else:
            cur = self.get_transforms("feature_", split, is_train)
        cfg["feature_transforms"] = cur
        return cfg

    def get_waveform_transforms(self, split, is_train):
        cfg = deepcopy(self.config)
        cfg["waveform_transforms"] = self.get_transforms("waveform_", split, is_train)
        return cfg

    def get_dataset_transforms(self, split, is_train):
        cfg = deepcopy(self.config)
        cfg["dataset_transforms"] = self.get_transforms("dataset_", split, is_train)
        return cfg

    @property
    def global_cmvn_stats_npz(self) -> Optional[str]:
        path = self.config.get("global_cmvn", {}).get("stats_npz_path", None)
        return self._auto_convert_to_abs_path(path)

    @property
    def vocoder(self) -> Dict[str, str]:
        vocoder = self.config.get("vocoder", {"type": "griffin_lim"})
        return self._auto_convert_to_abs_path(vocoder)

    @property
    def hub(self) -> Dict[str, str]:
        return self.config.get("hub", {})


class S2SDataConfig(S2TDataConfig):
    """Wrapper class for data config YAML"""

    @property
    def vocab_filename(self):
        """fairseq vocabulary file under data root"""
        return self.config.get("vocab_filename", None)

    @property
    def pre_tokenizer(self) -> Dict:
        return None

    @property
    def bpe_tokenizer(self) -> Dict:
        return None

    @property
    def input_transformed_channels(self):
        """The number of channels in the audio after feature transforms"""
        # TODO: move this into individual transforms
        # TODO: deprecate transforms
        _cur = self.config.get("transforms", {})
        ft_transforms = self.config.get("feature_transforms", {})
        if _cur and ft_transforms:
            _cur.update(ft_transforms)
        else:
            _cur = self.config.get("feature_transforms", {})
        cur = _cur.get("_train", [])

        _channels = self.input_channels
        if "delta_deltas" in cur:
            _channels *= 3

        return _channels

    @property
    def output_sample_rate(self):
        """The audio sample rate of output target speech"""
        return self.config.get("output_sample_rate", 22050)

    @property
    def target_speaker_embed(self):
        """Target speaker embedding file (one line per target audio sample)"""
        return self.config.get("target_speaker_embed", None)

    @property
    def prepend_tgt_lang_tag_as_bos(self) -> bool:
        """Prepend target lang ID token as the target BOS."""
        return self.config.get("prepend_tgt_lang_tag_as_bos", False)


class MultitaskConfig(object):
    """Wrapper class for data config YAML"""

    def __init__(self, yaml_path: Path):
        config = get_config_from_yaml(yaml_path)
        self.config = {}
        for k, v in config.items():
            self.config[k] = SingleTaskConfig(k, v)

    def get_all_tasks(self):
        return self.config

    def get_single_task(self, name):
        assert name in self.config, f"multitask '{name}' does not exist!"
        return self.config[name]

    @property
    def first_pass_decoder_task_index(self):
        """Return the task index of the first-pass text decoder.
        If there are multiple 'is_first_pass_decoder: True' in the config file,
            the last task is used for the first-pass decoder.
        If there is no 'is_first_pass_decoder: True' in the config file,
            the last task whose task_name includes 'target' and decoder_type is not ctc.
        """
        idx = -1
        for i, (k, v) in enumerate(self.config.items()):
            if v.is_first_pass_decoder:
                idx = i
        if idx < 0:
            for i, (k, v) in enumerate(self.config.items()):
                if k.startswith("target") and v.decoder_type == "transformer":
                    idx = i
        return idx


class SingleTaskConfig(object):
    def __init__(self, name, config):
        self.task_name = name
        self.config = config
        dict_path = config.get("dict", "")
        self.tgt_dict = Dictionary.load(dict_path) if Path(dict_path).exists() else None

    @property
    def data(self):
        return self.config.get("data", "")

    @property
    def decoder_type(self):
        return self.config.get("decoder_type", "transformer")

    @property
    def decoder_args(self):
        """Decoder arch related args"""
        args = self.config.get("decoder_args", {})
        return Namespace(**args)

    @property
    def criterion_cfg(self):
        """cfg for the multitask criterion"""
        if self.decoder_type == "ctc":
            from fairseq.criterions.ctc import CtcCriterionConfig

            cfg = CtcCriterionConfig
            cfg.zero_infinity = self.config.get("zero_infinity", True)
        else:
            from fairseq.criterions.label_smoothed_cross_entropy import (
                LabelSmoothedCrossEntropyCriterionConfig,
            )

            cfg = LabelSmoothedCrossEntropyCriterionConfig
            cfg.label_smoothing = self.config.get("label_smoothing", 0.2)
        return cfg

    @property
    def input_from(self):
        """Condition on encoder/decoder of the main model"""
        return "decoder" if "decoder_layer" in self.config else "encoder"

    @property
    def input_layer(self):
        if self.input_from == "decoder":
            return self.config["decoder_layer"] - 1
        else:
            # default using the output from the last encoder layer (-1)
            return self.config.get("encoder_layer", 0) - 1

    @property
    def loss_weight_schedule(self):
        return (
            "decay"
            if "loss_weight_max" in self.config
            and "loss_weight_decay_steps" in self.config
            else "fixed"
        )

    def get_loss_weight(self, num_updates):
        if self.loss_weight_schedule == "fixed":
            weight = self.config.get("loss_weight", 1.0)
        else:  # "decay"
            assert (
                self.config.get("loss_weight_decay_steps", 0) > 0
            ), "loss_weight_decay_steps must be greater than 0 for a decay schedule"
            loss_weight_min = self.config.get("loss_weight_min", 0.0001)
            loss_weight_decay_stepsize = (
                self.config["loss_weight_max"] - loss_weight_min
            ) / self.config["loss_weight_decay_steps"]
            weight = max(
                self.config["loss_weight_max"]
                - loss_weight_decay_stepsize * num_updates,
                loss_weight_min,
            )
        return weight

    @property
    def prepend_bos_and_append_tgt_lang_tag(self) -> bool:
        """Prepend BOS and append target lang ID token to the target (e.g. mBART with language token pretraining)."""
        return self.config.get("prepend_bos_and_append_tgt_lang_tag", False)

    @property
    def eos_token(self):
        """EOS token during generation"""
        return self.config.get("eos_token", "<eos>")

    @property
    def rdrop_alpha(self):
        return self.config.get("rdrop_alpha", 0.0)

    @property
    def is_first_pass_decoder(self):
        flag = self.config.get("is_first_pass_decoder", False)
        if flag:
            if self.decoder_type == "ctc":
                raise ValueError(
                    "First-pass decoder in the multi-decoder model must not be CTC."
                )
            if "target" not in self.task_name:
                raise Warning(
                    'The name of the first-pass decoder does not include "target".'
                )
        return flag

    @property
    def get_lang_tag_mapping(self):
        return self.config.get("lang_tag_mapping", {})