|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Optional |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from torch import Tensor |
|
from whisper.model import AudioEncoder, sinusoids, Whisper, ModelDimensions |
|
|
|
|
|
class AudioEncoder_(AudioEncoder): |
|
def __init__(self, *args, **kwargs): |
|
super(AudioEncoder_, self).__init__(*args, **kwargs) |
|
|
|
def extract_feature(self, x: Tensor, target_layer: Optional[int] = None): |
|
""" |
|
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx) |
|
the mel spectrogram of the audio |
|
""" |
|
x = F.gelu(self.conv1(x)) |
|
x = F.gelu(self.conv2(x)) |
|
x = x.permute(0, 2, 1) |
|
|
|
length_x = x.shape[1] |
|
if length_x > self.positional_embedding.shape[0]: |
|
self.register_buffer("positional_embedding", sinusoids(length_x, self.positional_embedding.shape[1])) |
|
self.positional_embedding = self.positional_embedding.to(x.device) |
|
x = (x + self.positional_embedding[:length_x, :]).to(x.dtype) |
|
|
|
if target_layer is None: |
|
target_layer = len(self.blocks) |
|
|
|
for block in self.blocks[:target_layer]: |
|
x = block(x) |
|
|
|
return x |
|
|
|
|
|
class Whisper_(Whisper): |
|
def __init__(self, dims: ModelDimensions): |
|
super(Whisper_, self).__init__(dims) |
|
|
|
self.encoder = AudioEncoder_( |
|
self.dims.n_mels, |
|
self.dims.n_audio_ctx, |
|
self.dims.n_audio_state, |
|
self.dims.n_audio_head, |
|
self.dims.n_audio_layer, |
|
) |
|
|
|
def extract_features(self, mel: torch.Tensor, target_layer: Optional[int] = None): |
|
return self.encoder.extract_feature(mel, target_layer) |
|
|