Skip to content

Add XPU MoE decode kernel (FP16/BF16 + INT2/4/8 sym/asym + FP8)#1813

Open
Copilot wants to merge 17 commits into
mainfrom
copilot/add-xpu-moe-decode-implementation
Open

Add XPU MoE decode kernel (FP16/BF16 + INT2/4/8 sym/asym + FP8)#1813
Copilot wants to merge 17 commits into
mainfrom
copilot/add-xpu-moe-decode-implementation

Conversation

Copy link
Copy Markdown
Contributor

Copilot AI commented May 14, 2026

Switch FP8 decode LUT to env-var runtime control (LUT on by default)

  • Replace compile-time macro with env-var runtime switch in auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_moe_decode.hpp
    • Both decode paths always compiled: decode_fp8_e4m3_lut / decode_fp8_e5m2_lut and decode_fp8_e4m3_bits / decode_fp8_e5m2_bits
    • Templated dispatch helper decode_fp8<IsE4M3, UseLut> selects branch with if constexpr (no per-element runtime cost)
    • launch_fp8 gains a UseLut bool template parameter; kernel name tag MoEDecodeKernelFP8 updated accordingly
    • Host helper fp8_decode_use_lut() reads env var ARK_FP8_DECODE_USE_LUT once (cached static); default = ON; "0"/"false"/"off"/"no" (case-insensitive) → OFF
    • moe_gemm_decode FP8 path picks UseLut=true/false per launch for all 4 combinations (FP16/BF16 × E4M3/E5M2)
    • Updated header comments

Usage

# default: LUT on
./your_app
# disable LUT, use inline bit-manip decode
ARK_FP8_DECODE_USE_LUT=0 ./your_app

Copilot AI and others added 2 commits May 14, 2026 04:03
Copilot AI and others added 2 commits May 14, 2026 07:16
@a32543254 a32543254 changed the title Add XPU MoE decode kernel (FP16/BF16 + INT4 sym/asym) Add XPU MoE decode kernel (FP16/BF16 + INT2/4/8 sym/asym + FP8) May 15, 2026
@a32543254 a32543254 requested a review from luoyu-intel May 18, 2026 15:33
@a32543254 a32543254 marked this pull request as ready for review May 18, 2026 15:33
Copilot AI review requested due to automatic review settings May 18, 2026 15:33
@a32543254
Copy link
Copy Markdown

@copilot resolve the merge conflicts in this pull request

…ecode-implementation

# Conflicts:
#	auto_round_extension/ark/auto_round_kernel/ark.cpp

Co-authored-by: a32543254 <53296245+a32543254@users.noreply.github.com>
Copy link
Copy Markdown
Contributor Author

Copilot AI commented May 18, 2026

@copilot resolve the merge conflicts in this pull request

Merged origin/main in 430868d. The only conflict was in auto_round_extension/ark/auto_round_kernel/ark.cpp — I kept the new moe_gemm_decode_wrapper from this branch alongside the updated sage_dynamic_quant signature (added bias param) from main.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds an XPU-optimized MoE decode-phase GEMV kernel (small M per expert) with multiple weight formats, and wires it through the C++/PyTorch extension layer with corresponding unit tests.

Changes:

  • Added a SYCL decode GEMV kernel supporting FP16/BF16, INT8/INT4/INT2 (sym/asym), and FP8 (E4M3/E5M2) weights.
  • Exposed the kernel via pybind (moe_gemm_decode) and added a Python wrapper with argument validation.
  • Added unit tests covering the new decode paths and key validation error cases.

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
auto_round_extension/ark/test/test_moe.py Adds decode-path unit tests plus packing/dequant reference helpers for INT2/4/8 and FP8.
auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_moe_decode.hpp Introduces the new SYCL MoE decode GEMV kernel implementations and dispatch.
auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_common.hpp Declares the new moe_gemm_decode API (but docs currently lag implementation).
auto_round_extension/ark/auto_round_kernel/ark.cpp Includes the new header and binds moe_gemm_decode via pybind.
auto_round_extension/ark/auto_round_kernel/init.py Adds the ARK.moe_gemm_decode Python wrapper and validation logic.
Comments suppressed due to low confidence (2)

