Skip to content

Optimize naive top-k masking in fused router#2783

Open
yosh20004 wants to merge 4 commits intoNVIDIA:mainfrom
yosh20004:perf/naive-topk-and-mask
Open

Optimize naive top-k masking in fused router#2783
yosh20004 wants to merge 4 commits intoNVIDIA:mainfrom
yosh20004:perf/naive-topk-and-mask

Conversation

@yosh20004
Copy link

@yosh20004 yosh20004 commented Mar 19, 2026

Description

Refactor the fused router's naive top-k selection path to track already-selected entries with a per-lane mask

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Refactor naive_topk_and_mask in transformer_engine/common/fused_router/utils.h to track per-lane selections with a local bitmask
  • Preserve the existing function interface and current fused router call sites

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 19, 2026

Greptile Summary

This PR refactors naive_topk_and_mask in transformer_engine/common/fused_router/utils.h to replace the original O(k) linear scan over topk_indices (the is_masked lambda) with a per-lane 32-bit bitmask (local_mask) that tracks already-selected elements in register. The function is also templatised on the score type T to match the existing call sites. The core reduction is changed from an XOR-butterfly (__shfl_xor_sync) to a down-tree (__shfl_down_sync) with a subsequent broadcast from lane 0.

Key points:

  • Correctness of bitmask approach: Correct for data_size <= 1024 (32 bits × 32 lanes). The assert added to enforce this is compiled out in release builds (NDEBUG), leaving the previously identified shift-by-≥-32 UB active in production.
  • Warp reduction: The __shfl_down_sync reduction correctly delivers the global max to lane 0, which is then broadcast via __shfl_sync. However, this introduces a subtle tie-breaking behaviour change relative to the original XOR butterfly, which is undocumented.
  • __syncwarp() restructuring: Moving __syncwarp() outside the topk loop is valid in the new design because topk_indices is no longer read back within the loop body (the mask is maintained in registers). Call sites supply their own __syncwarp() after the function anyway.
  • Template parameter T: The inner reduction is always done in CompType (float), so the templatisation is cosmetic for the current supported types (fp32, fp16, bf16).
  • The PR leaves several checklist items incomplete (no new tests, no new warnings confirmation, no documentation updates).

Confidence Score: 2/5

  • The PR has functional correctness issues that are not fully resolved — the data_size > 1024 UB guard is stripped in release builds, and the tie-breaking behaviour change is undocumented.
  • Score 2 reflects: the core bitmask algorithm is logically sound within its data_size <= 1024 bound, but (1) that bound is only enforced by a debug-only assert that is a no-op in optimised CUDA builds, meaning the previously identified shift-by-≥-32 UB is still present in production; (2) the tie-breaking change introduced by switching from XOR-butterfly to down-tree reduction is a silent behavioural delta not called out in the PR description; (3) no tests were added to cover the refactored path.
  • transformer_engine/common/fused_router/utils.h — specifically the assert guard (lines 213–214) and the warp reduction (lines 238–247).

Important Files Changed

Filename Overview
transformer_engine/common/fused_router/utils.h Refactors naive_topk_and_mask to use a per-lane 32-bit bitmask instead of a linear O(k) scan, and templatises the function on score type T. The key correctness concern is that the assert guarding the data_size <= 1024 constraint is stripped in release builds, leaving the previously identified shift-by-≥32 UB unprotected in production. The __shfl_down_sync reduction is logically correct (lane 0 receives the true global max, which is then broadcast), but introduces a subtle tie-breaking behaviour change compared to the original XOR-butterfly reduction.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["naive_topk_and_mask called\n(all 32 warp lanes)"] --> B["assert data_size ≤ 1024\n⚠ stripped in release builds"]
    B --> C["for k = 0 .. topk-1"]

    C --> D["Step 1: Per-lane local max\nfor i = lane_id; i < data_size; i += 32\n  check local_mask bit_idx\n  select unmasked scores[i]"]
    D --> E["Step 2: __shfl_down_sync tree reduction\ns = 16 → 8 → 4 → 2 → 1\nLane 0 accumulates global max"]
    E --> F["Broadcast from lane 0\n__shfl_sync(0xffffffff, global_max_idx, 0)\n__shfl_sync(0xffffffff, global_max_val, 0)"]
    F --> G["Step 3: Lane 0 writes\ntopk_indices[k] = global_max_idx\ntopk_scores[k] = global_max_val"]
    G --> H["Step 4: Owning lane sets bit\nif global_max_idx % 32 == lane_id\n  local_mask |= 1u << local_bit_pos"]
    H --> C

    C -->|"k == topk"| I["__syncwarp()\nreturn"]
