M3Site / esm /layers /ffn.py
anonymousforpaper's picture
Upload 103 files
224a33f verified
raw
history blame contribute delete
688 Bytes
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
# NOT CURRENTLY USED
class SwiGLU(nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x: Tensor) -> Tensor:
x1, x2 = x.chunk(2, dim=-1)
hidden = F.silu(x1) * x2
return hidden
class FFN(nn.Module):
def __init__(self, in_proj, activation, out_proj) -> None:
super().__init__()
self.in_proj = in_proj
self.activation = activation
self.out_proj = out_proj
def forward(self, x: Tensor) -> Tensor:
x = self.in_proj(x)
x = self.activation(x)
x = self.out_proj(x)
return x