Spaces:
Running
on
Zero
Running
on
Zero
File size: 17,133 Bytes
8e8cd3e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 |
import numpy as np
import torch.nn.functional as F
import torch
import torch.nn as nn
from models.mossformer_gan_se.conformer import ConformerBlock
class LearnableSigmoid(nn.Module):
"""A learnable sigmoid activation function that scales the output
based on the input features.
Args:
in_features (int): The number of input features for the sigmoid function.
beta (float, optional): A scaling factor for the sigmoid output. Default is 1.
Attributes:
beta (float): The scaling factor for the sigmoid function.
slope (Parameter): Learnable parameter that adjusts the slope of the sigmoid.
"""
def __init__(self, in_features, beta=1):
"""Initializes the LearnableSigmoid module.
Args:
in_features (int): Number of input features.
beta (float, optional): Scaling factor for the sigmoid output.
"""
super().__init__()
self.beta = beta # Scaling factor for the sigmoid
self.slope = nn.Parameter(torch.ones(in_features)) # Learnable slope parameter
self.slope.requiresGrad = True # Ensure gradient updates
def forward(self, x):
"""Forward pass of the learnable sigmoid function.
Args:
x (torch.Tensor): Input tensor with shape [batch_size, in_features].
Returns:
torch.Tensor: The scaled sigmoid output tensor.
"""
return self.beta * torch.sigmoid(self.slope * x)
#%% Spectrograms
def segment_specs(y, seg_length=15, seg_hop=4, max_length=None):
"""Segments a spectrogram into smaller segments for input to a CNN.
Each segment includes neighboring frequency bins to preserve
contextual information.
Args:
y (torch.Tensor): Input spectrogram tensor of shape [B, H, W],
where B is batch size, H is number of mel bands,
and W is the length of the spectrogram.
seg_length (int): Length of each segment (must be odd). Default is 15.
seg_hop (int): Hop length for segmenting the spectrogram. Default is 4.
max_length (int, optional): Maximum number of windows allowed. If the number of
windows exceeds this, a ValueError is raised.
Returns:
torch.Tensor: Segmented tensor with shape [B*n, C, H, seg_length], where n is the
number of segments, C is the number of channels (always 1).
Raises:
ValueError: If seg_length is even or if the number of windows exceeds max_length.
"""
# Ensure segment length is odd
if seg_length % 2 == 0:
raise ValueError('seg_length must be odd! (seg_length={})'.format(seg_length))
# Convert input to tensor if it's not already
if not torch.is_tensor(y):
y = torch.tensor(y)
B, _, _ = y.size() # Extract batch size and dimensions
for b in range(B):
x = y[b, :, :] # Extract the current batch's spectrogram
n_wins = x.shape[1] - (seg_length - 1) # Calculate number of windows
# Segment the mel-spectrogram
idx1 = torch.arange(seg_length) # Indices for segment length
idx2 = torch.arange(n_wins) # Indices for number of windows
idx3 = idx1.unsqueeze(0) + idx2.unsqueeze(1) # Create indices for segments
x = x.transpose(1, 0)[idx3, :].unsqueeze(1).transpose(3, 2) # Rearrange dimensions for CNN input
# Adjust segments based on hop length
if seg_hop > 1:
x = x[::seg_hop, :] # Downsample segments
n_wins = int(np.ceil(n_wins / seg_hop)) # Update number of windows
# Pad the segments if max_length is specified
if max_length is not None:
if max_length < n_wins:
raise ValueError('n_wins {} > max_length {}. Increase max window length max_segments!'.format(n_wins, max_length))
x_padded = torch.zeros((max_length, x.shape[1], x.shape[2], x.shape[3])) # Create a padded tensor
x_padded[:n_wins, :] = x # Fill the padded tensor with the segments
x = x_padded # Update x to the padded tensor
# Concatenate segments from each batch
if b == 0:
z = x.unsqueeze(0) # Initialize z for the first batch
else:
z = torch.cat((z, x.unsqueeze(0)), axis=0) # Concatenate to z
# Reshape the final tensor for output
B, n, c, f, t = z.size()
z = z.view(B * n, c, f, t) # Combine batch and segment dimensions
return z # Return the segmented spectrogram tensor
class AdaptCNN(nn.Module):
"""
AdaptCNN: A convolutional neural network (CNN) with adaptive max pooling that
can be used as a framewise model. This architecture is more flexible than a
standard CNN, which requires a fixed input dimension. The model consists of six
convolutional layers, with adaptive pooling at each layer to handle varying input sizes.
Args:
input_channels (int): Number of input channels (default is 2).
c_out_1 (int): Number of output channels for the first convolutional layer (default is 16).
c_out_2 (int): Number of output channels for the second convolutional layer (default is 32).
c_out_3 (int): Number of output channels for the third and subsequent convolutional layers (default is 64).
kernel_size (list or int): Size of the convolutional kernels (default is [3, 3]).
dropout (float): Dropout rate for regularization (default is 0.2).
pool_1 (list): Pooling parameters for the first adaptive pooling layer (default is [101, 7]).
pool_2 (list): Pooling parameters for the second adaptive pooling layer (default is [50, 7]).
pool_3 (list): Pooling parameters for the third adaptive pooling layer (default is [25, 5]).
pool_4 (list): Pooling parameters for the fourth adaptive pooling layer (default is [12, 5]).
pool_5 (list): Pooling parameters for the fifth adaptive pooling layer (default is [6, 3]).
fc_out_h (int, optional): Number of output units for the final fully connected layer. If None, the output size is determined from previous layers.
Attributes:
name (str): Name of the model.
dropout (Dropout): Dropout layer for regularization.
conv1, conv2, conv3, conv4, conv5, conv6 (Conv2d): Convolutional layers.
bn1, bn2, bn3, bn4, bn5, bn6 (BatchNorm2d): Batch normalization layers.
fc (Linear, optional): Fully connected layer.
fan_out (int): Output dimension of the final layer.
"""
def __init__(self,
input_channels=2,
c_out_1=16,
c_out_2=32,
c_out_3=64,
kernel_size=[3, 3],
dropout=0.2,
pool_1=[101, 7],
pool_2=[50, 7],
pool_3=[25, 5],
pool_4=[12, 5],
pool_5=[6, 3],
fc_out_h=None):
"""Initializes the AdaptCNN model with the specified parameters."""
super().__init__()
self.name = 'CNN_adapt'
# Model parameters
self.input_channels = input_channels
self.c_out_1 = c_out_1
self.c_out_2 = c_out_2
self.c_out_3 = c_out_3
self.kernel_size = kernel_size
self.pool_1 = pool_1
self.pool_2 = pool_2
self.pool_3 = pool_3
self.pool_4 = pool_4
self.pool_5 = pool_5
self.dropout_rate = dropout
self.fc_out_h = fc_out_h
# Dropout layer for regularization
self.dropout = nn.Dropout2d(p=self.dropout_rate)
# Ensure kernel_size is a tuple
if isinstance(self.kernel_size, int):
self.kernel_size = (self.kernel_size, self.kernel_size)
# Set kernel size for the last convolutional layer
self.kernel_size_last = (self.kernel_size[0], self.pool_5[1])
# Determine padding for convolutional layers based on kernel size
if self.kernel_size[1] == 1:
self.cnn_pad = (1, 0) # No padding needed for 1D convolution
else:
self.cnn_pad = (1, 1) # Padding for 2D convolution
# Define convolutional layers with batch normalization
self.conv1 = nn.Conv2d(self.input_channels, self.c_out_1, self.kernel_size, padding=self.cnn_pad)
self.bn1 = nn.BatchNorm2d(self.conv1.out_channels)
self.conv2 = nn.Conv2d(self.conv1.out_channels, self.c_out_2, self.kernel_size, padding=self.cnn_pad)
self.bn2 = nn.BatchNorm2d(self.conv2.out_channels)
self.conv3 = nn.Conv2d(self.conv2.out_channels, self.c_out_3, self.kernel_size, padding=self.cnn_pad)
self.bn3 = nn.BatchNorm2d(self.conv3.out_channels)
self.conv4 = nn.Conv2d(self.conv3.out_channels, self.c_out_3, self.kernel_size, padding=self.cnn_pad)
self.bn4 = nn.BatchNorm2d(self.conv4.out_channels)
self.conv5 = nn.Conv2d(self.conv4.out_channels, self.c_out_3, self.kernel_size, padding=self.cnn_pad)
self.bn5 = nn.BatchNorm2d(self.conv5.out_channels)
self.conv6 = nn.Conv2d(self.conv5.out_channels, self.c_out_3, self.kernel_size_last, padding=(1, 0))
self.bn6 = nn.BatchNorm2d(self.conv6.out_channels)
# Define fully connected layer if output size is specified
if self.fc_out_h:
self.fc = nn.Linear(self.conv6.out_channels * self.pool_3[0], self.fc_out_h)
self.fan_out = self.fc_out_h
else:
self.fan_out = (self.conv6.out_channels * self.pool_3[0])
def forward(self, x):
"""Defines the forward pass of the AdaptCNN model.
Args:
x (torch.Tensor): Input tensor of shape [batch_size, input_channels, height, width].
Returns:
torch.Tensor: Output tensor after passing through the CNN layers.
"""
# Forward pass through each layer with ReLU activation and adaptive pooling
x = F.relu(self.bn1(self.conv1(x))) # First convolutional layer
x = F.adaptive_max_pool2d(x, output_size=(self.pool_1)) # Adaptive pooling after conv1
x = F.relu(self.bn2(self.conv2(x))) # Second convolutional layer
x = F.adaptive_max_pool2d(x, output_size=(self.pool_2)) # Adaptive pooling after conv2
x = self.dropout(x) # Apply dropout
x = F.relu(self.bn3(self.conv3(x))) # Third convolutional layer
x = F.adaptive_max_pool2d(x, output_size=(self.pool_3)) # Adaptive pooling after conv3
x = self.dropout(x) # Apply dropout
x = F.relu(self.bn4(self.conv4(x))) # Fourth convolutional layer
x = F.adaptive_max_pool2d(x, output_size=(self.pool_4)) # Adaptive pooling after conv4
x = self.dropout(x) # Apply dropout
x = F.relu(self.bn5(self.conv5(x))) # Fifth convolutional layer
x = F.adaptive_max_pool2d(x, output_size=(self.pool_5)) # Adaptive pooling after conv5
x = self.dropout(x) # Apply dropout
x = F.relu(self.bn6(self.conv6(x))) # Last convolutional layer
# Flatten the output for the fully connected layer
x = x.view(-1, self.conv6.out_channels * self.pool_5[0])
# Apply fully connected layer if defined
if self.fc_out_h:
x = self.fc(x) # Fully connected output
return x # Return the output tensor
class PoolAttFF(nn.Module):
"""
PoolAttFF: An attention pooling module with an additional feed-forward network.
This module performs attention-based pooling on input features followed by a
feed-forward neural network. The attention mechanism helps in focusing on the
important parts of the input while pooling.
Args:
d_input (int): The dimensionality of the input features (default is 384).
output_size (int): The size of the output after the feed-forward network (default is 1).
h (int): The size of the hidden layer in the feed-forward network (default is 128).
dropout (float): The dropout rate for regularization (default is 0.1).
Attributes:
linear1 (Linear): First linear layer transforming input features to hidden size.
linear2 (Linear): Second linear layer producing attention scores.
linear3 (Linear): Final linear layer producing the output.
activation (function): Activation function used in the network (ReLU).
dropout (Dropout): Dropout layer for regularization.
"""
def __init__(self, d_input=384, output_size=1, h=128, dropout=0.1):
"""Initializes the PoolAttFF module with the specified parameters."""
super().__init__()
# Define the feed-forward layers
self.linear1 = nn.Linear(d_input, h) # First linear layer
self.linear2 = nn.Linear(h, 1) # Second linear layer for attention scores
self.linear3 = nn.Linear(d_input, output_size) # Final output layer
self.activation = F.relu # Activation function
self.dropout = nn.Dropout(dropout) # Dropout layer for regularization
def forward(self, x):
"""Defines the forward pass of the PoolAttFF module.
Args:
x (torch.Tensor): Input tensor of shape [batch_size, seq_len, d_input].
Returns:
torch.Tensor: Output tensor after attention pooling and feed-forward network.
"""
# Compute attention scores
att = self.linear2(self.dropout(self.activation(self.linear1(x))))
att = att.transpose(2, 1) # Transpose for softmax
# Apply softmax to get attention weights
att = F.softmax(att, dim=2) # Softmax along the sequence length
# Perform attention pooling
x = torch.bmm(att, x) # Batch matrix multiplication
x = x.squeeze(1) # Remove unnecessary dimension
x = self.linear3(x) # Final output layer
return x # Return the output tensor
class Discriminator(nn.Module):
"""
Discriminator: A neural network that predicts a normalized PESQ value
between a predicted waveform (x) and a ground truth waveform (y).
The model concatenates the two input waveforms, processes them through
a convolutional network (CNN), applies self-attention, and outputs a
value between 0 and 1 using a sigmoid activation function.
Args:
ndf (int): Number of filters in the convolutional layers (not directly used in this implementation).
in_channel (int): Number of input channels (default is 2).
Attributes:
dim (int): Dimensionality of the feature representation (default is 384).
cnn (AdaptCNN): CNN model for feature extraction.
att (Sequential): Sequential stack of Conformer blocks for attention processing.
pool (PoolAttFF): Attention pooling module.
sigmoid (LearnableSigmoid): Sigmoid layer for final output.
"""
def __init__(self, ndf, in_channel=2):
"""Initializes the Discriminator with specified parameters."""
super().__init__()
self.dim = 384 # Dimensionality of the feature representation
self.cnn = AdaptCNN() # CNN model for feature extraction
# Define attention layers using Conformer blocks
self.att = nn.Sequential(
ConformerBlock(dim=self.dim, dim_head=self.dim // 4, heads=4,
conv_kernel_size=31, attn_dropout=0.2, ff_dropout=0.2),
ConformerBlock(dim=self.dim, dim_head=self.dim // 4, heads=4,
conv_kernel_size=31, attn_dropout=0.2, ff_dropout=0.2)
)
# Define attention pooling module
self.pool = PoolAttFF()
self.sigmoid = LearnableSigmoid(1) # Sigmoid layer for output normalization
def forward(self, x, y):
"""Defines the forward pass of the Discriminator.
Args:
x (torch.Tensor): Predicted waveform tensor of shape [batch_size, 1, height, width].
y (torch.Tensor): Ground truth waveform tensor of shape [batch_size, 1, height, width].
Returns:
torch.Tensor: Output tensor representing the predicted PESQ value.
"""
B, _, _, _ = x.size() # Get the batch size from input x
x = segment_specs(x.squeeze(1)) # Segment and process predicted waveform
y = segment_specs(y.squeeze(1)) # Segment and process ground truth waveform
# Concatenate the processed waveforms
xy = torch.cat([x, y], dim=1) # Concatenate along the channel dimension
cnn_out = self.cnn(xy) # Extract features using CNN
_, d = cnn_out.size() # Get dimensions of CNN output
cnn_out = cnn_out.view(B, -1, d) # Reshape for attention processing
att_out = self.att(cnn_out) # Apply self-attention layers
pool_out = self.pool(att_out) # Apply attention pooling module
out = self.sigmoid(pool_out) # Normalize output using sigmoid function
return out # Return the predicted PESQ value
|