Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn as nn | |
from modules.cond import cast | |
class GEGLU(nn.Module): | |
"""#### Class representing the GEGLU activation function. | |
GEGLU is a gated activation function that is a combination of GELU and ReLU, | |
used to fire the neurons in the network. | |
#### Args: | |
- `dim_in` (int): The input dimension. | |
- `dim_out` (int): The output dimension. | |
""" | |
def __init__(self, dim_in: int, dim_out: int): | |
super().__init__() | |
self.proj = cast.manual_cast.Linear(dim_in, dim_out * 2) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
"""#### Forward pass for the GEGLU activation function. | |
#### Args: | |
- `x` (torch.Tensor): The input tensor. | |
#### Returns: | |
- `torch.Tensor`: The output tensor. | |
""" | |
x, gate = self.proj(x).chunk(2, dim=-1) | |
return x * torch.nn.functional.gelu(gate) | |