KingNish commited on
Commit
95c6462
·
verified ·
1 Parent(s): 4cda277

Upload ./RepCodec/examples/whisper_feature_reader.py with huggingface_hub

Browse files
RepCodec/examples/whisper_feature_reader.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # Copyright (c) Chutong Meng
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # Based on fairseq (https://github.com/facebookresearch/fairseq) and
7
+ # Whisper (https://github.com/openai/whisper/)
8
+
9
+ import io
10
+ import logging
11
+ import os
12
+ from typing import Optional, Union
13
+
14
+ import soundfile as sf
15
+ import torch
16
+ from whisper import _MODELS, _download, _ALIGNMENT_HEADS, available_models
17
+ from whisper.audio import log_mel_spectrogram
18
+ from whisper.model import ModelDimensions
19
+
20
+ from whisper_model import Whisper_
21
+
22
+ logger = logging.getLogger("dump_feature")
23
+
24
+
25
+ def load_model(
26
+ name: str,
27
+ device: Optional[Union[str, torch.device]] = None,
28
+ download_root: str = None,
29
+ in_memory: bool = False,
30
+ ) -> Whisper_:
31
+ """
32
+ Reference: https://github.com/openai/whisper/blob/main/whisper/__init__.py#L97
33
+ But we will load a `Whisper_` model for feature extraction.
34
+
35
+ Parameters
36
+ ----------
37
+ name : str
38
+ one of the official model names listed by `whisper.available_models()`, or
39
+ path to a model checkpoint containing the model dimensions and the model state_dict.
40
+ device : Union[str, torch.device]
41
+ the PyTorch device to put the model into
42
+ download_root: str
43
+ path to download the model files; by default, it uses "~/.cache/whisper"
44
+ in_memory: bool
45
+ whether to preload the model weights into host memory
46
+
47
+ Returns
48
+ -------
49
+ model : Whisper
50
+ The Whisper ASR model instance
51
+ """
52
+
53
+ if device is None:
54
+ device = "cuda" if torch.cuda.is_available() else "cpu"
55
+ if download_root is None:
56
+ default = os.path.join(os.path.expanduser("~"), ".cache")
57
+ download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
58
+
59
+ if name in _MODELS:
60
+ checkpoint_file = _download(_MODELS[name], download_root, in_memory)
61
+ alignment_heads = _ALIGNMENT_HEADS[name]
62
+ elif os.path.isfile(name):
63
+ checkpoint_file = open(name, "rb").read() if in_memory else name
64
+ alignment_heads = None
65
+ else:
66
+ raise RuntimeError(
67
+ f"Model {name} not found; available models = {available_models()}"
68
+ )
69
+
70
+ with (
71
+ io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
72
+ ) as fp:
73
+ checkpoint = torch.load(fp, map_location=device)
74
+ del checkpoint_file
75
+
76
+ dims = ModelDimensions(**checkpoint["dims"])
77
+ model = Whisper_(dims)
78
+ model.load_state_dict(checkpoint["model_state_dict"])
79
+
80
+ if alignment_heads is not None:
81
+ model.set_alignment_heads(alignment_heads)
82
+
83
+ return model.to(device)
84
+
85
+
86
+ class WhisperFeatureReader(object):
87
+ def __init__(self, root, ckpt, layer, device):
88
+ self.device = device
89
+ logger.info(f"device = {self.device}")
90
+
91
+ self.model: Whisper_ = load_model(name=ckpt, device=self.device, download_root=root).eval()
92
+ self.model.decoder = None # to save some memory by deleting the decoder
93
+ self.layer = layer # one-based
94
+
95
+ def read_audio(self, path, ref_len=None):
96
+ wav, sample_rate = sf.read(path)
97
+ assert sample_rate == 16000, sample_rate
98
+ if ref_len is not None and abs(ref_len - len(wav)) > 160:
99
+ logger.warning(f"ref {ref_len} != read {len(wav)} ({path})")
100
+ return wav
101
+
102
+ def get_feats(self, path, ref_len=None):
103
+ wav = self.read_audio(path, ref_len)
104
+ audio_length = len(wav)
105
+ with torch.no_grad():
106
+ mel = log_mel_spectrogram(torch.from_numpy(wav).float().to(self.device))
107
+ hidden = self.model.extract_features(mel.unsqueeze(0), target_layer=self.layer)
108
+ feature_length = audio_length // 320
109
+ hidden = hidden[0, :feature_length]
110
+ return hidden.contiguous()