File size: 36,565 Bytes
8e8cd3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
import math
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from packaging.version import parse as V
from torch.nn import init
from torch.nn.parameter import Parameter

from models.mossformer_gan_se.fsmn import UniDeepFsmn
from models.mossformer_gan_se.conv_module import ConvModule
from models.mossformer_gan_se.mossformer import MossFormer
from models.mossformer_gan_se.se_layer import SELayer
from models.mossformer_gan_se.get_layer_from_string import get_layer
from models.mossformer_gan_se.discriminator import Discriminator

# Check if the installed version of PyTorch is 1.9.0 or higher
is_torch_1_9_plus = V(torch.__version__) >= V("1.9.0")


class MossFormerGAN_SE_16K(nn.Module):
    """
    MossFormerGAN_SE_16K: A GAN-based speech enhancement model for 16kHz input audio.

    This model integrates a synchronous attention network (SyncANet) for 
    feature extraction. Depending on the mode (train or inference), it may 
    also include a discriminator for adversarial training.

    Args:
        args (Namespace): Arguments containing configuration parameters, 
                          including 'fft_len' and 'mode'.
    """

    def __init__(self, args):
        """Initializes the MossFormerGAN_SE_16K model."""
        super(MossFormerGAN_SE_16K, self).__init__()
        
        # Initialize SyncANet with specified number of channels and features
        self.model = SyncANet(num_channel=64, num_features=args.fft_len // 2 + 1)

        # Initialize discriminator if in training mode
        if args.mode == 'train':
            self.discriminator = Discriminator(ndf=16)
        else:
            self.discriminator = None

    def forward(self, x):
        """
        Defines the forward pass of the MossFormerGAN_SE_16K model.

        Args:
            x (torch.Tensor): Input tensor of shape [batch_size, num_channels, height, width].

        Returns:
            Tuple[torch.Tensor, torch.Tensor]: Output tensors representing the real and imaginary parts.
        """
        output_real, output_imag = self.model(x)  # Get real and imaginary outputs from the model
        return output_real, output_imag  # Return the outputs


class FSMN_Wrap(nn.Module):
    """
    FSMN_Wrap: A wrapper around the UniDeepFsmn module to facilitate 
    integration into the larger model architecture.

    Args:
        nIn (int): Number of input features.
        nHidden (int): Number of hidden features in the FSMN (default is 128).
        lorder (int): Order of the FSMN (default is 20).
        nOut (int): Number of output features (default is 128).
    """

    def __init__(self, nIn, nHidden=128, lorder=20, nOut=128):
        """Initializes the FSMN_Wrap module with specified parameters."""
        super(FSMN_Wrap, self).__init__()

        # Initialize the UniDeepFsmn module
        self.fsmn = UniDeepFsmn(nIn, nHidden, lorder, nHidden)

    def forward(self, x):
        """
        Defines the forward pass of the FSMN_Wrap module.

        Args:
            x (torch.Tensor): Input tensor of shape [batch_size, channels, height, time, 2].

        Returns:
            torch.Tensor: Output tensor reshaped to [batch_size, channels, height, time].
        """
        # Shape of input x: [b, c, h, T, 2]
        b, c, T, h = x.size()

        # Permute x to reshape it for FSMN processing: [b, T, h, c]
        x = x.permute(0, 2, 3, 1)  # Change dimensions to [b, T, h, c]
        x = torch.reshape(x, (b * T, h, c))  # Reshape to [b*T, h, c]

        # Pass through the FSMN
        output = self.fsmn(x)  # output: [b*T, h, c]

        # Reshape output back to original dimensions
        output = torch.reshape(output, (b, T, h, c))  # output: [b, T, h, c]

        return output.permute(0, 3, 1, 2)  # Final output shape: [b, c, h, T]

class DilatedDenseNet(nn.Module):
    """
    DilatedDenseNet: A dilated dense network for feature extraction.

    This network consists of a series of dilated convolutions organized in a dense block structure,
    allowing for efficient feature reuse and capturing multi-scale information.

    Args:
        depth (int): The number of layers in the dense block (default is 4).
        in_channels (int): The number of input channels for the first layer (default is 64).
    """

    def __init__(self, depth=4, in_channels=64):
        """Initializes the DilatedDenseNet with specified depth and input channels."""
        super(DilatedDenseNet, self).__init__()
        self.depth = depth
        self.in_channels = in_channels
        self.pad = nn.ConstantPad2d((1, 1, 1, 0), value=0.)  # Padding for the first layer
        self.twidth = 2  # Temporal width for convolutions
        self.kernel_size = (self.twidth, 3)  # Kernel size for convolutions

        # Initialize dilated convolutions, padding, normalization, and FSMN for each layer
        for i in range(self.depth):
            dil = 2 ** i  # Dilation factor for the current layer
            pad_length = self.twidth + (dil - 1) * (self.twidth - 1) - 1  # Calculate padding length
            setattr(self, 'pad{}'.format(i + 1), nn.ConstantPad2d((1, 1, pad_length, 0), value=0.))
            setattr(self, 'conv{}'.format(i + 1),
                    nn.Conv2d(self.in_channels * (i + 1), self.in_channels, kernel_size=self.kernel_size,
                              dilation=(dil, 1)))  # Convolution layer
            setattr(self, 'norm{}'.format(i + 1), nn.InstanceNorm2d(in_channels, affine=True))  # Normalization
            setattr(self, 'prelu{}'.format(i + 1), nn.PReLU(self.in_channels))  # Activation function
            setattr(self, 'fsmn{}'.format(i + 1), FSMN_Wrap(nIn=self.in_channels, nHidden=self.in_channels, lorder=5, nOut=self.in_channels))

    def forward(self, x):
        """
        Defines the forward pass for the DilatedDenseNet.

        Args:
            x (torch.Tensor): Input tensor of shape [batch_size, channels, height, width].

        Returns:
            torch.Tensor: Output tensor after processing through the dense network.
        """
        skip = x  # Initialize skip connection with input
        for i in range(self.depth):
            # Apply padding, convolution, normalization, activation, and FSMN in sequence
            out = getattr(self, 'pad{}'.format(i + 1))(skip)
            out = getattr(self, 'conv{}'.format(i + 1))(out)
            out = getattr(self, 'norm{}'.format(i + 1))(out)
            out = getattr(self, 'prelu{}'.format(i + 1))(out)
            out = getattr(self, 'fsmn{}'.format(i + 1))(out)
            skip = torch.cat([out, skip], dim=1)  # Concatenate outputs for dense connectivity
        return out  # Return the final output


class DenseEncoder(nn.Module):
    """
    DenseEncoder: A dense encoding module for feature extraction from input data.

    This module consists of a series of convolutional layers followed by a 
    dilated dense network for robust feature learning.

    Args:
        in_channel (int): Number of input channels for the encoder.
        channels (int): Number of output channels for each convolutional layer (default is 64).
    """

    def __init__(self, in_channel, channels=64):
        """Initializes the DenseEncoder with specified input channels and feature size."""
        super(DenseEncoder, self).__init__()
        self.conv_1 = nn.Sequential(
            nn.Conv2d(in_channel, channels, (1, 1), (1, 1)),  # Initial convolution layer
            nn.InstanceNorm2d(channels, affine=True),  # Normalization layer
            nn.PReLU(channels)  # Activation function
        )
        self.dilated_dense = DilatedDenseNet(depth=4, in_channels=channels)  # Dilated Dense Network
        self.conv_2 = nn.Sequential(
            nn.Conv2d(channels, channels, (1, 3), (1, 2), padding=(0, 1)),  # Second convolution layer
            nn.InstanceNorm2d(channels, affine=True),  # Normalization layer
            nn.PReLU(channels)  # Activation function
        )

    def forward(self, x):
        """
        Defines the forward pass for the DenseEncoder.

        Args:
            x (torch.Tensor): Input tensor of shape [batch_size, in_channel, height, width].

        Returns:
            torch.Tensor: Output tensor after processing through the encoder.
        """
        x = self.conv_1(x)  # Process through the first convolutional layer
        x = self.dilated_dense(x)  # Process through the dilated dense network
        x = self.conv_2(x)  # Process through the second convolutional layer
        return x  # Return the final output


class SPConvTranspose2d(nn.Module):
    """
    SPConvTranspose2d: A spatially separable convolution transpose layer.

    This module implements a transposed convolution operation with spatial separability,
    allowing for efficient upsampling and feature extraction.

    Args:
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        kernel_size (tuple): Size of the convolution kernel.
        r (int): Upsampling rate (default is 1).
    """

    def __init__(self, in_channels, out_channels, kernel_size, r=1):
        """Initializes the SPConvTranspose2d with specified parameters."""
        super(SPConvTranspose2d, self).__init__()
        self.pad1 = nn.ConstantPad2d((1, 1, 0, 0), value=0.)  # Padding for input
        self.out_channels = out_channels  # Store number of output channels
        self.conv = nn.Conv2d(in_channels, out_channels * r, kernel_size=kernel_size, stride=(1, 1))  # Convolution layer
        self.r = r  # Store the upsampling rate

    def forward(self, x):
        """
        Defines the forward pass for the SPConvTranspose2d module.

        Args:
            x (torch.Tensor): Input tensor of shape [batch_size, in_channels, height, width].

        Returns:
            torch.Tensor: Output tensor after transposed convolution operation.
        """
        x = self.pad1(x)  # Apply padding to input
        out = self.conv(x)  # Perform convolution operation
        batch_size, nchannels, H, W = out.shape  # Get output shape
        out = out.view((batch_size, self.r, nchannels // self.r, H, W))  # Reshape output for separation
        out = out.permute(0, 2, 3, 4, 1)  # Rearrange dimensions
        out = out.contiguous().view((batch_size, nchannels // self.r, H, -1))  # Final output shape
        return out  # Return the final output

class MaskDecoder(nn.Module):
    """
    MaskDecoder: A decoder module for estimating masks used in audio processing.

    This module utilizes a dilated dense network to capture features and 
    applies sub-pixel convolution to upscale the output. It produces 
    a mask that can be applied to the magnitude of audio signals.

    Args:
        num_features (int): The number of features in the output mask.
        num_channel (int): The number of channels in intermediate layers (default is 64).
        out_channel (int): The number of output channels for the final output mask (default is 1).
    """

    def __init__(self, num_features, num_channel=64, out_channel=1):
        """Initializes the MaskDecoder with specified parameters."""
        super(MaskDecoder, self).__init__()
        self.dense_block = DilatedDenseNet(depth=4, in_channels=num_channel)  # Dense feature extraction
        self.sub_pixel = SPConvTranspose2d(num_channel, num_channel, (1, 3), 2)  # Sub-pixel convolution for upsampling
        self.conv_1 = nn.Conv2d(num_channel, out_channel, (1, 2))  # Convolution layer to produce mask
        self.norm = nn.InstanceNorm2d(out_channel, affine=True)  # Normalization layer
        self.prelu = nn.PReLU(out_channel)  # Activation function
        self.final_conv = nn.Conv2d(out_channel, out_channel, (1, 1))  # Final convolution layer
        self.prelu_out = nn.PReLU(num_features, init=-0.25)  # Final activation for output mask

    def forward(self, x):
        """
        Defines the forward pass for the MaskDecoder.

        Args:
            x (torch.Tensor): Input tensor of shape [batch_size, channels, height, width].

        Returns:
            torch.Tensor: Output mask tensor after processing through the decoder.
        """
        x = self.dense_block(x)  # Feature extraction using dilated dense block
        x = self.sub_pixel(x)  # Upsample the features
        x = self.conv_1(x)  # Convolution to estimate the mask
        x = self.prelu(self.norm(x))  # Apply normalization and activation
        x = self.final_conv(x).permute(0, 3, 2, 1).squeeze(-1)  # Final convolution and rearrangement
        return self.prelu_out(x).permute(0, 2, 1).unsqueeze(1)  # Final output shape


class ComplexDecoder(nn.Module):
    """
    ComplexDecoder: A decoder module for estimating complex-valued outputs.

    This module processes features through a dilated dense network and a 
    sub-pixel convolution layer to generate two output channels representing 
    the real and imaginary parts of the complex output.

    Args:
        num_channel (int): The number of channels in intermediate layers (default is 64).
    """

    def __init__(self, num_channel=64):
        """Initializes the ComplexDecoder with specified parameters."""
        super(ComplexDecoder, self).__init__()
        self.dense_block = DilatedDenseNet(depth=4, in_channels=num_channel)  # Dense feature extraction
        self.sub_pixel = SPConvTranspose2d(num_channel, num_channel, (1, 3), 2)  # Sub-pixel convolution for upsampling
        self.prelu = nn.PReLU(num_channel)  # Activation function
        self.norm = nn.InstanceNorm2d(num_channel, affine=True)  # Normalization layer
        self.conv = nn.Conv2d(num_channel, 2, (1, 2))  # Convolution layer to produce complex outputs

    def forward(self, x):
        """
        Defines the forward pass for the ComplexDecoder.

        Args:
            x (torch.Tensor): Input tensor of shape [batch_size, channels, height, width].

        Returns:
            torch.Tensor: Output tensor containing real and imaginary parts.
        """
        x = self.dense_block(x)  # Feature extraction using dilated dense block
        x = self.sub_pixel(x)  # Upsample the features
        x = self.prelu(self.norm(x))  # Apply normalization and activation
        x = self.conv(x)  # Generate complex output
        return x  # Return the output tensor


class SyncANet(nn.Module):
    """
    SyncANet: A synchronous audio processing network for separating audio signals.

    This network integrates dense encoding, synchronous attention blocks, 
    and separate decoders for estimating masks and complex-valued outputs.

    Args:
        num_channel (int): The number of channels in the network (default is 64).
        num_features (int): The number of features for the mask decoder (default is 201).
    """

    def __init__(self, num_channel=64, num_features=201):
        """Initializes the SyncANet with specified parameters."""
        super(SyncANet, self).__init__()
        self.dense_encoder = DenseEncoder(in_channel=3, channels=num_channel)  # Dense encoder for input
        self.n_layers = 6  # Number of synchronous attention layers
        self.blocks = nn.ModuleList([])  # List to hold attention blocks
        
        # Initialize attention blocks
        for _ in range(self.n_layers):
            self.blocks.append(
                SyncANetBlock(
                    emb_dim=num_channel,
                    emb_ks=2,
                    emb_hs=1,
                    n_freqs=int(num_features//2)+1,
                    hidden_channels=num_channel*2,
                    n_head=4,
                    approx_qk_dim=512,
                    activation='prelu',
                    eps=1.0e-5,
                )
            )

        self.mask_decoder = MaskDecoder(num_features, num_channel=num_channel, out_channel=1)  # Mask decoder
        self.complex_decoder = ComplexDecoder(num_channel=num_channel)  # Complex decoder

    def forward(self, x):
        """
        Defines the forward pass for the SyncANet.

        Args:
            x (torch.Tensor): Input tensor of shape [batch_size, 2, height, width] representing complex signals.

        Returns:
            list: List containing the real and imaginary parts of the output tensor.
        """
        out_list = []  # List to store outputs
        mag = torch.sqrt(x[:, 0, :, :]**2 + x[:, 1, :, :]**2).unsqueeze(1)  # Calculate magnitude
        noisy_phase = torch.angle(torch.complex(x[:, 0, :, :], x[:, 1, :, :])).unsqueeze(1)  # Calculate phase
        x_in = torch.cat([mag, x], dim=1)  # Concatenate magnitude and input for processing

        x = self.dense_encoder(x_in)  # Feature extraction using dense encoder
        for ii in range(self.n_layers):
            x = self.blocks[ii](x)  # Pass through attention blocks

        mask = self.mask_decoder(x)  # Estimate mask from features
        out_mag = mask * mag  # Apply mask to magnitude

        complex_out = self.complex_decoder(x)  # Generate complex output
        mag_real = out_mag * torch.cos(noisy_phase)  # Real part of the output
        mag_imag = out_mag * torch.sin(noisy_phase)  # Imaginary part of the output
        final_real = mag_real + complex_out[:, 0, :, :].unsqueeze(1)  # Final real output
        final_imag = mag_imag + complex_out[:, 1, :, :].unsqueeze(1)  # Final imaginary output
        out_list.append(final_real)  # Append real output to list
        out_list.append(final_imag)  # Append imaginary output to list

        return out_list  # Return list of outputs

class FFConvM(nn.Module):
    """
    FFConvM: A feedforward convolutional module combining linear layers, normalization, 
    non-linear activation, and convolution operations.

    This module processes input tensors through a sequence of transformations, including 
    normalization, a linear layer with a SiLU activation, a convolutional operation, and 
    dropout for regularization.

    Args:
        dim_in (int): The number of input features (dimensionality of input).
        dim_out (int): The number of output features (dimensionality of output).
        norm_klass (nn.Module): The normalization class to be applied (default is nn.LayerNorm).
        dropout (float): The dropout probability for regularization (default is 0.1).
    """

    def __init__(
        self,
        dim_in,
        dim_out,
        norm_klass=nn.LayerNorm,
        dropout=0.1
    ):
        """Initializes the FFConvM with specified parameters."""
        super().__init__()
        
        # Define the sequential model
        self.mdl = nn.Sequential(
            norm_klass(dim_in),  # Apply normalization to input
            nn.Linear(dim_in, dim_out),  # Linear transformation to dim_out
            nn.SiLU(),  # Non-linear activation using SiLU (Sigmoid Linear Unit)
            ConvModule(dim_out),  # Convolution operation on the output
            nn.Dropout(dropout)  # Dropout layer for regularization
        )

    def forward(self, x):
        """
        Defines the forward pass for the FFConvM.

        Args:
            x (torch.Tensor): Input tensor of shape [batch_size, dim_in].

        Returns:
            torch.Tensor: Output tensor of shape [batch_size, dim_out] after processing.
        """
        output = self.mdl(x)  # Pass input through the sequential model
        return output  # Return the processed output

class SyncANetBlock(nn.Module):
    """
    SyncANetBlock implements a modified version of the MossFormer (GatedFormer) module,
    inspired by the TF-GridNet architecture (https://arxiv.org/abs/2211.12433). 
    It combines gated triple-attention schemes and Finite Short Memory Network (FSMN) modules 
    to enhance computational efficiency and overall performance in audio processing tasks.

    Attributes:
        emb_dim (int): Dimensionality of the embedding.
        emb_ks (int): Kernel size for embeddings.
        emb_hs (int): Stride size for embeddings.
        n_freqs (int): Number of frequency bands.
        hidden_channels (int): Number of hidden channels.
        n_head (int): Number of attention heads.
        approx_qk_dim (int): Approximate dimension for query-key matrices.
        activation (str): Activation function to use.
        eps (float): Small value to avoid division by zero in normalization layers.
    """
    
    def __getitem__(self, key):
        """ 
        Allows accessing module attributes using indexing.
        
        Args:
            key: Attribute name to retrieve.
        
        Returns:
            The requested attribute.
        """
        return getattr(self, key)

    def __init__(
        self,
        emb_dim,
        emb_ks,
        emb_hs,
        n_freqs,
        hidden_channels,
        n_head=4,
        approx_qk_dim=512,
        activation="prelu",
        eps=1e-5,
    ):
        """
        Initializes the SyncANetBlock with the specified parameters.

        Args:
            emb_dim (int): Dimensionality of the embedding.
            emb_ks (int): Kernel size for embeddings.
            emb_hs (int): Stride size for embeddings.
            n_freqs (int): Number of frequency bands.
            hidden_channels (int): Number of hidden channels.
            n_head (int): Number of attention heads. Default is 4.
            approx_qk_dim (int): Approximate dimension for query-key matrices. Default is 512.
            activation (str): Activation function to use. Default is "prelu".
            eps (float): Small value to avoid division by zero in normalization layers. Default is 1e-5.
        """
        super().__init__()

        in_channels = emb_dim * emb_ks  # Calculate the number of input channels

        ## Intra modules: Modules for internal processing within the block
        self.Fconv = nn.Conv2d(emb_dim, in_channels, kernel_size=(1, emb_ks), stride=(1, 1), groups=emb_dim)
        self.intra_norm = LayerNormalization4D(emb_dim, eps=eps)  # Layer normalization
        self.intra_to_u = FFConvM(
            dim_in=in_channels,
            dim_out=hidden_channels,
            norm_klass=nn.LayerNorm,
            dropout=0.1,
        )
        self.intra_to_v = FFConvM(
            dim_in=in_channels,
            dim_out=hidden_channels,
            norm_klass=nn.LayerNorm,
            dropout=0.1,
        )
        self.intra_rnn = self._build_repeats(in_channels, hidden_channels, 20, hidden_channels, repeats=1)  # FSMN layers
        self.intra_mossformer = MossFormer(dim=emb_dim, group_size=n_freqs)  # MossFormer module

        # Linear transformation for intra module output
        self.intra_linear = nn.ConvTranspose1d(
            hidden_channels, emb_dim, emb_ks, stride=emb_hs
        )
        self.intra_se = SELayer(channel=emb_dim, reduction=1)  # Squeeze-and-excitation layer

        ## Inter modules: Modules for external processing between blocks
        self.inter_norm = LayerNormalization4D(emb_dim, eps=eps)  # Layer normalization
        self.inter_to_u = FFConvM(
            dim_in=in_channels,
            dim_out=hidden_channels,
            norm_klass=nn.LayerNorm,
            dropout=0.1,
        )
        self.inter_to_v = FFConvM(
            dim_in=in_channels,
            dim_out=hidden_channels,
            norm_klass=nn.LayerNorm,
            dropout=0.1,
        )
        self.inter_rnn = self._build_repeats(in_channels, hidden_channels, 20, hidden_channels, repeats=1)  # FSMN layers
        self.inter_mossformer = MossFormer(dim=emb_dim, group_size=256)  # MossFormer module

        # Linear transformation for inter module output
        self.inter_linear = nn.ConvTranspose1d(
            hidden_channels, emb_dim, emb_ks, stride=emb_hs
        )
        self.inter_se = SELayer(channel=emb_dim, reduction=1)  # Squeeze-and-excitation layer

        # Approximate query-key dimension calculation
        E = math.ceil(approx_qk_dim * 1.0 / n_freqs)
        assert emb_dim % n_head == 0  # Ensure emb_dim is divisible by n_head

        # Define attention convolution layers for each head
        for ii in range(n_head):
            self.add_module(
                f"attn_conv_Q_{ii}",
                nn.Sequential(
                    nn.Conv2d(emb_dim, E, 1),
                    get_layer(activation)(),
                    LayerNormalization4DCF((E, n_freqs), eps=eps),
                ),
            )
            self.add_module(
                f"attn_conv_K_{ii}",
                nn.Sequential(
                    nn.Conv2d(emb_dim, E, 1),
                    get_layer(activation)(),
                    LayerNormalization4DCF((E, n_freqs), eps=eps),
                ),
            )
            self.add_module(
                f"attn_conv_V_{ii}",
                nn.Sequential(
                    nn.Conv2d(emb_dim, emb_dim // n_head, 1),
                    get_layer(activation)(),
                    LayerNormalization4DCF((emb_dim // n_head, n_freqs), eps=eps),
                ),
            )
        
        # Final attention concatenation projection
        self.add_module(
            "attn_concat_proj",
            nn.Sequential(
                nn.Conv2d(emb_dim, emb_dim, 1),
                get_layer(activation)(),
                LayerNormalization4DCF((emb_dim, n_freqs), eps=eps),
            ),
        )

        # Store parameters for further processing
        self.emb_dim = emb_dim
        self.emb_ks = emb_ks
        self.emb_hs = emb_hs
        self.n_head = n_head

    def _build_repeats(self, in_channels, out_channels, lorder, hidden_size, repeats=1):
        """
        Constructs a sequence of UniDeepFSMN modules.

        Args:
            in_channels (int): Number of input channels.
            out_channels (int): Number of output channels.
            lorder (int): Order of the filter.
            hidden_size (int): Hidden size for the FSMN.
            repeats (int): Number of times to repeat the module. Default is 1.

        Returns:
            nn.Sequential: A sequence of UniDeepFSMN modules.
        """
        repeats = [
            UniDeepFsmn(in_channels, out_channels, lorder, hidden_size)
            for _ in range(repeats)
        ]
        return nn.Sequential(*repeats)

    def forward(self, x):
        """Performs a forward pass through the SyncANetBlock.

        Args:
            x (torch.Tensor): Input tensor of shape [B, C, T, Q] where 
                              B is batch size, C is number of channels, 
                              T is temporal dimension, and Q is frequency dimension.

        Returns:
            torch.Tensor: Output tensor of the same shape [B, C, T, Q].
        """
        B, C, old_T, old_Q = x.shape
        
        # Calculate new dimensions for padding
        T = math.ceil((old_T - self.emb_ks) / self.emb_hs) * self.emb_hs + self.emb_ks
        Q = math.ceil((old_Q - self.emb_ks) / self.emb_hs) * self.emb_hs + self.emb_ks
        
        # Pad the input tensor to match the new dimensions
        x = F.pad(x, (0, Q - old_Q, 0, T - old_T))

        # Intra-process
        input_ = x
        intra_rnn = self.intra_norm(input_)  # Normalize input for intra-process
        intra_rnn = self.Fconv(intra_rnn)    # Apply depthwise convolution
        intra_rnn = (
            intra_rnn.transpose(1, 2).contiguous().view(B * T, C * self.emb_ks, -1)
        )  # Reshape for subsequent operations

        intra_rnn = intra_rnn.transpose(1, 2)  # Reshape for processing
        intra_rnn_u = self.intra_to_u(intra_rnn)  # Linear transformation
        intra_rnn_v = self.intra_to_v(intra_rnn)  # Linear transformation
        intra_rnn_u = self.intra_rnn(intra_rnn_u)  # Apply FSMN
        intra_rnn = intra_rnn_v * intra_rnn_u  # Element-wise multiplication
        intra_rnn = intra_rnn.transpose(1, 2)  # Reshape back
        intra_rnn = self.intra_linear(intra_rnn)  # Linear projection
        intra_rnn = intra_rnn.transpose(1, 2)  # Reshape for mossformer
        intra_rnn = intra_rnn.view([B, T, Q, C])  # Reshape for mossformer
        intra_rnn = self.intra_mossformer(intra_rnn)  # Apply MossFormer
        intra_rnn = intra_rnn.transpose(1, 2)  # Reshape back
        intra_rnn = intra_rnn.view([B, T, C, Q])  # Reshape back
        intra_rnn = intra_rnn.transpose(1, 2).contiguous()  # Final reshape
        intra_rnn = self.intra_se(intra_rnn)  # Squeeze-and-excitation layer
        intra_rnn = intra_rnn + input_  # Residual connection

        # Inter-process
        input_ = intra_rnn
        inter_rnn = self.inter_norm(input_)  # Normalize input for inter-process
        inter_rnn = (
            inter_rnn.permute(0, 3, 1, 2).contiguous().view(B * Q, C, T)
        )  # Reshape for processing
        inter_rnn = F.unfold(
            inter_rnn[..., None], (self.emb_ks, 1), stride=(self.emb_hs, 1)
        )  # Extract sliding windows
        inter_rnn = inter_rnn.transpose(1, 2)  # Reshape for further processing
        inter_rnn_u = self.inter_to_u(inter_rnn)  # Linear transformation
        inter_rnn_v = self.inter_to_v(inter_rnn)  # Linear transformation
        inter_rnn_u = self.inter_rnn(inter_rnn_u)  # Apply FSMN
        inter_rnn = inter_rnn_v * inter_rnn_u  # Element-wise multiplication
        inter_rnn = inter_rnn.transpose(1, 2)  # Reshape back
        inter_rnn = self.inter_linear(inter_rnn)  # Linear projection
        inter_rnn = inter_rnn.transpose(1, 2)  # Reshape for mossformer
        inter_rnn = inter_rnn.view([B, Q, T, C])  # Reshape for mossformer
        inter_rnn = self.inter_mossformer(inter_rnn)  # Apply MossFormer
        inter_rnn = inter_rnn.transpose(1, 2)  # Reshape back
        inter_rnn = inter_rnn.view([B, Q, C, T])  # Final reshape
        inter_rnn = inter_rnn.permute(0, 2, 3, 1).contiguous()  # Permute for SE layer
        inter_rnn = self.inter_se(inter_rnn)  # Squeeze-and-excitation layer
        inter_rnn = inter_rnn + input_  # Residual connection

        # Attention mechanism
        inter_rnn = inter_rnn[..., :old_T, :old_Q]  # Trim to original shape

        batch = inter_rnn
        all_Q, all_K, all_V = [], [], []
        
        # Compute query, key, and value for each attention head
        for ii in range(self.n_head):
            all_Q.append(self["attn_conv_Q_%d" % ii](batch))  # Query
            all_K.append(self["attn_conv_K_%d" % ii](batch))  # Key
            all_V.append(self["attn_conv_V_%d" % ii](batch))  # Value

        Q = torch.cat(all_Q, dim=0)  # Concatenate all queries
        K = torch.cat(all_K, dim=0)  # Concatenate all keys
        V = torch.cat(all_V, dim=0)  # Concatenate all values

        # Reshape for attention calculation
        Q = Q.transpose(1, 2)
        Q = Q.flatten(start_dim=2)  # Flatten for attention calculation
        K = K.transpose(1, 2)
        K = K.flatten(start_dim=2)  # Flatten for attention calculation
        V = V.transpose(1, 2)  # Reshape for attention calculation
        old_shape = V.shape
        V = V.flatten(start_dim=2)  # Flatten for attention calculation
        emb_dim = Q.shape[-1]

        # Compute scaled dot-product attention
        attn_mat = torch.matmul(Q, K.transpose(1, 2)) / (emb_dim**0.5)  # Attention matrix
        attn_mat = F.softmax(attn_mat, dim=2)  # Softmax over attention scores
        V = torch.matmul(attn_mat, V)  # Weighted sum of values

        V = V.reshape(old_shape)  # Reshape back
        V = V.transpose(1, 2)  # Final reshaping
        emb_dim = V.shape[1]

        batch = V.view([self.n_head, B, emb_dim, old_T, -1])  # Reshape for multi-head
        batch = batch.transpose(0, 1)  # Permute for batch processing
        batch = batch.contiguous().view(
            [B, self.n_head * emb_dim, old_T, -1]
        )  # Final reshape for concatenation
        batch = self["attn_concat_proj"](batch)  # Final linear projection

        # Combine inter-process result with attention output
        out = batch + inter_rnn
        return out  # Return the output tensor

class LayerNormalization4D(nn.Module):
    """
    LayerNormalization4D applies layer normalization to 4D tensors 
    (e.g., [B, C, T, F]), where B is the batch size, C is the number of channels,
    T is the temporal dimension, and F is the frequency dimension.

    Attributes:
        gamma (torch.Parameter): Learnable scaling parameter.
        beta (torch.Parameter): Learnable shifting parameter.
        eps (float): Small value for numerical stability during variance calculation.
    """

    def __init__(self, input_dimension, eps=1e-5):
        """
        Initializes the LayerNormalization4D layer.

        Args:
            input_dimension (int): The number of channels in the input tensor.
            eps (float, optional): Small constant added for numerical stability.
        """
        super().__init__()
        param_size = [1, input_dimension, 1, 1]
        self.gamma = Parameter(torch.Tensor(*param_size).to(torch.float32))  # Scale parameter
        self.beta = Parameter(torch.Tensor(*param_size).to(torch.float32))   # Shift parameter
        init.ones_(self.gamma)  # Initialize gamma to 1
        init.zeros_(self.beta)   # Initialize beta to 0
        self.eps = eps  # Set the epsilon value

    def forward(self, x):
        """
        Forward pass for the layer normalization.

        Args:
            x (torch.Tensor): Input tensor of shape [B, C, T, F].

        Returns:
            torch.Tensor: Normalized output tensor of the same shape.
        """
        if x.ndim == 4:
            _, C, _, _ = x.shape  # Extract the number of channels
            stat_dim = (1,)  # Dimension to compute statistics over
        else:
            raise ValueError("Expect x to have 4 dimensions, but got {}".format(x.ndim))

        # Compute mean and standard deviation along the specified dimension
        mu_ = x.mean(dim=stat_dim, keepdim=True)  # [B, 1, T, F]
        std_ = torch.sqrt(
            x.var(dim=stat_dim, unbiased=False, keepdim=True) + self.eps
        )  # [B, 1, T, F]

        # Normalize the input tensor and apply learnable parameters
        x_hat = ((x - mu_) / std_) * self.gamma + self.beta  # [B, C, T, F]
        return x_hat

class LayerNormalization4DCF(nn.Module):
    """
    LayerNormalization4DCF applies layer normalization to 4D tensors 
    (e.g., [B, C, T, F]) specifically designed for DCF (Dynamic Channel Frequency) inputs.
    
    Attributes:
        gamma (torch.Parameter): Learnable scaling parameter.
        beta (torch.Parameter): Learnable shifting parameter.
        eps (float): Small value for numerical stability during variance calculation.
    """

    def __init__(self, input_dimension, eps=1e-5):
        """
        Initializes the LayerNormalization4DCF layer.

        Args:
            input_dimension (tuple): A tuple containing the dimensions of the input tensor 
                                     (number of channels, frequency dimension).
            eps (float, optional): Small constant added for numerical stability.
        """
        super().__init__()
        assert len(input_dimension) == 2, "Input dimension must be a tuple of length 2."
        param_size = [1, input_dimension[0], 1, input_dimension[1]]  # Shape based on input dimensions
        self.gamma = Parameter(torch.Tensor(*param_size).to(torch.float32))  # Scale parameter
        self.beta = Parameter(torch.Tensor(*param_size).to(torch.float32))   # Shift parameter
        init.ones_(self.gamma)  # Initialize gamma to 1
        init.zeros_(self.beta)   # Initialize beta to 0
        self.eps = eps  # Set the epsilon value

    def forward(self, x):
        """
        Forward pass for the layer normalization.

        Args:
            x (torch.Tensor): Input tensor of shape [B, C, T, F].

        Returns:
            torch.Tensor: Normalized output tensor of the same shape.
        """
        if x.ndim == 4:
            stat_dim = (1, 3)  # Dimensions to compute statistics over
        else:
            raise ValueError("Expect x to have 4 dimensions, but got {}".format(x.ndim))

        # Compute mean and standard deviation along the specified dimensions
        mu_ = x.mean(dim=stat_dim, keepdim=True)  # [B, 1, T, 1]
        std_ = torch.sqrt(
            x.var(dim=stat_dim, unbiased=False, keepdim=True) + self.eps
        )  # [B, 1, T, F]

        # Normalize the input tensor and apply learnable parameters
        x_hat = ((x - mu_) / std_) * self.gamma + self.beta  # [B, C, T, F]
        return x_hat