File size: 4,516 Bytes
bad4ddc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# /// script
# dependencies = [
#     "torch",
#     "numpy",
#     "kernels",
# ]
# ///

import torch
from torch import nn
from torch.nn import functional as F
from kernels import get_kernel, get_local_kernel
from utils import to_dtype, tensor_stats, set_seed, bench_context
from config import (
    NUM_EXPERTS, HIDDEN_SIZE, TOP_K,
    BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE,
    WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED
)
from pathlib import Path
from collections import namedtuple
import os

# Discover the upstream artifact directory from env
data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.')

print(f"Loading weights from: {data_dir}")

router_weight = torch.load(Path(data_dir) / 'router_weight.pt')
router_bias = torch.load(Path(data_dir) / 'router_bias.pt')
gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt')
gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt')
down_proj = torch.load(Path(data_dir) / 'down_proj.pt')
down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt')

print("Loaded shared weights from artifacts")
print(f"Router weight sum: {router_weight.sum().item():.6f}")
print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}")
print(f"Down sum: {down_proj.sum().item():.6f}")

def build_megablocks_model(device: torch.device, dtype: torch.dtype):
    # Download optimized kernels from the Hugging Face hub
    megablocks = get_kernel("kernels-community/megablocks")

    # megablocks = get_local_kernel(
    #     Path("/home/ubuntu/Projects/megablocks-moe/build"), "megablocks")

    model = megablocks.layers.MegaBlocksMoeMLP()

    # Create attribute container for expert weights
    model.experts = namedtuple(
        "Experts", ["gate_up_proj", "gate_up_proj_bias", "down_proj", "down_proj_bias", "hidden_size"]
    )

    # Use loaded router weights for consistency
    model.router = torch.nn.Linear(HIDDEN_SIZE, NUM_EXPERTS, device=device, dtype=dtype)
    with torch.no_grad():
        model.router.weight.copy_(router_weight.to(dtype))
        model.router.bias.copy_(router_bias.to(dtype))

    # Attach loaded expert weights to the experts container
    e = model.experts
    e.alpha = 1.702
    e.capacity_factor = 4
    e.gate_up_proj = torch.nn.Parameter(gate_up_proj.clone().to(device, dtype=dtype))
    e.gate_up_proj_bias = torch.nn.Parameter(gate_up_proj_bias.clone().to(device, dtype=dtype))
    e.down_proj = torch.nn.Parameter(down_proj.clone().to(device, dtype=dtype))
    e.down_proj_bias = torch.nn.Parameter(down_proj_bias.clone().to(device, dtype=dtype))
    e.hidden_size = HIDDEN_SIZE
    
    # Log weight statistics for comparison
    print(f"[MegaBlocks] Router weight sum: {model.router.weight.sum().item():.6f}")
    print(f"[MegaBlocks] Gate/up projection shape: {tuple(e.gate_up_proj.shape)}, sum: {e.gate_up_proj.sum().item():.6f}")
    print(f"[MegaBlocks] Down projection shape: {tuple(e.down_proj.shape)}, sum: {e.down_proj.sum().item():.6f}")

    return model

# Create a wrapper to match the interface of other implementations
class MegaBlocksMoEWrapper(nn.Module):
    def __init__(self, megablocks_model):
        super().__init__()
        self.model = megablocks_model
        
    def forward(self, hidden_states):
        # MegaBlocks expects input in the format (batch, seq_len, hidden_dim)
        output, dummy_routing_weights = self.model(hidden_states)
        # Return output and dummy routing weights for consistency with other implementations
        # dummy_routing_weights = torch.zeros(
        #     hidden_states.shape[0] * hidden_states.shape[1], 
        #     NUM_EXPERTS, 
        #     device=hidden_states.device,
        #     dtype=hidden_states.dtype
        # )
        return output, dummy_routing_weights

# Run the model
set_seed(GENERAL_SEED)

device = torch.device(DEVICE)
dtype = to_dtype(DTYPE)

print("\n=== MegaBlocks Implementation ===")
# Build MegaBlocks model with loaded weights
megablocks_model = build_megablocks_model(device, dtype)
model = MegaBlocksMoEWrapper(megablocks_model).to(device=device, dtype=dtype)

# Benchmark the model using different input tensors on each iteration
tokens = BATCH_SIZE * SEQ_LEN
input_shape = (BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE)
with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, 
                   save_json="megablocks_results.json", input_shape=input_shape, input_seed_base=INPUT_SEED) as bench:
    output, stats = bench(model)
    print(f"\nOutput sum: {output[0].sum().item():.6f}")