File size: 8,105 Bytes
55e1701
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
import os
import glob
from fastapi import FastAPI, Body
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
import requests
from pydantic import BaseModel, field_validator
from typing import Optional
from mbridge import AutoBridge
from estimate import estimate_from_config
from megatron.core import parallel_state as mpu
import argparse
import json
import tempfile

# The directory of the current script (main.py)
WEBUI_DIR = os.path.dirname(os.path.abspath(__file__))

app = FastAPI()

# Mount static files from the webui directory
app.mount("/static", StaticFiles(directory=WEBUI_DIR), name="static")


@app.get("/")
async def read_index():
    return FileResponse(os.path.join(WEBUI_DIR, 'index.html'))

@app.get("/style.css")
async def read_css():
    return FileResponse(os.path.join(WEBUI_DIR, 'style.css'))

@app.get("/script.js")
async def read_js():
    return FileResponse(os.path.join(WEBUI_DIR, 'script.js'))


SUPPORTED_MODELS = [
    "Qwen/Qwen3-235B-A22B",
    "Qwen/Qwen3-30B-A3B",
    "Qwen/Qwen3-32B",
    "Qwen/Qwen3-14B",
    "Qwen/Qwen3-8B",
    "Qwen/Qwen2.5-7B",
    "Qwen/Qwen2.5-14B",
    "Qwen/Qwen2.5-32B",
    "Qwen/Qwen2.5-72B",
    "moonshotai/Moonlight-16B-A3B",
    "moonshotai/Kimi-K2-Instruct",
    "deepseek-ai/DeepSeek-V3",
]


@app.get("/local-hf-configs")
async def get_supported_models():
    """Return the list of HF model identifiers supported by the UI."""
    return SUPPORTED_MODELS

@app.get("/get-megatron-config/{model_path:path}")
async def get_remote_hf_config(model_path: str):
    """Fetch the HuggingFace config.json for the given model id."""
    url = f"https://huggingface.co/{model_path}/raw/main/config.json"
    try:
        resp = requests.get(url, timeout=10)
        resp.raise_for_status()
        return resp.json()
    except Exception as e:
        return {"error": f"Failed to fetch config from {url}: {str(e)}"}


class MBridgeEstimateConfig(BaseModel):
    hf_model_path: str
    custom_hf_config: Optional[dict] = None # Renamed for clarity
    
    # Hardware & Training
    num_gpus: int = 8
    mbs: int = 1
    seq_len: int = 4096
    use_distributed_optimizer: bool = True
    # Recompute settings are now part of the main config
    recompute_granularity: str = "selective"
    recompute_method: str = "uniform"
    recompute_num_layers: Optional[int] = 1

    # Parallelism
    tp: int = 1
    pp: int = 1
    ep: int = 1
    cp: int = 1
    vpp: Optional[int] = None
    etp: Optional[int] = None

    # Pipeline stage layer counts
    num_layers_in_first_pipeline_stage: Optional[int] = None
    num_layers_in_last_pipeline_stage: Optional[int] = None

    @field_validator('num_gpus')
    def num_gpus_must_be_multiple_of_8(cls, v):
        if v <= 0 or v % 8 != 0:
            raise ValueError('must be a positive multiple of 8')
        return v

def patch_parallel_states(config: MBridgeEstimateConfig):
    from mbridge.core.parallel_states import ParallelStates
    ParallelStates.get_default_parallel_states = lambda: ParallelStates(
        tp_size=config.tp,
        pp_size=config.pp,
        ep_size=config.ep,
        cp_size=config.cp,
        vpp_size=config.vpp,
        etp_size=config.etp,
    )

