HoneyTian commited on
Commit
5821056
·
1 Parent(s): 2225ef6
examples/silero_vad_by_webrtcvad/step_1_prepare_data.py CHANGED
@@ -43,7 +43,7 @@ def get_args():
43
  return args
44
 
45
 
46
- def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000, max_epoch: int = 20000):
47
  data_dir = Path(data_dir)
48
  for epoch_idx in range(max_epoch):
49
  for filename in data_dir.glob("**/*.wav"):
 
43
  return args
44
 
45
 
46
+ def target_second_signal_generator(data_dir: str, duration: int = 6, sample_rate: int = 8000, max_epoch: int = 20000):
47
  data_dir = Path(data_dir)
48
  for epoch_idx in range(max_epoch):
49
  for filename in data_dir.glob("**/*.wav"):
examples/silero_vad_by_webrtcvad/step_2_make_vad_segments.py CHANGED
@@ -30,7 +30,7 @@ def get_args():
30
  parser.add_argument("--agg", default=3, type=int)
31
  parser.add_argument("--frame_duration_ms", default=30, type=int)
32
  parser.add_argument("--padding_duration_ms", default=30, type=int)
33
- parser.add_argument("--silence_duration_threshold", default=0.3, type=float)
34
 
35
  args = parser.parse_args()
36
  return args
 
30
  parser.add_argument("--agg", default=3, type=int)
31
  parser.add_argument("--frame_duration_ms", default=30, type=int)
32
  parser.add_argument("--padding_duration_ms", default=30, type=int)
33
+ parser.add_argument("--silence_duration_threshold", default=0.0, type=float)
34
 
35
  args = parser.parse_args()
36
  return args
toolbox/torchaudio/models/vad/fsmn_vad/configuration_fsmn_vad.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ from typing import Tuple
4
+
5
+ from toolbox.torchaudio.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class FSMNVadConfig(PretrainedConfig):
9
+ def __init__(self,
10
+ sample_rate: int = 8000,
11
+ nfft: int = 512,
12
+ win_size: int = 240,
13
+ hop_size: int = 80,
14
+ win_type: str = "hann",
15
+
16
+ in_channels: int = 64,
17
+ hidden_size: int = 128,
18
+
19
+ lr: float = 0.001,
20
+ lr_scheduler: str = "CosineAnnealingLR",
21
+ lr_scheduler_kwargs: dict = None,
22
+
23
+ max_epochs: int = 100,
24
+ clip_grad_norm: float = 10.,
25
+ seed: int = 1234,
26
+
27
+ num_workers: int = 4,
28
+ batch_size: int = 4,
29
+ eval_steps: int = 25000,
30
+
31
+ **kwargs
32
+ ):
33
+ super(FSMNVadConfig, self).__init__(**kwargs)
34
+ # transform
35
+ self.sample_rate = sample_rate
36
+ self.nfft = nfft
37
+ self.win_size = win_size
38
+ self.hop_size = hop_size
39
+ self.win_type = win_type
40
+
41
+ # encoder
42
+ self.in_channels = in_channels
43
+ self.hidden_size = hidden_size
44
+
45
+ # train
46
+ self.lr = lr
47
+ self.lr_scheduler = lr_scheduler
48
+ self.lr_scheduler_kwargs = lr_scheduler_kwargs or dict()
49
+
50
+ self.max_epochs = max_epochs
51
+ self.clip_grad_norm = clip_grad_norm
52
+ self.seed = seed
53
+
54
+ self.num_workers = num_workers
55
+ self.batch_size = batch_size
56
+ self.eval_steps = eval_steps
57
+
58
+
59
+ def main():
60
+ config = FSMNVadConfig()
61
+ config.to_yaml_file("config.yaml")
62
+ return
63
+
64
+
65
+ if __name__ == "__main__":
66
+ main()
toolbox/torchaudio/models/vad/fsmn_vad/fsmn_encoder.py CHANGED
@@ -1,10 +1,11 @@
1
  #!/usr/bin/python3
2
  # -*- coding: utf-8 -*-
 
 
 
 
3
  from typing import Tuple, Dict, List
4
- import copy
5
- import os
6
 
7
- import numpy as np
8
  import torch
9
  import torch.nn as nn
10
  import torch.nn.functional as F
 
1
  #!/usr/bin/python3
2
  # -*- coding: utf-8 -*-
3
+ """
4
+ https://github.com/modelscope/FunASR/blob/main/funasr/models/fsmn_vad_streaming/encoder.py
5
+
6
+ """
7
  from typing import Tuple, Dict, List
 
 
