|
from typing import Optional, Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from .module import NeuralModule |
|
from .tdnn_attention import StatsPoolLayer, AttentivePoolLayer, init_weights |
|
from .cnn import Conv1d |
|
from .normalization import BatchNorm1d |
|
|
|
|
|
class TDNNLayer(nn.Module): |
|
|
|
def __init__(self, in_conv_dim, out_conv_dim, kernel_size, dilation): |
|
super().__init__() |
|
self.in_conv_dim = in_conv_dim |
|
self.out_conv_dim = out_conv_dim |
|
self.kernel_size = kernel_size |
|
self.dilation = dilation |
|
|
|
self.kernel = nn.Linear(self.in_conv_dim * self.kernel_size, self.out_conv_dim) |
|
self.activation = nn.ReLU() |
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|
|
|
weight = self.kernel.weight.view(self.out_conv_dim, self.kernel_size, self.in_conv_dim).transpose(1, 2) |
|
hidden_states = nn.functional.conv1d(hidden_states, weight, self.kernel.bias, dilation=self.dilation) |
|
hidden_states = self.activation(hidden_states) |
|
return hidden_states |
|
|
|
|
|
class XVectorEncoder(NeuralModule): |
|
""" |
|
input: |
|
feat_in: input feature shape (mel spec feature shape) |
|
filters: list of filter shapes for SE_TDNN modules |
|
kernel_sizes: list of kernel shapes for SE_TDNN modules |
|
dilations: list of dilations for group conv se layer |
|
scale: scale value to group wider conv channels (deafult:8) |
|
|
|
output: |
|
outputs : encoded output |
|
output_length: masked output lengths |
|
""" |
|
|
|
def __init__( |
|
self, |
|
feat_in: int, |
|
filters: list, |
|
kernel_sizes: list, |
|
dilations: list, |
|
init_mode: str = 'xavier_uniform', |
|
): |
|
super().__init__() |
|
self.blocks = nn.ModuleList() |
|
|
|
|
|
in_channels = feat_in |
|
tdnn_blocks = len(filters) |
|
for block_index in range(tdnn_blocks): |
|
out_channels = filters[block_index] |
|
self.blocks.extend( |
|
[ |
|
Conv1d( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
kernel_size=kernel_sizes[block_index], |
|
dilation=dilations[block_index], |
|
), |
|
torch.nn.LeakyReLU(), |
|
BatchNorm1d(input_size=out_channels), |
|
] |
|
) |
|
in_channels = filters[block_index] |
|
|
|
self.apply(lambda x: init_weights(x, mode=init_mode)) |
|
|
|
def forward(self, audio_signal: torch.Tensor, length: torch.Tensor = None): |
|
""" |
|
audio_signal: tensor shape of (B, D, T) |
|
output: tensor shape of (B, D, T) |
|
""" |
|
x = audio_signal.transpose(1, 2) |
|
for layer in self.blocks: |
|
x = layer(x) |
|
output = x.transpose(1, 2) |
|
return output, length |
|
|
|
|
|
class SpeakerDecoder(NeuralModule): |
|
""" |
|
Speaker Decoder creates the final neural layers that maps from the outputs |
|
of Jasper Encoder to the embedding layer followed by speaker based softmax loss. |
|
|
|
Args: |
|
feat_in (int): Number of channels being input to this module |
|
num_classes (int): Number of unique speakers in dataset |
|
emb_sizes (list) : shapes of intermediate embedding layers (we consider speaker embbeddings |
|
from 1st of this layers). Defaults to [1024,1024] |
|
pool_mode (str) : Pooling strategy type. options are 'xvector','tap', 'attention' |
|
Defaults to 'xvector (mean and variance)' |
|
tap (temporal average pooling: just mean) |
|
attention (attention based pooling) |
|
init_mode (str): Describes how neural network parameters are |
|
initialized. Options are ['xavier_uniform', 'xavier_normal', |
|
'kaiming_uniform','kaiming_normal']. |
|
Defaults to "xavier_uniform". |
|
""" |
|
|
|
def __init__( |
|
self, |
|
feat_in: int, |
|
num_classes: int, |
|
emb_sizes: Optional[Union[int, list]] = 256, |
|
pool_mode: str = 'xvector', |
|
angular: bool = False, |
|
attention_channels: int = 128, |
|
init_mode: str = "xavier_uniform", |
|
): |
|
super().__init__() |
|
self.angular = angular |
|
self.emb_id = 2 |
|
bias = False if self.angular else True |
|
emb_sizes = [emb_sizes] if type(emb_sizes) is int else emb_sizes |
|
|
|
self._num_classes = num_classes |
|
self.pool_mode = pool_mode.lower() |
|
if self.pool_mode == 'xvector' or self.pool_mode == 'tap': |
|
self._pooling = StatsPoolLayer(feat_in=feat_in, pool_mode=self.pool_mode) |
|
affine_type = 'linear' |
|
elif self.pool_mode == 'attention': |
|
self._pooling = AttentivePoolLayer(inp_filters=feat_in, attention_channels=attention_channels) |
|
affine_type = 'conv' |
|
|
|
shapes = [self._pooling.feat_in] |
|
for size in emb_sizes: |
|
shapes.append(int(size)) |
|
|
|
emb_layers = [] |
|
for shape_in, shape_out in zip(shapes[:-1], shapes[1:]): |
|
layer = self.affine_layer(shape_in, shape_out, learn_mean=False, affine_type=affine_type) |
|
emb_layers.append(layer) |
|
|
|
self.emb_layers = nn.ModuleList(emb_layers) |
|
|
|
self.final = nn.Linear(shapes[-1], self._num_classes, bias=bias) |
|
|
|
self.apply(lambda x: init_weights(x, mode=init_mode)) |
|
|
|
def affine_layer( |
|
self, |
|
inp_shape, |
|
out_shape, |
|
learn_mean=True, |
|
affine_type='conv', |
|
): |
|
if affine_type == 'conv': |
|
layer = nn.Sequential( |
|
nn.BatchNorm1d(inp_shape, affine=True, track_running_stats=True), |
|
nn.Conv1d(inp_shape, out_shape, kernel_size=1), |
|
) |
|
|
|
else: |
|
layer = nn.Sequential( |
|
nn.Linear(inp_shape, out_shape), |
|
nn.BatchNorm1d(out_shape, affine=learn_mean, track_running_stats=True), |
|
nn.ReLU(), |
|
) |
|
|
|
return layer |
|
|
|
def forward(self, encoder_output, length: torch.Tensor = None): |
|
pool = self._pooling(encoder_output, length) |
|
embs = [] |
|
|
|
for layer in self.emb_layers: |
|
pool, emb = layer(pool), layer[: self.emb_id](pool) |
|
embs.append(emb) |
|
|
|
pool = pool.squeeze(-1) |
|
if self.angular: |
|
for W in self.final.parameters(): |
|
W = F.normalize(W, p=2, dim=1) |
|
pool = F.normalize(pool, p=2, dim=1) |
|
|
|
out = self.final(pool) |
|
|
|
return out, embs[-1].squeeze(-1) |