Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
58b0900
Add XPU MoE decode kernel with INT4 sym/asym and FP16/BF16 baselines
Copilot May 14, 2026
527eede
Document int4 sign-extension trick
Copilot May 14, 2026
78ecc0c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 14, 2026
5dc9d95
Add INT8/INT2/FP8 decode MoE GEMV kernels and tests
Copilot May 14, 2026
f15093a
docs: clarify int2 bit-indexing notation in moe_gemm_decode
Copilot May 14, 2026
430868d
Merge remote-tracking branch 'origin/main' into copilot/add-xpu-moe-d…
May 18, 2026
4395884
test: add perf comparison UT — moe_gemm_decode vs default XPU MoE
Copilot May 19, 2026
a864bed
test: clearer skip reasons for moe_gemm_decode perf UT
Copilot May 20, 2026
407da75
fix(ark): correct duplicated bestla include path in sycl_tla_moe_deco…
Copilot May 26, 2026
70dc320
perf: vectorize moe_gemm_decode loads, parallelize expert-id fill, dr…
Copilot May 26, 2026
1da1977
feat(ark): add ARK_FP8_DECODE_USE_LUT switch for FP8 decode in MoE ke…
Copilot May 26, 2026
c297d37
feat(ark): make FP8 decode LUT switch runtime via ARK_FP8_DECODE_USE_…
Copilot May 26, 2026
26dbeaa
fix precommit
a32543254 May 27, 2026
72b19e9
Merge branch 'main' into copilot/add-xpu-moe-decode-implementation
a32543254 May 27, 2026
608bf28
Merge branch 'main' into copilot/add-xpu-moe-decode-implementation
a32543254 May 28, 2026
7ef8dbf
Merge branch 'main' into copilot/add-xpu-moe-decode-implementation
a32543254 May 29, 2026
3dd5a8e
Merge branch 'main' into copilot/add-xpu-moe-decode-implementation
a32543254 Jun 5, 2026
dab6219
Apply remaining changes
Copilot Jun 5, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 202 additions & 0 deletions auto_round_extension/ark/auto_round_kernel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1065,6 +1065,208 @@ def _ceil_div(a, b):
out = out[:, :, :Sq, :]
return out

def moe_gemm_decode(
self,
activations: torch.Tensor,
weights: torch.Tensor,
num_tokens_per_expert: torch.Tensor,
*,
scales: Optional[torch.Tensor] = None,
zeros: Optional[torch.Tensor] = None,
weight_bits: int = 4,
group_size: int = 128,
asym: bool = False,
) -> torch.Tensor:
"""MoE GEMV optimized for the decode phase.

Each expert typically processes only 1-2 tokens (top-k routing with
small batch). Activations must already be gathered/sorted by expert
(same convention as ``moe_gemm``).

Args:
activations: ``[total_tokens, K]`` in fp16 or bf16.
weights: 3-D tensor ``[E, N, K_packed]``. The accepted layouts are:

* Unquantized (``weight_bits=16``): ``torch.float16`` / ``torch.bfloat16``
matching the activations dtype, ``K_packed == K``.
* Int8 (``weight_bits=8``): ``torch.uint8``, ``K_packed == K``.
Sym (``asym=False``) reinterprets each byte as signed int8;
asym (``asym=True``) treats each byte as ``uint8`` with a
per-group zero-point.
* Int4 (``weight_bits=4``): ``torch.uint8`` packed,
``K_packed == K // 2`` (two 4-bit values per byte; low nibble
at the lower K index).
* Int2 (``weight_bits=2``): ``torch.uint8`` packed,
``K_packed == K // 4`` (four 2-bit values per byte; field j at
K index ``4*i + j`` occupies bits 2j and 2j+1 of byte i).
* FP8 (``torch.float8_e4m3fn`` / ``torch.float8_e5m2``):
``K_packed == K``. ``weight_bits`` is ignored; ``asym`` must
be ``False`` (no zero-points for FP8).
num_tokens_per_expert: ``[E]`` int32. Sum must equal
``activations.shape[0]``.
scales: ``[E, N, K // group_size]`` in activations dtype. Required
for all quantized paths (int8/int4/int2/fp8); must be ``None``
for unquantized weights.
zeros: ``[E, N, K // group_size]`` in activations dtype. Required
when ``asym=True`` (int8/int4/int2 only); otherwise ``None``.
weight_bits: 2, 4, 8, or 16. Ignored when ``weights`` is an FP8
tensor (the FP8 sub-format is taken from ``weights.dtype``).
group_size: group along K for quantized weights (default 128).
asym: if ``True``, weights use unsigned encoding and ``zeros`` must
be provided. Not supported for FP8.

Returns:
outputs: ``[total_tokens, N]`` in the same dtype as activations.
"""
if activations.device.type != "xpu":
raise NotImplementedError("moe_gemm_decode is only supported on XPU")

