File size: 17,953 Bytes
8e8cd3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
import torch
import torch.nn as nn
import torch.nn.functional as F

class UniDeepFsmn(nn.Module):
    """
    A single layer Deep Feedforward Sequential Memory Network (FSMN) for unidirectional processing.
    This model uses a combination of linear layers and convolutional layers to process input features.

    Attributes:
    - input_dim (int): Number of input features.
    - output_dim (int): Number of output features.
    - lorder (int): Order of the linear filter.
    - hidden_size (int): Number of hidden units in the linear layer.
    """
    
    def __init__(self, input_dim, output_dim, lorder=None, hidden_size=None):
        super(UniDeepFsmn, self).__init__()

        self.input_dim = input_dim
        self.output_dim = output_dim

        if lorder is None:
            return  # If no filter order is provided, skip initialization

        self.lorder = lorder
        self.hidden_size = hidden_size

        # Linear layer to project input to hidden space
        self.linear = nn.Linear(input_dim, hidden_size)
        # Projection layer to output space
        self.project = nn.Linear(hidden_size, output_dim, bias=False)
        # Depthwise convolution layer for filtering
        self.conv1 = nn.Conv2d(output_dim, output_dim, [lorder, 1], [1, 1], groups=output_dim, bias=False)

    def forward(self, input):
        """
        Forward pass through the UniDeepFsmn model.
        
        Parameters:
        - input (Tensor): Input tensor of shape (batch_size, sequence_length, input_dim).
        
        Returns:
        - Tensor: Output tensor after processing, with the same shape as input.
        """
        # Apply linear transformation and ReLU activation
        f1 = F.relu(self.linear(input))
        # Project to output dimension
        p1 = self.project(f1)

        # Reshape and pad the tensor for convolution
        x = torch.unsqueeze(p1, 1)  # Shape: (b, c, T, h)
        x_per = x.permute(0, 3, 2, 1)  # Permute to shape (b, h, T, c)
        y = F.pad(x_per, [0, 0, self.lorder - 1, 0])  # Pad the tensor

        # Add convolutional output to original input
        out = x_per + self.conv1(y)

        out1 = out.permute(0, 3, 2, 1)  # Restore original shape
        return input + out1.squeeze()  # Return the combined output


class ComplexUniDeepFsmn(nn.Module):
    """
    A complex variant of the UniDeepFsmn that processes complex-valued input.
    This model has separate layers for the real and imaginary components.

    Attributes:
    - nIn (int): Number of input features.
    - nHidden (int): Number of hidden units in the FSMN layers.
    - nOut (int): Number of output features.
    """

    def __init__(self, nIn, nHidden=128, nOut=128):
        super(ComplexUniDeepFsmn, self).__init__()

        # Initialize FSMN layers for real and imaginary parts
        self.fsmn_re_L1 = UniDeepFsmn(nIn, nHidden, 20, nHidden)
        self.fsmn_im_L1 = UniDeepFsmn(nIn, nHidden, 20, nHidden)
        self.fsmn_re_L2 = UniDeepFsmn(nHidden, nOut, 20, nHidden)
        self.fsmn_im_L2 = UniDeepFsmn(nHidden, nOut, 20, nHidden)

    def forward(self, x):
        """
        Forward pass through the ComplexUniDeepFsmn model.

        Parameters:
        - x (Tensor): Input tensor of shape (batch_size, channels, height, time, 2) where 
                      the last dimension represents the real and imaginary parts.

        Returns:
        - Tensor: Output tensor after processing.
        """
        # Reshape input to [b, c*h, T, 2]
        b, c, h, T, d = x.size()
        x = torch.reshape(x, (b, c * h, T, d))  # Flatten channel and height
        x = torch.transpose(x, 1, 2)  # Permute to [b, T, c*h, 2]

        # Process the real and imaginary parts
        real_L1 = self.fsmn_re_L1(x[..., 0]) - self.fsmn_im_L1(x[..., 1])
        imaginary_L1 = self.fsmn_re_L1(x[..., 1]) + self.fsmn_im_L1(x[..., 0])
        
        # Process the second layer
        real = self.fsmn_re_L2(real_L1) - self.fsmn_im_L2(imaginary_L1)
        imaginary = self.fsmn_re_L2(imaginary_L1) + self.fsmn_im_L2(real_L1)

        # Combine real and imaginary parts into output tensor
        output = torch.stack((real, imaginary), dim=-1)  # Shape: [b, T, h, 2]
        output = torch.transpose(output, 1, 2)  # Shape: [b, h, T, 2]
        output = torch.reshape(output, (b, c, h, T, d))  # Restore original shape

        return output


