| # Copyright 2024 Databricks | |
| # SPDX-License-Identifier: Apache-2.0 | |
| import torch | |
| import torch.nn.functional as F | |
| from megablocks.layers.arguments import Arguments | |
| class FFN(torch.nn.Module): | |
| def __init__(self, args: Arguments): | |
| super().__init__() | |
| self.w1 = torch.nn.Parameter( | |
| torch.empty( | |
| args.hidden_size, | |
| args.ffn_hidden_size, | |
| device=args.device, | |
| dtype=torch.float16 if args.fp16 else torch.float32, | |
| ), | |
| ) | |
| self.w2 = torch.nn.Parameter( | |
| torch.empty( | |
| args.ffn_hidden_size, | |
| args.hidden_size, | |
| device=args.device, | |
| dtype=torch.float16 if args.fp16 else torch.float32, | |
| ), | |
| ) | |
| def forward(self, x): | |
| return torch.matmul( | |
| F.gelu(torch.matmul(x, self.w1), approximate='tanh'), | |
| self.w2, | |
| ) | |
| class GLU(FFN): | |
| def __init__(self, args: Arguments): | |
| super().__init__(args) | |
| self.v1 = torch.nn.Parameter( | |
| torch.empty( | |
| args.hidden_size, | |
| args.ffn_hidden_size, | |
| device=args.device, | |
| dtype=torch.float16 if args.fp16 else torch.float32, | |
| ), | |
| ) | |
| def forward(self, x): | |
| x1 = F.gelu(torch.matmul(x, self.w1), approximate='tanh') * torch.matmul(x, self.v1) | |
| return torch.matmul(x1, self.w2) | |
