|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
class MLP(nn.Module): |
|
def __init__( |
|
self, |
|
dim: int, |
|
hidden_dim: int, |
|
): |
|
""" |
|
Initializes the multilayer perceptron (MLP) module. |
|
|
|
Args: |
|
dim: The input and output dimensionality. |
|
hidden_dim: The dimensionality of the hidden layer. |
|
""" |
|
super().__init__() |
|
self.w1 = nn.Linear(dim, hidden_dim, bias=False) |
|
self.w2 = nn.Linear(hidden_dim, dim, bias=False) |
|
self.w3 = nn.Linear(dim, hidden_dim, bias=False) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Performs the forward pass of the MLP module. |
|
|
|
Args: |
|
x: The input tensor of shape (batch_size, dim). |
|
|
|
Returns: |
|
The output tensor of shape (batch_size, dim). |
|
""" |
|
output = self.w2(F.silu(self.w1(x)) * self.w3(x)) |
|
return output |
|
|