class ComplexUniDeepFsmn_L1(nn.Module):
    """
    A complex variant of UniDeepFsmn for the first layer.
    This model processes complex-valued input and has two FSMN layers for the real and imaginary parts.

    Attributes:
    - nIn (int): Number of input features.
    - nHidden (int): Number of hidden units in the FSMN layers.
    - nOut (int): Number of output features.
    """

    def __init__(self, nIn, nHidden=128, nOut=128):
        super(ComplexUniDeepFsmn_L1, self).__init__()

        self.fsmn_re_L1 = UniDeepFsmn(nIn, nHidden, 20, nHidden)
        self.fsmn_im_L1 = UniDeepFsmn(nIn, nHidden, 20, nHidden)

    def forward(self, x):
        """
        Forward pass through the ComplexUniDeepFsmn_L1 model.

        Parameters:
        - x (Tensor): Input tensor of shape (batch_size, channels, height, time, 2).

        Returns:
        - Tensor: Output tensor after processing.
        """
        b, c, h, T, d = x.size()
        x = torch.transpose(x, 1, 3)  # Shape: [b, T, h, c, 2]
        x = torch.reshape(x, (b * T, h, c, d))  # Reshape to process

        # Process the real and imaginary parts
        real = self.fsmn_re_L1(x[..., 0]) - self.fsmn_im_L1(x[..., 1])
        imaginary = self.fsmn_re_L1(x[..., 1]) + self.fsmn_im_L1(x[..., 0])

        # Combine results and reshape back to original dimensions
        output = torch.stack((real, imaginary), dim=-1)  # Shape: [b*T, h, c, 2]
        output = torch.reshape(output, (b, T, h, c, d))  # Restore shape to [b, T, h, c, 2]
        output = torch.transpose(output, 1, 3)  # Shape: [b, c, h, T, 2]

        return output


class BidirectionalLSTM_L1(nn.Module):
    """
    A unidirectional LSTM model for processing sequences.

    Attributes:
    - nIn (int): Number of input features.
    - nHidden (int): Number of hidden units in the LSTM.
    - nOut (int): Number of output features.
    """
    
    def __init__(self, nIn, nHidden, nOut):
        super(BidirectionalLSTM_L1, self).__init__()

        self.rnn = nn.GRU(nIn, nHidden, bidirectional=False)  # Using GRU instead of LSTM

    def forward(self, input):
        """
        Forward pass through the BidirectionalLSTM_L1 model.

        Parameters:
        - input (Tensor): Input tensor of shape (sequence_length, batch_size, input_dim).

        Returns:
        - Tensor: Output tensor after processing.
        """
        output, _ = self.rnn(input)  # Forward pass through GRU
        return output

class BidirectionalLSTM_L2(nn.Module):
    """
    A unidirectional Long Short-Term Memory (LSTM) network that processes input sequences 
    and produces an output using a linear embedding layer.

    Attributes:
        rnn (nn.GRU): The GRU layer for processing the input sequences.
        embedding (nn.Linear): A linear layer that transforms the output of the GRU to the desired output dimension.
    """

    def __init__(self, nIn, nHidden, nOut):
        super(BidirectionalLSTM_L2, self).__init__()

        # Initialize the GRU layer
        self.rnn = nn.GRU(nIn, nHidden, bidirectional=False)
        # Initialize the linear embedding layer
        self.embedding = nn.Linear(nHidden, nOut)

    def forward(self, input):
        """
        Forward pass through the Bidirectional LSTM network.

        Args:
            input (torch.Tensor): Input tensor of shape (T, b, nIn), where T is the sequence length, 
                                  b is the batch size, and nIn is the input feature size.

        Returns:
            torch.Tensor: Output tensor of shape (T, b, nOut), where nOut is the output feature size.
        """
        recurrent, _ = self.rnn(input)  # Process the input through the GRU layer
        # Get the shape of the recurrent output
        T, b, h = recurrent.size()
        t_rec = recurrent.view(T * b, h)  # Flatten the output for the embedding layer
        output = self.embedding(t_rec)  # Apply the linear embedding layer
        output = output.view(T, b, -1)  # Reshape back to (T, b, nOut)

        return output


