Add XPU MoE decode kernel (FP16/BF16 + INT2/4/8 sym/asym + FP8)#1813
Add XPU MoE decode kernel (FP16/BF16 + INT2/4/8 sym/asym + FP8)#1813Copilot wants to merge 17 commits into
Conversation
Agent-Logs-Url: https://github.com/intel/auto-round/sessions/95841e6d-d5d1-4662-8db0-4dd69690bc28 Co-authored-by: a32543254 <53296245+a32543254@users.noreply.github.com>
Agent-Logs-Url: https://github.com/intel/auto-round/sessions/95841e6d-d5d1-4662-8db0-4dd69690bc28 Co-authored-by: a32543254 <53296245+a32543254@users.noreply.github.com>
for more information, see https://pre-commit.ci
Agent-Logs-Url: https://github.com/intel/auto-round/sessions/91221649-2c90-4404-ae86-3321b1581428 Co-authored-by: a32543254 <53296245+a32543254@users.noreply.github.com>
Agent-Logs-Url: https://github.com/intel/auto-round/sessions/91221649-2c90-4404-ae86-3321b1581428 Co-authored-by: a32543254 <53296245+a32543254@users.noreply.github.com>
|
@copilot resolve the merge conflicts in this pull request |
…ecode-implementation # Conflicts: # auto_round_extension/ark/auto_round_kernel/ark.cpp Co-authored-by: a32543254 <53296245+a32543254@users.noreply.github.com>
Merged |
There was a problem hiding this comment.
Pull request overview
This PR adds an XPU-optimized MoE decode-phase GEMV kernel (small M per expert) with multiple weight formats, and wires it through the C++/PyTorch extension layer with corresponding unit tests.
Changes:
- Added a SYCL decode GEMV kernel supporting FP16/BF16, INT8/INT4/INT2 (sym/asym), and FP8 (E4M3/E5M2) weights.
- Exposed the kernel via pybind (
moe_gemm_decode) and added a Python wrapper with argument validation. - Added unit tests covering the new decode paths and key validation error cases.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| auto_round_extension/ark/test/test_moe.py | Adds decode-path unit tests plus packing/dequant reference helpers for INT2/4/8 and FP8. |
| auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_moe_decode.hpp | Introduces the new SYCL MoE decode GEMV kernel implementations and dispatch. |
| auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_common.hpp | Declares the new moe_gemm_decode API (but docs currently lag implementation). |
| auto_round_extension/ark/auto_round_kernel/ark.cpp | Includes the new header and binds moe_gemm_decode via pybind. |
| auto_round_extension/ark/auto_round_kernel/init.py | Adds the ARK.moe_gemm_decode Python wrapper and validation logic. |
Comments suppressed due to low confidence (2)
auto_round_extension/ark/auto_round_kernel/init.py:871
- num_tokens_per_expert is converted to int32/contiguous but its device is not validated. If it’s a CPU tensor, the kernel will treat a host pointer as device memory. Please ensure num_tokens_per_expert is on XPU (and matches activations.device), or move it to XPU explicitly before calling into the extension.
weights = weights.contiguous()
if num_tokens_per_expert.dtype != torch.int32:
num_tokens_per_expert = num_tokens_per_expert.to(torch.int32)
if not num_tokens_per_expert.is_contiguous():
auto_round_extension/ark/auto_round_kernel/init.py:896
- group_size is used in modulo/division checks (e.g.,
K % group_size) without validating group_size > 0. Passing group_size=0 will raise a ZeroDivisionError rather than a clear ValueError. Please add an explicit check that group_size is a positive integer before any modulo/division operations.
if scales is None:
raise ValueError("scales is required for FP8 weights")
if scales.dtype != activations.dtype:
raise ValueError("scales dtype must match activations dtype")
if K % group_size != 0:
Agent-Logs-Url: https://github.com/intel/auto-round/sessions/132db2ab-85c0-45b6-81a7-b9baaa533e5e Co-authored-by: a32543254 <53296245+a32543254@users.noreply.github.com>
…LUT env var (default on) Agent-Logs-Url: https://github.com/intel/auto-round/sessions/0f88e20f-9644-4ebb-8cd3-2052a9f7f2e9 Co-authored-by: a32543254 <53296245+a32543254@users.noreply.github.com>
MoE GEMM Decode Performance ResultsEnvironment: XPU available, Summary
FP weightsfloat16 — ark.moe_gemm_decode vs per-expert A @ W.T
bfloat16 — ark.moe_gemm_decode vs per-expert A @ W.T
INT4 (group_size=128)INT4 sym, act=float16
INT4 sym, act=bfloat16
INT4 asym, act=float16
INT4 asym, act=bfloat16
INT8 (group_size=128)INT8 sym, act=float16
INT8 sym, act=bfloat16
INT8 asym, act=float16
INT8 asym, act=bfloat16
INT2 (group_size=128)INT2 sym, act=float16
INT2 sym, act=bfloat16
INT2 asym, act=float16
INT2 asym, act=bfloat16
FP8 (group_size=128)FP8 e4m3fn, act=float16
FP8 e4m3fn, act=bfloat16
FP8 e5m2, act=float16
FP8 e5m2, act=bfloat16
Notes
|
|
/azp run Unit-Test-CUDA-AutoRound |
|
Azure Pipelines could not run because the pipeline triggers exclude this branch/path. |
|
/azp run Unit-Test-CUDA-AutoRound |
|
Azure Pipelines could not run because the pipeline triggers exclude this branch/path. |
Signed-off-by: Dong, Bo1 <bo1.dong@intel.com>
|
/azp run Unit-Test-CUDA-AutoRound |
|
Azure Pipelines could not run because the pipeline triggers exclude this branch/path. |
|
/azp run Unit-Test-CUDA-AutoRound |
|
Azure Pipelines could not run because the pipeline triggers exclude this branch/path. |
Switch FP8 decode LUT to env-var runtime control (LUT on by default)
auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_moe_decode.hppdecode_fp8_e4m3_lut/decode_fp8_e5m2_lutanddecode_fp8_e4m3_bits/decode_fp8_e5m2_bitsdecode_fp8<IsE4M3, UseLut>selects branch withif constexpr(no per-element runtime cost)launch_fp8gains aUseLutbool template parameter; kernel name tagMoEDecodeKernelFP8updated accordinglyfp8_decode_use_lut()reads env varARK_FP8_DECODE_USE_LUTonce (cachedstatic); default = ON;"0"/"false"/"off"/"no"(case-insensitive) → OFFmoe_gemm_decodeFP8 path picksUseLut=true/falseper launch for all 4 combinations (FP16/BF16 × E4M3/E5M2)Usage