|
|
|
|
|
|
|
|
|
import torch |
|
from torch import nn |
|
from torch.nn import functional as F |
|
|
|
from ..utils.compile import torch_compile_lazy |
|
|
|
|
|
@torch_compile_lazy |
|
def gating_forward_kernel( |
|
weight_in: torch.Tensor, weight_out: torch.Tensor, activation, x: torch.Tensor |
|
): |
|
x = F.linear(x, weight_in) |
|
B, T, _ = x.shape |
|
x = x.view(B, T, 2, -1) |
|
x = activation(x[..., 0, :]) * x[..., 1, :] |
|
x = F.linear(x, weight_out) |
|
return x |
|
|
|
|
|
class ActivationGating(nn.Module): |
|
""" |
|
Gating FFN layer, using the given activation. |
|
Args: |
|
dim (int): dimension of the input and output of the transformer. |
|
activation (any callable Tensor to Tensor): activation function to use. |
|
**factory_kwargs: other kwargs passed to the linear layer, in particular device and dtype. |
|
""" |
|
|
|
_fsdp_final = True |
|
|
|
def __init__(self, dim: int, dim_feedforward: int, activation, **factory_kwargs): |
|
super().__init__() |
|
|
|
|
|
|
|
if dim_feedforward == 4 * dim: |
|
hidden = (21 * dim) // 8 |
|
else: |
|
hidden = (2 * dim_feedforward) // 3 |
|
self.linear_in = nn.Linear(dim, 2 * hidden, bias=False, **factory_kwargs) |
|
self.linear_out = nn.Linear(hidden, dim, bias=False, **factory_kwargs) |
|
self.activation = activation |
|
|
|
def forward(self, x: torch.Tensor): |
|
return gating_forward_kernel( |
|
self.linear_in.weight, self.linear_out.weight, self.activation, x |
|
) |
|
|
|
|
|
def _get_activation(name: str): |
|
if name in ["sigmoid", "tanh", "relu"]: |
|
return getattr(torch, name) |
|
elif name in ["leaky_relu", "elu", "gelu", "silu", "mish", "softsign"]: |
|
return getattr(torch.nn.functional, name) |
|
elif name == "identity": |
|
return torch.nn.Identity() |
|
else: |
|
raise ValueError(f"Unknown activation {name}") |
|
|
|
|
|
def _make_gating( |
|
name: str, dim: int, dim_feedforward: int, **factory_kwargs |
|
) -> nn.Module: |
|
return ActivationGating( |
|
dim, dim_feedforward, _get_activation(name), **factory_kwargs |
|
) |
|
|
|
|
|
def make_gating( |
|
name: str, dim: int, dim_feedforward: int, **factory_kwargs |
|
) -> nn.Module: |
|
gating = _make_gating(name, dim, dim_feedforward, **factory_kwargs) |
|
max_params = 2 * dim * dim_feedforward |
|
params = sum(p.numel() for p in gating.parameters()) |
|
assert ( |
|
params <= max_params |
|
), f"{name} gating has {params} params, max is {max_params}" |
|
return gating |
|
|