class ComplexBidirectionalLSTM(nn.Module):
    """
    A complex-valued bidirectional LSTM that processes input sequences containing real 
    and imaginary components, producing a complex-valued output.

    Attributes:
        lstm_re_L1 (BidirectionalLSTM_L1): First layer for the real part of the input.
        lstm_im_L1 (BidirectionalLSTM_L1): First layer for the imaginary part of the input.
        lstm_re_L2 (BidirectionalLSTM_L2): Second layer for the real part of the input.
        lstm_im_L2 (BidirectionalLSTM_L2): Second layer for the imaginary part of the input.
    """

    def __init__(self, nIn, nHidden=128, nOut=1024):
        super(ComplexBidirectionalLSTM, self).__init__()

        # Initialize the first and second LSTM layers for real and imaginary components
        self.lstm_re_L1 = BidirectionalLSTM_L1(nIn, nHidden, nOut)
        self.lstm_im_L1 = BidirectionalLSTM_L1(nIn, nHidden, nOut)
        self.lstm_re_L2 = BidirectionalLSTM_L2(nHidden, nHidden, nOut)
        self.lstm_im_L2 = BidirectionalLSTM_L2(nHidden, nHidden, nOut)

    def forward(self, x):
        """
        Forward pass through the complex-valued bidirectional LSTM.

        Args:
            x (torch.Tensor): Input tensor of shape (b, c, h, T, 2) where:
                - b is the batch size,
                - c is the number of channels,
                - h is the number of hidden units,
                - T is the sequence length,
                - 2 represents the real and imaginary parts.

        Returns:
            torch.Tensor: Output tensor of shape (b, c, h, T, 2).
        """
        # Get the shape of the input tensor
        b, c, h, T, d = x.size()
        # Reshape the input for processing
        x = torch.reshape(x, (b, c*h, T, d))
        # Transpose to prepare for LSTM processing
        x = torch.transpose(x, 0, 2)  # Shape: (T, c*h, d)
        x = torch.transpose(x, 1, 2)  # Shape: (T, d, c*h)

        # Process the real and imaginary parts through LSTM layers
        real_L1 = self.lstm_re_L1(x[..., 0]) - self.lstm_im_L1(x[..., 1])
        imaginary_L1 = self.lstm_re_L1(x[..., 1]) + self.lstm_im_L1(x[..., 0])
        real = self.lstm_re_L2(real_L1) - self.lstm_im_L2(imaginary_L1)
        imaginary = self.lstm_re_L2(imaginary_L1) + self.lstm_im_L2(real_L1)

        # Stack the real and imaginary parts to create the output tensor
        output = torch.stack((real, imaginary), dim=-1)  # Shape: (T, b, h, 2)
        output = torch.transpose(output, 1, 2)  # Shape: (T, h, b, 2)
        output = torch.transpose(output, 0, 2)  # Shape: (b, h, T, 2)
        output = torch.reshape(output, (b, c, h, T, d))  # Shape: (b, c, h, T, 2)

        return output


class ComplexConv2d(nn.Module):
    """
    A complex-valued 2D convolutional layer that processes input tensors with real 
    and imaginary parts, returning a complex output.

    Attributes:
        conv_re (nn.Conv2d): Convolutional layer for the real part.
        conv_im (nn.Conv2d): Convolutional layer for the imaginary part.
    """

    def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, **kwargs):
        super().__init__()

        # Initialize convolutional layers for real and imaginary components
        self.conv_re = nn.Conv2d(in_channel, out_channel, kernel_size, stride=stride, padding=padding,
                                 dilation=dilation, groups=groups, bias=bias, **kwargs)
        self.conv_im = nn.Conv2d(in_channel, out_channel, kernel_size, stride=stride, padding=padding,
                                 dilation=dilation, groups=groups, bias=bias, **kwargs)

    def forward(self, x):
        """
        Forward pass through the complex convolutional layer.

        Args:
            x (torch.Tensor): Input tensor of shape (batch, channel, axis1, axis2, 2) 
                             representing the real and imaginary parts.

        Returns:
            torch.Tensor: Output tensor containing the convolved real and imaginary parts.
        """
        # Apply convolution to the real and imaginary parts
        real = self.conv_re(x[..., 0]) - self.conv_im(x[..., 1])
        imaginary = self.conv_re(x[..., 1]) + self.conv_im(x[..., 0])
        output = torch.stack((real, imaginary), dim=-1)  # Stack real and imaginary components

        return output


