File size: 2,020 Bytes
a82df09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
# Copyright (c) ByteDance, Inc. and its affiliates.
# Copyright (c) Chutong Meng
# 
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# Based on fairseq (https://github.com/facebookresearch/fairseq) and 
# Whisper (https://github.com/openai/whisper/)

from typing import Optional

import torch
import torch.nn.functional as F
from torch import Tensor
from whisper.model import AudioEncoder, sinusoids, Whisper, ModelDimensions


class AudioEncoder_(AudioEncoder):
    def __init__(self, *args, **kwargs):
        super(AudioEncoder_, self).__init__(*args, **kwargs)

    def extract_feature(self, x: Tensor, target_layer: Optional[int] = None):
        """
        x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
            the mel spectrogram of the audio
        """
        x = F.gelu(self.conv1(x))
        x = F.gelu(self.conv2(x))
        x = x.permute(0, 2, 1)

        length_x = x.shape[1]
        if length_x > self.positional_embedding.shape[0]:
            self.register_buffer("positional_embedding", sinusoids(length_x, self.positional_embedding.shape[1]))
            self.positional_embedding = self.positional_embedding.to(x.device)
        x = (x + self.positional_embedding[:length_x, :]).to(x.dtype)

        if target_layer is None:
            target_layer = len(self.blocks)

        for block in self.blocks[:target_layer]:
            x = block(x)

        return x


class Whisper_(Whisper):
    def __init__(self, dims: ModelDimensions):
        super(Whisper_, self).__init__(dims)
        # replace audio encoder with our audio encoder
        self.encoder = AudioEncoder_(
            self.dims.n_mels,
            self.dims.n_audio_ctx,
            self.dims.n_audio_state,
            self.dims.n_audio_head,
            self.dims.n_audio_layer,
        )

    def extract_features(self, mel: torch.Tensor, target_layer: Optional[int] = None):
        return self.encoder.extract_feature(mel, target_layer)