Spaces:
Running
Running
File size: 1,130 Bytes
306b4ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 |
# Copyright (c) 2024, Tri Dao, Albert Gu.
from torch import nn
from torch.nn import functional as F
class GatedMLP(nn.Module):
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
activation=F.silu,
bias=False,
multiple_of=128,
device=None,
dtype=None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
out_features = out_features if out_features is not None else in_features
hidden_features = (
hidden_features if hidden_features is not None else int(8 * in_features / 3)
)
hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
self.fc1 = nn.Linear(in_features, 2 * hidden_features, bias=bias, **factory_kwargs)
self.activation = activation
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, **factory_kwargs)
def forward(self, x):
y = self.fc1(x)
y, gate = y.chunk(2, dim=-1)
y = y * self.activation(gate)
y = self.fc2(y)
return y
|