Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Author: Speech Lab, Alibaba Group, China | |
| """ | |
| import logging | |
| from contextlib import contextmanager | |
| from distutils.version import LooseVersion | |
| from typing import Dict | |
| from typing import List | |
| from typing import Optional | |
| from typing import Tuple | |
| from typing import Union | |
| import torch | |
| from funasr_detach.layers.abs_normalize import AbsNormalize | |
| from funasr_detach.losses.label_smoothing_loss import ( | |
| LabelSmoothingLoss, # noqa: H301 | |
| ) | |
| from funasr_detach.models.ctc import CTC | |
| from funasr_detach.models.decoder.abs_decoder import AbsDecoder | |
| from funasr_detach.models.encoder.abs_encoder import AbsEncoder | |
| from funasr_detach.frontends.abs_frontend import AbsFrontend | |
| from funasr_detach.models.postencoder.abs_postencoder import AbsPostEncoder | |
| from funasr_detach.models.preencoder.abs_preencoder import AbsPreEncoder | |
| from funasr_detach.models.specaug.abs_specaug import AbsSpecAug | |
| from funasr_detach.models.transformer.utils.add_sos_eos import add_sos_eos | |
| from funasr_detach.metrics import ErrorCalculator | |
| from funasr_detach.metrics.compute_acc import th_accuracy | |
| from funasr_detach.train_utils.device_funcs import force_gatherable | |
| from funasr_detach.models.base_model import FunASRModel | |
| if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): | |
| from torch.cuda.amp import autocast | |
| else: | |
| # Nothing to do if torch<1.6.0 | |
| def autocast(enabled=True): | |
| yield | |
| class ESPnetSVModel(FunASRModel): | |
| """CTC-attention hybrid Encoder-Decoder model""" | |
| def __init__( | |
| self, | |
| vocab_size: int, | |
| token_list: Union[Tuple[str, ...], List[str]], | |
| frontend: Optional[AbsFrontend], | |
| specaug: Optional[AbsSpecAug], | |
| normalize: Optional[AbsNormalize], | |
| preencoder: Optional[AbsPreEncoder], | |
| encoder: AbsEncoder, | |
| postencoder: Optional[AbsPostEncoder], | |
| pooling_layer: torch.nn.Module, | |
| decoder: AbsDecoder, | |
| ): | |
| super().__init__() | |
| # note that eos is the same as sos (equivalent ID) | |
| self.vocab_size = vocab_size | |
| self.token_list = token_list.copy() | |
| self.frontend = frontend | |
| self.specaug = specaug | |
| self.normalize = normalize | |
| self.preencoder = preencoder | |
| self.postencoder = postencoder | |
| self.encoder = encoder | |
| self.pooling_layer = pooling_layer | |
| self.decoder = decoder | |
| def forward( | |
| self, | |
| speech: torch.Tensor, | |
| speech_lengths: torch.Tensor, | |
| text: torch.Tensor, | |
| text_lengths: torch.Tensor, | |
| ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: | |
| """Frontend + Encoder + Decoder + Calc loss | |
| Args: | |
| speech: (Batch, Length, ...) | |
| speech_lengths: (Batch, ) | |
| text: (Batch, Length) | |
| text_lengths: (Batch,) | |
| """ | |
| assert text_lengths.dim() == 1, text_lengths.shape | |
| # Check that batch_size is unified | |
| assert ( | |
| speech.shape[0] | |
| == speech_lengths.shape[0] | |
| == text.shape[0] | |
| == text_lengths.shape[0] | |
| ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape) | |
| batch_size = speech.shape[0] | |
| # for data-parallel | |
| text = text[:, : text_lengths.max()] | |
| # 1. Encoder | |
| encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) | |
| intermediate_outs = None | |
| if isinstance(encoder_out, tuple): | |
| intermediate_outs = encoder_out[1] | |
| encoder_out = encoder_out[0] | |
| loss_att, acc_att, cer_att, wer_att = None, None, None, None | |
| loss_ctc, cer_ctc = None, None | |
| loss_transducer, cer_transducer, wer_transducer = None, None, None | |
| stats = dict() | |
| # 1. CTC branch | |
| if self.ctc_weight != 0.0: | |
| loss_ctc, cer_ctc = self._calc_ctc_loss( | |
| encoder_out, encoder_out_lens, text, text_lengths | |
| ) | |
| # Collect CTC branch stats | |
| stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None | |
| stats["cer_ctc"] = cer_ctc | |
| # Intermediate CTC (optional) | |
| loss_interctc = 0.0 | |
| if self.interctc_weight != 0.0 and intermediate_outs is not None: | |
| for layer_idx, intermediate_out in intermediate_outs: | |
| # we assume intermediate_out has the same length & padding | |
| # as those of encoder_out | |
| loss_ic, cer_ic = self._calc_ctc_loss( | |
| intermediate_out, encoder_out_lens, text, text_lengths | |
| ) | |
| loss_interctc = loss_interctc + loss_ic | |
| # Collect Intermedaite CTC stats | |
| stats["loss_interctc_layer{}".format(layer_idx)] = ( | |
| loss_ic.detach() if loss_ic is not None else None | |
| ) | |
| stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic | |
| loss_interctc = loss_interctc / len(intermediate_outs) | |
| # calculate whole encoder loss | |
| loss_ctc = ( | |
| 1 - self.interctc_weight | |
| ) * loss_ctc + self.interctc_weight * loss_interctc | |
| if self.use_transducer_decoder: | |
| # 2a. Transducer decoder branch | |
| ( | |
| loss_transducer, | |
| cer_transducer, | |
| wer_transducer, | |
| ) = self._calc_transducer_loss( | |
| encoder_out, | |
| encoder_out_lens, | |
| text, | |
| ) | |
| if loss_ctc is not None: | |
| loss = loss_transducer + (self.ctc_weight * loss_ctc) | |
| else: | |
| loss = loss_transducer | |
| # Collect Transducer branch stats | |
| stats["loss_transducer"] = ( | |
| loss_transducer.detach() if loss_transducer is not None else None | |
| ) | |
| stats["cer_transducer"] = cer_transducer | |
| stats["wer_transducer"] = wer_transducer | |
| else: | |
| # 2b. Attention decoder branch | |
| if self.ctc_weight != 1.0: | |
| loss_att, acc_att, cer_att, wer_att = self._calc_att_loss( | |
| encoder_out, encoder_out_lens, text, text_lengths | |
| ) | |
| # 3. CTC-Att loss definition | |
| if self.ctc_weight == 0.0: | |
| loss = loss_att | |
| elif self.ctc_weight == 1.0: | |
| loss = loss_ctc | |
| else: | |
| loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att | |
| # Collect Attn branch stats | |
| stats["loss_att"] = loss_att.detach() if loss_att is not None else None | |
| stats["acc"] = acc_att | |
| stats["cer"] = cer_att | |
| stats["wer"] = wer_att | |
| # Collect total loss stats | |
| stats["loss"] = torch.clone(loss.detach()) | |
| # force_gatherable: to-device and to-tensor if scalar for DataParallel | |
| loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) | |
| return loss, stats, weight | |
| def collect_feats( | |
| self, | |
| speech: torch.Tensor, | |
| speech_lengths: torch.Tensor, | |
| text: torch.Tensor, | |
| text_lengths: torch.Tensor, | |
| ) -> Dict[str, torch.Tensor]: | |
| if self.extract_feats_in_collect_stats: | |
| feats, feats_lengths = self._extract_feats(speech, speech_lengths) | |
| else: | |
| # Generate dummy stats if extract_feats_in_collect_stats is False | |
| logging.warning( | |
| "Generating dummy stats for feats and feats_lengths, " | |
| "because encoder_conf.extract_feats_in_collect_stats is " | |
| f"{self.extract_feats_in_collect_stats}" | |
| ) | |
| feats, feats_lengths = speech, speech_lengths | |
| return {"feats": feats, "feats_lengths": feats_lengths} | |
| def encode( | |
| self, speech: torch.Tensor, speech_lengths: torch.Tensor | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """Frontend + Encoder. Note that this method is used by asr_inference.py | |
| Args: | |
| speech: (Batch, Length, ...) | |
| speech_lengths: (Batch, ) | |
| """ | |
| with autocast(False): | |
| # 1. Extract feats | |
| feats, feats_lengths = self._extract_feats(speech, speech_lengths) | |
| # 2. Data augmentation | |
| if self.specaug is not None and self.training: | |
| feats, feats_lengths = self.specaug(feats, feats_lengths) | |
| # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN | |
| if self.normalize is not None: | |
| feats, feats_lengths = self.normalize(feats, feats_lengths) | |
| # Pre-encoder, e.g. used for raw input data | |
| if self.preencoder is not None: | |
| feats, feats_lengths = self.preencoder(feats, feats_lengths) | |
| # 4. Forward encoder | |
| # feats: (Batch, Length, Dim) -> (Batch, Channel, Length2, Dim2) | |
| encoder_out, encoder_out_lens = self.encoder(feats, feats_lengths) | |
| # Post-encoder, e.g. NLU | |
| if self.postencoder is not None: | |
| encoder_out, encoder_out_lens = self.postencoder( | |
| encoder_out, encoder_out_lens | |
| ) | |
| return encoder_out, encoder_out_lens | |
| def _extract_feats( | |
| self, speech: torch.Tensor, speech_lengths: torch.Tensor | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| assert speech_lengths.dim() == 1, speech_lengths.shape | |
| # for data-parallel | |
| speech = speech[:, : speech_lengths.max()] | |
| if self.frontend is not None: | |
| # Frontend | |
| # e.g. STFT and Feature extract | |
| # data_loader may send time-domain signal in this case | |
| # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim) | |
| feats, feats_lengths = self.frontend(speech, speech_lengths) | |
| else: | |
| # No frontend and no feature extract | |
| feats, feats_lengths = speech, speech_lengths | |
| return feats, feats_lengths | |