if activations.dtype not in (torch.float16, torch.bfloat16):
raise ValueError(f"activations must be fp16/bf16, got {activations.dtype}")

if activations.ndim != 2:
raise ValueError("activations must be 2D [total_tokens, K]")
if weights.ndim != 3:
raise ValueError("weights must be 3D [E, N, K_packed]")

if not activations.is_contiguous():
activations = activations.contiguous()
Comment thread
a32543254 marked this conversation as resolved.
if not weights.is_contiguous():
weights = weights.contiguous()

if num_tokens_per_expert.dtype != torch.int32:
num_tokens_per_expert = num_tokens_per_expert.to(torch.int32)
if not num_tokens_per_expert.is_contiguous():
num_tokens_per_expert = num_tokens_per_expert.contiguous()

total_tokens, K = activations.shape
num_experts = weights.shape[0]
N = weights.shape[1]

if num_tokens_per_expert.shape[0] != num_experts:
raise ValueError(
f"num_tokens_per_expert length {num_tokens_per_expert.shape[0]} != num_experts {num_experts}"
)

# Detect FP8 weight dtype first (overrides weight_bits).
is_fp8 = weights.dtype in (torch.float8_e4m3fn, torch.float8_e5m2)

# Validate weight layout / dtype combination.
if is_fp8:
if asym:
raise ValueError("FP8 weights do not support asym=True")
if weights.shape[2] != K:
raise ValueError(f"FP8 weights K dim {weights.shape[2]} != activations K {K}")
if scales is None:
raise ValueError("scales is required for FP8 weights")
if scales.dtype != activations.dtype:
raise ValueError("scales dtype must match activations dtype")
if K % group_size != 0:
raise ValueError("K must be a multiple of group_size")
expected_scale_shape = (num_experts, N, K // group_size)
if tuple(scales.shape) != expected_scale_shape:
raise ValueError(f"scales shape {tuple(scales.shape)} != expected {expected_scale_shape}")
if zeros is not None:
raise ValueError("zeros must be None for FP8 weights")
weight_dtype = ARK_DT.float8_e4m3 if weights.dtype == torch.float8_e4m3fn else ARK_DT.float8_e5m2
if not scales.is_contiguous():
scales = scales.contiguous()
elif weight_bits == 16:
if weights.dtype != activations.dtype:
raise ValueError("Unquantized weights must match activations dtype")
if weights.shape[2] != K:
raise ValueError(f"Unquantized weights K dim {weights.shape[2]} != activations K {K}")
weight_dtype = cvt_dtype(activations.dtype)
if scales is not None or zeros is not None:
raise ValueError("scales/zeros must be None when weight_bits=16")
elif weight_bits in (8, 4, 2):
if weights.dtype != torch.uint8:
raise ValueError(f"Int{weight_bits} packed weights must be torch.uint8")
if weight_bits == 8:
k_packed_expected = K
k_div = 1
elif weight_bits == 4:
k_packed_expected = K // 2
k_div = 2
else: # weight_bits == 2
k_packed_expected = K // 4
k_div = 4
if K % k_div != 0:
raise ValueError(f"K must be a multiple of {k_div} for weight_bits={weight_bits}")
if weights.shape[2] != k_packed_expected:
raise ValueError(
f"Int{weight_bits} packed weights last dim {weights.shape[2]} must equal K/{k_div} "
f"({k_packed_expected})"
)
if scales is None:
raise ValueError(f"scales is required for int{weight_bits} weights")
if scales.dtype != activations.dtype:
raise ValueError("scales dtype must match activations dtype")
if K % group_size != 0:
raise ValueError("K must be a multiple of group_size")
# Group_size constraints per dtype.
if weight_bits == 4 and (group_size & 1) != 0:
raise ValueError("group_size must be even for int4 weights")
if weight_bits == 2 and (group_size & 3) != 0:
raise ValueError("group_size must be a multiple of 4 for int2 weights")
expected_scale_shape = (num_experts, N, K // group_size)
if tuple(scales.shape) != expected_scale_shape:
raise ValueError(f"scales shape {tuple(scales.shape)} != expected {expected_scale_shape}")
if asym:
if zeros is None:
raise ValueError("zeros is required when asym=True")
if zeros.dtype != activations.dtype:
raise ValueError("zeros dtype must match activations dtype")
if tuple(zeros.shape) != expected_scale_shape:
raise ValueError(f"zeros shape {tuple(zeros.shape)} != expected {expected_scale_shape}")
else:
if zeros is not None:
raise ValueError("zeros must be None when asym=False")
weight_dtype = {8: ARK_DT.int8, 4: ARK_DT.int4, 2: ARK_DT.int2}[weight_bits]
if not scales.is_contiguous():
scales = scales.contiguous()
if asym and not zeros.is_contiguous():
zeros = zeros.contiguous()
else:
raise ValueError(f"Unsupported weight_bits={weight_bits} (supported: 2, 4, 8, 16)")

if N % 16 != 0:
raise ValueError(f"N must be a multiple of 16 (got {N})")

expected_total = int(num_tokens_per_expert.sum().item())
if expected_total != total_tokens:
raise ValueError(f"Sum of num_tokens_per_expert ({expected_total}) != total_tokens ({total_tokens})")

lib = self.get_lib(activations)
stream = get_stream(activations)
outputs = torch.empty((total_tokens, N), device=activations.device, dtype=activations.dtype)
# Scratch buffer mapping each token to its expert id; filled on-device
# inside the kernel wrapper so we avoid host-device sync.
expert_id_per_token = torch.empty((total_tokens,), device=activations.device, dtype=torch.int32)

scales_ptr = scales.data_ptr() if scales is not None else 0
zeros_ptr = zeros.data_ptr() if zeros is not None else 0

lib.moe_gemm_decode(
stream,
activations.data_ptr(),
weights.data_ptr(),
scales_ptr,
zeros_ptr,
outputs.data_ptr(),
expert_id_per_token.data_ptr(),
cvt_dtype(activations.dtype),
weight_dtype,
N,
K,
group_size,
num_tokens_per_expert.data_ptr(),
num_experts,
total_tokens,
bool(asym),
)
return outputs


def moe_gemm(
activations: torch.Tensor,
Expand Down
12 changes: 12 additions & 0 deletions auto_round_extension/ark/auto_round_kernel/ark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ typedef uintptr_t torch_ptr;
// Only include declarations, implementations are in separate .cpp files
#include "sycl_tla_common.hpp"
#include "sycl_tla_moe.hpp"
#include "sycl_tla_moe_decode.hpp"
#include "sycl_tla_sdpa.hpp"
#endif
#else
Expand Down Expand Up @@ -222,6 +223,16 @@ static void moe_gemm_wrapper(torch_ptr stream, torch_ptr activations, torch_ptr
(void*)outputs, (BTLA_DTYPE)(dtype), N, K, (int*)num_tokens_per_expert, num_experts);
}

static void moe_gemm_decode_wrapper(torch_ptr stream, torch_ptr activations, torch_ptr weights, torch_ptr scales,
torch_ptr zeros, torch_ptr outputs, torch_ptr expert_id_per_token_buf,
int act_dtype, int weight_dtype, int N, int K, int group_size,
torch_ptr num_tokens_per_expert, int num_experts, int total_tokens, bool asym) {
ark::moe_gemm_decode((sycl::queue*)stream, (void*)activations, (void*)weights, scales ? (void*)scales : nullptr,
zeros ? (void*)zeros : nullptr, (void*)outputs, (int*)expert_id_per_token_buf,
(BTLA_DTYPE)(act_dtype), (BTLA_DTYPE)(weight_dtype), N, K, group_size,
(int*)num_tokens_per_expert, num_experts, total_tokens, asym);
}

static void sage_dynamic_quant(torch_ptr stream, torch_ptr input, torch_ptr bias, torch_ptr output, torch_ptr scale_out,
int num_rows, int head_dim, int block_size) {
auto* q = (sycl::queue*)stream;
Expand Down Expand Up @@ -439,5 +450,6 @@ PYBIND11_MODULE(PY_NAME, m) {
m.def("sage_dynamic_quant_layout", &ark::sage_dynamic_quant_layout);
m.def("sage_dynamic_quant_v_layout", &ark::sage_dynamic_quant_v_layout);
m.def("moe_gemm", &ark::moe_gemm_wrapper);
m.def("moe_gemm_decode", &ark::moe_gemm_decode_wrapper);
#endif
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,40 @@ namespace ark {
void moe_gemm(sycl::queue* q, void* activations, void* weights, void* scales, void* outputs, BTLA_DTYPE dtype, int N,
int K, int* num_tokens_per_expert, int num_experts);

/**
* @brief MoE GEMV optimized for the decode phase (M per expert is typically
* 1-2 tokens). Supports unquantized FP16/BF16 weights and int4 (S4_CLIP)
* weights with group-wise scales and optional zero-points.
*
* Implementation is header-only in `sycl_tla_moe_decode.hpp`.
*
* @param q SYCL queue
* @param activations [total_tokens, K] in `act_dtype`
* @param weights Unquantized: [num_experts, N, K] in act_dtype
* Int4: packed [num_experts, N, K/2] uint8
* @param scales [num_experts, N, K/group_size] (act_dtype),
* ignored when weight_dtype is FP16/BF16
* @param zeros [num_experts, N, K/group_size] (act_dtype) or
* nullptr; required when asym==true
* @param outputs [total_tokens, N] in act_dtype
* @param expert_id_per_token_buf [total_tokens] int32 scratch buffer (device)
* @param act_dtype BTLA_DTYPE::F16 or BTLA_DTYPE::BF16
* @param weight_dtype BTLA_DTYPE::F16/BF16/S4_CLIP
* @param N Output feature dim (must be multiple of 16)
* @param K Input feature dim
* @param group_size Quantization group along K (int4 only); must
* divide K and be even. Default 128.
* @param num_tokens_per_expert [num_experts] int32
* @param num_experts Number of experts
* @param total_tokens Sum of num_tokens_per_expert (== rows of
* activations / outputs)
* @param asym Whether int4 weights are asymmetric
* (zeros required when true).
Comment thread
a32543254 marked this conversation as resolved.
*/
void moe_gemm_decode(sycl::queue* q, void* activations, void* weights, void* scales, void* zeros, void* outputs,
int* expert_id_per_token_buf, BTLA_DTYPE act_dtype, BTLA_DTYPE weight_dtype, int N, int K,
int group_size, int* num_tokens_per_expert, int num_experts, int total_tokens, bool asym);

// ========================================================================
// Public API
// ========================================================================
Expand Down
Loading
Loading