[Metal][Performance]: Add split-K for quantized matmul (small M)#3120
[Metal][Performance]: Add split-K for quantized matmul (small M)#3120angeloskath merged 2 commits intoml-explore:mainfrom
Conversation
|
As a quick note on why these specific dimensions matter: The performance bottleneck for small |
|
Thanks that is great! I 'll take a look asap. |
angeloskath
left a comment
There was a problem hiding this comment.
This is great but unfortunately unfinished.
The fp quantizations are not implemented (it should be trivial to add based on the qmm_t_splitk_impl), the qmv split-k is not used by anything so it can be removed for starters or if you want you can finish the implementation.
Finally, you don't need a qmm_t_splitk_impl and a qmv_splitk_impl the point of qmv_impl, qmv_fast_impl and qmm_t_impl are that one can adjust the input matrix offsets and call the implementation. See the qvm_splitk for an example.
e7c8390 to
884b86f
Compare
|
Addressed all feedback:
Build passes, pre-commit clean, benchmark confirms split-K working correctly for both affine and fp paths. Benchmark Results (M1/M2/M3 / applegpu_g15s)Device: applegpu_g15s Memory: 52 GBqmv_batch_limit(D=4096, O=4096) = 12 ============================================================ 1 0.475ms 0.036ms 0.08x gemv qmv (Tested across affine, mxfp8, and mxfp4 - truncated for brevity but all show similar smooth transitions). Let me know if there is anything I missed. |
jagrit06
left a comment
There was a problem hiding this comment.
Requesting a few small changes, thanks for putting this together!
Add a split-K variant for quantized matrix multiplication that partitions the K dimension across threadgroups when GPU occupancy is low (small M). - Reuse qmm_t_impl with a K_eff parameter for the loop bound, pre-offset pointers in the splitk wrapper (following qvm_splitk pattern) - Remove unused qmv_split_k code - Add fp quantization support (fp_qmm_t_splitk) - Dynamic split_k selection targeting ~512 threadgroups - Fallback to regular qmm when split_k <= 1
884b86f to
c83ba7f
Compare
angeloskath
left a comment
There was a problem hiding this comment.
Looks good to me. @jagrit06 feel free to run some benchmarks and tests and then merge it.
jagrit06
left a comment
There was a problem hiding this comment.
There's no changes needed that should hold back a merge!
|
This directly benefits MTP speculative decoding verification on MoE models — exactly the small-M quantized matmul pattern where GPU underutilization is worst. I run Qwen3.5-122B (5-bit MoE, 10B active) with MTP on M2 Ultra. The verification step does batched quantized matmuls at M=8-16 where this split-K work would help most. I also have SpecPrefill (attention-based sparse prefill) which uses the draft model for Q@K^T scoring at similar small batch sizes. Both reviewers have approved and the benchmarks look solid. Happy to provide M2 Ultra numbers once this lands. Would be great to see this merged — it's been approved since early March. |
|
Gentle nudge — this has been approved by both @angeloskath and @jagrit06 for 10 days now with all CI green. Is there anything blocking the merge? Happy to help with any remaining benchmarking if needed. |
|
Sorry, nothing was really keeping us from merging it, I thought Jagrit would merge it since he was the last to review (perhaps he thought I would merge it 😅) |
Proposed changes
In issue #3086, it was observed that the quantized
qmmkernel severely underutilizes the GPU for smallM(e.g.,M=12-32). For example, a configuration ofD=2560andM=12yields only 80 threadgroups (assumingBM=BN=32), which is insufficient to saturate the GPU grid.This PR introduces a split-K variant (
qmm_t_splitk) that partitions theKdimension across multiple threadgroups. This safely improves GPU occupancy and execution speed for small-batch inference scenarios, while falling back to the standard kernel for larger batches to prevent any performance regression.What changed
qmm_t_splitk) in the Metal backend, conceptually similar to the existingfp16steel_gemm_splitk.quantized.cppto dynamically calculate the split factor, targeting ~512 threadgroups for optimal occupancy.qmmkernel whensplit_k <= 1(e.g., for largeM).group_size=64):D=2560, M=12: 0.079ms -> 0.055ms (~30% faster)D=4096, M=16: 0.155ms -> 0.117ms (~25% faster)Mconfigurations.Checklist
Put an
xin the boxes that apply.pre-commit run --all-filesto format my code / installed pre-commit prior to committing changes