8
 
 
9
  import torch
10
  import torch.nn as nn
11
  import torch.nn.functional as F
toolbox/torchaudio/models/vad/fsmn_vad/modeling_fsmn_vad.py CHANGED
@@ -10,8 +10,50 @@ https://github.com/lovemefan/fsmn-vad
10
  https://github.com/modelscope/FunASR/blob/main/funasr/models/fsmn_vad_streaming/encoder.py
11
 
12
  """
 
 
13
 
 
 
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
 
17
  if __name__ == "__main__":
 
10
  https://github.com/modelscope/FunASR/blob/main/funasr/models/fsmn_vad_streaming/encoder.py
11
 
12
  """
13
+ import os
14
+ from typing import Optional, Union
15
 
16
+ import torch
17
+ import torch.nn as nn
18
 
19
+ from toolbox.torchaudio.configuration_utils import CONFIG_FILE
20
+ from toolbox.torchaudio.models.vad.fsmn_vad.configuration_fsmn_vad import FSMNVadConfig
21
+ from toolbox.torchaudio.modules.conv_stft import ConvSTFT
22
+ from toolbox.torchaudio.models.vad.fsmn_vad.fsmn_encoder import FSMN
23
+
24
+
25
+ MODEL_FILE = "model.pt"
26
+
27
+
28
+ class FSMNVadModel(nn.Module):
29
+ def __init__(self, config: FSMNVadConfig):
30
+ super(FSMNVadModel, self).__init__()
31
+ self.config = config
32
+ self.eps = 1e-12
33
+
34
+ self.stft = ConvSTFT(
35
+ nfft=config.nfft,
36
+ win_size=config.win_size,
37
+ hop_size=config.hop_size,
38
+ win_type=config.win_type,
39
+ power=1,
40
+ requires_grad=False
41
+ )
42
+
43
+ self.fsmn_encoder = FSMN(
44
+ input_size=400,
45
+ input_affine_size=140,
46
+ hidden_size=250,
47
+ basic_block_layers=4,
48
+ basic_block_hidden_size=128,
49
+ basic_block_lorder=20,
50
+ basic_block_rorder=0,
51
+ basic_block_lstride=1,
52
+ basic_block_rstride=0,
53
+ output_affine_size=140,
54
+ output_size=248,
55
+ use_softmax=True,
56
+ )
57
 
58
 
59
  if __name__ == "__main__":
