CosyVoice commited on
Commit
2665b06
·
unverified ·
2 Parent(s): e19e80f 1d05ae5

Merge pull request #356 from MiXaiLL76/main

Browse files

Implemented fast processing of extract_embedding

Files changed (1) hide show
  1. tools/extract_embedding.py +48 -24
tools/extract_embedding.py CHANGED
@@ -13,58 +13,82 @@
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  import argparse
 
 
 
 
16
  import torch
17
  import torchaudio
18
- from tqdm import tqdm
19
- import onnxruntime
20
  import torchaudio.compliance.kaldi as kaldi
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
 
23
  def main(args):
24
  utt2wav, utt2spk = {}, {}
25
- with open('{}/wav.scp'.format(args.dir)) as f:
26
  for l in f:
27
- l = l.replace('\n', '').split()
28
  utt2wav[l[0]] = l[1]
29
- with open('{}/utt2spk'.format(args.dir)) as f:
30
  for l in f:
31
- l = l.replace('\n', '').split()
32
  utt2spk[l[0]] = l[1]
33
 
 
 
34
  option = onnxruntime.SessionOptions()
35
- option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
 
 
36
  option.intra_op_num_threads = 1
37
  providers = ["CPUExecutionProvider"]
38
- ort_session = onnxruntime.InferenceSession(args.onnx_path, sess_options=option, providers=providers)
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  utt2embedding, spk2embedding = {}, {}
41
- for utt in tqdm(utt2wav.keys()):
42
- audio, sample_rate = torchaudio.load(utt2wav[utt])
43
- if sample_rate != 16000:
44
- audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio)
45
- feat = kaldi.fbank(audio,
46
- num_mel_bins=80,
47
- dither=0,
48
- sample_frequency=16000)
49
- feat = feat - feat.mean(dim=0, keepdim=True)
50
- embedding = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
51
  utt2embedding[utt] = embedding
52
  spk = utt2spk[utt]
53
  if spk not in spk2embedding:
54
  spk2embedding[spk] = []
55
  spk2embedding[spk].append(embedding)
 
56
  for k, v in spk2embedding.items():
57
  spk2embedding[k] = torch.tensor(v).mean(dim=0).tolist()
58
 
59
- torch.save(utt2embedding, '{}/utt2embedding.pt'.format(args.dir))
60
- torch.save(spk2embedding, '{}/spk2embedding.pt'.format(args.dir))
61
 
62
 
63
  if __name__ == "__main__":
64
  parser = argparse.ArgumentParser()
65
- parser.add_argument('--dir',
66
- type=str)
67
- parser.add_argument('--onnx_path',
68
- type=str)
69
  args = parser.parse_args()
70
  main(args)
 
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  import argparse
16
+ import os
17
+ from concurrent.futures import ThreadPoolExecutor
18
+
19
+ import onnxruntime
20
  import torch
21
  import torchaudio
 
 
22
  import torchaudio.compliance.kaldi as kaldi
23
+ from tqdm import tqdm
24
+ from itertools import repeat
25
+
26
+
27
+ def extract_embedding(utt: str, wav_file: str, ort_session: onnxruntime.InferenceSession):
28
+ audio, sample_rate = torchaudio.load(wav_file)
29
+ if sample_rate != 16000:
30
+ audio = torchaudio.transforms.Resample(
31
+ orig_freq=sample_rate, new_freq=16000
32
+ )(audio)
33
+ feat = kaldi.fbank(audio, num_mel_bins=80, dither=0, sample_frequency=16000)
34
+ feat = feat - feat.mean(dim=0, keepdim=True)
35
+ embedding = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
36
+ return (utt, embedding)
37
 
38
 
39
  def main(args):
40
  utt2wav, utt2spk = {}, {}
41
+ with open("{}/wav.scp".format(args.dir)) as f:
42
  for l in f:
43
+ l = l.replace("\n", "").split()
44
  utt2wav[l[0]] = l[1]
45
+ with open("{}/utt2spk".format(args.dir)) as f:
46
  for l in f:
47
+ l = l.replace("\n", "").split()
48
  utt2spk[l[0]] = l[1]
49
 
50
+ assert os.path.exists(args.onnx_path), "onnx_path not exists"
51
+
52
  option = onnxruntime.SessionOptions()
53
+ option.graph_optimization_level = (
54
+ onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
55
+ )
56
  option.intra_op_num_threads = 1
57
  providers = ["CPUExecutionProvider"]
58
+ ort_session = onnxruntime.InferenceSession(
59
+ args.onnx_path, sess_options=option, providers=providers
60
+ )
61
+
62
+ all_utt = utt2wav.keys()
63
+
64
+ with ThreadPoolExecutor(max_workers=args.num_thread) as executor:
65
+ results = list(
66
+ tqdm(
67
+ executor.map(extract_embedding, all_utt, [utt2wav[utt] for utt in all_utt], repeat(ort_session)),
68
+ total=len(utt2wav),
69
+ desc="Process data: "
70
+ )
71
+ )
72
 
73
  utt2embedding, spk2embedding = {}, {}
74
+ for utt, embedding in results:
 
 
 
 
 
 
 
 
 
75
  utt2embedding[utt] = embedding
76
  spk = utt2spk[utt]
77
  if spk not in spk2embedding:
78
  spk2embedding[spk] = []
79
  spk2embedding[spk].append(embedding)
80
+
81
  for k, v in spk2embedding.items():
82
  spk2embedding[k] = torch.tensor(v).mean(dim=0).tolist()
83
 
84
+ torch.save(utt2embedding, "{}/utt2embedding.pt".format(args.dir))
85
+ torch.save(spk2embedding, "{}/spk2embedding.pt".format(args.dir))
86
 
87
 
88
  if __name__ == "__main__":
89
  parser = argparse.ArgumentParser()
90
+ parser.add_argument("--dir", type=str)
91
+ parser.add_argument("--onnx_path", type=str)
92
+ parser.add_argument("--num_thread", type=int, default=8)
 
93
  args = parser.parse_args()
94
  main(args)