KingNish commited on
Commit
db9db4f
·
verified ·
1 Parent(s): c80ca96

Upload ./RepCodec/examples/hubert_feature_reader.py with huggingface_hub

Browse files
RepCodec/examples/hubert_feature_reader.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # Copyright (c) Chutong Meng
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # Based on fairseq (https://github.com/facebookresearch/fairseq)
7
+
8
+ import logging
9
+
10
+ import fairseq
11
+ import torch
12
+ import torch.nn.functional as F
13
+
14
+ from fairseq.data.audio.audio_utils import get_features_or_waveform
15
+
16
+ logger = logging.getLogger("dump_feature")
17
+
18
+
19
+ class HubertFeatureReader(object):
20
+ def __init__(self, ckpt_path: str, layer: int, device: str, max_chunk=1600000):
21
+ (
22
+ model,
23
+ cfg,
24
+ task,
25
+ ) = fairseq.checkpoint_utils.load_model_ensemble_and_task([ckpt_path])
26
+
27
+ self.device = device
28
+ logger.info(f"device = {self.device}")
29
+
30
+ self.model = model[0].eval().to(self.device)
31
+ self.task = task
32
+ self.layer = layer
33
+ self.max_chunk = max_chunk
34
+ logger.info(f"TASK CONFIG:\n{self.task.cfg}")
35
+ logger.info(f" max_chunk = {self.max_chunk}")
36
+
37
+ def read_audio(self, path, ref_len=None):
38
+ wav = get_features_or_waveform(path, need_waveform=True, use_sample_rate=self.task.cfg.sample_rate)
39
+ if wav.ndim == 2:
40
+ wav = wav.mean(-1)
41
+ assert wav.ndim == 1, wav.ndim
42
+ if ref_len is not None and abs(ref_len - len(wav)) > 160:
43
+ logger.warning(f"ref {ref_len} != read {len(wav)} ({path})")
44
+ return wav
45
+
46
+ def get_feats(self, path, ref_len=None):
47
+ x = self.read_audio(path, ref_len=ref_len)
48
+ with torch.no_grad():
49
+ x = torch.from_numpy(x).float().to(self.device)
50
+ if self.task.cfg.normalize:
51
+ x = F.layer_norm(x, x.shape)
52
+ x = x.view(1, -1)
53
+
54
+ feat = []
55
+ for start in range(0, x.size(1), self.max_chunk):
56
+ x_chunk = x[:, start: start + self.max_chunk]
57
+ feat_chunk, _ = self.model.extract_features(
58
+ source=x_chunk,
59
+ padding_mask=None,
60
+ mask=False,
61
+ output_layer=self.layer,
62
+ )
63
+ feat.append(feat_chunk)
64
+ return torch.cat(feat, 1).squeeze(0)