MobileLLM-R1-950M-MLX / test_model.py
robbiemu's picture
add mlx and mlx-lm support
e39ff3a
import sys
from pathlib import Path
# Add the current directory to the python path to import model.py
sys.path.append(str(Path.cwd()))
from model import load_model
from mlx.utils import tree_flatten
def run_diagnostic_checks():
"""
Performs the verification checks outlined in the review.
"""
print("--- Running Diagnostic Checks ---")
# 1. Load model and check for errors
try:
model = load_model(".")
print("Successfully loaded model definition.")
except Exception as e:
print(f"Error loading model: {e}")
return
# 2. Print total parameter count
try:
params = model.parameters()
num_params = sum(p.size for _, p in tree_flatten(params))
print(f"Total number of parameters: {num_params / 1e6:.2f}M")
except Exception as e:
print(f"Error calculating parameters: {e}")
# 3. Verify MLP weight shapes
print("--- Verifying MLP Weight Shapes ---")
try:
first_block = model.layers[0]
args = model.args
print(f"use_dual_mlp detected: {args.use_dual_mlp}")
if args.use_dual_mlp:
g_up_shape = first_block.feed_forward.g_up.weight.shape
p_up_shape = first_block.feed_forward.p_up.weight.shape
print(f"Gated MLP branch (g_up) weight shape: {g_up_shape}")
print(f"Plain MLP branch (p_up) weight shape: {p_up_shape}")
assert g_up_shape == (args.intermediate_size, args.hidden_size)
assert p_up_shape == (args.intermediate_size_mlp, args.hidden_size)
print("DualMLP weight shapes are correct.")
else:
gate_proj_shape = first_block.feed_forward.gate_proj.weight.shape
up_proj_shape = first_block.feed_forward.up_proj.weight.shape
print(f"SwiGLUMLP gate_proj weight shape: {gate_proj_shape}")
print(f"SwiGLUMLP up_proj weight shape: {up_proj_shape}")
assert gate_proj_shape == (args.intermediate_size_mlp, args.hidden_size)
assert up_proj_shape == (args.intermediate_size_mlp, args.hidden_size)
print("SwiGLUMLP weight shapes are correct.")
except AttributeError as e:
print(
f"Error accessing MLP weights. It seems the structure is not as expected: {e}"
)
except AssertionError:
print("Error: MLP weight shapes do not match the configuration.")
except Exception as e:
print(f"An unexpected error occurred while verifying shapes: {e}")
# 4. Verify Embedding shape
print("--- Verifying Embedding Shape ---")
try:
embedding_shape = model.tok_embeddings.weight.shape
print(f"Embedding weight shape: {embedding_shape}")
args = model.args
print(f"Expected embedding shape: ({args.vocab_size}, {args.hidden_size})")
assert embedding_shape == (args.vocab_size, args.hidden_size)
print("Embedding shape is correct.")
except Exception as e:
print(f"An unexpected error occurred while verifying embedding shape: {e}")
print("--- Sanity Checking Loaded Weights ---")
try:
# Check expected attribute exists based on architecture
if model.args.use_dual_mlp:
_ = model.layers[0].feed_forward.g_gate.weight
_ = model.layers[0].feed_forward.g_up.weight
_ = model.layers[0].feed_forward.g_down.weight
_ = model.layers[0].feed_forward.p_up.weight
_ = model.layers[0].feed_forward.p_down.weight
print("Found dual-branch MLP weights in the model.")
else:
_ = model.layers[0].feed_forward.gate_proj.weight
_ = model.layers[0].feed_forward.up_proj.weight
_ = model.layers[0].feed_forward.down_proj.weight
print("Found SwiGLU MLP weights in the model.")
print("Weight presence sanity check passed.")
except Exception as e:
print(f"An error occurred during sanity check: {e}")
print("--- Diagnostic Checks Complete ---")
if __name__ == "__main__":
run_diagnostic_checks()