auto_round_extension/ark/auto_round_kernel/init.py:871

  • num_tokens_per_expert is converted to int32/contiguous but its device is not validated. If it’s a CPU tensor, the kernel will treat a host pointer as device memory. Please ensure num_tokens_per_expert is on XPU (and matches activations.device), or move it to XPU explicitly before calling into the extension.
            weights = weights.contiguous()

        if num_tokens_per_expert.dtype != torch.int32:
            num_tokens_per_expert = num_tokens_per_expert.to(torch.int32)
        if not num_tokens_per_expert.is_contiguous():

auto_round_extension/ark/auto_round_kernel/init.py:896

  • group_size is used in modulo/division checks (e.g., K % group_size) without validating group_size > 0. Passing group_size=0 will raise a ZeroDivisionError rather than a clear ValueError. Please add an explicit check that group_size is a positive integer before any modulo/division operations.
            if scales is None:
                raise ValueError("scales is required for FP8 weights")
            if scales.dtype != activations.dtype:
                raise ValueError("scales dtype must match activations dtype")
            if K % group_size != 0:

Comment thread auto_round_extension/ark/auto_round_kernel/__init__.py
…LUT env var (default on)

Agent-Logs-Url: https://github.com/intel/auto-round/sessions/0f88e20f-9644-4ebb-8cd3-2052a9f7f2e9

Co-authored-by: a32543254 <53296245+a32543254@users.noreply.github.com>
@a32543254
Copy link
Copy Markdown

a32543254 commented May 27, 2026

MoE GEMM Decode Performance Results

Environment: XPU available, has_moe_gemm_decode=True · pytest 9.0.2 · Python 3.13.12
Comparison: ark.moe_gemm_decode vs baseline (per-expert A @ W.T, with dequant for quantized cases)
Result: ✅ All 18 tests PASSED

Summary

Variant Best speedup Worst speedup
FP16 3.66x 1.05x
BF16 3.34x 1.11x
INT4 sym/asym 3.31x 1.01x
INT8 sym/asym 3.27x 1.01x
INT2 sym/asym 3.36x 1.01x
FP8 e4m3fn 3.09x 0.74x ⚠️
FP8 e5m2 3.16x 0.90x ⚠️

⚠️ FP8 shows regression on the ffn-dn shape (E=8, N=4096, K=14336).


FP weights

float16 — ark.moe_gemm_decode vs per-expert A @ W.T
shape N K tokens baseline (ms) ark (ms) speedup
small E=4 1024 1024 3 0.6804 0.5403 1.26x
medium E=8 2048 2048 6 2.7801 0.7590 3.66x
large E=8 4096 4096 6 3.0044 1.1605 2.59x
ffn-up E=8 14336 4096 6 4.2550 1.8198 2.34x
ffn-dn E=8 4096 14336 6 2.8281 2.7060 1.05x
bfloat16 — ark.moe_gemm_decode vs per-expert A @ W.T
shape N K tokens baseline (ms) ark (ms) speedup
small E=4 1024 1024 3 0.4270 0.1379 3.10x
medium E=8 2048 2048 6 0.8708 0.2604 3.34x
large E=8 4096 4096 6 1.3611 0.8015 1.70x
ffn-up E=8 14336 4096 6 16.7235 7.5298 2.22x
ffn-dn E=8 4096 14336 6 10.1009 9.1060 1.11x

INT4 (group_size=128)

