File size: 5,830 Bytes
0d80816
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
from typing import Optional, Tuple
import torch

from wenet.ssl.bestrq.mask import compute_mask_indices
from wenet.utils.mask import make_pad_mask


class BestRQModel(torch.nn.Module):
    def __init__(
        self,
        encoder: torch.nn.Module,
        input_dim: int = 256,
        embedding_dim: int = 256,
        num_embeddings: int = 8192,
        num_codebooks: int = 1,
        dropout_rate: float = 0.1,
        mask_prob: float = 0.01,
        mask_length: int = 10,
        min_masks: int = 2,
        layer_norm_epsilon=1e-5,
    ) -> None:
        super().__init__()

        assert mask_prob > 0.0

        self.mask_prob = mask_prob
        # NOTE: should filter audio less than mask_length
        self.mask_length = mask_length
        self.min_masks = min_masks

        self.input_dropout = torch.nn.Dropout(dropout_rate)

        # [embedding_dim, num_embeddings]
        random_embedding_weight = torch.empty(
            num_codebooks, embedding_dim, num_embeddings, requires_grad=False
        )
        self.embeddings = torch.nn.init.normal_(random_embedding_weight)

        random_projection_weight = torch.empty(
            input_dim, embedding_dim, requires_grad=False
        )
        self.projection = torch.nn.init.xavier_normal_(random_projection_weight)

        mask_emb_weight = torch.Tensor(input_dim)
        mask_emb_weight.requires_grad = True
        self.mask_emb = torch.nn.init.normal_(mask_emb_weight, mean=0, std=0.1)

        self.input_layer_norm = torch.nn.LayerNorm(input_dim, layer_norm_epsilon)
        self.encoder = encoder
        self.encoder_top_n_out = torch.nn.parameter.Parameter(
            torch.Tensor(num_codebooks, self.encoder.output_size(), num_embeddings)
        )

    def forward(
        self,
        xs: torch.Tensor,
        xs_lens: torch.Tensor,
        text: Optional[torch.Tensor] = None,
        text_length: Optional[torch.Tensor] = None,
    ):
        # should support nonstreamming and streamming
        # TODO(Mddct): streamming future
        # eg: full attenton and chunk or  dynamic chunk training
        # 1 forward subsampling
        xs, pos_emb, masks = self._forward_subsampling(xs, xs_lens)
        unmasked_xs = xs
        # 2 mask features
        # 2.0 apply mask
        masked_xs, masked_masks = self._apply_mask(xs)
        # 2.1 get nearest embedding
        target_ids = self._nearest_embedding_idx(unmasked_xs)
        # 3 forward xxx-formaer block
        out, out_mask = self._forward_encoder_blocks(masked_xs, masks, pos_emb, masks)
        # 4 get logits
        out = out.unsqueeze(1)  # [B, 1, T', dim]
        top_n_out = self.encoder_top_n_out.unsqueeze(
            0
        )  # [num_codebooks, dim, num_embeddings]
        out = torch.matmul(out, top_n_out)  # [B, num_codebooks, T', num_embeddings]

        # 5 compute loss
        loss = self._compute_loss(out, target_ids, out_mask.squeeze(1) * masked_masks)
        return {"loss": loss}

    def _compute_loss(
        self, input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor
    ):
        input = input.transpose(1, 3)  # [B, num_embeddings, T' num_codebooks]
        entropy = torch.nn.functional.cross_entropy(
            input, target, reduction="none"
        )  # [B, T', num_codebooks]
        # stop gradient for non mask area
        loss = entropy * mask.unsqueeze(2)
        return loss.sum() / (mask.sum() * loss.size(2))

    def _forward_encoder_blocks(
        self,
        xs: torch.Tensor,
        xs_masks: torch.Tensor,
        pos_emb: torch.Tensor,
        mask_pad: torch.Tensor,
    ):
        masks = xs_masks
        for layer in self.encoder.encoders:
            xs, masks, _, _ = layer(xs, xs_masks, pos_emb, mask_pad)
        if self.encoder.normalize_before:
            xs = self.encoder.after_norm(xs)
        # Here we assume the mask is not changed in encoder layers, so just
        # return the masks before encoder layers, and the masks will be used
        # for cross attention with decoder later
        return xs, masks

    def _nearest_embedding_idx(self, xs: torch.Tensor) -> torch.Tensor:
        xs = self.input_layer_norm(xs)
        xs = self.input_dropout(xs)
        xs = torch.matmul(xs, self.projection.to(xs.device))

        B, T, C = xs.size()
        flattened_input = xs.view(-1, C)
        embeddings = self.embeddings.to(
            xs.device
        )  # [num_codebooks, embedding_dim, num_embeddings]
        # [num_codebooks, B*T, num_embeddings]
        distance = (
            torch.sum(flattened_input**2, dim=1, keepdim=True).unsqueeze(0)
            + torch.sum(embeddings**2, dim=1, keepdim=True)
            - 2 * torch.matmul(flattened_input.unsqueeze(0), embeddings)
        )

        out = torch.argmin(distance, dim=-1)  # [num_codebooks, B*T]
        out = out.transpose(0, 1)  # [B*T, num_codebooks]
        return out.reshape(B, T, -1)  # [B, T, num_codebooks]

    def _apply_mask(self, xs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        masks = compute_mask_indices(
            xs.size()[:-1],
            self.mask_prob,
            self.mask_length,
            self.min_masks,
            device=xs.device,
        )
        masks_expand = masks.unsqueeze(-1)  # [B, T, 1]

        mask_emb = self.mask_emb.to(xs.device).view(1, 1, -1)
        xs = torch.where(masks_expand, mask_emb, xs)
        return xs, masks

    def _forward_subsampling(
        self, xs: torch.Tensor, xs_lens: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        T = xs.size(1)
        masks = ~make_pad_mask(xs_lens, T).unsqueeze(1)  # (B, 1, T)
        if self.encoder.global_cmvn is not None:
            xs = self.encoder.global_cmvn(xs)
        xs, pos_emb, masks = self.encoder.embed(xs, masks)
        return xs, pos_emb, masks