File size: 17,133 Bytes
8e8cd3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
import numpy as np
import torch.nn.functional as F
import torch
import torch.nn as nn
from models.mossformer_gan_se.conformer import ConformerBlock

class LearnableSigmoid(nn.Module):
    """A learnable sigmoid activation function that scales the output 
    based on the input features.

    Args:
        in_features (int): The number of input features for the sigmoid function.
        beta (float, optional): A scaling factor for the sigmoid output. Default is 1.
    
    Attributes:
        beta (float): The scaling factor for the sigmoid function.
        slope (Parameter): Learnable parameter that adjusts the slope of the sigmoid.
    """
    
    def __init__(self, in_features, beta=1):
        """Initializes the LearnableSigmoid module.

        Args:
            in_features (int): Number of input features.
            beta (float, optional): Scaling factor for the sigmoid output.
        """
        super().__init__()
        self.beta = beta  # Scaling factor for the sigmoid
        self.slope = nn.Parameter(torch.ones(in_features))  # Learnable slope parameter
        self.slope.requiresGrad = True  # Ensure gradient updates

    def forward(self, x):
        """Forward pass of the learnable sigmoid function.

        Args:
            x (torch.Tensor): Input tensor with shape [batch_size, in_features].

        Returns:
            torch.Tensor: The scaled sigmoid output tensor.
        """
        return self.beta * torch.sigmoid(self.slope * x)


#%% Spectrograms
def segment_specs(y, seg_length=15, seg_hop=4, max_length=None):
    """Segments a spectrogram into smaller segments for input to a CNN. 
    Each segment includes neighboring frequency bins to preserve 
    contextual information.

    Args:
        y (torch.Tensor): Input spectrogram tensor of shape [B, H, W], 
                          where B is batch size, H is number of mel bands, 
                          and W is the length of the spectrogram.
        seg_length (int): Length of each segment (must be odd). Default is 15.
        seg_hop (int): Hop length for segmenting the spectrogram. Default is 4.
        max_length (int, optional): Maximum number of windows allowed. If the number of 
                                     windows exceeds this, a ValueError is raised.

    Returns:
        torch.Tensor: Segmented tensor with shape [B*n, C, H, seg_length], where n is the 
                      number of segments, C is the number of channels (always 1).
    
    Raises:
        ValueError: If seg_length is even or if the number of windows exceeds max_length.
    """
    # Ensure segment length is odd
    if seg_length % 2 == 0:
        raise ValueError('seg_length must be odd! (seg_length={})'.format(seg_length))
    
    # Convert input to tensor if it's not already
    if not torch.is_tensor(y):
        y = torch.tensor(y)

    B, _, _ = y.size()  # Extract batch size and dimensions
    for b in range(B):
        x = y[b, :, :]  # Extract the current batch's spectrogram
        n_wins = x.shape[1] - (seg_length - 1)  # Calculate number of windows
        
        # Segment the mel-spectrogram
        idx1 = torch.arange(seg_length)  # Indices for segment length
        idx2 = torch.arange(n_wins)  # Indices for number of windows
        idx3 = idx1.unsqueeze(0) + idx2.unsqueeze(1)  # Create indices for segments
        x = x.transpose(1, 0)[idx3, :].unsqueeze(1).transpose(3, 2)  # Rearrange dimensions for CNN input

        # Adjust segments based on hop length
        if seg_hop > 1:
            x = x[::seg_hop, :]  # Downsample segments
            n_wins = int(np.ceil(n_wins / seg_hop))  # Update number of windows

        # Pad the segments if max_length is specified
        if max_length is not None:
            if max_length < n_wins:
                raise ValueError('n_wins {} > max_length {}. Increase max window length max_segments!'.format(n_wins, max_length))
            x_padded = torch.zeros((max_length, x.shape[1], x.shape[2], x.shape[3]))  # Create a padded tensor
            x_padded[:n_wins, :] = x  # Fill the padded tensor with the segments
            x = x_padded  # Update x to the padded tensor

        # Concatenate segments from each batch
        if b == 0:
            z = x.unsqueeze(0)  # Initialize z for the first batch
        else:
            z = torch.cat((z, x.unsqueeze(0)), axis=0)  # Concatenate to z

    # Reshape the final tensor for output
    B, n, c, f, t = z.size()
    z = z.view(B * n, c, f, t)  # Combine batch and segment dimensions
    return z  # Return the segmented spectrogram tensor

