CosyVoice commited on
Commit
28f1353
·
unverified ·
2 Parent(s): 2898d5a f6b5c42

Merge pull request #404 from FunAudioLLM/dev/lyuxiang.lx

Browse files
tools/extract_embedding.py CHANGED
@@ -13,14 +13,50 @@
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  import argparse
 
 
16
  import torch
17
  import torchaudio
18
- from tqdm import tqdm
19
- import onnxruntime
20
  import torchaudio.compliance.kaldi as kaldi
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
 
23
  def main(args):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  utt2wav, utt2spk = {}, {}
25
  with open('{}/wav.scp'.format(args.dir)) as f:
26
  for l in f:
@@ -36,35 +72,6 @@ def main(args):
36
  option.intra_op_num_threads = 1
37
  providers = ["CPUExecutionProvider"]
38
  ort_session = onnxruntime.InferenceSession(args.onnx_path, sess_options=option, providers=providers)
 
39
 
40
- utt2embedding, spk2embedding = {}, {}
41
- for utt in tqdm(utt2wav.keys()):
42
- audio, sample_rate = torchaudio.load(utt2wav[utt])
43
- if sample_rate != 16000:
44
- audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio)
45
- feat = kaldi.fbank(audio,
46
- num_mel_bins=80,
47
- dither=0,
48
- sample_frequency=16000)
49
- feat = feat - feat.mean(dim=0, keepdim=True)
50
- embedding = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
51
- utt2embedding[utt] = embedding
52
- spk = utt2spk[utt]
53
- if spk not in spk2embedding:
54
- spk2embedding[spk] = []
55
- spk2embedding[spk].append(embedding)
56
- for k, v in spk2embedding.items():
57
- spk2embedding[k] = torch.tensor(v).mean(dim=0).tolist()
58
-
59
- torch.save(utt2embedding, '{}/utt2embedding.pt'.format(args.dir))
60
- torch.save(spk2embedding, '{}/spk2embedding.pt'.format(args.dir))
61
-
62
-
63
- if __name__ == "__main__":
64
- parser = argparse.ArgumentParser()
65
- parser.add_argument('--dir',
66
- type=str)
67
- parser.add_argument('--onnx_path',
68
- type=str)
69
- args = parser.parse_args()
70
  main(args)
 
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  import argparse
16
+ from concurrent.futures import ThreadPoolExecutor, as_completed
17
+ import onnxruntime
18
  import torch
19
  import torchaudio
 
 
20
  import torchaudio.compliance.kaldi as kaldi
21
+ from tqdm import tqdm
22
+
23
+
24
+ def single_job(utt):
25
+ audio, sample_rate = torchaudio.load(utt2wav[utt])
26
+ if sample_rate != 16000:
27
+ audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio)
28
+ feat = kaldi.fbank(audio,
29
+ num_mel_bins=80,
30
+ dither=0,
31
+ sample_frequency=16000)
32
+ feat = feat - feat.mean(dim=0, keepdim=True)
33
+ embedding = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
34
+ return utt, embedding
35
 
36
 
37
  def main(args):
38
+ all_task = [executor.submit(single_job, utt) for utt in utt2wav.keys()]
39
+ utt2embedding, spk2embedding = {}, {}
40
+ for future in tqdm(as_completed(all_task)):
41
+ utt, embedding = future.result()
42
+ utt2embedding[utt] = embedding
43
+ spk = utt2spk[utt]
44
+ if spk not in spk2embedding:
45
+ spk2embedding[spk] = []
46
+ spk2embedding[spk].append(embedding)
47
+ for k, v in spk2embedding.items():
48
+ spk2embedding[k] = torch.tensor(v).mean(dim=0).tolist()
49
+ torch.save(utt2embedding, "{}/utt2embedding.pt".format(args.dir))
50
+ torch.save(spk2embedding, "{}/spk2embedding.pt".format(args.dir))
51
+
52
+
53
+ if __name__ == "__main__":
54
+ parser = argparse.ArgumentParser()
55
+ parser.add_argument("--dir", type=str)
56
+ parser.add_argument("--onnx_path", type=str)
57
+ parser.add_argument("--num_thread", type=int, default=8)
58
+ args = parser.parse_args()
59
+
60
  utt2wav, utt2spk = {}, {}
