Yan Bai
add
55e1701
import os
import glob
from fastapi import FastAPI, Body
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
import requests
from pydantic import BaseModel, field_validator
from typing import Optional
from mbridge import AutoBridge
from estimate import estimate_from_config
from megatron.core import parallel_state as mpu
import argparse
import json
import tempfile
# The directory of the current script (main.py)
WEBUI_DIR = os.path.dirname(os.path.abspath(__file__))
app = FastAPI()
# Mount static files from the webui directory
app.mount("/static", StaticFiles(directory=WEBUI_DIR), name="static")
@app.get("/")
async def read_index():
return FileResponse(os.path.join(WEBUI_DIR, 'index.html'))
@app.get("/style.css")
async def read_css():
return FileResponse(os.path.join(WEBUI_DIR, 'style.css'))
@app.get("/script.js")
async def read_js():
return FileResponse(os.path.join(WEBUI_DIR, 'script.js'))
SUPPORTED_MODELS = [
"Qwen/Qwen3-235B-A22B",
"Qwen/Qwen3-30B-A3B",
"Qwen/Qwen3-32B",
"Qwen/Qwen3-14B",
"Qwen/Qwen3-8B",
"Qwen/Qwen2.5-7B",
"Qwen/Qwen2.5-14B",
"Qwen/Qwen2.5-32B",
"Qwen/Qwen2.5-72B",
"moonshotai/Moonlight-16B-A3B",
"moonshotai/Kimi-K2-Instruct",
"deepseek-ai/DeepSeek-V3",
]
@app.get("/local-hf-configs")
async def get_supported_models():
"""Return the list of HF model identifiers supported by the UI."""
return SUPPORTED_MODELS
@app.get("/get-megatron-config/{model_path:path}")
async def get_remote_hf_config(model_path: str):
"""Fetch the HuggingFace config.json for the given model id."""
url = f"https://huggingface.co/{model_path}/raw/main/config.json"
try:
resp = requests.get(url, timeout=10)
resp.raise_for_status()
return resp.json()
except Exception as e:
return {"error": f"Failed to fetch config from {url}: {str(e)}"}
class MBridgeEstimateConfig(BaseModel):
hf_model_path: str
custom_hf_config: Optional[dict] = None # Renamed for clarity
# Hardware & Training
num_gpus: int = 8
mbs: int = 1
seq_len: int = 4096
use_distributed_optimizer: bool = True
# Recompute settings are now part of the main config
recompute_granularity: str = "selective"
recompute_method: str = "uniform"
recompute_num_layers: Optional[int] = 1
# Parallelism
tp: int = 1
pp: int = 1
ep: int = 1
cp: int = 1
vpp: Optional[int] = None
etp: Optional[int] = None
# Pipeline stage layer counts
num_layers_in_first_pipeline_stage: Optional[int] = None
num_layers_in_last_pipeline_stage: Optional[int] = None
@field_validator('num_gpus')
def num_gpus_must_be_multiple_of_8(cls, v):
if v <= 0 or v % 8 != 0:
raise ValueError('must be a positive multiple of 8')
return v
def patch_parallel_states(config: MBridgeEstimateConfig):
from mbridge.core.parallel_states import ParallelStates
ParallelStates.get_default_parallel_states = lambda: ParallelStates(
tp_size=config.tp,
pp_size=config.pp,
ep_size=config.ep,
cp_size=config.cp,
vpp_size=config.vpp,
etp_size=config.etp,
)
@app.post("/estimate_with_mbridge")
async def estimate_with_mbridge(config: MBridgeEstimateConfig):
# Validate Inputs
if config.num_gpus <= 0 or config.num_gpus % 8 != 0:
return {"error": "Total number of GPUs must be a positive multiple of 8."}
parallel_product = config.tp * config.pp * config.cp
if parallel_product == 0: # Avoid division by zero
return {"error": "Parallelism dimensions (TP, PP, CP) cannot be zero."}
if config.num_gpus % parallel_product != 0:
return {"error": f"Number of GPUs ({config.num_gpus}) must be divisible by the product of TP*PP*CP ({parallel_product})."}
patch_parallel_states(config)
# If the path is just a filename, assume it's in our local model-configs dir
hf_model_path = config.hf_model_path
# This logic needs to change. The custom config from the UI is an HF config, not a Megatron config.
# We need to load it via a temporary file.
if config.custom_hf_config:
try:
# Create a temporary file to save the custom HF config
with tempfile.NamedTemporaryFile(mode='w+', delete=False, suffix=".json", dir=os.path.join(WEBUI_DIR, 'model-configs')) as tmp:
json.dump(config.custom_hf_config, tmp)
tmp_path = tmp.name
# Load the bridge from the temporary config file
from transformers import AutoConfig
AutoConfig.trust_remote_code = True
bridge = AutoBridge.from_pretrained(tmp_path)
tf_config = bridge.config
hf_config = bridge.hf_config
finally:
# Ensure the temporary file is deleted
if 'tmp_path' in locals() and os.path.exists(tmp_path):
os.remove(tmp_path)
else:
# If no custom config, load from the original path
if not os.path.isabs(hf_model_path) and not hf_model_path.startswith(('http', './', '../')):
hf_model_path = os.path.join(WEBUI_DIR, 'model-configs', hf_model_path)
bridge = AutoBridge.from_pretrained(hf_model_path)
tf_config = bridge.config
hf_config = bridge.hf_config
# --- Configuration Unification ---
# Update the tf_config with values from the form. This makes tf_config the single source of truth.
tf_config.tensor_model_parallel_size = config.tp
tf_config.pipeline_model_parallel_size = config.pp
tf_config.expert_model_parallel_size = config.ep
tf_config.context_parallel_size = config.cp
tf_config.recompute_granularity = config.recompute_granularity
tf_config.recompute_method = config.recompute_method
tf_config.recompute_num_layers = config.recompute_num_layers
tf_config.num_layers_per_virtual_pipeline_stage = config.vpp if config.vpp and config.vpp > 1 else None
if config.num_layers_in_first_pipeline_stage is not None:
tf_config.num_layers_in_first_pipeline_stage = config.num_layers_in_first_pipeline_stage
if config.num_layers_in_last_pipeline_stage is not None:
tf_config.num_layers_in_last_pipeline_stage = config.num_layers_in_last_pipeline_stage
# print(tf_config)
# Create a minimal 'args' object with parameters not present in TransformerConfig
args = argparse.Namespace()
args.micro_batch_size = config.mbs
args.seq_length = config.seq_len
args.use_distributed_optimizer = config.use_distributed_optimizer
args.data_parallel_size = config.num_gpus // parallel_product
args.expert_tensor_parallel_size = config.etp if config.etp else 1
# These are required by the estimator but can be derived or defaulted
args.transformer_impl = "transformer_engine"
args.fp8 = False
args.num_experts = getattr(tf_config, 'num_moe_experts', 1) # Needed for layer spec
args.moe_grouped_gemm = True # Default
args.qk_layernorm = tf_config.qk_layernorm
args.multi_latent_attention = "deepseek" in getattr(hf_config, "model_type", "")
args.padded_vocab_size = getattr(hf_config, "vocab_size")
args.max_position_embeddings = getattr(hf_config, "max_position_embeddings")
args.tie_word_embeddings = getattr(hf_config, "tie_word_embeddings", False)
# This function now returns a list of reports, one for each PP rank
raw_reports_list = estimate_from_config(tf_config, args)
# The report from estimate.py now has the correct units (GB), so no conversion is needed.
# We just need to remove the complex 'details' part for the main display table.
processed_reports = []
for report in raw_reports_list:
# Create a copy of the report and remove the 'details' key
processed_report = report.copy()
processed_report.pop('details', None)
processed_reports.append(processed_report)
return {
"processed_report": processed_reports,
"raw_report": raw_reports_list
}