class AdaptCNN(nn.Module):
    """
    AdaptCNN: A convolutional neural network (CNN) with adaptive max pooling that 
    can be used as a framewise model. This architecture is more flexible than a 
    standard CNN, which requires a fixed input dimension. The model consists of six 
    convolutional layers, with adaptive pooling at each layer to handle varying input sizes.

    Args:
        input_channels (int): Number of input channels (default is 2).
        c_out_1 (int): Number of output channels for the first convolutional layer (default is 16).
        c_out_2 (int): Number of output channels for the second convolutional layer (default is 32).
        c_out_3 (int): Number of output channels for the third and subsequent convolutional layers (default is 64).
        kernel_size (list or int): Size of the convolutional kernels (default is [3, 3]).
        dropout (float): Dropout rate for regularization (default is 0.2).
        pool_1 (list): Pooling parameters for the first adaptive pooling layer (default is [101, 7]).
        pool_2 (list): Pooling parameters for the second adaptive pooling layer (default is [50, 7]).
        pool_3 (list): Pooling parameters for the third adaptive pooling layer (default is [25, 5]).
        pool_4 (list): Pooling parameters for the fourth adaptive pooling layer (default is [12, 5]).
        pool_5 (list): Pooling parameters for the fifth adaptive pooling layer (default is [6, 3]).
        fc_out_h (int, optional): Number of output units for the final fully connected layer. If None, the output size is determined from previous layers.

    Attributes:
        name (str): Name of the model.
        dropout (Dropout): Dropout layer for regularization.
        conv1, conv2, conv3, conv4, conv5, conv6 (Conv2d): Convolutional layers.
        bn1, bn2, bn3, bn4, bn5, bn6 (BatchNorm2d): Batch normalization layers.
        fc (Linear, optional): Fully connected layer.
        fan_out (int): Output dimension of the final layer.
    """
    
    def __init__(self, 
                 input_channels=2,
                 c_out_1=16, 
                 c_out_2=32,
                 c_out_3=64,
                 kernel_size=[3, 3], 
                 dropout=0.2,
                 pool_1=[101, 7],
                 pool_2=[50, 7],
                 pool_3=[25, 5],
                 pool_4=[12, 5],
                 pool_5=[6, 3],
                 fc_out_h=None):
        """Initializes the AdaptCNN model with the specified parameters."""
        super().__init__()
        self.name = 'CNN_adapt'

        # Model parameters
        self.input_channels = input_channels
        self.c_out_1 = c_out_1
        self.c_out_2 = c_out_2
        self.c_out_3 = c_out_3
        self.kernel_size = kernel_size
        self.pool_1 = pool_1
        self.pool_2 = pool_2
        self.pool_3 = pool_3
        self.pool_4 = pool_4
        self.pool_5 = pool_5
        self.dropout_rate = dropout
        self.fc_out_h = fc_out_h

        # Dropout layer for regularization
        self.dropout = nn.Dropout2d(p=self.dropout_rate)
        
        # Ensure kernel_size is a tuple
        if isinstance(self.kernel_size, int):
            self.kernel_size = (self.kernel_size, self.kernel_size)
            
        # Set kernel size for the last convolutional layer
        self.kernel_size_last = (self.kernel_size[0], self.pool_5[1])
            
        # Determine padding for convolutional layers based on kernel size
        if self.kernel_size[1] == 1:
            self.cnn_pad = (1, 0)  # No padding needed for 1D convolution
        else:
            self.cnn_pad = (1, 1)   # Padding for 2D convolution
            
        # Define convolutional layers with batch normalization
        self.conv1 = nn.Conv2d(self.input_channels, self.c_out_1, self.kernel_size, padding=self.cnn_pad)
        self.bn1 = nn.BatchNorm2d(self.conv1.out_channels)

        self.conv2 = nn.Conv2d(self.conv1.out_channels, self.c_out_2, self.kernel_size, padding=self.cnn_pad)
        self.bn2 = nn.BatchNorm2d(self.conv2.out_channels)

        self.conv3 = nn.Conv2d(self.conv2.out_channels, self.c_out_3, self.kernel_size, padding=self.cnn_pad)
        self.bn3 = nn.BatchNorm2d(self.conv3.out_channels)

        self.conv4 = nn.Conv2d(self.conv3.out_channels, self.c_out_3, self.kernel_size, padding=self.cnn_pad)
        self.bn4 = nn.BatchNorm2d(self.conv4.out_channels)

        self.conv5 = nn.Conv2d(self.conv4.out_channels, self.c_out_3, self.kernel_size, padding=self.cnn_pad)
        self.bn5 = nn.BatchNorm2d(self.conv5.out_channels)

        self.conv6 = nn.Conv2d(self.conv5.out_channels, self.c_out_3, self.kernel_size_last, padding=(1, 0))
        self.bn6 = nn.BatchNorm2d(self.conv6.out_channels)
        
        # Define fully connected layer if output size is specified
        if self.fc_out_h:
            self.fc = nn.Linear(self.conv6.out_channels * self.pool_3[0], self.fc_out_h)
            self.fan_out = self.fc_out_h
        else:
            self.fan_out = (self.conv6.out_channels * self.pool_3[0])

    def forward(self, x):
        """Defines the forward pass of the AdaptCNN model.

        Args:
            x (torch.Tensor): Input tensor of shape [batch_size, input_channels, height, width].

        Returns:
            torch.Tensor: Output tensor after passing through the CNN layers.
        """
        # Forward pass through each layer with ReLU activation and adaptive pooling
        x = F.relu(self.bn1(self.conv1(x)))  # First convolutional layer
        x = F.adaptive_max_pool2d(x, output_size=(self.pool_1))  # Adaptive pooling after conv1

        x = F.relu(self.bn2(self.conv2(x)))  # Second convolutional layer
        x = F.adaptive_max_pool2d(x, output_size=(self.pool_2))  # Adaptive pooling after conv2
        
        x = self.dropout(x)  # Apply dropout
        x = F.relu(self.bn3(self.conv3(x)))  # Third convolutional layer
        x = F.adaptive_max_pool2d(x, output_size=(self.pool_3))  # Adaptive pooling after conv3

        x = self.dropout(x)  # Apply dropout
        x = F.relu(self.bn4(self.conv4(x)))  # Fourth convolutional layer
        x = F.adaptive_max_pool2d(x, output_size=(self.pool_4))  # Adaptive pooling after conv4

        x = self.dropout(x)  # Apply dropout
        x = F.relu(self.bn5(self.conv5(x)))  # Fifth convolutional layer
        x = F.adaptive_max_pool2d(x, output_size=(self.pool_5))  # Adaptive pooling after conv5

        x = self.dropout(x)  # Apply dropout
        x = F.relu(self.bn6(self.conv6(x)))  # Last convolutional layer
        
        # Flatten the output for the fully connected layer
        x = x.view(-1, self.conv6.out_channels * self.pool_5[0])  
        
        # Apply fully connected layer if defined
        if self.fc_out_h:
            x = self.fc(x)  # Fully connected output
        
        return x  # Return the output tensor

