from typing import Optional, Tuple import torch from wenet.ssl.bestrq.mask import compute_mask_indices from wenet.utils.mask import make_pad_mask class BestRQModel(torch.nn.Module): def __init__( self, encoder: torch.nn.Module, input_dim: int = 256, embedding_dim: int = 256, num_embeddings: int = 8192, num_codebooks: int = 1, dropout_rate: float = 0.1, mask_prob: float = 0.01, mask_length: int = 10, min_masks: int = 2, layer_norm_epsilon=1e-5, ) -> None: super().__init__() assert mask_prob > 0.0 self.mask_prob = mask_prob # NOTE: should filter audio less than mask_length self.mask_length = mask_length self.min_masks = min_masks self.input_dropout = torch.nn.Dropout(dropout_rate) # [embedding_dim, num_embeddings] random_embedding_weight = torch.empty( num_codebooks, embedding_dim, num_embeddings, requires_grad=False ) self.embeddings = torch.nn.init.normal_(random_embedding_weight) random_projection_weight = torch.empty( input_dim, embedding_dim, requires_grad=False ) self.projection = torch.nn.init.xavier_normal_(random_projection_weight) mask_emb_weight = torch.Tensor(input_dim) mask_emb_weight.requires_grad = True self.mask_emb = torch.nn.init.normal_(mask_emb_weight, mean=0, std=0.1) self.input_layer_norm = torch.nn.LayerNorm(input_dim, layer_norm_epsilon) self.encoder = encoder self.encoder_top_n_out = torch.nn.parameter.Parameter( torch.Tensor(num_codebooks, self.encoder.output_size(), num_embeddings) ) def forward( self, xs: torch.Tensor, xs_lens: torch.Tensor, text: Optional[torch.Tensor] = None, text_length: Optional[torch.Tensor] = None, ): # should support nonstreamming and streamming # TODO(Mddct): streamming future # eg: full attenton and chunk or dynamic chunk training # 1 forward subsampling xs, pos_emb, masks = self._forward_subsampling(xs, xs_lens) unmasked_xs = xs # 2 mask features # 2.0 apply mask masked_xs, masked_masks = self._apply_mask(xs) # 2.1 get nearest embedding target_ids = self._nearest_embedding_idx(unmasked_xs) # 3 forward xxx-formaer block out, out_mask = self._forward_encoder_blocks(masked_xs, masks, pos_emb, masks) # 4 get logits out = out.unsqueeze(1) # [B, 1, T', dim] top_n_out = self.encoder_top_n_out.unsqueeze( 0 ) # [num_codebooks, dim, num_embeddings] out = torch.matmul(out, top_n_out) # [B, num_codebooks, T', num_embeddings] # 5 compute loss loss = self._compute_loss(out, target_ids, out_mask.squeeze(1) * masked_masks) return {"loss": loss} def _compute_loss( self, input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor ): input = input.transpose(1, 3) # [B, num_embeddings, T' num_codebooks] entropy = torch.nn.functional.cross_entropy( input, target, reduction="none" ) # [B, T', num_codebooks] # stop gradient for non mask area loss = entropy * mask.unsqueeze(2) return loss.sum() / (mask.sum() * loss.size(2)) def _forward_encoder_blocks( self, xs: torch.Tensor, xs_masks: torch.Tensor, pos_emb: torch.Tensor, mask_pad: torch.Tensor, ): masks = xs_masks for layer in self.encoder.encoders: xs, masks, _, _ = layer(xs, xs_masks, pos_emb, mask_pad) if self.encoder.normalize_before: xs = self.encoder.after_norm(xs) # Here we assume the mask is not changed in encoder layers, so just # return the masks before encoder layers, and the masks will be used # for cross attention with decoder later return xs, masks def _nearest_embedding_idx(self, xs: torch.Tensor) -> torch.Tensor: xs = self.input_layer_norm(xs) xs = self.input_dropout(xs) xs = torch.matmul(xs, self.projection.to(xs.device)) B, T, C = xs.size() flattened_input = xs.view(-1, C) embeddings = self.embeddings.to( xs.device ) # [num_codebooks, embedding_dim, num_embeddings] # [num_codebooks, B*T, num_embeddings] distance = ( torch.sum(flattened_input**2, dim=1, keepdim=True).unsqueeze(0) + torch.sum(embeddings**2, dim=1, keepdim=True) - 2 * torch.matmul(flattened_input.unsqueeze(0), embeddings) ) out = torch.argmin(distance, dim=-1) # [num_codebooks, B*T] out = out.transpose(0, 1) # [B*T, num_codebooks] return out.reshape(B, T, -1) # [B, T, num_codebooks] def _apply_mask(self, xs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: masks = compute_mask_indices( xs.size()[:-1], self.mask_prob, self.mask_length, self.min_masks, device=xs.device, ) masks_expand = masks.unsqueeze(-1) # [B, T, 1] mask_emb = self.mask_emb.to(xs.device).view(1, 1, -1) xs = torch.where(masks_expand, mask_emb, xs) return xs, masks def _forward_subsampling( self, xs: torch.Tensor, xs_lens: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: T = xs.size(1) masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T) if self.encoder.global_cmvn is not None: xs = self.encoder.global_cmvn(xs) xs, pos_emb, masks = self.encoder.embed(xs, masks) return xs, pos_emb, masks