|
from typing import Callable |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
def softmax(x: torch.Tensor, dim: int | tuple[int, ...]) -> torch.Tensor: |
|
""" |
|
Compute the softmax along the specified dimensions. |
|
This function adds the option to specify multiple dimensions |
|
|
|
Args: |
|
x (torch.Tensor): Input tensor. |
|
dims (int or tuple[int]): The dimension or list of dimensions along which the softmax probabilities are computed. |
|
|
|
Returns: |
|
torch.Tensor: Output tensor containing softmax probabilities along the specified dimensions. |
|
""" |
|
dtype = x.dtype |
|
x = x.to(torch.float32) |
|
max_vals = torch.amax(x, dim=dim, keepdim=True) |
|
e_x = torch.exp(x - max_vals) |
|
sum_exp = e_x.sum(dim=dim, keepdim=True) |
|
return (e_x / sum_exp).to(dtype) |
|
|
|
|
|
class SteelSoftMoEV3(nn.Module): |
|
""" |
|
A wrapper class to create a Soft Mixture of Experts layer. |
|
|
|
From "From Sparse to Soft Mixtures of Experts" |
|
https://arxiv.org/pdf/2308.00951.pdf |
|
""" |
|
|
|
def __init__( |
|
self, |
|
config, |
|
layer: Callable, |
|
) -> None: |
|
""" |
|
Args: |
|
dim (int): Dimensionality of input features. |
|
num_experts (int): Number of experts. |
|
slots_per_expert (int): Number of token slots per expert. |
|
layer (Callable): Network layer of the experts. |
|
normalize (bool): Normalize input and phi (sec. 2.3 from paper) |
|
**layer_kwargs: Additional keyword arguments for the layer class. |
|
""" |
|
super().__init__() |
|
|
|
self.dim = config.hidden_size |
|
self.num_experts = config.n_experts |
|
self.slots_per_expert = config.slots_per_expert if hasattr(config, "slots_per_expert") else 1 |
|
self.normalize = True |
|
|
|
|
|
self.phi = nn.Parameter(torch.zeros(self.dim, self.num_experts, self.slots_per_expert)) |
|
if self.normalize: |
|
self.scale = nn.Parameter(torch.ones(1)) |
|
|
|
|
|
|
|
nn.init.normal_(self.phi, mean=0, std=1 / self.dim**0.5) |
|
|
|
|
|
self.experts = nn.ModuleList( |
|
[layer(config) for _ in range(self.num_experts)] |
|
) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Forward pass through the Soft-MoE layer (algorithm 1 from paper). |
|
|
|
Args: |
|
x (torch.Tensor): Input tensor of shape [batch_size, seq_len, input_dim]. |
|
|
|
Returns: |
|
torch.Tensor: Output tensor of shape [batch_size, seq_len, input_dim]. |
|
""" |
|
assert ( |
|
x.shape[-1] == self.dim |
|
), f"Input feature dim of {x.shape[-1]} does not match layer dim of {self.dim}" |
|
assert ( |
|
len(x.shape) == 3 |
|
), f"Input expected to have 3 dimensions but has {len(x.shape)}" |
|
|
|
phi = self.phi |
|
|
|
|
|
if self.normalize: |
|
x = F.normalize(x, dim=2) |
|
phi = self.scale * F.normalize(phi, dim=0) |
|
|
|
|
|
logits = torch.einsum("bmd,dnp->bmnp", x, phi) |
|
d = softmax(logits, dim=1) |
|
c = softmax(logits, dim=(2, 3)) |
|
|
|
|
|
|
|
xs = torch.einsum("bmd,bmnp->bnpd", x, d) |
|
|
|
|
|
ys = torch.stack( |
|
[f_i(xs[:, i, :, :]) for i, f_i in enumerate(self.experts)], dim=1 |
|
) |
|
|
|
|
|
y = torch.einsum("bnpd,bmnp->bmd", ys, c) |
|
|
|
return y |