MiXaiLL76 commited on
Commit
7b3e285
·
1 Parent(s): bcda6d8

add threading

Browse files
Files changed (1) hide show
  1. tools/extract_embedding.py +94 -30
tools/extract_embedding.py CHANGED
@@ -18,53 +18,117 @@ import torchaudio
18
  from tqdm import tqdm
19
  import onnxruntime
20
  import torchaudio.compliance.kaldi as kaldi
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
 
23
  def main(args):
24
  utt2wav, utt2spk = {}, {}
25
- with open('{}/wav.scp'.format(args.dir)) as f:
26
  for l in f:
27
- l = l.replace('\n', '').split()
28
  utt2wav[l[0]] = l[1]
29
- with open('{}/utt2spk'.format(args.dir)) as f:
30
  for l in f:
31
- l = l.replace('\n', '').split()
32
  utt2spk[l[0]] = l[1]
33
 
34
- option = onnxruntime.SessionOptions()
35
- option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
36
- option.intra_op_num_threads = 1
37
- providers = ["CPUExecutionProvider"]
38
- ort_session = onnxruntime.InferenceSession(args.onnx_path, sess_options=option, providers=providers)
 
39
 
40
  utt2embedding, spk2embedding = {}, {}
41
- for utt in tqdm(utt2wav.keys()):
42
- audio, sample_rate = torchaudio.load(utt2wav[utt])
43
- if sample_rate != 16000:
44
- audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio)
45
- feat = kaldi.fbank(audio,
46
- num_mel_bins=80,
47
- dither=0,
48
- sample_frequency=16000)
49
- feat = feat - feat.mean(dim=0, keepdim=True)
50
- embedding = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
51
- utt2embedding[utt] = embedding
52
- spk = utt2spk[utt]
53
- if spk not in spk2embedding:
54
- spk2embedding[spk] = []
55
- spk2embedding[spk].append(embedding)
 
 
 
 
56
  for k, v in spk2embedding.items():
57
  spk2embedding[k] = torch.tensor(v).mean(dim=0).tolist()
58
 
59
- torch.save(utt2embedding, '{}/utt2embedding.pt'.format(args.dir))
60
- torch.save(spk2embedding, '{}/spk2embedding.pt'.format(args.dir))
61
 
62
 
63
  if __name__ == "__main__":
64
  parser = argparse.ArgumentParser()
65
- parser.add_argument('--dir',
66
- type=str)
67
- parser.add_argument('--onnx_path',
68
- type=str)
69
  args = parser.parse_args()
70
  main(args)
 
18
  from tqdm import tqdm
19
  import onnxruntime
20
  import torchaudio.compliance.kaldi as kaldi
21
+ from queue import Queue, Empty
22
+ from threading import Thread
23
+
24
+
25
+ class ExtractEmbedding:
26
+ def __init__(self, model_path: str, queue: Queue, out_queue: Queue):
27
+ self.model_path = model_path
28
+ self.queue = queue
29
+ self.out_queue = out_queue
30
+ self.is_run = True
31
+
32
+ def run(self):
33
+ self.consumer_thread = Thread(target=self.consumer)
34
+ self.consumer_thread.start()
35
+
36
+ def stop(self):
37
+ self.is_run = False
38
+ self.consumer_thread.join()
39
+
40
+ def consumer(self):
41
+ option = onnxruntime.SessionOptions()
42
+ option.graph_optimization_level = (
43
+ onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
44
+ )
45
+ option.intra_op_num_threads = 1
46
+ providers = ["CPUExecutionProvider"]
47
+ ort_session = onnxruntime.InferenceSession(
48
+ self.model_path, sess_options=option, providers=providers
49
+ )
50
+
51
+ while self.is_run:
52
+ try:
53
+ utt, wav_file = self.queue.get(timeout=1)
54
+
55
+ audio, sample_rate = torchaudio.load(wav_file)
56
+ if sample_rate != 16000:
57
+ audio = torchaudio.transforms.Resample(
58
+ orig_freq=sample_rate, new_freq=16000
59
+ )(audio)
60
+ feat = kaldi.fbank(
61
+ audio, num_mel_bins=80, dither=0, sample_frequency=16000
62
+ )
63
+ feat = feat - feat.mean(dim=0, keepdim=True)
64
+ embedding = (
65
+ ort_session.run(
66
+ None,
67
+ {
68
+ ort_session.get_inputs()[0]
69
+ .name: feat.unsqueeze(dim=0)
70
+ .cpu()
71
+ .numpy()
72
+ },
73
+ )[0]
74
+ .flatten()
75
+ .tolist()
76
+ )
77
+ self.out_queue.put((utt, embedding))
78
+ except Empty:
79
+ self.is_run = False
80
+ break
81
 
82
 
83
  def main(args):
84
  utt2wav, utt2spk = {}, {}
85
+ with open("{}/wav.scp".format(args.dir)) as f:
86
  for l in f:
87
+ l = l.replace("\n", "").split()
88
  utt2wav[l[0]] = l[1]
89
+ with open("{}/utt2spk".format(args.dir)) as f:
90
  for l in f:
91
+ l = l.replace("\n", "").split()
92
  utt2spk[l[0]] = l[1]
93
 
94
+ input_queue = Queue()
95
+ output_queue = Queue()
96
+ consumers = [
97
+ ExtractEmbedding(args.onnx_path, input_queue, output_queue)
98
+ for _ in range(args.num_thread)
99
+ ]
100
 
101
  utt2embedding, spk2embedding = {}, {}
102
+ for utt in tqdm(utt2wav.keys(), desc="Load data"):
103
+ input_queue.put((utt, utt2wav[utt]))
104
+
105
+ for c in consumers:
106
+ c.run()
107
+
108
+ with tqdm(desc="Process data: ", total=len(utt2wav)) as pbar:
109
+ while any([c.is_run for c in consumers]):
110
+ try:
111
+ utt, embedding = output_queue.get(timeout=1)
112
+ utt2embedding[utt] = embedding
113
+ spk = utt2spk[utt]
114
+ if spk not in spk2embedding:
115
+ spk2embedding[spk] = []
116
+ spk2embedding[spk].append(embedding)
117
+ pbar.update(1)
118
+ except Empty:
119
+ continue
120
+
121
  for k, v in spk2embedding.items():
122
  spk2embedding[k] = torch.tensor(v).mean(dim=0).tolist()
123
 
124
+ torch.save(utt2embedding, "{}/utt2embedding.pt".format(args.dir))
125
+ torch.save(spk2embedding, "{}/spk2embedding.pt".format(args.dir))
126
 
127
 
128
  if __name__ == "__main__":
129
  parser = argparse.ArgumentParser()
130
+ parser.add_argument("--dir", type=str)
131
+ parser.add_argument("--onnx_path", type=str)
132
+ parser.add_argument("--num_thread", type=int, default=8)
 
133
  args = parser.parse_args()
134
  main(args)