toolbox/torchaudio/modules/freq_bands/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == "__main__":
6
+ pass
toolbox/torchaudio/modules/freq_bands/erb_bands.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import math
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+
10
+ class ErbBandsNumpy(object):
11
+
12
+ @staticmethod
13
+ def freq2erb(freq_hz: float) -> float:
14
+ """
15
+ https://www.cnblogs.com/LXP-Never/p/16011229.html
16
+ 1 / (24.7 * 9.265) = 0.00436976
17
+ """
18
+ return 9.265 * math.log(freq_hz / (24.7 * 9.265) + 1)
19
+
20
+ @staticmethod
21
+ def erb2freq(n_erb: float) -> float:
22
+ return 24.7 * 9.265 * (math.exp(n_erb / 9.265) - 1)
23
+
24
+ @classmethod
25
+ def get_erb_widths(cls, sample_rate: int, nfft: int, erb_bins: int, min_freq_bins_for_erb: int) -> np.ndarray:
26
+ """
27
+ https://github.com/Rikorose/DeepFilterNet/blob/main/libDF/src/lib.rs
28
+ :param sample_rate:
29
+ :param nfft:
30
+ :param erb_bins: erb (Equivalent Rectangular Bandwidth) 等效矩形带宽的通道数.
31
+ :param min_freq_bins_for_erb: Minimum number of frequency bands per erb band
32
+ :return:
33
+ """
34
+ nyq_freq = sample_rate / 2.
35
+ freq_width: float = sample_rate / nfft
36
+
37
+ min_erb: float = cls.freq2erb(0.)
38
+ max_erb: float = cls.freq2erb(nyq_freq)
39
+
40
+ erb = [0] * erb_bins
41
+ step = (max_erb - min_erb) / erb_bins
42
+
43
+ prev_freq_bin = 0
44
+ freq_over = 0
45
+ for i in range(1, erb_bins + 1):
46
+ f = cls.erb2freq(min_erb + i * step)
47
+ freq_bin = int(round(f / freq_width))
48
+ freq_bins = freq_bin - prev_freq_bin - freq_over
49
+
50
+ if freq_bins < min_freq_bins_for_erb:
51
+ freq_over = min_freq_bins_for_erb - freq_bins
52
+ freq_bins = min_freq_bins_for_erb
53
+ else:
54
+ freq_over = 0
55
+ erb[i - 1] = freq_bins
56
+ prev_freq_bin = freq_bin
57
+
58
+ erb[erb_bins - 1] += 1
59
+ too_large = sum(erb) - (nfft / 2 + 1)
60
+ if too_large > 0:
61
+ erb[erb_bins - 1] -= too_large
62
+ return np.array(erb, dtype=np.uint64)
63
+
64
+ @staticmethod
65
+ def get_erb_filter_bank(erb_widths: np.ndarray,
66
+ normalized: bool = True,
67
+ inverse: bool = False,
68
+ ):
69
+ num_freq_bins = int(np.sum(erb_widths))
70
+ num_erb_bins = len(erb_widths)
71
+
72
+ fb: np.ndarray = np.zeros(shape=(num_freq_bins, num_erb_bins))
73
+
74
+ points = np.cumsum([0] + erb_widths.tolist()).astype(int)[:-1]
75
+ for i, (b, w) in enumerate(zip(points.tolist(), erb_widths.tolist())):
76
+ fb[b: b + w, i] = 1
77
+
78
+ if inverse:
79
+ fb = fb.T
80
+ if not normalized:
81
+ fb /= np.sum(fb, axis=1, keepdims=True)
82
+ else:
83
+ if normalized:
84
+ fb /= np.sum(fb, axis=0)
85
+ return fb
86
+
87
+ @staticmethod
88
+ def spec2erb(spec: np.ndarray, erb_fb: np.ndarray, db: bool = True):
89
+ """
90
+ ERB filterbank and transform to decibel scale.
91
+
92
+ :param spec: Spectrum of shape [B, C, T, F].
93
+ :param erb_fb: ERB filterbank array of shape [B] containing the ERB widths,
94
+ where B are the number of ERB bins.
95
+ :param db: Whether to transform the output into decibel scale. Defaults to `True`.
96
+ :return:
97
+ """
98
+ # complex spec to power spec. (real * real + image * image)
99
+ spec_ = np.abs(spec) ** 2
100
+
101
+ # spec to erb feature.
102
+ erb_feat = np.matmul(spec_, erb_fb)
103
+
104
+ if db:
105
+ erb_feat = 10 * np.log10(erb_feat + 1e-10)
106
+
107
+ erb_feat = np.array(erb_feat, dtype=np.float32)
108
+ return erb_feat
109
+
110
+
111
+ class ErbBands(nn.Module):
112
+ def __init__(self,
113
+ sample_rate: int = 8000,
114
+ nfft: int = 512,
115
+ erb_bins: int = 32,
116
+ min_freq_bins_for_erb: int = 2,
117
+ ):
118
+ super().__init__()
119
+ self.sample_rate = sample_rate
120
+ self.nfft = nfft
121
+ self.erb_bins = erb_bins
122
+ self.min_freq_bins_for_erb = min_freq_bins_for_erb
123
+
124
+ erb_fb, erb_fb_inv = self.init_erb_fb()
125
+ erb_fb = torch.tensor(erb_fb, dtype=torch.float32, requires_grad=False)
126
+ erb_fb_inv = torch.tensor(erb_fb_inv, dtype=torch.float32, requires_grad=False)
127
+ self.erb_fb = nn.Parameter(erb_fb, requires_grad=False)
128
+ self.erb_fb_inv = nn.Parameter(erb_fb_inv, requires_grad=False)
129
+
130
+ def init_erb_fb(self):
131
+ erb_widths = ErbBandsNumpy.get_erb_widths(
132
+ sample_rate=self.sample_rate,
133
+ nfft=self.nfft,
134
+ erb_bins=self.erb_bins,
135
+ min_freq_bins_for_erb=self.min_freq_bins_for_erb,
136
+ )
137
+ erb_fb = ErbBandsNumpy.get_erb_filter_bank(
138
+ erb_widths=erb_widths,
139
+ normalized=True,
140
+ inverse=False,
141
+ )
142
+ erb_fb_inv = ErbBandsNumpy.get_erb_filter_bank(
143
+ erb_widths=erb_widths,
144
+ normalized=True,
145
+ inverse=True,
146
+ )
147
+ return erb_fb, erb_fb_inv
148
+
149
+ def erb_scale(self, spec: torch.Tensor, db: bool = True):
150
+ # spec shape: (b, t, f)
151
+ spec_erb = torch.matmul(spec, self.erb_fb)
152
+ if db:
153
+ spec_erb = 10 * torch.log10(spec_erb + 1e-10)
154
+ return spec_erb
155
+
156
+ def erb_scale_inv(self, spec_erb: torch.Tensor):
157
+ spec = torch.matmul(spec_erb, self.erb_fb_inv)
158
+ return spec
159
+
160
+
161
+ def main():
162
+
163
+ erb_bands = ErbBands()
164
+
165
+ spec = torch.randn(size=(2, 199, 257), dtype=torch.float32)
166
+ spec_erb = erb_bands.erb_scale(spec)
167
+ print(spec_erb.shape)
168
+
169
+ spec = erb_bands.erb_scale_inv(spec_erb)
170
+ print(spec.shape)
171
+
172
+ return
173
+
174
+
175
+ if __name__ == "__main__":
176
+ main()
toolbox/torchaudio/modules/freq_bands/mel_bands.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == "__main__":
6
+ pass
toolbox/webrtcvad/vad.py CHANGED
@@ -24,13 +24,15 @@ class WebRTCVad(object):
24
  frame_duration_ms: int = 30,