61
  with open('{}/wav.scp'.format(args.dir)) as f:
62
  for l in f:
 
72
  option.intra_op_num_threads = 1
73
  providers = ["CPUExecutionProvider"]
74
  ort_session = onnxruntime.InferenceSession(args.onnx_path, sess_options=option, providers=providers)
75
+ executor = ThreadPoolExecutor(max_workers=args.num_thread)
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  main(args)
tools/extract_speech_token.py CHANGED
@@ -13,6 +13,7 @@
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  import argparse
 
16
  import logging
17
  import torch
18
  from tqdm import tqdm
@@ -22,7 +23,36 @@ import torchaudio
22
  import whisper
23
 
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  def main(args):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  utt2wav = {}
27
  with open('{}/wav.scp'.format(args.dir)) as f:
28
  for l in f:
@@ -34,28 +64,6 @@ def main(args):
34
  option.intra_op_num_threads = 1
35
  providers = ["CUDAExecutionProvider"]
36
  ort_session = onnxruntime.InferenceSession(args.onnx_path, sess_options=option, providers=providers)
 
37
 
38
- utt2speech_token = {}
39
- for utt in tqdm(utt2wav.keys()):
40
- audio, sample_rate = torchaudio.load(utt2wav[utt])
41
- if sample_rate != 16000:
42
- audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio)
43
- if audio.shape[1] / 16000 > 30:
44
- logging.warning('do not support extract speech token for audio longer than 30s')
45
- speech_token = []
46
- else:
47
- feat = whisper.log_mel_spectrogram(audio, n_mels=128)
48
- speech_token = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.detach().cpu().numpy(),
49
- ort_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
50
- utt2speech_token[utt] = speech_token
51
- torch.save(utt2speech_token, '{}/utt2speech_token.pt'.format(args.dir))
52
-
53
-
54
- if __name__ == "__main__":
55
- parser = argparse.ArgumentParser()
56
- parser.add_argument('--dir',
57
- type=str)
58
- parser.add_argument('--onnx_path',
59
- type=str)
60
- args = parser.parse_args()
61
  main(args)
 
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  import argparse
16
+ from concurrent.futures import ThreadPoolExecutor, as_completed
17
  import logging
18
  import torch
19
  from tqdm import tqdm
 
23
  import whisper
24
 
25
 
26
+ def single_job(utt):
27
+ audio, sample_rate = torchaudio.load(utt2wav[utt])
28
+ if sample_rate != 16000:
29
+ audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio)
30
+ if audio.shape[1] / 16000 > 30:
31
+ logging.warning('do not support extract speech token for audio longer than 30s')
32
+ speech_token = []
33
+ else:
34
+ feat = whisper.log_mel_spectrogram(audio, n_mels=128)
35
+ speech_token = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.detach().cpu().numpy(),
36
+ ort_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
37
+ return utt, speech_token
38
+
39
+
40
  def main(args):
41
+ all_task = [executor.submit(single_job, utt) for utt in utt2wav.keys()]
42
+ utt2speech_token = {}
43
+ for future in tqdm(as_completed(all_task)):
44
+ utt, speech_token = future.result()
45
+ utt2speech_token[utt] = speech_token
46
+ torch.save(utt2speech_token, '{}/utt2speech_token.pt'.format(args.dir))
47
+
48
+
49
+ if __name__ == "__main__":
50
+ parser = argparse.ArgumentParser()
51
+ parser.add_argument("--dir", type=str)
52
+ parser.add_argument("--onnx_path", type=str)
53
+ parser.add_argument("--num_thread", type=int, default=8)
54
+ args = parser.parse_args()
55
+
56
  utt2wav = {}
57
  with open('{}/wav.scp'.format(args.dir)) as f:
58
  for l in f:
 
64
  option.intra_op_num_threads = 1
65
  providers = ["CUDAExecutionProvider"]
66
  ort_session = onnxruntime.InferenceSession(args.onnx_path, sess_options=option, providers=providers)
67
+ executor = ThreadPoolExecutor(max_workers=args.num_thread)
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  main(args)