INT4 sym, act=float16
shape N K tokens baseline (ms) ark (ms) speedup
small E=4 1024 1024 3 4.5283 4.4911 1.01x
medium E=8 2048 2048 6 14.7504 4.5228 3.26x
large E=8 4096 4096 6 15.2537 4.8477 3.15x
ffn-up E=8 14336 4096 6 16.7572 5.3848 3.11x
ffn-dn E=8 4096 14336 6 10.0767 6.7913 1.48x
INT4 sym, act=bfloat16
shape N K tokens baseline (ms) ark (ms) speedup
small E=4 1024 1024 3 4.5358 4.4735 1.01x
medium E=8 2048 2048 6 14.7139 4.5603 3.23x
large E=8 4096 4096 6 15.2614 4.8515 3.15x
ffn-up E=8 14336 4096 6 16.7543 5.3772 3.12x
ffn-dn E=8 4096 14336 6 9.9935 6.8599 1.46x
INT4 asym, act=float16
shape N K tokens baseline (ms) ark (ms) speedup
small E=4 1024 1024 3 4.5470 4.4868 1.01x
medium E=8 2048 2048 6 14.7237 4.5170 3.26x
large E=8 4096 4096 6 15.2905 4.7301 3.23x
ffn-up E=8 14336 4096 6 16.7591 5.0727 3.30x
ffn-dn E=8 4096 14336 6 10.0460 6.6348 1.51x
INT4 asym, act=bfloat16
shape N K tokens baseline (ms) ark (ms) speedup
small E=4 1024 1024 3 4.5531 4.4723 1.02x
medium E=8 2048 2048 6 14.7281 4.5118 3.26x
large E=8 4096 4096 6 15.2519 4.7181 3.23x
ffn-up E=8 14336 4096 6 16.7251 5.0467 3.31x
ffn-dn E=8 4096 14336 6 10.0143 6.6172 1.51x

INT8 (group_size=128)

INT8 sym, act=float16
shape N K tokens baseline (ms) ark (ms) speedup
small E=4 1024 1024 3 4.5348 4.4776 1.01x
medium E=8 2048 2048 6 14.7624 4.5413 3.25x
large E=8 4096 4096 6 15.2835 4.7855 3.19x
ffn-up E=8 14336 4096 6 16.7979 5.3490 3.14x
ffn-dn E=8 4096 14336 6 10.0627 6.7321 1.49x
INT8 sym, act=bfloat16
shape N K tokens baseline (ms) ark (ms) speedup
small E=4 1024 1024 3 4.5361 4.4548 1.02x
medium E=8 2048 2048 6 14.8570 4.5442 3.27x
large E=8 4096 4096 6 15.2680 4.7650 3.20x
ffn-up E=8 14336 4096 6 16.7520 5.3738 3.12x
ffn-dn E=8 4096 14336 6 10.0337 6.7042 1.50x
INT8 asym, act=float16
shape N K tokens baseline (ms) ark (ms) speedup
small E=4 1024 1024 3 4.5353 4.4759 1.01x
medium E=8 2048 2048 6 14.7176 4.5600 3.23x
large E=8 4096 4096 6 15.2383 4.7803 3.19x
ffn-up E=8 14336 4096 6 16.7403 5.4200 3.09x
ffn-dn E=8 4096 14336 6 10.0434 7.3737 1.36x
INT8 asym, act=bfloat16
shape N K tokens baseline (ms) ark (ms) speedup
small E=4 1024 1024 3 4.5362 4.4826 1.01x
medium E=8 2048 2048 6 14.7305 4.5788 3.22x
large E=8 4096 4096 6 15.2690 4.7784 3.20x
ffn-up E=8 14336 4096 6 16.7455 5.3965 3.10x
ffn-dn E=8 4096 14336 6 10.0108 7.4007 1.35x

INT2 (group_size=128)

INT2 sym, act=float16
shape N K tokens baseline (ms) ark (ms) speedup
small E=4 1024 1024 3 4.5327 4.4757 1.01x
medium E=8 2048 2048 6 14.6971 4.5187 3.25x
large E=8 4096 4096 6 15.2535 4.7488 3.21x
ffn-up E=8 14336 4096 6 16.7776 5.1435 3.26x
ffn-dn E=8 4096 14336 6 10.0886 5.3883 1.87x
INT2 sym, act=bfloat16
shape N K tokens baseline (ms) ark (ms) speedup
small E=4 1024 1024 3 4.5332 4.4797 1.01x
medium E=8 2048 2048 6 14.7150 4.5087 3.26x
large E=8 4096 4096 6 15.2786 4.7380 3.22x
ffn-up E=8 14336 4096 6 16.7121 5.1214 3.26x
ffn-dn E=8 4096 14336 6 10.0805 5.4094 1.86x
INT2 asym, act=float16
shape N K tokens baseline (ms) ark (ms) speedup
small E=4 1024 1024 3 4.5359 4.4716 1.01x
medium E=8 2048 2048 6 14.7147 4.5261 3.25x
large E=8 4096 4096 6 15.2664 4.7136 3.24x
ffn-up E=8 14336 4096 6 16.7168 5.0459 3.31x
ffn-dn E=8 4096 14336 6 10.0423 5.3374 1.88x
INT2 asym, act=bfloat16
shape N K tokens baseline (ms) ark (ms) speedup
small E=4 1024 1024 3 4.5347 4.4518 1.02x
medium E=8 2048 2048 6 14.7113 4.4943 3.27x
large E=8 4096 4096 6 15.2581 4.6863 3.26x
ffn-up E=8 14336 4096 6 16.7285 4.9859 3.36x
ffn-dn E=8 4096 14336 6 10.0013 5.2893 1.89x

