# Copyright (c) ByteDance, Inc. and its affiliates. # Copyright (c) Chutong Meng # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. # Based on fairseq (https://github.com/facebookresearch/fairseq) and # Whisper (https://github.com/openai/whisper/) from typing import Optional import torch import torch.nn.functional as F from torch import Tensor from whisper.model import AudioEncoder, sinusoids, Whisper, ModelDimensions class AudioEncoder_(AudioEncoder): def __init__(self, *args, **kwargs): super(AudioEncoder_, self).__init__(*args, **kwargs) def extract_feature(self, x: Tensor, target_layer: Optional[int] = None): """ x : torch.Tensor, shape = (batch_size, n_mels, n_ctx) the mel spectrogram of the audio """ x = F.gelu(self.conv1(x)) x = F.gelu(self.conv2(x)) x = x.permute(0, 2, 1) length_x = x.shape[1] if length_x > self.positional_embedding.shape[0]: self.register_buffer("positional_embedding", sinusoids(length_x, self.positional_embedding.shape[1])) self.positional_embedding = self.positional_embedding.to(x.device) x = (x + self.positional_embedding[:length_x, :]).to(x.dtype) if target_layer is None: target_layer = len(self.blocks) for block in self.blocks[:target_layer]: x = block(x) return x class Whisper_(Whisper): def __init__(self, dims: ModelDimensions): super(Whisper_, self).__init__(dims) # replace audio encoder with our audio encoder self.encoder = AudioEncoder_( self.dims.n_mels, self.dims.n_audio_ctx, self.dims.n_audio_state, self.dims.n_audio_head, self.dims.n_audio_layer, ) def extract_features(self, mel: torch.Tensor, target_layer: Optional[int] = None): return self.encoder.extract_feature(mel, target_layer)