|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import io |
|
import logging |
|
import os |
|
from typing import Optional, Union |
|
|
|
import soundfile as sf |
|
import torch |
|
from whisper import _MODELS, _download, _ALIGNMENT_HEADS, available_models |
|
from whisper.audio import log_mel_spectrogram |
|
from whisper.model import ModelDimensions |
|
|
|
from whisper_model import Whisper_ |
|
|
|
logger = logging.getLogger("dump_feature") |
|
|
|
|
|
def load_model( |
|
name: str, |
|
device: Optional[Union[str, torch.device]] = None, |
|
download_root: str = None, |
|
in_memory: bool = False, |
|
) -> Whisper_: |
|
""" |
|
Reference: https://github.com/openai/whisper/blob/main/whisper/__init__.py#L97 |
|
But we will load a `Whisper_` model for feature extraction. |
|
|
|
Parameters |
|
---------- |
|
name : str |
|
one of the official model names listed by `whisper.available_models()`, or |
|
path to a model checkpoint containing the model dimensions and the model state_dict. |
|
device : Union[str, torch.device] |
|
the PyTorch device to put the model into |
|
download_root: str |
|
path to download the model files; by default, it uses "~/.cache/whisper" |
|
in_memory: bool |
|
whether to preload the model weights into host memory |
|
|
|
Returns |
|
------- |
|
model : Whisper |
|
The Whisper ASR model instance |
|
""" |
|
|
|
if device is None: |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
if download_root is None: |
|
default = os.path.join(os.path.expanduser("~"), ".cache") |
|
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper") |
|
|
|
if name in _MODELS: |
|
checkpoint_file = _download(_MODELS[name], download_root, in_memory) |
|
alignment_heads = _ALIGNMENT_HEADS[name] |
|
elif os.path.isfile(name): |
|
checkpoint_file = open(name, "rb").read() if in_memory else name |
|
alignment_heads = None |
|
else: |
|
raise RuntimeError( |
|
f"Model {name} not found; available models = {available_models()}" |
|
) |
|
|
|
with ( |
|
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb") |
|
) as fp: |
|
checkpoint = torch.load(fp, map_location=device) |
|
del checkpoint_file |
|
|
|
dims = ModelDimensions(**checkpoint["dims"]) |
|
model = Whisper_(dims) |
|
model.load_state_dict(checkpoint["model_state_dict"]) |
|
|
|
if alignment_heads is not None: |
|
model.set_alignment_heads(alignment_heads) |
|
|
|
return model.to(device) |
|
|
|
|
|
class WhisperFeatureReader(object): |
|
def __init__(self, root, ckpt, layer, device): |
|
self.device = device |
|
logger.info(f"device = {self.device}") |
|
|
|
self.model: Whisper_ = load_model(name=ckpt, device=self.device, download_root=root).eval() |
|
self.model.decoder = None |
|
self.layer = layer |
|
|
|
def read_audio(self, path, ref_len=None): |
|
wav, sample_rate = sf.read(path) |
|
assert sample_rate == 16000, sample_rate |
|
if ref_len is not None and abs(ref_len - len(wav)) > 160: |
|
logger.warning(f"ref {ref_len} != read {len(wav)} ({path})") |
|
return wav |
|
|
|
def get_feats(self, path, ref_len=None): |
|
wav = self.read_audio(path, ref_len) |
|
audio_length = len(wav) |
|
with torch.no_grad(): |
|
mel = log_mel_spectrogram(torch.from_numpy(wav).float().to(self.device)) |
|
hidden = self.model.extract_features(mel.unsqueeze(0), target_layer=self.layer) |
|
feature_length = audio_length // 320 |
|
hidden = hidden[0, :feature_length] |
|
return hidden.contiguous() |
|
|