25
  padding_duration_ms: int = 300,
26
  silence_duration_threshold: float = 0.3,
27
- sample_rate: int = 8000
 
28
  ):
29
  self.agg = agg
30
  self.frame_duration_ms = frame_duration_ms
31
  self.padding_duration_ms = padding_duration_ms
32
  self.silence_duration_threshold = silence_duration_threshold
33
  self.sample_rate = sample_rate
 
34
 
35
  self._vad = webrtcvad.Vad(mode=agg)
36
 
@@ -110,7 +112,7 @@ class WebRTCVad(object):
110
  self.ring_buffer.append((frame, is_speech))
111
  num_voiced = len([f for f, speech in self.ring_buffer if speech])
112
 
113
- if num_voiced > 0.9 * self.ring_buffer.maxlen:
114
  self.triggered = True
115
 
116
  for f, _ in self.ring_buffer:
@@ -120,7 +122,7 @@ class WebRTCVad(object):
120
  self.voiced_frames.append(frame)
121
  self.ring_buffer.append((frame, is_speech))
122
  num_unvoiced = len([f for f, speech in self.ring_buffer if not speech])
123
- if num_unvoiced > 0.9 * self.ring_buffer.maxlen:
124
  self.triggered = False
125
  segment = [
126
  np.concatenate([f.signal for f in self.voiced_frames]),
@@ -204,12 +206,12 @@ def get_args():
204
  )
205
  parser.add_argument(
206
  "--padding_duration_ms",
207
- default=300,
208
  type=int,
209
  )
210
  parser.add_argument(
211
  "--silence_duration_threshold",
212
- default=0.3,
213
  type=float,
214
  help="minimum silence duration, in seconds."
215
  )
 
24
  frame_duration_ms: int = 30,
25
  padding_duration_ms: int = 300,
26
  silence_duration_threshold: float = 0.3,
27
+ sample_rate: int = 8000,
28
+ ring_buffer_activity_threshold: float = 0.9,
29
  ):
30
  self.agg = agg
31
  self.frame_duration_ms = frame_duration_ms
32
  self.padding_duration_ms = padding_duration_ms
33
  self.silence_duration_threshold = silence_duration_threshold
34
  self.sample_rate = sample_rate
35
+ self.ring_buffer_activity_threshold = ring_buffer_activity_threshold
36
 
37
  self._vad = webrtcvad.Vad(mode=agg)
38
 
 
112
  self.ring_buffer.append((frame, is_speech))
113
  num_voiced = len([f for f, speech in self.ring_buffer if speech])
114
 
115
+ if num_voiced > self.ring_buffer_activity_threshold * self.ring_buffer.maxlen:
116
  self.triggered = True
117
 
118
  for f, _ in self.ring_buffer:
 
122
  self.voiced_frames.append(frame)
123
  self.ring_buffer.append((frame, is_speech))
124
  num_unvoiced = len([f for f, speech in self.ring_buffer if not speech])
125
+ if num_unvoiced > self.ring_buffer_activity_threshold * self.ring_buffer.maxlen:
126
  self.triggered = False
127
  segment = [
128
  np.concatenate([f.signal for f in self.voiced_frames]),
 
206
  )
207
  parser.add_argument(
208
  "--padding_duration_ms",
209
+ default=30,
210
  type=int,
211
  )
212
  parser.add_argument(
213
  "--silence_duration_threshold",
214
+ default=0.0,
215
  type=float,
216
  help="minimum silence duration, in seconds."
217
  )