Loading

Comments Outside Diff (1)

  1. transformer_engine/common/fused_router/utils.h, line 213-214 (link)

    assert is silently stripped in release builds, leaving the UB unguarded

    The newly added assert(data_size <= 1024 …) is compiled out whenever NDEBUG is defined, which is the standard configuration for all optimized/release CUDA builds. This means the constraint guarding against the 32-bit local_mask overflow (and the resulting shift-by-≥-32 UB identified in the earlier review) does not exist at runtime in production.

    Device-side CUDA asserts obey the same NDEBUG stripping rule as host asserts: without -G or --device-debug, the check is a no-op in any -O build. A runtime check that is always active (e.g. a kernel-visible NVTE_CHECK-style condition or a guarded early-return) would be needed to actually enforce this constraint in release mode.

Last reviewed commit: "[Common] Add top-k l..."

int bit_idx = 0;
for (int i = lane_id; i < data_size; i += kThreadsPerWarp) {
CompType cur_val = 0.0f;
if constexpr (std::is_same_v<CompType, double>) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 if constexpr checks CompType, not T — double branch is permanently dead

The condition std::is_same_v<CompType, double> checks the alias CompType, which is unconditionally using CompType = float (line 19 of this file). This condition is therefore always false, making the entire double branch (__double_as_longlong, __longlong_as_double, 0xFFF0000000000000ULL) unreachable dead code.

The intent appears to be std::is_same_v<T, double> so that the branch activates when the function is instantiated with a double-precision type. However, note that even after that fix, the else branch also casts via static_cast<CompType>(scores[i]) (i.e., to float), so true double-precision correctness would still need a broader refactor. For now, the dead branch adds confusion with no benefit.

Suggested change
if constexpr (std::is_same_v<CompType, double>) {
if constexpr (std::is_same_v<T, double>) {

Comment on lines +218 to +263
for (int i = lane_id; i < data_size; i += kThreadsPerWarp) {
CompType cur_val = 0.0f;
if constexpr (std::is_same_v<CompType, double>) {
uint64_t mask = -(uint64_t)((local_mask >> bit_idx) & 1u);
uint64_t x_bits = __double_as_longlong(static_cast<CompType>(scores[i]));
uint64_t result_bits = (~mask & x_bits) | (mask & 0xFFF0000000000000ULL);
cur_val = __longlong_as_double(result_bits);
} else {
uint32_t full_mask = -(uint32_t)((local_mask >> bit_idx) & 1u);
uint32_t x_bits = __float_as_uint(static_cast<CompType>(scores[i]));
uint32_t result_bits = (~full_mask & x_bits) | (full_mask & 0xFF800000u);
cur_val = __uint_as_float(result_bits);
}
if (cur_val > local_max_val) {
local_max_val = cur_val;
local_max_idx = i;
}
bit_idx++;
}
// Warp shuffle between threads
for (int s = 16; s > 0; s /= 2) {
auto shuffled_val = __shfl_xor_sync(0xffffffff, val, s);
auto shuffled_index = __shfl_xor_sync(0xffffffff, index, s);
if (shuffled_val > val) {
val = shuffled_val;
index = shuffled_index;

// 2) Warp reduction to find global max and index.
CompType global_max_val = local_max_val;
int global_max_idx = local_max_idx;
for (int s = kThreadsPerWarp / 2; s > 0; s /= 2) {
CompType shuffled_val = __shfl_down_sync(0xffffffff, global_max_val, s);
int shuffled_idx = __shfl_down_sync(0xffffffff, global_max_idx, s);
if (shuffled_val > global_max_val) {
global_max_val = shuffled_val;
global_max_idx = shuffled_idx;
}
}
global_max_idx = __shfl_sync(0xffffffff, global_max_idx, 0);
global_max_val = __shfl_sync(0xffffffff, global_max_val, 0);

// 3) Write top-k result.
if (lane_id == 0) {
topk_indices[k] = index;
topk_scores[k] = val;
topk_indices[k] = global_max_idx;
topk_scores[k] = static_cast<T>(global_max_val);
}

// 4) Mark selected element in owning lane's local mask.
if (global_max_idx >= 0 && (global_max_idx % kThreadsPerWarp) == lane_id) {
int local_bit_pos = global_max_idx / kThreadsPerWarp;
if (local_bit_pos < 32) {
local_mask |= (1u << local_bit_pos);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 32-bit local_mask silently breaks correctness for data_size > 1024

local_mask is a uint32_t, giving each lane exactly 32 bits to track up to 32 local elements. Each lane owns element indices at lane_id, lane_id + 32, lane_id + 64, …, so the maximum trackable data_size is 32 * kThreadsPerWarp = 1024.

When data_size > 1024, two related problems arise:

1. Undefined-behavior shift in the inner loop (line 226 / 221):

uint32_t full_mask = -(uint32_t)((local_mask >> bit_idx) & 1u);

bit_idx increments to 32+ when data_size > 1024. Right-shifting a uint32_t by ≥ 32 is undefined behaviour in C++. On CUDA PTX the hardware clamps the result to 0, so the element is effectively never masked — but that itself causes the second problem.

2. Silent double-selection via the guard on line 261:

if (local_bit_pos < 32) {
    local_mask |= (1u << local_bit_pos);
}

When global_max_idx / kThreadsPerWarp ≥ 32, the guard silently skips the bit-set, leaving the already-selected element unmasked. In the next top-k iteration the same element is eligible for selection again, producing duplicate indices in topk_indices.

The original is_masked lambda had no such limit because it scanned the full topk_indices array. While current usage (num_experts ≤ a few hundred) keeps data_size well under 1024, the silent failure mode is dangerous and should at minimum be guarded with a compile-time or runtime assertion:

// At function entry or as a static_assert at the call site:
assert(data_size <= static_cast<int>(sizeof(local_mask) * 8 * kThreadsPerWarp) &&
       "local_mask too small for data_size > 1024");

Or switch to a uint64_t mask (doubling the safe range to 2048) and add the assertion for anything larger.

Comment on lines +239 to +250
CompType global_max_val = local_max_val;
int global_max_idx = local_max_idx;
for (int s = kThreadsPerWarp / 2; s > 0; s /= 2) {
CompType shuffled_val = __shfl_down_sync(0xffffffff, global_max_val, s);
int shuffled_idx = __shfl_down_sync(0xffffffff, global_max_idx, s);
if (shuffled_val > global_max_val) {
global_max_val = shuffled_val;
global_max_idx = shuffled_idx;
}
}
global_max_idx = __shfl_sync(0xffffffff, global_max_idx, 0);
global_max_val = __shfl_sync(0xffffffff, global_max_val, 0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 __shfl_down_sync reduction does not guarantee a stable winner on ties

The original code used a butterfly (__shfl_xor_sync) reduction pattern where all threads naturally converge to the same result in 5 rounds. The new __shfl_down_sync tree reduction still correctly delivers the global maximum to lane 0 (then broadcast via __shfl_sync), but the tie-breaking behaviour when two lanes hold the same maximum value is different from the original: with shuffled_val > global_max_val (strict >), the lower-indexed lane's index is preserved on a tie, which subtly differs from the XOR-butterfly order.

This is unlikely to matter for a floating-point router where exact ties are rare, but it is a behavioural change not called out in the PR description. If determinism across refactors matters for reproducibility testing, this should be documented or a consistent tie-breaking rule (e.g., prefer smaller index) should be explicitly enforced in both the value comparison and the index comparison:

if (shuffled_val > global_max_val ||
    (shuffled_val == global_max_val && shuffled_idx < global_max_idx)) {
    global_max_val = shuffled_val;
    global_max_idx = shuffled_idx;
}

@yosh20004
Copy link
Author

yosh20004 commented Mar 19, 2026

Performance test script

# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

import os
from typing import Any, Callable, Optional, Tuple, cast

import pytest
import torch

from transformer_engine.pytorch.router import (
    fused_compute_score_for_moe_aux_loss,
    fused_moe_aux_loss,
    fused_topk_with_score_function,
)


seed = 42
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)


def _cuda() -> Any:
    return getattr(torch, "cuda", None)


def _require_perf_env() -> None:
    if _cuda() is None or not _cuda().is_available():
        pytest.skip("CUDA is not available.")
    if os.getenv("TE_RUN_PERF_TESTS", "1") != "1":
        pytest.skip("Set TE_RUN_PERF_TESTS=1 to enable router performance tests.")


def _benchmark_cuda_kernel(
    fn: Callable[[], object], warmup: int = 500, iters: int = 2000
) -> float:
    start_event = _cuda().Event(enable_timing=True)
    end_event = _cuda().Event(enable_timing=True)

    for _ in range(warmup):
        fn()
    _cuda().synchronize()

    start_event.record()
    for _ in range(iters):
        fn()
    end_event.record()
    _cuda().synchronize()

    return start_event.elapsed_time(end_event) / iters


def group_limited_topk(
    scores: torch.Tensor,
    topk: int,
    num_tokens: int,
    num_experts: int,
    num_groups: int,
    group_topk: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
    group_scores = (
        scores.view(num_tokens, num_groups, -1)
        .topk(topk // group_topk, dim=-1)[0]
        .sum(dim=-1)
    )
    group_idx = torch.topk(group_scores, k=group_topk, dim=-1, sorted=False)[1]
    group_mask = torch.zeros_like(group_scores)
    group_mask.scatter_(1, group_idx, 1)

    score_mask = (
        group_mask.unsqueeze(-1)
        .expand(num_tokens, num_groups, num_experts // num_groups)
        .reshape(num_tokens, -1)
    )
    masked_scores = scores.masked_fill(~score_mask.bool(), float("-inf"))
    probs, top_indices = torch.topk(masked_scores, k=topk, dim=-1)
    return probs, top_indices


def topk_score_function_pytorch(
    logits: torch.Tensor,
    topk: int,
    use_pre_softmax: bool = False,
    num_groups: Optional[int] = None,
    group_topk: Optional[int] = None,
    scaling_factor: Optional[float] = None,
    score_function: str = "softmax",
    expert_bias: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
    num_tokens, num_experts = logits.shape

    def compute_topk(
        scores: torch.Tensor,
        topk_value: int,
        num_groups_value: Optional[int] = None,
        group_topk_value: Optional[int] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        if group_topk_value:
            assert num_groups_value is not None
            return group_limited_topk(
                scores=scores,
                topk=topk_value,
                num_tokens=num_tokens,
                num_experts=num_experts,
                num_groups=num_groups_value,
                group_topk=group_topk_value,
            )
        return torch.topk(scores, k=topk_value, dim=1)

    if score_function == "softmax":
        if use_pre_softmax:
            scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits)
            probs, top_indices = compute_topk(scores, topk, num_groups, group_topk)
        else:
            scores, top_indices = compute_topk(logits, topk, num_groups, group_topk)
            probs = torch.softmax(scores, dim=-1, dtype=torch.float32).type_as(logits)
    elif score_function in ("sigmoid", "sqrtsoftplus"):
        if score_function == "sigmoid":
            scores = torch.sigmoid(logits.float()).type_as(logits)
        else:
            scores = torch.nn.functional.softplus(logits.float()).sqrt().type_as(logits)
        if expert_bias is not None:
            scores_for_routing = scores + expert_bias
            _, top_indices = compute_topk(
                scores_for_routing, topk, num_groups, group_topk
            )
            scores = torch.gather(scores, dim=1, index=top_indices).type_as(logits)
        else:
            scores, top_indices = compute_topk(scores, topk, num_groups, group_topk)
        probs = (
            scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if topk > 1 else scores
        )
    else:
        raise ValueError(f"Invalid score_function: {score_function}")

    if scaling_factor:
        probs = probs * scaling_factor

    topk_masked_gates = torch.zeros_like(logits).scatter(1, top_indices, probs)
    topk_map = torch.zeros_like(logits).int().scatter(1, top_indices, 1).bool()
    return topk_masked_gates, topk_map


def compute_scores_for_aux_loss_pytorch(
    logits: torch.Tensor, topk: int, score_function: str
) -> Tuple[torch.Tensor, torch.Tensor]:
    if score_function == "softmax":
        scores = torch.softmax(logits, dim=-1, dtype=torch.float32)
    elif score_function == "sigmoid":
        scores = torch.sigmoid(logits.float())
        scores = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20)
    elif score_function == "sqrtsoftplus":
        scores = torch.nn.functional.softplus(logits.float()).sqrt()
        scores = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20)
    else:
        raise ValueError(f"Invalid score_function: {score_function}")

    _, top_indices = torch.topk(scores, k=topk, dim=1)
    routing_map = torch.zeros_like(logits).int().scatter(1, top_indices, 1).bool()
    return routing_map, scores


def aux_loss_pytorch(
    probs: torch.Tensor,
    tokens_per_expert: torch.Tensor,
    total_num_tokens: int,
    topk: int,
    num_experts: int,
    moe_aux_loss_coeff: float,
) -> torch.Tensor:
    aggregated_probs_per_expert = probs.sum(dim=0)
    return torch.sum(aggregated_probs_per_expert * tokens_per_expert) * (
        num_experts * moe_aux_loss_coeff / (topk * total_num_tokens * total_num_tokens)
    )


def _make_router_logits(
    dtype: torch.dtype, num_tokens: int, num_experts: int, score_function: str
) -> torch.Tensor:
    if score_function in ("sigmoid", "sqrtsoftplus"):
        offset = (
            torch.arange(-num_tokens // 2, num_tokens // 2, dtype=dtype, device="cuda")
            * 1e-4
        )
        logits = (
            torch.arange(
                -num_experts // 2, num_experts // 2, device="cuda", dtype=dtype
            )
            * 1e-2
        )
        return logits.unsqueeze(0).repeat(num_tokens, 1) + offset.unsqueeze(1)

    logits = (
        torch.arange(
            -num_tokens * num_experts // 2,
            num_tokens * num_experts // 2,
            device="cuda",
            dtype=dtype,
        )
        * 1e-4
    )
    return logits.view(num_tokens, num_experts)


def _make_router_bias(num_experts: int, dtype: torch.dtype) -> torch.Tensor:
    bias = torch.arange(num_experts, device="cuda", dtype=dtype) * 0.1
    return torch.flip(bias, dims=[0])


def _make_aux_loss_probs(
    dtype: torch.dtype, num_tokens: int, num_experts: int
) -> torch.Tensor:
    offset = (
        torch.arange(-num_tokens // 2, num_tokens // 2, dtype=dtype, device="cuda")
        * 1e-4
    )
    probs = (
        torch.arange(-num_experts // 2, num_experts // 2, device="cuda", dtype=dtype)
        * 1e-2
    )
    probs = probs.unsqueeze(0).repeat(num_tokens, 1) + offset.unsqueeze(1)
    return probs.view(num_tokens, num_experts)


def _fused_expert_bias(expert_bias: Optional[torch.Tensor]) -> torch.Tensor:
    if expert_bias is None:
        return cast(torch.Tensor, None)
    return expert_bias


def _print_perf_result(case_name: str, torch_ms: float, fused_ms: float) -> None:
    speedup = torch_ms / fused_ms
    print(
        f"{case_name}: torch={torch_ms:.6f} ms, fused={fused_ms:.6f} ms, speedup={speedup:.4f}x"
    )


def _perf_assert_message(case_name: str, torch_ms: float, fused_ms: float) -> str:
    return (
        f"{case_name} perf result: torch={torch_ms:.6f} ms, "
        f"fused={fused_ms:.6f} ms, speedup={torch_ms / fused_ms:.4f}x"
    )


@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize(
    "num_tokens,num_experts,topk",
    [(8192, 256, 8), (16384, 256, 8)],
    ids=["large", "xlarge"],
)
@pytest.mark.parametrize(
    "score_function,use_pre_softmax,enable_bias,num_groups,group_topk,scaling_factor",
    [
        ("softmax", True, False, None, None, None),
        ("softmax", False, False, 8, 4, 1.2),
        ("sigmoid", False, False, None, None, None),
        ("sigmoid", False, True, 8, 4, 1.2),
        ("sqrtsoftplus", False, False, None, None, None),
        ("sqrtsoftplus", False, True, 8, 4, 1.2),
    ],
    ids=[
        "softmax_pre",
        "softmax_grouped",
        "sigmoid_plain",
        "sigmoid_grouped_bias",
        "sqrtsoftplus_plain",
        "sqrtsoftplus_grouped_bias",
    ],
)
def test_fused_topk_router_perf_against_torch(
    dtype,
    num_tokens,
    num_experts,
    topk,
    score_function,
    use_pre_softmax,
    enable_bias,
    num_groups,
    group_topk,
    scaling_factor,
    record_property,
):
    _require_perf_env()

    logits = _make_router_logits(dtype, num_tokens, num_experts, score_function)
    expert_bias = None
    if enable_bias and score_function in ("sigmoid", "sqrtsoftplus"):
        expert_bias = _make_router_bias(num_experts, dtype)

    torch_probs, torch_map = topk_score_function_pytorch(
        logits=logits,
        topk=topk,
        use_pre_softmax=use_pre_softmax,
        num_groups=num_groups,
        group_topk=group_topk,
        scaling_factor=scaling_factor,
        score_function=score_function,
        expert_bias=expert_bias,
    )
    fused_probs, fused_map = cast(
        Tuple[torch.Tensor, torch.Tensor],
        fused_topk_with_score_function(
            logits=logits,
            topk=topk,
            use_pre_softmax=use_pre_softmax,
            num_groups=num_groups,
            group_topk=group_topk,
            scaling_factor=scaling_factor,
            score_function=score_function,
            expert_bias=_fused_expert_bias(expert_bias),
        ),
    )

    torch_ms = _benchmark_cuda_kernel(
        lambda: topk_score_function_pytorch(
            logits=logits,
            topk=topk,
            use_pre_softmax=use_pre_softmax,
            num_groups=num_groups,
            group_topk=group_topk,
            scaling_factor=scaling_factor,
            score_function=score_function,
            expert_bias=expert_bias,
        )
    )
    fused_ms = _benchmark_cuda_kernel(
        lambda: fused_topk_with_score_function(
            logits=logits,
            topk=topk,
            use_pre_softmax=use_pre_softmax,
            num_groups=num_groups,
            group_topk=group_topk,
            scaling_factor=scaling_factor,
            score_function=score_function,
            expert_bias=_fused_expert_bias(expert_bias),
        )
    )

    case_name = (
        f"topk_router[{score_function}]"
        f"[tokens={num_tokens},experts={num_experts},topk={topk}]"
        f"[pre={use_pre_softmax},groups={group_topk is not None},bias={enable_bias}]"
    )
    record_property("torch_ms", round(torch_ms, 6))
    record_property("fused_ms", round(fused_ms, 6))
    record_property("speedup", round(torch_ms / fused_ms, 6))
    _print_perf_result(case_name, torch_ms, fused_ms)

    assert torch_ms > 0, _perf_assert_message(case_name, torch_ms, fused_ms)
    assert fused_ms > 0, _perf_assert_message(case_name, torch_ms, fused_ms)
    torch.testing.assert_close(torch_probs, fused_probs)
    torch.testing.assert_close(torch_map, fused_map)


@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize(
    "num_tokens,num_experts,topk",
    [(8192, 256, 8), (16384, 256, 8)],
    ids=["large", "xlarge"],
)
@pytest.mark.parametrize("score_function", ["softmax", "sigmoid", "sqrtsoftplus"])
def test_fused_scores_for_aux_loss_perf_against_torch(
    dtype, num_tokens, num_experts, topk, score_function, record_property
):
    _require_perf_env()

    logits = _make_router_logits(dtype, num_tokens, num_experts, score_function)

    torch_map, torch_scores = compute_scores_for_aux_loss_pytorch(
        logits=logits,
        topk=topk,
        score_function=score_function,
    )
    fused_map, fused_scores = cast(
        Tuple[torch.Tensor, torch.Tensor],
        fused_compute_score_for_moe_aux_loss(
            logits=logits,
            topk=topk,
            score_function=score_function,
        ),
    )

    torch_ms = _benchmark_cuda_kernel(
        lambda: compute_scores_for_aux_loss_pytorch(
            logits=logits,
            topk=topk,
            score_function=score_function,
        )
    )
    fused_ms = _benchmark_cuda_kernel(
        lambda: fused_compute_score_for_moe_aux_loss(
            logits=logits,
            topk=topk,
            score_function=score_function,
        )
    )

    case_name = (
        f"scores_for_aux_loss[{score_function}]"
        f"[tokens={num_tokens},experts={num_experts},topk={topk}]"
    )
    record_property("torch_ms", round(torch_ms, 6))
    record_property("fused_ms", round(fused_ms, 6))
    record_property("speedup", round(torch_ms / fused_ms, 6))
    _print_perf_result(case_name, torch_ms, fused_ms)

    assert torch_ms > 0, _perf_assert_message(case_name, torch_ms, fused_ms)
    assert fused_ms > 0, _perf_assert_message(case_name, torch_ms, fused_ms)
    torch.testing.assert_close(torch_scores, fused_scores)
    torch.testing.assert_close(torch_map, fused_map)


@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize(
    "num_tokens,num_experts,topk",
    [(14234, 256, 4), (28672, 256, 8)],
    ids=["large", "xlarge"],
)
@pytest.mark.parametrize("coeff", [0.01, 0.05])
def test_fused_moe_aux_loss_perf_against_torch(
    dtype, num_tokens, num_experts, topk, coeff, record_property
):
    _require_perf_env()

    probs = _make_aux_loss_probs(dtype, num_tokens, num_experts)
    tokens_per_expert = torch.randint(
        1, 1000, (num_experts,), device="cuda", dtype=torch.int32
    )

    torch_loss = aux_loss_pytorch(
        probs=probs,
        tokens_per_expert=tokens_per_expert,
        total_num_tokens=num_tokens,
        topk=topk,
        num_experts=num_experts,
        moe_aux_loss_coeff=coeff,
    )
    fused_loss = fused_moe_aux_loss(
        probs=probs,
        tokens_per_expert=tokens_per_expert,
        total_num_tokens=num_tokens,
        num_experts=num_experts,
        topk=topk,
        coeff=coeff,
    )

    torch_ms = _benchmark_cuda_kernel(
        lambda: aux_loss_pytorch(
            probs=probs,
            tokens_per_expert=tokens_per_expert,
            total_num_tokens=num_tokens,
            topk=topk,
            num_experts=num_experts,
            moe_aux_loss_coeff=coeff,
        )
    )
    fused_ms = _benchmark_cuda_kernel(
        lambda: fused_moe_aux_loss(
            probs=probs,
            tokens_per_expert=tokens_per_expert,
            total_num_tokens=num_tokens,
            num_experts=num_experts,
            topk=topk,
            coeff=coeff,
        )
    )

    case_name = f"moe_aux_loss[tokens={num_tokens},experts={num_experts},coeff={coeff}]"
    record_property("torch_ms", round(torch_ms, 6))
    record_property("fused_ms", round(fused_ms, 6))
    record_property("speedup", round(torch_ms / fused_ms, 6))
    _print_perf_result(case_name, torch_ms, fused_ms)

    assert torch_ms > 0, _perf_assert_message(case_name, torch_ms, fused_ms)
    assert fused_ms > 0, _perf_assert_message(case_name, torch_ms, fused_ms)
    torch.testing.assert_close(torch_loss, fused_loss)

Result (H200 cuda13.1)

(1) (test_topk_router[softmax][tokens=8192,experts=256,topk=8][pre=True,groups=False,bias=False])
    before: 0.044309 ms
    after: 0.025076 ms
    speedup: 1.77x

(2) (test_topk_router[sigmoid][tokens=16384,experts=256,topk=8][pre=False,groups=False,bias=False])
    before: 0.077351 ms
    after: 0.042781 ms
    speedup: 1.81x

(3) (test_topk_router[softmax][tokens=16384,experts=256,topk=8][pre=False,groups=True,bias=False])
    before: 0.112017 ms
    after: 0.081958 ms
    speedup: 1.37x

(4) (test_scores_for_aux_loss[sigmoid][tokens=16384,experts=256,topk=8])
    before: 0.083227 ms
    after: 0.045734 ms
    speedup: 1.82x

(5) (test_topk_router[sqrtsoftplus][tokens=16384,experts=256,topk=8][pre=False,groups=True,bias=True])
    before: 0.119349 ms
    after: 0.089024 ms
    speedup: 1.34x

yosh20004 and others added 2 commits March 20, 2026 00:05
Refactor naive_topk_and_mask to track selections with a per-lane mask and reduce across the warp more directly. This keeps the top-k routing path cleaner while preserving the existing interface.

Co-authored-by: Guitar_Players XiaomingFun233 <xiaomingchinafun@outlook.com>
Signed-off-by: yosh20004 <2172622103@qq.com>
@yosh20004 yosh20004 force-pushed the perf/naive-topk-and-mask branch from 66dd207 to 1ced891 Compare March 19, 2026 16:16
Signed-off-by: yosh20005 <2172622103@qq.com>
@yosh20004 yosh20004 force-pushed the perf/naive-topk-and-mask branch from 1ced891 to 28844c1 Compare March 19, 2026 16:19
Signed-off-by: yosh20005 <2172622103@qq.com>
@yosh20004 yosh20004 changed the title Optimize top-k masking in fused router for cleaner routing Optimize naive top-k masking in fused router Mar 19, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant