# Copyright (c) 2024, Tri Dao, Albert Gu. from torch import nn from torch.nn import functional as F class GatedMLP(nn.Module): def __init__( self, in_features, hidden_features=None, out_features=None, activation=F.silu, bias=False, multiple_of=128, device=None, dtype=None, ): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() out_features = out_features if out_features is not None else in_features hidden_features = ( hidden_features if hidden_features is not None else int(8 * in_features / 3) ) hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of self.fc1 = nn.Linear(in_features, 2 * hidden_features, bias=bias, **factory_kwargs) self.activation = activation self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, **factory_kwargs) def forward(self, x): y = self.fc1(x) y, gate = y.chunk(2, dim=-1) y = y * self.activation(gate) y = self.fc2(y) return y