File size: 3,867 Bytes
95c6462 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
# Copyright (c) ByteDance, Inc. and its affiliates.
# Copyright (c) Chutong Meng
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# Based on fairseq (https://github.com/facebookresearch/fairseq) and
# Whisper (https://github.com/openai/whisper/)
import io
import logging
import os
from typing import Optional, Union
import soundfile as sf
import torch
from whisper import _MODELS, _download, _ALIGNMENT_HEADS, available_models
from whisper.audio import log_mel_spectrogram
from whisper.model import ModelDimensions
from whisper_model import Whisper_
logger = logging.getLogger("dump_feature")
def load_model(
name: str,
device: Optional[Union[str, torch.device]] = None,
download_root: str = None,
in_memory: bool = False,
) -> Whisper_:
"""
Reference: https://github.com/openai/whisper/blob/main/whisper/__init__.py#L97
But we will load a `Whisper_` model for feature extraction.
Parameters
----------
name : str
one of the official model names listed by `whisper.available_models()`, or
path to a model checkpoint containing the model dimensions and the model state_dict.
device : Union[str, torch.device]
the PyTorch device to put the model into
download_root: str
path to download the model files; by default, it uses "~/.cache/whisper"
in_memory: bool
whether to preload the model weights into host memory
Returns
-------
model : Whisper
The Whisper ASR model instance
"""
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
if download_root is None:
default = os.path.join(os.path.expanduser("~"), ".cache")
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
if name in _MODELS:
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
alignment_heads = _ALIGNMENT_HEADS[name]
elif os.path.isfile(name):
checkpoint_file = open(name, "rb").read() if in_memory else name
alignment_heads = None
else:
raise RuntimeError(
f"Model {name} not found; available models = {available_models()}"
)
with (
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
) as fp:
checkpoint = torch.load(fp, map_location=device)
del checkpoint_file
dims = ModelDimensions(**checkpoint["dims"])
model = Whisper_(dims)
model.load_state_dict(checkpoint["model_state_dict"])
if alignment_heads is not None:
model.set_alignment_heads(alignment_heads)
return model.to(device)
class WhisperFeatureReader(object):
def __init__(self, root, ckpt, layer, device):
self.device = device
logger.info(f"device = {self.device}")
self.model: Whisper_ = load_model(name=ckpt, device=self.device, download_root=root).eval()
self.model.decoder = None # to save some memory by deleting the decoder
self.layer = layer # one-based
def read_audio(self, path, ref_len=None):
wav, sample_rate = sf.read(path)
assert sample_rate == 16000, sample_rate
if ref_len is not None and abs(ref_len - len(wav)) > 160:
logger.warning(f"ref {ref_len} != read {len(wav)} ({path})")
return wav
def get_feats(self, path, ref_len=None):
wav = self.read_audio(path, ref_len)
audio_length = len(wav)
with torch.no_grad():
mel = log_mel_spectrogram(torch.from_numpy(wav).float().to(self.device))
hidden = self.model.extract_features(mel.unsqueeze(0), target_layer=self.layer)
feature_length = audio_length // 320
hidden = hidden[0, :feature_length]
return hidden.contiguous()
|