File size: 3,326 Bytes
56a1295
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch
from transformers import AutoTokenizer, AutoModel, Wav2Vec2FeatureExtractor

class AstralQuantizer(torch.nn.Module):
    def __init__(

            self,

            tokenizer_name: str,

            ssl_model_name: str,

            ssl_output_layer: int,

            encoder: torch.nn.Module,

            quantizer: torch.nn.Module,

            skip_ssl: bool = False,

    ):
        super().__init__()
        self.encoder = encoder
        self.quantizer = quantizer
        self.tokenizer_name = tokenizer_name
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)

        # Load SSL model from Huggingface
        self.ssl_model_name = ssl_model_name
        self.ssl_output_layer = ssl_output_layer
        self.ssl_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(ssl_model_name)

        if skip_ssl:  # in case the same SSL model has been loaded somewhere else
            self.ssl_model = None
        else:
            self.ssl_model = AutoModel.from_pretrained(ssl_model_name).eval()
            self.ssl_model.encoder.layers = self.ssl_model.encoder.layers[:ssl_output_layer]
            self.ssl_model.encoder.layer_norm = torch.nn.Identity()

    def load_separate_checkpoint(self, checkpoint_path):
        params = torch.load(checkpoint_path, map_location='cpu')['net']
        for key in params.keys():
            for k in list(params[key].keys()):
                if k.startswith("module."):
                    params[key][k[len("module."):]] = params[key][k]
                    del params[key][k]
        self.encoder.load_state_dict(params['encoder'])
        self.quantizer.load_state_dict(params['vq'])
        if self.decoder is not None:
            self.decoder.load_state_dict(params['decoder'])
        if self.asr_decoder is not None:
            self.asr_decoder.load_state_dict(params['predictor'], strict=False)

    def forward(self, waves_16k, wave_16k_lens, ssl_model=None):
        ssl_fn = self.ssl_model if self.ssl_model else ssl_model
        assert ssl_fn is not None, "In case in-class SSL model loading is skipped, external ssl_model must be provided"
        waves_16k_input_list = [
            waves_16k[bib, :wave_16k_lens[bib]].cpu().numpy()
            for bib in range(len(waves_16k))
        ]
        alt_inputs = self.ssl_feature_extractor(
            waves_16k_input_list,
            return_tensors='pt',
            return_attention_mask=True,
            padding=True,
            sampling_rate=16000
        ).to(waves_16k.device)
        feature_lens = alt_inputs.data['attention_mask'].sum(-1) // 320  # frame rate of hubert is 50 Hz

        outputs = ssl_fn(
            alt_inputs.input_values,
            attention_mask=alt_inputs.attention_mask,
        )
        last_hidden_states = outputs.last_hidden_state
        last_hidden_states = last_hidden_states[:, :feature_lens.max(), :]
        feature_lens = feature_lens.clamp(max=last_hidden_states.size(1))
        last_hidden_states = last_hidden_states.transpose(1, 2)
        x_hidden = self.encoder(last_hidden_states, feature_lens)
        x_hidden = x_hidden.transpose(1, 2)
        x_quantized, indices = self.quantizer(x_hidden)[:2]
        return x_quantized, indices, feature_lens