Skip to content

[Metal][Performance]: Add split-K for quantized matmul (small M)#3120

Merged
angeloskath merged 2 commits intoml-explore:mainfrom
Ziqiao-git:qmm-splitk-small-m
Mar 21, 2026
Merged

[Metal][Performance]: Add split-K for quantized matmul (small M)#3120
angeloskath merged 2 commits intoml-explore:mainfrom
Ziqiao-git:qmm-splitk-small-m

Conversation

@Ziqiao-git
Copy link
Copy Markdown
Contributor

@Ziqiao-git Ziqiao-git commented Feb 12, 2026

Proposed changes

In issue #3086, it was observed that the quantized qmm kernel severely underutilizes the GPU for small M (e.g., M=12-32). For example, a configuration of D=2560 and M=12 yields only 80 threadgroups (assuming BM=BN=32), which is insufficient to saturate the GPU grid.

This PR introduces a split-K variant (qmm_t_splitk) that partitions the K dimension across multiple threadgroups. This safely improves GPU occupancy and execution speed for small-batch inference scenarios, while falling back to the standard kernel for larger batches to prevent any performance regression.

What changed

  • Added a split-K variant of the quantized matrix multiplication kernel (qmm_t_splitk) in the Metal backend, conceptually similar to the existing fp16 steel_gemm_splitk.
  • Updated the dispatch logic in quantized.cpp to dynamically calculate the split factor, targeting ~512 threadgroups for optimal occupancy.
  • Added a fallback mechanism that automatically routes to the regular qmm kernel when split_k <= 1 (e.g., for large M).
  • Verified performance gains on Apple M3 Max (4-bit, group_size=64):
    • D=2560, M=12: 0.079ms -> 0.055ms (~30% faster)
    • D=4096, M=16: 0.155ms -> 0.117ms (~25% faster)
    • No regressions observed for large M configurations.
  • Verified correctness: Existing tests cover the new code path (27 tests, 1639 subtests pass).

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

@Ziqiao-git
Copy link
Copy Markdown
Contributor Author

As a quick note on why these specific dimensions matter: The performance bottleneck for small $M$ sizes ($12 \sim 32$) directly impacts the verification step in Speculative Decoding.
By fixing the GPU underutilization here, we significantly speed up the time it takes to evaluate draft tokens. Given how important speculative decoding is for pushing the limits of inference speed on edge devices, this change should provide a meaningful boost to overall generation latency, making the backend more robust for future speculative decoding implementations.

@angeloskath
Copy link
Copy Markdown
Member

Thanks that is great! I 'll take a look asap.

@Ziqiao-git Ziqiao-git changed the title metal: Add split-K for quantized matmul (small M) [Metal][Performance]: Add split-K for quantized matmul (small M) Feb 13, 2026
Copy link
Copy Markdown
Member

@angeloskath angeloskath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great but unfortunately unfinished.

The fp quantizations are not implemented (it should be trivial to add based on the qmm_t_splitk_impl), the qmv split-k is not used by anything so it can be removed for starters or if you want you can finish the implementation.

Finally, you don't need a qmm_t_splitk_impl and a qmv_splitk_impl the point of qmv_impl, qmv_fast_impl and qmm_t_impl are that one can adjust the input matrix offsets and call the implementation. See the qvm_splitk for an example.

@Ziqiao-git Ziqiao-git force-pushed the qmm-splitk-small-m branch 2 times, most recently from e7c8390 to 884b86f Compare February 24, 2026 01:27
@Ziqiao-git
Copy link
Copy Markdown
Contributor Author

Addressed all feedback:

  1. Removed unused qmv_split_k code (impl, kernel, macro, and dispatch function)
  2. Removed qmm_t_splitk_implaffine_qmm_t_splitk now pre-offsets pointers and calls qmm_t_impl directly (added K_eff parameter for loop bound, following the qvm_splitk pattern)
  3. Added fp quantization support: fp_qmm_t_impl also takes K_eff, new fp_qmm_t_splitk kernel wrapper, and instantiation macros in fp_quantized.metal

