Spaces:
Running
Running
import math | |
from typing import List, Optional, Tuple | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
def transform_wb_pesq_range(x: float) -> float: | |
"""The metric defined by ITU-T P.862 is often called 'PESQ score', which is defined | |
for narrow-band signals and has a value range of [-0.5, 4.5] exactly. Here, we use the metric | |
defined by ITU-T P.862.2, commonly known as 'wide-band PESQ' and will be referred to as "PESQ score". | |
Args: | |
x (float): Narrow-band PESQ score. | |
Returns: | |
(float): Wide-band PESQ score. | |
""" | |
return 0.999 + (4.999 - 0.999) / (1 + math.exp(-1.3669 * x + 3.8224)) | |
PESQRange: Tuple[float, float] = ( | |
1.0, # P.862.2 uses a different input filter than P.862, and the lower bound of | |
# the raw score is not -0.5 anymore. It's hard to figure out the true lower bound. | |
# We are using 1.0 as a reasonable approximation. | |
transform_wb_pesq_range(4.5), | |
) | |
class RangeSigmoid(nn.Module): | |
def __init__(self, val_range: Tuple[float, float] = (0.0, 1.0)) -> None: | |
super(RangeSigmoid, self).__init__() | |
assert isinstance(val_range, tuple) and len(val_range) == 2 | |
self.val_range: Tuple[float, float] = val_range | |
self.sigmoid: nn.modules.Module = nn.Sigmoid() | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
out = self.sigmoid(x) * (self.val_range[1] - self.val_range[0]) + self.val_range[0] | |
return out | |
class Encoder(nn.Module): | |
"""Encoder module that transform 1D waveform to 2D representations. | |
Args: | |
feat_dim (int, optional): The feature dimension after Encoder module. (Default: 512) | |
win_len (int, optional): kernel size in the Conv1D layer. (Default: 32) | |
""" | |
def __init__(self, feat_dim: int = 512, win_len: int = 32) -> None: | |
super(Encoder, self).__init__() | |
self.conv1d = nn.Conv1d(1, feat_dim, win_len, stride=win_len // 2, bias=False) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
"""Apply waveforms to convolutional layer and ReLU layer. | |
Args: | |
x (torch.Tensor): Input waveforms. Tensor with dimensions `(batch, time)`. | |
Returns: | |
(torch,Tensor): Feature Tensor with dimensions `(batch, channel, frame)`. | |
""" | |
out = x.unsqueeze(dim=1) | |
out = F.relu(self.conv1d(out)) | |
return out | |
class SingleRNN(nn.Module): | |
def __init__(self, rnn_type: str, input_size: int, hidden_size: int, dropout: float = 0.0) -> None: | |
super(SingleRNN, self).__init__() | |
self.rnn_type = rnn_type | |
self.input_size = input_size | |
self.hidden_size = hidden_size | |
self.rnn: nn.modules.Module = getattr(nn, rnn_type)( | |
input_size, | |
hidden_size, | |
1, | |
dropout=dropout, | |
batch_first=True, | |
bidirectional=True, | |
) | |
self.proj = nn.Linear(hidden_size * 2, input_size) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
# input shape: batch, seq, dim | |
out, _ = self.rnn(x) | |
out = self.proj(out) | |
return out | |
class DPRNN(nn.Module): | |
"""*Dual-path recurrent neural networks (DPRNN)* :cite:`luo2020dual`. | |
Args: | |
feat_dim (int, optional): The feature dimension after Encoder module. (Default: 64) | |
hidden_dim (int, optional): Hidden dimension in the RNN layer of DPRNN. (Default: 128) | |
num_blocks (int, optional): Number of DPRNN layers. (Default: 6) | |
rnn_type (str, optional): Type of RNN in DPRNN. Valid options are ["RNN", "LSTM", "GRU"]. (Default: "LSTM") | |
d_model (int, optional): The number of expected features in the input. (Default: 256) | |
chunk_size (int, optional): Chunk size of input for DPRNN. (Default: 100) | |
chunk_stride (int, optional): Stride of chunk input for DPRNN. (Default: 50) | |
""" | |
def __init__( | |
self, | |
feat_dim: int = 64, | |
hidden_dim: int = 128, | |
num_blocks: int = 6, | |
rnn_type: str = "LSTM", | |
d_model: int = 256, | |
chunk_size: int = 100, | |
chunk_stride: int = 50, | |
) -> None: | |
super(DPRNN, self).__init__() | |
self.num_blocks = num_blocks | |
self.row_rnn = nn.ModuleList([]) | |
self.col_rnn = nn.ModuleList([]) | |
self.row_norm = nn.ModuleList([]) | |
self.col_norm = nn.ModuleList([]) | |
for _ in range(num_blocks): | |
self.row_rnn.append(SingleRNN(rnn_type, feat_dim, hidden_dim)) | |
self.col_rnn.append(SingleRNN(rnn_type, feat_dim, hidden_dim)) | |
self.row_norm.append(nn.GroupNorm(1, feat_dim, eps=1e-8)) | |
self.col_norm.append(nn.GroupNorm(1, feat_dim, eps=1e-8)) | |
self.conv = nn.Sequential( | |
nn.Conv2d(feat_dim, d_model, 1), | |
nn.PReLU(), | |
) | |
self.chunk_size = chunk_size | |
self.chunk_stride = chunk_stride | |
def pad_chunk(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]: | |
# input shape: (B, N, T) | |
seq_len = x.shape[-1] | |
rest = self.chunk_size - (self.chunk_stride + seq_len % self.chunk_size) % self.chunk_size | |
out = F.pad(x, [self.chunk_stride, rest + self.chunk_stride]) | |
return out, rest | |
def chunking(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]: | |
out, rest = self.pad_chunk(x) | |
batch_size, feat_dim, seq_len = out.shape | |
segments1 = out[:, :, : -self.chunk_stride].contiguous().view(batch_size, feat_dim, -1, self.chunk_size) | |
segments2 = out[:, :, self.chunk_stride :].contiguous().view(batch_size, feat_dim, -1, self.chunk_size) | |
out = torch.cat([segments1, segments2], dim=3) | |
out = out.view(batch_size, feat_dim, -1, self.chunk_size).transpose(2, 3).contiguous() | |
return out, rest | |
def merging(self, x: torch.Tensor, rest: int) -> torch.Tensor: | |
batch_size, dim, _, _ = x.shape | |
out = x.transpose(2, 3).contiguous().view(batch_size, dim, -1, self.chunk_size * 2) | |
out1 = out[:, :, :, : self.chunk_size].contiguous().view(batch_size, dim, -1)[:, :, self.chunk_stride :] | |
out2 = out[:, :, :, self.chunk_size :].contiguous().view(batch_size, dim, -1)[:, :, : -self.chunk_stride] | |
out = out1 + out2 | |
if rest > 0: | |
out = out[:, :, :-rest] | |
out = out.contiguous() | |
return out | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
x, rest = self.chunking(x) | |
batch_size, _, dim1, dim2 = x.shape | |
out = x | |
for row_rnn, row_norm, col_rnn, col_norm in zip(self.row_rnn, self.row_norm, self.col_rnn, self.col_norm): | |
row_in = out.permute(0, 3, 2, 1).contiguous().view(batch_size * dim2, dim1, -1).contiguous() | |
row_out = row_rnn(row_in) | |
row_out = row_out.view(batch_size, dim2, dim1, -1).permute(0, 3, 2, 1).contiguous() | |
row_out = row_norm(row_out) | |
out = out + row_out | |
col_in = out.permute(0, 2, 3, 1).contiguous().view(batch_size * dim1, dim2, -1).contiguous() | |
col_out = col_rnn(col_in) | |
col_out = col_out.view(batch_size, dim1, dim2, -1).permute(0, 3, 1, 2).contiguous() | |
col_out = col_norm(col_out) | |
out = out + col_out | |
out = self.conv(out) | |
out = self.merging(out, rest) | |
out = out.transpose(1, 2).contiguous() | |
return out | |
class AutoPool(nn.Module): | |
def __init__(self, pool_dim: int = 1) -> None: | |
super(AutoPool, self).__init__() | |
self.pool_dim: int = pool_dim | |
self.softmax: nn.modules.Module = nn.Softmax(dim=pool_dim) | |
self.register_parameter("alpha", nn.Parameter(torch.ones(1))) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
weight = self.softmax(torch.mul(x, self.alpha)) | |
out = torch.sum(torch.mul(x, weight), dim=self.pool_dim) | |
return out | |
class SquimObjective(nn.Module): | |
"""Speech Quality and Intelligibility Measures (SQUIM) model that predicts **objective** metric scores | |
for speech enhancement (e.g., STOI, PESQ, and SI-SDR). | |
Args: | |
encoder (torch.nn.Module): Encoder module to transform 1D waveform to 2D feature representation. | |
dprnn (torch.nn.Module): DPRNN module to model sequential feature. | |
branches (torch.nn.ModuleList): Transformer branches in which each branch estimate one objective metirc score. | |
""" | |
def __init__( | |
self, | |
encoder: nn.Module, | |
dprnn: nn.Module, | |
branches: nn.ModuleList, | |
): | |
super(SquimObjective, self).__init__() | |
self.encoder = encoder | |
self.dprnn = dprnn | |
self.branches = branches | |
def forward(self, x: torch.Tensor) -> List[torch.Tensor]: | |
""" | |
Args: | |
x (torch.Tensor): Input waveforms. Tensor with dimensions `(batch, time)`. | |
Returns: | |
List(torch.Tensor): List of score Tenosrs. Each Tensor is with dimension `(batch,)`. | |
""" | |
if x.ndim != 2: | |
raise ValueError(f"The input must be a 2D Tensor. Found dimension {x.ndim}.") | |
x = x / (torch.mean(x**2, dim=1, keepdim=True) ** 0.5 * 20) | |
out = self.encoder(x) | |
out = self.dprnn(out) | |
scores = [] | |
for branch in self.branches: | |
scores.append(branch(out).squeeze(dim=1)) | |
return scores | |
def _create_branch(d_model: int, nhead: int, metric: str) -> nn.modules.Module: | |
"""Create branch module after DPRNN model for predicting metric score. | |
Args: | |
d_model (int): The number of expected features in the input. | |
nhead (int): Number of heads in the multi-head attention model. | |
metric (str): The metric name to predict. | |
Returns: | |
(nn.Module): Returned module to predict corresponding metric score. | |
""" | |
layer1 = nn.TransformerEncoderLayer(d_model, nhead, d_model * 4, dropout=0.0, batch_first=True) | |
layer2 = AutoPool() | |
if metric == "stoi": | |
layer3 = nn.Sequential( | |
nn.Linear(d_model, d_model), | |
nn.PReLU(), | |
nn.Linear(d_model, 1), | |
RangeSigmoid(), | |
) | |
elif metric == "pesq": | |
layer3 = nn.Sequential( | |
nn.Linear(d_model, d_model), | |
nn.PReLU(), | |
nn.Linear(d_model, 1), | |
RangeSigmoid(val_range=PESQRange), | |
) | |
else: | |
layer3: nn.modules.Module = nn.Sequential(nn.Linear(d_model, d_model), nn.PReLU(), nn.Linear(d_model, 1)) | |
return nn.Sequential(layer1, layer2, layer3) | |
def squim_objective_model( | |
feat_dim: int, | |
win_len: int, | |
d_model: int, | |
nhead: int, | |
hidden_dim: int, | |
num_blocks: int, | |
rnn_type: str, | |
chunk_size: int, | |
chunk_stride: Optional[int] = None, | |
) -> SquimObjective: | |
"""Build a custome :class:`torchaudio.prototype.models.SquimObjective` model. | |
Args: | |
feat_dim (int, optional): The feature dimension after Encoder module. | |
win_len (int): Kernel size in the Encoder module. | |
d_model (int): The number of expected features in the input. | |
nhead (int): Number of heads in the multi-head attention model. | |
hidden_dim (int): Hidden dimension in the RNN layer of DPRNN. | |
num_blocks (int): Number of DPRNN layers. | |
rnn_type (str): Type of RNN in DPRNN. Valid options are ["RNN", "LSTM", "GRU"]. | |
chunk_size (int): Chunk size of input for DPRNN. | |
chunk_stride (int or None, optional): Stride of chunk input for DPRNN. | |
""" | |
if chunk_stride is None: | |
chunk_stride = chunk_size // 2 | |
encoder = Encoder(feat_dim, win_len) | |
dprnn = DPRNN(feat_dim, hidden_dim, num_blocks, rnn_type, d_model, chunk_size, chunk_stride) | |
branches = nn.ModuleList( | |
[ | |
_create_branch(d_model, nhead, "stoi"), | |
_create_branch(d_model, nhead, "pesq"), | |
_create_branch(d_model, nhead, "sisdr"), | |
] | |
) | |
return SquimObjective(encoder, dprnn, branches) | |
def squim_objective_base() -> SquimObjective: | |
"""Build :class:`torchaudio.prototype.models.SquimObjective` model with default arguments.""" | |
return squim_objective_model( | |
feat_dim=256, | |
win_len=64, | |
d_model=256, | |
nhead=4, | |
hidden_dim=256, | |
num_blocks=2, | |
rnn_type="LSTM", | |
chunk_size=71, | |
) | |