class ComplexConvTranspose2d(nn.Module):
    """
    A complex-valued 2D transposed convolutional layer that processes input tensors 
    with real and imaginary parts, returning a complex output.

    Attributes:
        tconv_re (nn.ConvTranspose2d): Transposed convolutional layer for the real part.
        tconv_im (nn.ConvTranspose2d): Transposed convolutional layer for the imaginary part.
    """

    def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, groups=1, bias=True, **kwargs):
        super().__init__()

        # Initialize transposed convolutional layers for real and imaginary components
        self.tconv_re = nn.ConvTranspose2d(in_channel, out_channel,
                                           kernel_size=kernel_size,
                                           stride=stride,
                                           padding=padding,
                                           output_padding=output_padding,
                                           groups=groups,
                                           bias=bias,
                                           dilation=dilation,
                                           **kwargs)
        self.tconv_im = nn.ConvTranspose2d(in_channel, out_channel,
                                           kernel_size=kernel_size,
                                           stride=stride,
                                           padding=padding,
                                           output_padding=output_padding,
                                           groups=groups,
                                           bias=bias,
                                           dilation=dilation,
                                           **kwargs)

    def forward(self, x):
        """
        Forward pass through the complex transposed convolutional layer.

        Args:
            x (torch.Tensor): Input tensor of shape (batch, channel, axis1, axis2, 2) 
                             representing the real and imaginary parts.

        Returns:
            torch.Tensor: Output tensor containing the transposed convoluted real and imaginary parts.
        """
        # Apply transposed convolution to the real and imaginary parts
        real = self.tconv_re(x[..., 0]) - self.tconv_im(x[..., 1])
        imaginary = self.tconv_re(x[..., 1]) + self.tconv_im(x[..., 0])
        output = torch.stack((real, imaginary), dim=-1)  # Stack real and imaginary components

        return output

class ComplexBatchNorm2d(nn.Module):
    """
    A complex-valued batch normalization layer that normalizes input tensors with 
    separate real and imaginary components.

    This layer applies batch normalization independently to the real and imaginary parts of the input,
    ensuring that each part is normalized appropriately. It is particularly useful in complex-valued networks,
    where inputs are represented as pairs of real and imaginary components.

    Attributes:
        bn_re (nn.BatchNorm2d): Batch normalization layer for the real part of the input.
        bn_im (nn.BatchNorm2d): Batch normalization layer for the imaginary part of the input.
    """

    def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
                 track_running_stats=True, **kwargs):
        """
        Initializes the ComplexBatchNorm2d layer.

        Args:
            num_features (int): Number of features (channels) for the input.
            eps (float, optional): A value added to the denominator for numerical stability. Default is 1e-5.
            momentum (float, optional): Momentum for the running mean and variance. Default is 0.1.
            affine (bool, optional): If True, this layer has learnable parameters. Default is True.
            track_running_stats (bool, optional): If True, track the running mean and variance. Default is True.
        """
        super().__init__()
        # Initialize batch normalization layers for real and imaginary parts
        self.bn_re = nn.BatchNorm2d(num_features=num_features, momentum=momentum, affine=affine, eps=eps, track_running_stats=track_running_stats, **kwargs)
        self.bn_im = nn.BatchNorm2d(num_features=num_features, momentum=momentum, affine=affine, eps=eps, track_running_stats=track_running_stats, **kwargs)

    def forward(self, x):
        """
        Forward pass through the complex batch normalization layer.

        Args:
            x (torch.Tensor): Input tensor of shape (batch, channel, height, width, 2),
                              where the last dimension represents the real and imaginary parts.

        Returns:
            torch.Tensor: Output tensor containing the normalized real and imaginary components,
                          with the same shape as the input tensor.
        """
        # Apply batch normalization to the real part
        real = self.bn_re(x[..., 0])
        # Apply batch normalization to the imaginary part
        imag = self.bn_im(x[..., 1])
        # Stack the normalized real and imaginary parts back together
        output = torch.stack((real, imag), dim=-1)

        return output