File size: 3,315 Bytes
8e8cd3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from torch import nn
import torch

## Referencing the paper: https://arxiv.org/pdf/1709.01507
class SELayer(nn.Module):
    """
    Squeeze-and-Excitation Layer (SELayer) for enhancing channel-wise feature responses.

    The SELayer implements the Squeeze-and-Excitation block as proposed in the paper,
    which adaptively recalibrates channel-wise feature responses by modeling the interdependencies
    between channels.

    Args:
        channel (int): The number of input channels.
        reduction (int): The reduction ratio for the number of channels in the bottleneck.
                         Default is 16.
    """
    def __init__(self, channel, reduction=16):
        super(SELayer, self).__init__()

        # Adaptive average pooling to generate a global descriptor
        self.avg_pool = nn.AdaptiveAvgPool2d(1)

        # Fully connected layers for the real part
        self.fc_r = nn.Sequential(
            nn.Linear(channel, channel // reduction),  # Reduce channels
            nn.ReLU(inplace=True),                     # Activation function
            nn.Linear(channel // reduction, channel),  # Restore channels
            nn.Sigmoid()                               # Sigmoid activation to scale outputs
        )

        # Fully connected layers for the imaginary part
        self.fc_i = nn.Sequential(
            nn.Linear(channel, channel // reduction),  # Reduce channels
            nn.ReLU(inplace=True),                     # Activation function
            nn.Linear(channel // reduction, channel),  # Restore channels
            nn.Sigmoid()                               # Sigmoid activation to scale outputs
        )

    def forward(self, x):
        """
        Forward pass for the SELayer.

        The forward method applies the squeeze-and-excitation operation on the input tensor `x`.
        It computes the channel-wise attention weights for both the real and imaginary parts 
        of the input.

        Args:
            x (torch.Tensor): Input tensor of shape (B, C, D, H, W), where:
                              B - batch size,
                              C - number of channels,
                              D - depth,
                              H - height,
                              W - width.

        Returns:
            torch.Tensor: Output tensor after applying channel-wise attention, 
                          same shape as input `x`.
        """
        # Extract the batch size and number of channels
        b, c, _, _, _ = x.size()

        # Compute the squeeze operation for the real part
        x_r = self.avg_pool(x[:, :, :, :, 0]).view(b, c)  # Global average pooling
        # Compute the squeeze operation for the imaginary part
        x_i = self.avg_pool(x[:, :, :, :, 1]).view(b, c)  # Global average pooling

        # Calculate channel-wise attention for the real part
        y_r = self.fc_r(x_r).view(b, c, 1, 1, 1) - self.fc_i(x_i).view(b, c, 1, 1, 1)
        # Calculate channel-wise attention for the imaginary part
        y_i = self.fc_r(x_i).view(b, c, 1, 1, 1) + self.fc_i(x_r).view(b, c, 1, 1, 1)

        # Concatenate real and imaginary attention weights along the channel dimension
        y = torch.cat([y_r, y_i], 4)

        # Scale the input features by the attention weights
        return x * y