from contextlib import contextmanager from distutils.version import LooseVersion from typing import Dict from typing import List from typing import Optional from typing import Tuple from typing import Union import logging import torch from funasr_detach.metrics import ErrorCalculator from funasr_detach.metrics.compute_acc import th_accuracy from funasr_detach.models.transformer.utils.add_sos_eos import add_sos_eos from funasr_detach.losses.label_smoothing_loss import ( LabelSmoothingLoss, # noqa: H301 ) from funasr_detach.models.ctc import CTC from funasr_detach.models.decoder.abs_decoder import AbsDecoder from funasr_detach.models.encoder.abs_encoder import AbsEncoder from funasr_detach.frontends.abs_frontend import AbsFrontend from funasr_detach.models.preencoder.abs_preencoder import AbsPreEncoder from funasr_detach.models.specaug.abs_specaug import AbsSpecAug from funasr_detach.layers.abs_normalize import AbsNormalize from funasr_detach.train_utils.device_funcs import force_gatherable from funasr_detach.models.base_model import FunASRModel if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): from torch.cuda.amp import autocast else: # Nothing to do if torch<1.6.0 @contextmanager def autocast(enabled=True): yield import pdb import random import math class MFCCA(FunASRModel): """ Author: Audio, Speech and Language Processing Group (ASLP@NPU), Northwestern Polytechnical University MFCCA:Multi-Frame Cross-Channel attention for multi-speaker ASR in Multi-party meeting scenario https://arxiv.org/abs/2210.05265 """ def __init__( self, vocab_size: int, token_list: Union[Tuple[str, ...], List[str]], frontend: Optional[AbsFrontend], specaug: Optional[AbsSpecAug], normalize: Optional[AbsNormalize], encoder: AbsEncoder, decoder: AbsDecoder, ctc: CTC, rnnt_decoder: None = None, ctc_weight: float = 0.5, ignore_id: int = -1, lsm_weight: float = 0.0, mask_ratio: float = 0.0, length_normalized_loss: bool = False, report_cer: bool = True, report_wer: bool = True, sym_space: str = "", sym_blank: str = "", preencoder: Optional[AbsPreEncoder] = None, ): assert 0.0 <= ctc_weight <= 1.0, ctc_weight assert rnnt_decoder is None, "Not implemented" super().__init__() # note that eos is the same as sos (equivalent ID) self.sos = vocab_size - 1 self.eos = vocab_size - 1 self.vocab_size = vocab_size self.ignore_id = ignore_id self.ctc_weight = ctc_weight self.token_list = token_list.copy() self.mask_ratio = mask_ratio self.frontend = frontend self.specaug = specaug self.normalize = normalize self.preencoder = preencoder self.encoder = encoder # we set self.decoder = None in the CTC mode since # self.decoder parameters were never used and PyTorch complained # and threw an Exception in the multi-GPU experiment. # thanks Jeff Farris for pointing out the issue. if ctc_weight == 1.0: self.decoder = None else: self.decoder = decoder if ctc_weight == 0.0: self.ctc = None else: self.ctc = ctc self.rnnt_decoder = rnnt_decoder self.criterion_att = LabelSmoothingLoss( size=vocab_size, padding_idx=ignore_id, smoothing=lsm_weight, normalize_length=length_normalized_loss, ) if report_cer or report_wer: self.error_calculator = ErrorCalculator( token_list, sym_space, sym_blank, report_cer, report_wer ) else: self.error_calculator = None def forward( self, speech: torch.Tensor, speech_lengths: torch.Tensor, text: torch.Tensor, text_lengths: torch.Tensor, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Frontend + Encoder + Decoder + Calc loss Args: speech: (Batch, Length, ...) speech_lengths: (Batch, ) text: (Batch, Length) text_lengths: (Batch,) """ assert text_lengths.dim() == 1, text_lengths.shape # Check that batch_size is unified assert ( speech.shape[0] == speech_lengths.shape[0] == text.shape[0] == text_lengths.shape[0] ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape) # pdb.set_trace() if speech.dim() == 3 and speech.size(2) == 8 and self.mask_ratio != 0: rate_num = random.random() # rate_num = 0.1 if rate_num <= self.mask_ratio: retain_channel = math.ceil(random.random() * 8) if retain_channel > 1: speech = speech[ :, :, torch.randperm(8)[0:retain_channel].sort().values ] else: speech = speech[:, :, torch.randperm(8)[0]] # pdb.set_trace() batch_size = speech.shape[0] # for data-parallel text = text[:, : text_lengths.max()] # 1. Encoder encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) # 2a. Attention-decoder branch if self.ctc_weight == 1.0: loss_att, acc_att, cer_att, wer_att = None, None, None, None else: loss_att, acc_att, cer_att, wer_att = self._calc_att_loss( encoder_out, encoder_out_lens, text, text_lengths ) # 2b. CTC branch if self.ctc_weight == 0.0: loss_ctc, cer_ctc = None, None else: loss_ctc, cer_ctc = self._calc_ctc_loss( encoder_out, encoder_out_lens, text, text_lengths ) # 2c. RNN-T branch if self.rnnt_decoder is not None: _ = self._calc_rnnt_loss(encoder_out, encoder_out_lens, text, text_lengths) if self.ctc_weight == 0.0: loss = loss_att elif self.ctc_weight == 1.0: loss = loss_ctc else: loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att stats = dict( loss=loss.detach(), loss_att=loss_att.detach() if loss_att is not None else None, loss_ctc=loss_ctc.detach() if loss_ctc is not None else None, acc=acc_att, cer=cer_att, wer=wer_att, cer_ctc=cer_ctc, ) # force_gatherable: to-device and to-tensor if scalar for DataParallel loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) return loss, stats, weight def collect_feats( self, speech: torch.Tensor, speech_lengths: torch.Tensor, text: torch.Tensor, text_lengths: torch.Tensor, ) -> Dict[str, torch.Tensor]: feats, feats_lengths, channel_size = self._extract_feats(speech, speech_lengths) return {"feats": feats, "feats_lengths": feats_lengths} def encode( self, speech: torch.Tensor, speech_lengths: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """Frontend + Encoder. Note that this method is used by asr_inference.py Args: speech: (Batch, Length, ...) speech_lengths: (Batch, ) """ with autocast(False): # 1. Extract feats feats, feats_lengths, channel_size = self._extract_feats( speech, speech_lengths ) # 2. Data augmentation if self.specaug is not None and self.training: feats, feats_lengths = self.specaug(feats, feats_lengths) # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN if self.normalize is not None: feats, feats_lengths = self.normalize(feats, feats_lengths) # Pre-encoder, e.g. used for raw input data if self.preencoder is not None: feats, feats_lengths = self.preencoder(feats, feats_lengths) # pdb.set_trace() encoder_out, encoder_out_lens, _ = self.encoder( feats, feats_lengths, channel_size ) assert encoder_out.size(0) == speech.size(0), ( encoder_out.size(), speech.size(0), ) if encoder_out.dim() == 4: assert encoder_out.size(2) <= encoder_out_lens.max(), ( encoder_out.size(), encoder_out_lens.max(), ) else: assert encoder_out.size(1) <= encoder_out_lens.max(), ( encoder_out.size(), encoder_out_lens.max(), ) return encoder_out, encoder_out_lens def _extract_feats( self, speech: torch.Tensor, speech_lengths: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: assert speech_lengths.dim() == 1, speech_lengths.shape # for data-parallel speech = speech[:, : speech_lengths.max()] if self.frontend is not None: # Frontend # e.g. STFT and Feature extract # data_loader may send time-domain signal in this case # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim) feats, feats_lengths, channel_size = self.frontend(speech, speech_lengths) else: # No frontend and no feature extract feats, feats_lengths = speech, speech_lengths channel_size = 1 return feats, feats_lengths, channel_size def _calc_att_loss( self, encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, ys_pad: torch.Tensor, ys_pad_lens: torch.Tensor, ): ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_in_lens = ys_pad_lens + 1 # 1. Forward decoder decoder_out, _ = self.decoder( encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens ) # 2. Compute attention loss loss_att = self.criterion_att(decoder_out, ys_out_pad) acc_att = th_accuracy( decoder_out.view(-1, self.vocab_size), ys_out_pad, ignore_label=self.ignore_id, ) # Compute cer/wer using attention-decoder if self.training or self.error_calculator is None: cer_att, wer_att = None, None else: ys_hat = decoder_out.argmax(dim=-1) cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) return loss_att, acc_att, cer_att, wer_att def _calc_ctc_loss( self, encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, ys_pad: torch.Tensor, ys_pad_lens: torch.Tensor, ): # Calc CTC loss if encoder_out.dim() == 4: encoder_out = encoder_out.mean(1) loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens) # Calc CER using CTC cer_ctc = None if not self.training and self.error_calculator is not None: ys_hat = self.ctc.argmax(encoder_out).data cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) return loss_ctc, cer_ctc def _calc_rnnt_loss( self, encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, ys_pad: torch.Tensor, ys_pad_lens: torch.Tensor, ): raise NotImplementedError