File size: 1,090 Bytes
0a5b75c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import sys,os,torch
sys.path.append(f"{os.getcwd()}/eres2net")
sv_path = "pretrained_models/sv/pretrained_eres2netv2w24s4ep4.ckpt"
from ERes2NetV2 import ERes2NetV2
import kaldi as Kaldi
class SV:
    def __init__(self,device,is_half):
        pretrained_state = torch.load(sv_path, map_location='cpu', weights_only=False)
        embedding_model = ERes2NetV2(baseWidth=24,scale=4,expansion=4)
        embedding_model.load_state_dict(pretrained_state)
        embedding_model.eval()
        self.embedding_model=embedding_model
        if is_half == False:
            self.embedding_model=self.embedding_model.to(device)
        else:
            self.embedding_model=self.embedding_model.half().to(device)
        self.is_half=is_half

    def compute_embedding3(self,wav):
        with torch.no_grad():
            if self.is_half==True:wav=wav.half()
            feat = torch.stack([Kaldi.fbank(wav0.unsqueeze(0), num_mel_bins=80, sample_frequency=16000, dither=0) for wav0 in wav])
            sv_emb = self.embedding_model.forward3(feat)
        return sv_emb