class PoolAttFF(nn.Module):
    """
    PoolAttFF: An attention pooling module with an additional feed-forward network.
    
    This module performs attention-based pooling on input features followed by a 
    feed-forward neural network. The attention mechanism helps in focusing on the 
    important parts of the input while pooling.

    Args:
        d_input (int): The dimensionality of the input features (default is 384).
        output_size (int): The size of the output after the feed-forward network (default is 1).
        h (int): The size of the hidden layer in the feed-forward network (default is 128).
        dropout (float): The dropout rate for regularization (default is 0.1).

    Attributes:
        linear1 (Linear): First linear layer transforming input features to hidden size.
        linear2 (Linear): Second linear layer producing attention scores.
        linear3 (Linear): Final linear layer producing the output.
        activation (function): Activation function used in the network (ReLU).
        dropout (Dropout): Dropout layer for regularization.
    """
    
    def __init__(self, d_input=384, output_size=1, h=128, dropout=0.1):
        """Initializes the PoolAttFF module with the specified parameters."""
        super().__init__()

        # Define the feed-forward layers
        self.linear1 = nn.Linear(d_input, h)  # First linear layer
        self.linear2 = nn.Linear(h, 1)         # Second linear layer for attention scores

        self.linear3 = nn.Linear(d_input, output_size)  # Final output layer

        self.activation = F.relu  # Activation function
        self.dropout = nn.Dropout(dropout)  # Dropout layer for regularization

    def forward(self, x):
        """Defines the forward pass of the PoolAttFF module.

        Args:
            x (torch.Tensor): Input tensor of shape [batch_size, seq_len, d_input].

        Returns:
            torch.Tensor: Output tensor after attention pooling and feed-forward network.
        """
        # Compute attention scores
        att = self.linear2(self.dropout(self.activation(self.linear1(x))))
        att = att.transpose(2, 1)  # Transpose for softmax

        # Apply softmax to get attention weights
        att = F.softmax(att, dim=2)  # Softmax along the sequence length

        # Perform attention pooling
        x = torch.bmm(att, x)  # Batch matrix multiplication
        x = x.squeeze(1)  # Remove unnecessary dimension
        x = self.linear3(x)  # Final output layer

        return x  # Return the output tensor