FP8 (group_size=128)

FP8 e4m3fn, act=float16
shape N K tokens baseline (ms) ark (ms) speedup
small E=4 1024 1024 3 4.5283 4.5969 0.99x
medium E=8 2048 2048 6 14.7270 4.7608 3.09x
large E=8 4096 4096 6 15.2502 6.8569 2.22x
ffn-up E=8 14336 4096 6 16.7447 11.2829 1.48x
ffn-dn E=8 4096 14336 6 10.0653 13.5504 ⚠️ 0.74x
FP8 e4m3fn, act=bfloat16
shape N K tokens baseline (ms) ark (ms) speedup
small E=4 1024 1024 3 4.5295 4.6064 0.98x
medium E=8 2048 2048 6 14.7143 4.7683 3.09x
large E=8 4096 4096 6 15.2756 6.8129 2.24x
ffn-up E=8 14336 4096 6 16.7385 11.6655 1.43x
ffn-dn E=8 4096 14336 6 10.0356 13.4167 ⚠️ 0.75x
FP8 e5m2, act=float16
shape N K tokens baseline (ms) ark (ms) speedup
small E=4 1024 1024 3 4.5368 4.5700 0.99x
medium E=8 2048 2048 6 14.7401 4.6979 3.14x
large E=8 4096 4096 6 15.2908 5.4020 2.83x
ffn-up E=8 14336 4096 6 16.7413 9.4213 1.78x
ffn-dn E=8 4096 14336 6 10.0432 11.2030 ⚠️ 0.90x
FP8 e5m2, act=bfloat16
shape N K tokens baseline (ms) ark (ms) speedup
small E=4 1024 1024 3 4.5477 4.5759 0.99x
medium E=8 2048 2048 6 14.7776 4.6837 3.16x
large E=8 4096 4096 6 15.2996 5.3868 2.84x
ffn-up E=8 14336 4096 6 16.7595 9.6810 1.73x
ffn-dn E=8 4096 14336 6 10.0290 11.1763 ⚠️ 0.90x

Notes

  • INT/FP-quant baselines include dequant cost, which inflates baseline times vs. ark's fused kernel.
  • INT2/INT4/INT8 all show consistent ~3x speedup on mid/large square shapes.
  • FP8 regresses on the ffn-dn (K=14336) shape — worth investigating before enabling FP8 by default for down-projections.

@chensuyue
Copy link
Copy Markdown
Contributor

/azp run Unit-Test-CUDA-AutoRound

@azure-pipelines
Copy link
Copy Markdown

Azure Pipelines could not run because the pipeline triggers exclude this branch/path.

@chensuyue
Copy link
Copy Markdown
Contributor

/azp run Unit-Test-CUDA-AutoRound

@azure-pipelines
Copy link
Copy Markdown

Azure Pipelines could not run because the pipeline triggers exclude this branch/path.

a32543254 and others added 2 commits May 27, 2026 23:03
@chensuyue
Copy link
Copy Markdown
Contributor

/azp run Unit-Test-CUDA-AutoRound

@azure-pipelines
Copy link
Copy Markdown

Azure Pipelines could not run because the pipeline triggers exclude this branch/path.

@chensuyue
Copy link
Copy Markdown
Contributor

/azp run Unit-Test-CUDA-AutoRound

@azure-pipelines
Copy link
Copy Markdown

Azure Pipelines could not run because the pipeline triggers exclude this branch/path.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants