Spaces:
Paused
Paused
import torch | |
__all__ = ["DeepSpeech"] | |
class FullyConnected(torch.nn.Module): | |
""" | |
Args: | |
n_feature: Number of input features | |
n_hidden: Internal hidden unit size. | |
""" | |
def __init__(self, n_feature: int, n_hidden: int, dropout: float, relu_max_clip: int = 20) -> None: | |
super(FullyConnected, self).__init__() | |
self.fc = torch.nn.Linear(n_feature, n_hidden, bias=True) | |
self.relu_max_clip = relu_max_clip | |
self.dropout = dropout | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
x = self.fc(x) | |
x = torch.nn.functional.relu(x) | |
x = torch.nn.functional.hardtanh(x, 0, self.relu_max_clip) | |
if self.dropout: | |
x = torch.nn.functional.dropout(x, self.dropout, self.training) | |
return x | |
class DeepSpeech(torch.nn.Module): | |
"""DeepSpeech architecture introduced in | |
*Deep Speech: Scaling up end-to-end speech recognition* :cite:`hannun2014deep`. | |
Args: | |
n_feature: Number of input features | |
n_hidden: Internal hidden unit size. | |
n_class: Number of output classes | |
""" | |
def __init__( | |
self, | |
n_feature: int, | |
n_hidden: int = 2048, | |
n_class: int = 40, | |
dropout: float = 0.0, | |
) -> None: | |
super(DeepSpeech, self).__init__() | |
self.n_hidden = n_hidden | |
self.fc1 = FullyConnected(n_feature, n_hidden, dropout) | |
self.fc2 = FullyConnected(n_hidden, n_hidden, dropout) | |
self.fc3 = FullyConnected(n_hidden, n_hidden, dropout) | |
self.bi_rnn = torch.nn.RNN(n_hidden, n_hidden, num_layers=1, nonlinearity="relu", bidirectional=True) | |
self.fc4 = FullyConnected(n_hidden, n_hidden, dropout) | |
self.out = torch.nn.Linear(n_hidden, n_class) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
""" | |
Args: | |
x (torch.Tensor): Tensor of dimension (batch, channel, time, feature). | |
Returns: | |
Tensor: Predictor tensor of dimension (batch, time, class). | |
""" | |
# N x C x T x F | |
x = self.fc1(x) | |
# N x C x T x H | |
x = self.fc2(x) | |
# N x C x T x H | |
x = self.fc3(x) | |
# N x C x T x H | |
x = x.squeeze(1) | |
# N x T x H | |
x = x.transpose(0, 1) | |
# T x N x H | |
x, _ = self.bi_rnn(x) | |
# The fifth (non-recurrent) layer takes both the forward and backward units as inputs | |
x = x[:, :, : self.n_hidden] + x[:, :, self.n_hidden :] | |
# T x N x H | |
x = self.fc4(x) | |
# T x N x H | |
x = self.out(x) | |
# T x N x n_class | |
x = x.permute(1, 0, 2) | |
# N x T x n_class | |
x = torch.nn.functional.log_softmax(x, dim=2) | |
# N x T x n_class | |
return x | |