class Discriminator(nn.Module):
    """
    Discriminator: A neural network that predicts a normalized PESQ value 
    between a predicted waveform (x) and a ground truth waveform (y).

    The model concatenates the two input waveforms, processes them through 
    a convolutional network (CNN), applies self-attention, and outputs a 
    value between 0 and 1 using a sigmoid activation function.

    Args:
        ndf (int): Number of filters in the convolutional layers (not directly used in this implementation).
        in_channel (int): Number of input channels (default is 2).

    Attributes:
        dim (int): Dimensionality of the feature representation (default is 384).
        cnn (AdaptCNN): CNN model for feature extraction.
        att (Sequential): Sequential stack of Conformer blocks for attention processing.
        pool (PoolAttFF): Attention pooling module.
        sigmoid (LearnableSigmoid): Sigmoid layer for final output.
    """
    
    def __init__(self, ndf, in_channel=2):
        """Initializes the Discriminator with specified parameters."""
        super().__init__()
        self.dim = 384  # Dimensionality of the feature representation
        self.cnn = AdaptCNN()  # CNN model for feature extraction

        # Define attention layers using Conformer blocks
        self.att = nn.Sequential(
            ConformerBlock(dim=self.dim, dim_head=self.dim // 4, heads=4,
                                             conv_kernel_size=31, attn_dropout=0.2, ff_dropout=0.2),
            ConformerBlock(dim=self.dim, dim_head=self.dim // 4, heads=4,
                                             conv_kernel_size=31, attn_dropout=0.2, ff_dropout=0.2)
        )

        # Define attention pooling module
        self.pool = PoolAttFF()
        self.sigmoid = LearnableSigmoid(1)  # Sigmoid layer for output normalization

    def forward(self, x, y):
        """Defines the forward pass of the Discriminator.

        Args:
            x (torch.Tensor): Predicted waveform tensor of shape [batch_size, 1, height, width].
            y (torch.Tensor): Ground truth waveform tensor of shape [batch_size, 1, height, width].

        Returns:
            torch.Tensor: Output tensor representing the predicted PESQ value.
        """
        B, _, _, _ = x.size()  # Get the batch size from input x
        x = segment_specs(x.squeeze(1))  # Segment and process predicted waveform
        y = segment_specs(y.squeeze(1))  # Segment and process ground truth waveform

        # Concatenate the processed waveforms
        xy = torch.cat([x, y], dim=1)  # Concatenate along the channel dimension
        cnn_out = self.cnn(xy)  # Extract features using CNN

        _, d = cnn_out.size()  # Get dimensions of CNN output
        cnn_out = cnn_out.view(B, -1, d)  # Reshape for attention processing
        att_out = self.att(cnn_out)  # Apply self-attention layers
        pool_out = self.pool(att_out)  # Apply attention pooling module
        
        out = self.sigmoid(pool_out)  # Normalize output using sigmoid function
        return out  # Return the predicted PESQ value