File size: 2,029 Bytes
			
			| 7fabc4d 54d2ac1 7fabc4d 54d2ac1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 | """
Patches to support multipack for mixtral
"""
import torch
def patch_mixtral_moe_forward_zero3() -> None:
    import torch.nn.functional as F
    def mlp_forward(self, hidden_states):
        current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(
            hidden_states
        )
        current_hidden_states = self.w2(current_hidden_states)
        return current_hidden_states
    # Ref. https://huggingface.co/deepseek-ai/deepseek-moe-16b-base/blob/main/modeling_deepseek.py
    def moe_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        batch_size, sequence_length, hidden_dim = hidden_states.shape
        hidden_states = hidden_states.view(-1, hidden_dim)
        # router_logits: (batch * sequence_length, n_experts)
        router_logits = self.gate(hidden_states)
        routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
        topk_weight, topk_idx = torch.topk(
            routing_weights, self.top_k, dim=-1, sorted=False
        )
        topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
        # we cast back to the input dtype
        topk_weight = topk_weight.to(hidden_states.dtype)
        hidden_states = hidden_states.repeat_interleave(self.top_k, dim=0)
        y = torch.empty_like(hidden_states)  # pylint: disable=invalid-name
        flat_topk_idx = topk_idx.view(-1)
        for i in range(self.num_experts):
            expert = self.experts[i]
            y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i])
        y = (  # pylint: disable=invalid-name
            y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)
        ).sum(dim=1)
        final_hidden_states = y.reshape(batch_size, sequence_length, hidden_dim)
        return final_hidden_states, router_logits
    from transformers.models.mixtral.modeling_mixtral import (
        MixtralBLockSparseTop2MLP,
        MixtralSparseMoeBlock,
    )
    MixtralBLockSparseTop2MLP.forward = mlp_forward
    MixtralSparseMoeBlock.forward = moe_forward
 |