Spaces:
Running
on
Zero
Running
on
Zero
| #!/usr/bin/env python3 | |
| # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. | |
| # MIT License (https://opensource.org/licenses/MIT) | |
| import logging | |
| import random | |
| from contextlib import contextmanager | |
| from distutils.version import LooseVersion | |
| from itertools import permutations | |
| from typing import Dict | |
| from typing import Optional | |
| from typing import Tuple, List | |
| import numpy as np | |
| import torch | |
| from torch.nn import functional as F | |
| from funasr_detach.models.transformer.utils.nets_utils import to_device | |
| from funasr_detach.models.transformer.utils.nets_utils import make_pad_mask | |
| from funasr_detach.models.decoder.abs_decoder import AbsDecoder | |
| from funasr_detach.models.encoder.abs_encoder import AbsEncoder | |
| from funasr_detach.frontends.abs_frontend import AbsFrontend | |
| from funasr_detach.models.specaug.abs_specaug import AbsSpecAug | |
| from funasr_detach.models.specaug.abs_profileaug import AbsProfileAug | |
| from funasr_detach.layers.abs_normalize import AbsNormalize | |
| from funasr_detach.train_utils.device_funcs import force_gatherable | |
| from funasr_detach.models.base_model import FunASRModel | |
| from funasr_detach.losses.label_smoothing_loss import ( | |
| LabelSmoothingLoss, | |
| SequenceBinaryCrossEntropy, | |
| ) | |
| from funasr_detach.utils.misc import int2vec | |
| from funasr_detach.utils.hinter import hint_once | |
| if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): | |
| from torch.cuda.amp import autocast | |
| else: | |
| # Nothing to do if torch<1.6.0 | |
| def autocast(enabled=True): | |
| yield | |
| class DiarSondModel(FunASRModel): | |
| """Speaker overlap-aware neural diarization model | |
| reference: https://arxiv.org/abs/2211.10243 | |
| """ | |
| def __init__( | |
| self, | |
| vocab_size: int, | |
| frontend: Optional[AbsFrontend], | |
| specaug: Optional[AbsSpecAug], | |
| profileaug: Optional[AbsProfileAug], | |
| normalize: Optional[AbsNormalize], | |
| encoder: torch.nn.Module, | |
| speaker_encoder: Optional[torch.nn.Module], | |
| ci_scorer: torch.nn.Module, | |
| cd_scorer: Optional[torch.nn.Module], | |
| decoder: torch.nn.Module, | |
| token_list: list, | |
| lsm_weight: float = 0.1, | |
| length_normalized_loss: bool = False, | |
| max_spk_num: int = 16, | |
| label_aggregator: Optional[torch.nn.Module] = None, | |
| normalize_speech_speaker: bool = False, | |
| ignore_id: int = -1, | |
| speaker_discrimination_loss_weight: float = 1.0, | |
| inter_score_loss_weight: float = 0.0, | |
| inputs_type: str = "raw", | |
| model_regularizer_weight: float = 0.0, | |
| freeze_encoder: bool = False, | |
| onfly_shuffle_speaker: bool = True, | |
| ): | |
| super().__init__() | |
| self.encoder = encoder | |
| self.speaker_encoder = speaker_encoder | |
| self.ci_scorer = ci_scorer | |
| self.cd_scorer = cd_scorer | |
| self.normalize = normalize | |
| self.frontend = frontend | |
| self.specaug = specaug | |
| self.profileaug = profileaug | |
| self.label_aggregator = label_aggregator | |
| self.decoder = decoder | |
| self.token_list = token_list | |
| self.max_spk_num = max_spk_num | |
| self.normalize_speech_speaker = normalize_speech_speaker | |
| self.ignore_id = ignore_id | |
| self.model_regularizer_weight = model_regularizer_weight | |
| self.freeze_encoder = freeze_encoder | |
| self.onfly_shuffle_speaker = onfly_shuffle_speaker | |
| self.criterion_diar = LabelSmoothingLoss( | |
| size=vocab_size, | |
| padding_idx=ignore_id, | |
| smoothing=lsm_weight, | |
| normalize_length=length_normalized_loss, | |
| ) | |
| self.criterion_bce = SequenceBinaryCrossEntropy( | |
| normalize_length=length_normalized_loss | |
| ) | |
| self.pse_embedding = self.generate_pse_embedding() | |
| self.power_weight = torch.from_numpy( | |
| 2 ** np.arange(max_spk_num)[np.newaxis, np.newaxis, :] | |
| ).float() | |
| self.int_token_arr = torch.from_numpy( | |
| np.array(self.token_list).astype(int)[np.newaxis, np.newaxis, :] | |
| ).int() | |
| self.speaker_discrimination_loss_weight = speaker_discrimination_loss_weight | |
| self.inter_score_loss_weight = inter_score_loss_weight | |
| self.forward_steps = 0 | |
| self.inputs_type = inputs_type | |
| self.to_regularize_parameters = None | |
| def get_regularize_parameters(self): | |
| to_regularize_parameters, normal_parameters = [], [] | |
| for name, param in self.named_parameters(): | |
| if ( | |
| "encoder" in name | |
| and "weight" in name | |
| and "bn" not in name | |
| and ( | |
| "conv2" in name | |
| or "conv1" in name | |
| or "conv_sc" in name | |
| or "dense" in name | |
| ) | |
| ): | |
| to_regularize_parameters.append((name, param)) | |
| else: | |
| normal_parameters.append((name, param)) | |
| self.to_regularize_parameters = to_regularize_parameters | |
| return to_regularize_parameters, normal_parameters | |
| def generate_pse_embedding(self): | |
| embedding = np.zeros((len(self.token_list), self.max_spk_num), dtype=np.float32) | |
| for idx, pse_label in enumerate(self.token_list): | |
| emb = int2vec(int(pse_label), vec_dim=self.max_spk_num, dtype=np.float32) | |
| embedding[idx] = emb | |
| return torch.from_numpy(embedding) | |
| def rand_permute_speaker(self, raw_profile, raw_binary_labels): | |
| """ | |
| raw_profile: B, N, D | |
| raw_binary_labels: B, T, N | |
| """ | |
| assert ( | |
| raw_profile.shape[1] == raw_binary_labels.shape[2] | |
| ), "Num profile: {}, Num label: {}".format( | |
| raw_profile.shape[1], raw_binary_labels.shape[-1] | |
| ) | |
| profile = torch.clone(raw_profile) | |
| binary_labels = torch.clone(raw_binary_labels) | |
| bsz, num_spk = profile.shape[0], profile.shape[1] | |
| for i in range(bsz): | |
| idx = list(range(num_spk)) | |
| random.shuffle(idx) | |
| profile[i] = profile[i][idx, :] | |
| binary_labels[i] = binary_labels[i][:, idx] | |
| return profile, binary_labels | |
| def forward( | |
| self, | |
| speech: torch.Tensor, | |
| speech_lengths: torch.Tensor = None, | |
| profile: torch.Tensor = None, | |
| profile_lengths: torch.Tensor = None, | |
| binary_labels: torch.Tensor = None, | |
| binary_labels_lengths: torch.Tensor = None, | |
| ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: | |
| """Frontend + Encoder + Speaker Encoder + CI Scorer + CD Scorer + Decoder + Calc loss | |
| Args: | |
| speech: (Batch, samples) or (Batch, frames, input_size) | |
| speech_lengths: (Batch,) default None for chunk interator, | |
| because the chunk-iterator does not | |
| have the speech_lengths returned. | |
| see in | |
| espnet2/iterators/chunk_iter_factory.py | |
| profile: (Batch, N_spk, dim) | |
| profile_lengths: (Batch,) | |
| binary_labels: (Batch, frames, max_spk_num) | |
| binary_labels_lengths: (Batch,) | |
| """ | |
| assert speech.shape[0] <= binary_labels.shape[0], ( | |
| speech.shape, | |
| binary_labels.shape, | |
| ) | |
| batch_size = speech.shape[0] | |
| if self.freeze_encoder: | |
| hint_once("Freeze encoder", "freeze_encoder", rank=0) | |
| self.encoder.eval() | |
| self.forward_steps = self.forward_steps + 1 | |
| if self.pse_embedding.device != speech.device: | |
| self.pse_embedding = self.pse_embedding.to(speech.device) | |
| self.power_weight = self.power_weight.to(speech.device) | |
| self.int_token_arr = self.int_token_arr.to(speech.device) | |
| if self.onfly_shuffle_speaker: | |
| hint_once( | |
| "On-the-fly shuffle speaker permutation.", | |
| "onfly_shuffle_speaker", | |
| rank=0, | |
| ) | |
| profile, binary_labels = self.rand_permute_speaker(profile, binary_labels) | |
| # 0a. Aggregate time-domain labels to match forward outputs | |
| if self.label_aggregator is not None: | |
| binary_labels, binary_labels_lengths = self.label_aggregator( | |
| binary_labels, binary_labels_lengths | |
| ) | |
| # 0b. augment profiles | |
| if self.profileaug is not None and self.training: | |
| speech, profile, binary_labels = self.profileaug( | |
| speech, | |
| speech_lengths, | |
| profile, | |
| profile_lengths, | |
| binary_labels, | |
| binary_labels_lengths, | |
| ) | |
| # 1. Calculate power-set encoding (PSE) labels | |
| pad_bin_labels = F.pad( | |
| binary_labels, | |
| (0, self.max_spk_num - binary_labels.shape[2]), | |
| "constant", | |
| 0.0, | |
| ) | |
| raw_pse_labels = torch.sum( | |
| pad_bin_labels * self.power_weight, dim=2, keepdim=True | |
| ) | |
| pse_labels = torch.argmax( | |
| (raw_pse_labels.int() == self.int_token_arr).float(), dim=2 | |
| ) | |
| # 2. Network forward | |
| pred, inter_outputs = self.prediction_forward( | |
| speech, speech_lengths, profile, profile_lengths, return_inter_outputs=True | |
| ) | |
| (speech, speech_lengths), (profile, profile_lengths), (ci_score, cd_score) = ( | |
| inter_outputs | |
| ) | |
| # If encoder uses conv* as input_layer (i.e., subsampling), | |
| # the sequence length of 'pred' might be slightly less than the | |
| # length of 'spk_labels'. Here we force them to be equal. | |
| length_diff_tolerance = 2 | |
| length_diff = abs(pse_labels.shape[1] - pred.shape[1]) | |
| if length_diff <= length_diff_tolerance: | |
| min_len = min(pred.shape[1], pse_labels.shape[1]) | |
| pse_labels = pse_labels[:, :min_len] | |
| pred = pred[:, :min_len] | |
| cd_score = cd_score[:, :min_len] | |
| ci_score = ci_score[:, :min_len] | |
| loss_diar = self.classification_loss(pred, pse_labels, binary_labels_lengths) | |
| loss_spk_dis = self.speaker_discrimination_loss(profile, profile_lengths) | |
| loss_inter_ci, loss_inter_cd = self.internal_score_loss( | |
| cd_score, ci_score, pse_labels, binary_labels_lengths | |
| ) | |
| regularizer_loss = None | |
| if ( | |
| self.model_regularizer_weight > 0 | |
| and self.to_regularize_parameters is not None | |
| ): | |
| regularizer_loss = self.calculate_regularizer_loss() | |
| label_mask = make_pad_mask( | |
| binary_labels_lengths, maxlen=pse_labels.shape[1] | |
| ).to(pse_labels.device) | |
| loss = ( | |
| loss_diar | |
| + self.speaker_discrimination_loss_weight * loss_spk_dis | |
| + self.inter_score_loss_weight * (loss_inter_ci + loss_inter_cd) | |
| ) | |
| # if regularizer_loss is not None: | |
| # loss = loss + regularizer_loss * self.model_regularizer_weight | |
| ( | |
| correct, | |
| num_frames, | |
| speech_scored, | |
| speech_miss, | |
| speech_falarm, | |
| speaker_scored, | |
| speaker_miss, | |
| speaker_falarm, | |
| speaker_error, | |
| ) = self.calc_diarization_error( | |
| pred=F.embedding(pred.argmax(dim=2) * (~label_mask), self.pse_embedding), | |
| label=F.embedding(pse_labels * (~label_mask), self.pse_embedding), | |
| length=binary_labels_lengths, | |
| ) | |
| if speech_scored > 0 and num_frames > 0: | |
| sad_mr, sad_fr, mi, fa, cf, acc, der = ( | |
| speech_miss / speech_scored, | |
| speech_falarm / speech_scored, | |
| speaker_miss / speaker_scored, | |
| speaker_falarm / speaker_scored, | |
| speaker_error / speaker_scored, | |
| correct / num_frames, | |
| (speaker_miss + speaker_falarm + speaker_error) / speaker_scored, | |
| ) | |
| else: | |
| sad_mr, sad_fr, mi, fa, cf, acc, der = 0, 0, 0, 0, 0, 0, 0 | |
| stats = dict( | |
| loss=loss.detach(), | |
| loss_diar=loss_diar.detach() if loss_diar is not None else None, | |
| loss_spk_dis=loss_spk_dis.detach() if loss_spk_dis is not None else None, | |
| loss_inter_ci=loss_inter_ci.detach() if loss_inter_ci is not None else None, | |
| loss_inter_cd=loss_inter_cd.detach() if loss_inter_cd is not None else None, | |
| regularizer_loss=( | |
| regularizer_loss.detach() if regularizer_loss is not None else None | |
| ), | |
| sad_mr=sad_mr, | |
| sad_fr=sad_fr, | |
| mi=mi, | |
| fa=fa, | |
| cf=cf, | |
| acc=acc, | |
| der=der, | |
| forward_steps=self.forward_steps, | |
| ) | |
| loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) | |
| return loss, stats, weight | |
| def calculate_regularizer_loss(self): | |
| regularizer_loss = 0.0 | |
| for name, param in self.to_regularize_parameters: | |
| regularizer_loss = regularizer_loss + torch.norm(param, p=2) | |
| return regularizer_loss | |
| def classification_loss( | |
| self, | |
| predictions: torch.Tensor, | |
| labels: torch.Tensor, | |
| prediction_lengths: torch.Tensor, | |
| ) -> torch.Tensor: | |
| mask = make_pad_mask(prediction_lengths, maxlen=labels.shape[1]) | |
| pad_labels = labels.masked_fill( | |
| mask.to(predictions.device), value=self.ignore_id | |
| ) | |
| loss = self.criterion_diar(predictions.contiguous(), pad_labels) | |
| return loss | |
| def speaker_discrimination_loss( | |
| self, profile: torch.Tensor, profile_lengths: torch.Tensor | |
| ) -> torch.Tensor: | |
| profile_mask = ( | |
| torch.linalg.norm(profile, ord=2, dim=2, keepdim=True) > 0 | |
| ).float() # (B, N, 1) | |
| mask = torch.matmul(profile_mask, profile_mask.transpose(1, 2)) # (B, N, N) | |
| mask = mask * (1.0 - torch.eye(self.max_spk_num).unsqueeze(0).to(mask)) | |
| eps = 1e-12 | |
| coding_norm = ( | |
| torch.linalg.norm( | |
| profile * profile_mask + (1 - profile_mask) * eps, dim=2, keepdim=True | |
| ) | |
| * profile_mask | |
| ) | |
| # profile: Batch, N, dim | |
| cos_theta = ( | |
| F.cosine_similarity( | |
| profile.unsqueeze(2), profile.unsqueeze(1), dim=-1, eps=eps | |
| ) | |
| * mask | |
| ) | |
| cos_theta = torch.clip(cos_theta, -1 + eps, 1 - eps) | |
| loss = (F.relu(mask * coding_norm * (cos_theta - 0.0))).sum() / mask.sum() | |
| return loss | |
| def calculate_multi_labels(self, pse_labels, pse_labels_lengths): | |
| mask = make_pad_mask(pse_labels_lengths, maxlen=pse_labels.shape[1]) | |
| padding_labels = pse_labels.masked_fill(mask.to(pse_labels.device), value=0).to( | |
| pse_labels | |
| ) | |
| multi_labels = F.embedding(padding_labels, self.pse_embedding) | |
| return multi_labels | |
| def internal_score_loss( | |
| self, | |
| cd_score: torch.Tensor, | |
| ci_score: torch.Tensor, | |
| pse_labels: torch.Tensor, | |
| pse_labels_lengths: torch.Tensor, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| multi_labels = self.calculate_multi_labels(pse_labels, pse_labels_lengths) | |
| ci_loss = self.criterion_bce(ci_score, multi_labels, pse_labels_lengths) | |
| cd_loss = self.criterion_bce(cd_score, multi_labels, pse_labels_lengths) | |
| return ci_loss, cd_loss | |
| def collect_feats( | |
| self, | |
| speech: torch.Tensor, | |
| speech_lengths: torch.Tensor, | |
| profile: torch.Tensor = None, | |
| profile_lengths: torch.Tensor = None, | |
| binary_labels: torch.Tensor = None, | |
| binary_labels_lengths: torch.Tensor = None, | |
| ) -> Dict[str, torch.Tensor]: | |
| feats, feats_lengths = self._extract_feats(speech, speech_lengths) | |
| return {"feats": feats, "feats_lengths": feats_lengths} | |
| def encode_speaker( | |
| self, | |
| profile: torch.Tensor, | |
| profile_lengths: torch.Tensor, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| with autocast(False): | |
| if profile.shape[1] < self.max_spk_num: | |
| profile = F.pad( | |
| profile, | |
| [0, 0, 0, self.max_spk_num - profile.shape[1], 0, 0], | |
| "constant", | |
| 0.0, | |
| ) | |
| profile_mask = ( | |
| torch.linalg.norm(profile, ord=2, dim=2, keepdim=True) > 0 | |
| ).float() | |
| profile = F.normalize(profile, dim=2) | |
| if self.speaker_encoder is not None: | |
| profile = self.speaker_encoder(profile, profile_lengths)[0] | |
| return profile * profile_mask, profile_lengths | |
| else: | |
| return profile, profile_lengths | |
| def encode_speech( | |
| self, | |
| speech: torch.Tensor, | |
| speech_lengths: torch.Tensor, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| if self.encoder is not None and self.inputs_type == "raw": | |
| speech, speech_lengths = self.encode(speech, speech_lengths) | |
| speech_mask = ~make_pad_mask(speech_lengths, maxlen=speech.shape[1]) | |
| speech_mask = speech_mask.to(speech.device).unsqueeze(-1).float() | |
| return speech * speech_mask, speech_lengths | |
| else: | |
| return speech, speech_lengths | |
| def concate_speech_ivc(speech: torch.Tensor, ivc: torch.Tensor) -> torch.Tensor: | |
| nn, tt = ivc.shape[1], speech.shape[1] | |
| speech = speech.unsqueeze(dim=1) # B x 1 x T x D | |
| speech = speech.expand(-1, nn, -1, -1) # B x N x T x D | |
| ivc = ivc.unsqueeze(dim=2) # B x N x 1 x D | |
| ivc = ivc.expand(-1, -1, tt, -1) # B x N x T x D | |
| sd_in = torch.cat([speech, ivc], dim=3) # B x N x T x 2D | |
| return sd_in | |
| def calc_similarity( | |
| self, | |
| speech_encoder_outputs: torch.Tensor, | |
| speaker_encoder_outputs: torch.Tensor, | |
| seq_len: torch.Tensor = None, | |
| spk_len: torch.Tensor = None, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| bb, tt = speech_encoder_outputs.shape[0], speech_encoder_outputs.shape[1] | |
| d_sph, d_spk = speech_encoder_outputs.shape[2], speaker_encoder_outputs.shape[2] | |
| if self.normalize_speech_speaker: | |
| speech_encoder_outputs = F.normalize(speech_encoder_outputs, dim=2) | |
| speaker_encoder_outputs = F.normalize(speaker_encoder_outputs, dim=2) | |
| ge_in = self.concate_speech_ivc(speech_encoder_outputs, speaker_encoder_outputs) | |
| ge_in = torch.reshape(ge_in, [bb * self.max_spk_num, tt, d_sph + d_spk]) | |
| ge_len = seq_len.unsqueeze(1).expand(-1, self.max_spk_num) | |
| ge_len = torch.reshape(ge_len, [bb * self.max_spk_num]) | |
| cd_simi = self.cd_scorer(ge_in, ge_len)[0] | |
| cd_simi = torch.reshape(cd_simi, [bb, self.max_spk_num, tt, 1]) | |
| cd_simi = cd_simi.squeeze(dim=3).permute([0, 2, 1]) | |
| if isinstance(self.ci_scorer, AbsEncoder): | |
| ci_simi = self.ci_scorer(ge_in, ge_len)[0] | |
| ci_simi = torch.reshape(ci_simi, [bb, self.max_spk_num, tt]).permute( | |
| [0, 2, 1] | |
| ) | |
| else: | |
| ci_simi = self.ci_scorer(speech_encoder_outputs, speaker_encoder_outputs) | |
| return ci_simi, cd_simi | |
| def post_net_forward(self, simi, seq_len): | |
| logits = self.decoder(simi, seq_len)[0] | |
| return logits | |
| def prediction_forward( | |
| self, | |
| speech: torch.Tensor, | |
| speech_lengths: torch.Tensor, | |
| profile: torch.Tensor, | |
| profile_lengths: torch.Tensor, | |
| return_inter_outputs: bool = False, | |
| ) -> [torch.Tensor, Optional[list]]: | |
| # speech encoding | |
| speech, speech_lengths = self.encode_speech(speech, speech_lengths) | |
| # speaker encoding | |
| profile, profile_lengths = self.encode_speaker(profile, profile_lengths) | |
| # calculating similarity | |
| ci_simi, cd_simi = self.calc_similarity( | |
| speech, profile, speech_lengths, profile_lengths | |
| ) | |
| similarity = torch.cat([cd_simi, ci_simi], dim=2) | |
| # post net forward | |
| logits = self.post_net_forward(similarity, speech_lengths) | |
| if return_inter_outputs: | |
| return logits, [ | |
| (speech, speech_lengths), | |
| (profile, profile_lengths), | |
| (ci_simi, cd_simi), | |
| ] | |
| return logits | |
| def encode( | |
| self, speech: torch.Tensor, speech_lengths: torch.Tensor | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """Frontend + Encoder | |
| Args: | |
| speech: (Batch, Length, ...) | |
| speech_lengths: (Batch,) | |
| """ | |
| with autocast(False): | |
| # 1. Extract feats | |
| feats, feats_lengths = self._extract_feats(speech, speech_lengths) | |
| # 2. Data augmentation | |
| if self.specaug is not None and self.training: | |
| feats, feats_lengths = self.specaug(feats, feats_lengths) | |
| # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN | |
| if self.normalize is not None: | |
| feats, feats_lengths = self.normalize(feats, feats_lengths) | |
| # 4. Forward encoder | |
| # feats: (Batch, Length, Dim) | |
| # -> encoder_out: (Batch, Length2, Dim) | |
| encoder_outputs = self.encoder(feats, feats_lengths) | |
| encoder_out, encoder_out_lens = encoder_outputs[:2] | |
| assert encoder_out.size(0) == speech.size(0), ( | |
| encoder_out.size(), | |
| speech.size(0), | |
| ) | |
| assert encoder_out.size(1) <= encoder_out_lens.max(), ( | |
| encoder_out.size(), | |
| encoder_out_lens.max(), | |
| ) | |
| return encoder_out, encoder_out_lens | |
| def _extract_feats( | |
| self, speech: torch.Tensor, speech_lengths: torch.Tensor | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| batch_size = speech.shape[0] | |
| speech_lengths = ( | |
| speech_lengths | |
| if speech_lengths is not None | |
| else torch.ones(batch_size).int() * speech.shape[1] | |
| ) | |
| assert speech_lengths.dim() == 1, speech_lengths.shape | |
| # for data-parallel | |
| speech = speech[:, : speech_lengths.max()] | |
| if self.frontend is not None: | |
| # Frontend | |
| # e.g. STFT and Feature extract | |
| # data_loader may send time-domain signal in this case | |
| # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim) | |
| feats, feats_lengths = self.frontend(speech, speech_lengths) | |
| else: | |
| # No frontend and no feature extract | |
| feats, feats_lengths = speech, speech_lengths | |
| return feats, feats_lengths | |
| def calc_diarization_error(pred, label, length): | |
| # Note (jiatong): Credit to https://github.com/hitachi-speech/EEND | |
| (batch_size, max_len, num_output) = label.size() | |
| # mask the padding part | |
| mask = ~make_pad_mask(length, maxlen=label.shape[1]).unsqueeze(-1).numpy() | |
| # pred and label have the shape (batch_size, max_len, num_output) | |
| label_np = label.data.cpu().numpy().astype(int) | |
| pred_np = (pred.data.cpu().numpy() > 0).astype(int) | |
| label_np = label_np * mask | |
| pred_np = pred_np * mask | |
| length = length.data.cpu().numpy() | |
| # compute speech activity detection error | |
| n_ref = np.sum(label_np, axis=2) | |
| n_sys = np.sum(pred_np, axis=2) | |
| speech_scored = float(np.sum(n_ref > 0)) | |
| speech_miss = float(np.sum(np.logical_and(n_ref > 0, n_sys == 0))) | |
| speech_falarm = float(np.sum(np.logical_and(n_ref == 0, n_sys > 0))) | |
| # compute speaker diarization error | |
| speaker_scored = float(np.sum(n_ref)) | |
| speaker_miss = float(np.sum(np.maximum(n_ref - n_sys, 0))) | |
| speaker_falarm = float(np.sum(np.maximum(n_sys - n_ref, 0))) | |
| n_map = np.sum(np.logical_and(label_np == 1, pred_np == 1), axis=2) | |
| speaker_error = float(np.sum(np.minimum(n_ref, n_sys) - n_map)) | |
| correct = float(1.0 * np.sum((label_np == pred_np) * mask) / num_output) | |
| num_frames = np.sum(length) | |
| return ( | |
| correct, | |
| num_frames, | |
| speech_scored, | |
| speech_miss, | |
| speech_falarm, | |
| speaker_scored, | |
| speaker_miss, | |
| speaker_falarm, | |
| speaker_error, | |
| ) | |