Spaces:
Running
Running
# Copyright (c) 2024, Tri Dao, Albert Gu. | |
from torch import nn | |
from torch.nn import functional as F | |
class GatedMLP(nn.Module): | |
def __init__( | |
self, | |
in_features, | |
hidden_features=None, | |
out_features=None, | |
activation=F.silu, | |
bias=False, | |
multiple_of=128, | |
device=None, | |
dtype=None, | |
): | |
factory_kwargs = {"device": device, "dtype": dtype} | |
super().__init__() | |
out_features = out_features if out_features is not None else in_features | |
hidden_features = ( | |
hidden_features if hidden_features is not None else int(8 * in_features / 3) | |
) | |
hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of | |
self.fc1 = nn.Linear(in_features, 2 * hidden_features, bias=bias, **factory_kwargs) | |
self.activation = activation | |
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, **factory_kwargs) | |
def forward(self, x): | |
y = self.fc1(x) | |
y, gate = y.chunk(2, dim=-1) | |
y = y * self.activation(gate) | |
y = self.fc2(y) | |
return y | |