MobileLLM-R1-950M-MLX / mlx_technical_summary.md
robbiemu's picture
add mlx and mlx-lm support
e39ff3a

Porting MobileLLM-R1-950M to MLX and mlx-lm: Architectural Challenges and Solutions

I spent a some time pairing with Gemini 2.5 Pro and later OpenAI Codex to drag the brand-new facebook/MobileLLM-R1-950M weights onto Apple Silicon. This write-up is the “why it wasn’t copy-paste” story, plus the gotchas that bit us until the model finally spoke clean English and quantized without drama.

Goal

Enable facebook/MobileLLM-R1-950M to run natively on Apple Silicon using MLX, then create quantized versions compatible with the mlx-lm ecosystem.


1. Why a Direct "Llama-4 Drop-In" Failed

Although the Hugging Face repo presents MobileLLM-R1-950M as a Llama-4-style dense model, its config and weights don't align cleanly with a stock Llama block. The deviations aren't quirks of MLX—they reflect this model's specific architecture:

  • MLP ambiguity
    Config advertises both intermediate_size and intermediate_size_mlp, suggesting a dual-branch feed-forward.
    Actual weights contain only a SwiGLU branch (gate_proj, up_proj, down_proj).
    → Solution: auto-detect MLP variant from weight names at load time.

  • Grouped-Query Attention (GQA)
    num_attention_heads=24, num_key_value_heads=6.
    K/V tensors must be repeated to full head count for attention shapes to align correctly.

  • QK-norm and scaling
    Config includes use_qk_norm=True and attn_scale=0.1.
    We add the RMSNorm on Q/K as specified, but drop the extra 0.1 multiplier—applying it in MLX's scaled_dot_product_attention collapses logits into gibberish.

  • RoPE gating
    Config lists all layers under no_rope_layers.
    Disabling RoPE everywhere would eliminate positional encoding entirely.
    → Treat "all layers disabled" as a config artifact and apply RoPE everywhere.


2. Prompt-Level Deviations

Even after weights loaded correctly, default inference was disrupted by tokenizer settings:

  • Chat template
    Default system prompt: "Please reason step-by-step and put your final answer within \boxed{}."
    Without overrides, the model produces verbose "reasoning" outputs.
    → Added CLI controls: --system, --disable-chat-template, --final-only.

  • Double BOS
    Both tokenizer and template inserted BOS tokens.
    → Fixed with add_special_tokens=False.

  • Premature EOS
    Template headers (<|eot_id|>) were treated as stop tokens.
    → Limited stopping criteria to true EOS token only.


3. Sampling Stability

Sampling issues stemmed from API mismatches rather than model problems:

  • Top-p on probabilities then feeding mx.random.categorical produced repetition loops.
  • Solution: Apply penalties → scale logits → top-p mask (with float('-inf')) → categorical(logits).
  • Added controls for temperature, repetition penalty, frequency penalty.

4. Quantization in mlx-lm: Why Custom Metadata Was Required

mlx-lm provides quantization hooks, but MobileLLM's architecture exposed several challenges:

  1. Frozen gradients during sensitivity analysis → empty sensitivity lists.
    → Avoid freezing weights during gradient computation.

  2. Re-quantizing quantized layers → type errors on second pass.
    → Skip QuantizedLinear layers if already quantized.

  3. Embedding/norm dtype crashes
    Standard quantization re-quantized everything, but embeddings must remain float.
    → Introduced metadata-driven approach: config.json records per-layer bit-widths. Only specified layers are instantiated as QuantizedLinear.

This metadata contract allows 4-bit mixed-precision MobileLLM to be loaded cleanly by our metadata-aware custom_loader.py, making it compatible with the mlx-lm ecosystem.


5. End State

  • MLX path:
    Structural fixes (GQA, MLP detection), numerical fixes (QK-norm, RoPE, attn_scale), and prompt controls together yield fluent, stable inference.

  • mlx-lm path:
    Custom quantization pipeline produces FP16 and 4-bit models. These can be loaded with our metadata-aware custom_loader.py and used for inference with our provided scripts.
    Performance: measurable speedup and reduced VRAM usage on Apple Silicon, with minimal quality degradation.


Takeaway

The MobileLLM-R1-950M port required systematically addressing architectural mismatches (MLP variant detection, GQA handling, QK-norm implementation, RoPE configuration) and developing a metadata-driven quantization approach. Once these were resolved, the model became fully functional in MLX with both float and quantized inference paths.