Aatricks's picture
Upload folder using huggingface_hub
d9a2e19 verified
raw
history blame contribute delete
950 Bytes
import torch
import torch.nn as nn
from modules.cond import cast
class GEGLU(nn.Module):
"""#### Class representing the GEGLU activation function.
GEGLU is a gated activation function that is a combination of GELU and ReLU,
used to fire the neurons in the network.
#### Args:
- `dim_in` (int): The input dimension.
- `dim_out` (int): The output dimension.
"""
def __init__(self, dim_in: int, dim_out: int):
super().__init__()
self.proj = cast.manual_cast.Linear(dim_in, dim_out * 2)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""#### Forward pass for the GEGLU activation function.
#### Args:
- `x` (torch.Tensor): The input tensor.
#### Returns:
- `torch.Tensor`: The output tensor.
"""
x, gate = self.proj(x).chunk(2, dim=-1)
return x * torch.nn.functional.gelu(gate)