File size: 6,440 Bytes
210822a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
from typing import Any, Union
from transformers.configuration_utils import PretrainedConfig
class XVectorConfig(PretrainedConfig):
def __init__(
self,
sample_rate: int = 16000,
window_size: float = 0.02,
window_stride: float = 0.01,
n_window_size: Any = None,
n_window_stride: Any = None,
window: str = "hann",
normalize: str = "per_feature",
n_fft: Any = None,
preemph: float = 0.97,
features: int = 64,
lowfreq: int = 0,
highfreq: Any = None,
log: bool = True,
log_zero_guard_type: str = "add",
log_zero_guard_value: Any = 2 ** -24,
dither: float = 0.00001,
pad_to: int = 16,
frame_splicing: int = 1,
exact_pad: bool = False,
pad_value: int = 0,
mag_power: float = 2,
rng: Any = None,
nb_augmentation_prob: float = 0,
nb_max_freq: int = 4000,
use_torchaudio: bool = True,
mel_norm: str = "slaney",
freq_masks: int = 0,
time_masks: int = 0,
freq_width: int = 10,
time_width: int = 10,
rect_masks: int = 0,
rect_time: int = 5,
rect_freq: int = 20,
mask_value: float = 0,
use_vectorized_spec_augment: bool = True,
filters: list = [512, 512, 512, 512, 1500],
kernel_sizes: list = [5, 3, 3, 1, 1],
dilations: list = [1, 2, 3, 1, 1],
init_mode: str = 'xavier_uniform',
emb_sizes: Union[int, list] = 256,
pool_mode: str = 'xvector',
attention_channels: int = 128,
objective: str = 'cross_entropy', # additive_margin, additive_angular_margin, cross_entropy
angular_scale = 30,
angular_margin: float = 0.2,
label_smoothing: float = 0.0,
initializer_range=0.02,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
**kwargs,
):
super().__init__(**kwargs, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id)
self.initializer_range = initializer_range
# Mel-spectrogram configuration
self.sample_rate = sample_rate
self.window_size = window_size
self.window_stride = window_stride
self.n_window_size = n_window_size
self.n_window_stride = n_window_stride
self.window = window
self.normalize = normalize
self.n_fft = n_fft
self.preemph = preemph
self.features = features
self.lowfreq = lowfreq
self.highfreq = highfreq
self.log = log
self.log_zero_guard_type = log_zero_guard_type
self.log_zero_guard_value = log_zero_guard_value
self.dither = dither
self.pad_to = pad_to
self.frame_splicing = frame_splicing
self.exact_pad = exact_pad
self.pad_value = pad_value
self.mag_power = mag_power
self.rng = rng
self.nb_augmentation_prob = nb_augmentation_prob
self.nb_max_freq = nb_max_freq
self.use_torchaudio = use_torchaudio
self.mel_norm = mel_norm
self.mel_spectrogram_config = {
"sample_rate": sample_rate,
"window_size": window_size,
"window_stride": window_stride,
"n_window_size": n_window_size,
"n_window_stride": n_window_stride,
"window": window,
"normalize": normalize,
"n_fft": n_fft,
"preemph": preemph,
"features": features,
"lowfreq": lowfreq,
"highfreq": highfreq,
"log": log,
"log_zero_guard_type": log_zero_guard_type,
"log_zero_guard_value": log_zero_guard_value,
"dither": dither,
"pad_to": pad_to,
"frame_splicing": frame_splicing,
"exact_pad": exact_pad,
"pad_value": pad_value,
"mag_power": mag_power,
"rng": rng,
"nb_augmentation_prob": nb_augmentation_prob,
"nb_max_freq": nb_max_freq,
"use_torchaudio": use_torchaudio,
"mel_norm": mel_norm,
}
# Spectrogram Augmentation configuration
self.freq_masks = freq_masks
self.time_masks = time_masks
self.freq_width = freq_width
self.time_width = time_width
self.rect_masks = rect_masks
self.rect_time = rect_time
self.rect_freq = rect_freq
self.mask_value = mask_value
self.use_vectorized_spec_augment = use_vectorized_spec_augment
self.spectrogram_augmentation_config = {
"freq_masks": freq_masks,
"time_masks": time_masks,
"freq_width": freq_width,
"time_width": time_width,
"rect_masks": rect_masks,
"rect_time": rect_time,
"rect_freq": rect_freq,
"mask_value": mask_value,
"use_vectorized_spec_augment": use_vectorized_spec_augment,
}
# Encoder configuration
self.feat_in = features
self.filters = filters
self.kernel_sizes = kernel_sizes
self.dilations = dilations
self.init_mode = init_mode
self.encoder_config = {
"feat_in": features,
"filters": filters,
"kernel_sizes": kernel_sizes,
"dilations": dilations,
"init_mode": init_mode,
}
# Decoder configuration
self.emb_sizes = emb_sizes
self.pool_mode = pool_mode
self.angular = True if objective in ['additive_angular_margin'] else False
self.attention_channels = attention_channels
self.decoder_config = {
"feat_in": filters[-1],
"num_classes": self.num_labels,
"emb_sizes": emb_sizes,
"pool_mode": pool_mode,
"angular": self.angular,
"attention_channels": attention_channels,
"init_mode": init_mode,
}
# Loss function configuration
self.objective = objective
if objective in ['additive_angular_margin', 'additive_margin']:
self.objective_config = {
"scale": angular_scale,
"margin": angular_margin,
}
elif objective == 'cross_entropy':
self.objective_config = {
"label_smoothing": label_smoothing,
} |