@app.post("/estimate_with_mbridge")
async def estimate_with_mbridge(config: MBridgeEstimateConfig):
    # Validate Inputs
    if config.num_gpus <= 0 or config.num_gpus % 8 != 0:
        return {"error": "Total number of GPUs must be a positive multiple of 8."}
    
    parallel_product = config.tp * config.pp * config.cp
    if parallel_product == 0: # Avoid division by zero
        return {"error": "Parallelism dimensions (TP, PP, CP) cannot be zero."}
    
    if config.num_gpus % parallel_product != 0:
        return {"error": f"Number of GPUs ({config.num_gpus}) must be divisible by the product of TP*PP*CP ({parallel_product})."}

    patch_parallel_states(config)
    
    # If the path is just a filename, assume it's in our local model-configs dir
    hf_model_path = config.hf_model_path
    # This logic needs to change. The custom config from the UI is an HF config, not a Megatron config.
    # We need to load it via a temporary file.
    if config.custom_hf_config:
        try:
            # Create a temporary file to save the custom HF config
            with tempfile.NamedTemporaryFile(mode='w+', delete=False, suffix=".json", dir=os.path.join(WEBUI_DIR, 'model-configs')) as tmp:
                json.dump(config.custom_hf_config, tmp)
                tmp_path = tmp.name
            
            # Load the bridge from the temporary config file
            from transformers import AutoConfig
            AutoConfig.trust_remote_code = True
            bridge = AutoBridge.from_pretrained(tmp_path)
            tf_config = bridge.config
            hf_config = bridge.hf_config

        finally:
            # Ensure the temporary file is deleted
            if 'tmp_path' in locals() and os.path.exists(tmp_path):
                os.remove(tmp_path)
    else:
        # If no custom config, load from the original path
        if not os.path.isabs(hf_model_path) and not hf_model_path.startswith(('http', './', '../')):
            hf_model_path = os.path.join(WEBUI_DIR, 'model-configs', hf_model_path)
        bridge = AutoBridge.from_pretrained(hf_model_path)
        tf_config = bridge.config
        hf_config = bridge.hf_config

    # --- Configuration Unification ---
    # Update the tf_config with values from the form. This makes tf_config the single source of truth.
    tf_config.tensor_model_parallel_size = config.tp
    tf_config.pipeline_model_parallel_size = config.pp
    tf_config.expert_model_parallel_size = config.ep
    tf_config.context_parallel_size = config.cp
    tf_config.recompute_granularity = config.recompute_granularity
    tf_config.recompute_method = config.recompute_method
    tf_config.recompute_num_layers = config.recompute_num_layers
    tf_config.num_layers_per_virtual_pipeline_stage = config.vpp if config.vpp and config.vpp > 1 else None
    
    if config.num_layers_in_first_pipeline_stage is not None:
        tf_config.num_layers_in_first_pipeline_stage = config.num_layers_in_first_pipeline_stage
    if config.num_layers_in_last_pipeline_stage is not None:
        tf_config.num_layers_in_last_pipeline_stage = config.num_layers_in_last_pipeline_stage
    # print(tf_config)

    # Create a minimal 'args' object with parameters not present in TransformerConfig
    args = argparse.Namespace()
    args.micro_batch_size = config.mbs
    args.seq_length = config.seq_len
    args.use_distributed_optimizer = config.use_distributed_optimizer
    args.data_parallel_size = config.num_gpus // parallel_product
    args.expert_tensor_parallel_size = config.etp if config.etp else 1

    # These are required by the estimator but can be derived or defaulted
    args.transformer_impl = "transformer_engine"
    args.fp8 = False
    args.num_experts = getattr(tf_config, 'num_moe_experts', 1) # Needed for layer spec
    args.moe_grouped_gemm = True # Default
    args.qk_layernorm = tf_config.qk_layernorm
    args.multi_latent_attention = "deepseek" in getattr(hf_config, "model_type", "")
    args.padded_vocab_size = getattr(hf_config, "vocab_size")
    args.max_position_embeddings = getattr(hf_config, "max_position_embeddings")
    args.tie_word_embeddings = getattr(hf_config, "tie_word_embeddings", False)


    # This function now returns a list of reports, one for each PP rank
    raw_reports_list = estimate_from_config(tf_config, args)

    # The report from estimate.py now has the correct units (GB), so no conversion is needed.
    # We just need to remove the complex 'details' part for the main display table.
    processed_reports = []
    for report in raw_reports_list:
        # Create a copy of the report and remove the 'details' key
        processed_report = report.copy()
        processed_report.pop('details', None)
        processed_reports.append(processed_report)

    return {
        "processed_report": processed_reports,
        "raw_report": raw_reports_list
    }