Build passes, pre-commit clean, benchmark confirms split-K working correctly for both affine and fp paths.

Benchmark Results (M1/M2/M3 / applegpu_g15s) Device: applegpu_g15s Memory: 52 GB

qmv_batch_limit(D=4096, O=4096) = 12

============================================================
D=4096 mode=mxfp8 (bits=8, group_size=32)
M fp16 quant ratio fp16 kernel quant kernel


1 0.475ms 0.036ms 0.08x gemv qmv
2 0.519ms 0.061ms 0.12x split-K qmv
4 0.533ms 0.111ms 0.21x split-K qmv
8 0.516ms 0.212ms 0.41x split-K qmv
10 0.520ms 0.262ms 0.50x split-K qmv
12 0.518ms 0.121ms 0.23x split-K qmm_splitk
14 0.521ms 0.122ms 0.23x split-K qmm_splitk
16 0.515ms 0.122ms 0.24x split-K qmm_splitk
20 0.584ms 0.121ms 0.21x split-K qmm_splitk
32 0.526ms 0.125ms 0.24x split-K qmm_splitk
2048 6.318ms 6.646ms 1.05x regular GEMM qmm

(Tested across affine, mxfp8, and mxfp4 - truncated for brevity but all show similar smooth transitions).

Let me know if there is anything I missed.

Copy link
Copy Markdown
Member

@jagrit06 jagrit06 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Requesting a few small changes, thanks for putting this together!

Comment thread mlx/backend/metal/kernels/fp_quantized.h Outdated
Comment thread mlx/backend/metal/kernels/quantized.h Outdated
Comment thread mlx/backend/metal/quantized.cpp
Add a split-K variant for quantized matrix multiplication that
partitions the K dimension across threadgroups when GPU occupancy
is low (small M).

- Reuse qmm_t_impl with a K_eff parameter for the loop bound,
  pre-offset pointers in the splitk wrapper (following qvm_splitk pattern)
- Remove unused qmv_split_k code
- Add fp quantization support (fp_qmm_t_splitk)
- Dynamic split_k selection targeting ~512 threadgroups
- Fallback to regular qmm when split_k <= 1
Comment thread mlx/backend/metal/quantized.cpp
@Ziqiao-git Ziqiao-git requested a review from jagrit06 March 10, 2026 23:02
Copy link
Copy Markdown
Member

@angeloskath angeloskath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me. @jagrit06 feel free to run some benchmarks and tests and then merge it.

Copy link
Copy Markdown
Member

@jagrit06 jagrit06 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no changes needed that should hold back a merge!

@Thump604
Copy link
Copy Markdown

This directly benefits MTP speculative decoding verification on MoE models — exactly the small-M quantized matmul pattern where GPU underutilization is worst.

I run Qwen3.5-122B (5-bit MoE, 10B active) with MTP on M2 Ultra. The verification step does batched quantized matmuls at M=8-16 where this split-K work would help most. I also have SpecPrefill (attention-based sparse prefill) which uses the draft model for Q@K^T scoring at similar small batch sizes.

Both reviewers have approved and the benchmarks look solid. Happy to provide M2 Ultra numbers once this lands. Would be great to see this merged — it's been approved since early March.

@Thump604
Copy link
Copy Markdown

Gentle nudge — this has been approved by both @angeloskath and @jagrit06 for 10 days now with all CI green. Is there anything blocking the merge? Happy to help with any remaining benchmarking if needed.

@angeloskath angeloskath merged commit 38ad257 into ml-explore:main Mar 21, 2026
16 checks passed
@angeloskath
Copy link
Copy Markdown
Member

Sorry, nothing was really keeping us from merging it, I thought Jagrit would merge it since he was the last to review (perhaps he thought I would merge it 😅)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants