File size: 1,322 Bytes
0dabde8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
import librosa
from fairseq import checkpoint_utils
import torch.nn.functional as F


def get_model(vec_path):
    print("load model(s) from {}".format(vec_path))
    models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
        [vec_path],
        suffix="",
    )
    model = models[0]
    model.eval()
    return model


@torch.no_grad()
def get_content(hmodel, wav_16k_tensor, device='cuda', layer=12):
    # print(layer)
    wav_16k_tensor = wav_16k_tensor.to(device)
    # so that the output shape will be len(audio//320)
    wav_16k_tensor = F.pad(wav_16k_tensor, ((400 - 320) // 2, (400 - 320) // 2))
    feats = wav_16k_tensor
    padding_mask = torch.BoolTensor(feats.shape).fill_(False)
    inputs = {
        "source": feats.to(wav_16k_tensor.device),
        "padding_mask": padding_mask.to(wav_16k_tensor.device),
        "output_layer": layer
    }
    logits = hmodel.extract_features(**inputs)[0]
    # feats = hmodel.final_proj(logits[0])
    return logits


if __name__ == '__main__':
    audio, sr = librosa.load('test.wav', sr=16000)
    audio = audio[:100*320]
    model = get_model('../../ckpts/checkpoint_best_legacy_500.pt')
    model = model.cuda()
    content = get_content(model, torch.tensor([audio]))
    print(content)