Optimize naive top-k masking in fused router#2783
Optimize naive top-k masking in fused router#2783yosh20004 wants to merge 4 commits intoNVIDIA:mainfrom
Conversation
Greptile SummaryThis PR refactors Key points:
Confidence Score: 2/5
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["naive_topk_and_mask called\n(all 32 warp lanes)"] --> B["assert data_size ≤ 1024\n⚠ stripped in release builds"]
B --> C["for k = 0 .. topk-1"]
C --> D["Step 1: Per-lane local max\nfor i = lane_id; i < data_size; i += 32\n check local_mask bit_idx\n select unmasked scores[i]"]
D --> E["Step 2: __shfl_down_sync tree reduction\ns = 16 → 8 → 4 → 2 → 1\nLane 0 accumulates global max"]
E --> F["Broadcast from lane 0\n__shfl_sync(0xffffffff, global_max_idx, 0)\n__shfl_sync(0xffffffff, global_max_val, 0)"]
F --> G["Step 3: Lane 0 writes\ntopk_indices[k] = global_max_idx\ntopk_scores[k] = global_max_val"]
G --> H["Step 4: Owning lane sets bit\nif global_max_idx % 32 == lane_id\n local_mask |= 1u << local_bit_pos"]
H --> C
C -->|"k == topk"| I["__syncwarp()\nreturn"]
|
| int bit_idx = 0; | ||
| for (int i = lane_id; i < data_size; i += kThreadsPerWarp) { | ||
| CompType cur_val = 0.0f; | ||
| if constexpr (std::is_same_v<CompType, double>) { |
There was a problem hiding this comment.
if constexpr checks CompType, not T — double branch is permanently dead
The condition std::is_same_v<CompType, double> checks the alias CompType, which is unconditionally using CompType = float (line 19 of this file). This condition is therefore always false, making the entire double branch (__double_as_longlong, __longlong_as_double, 0xFFF0000000000000ULL) unreachable dead code.
The intent appears to be std::is_same_v<T, double> so that the branch activates when the function is instantiated with a double-precision type. However, note that even after that fix, the else branch also casts via static_cast<CompType>(scores[i]) (i.e., to float), so true double-precision correctness would still need a broader refactor. For now, the dead branch adds confusion with no benefit.
| if constexpr (std::is_same_v<CompType, double>) { | |
| if constexpr (std::is_same_v<T, double>) { |
| for (int i = lane_id; i < data_size; i += kThreadsPerWarp) { | ||
| CompType cur_val = 0.0f; | ||
| if constexpr (std::is_same_v<CompType, double>) { | ||
| uint64_t mask = -(uint64_t)((local_mask >> bit_idx) & 1u); | ||
| uint64_t x_bits = __double_as_longlong(static_cast<CompType>(scores[i])); | ||
| uint64_t result_bits = (~mask & x_bits) | (mask & 0xFFF0000000000000ULL); | ||
| cur_val = __longlong_as_double(result_bits); | ||
| } else { | ||
| uint32_t full_mask = -(uint32_t)((local_mask >> bit_idx) & 1u); | ||
| uint32_t x_bits = __float_as_uint(static_cast<CompType>(scores[i])); | ||
| uint32_t result_bits = (~full_mask & x_bits) | (full_mask & 0xFF800000u); | ||
| cur_val = __uint_as_float(result_bits); | ||
| } | ||
| if (cur_val > local_max_val) { | ||
| local_max_val = cur_val; | ||
| local_max_idx = i; | ||
| } | ||
| bit_idx++; | ||
| } | ||
| // Warp shuffle between threads | ||
| for (int s = 16; s > 0; s /= 2) { | ||
| auto shuffled_val = __shfl_xor_sync(0xffffffff, val, s); | ||
| auto shuffled_index = __shfl_xor_sync(0xffffffff, index, s); | ||
| if (shuffled_val > val) { | ||
| val = shuffled_val; | ||
| index = shuffled_index; | ||
|
|
||
| // 2) Warp reduction to find global max and index. | ||
| CompType global_max_val = local_max_val; | ||
| int global_max_idx = local_max_idx; | ||
| for (int s = kThreadsPerWarp / 2; s > 0; s /= 2) { | ||
| CompType shuffled_val = __shfl_down_sync(0xffffffff, global_max_val, s); | ||
| int shuffled_idx = __shfl_down_sync(0xffffffff, global_max_idx, s); | ||
| if (shuffled_val > global_max_val) { | ||
| global_max_val = shuffled_val; | ||
| global_max_idx = shuffled_idx; | ||
| } | ||
| } | ||
| global_max_idx = __shfl_sync(0xffffffff, global_max_idx, 0); | ||
| global_max_val = __shfl_sync(0xffffffff, global_max_val, 0); | ||
|
|
||
| // 3) Write top-k result. | ||
| if (lane_id == 0) { | ||
| topk_indices[k] = index; | ||
| topk_scores[k] = val; | ||
| topk_indices[k] = global_max_idx; | ||
| topk_scores[k] = static_cast<T>(global_max_val); | ||
| } | ||
|
|
||
| // 4) Mark selected element in owning lane's local mask. | ||
| if (global_max_idx >= 0 && (global_max_idx % kThreadsPerWarp) == lane_id) { | ||
| int local_bit_pos = global_max_idx / kThreadsPerWarp; | ||
| if (local_bit_pos < 32) { | ||
| local_mask |= (1u << local_bit_pos); | ||
| } |
There was a problem hiding this comment.
32-bit
local_mask silently breaks correctness for data_size > 1024
local_mask is a uint32_t, giving each lane exactly 32 bits to track up to 32 local elements. Each lane owns element indices at lane_id, lane_id + 32, lane_id + 64, …, so the maximum trackable data_size is 32 * kThreadsPerWarp = 1024.
When data_size > 1024, two related problems arise:
1. Undefined-behavior shift in the inner loop (line 226 / 221):
uint32_t full_mask = -(uint32_t)((local_mask >> bit_idx) & 1u);bit_idx increments to 32+ when data_size > 1024. Right-shifting a uint32_t by ≥ 32 is undefined behaviour in C++. On CUDA PTX the hardware clamps the result to 0, so the element is effectively never masked — but that itself causes the second problem.
2. Silent double-selection via the guard on line 261:
if (local_bit_pos < 32) {
local_mask |= (1u << local_bit_pos);
}When global_max_idx / kThreadsPerWarp ≥ 32, the guard silently skips the bit-set, leaving the already-selected element unmasked. In the next top-k iteration the same element is eligible for selection again, producing duplicate indices in topk_indices.
The original is_masked lambda had no such limit because it scanned the full topk_indices array. While current usage (num_experts ≤ a few hundred) keeps data_size well under 1024, the silent failure mode is dangerous and should at minimum be guarded with a compile-time or runtime assertion:
// At function entry or as a static_assert at the call site:
assert(data_size <= static_cast<int>(sizeof(local_mask) * 8 * kThreadsPerWarp) &&
"local_mask too small for data_size > 1024");Or switch to a uint64_t mask (doubling the safe range to 2048) and add the assertion for anything larger.
| CompType global_max_val = local_max_val; | ||
| int global_max_idx = local_max_idx; | ||
| for (int s = kThreadsPerWarp / 2; s > 0; s /= 2) { | ||
| CompType shuffled_val = __shfl_down_sync(0xffffffff, global_max_val, s); | ||
| int shuffled_idx = __shfl_down_sync(0xffffffff, global_max_idx, s); | ||
| if (shuffled_val > global_max_val) { | ||
| global_max_val = shuffled_val; | ||
| global_max_idx = shuffled_idx; | ||
| } | ||
| } | ||
| global_max_idx = __shfl_sync(0xffffffff, global_max_idx, 0); | ||
| global_max_val = __shfl_sync(0xffffffff, global_max_val, 0); |
There was a problem hiding this comment.
__shfl_down_sync reduction does not guarantee a stable winner on ties
The original code used a butterfly (__shfl_xor_sync) reduction pattern where all threads naturally converge to the same result in 5 rounds. The new __shfl_down_sync tree reduction still correctly delivers the global maximum to lane 0 (then broadcast via __shfl_sync), but the tie-breaking behaviour when two lanes hold the same maximum value is different from the original: with shuffled_val > global_max_val (strict >), the lower-indexed lane's index is preserved on a tie, which subtly differs from the XOR-butterfly order.
This is unlikely to matter for a floating-point router where exact ties are rare, but it is a behavioural change not called out in the PR description. If determinism across refactors matters for reproducibility testing, this should be documented or a consistent tie-breaking rule (e.g., prefer smaller index) should be explicitly enforced in both the value comparison and the index comparison:
if (shuffled_val > global_max_val ||
(shuffled_val == global_max_val && shuffled_idx < global_max_idx)) {
global_max_val = shuffled_val;
global_max_idx = shuffled_idx;
}
Performance test script# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import os
from typing import Any, Callable, Optional, Tuple, cast
import pytest
import torch
from transformer_engine.pytorch.router import (
fused_compute_score_for_moe_aux_loss,
fused_moe_aux_loss,
fused_topk_with_score_function,
)
seed = 42
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
def _cuda() -> Any:
return getattr(torch, "cuda", None)
def _require_perf_env() -> None:
if _cuda() is None or not _cuda().is_available():
pytest.skip("CUDA is not available.")
if os.getenv("TE_RUN_PERF_TESTS", "1") != "1":
pytest.skip("Set TE_RUN_PERF_TESTS=1 to enable router performance tests.")
def _benchmark_cuda_kernel(
fn: Callable[[], object], warmup: int = 500, iters: int = 2000
) -> float:
start_event = _cuda().Event(enable_timing=True)
end_event = _cuda().Event(enable_timing=True)
for _ in range(warmup):
fn()
_cuda().synchronize()
start_event.record()
for _ in range(iters):
fn()
end_event.record()
_cuda().synchronize()
return start_event.elapsed_time(end_event) / iters
def group_limited_topk(
scores: torch.Tensor,
topk: int,
num_tokens: int,
num_experts: int,
num_groups: int,
group_topk: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
group_scores = (
scores.view(num_tokens, num_groups, -1)
.topk(topk // group_topk, dim=-1)[0]
.sum(dim=-1)
)
group_idx = torch.topk(group_scores, k=group_topk, dim=-1, sorted=False)[1]
group_mask = torch.zeros_like(group_scores)
group_mask.scatter_(1, group_idx, 1)
score_mask = (
group_mask.unsqueeze(-1)
.expand(num_tokens, num_groups, num_experts // num_groups)
.reshape(num_tokens, -1)
)
masked_scores = scores.masked_fill(~score_mask.bool(), float("-inf"))
probs, top_indices = torch.topk(masked_scores, k=topk, dim=-1)
return probs, top_indices
def topk_score_function_pytorch(
logits: torch.Tensor,
topk: int,
use_pre_softmax: bool = False,
num_groups: Optional[int] = None,
group_topk: Optional[int] = None,
scaling_factor: Optional[float] = None,
score_function: str = "softmax",
expert_bias: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
num_tokens, num_experts = logits.shape
def compute_topk(
scores: torch.Tensor,
topk_value: int,
num_groups_value: Optional[int] = None,
group_topk_value: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if group_topk_value:
assert num_groups_value is not None
return group_limited_topk(
scores=scores,
topk=topk_value,
num_tokens=num_tokens,
num_experts=num_experts,
num_groups=num_groups_value,
group_topk=group_topk_value,
)
return torch.topk(scores, k=topk_value, dim=1)
if score_function == "softmax":
if use_pre_softmax:
scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits)
probs, top_indices = compute_topk(scores, topk, num_groups, group_topk)
else:
scores, top_indices = compute_topk(logits, topk, num_groups, group_topk)
probs = torch.softmax(scores, dim=-1, dtype=torch.float32).type_as(logits)
elif score_function in ("sigmoid", "sqrtsoftplus"):
if score_function == "sigmoid":
scores = torch.sigmoid(logits.float()).type_as(logits)
else:
scores = torch.nn.functional.softplus(logits.float()).sqrt().type_as(logits)
if expert_bias is not None:
scores_for_routing = scores + expert_bias
_, top_indices = compute_topk(
scores_for_routing, topk, num_groups, group_topk
)
scores = torch.gather(scores, dim=1, index=top_indices).type_as(logits)
else:
scores, top_indices = compute_topk(scores, topk, num_groups, group_topk)
probs = (
scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if topk > 1 else scores
)
else:
raise ValueError(f"Invalid score_function: {score_function}")
if scaling_factor:
probs = probs * scaling_factor
topk_masked_gates = torch.zeros_like(logits).scatter(1, top_indices, probs)
topk_map = torch.zeros_like(logits).int().scatter(1, top_indices, 1).bool()
return topk_masked_gates, topk_map
def compute_scores_for_aux_loss_pytorch(
logits: torch.Tensor, topk: int, score_function: str
) -> Tuple[torch.Tensor, torch.Tensor]:
if score_function == "softmax":
scores = torch.softmax(logits, dim=-1, dtype=torch.float32)
elif score_function == "sigmoid":
scores = torch.sigmoid(logits.float())
scores = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20)
elif score_function == "sqrtsoftplus":
scores = torch.nn.functional.softplus(logits.float()).sqrt()
scores = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20)
else:
raise ValueError(f"Invalid score_function: {score_function}")
_, top_indices = torch.topk(scores, k=topk, dim=1)
routing_map = torch.zeros_like(logits).int().scatter(1, top_indices, 1).bool()
return routing_map, scores
def aux_loss_pytorch(
probs: torch.Tensor,
tokens_per_expert: torch.Tensor,
total_num_tokens: int,
topk: int,
num_experts: int,
moe_aux_loss_coeff: float,
) -> torch.Tensor:
aggregated_probs_per_expert = probs.sum(dim=0)
return torch.sum(aggregated_probs_per_expert * tokens_per_expert) * (
num_experts * moe_aux_loss_coeff / (topk * total_num_tokens * total_num_tokens)
)
def _make_router_logits(
dtype: torch.dtype, num_tokens: int, num_experts: int, score_function: str
) -> torch.Tensor:
if score_function in ("sigmoid", "sqrtsoftplus"):
offset = (
torch.arange(-num_tokens // 2, num_tokens // 2, dtype=dtype, device="cuda")
* 1e-4
)
logits = (
torch.arange(
-num_experts // 2, num_experts // 2, device="cuda", dtype=dtype
)
* 1e-2
)
return logits.unsqueeze(0).repeat(num_tokens, 1) + offset.unsqueeze(1)
logits = (
torch.arange(
-num_tokens * num_experts // 2,
num_tokens * num_experts // 2,
device="cuda",
dtype=dtype,
)
* 1e-4
)
return logits.view(num_tokens, num_experts)
def _make_router_bias(num_experts: int, dtype: torch.dtype) -> torch.Tensor:
bias = torch.arange(num_experts, device="cuda", dtype=dtype) * 0.1
return torch.flip(bias, dims=[0])
def _make_aux_loss_probs(
dtype: torch.dtype, num_tokens: int, num_experts: int
) -> torch.Tensor:
offset = (
torch.arange(-num_tokens // 2, num_tokens // 2, dtype=dtype, device="cuda")
* 1e-4
)
probs = (
torch.arange(-num_experts // 2, num_experts // 2, device="cuda", dtype=dtype)
* 1e-2
)
probs = probs.unsqueeze(0).repeat(num_tokens, 1) + offset.unsqueeze(1)
return probs.view(num_tokens, num_experts)
def _fused_expert_bias(expert_bias: Optional[torch.Tensor]) -> torch.Tensor:
if expert_bias is None:
return cast(torch.Tensor, None)
return expert_bias
def _print_perf_result(case_name: str, torch_ms: float, fused_ms: float) -> None:
speedup = torch_ms / fused_ms
print(
f"{case_name}: torch={torch_ms:.6f} ms, fused={fused_ms:.6f} ms, speedup={speedup:.4f}x"
)
def _perf_assert_message(case_name: str, torch_ms: float, fused_ms: float) -> str:
return (
f"{case_name} perf result: torch={torch_ms:.6f} ms, "
f"fused={fused_ms:.6f} ms, speedup={torch_ms / fused_ms:.4f}x"
)
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize(
"num_tokens,num_experts,topk",
[(8192, 256, 8), (16384, 256, 8)],
ids=["large", "xlarge"],
)
@pytest.mark.parametrize(
"score_function,use_pre_softmax,enable_bias,num_groups,group_topk,scaling_factor",
[
("softmax", True, False, None, None, None),
("softmax", False, False, 8, 4, 1.2),
("sigmoid", False, False, None, None, None),
("sigmoid", False, True, 8, 4, 1.2),
("sqrtsoftplus", False, False, None, None, None),
("sqrtsoftplus", False, True, 8, 4, 1.2),
],
ids=[
"softmax_pre",
"softmax_grouped",
"sigmoid_plain",
"sigmoid_grouped_bias",
"sqrtsoftplus_plain",
"sqrtsoftplus_grouped_bias",
],
)
def test_fused_topk_router_perf_against_torch(
dtype,
num_tokens,
num_experts,
topk,
score_function,
use_pre_softmax,
enable_bias,
num_groups,
group_topk,
scaling_factor,
record_property,
):
_require_perf_env()
logits = _make_router_logits(dtype, num_tokens, num_experts, score_function)
expert_bias = None
if enable_bias and score_function in ("sigmoid", "sqrtsoftplus"):
expert_bias = _make_router_bias(num_experts, dtype)
torch_probs, torch_map = topk_score_function_pytorch(
logits=logits,
topk=topk,
use_pre_softmax=use_pre_softmax,
num_groups=num_groups,
group_topk=group_topk,
scaling_factor=scaling_factor,
score_function=score_function,
expert_bias=expert_bias,
)
fused_probs, fused_map = cast(
Tuple[torch.Tensor, torch.Tensor],
fused_topk_with_score_function(
logits=logits,
topk=topk,
use_pre_softmax=use_pre_softmax,
num_groups=num_groups,
group_topk=group_topk,
scaling_factor=scaling_factor,
score_function=score_function,
expert_bias=_fused_expert_bias(expert_bias),
),
)
torch_ms = _benchmark_cuda_kernel(
lambda: topk_score_function_pytorch(
logits=logits,
topk=topk,
use_pre_softmax=use_pre_softmax,
num_groups=num_groups,
group_topk=group_topk,
scaling_factor=scaling_factor,
score_function=score_function,
expert_bias=expert_bias,
)
)
fused_ms = _benchmark_cuda_kernel(
lambda: fused_topk_with_score_function(
logits=logits,
topk=topk,
use_pre_softmax=use_pre_softmax,
num_groups=num_groups,
group_topk=group_topk,
scaling_factor=scaling_factor,
score_function=score_function,
expert_bias=_fused_expert_bias(expert_bias),
)
)
case_name = (
f"topk_router[{score_function}]"
f"[tokens={num_tokens},experts={num_experts},topk={topk}]"
f"[pre={use_pre_softmax},groups={group_topk is not None},bias={enable_bias}]"
)
record_property("torch_ms", round(torch_ms, 6))
record_property("fused_ms", round(fused_ms, 6))
record_property("speedup", round(torch_ms / fused_ms, 6))
_print_perf_result(case_name, torch_ms, fused_ms)
assert torch_ms > 0, _perf_assert_message(case_name, torch_ms, fused_ms)
assert fused_ms > 0, _perf_assert_message(case_name, torch_ms, fused_ms)
torch.testing.assert_close(torch_probs, fused_probs)
torch.testing.assert_close(torch_map, fused_map)
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize(
"num_tokens,num_experts,topk",
[(8192, 256, 8), (16384, 256, 8)],
ids=["large", "xlarge"],
)
@pytest.mark.parametrize("score_function", ["softmax", "sigmoid", "sqrtsoftplus"])
def test_fused_scores_for_aux_loss_perf_against_torch(
dtype, num_tokens, num_experts, topk, score_function, record_property
):
_require_perf_env()
logits = _make_router_logits(dtype, num_tokens, num_experts, score_function)
torch_map, torch_scores = compute_scores_for_aux_loss_pytorch(
logits=logits,
topk=topk,
score_function=score_function,
)
fused_map, fused_scores = cast(
Tuple[torch.Tensor, torch.Tensor],
fused_compute_score_for_moe_aux_loss(
logits=logits,
topk=topk,
score_function=score_function,
),
)
torch_ms = _benchmark_cuda_kernel(
lambda: compute_scores_for_aux_loss_pytorch(
logits=logits,
topk=topk,
score_function=score_function,
)
)
fused_ms = _benchmark_cuda_kernel(
lambda: fused_compute_score_for_moe_aux_loss(
logits=logits,
topk=topk,
score_function=score_function,
)
)
case_name = (
f"scores_for_aux_loss[{score_function}]"
f"[tokens={num_tokens},experts={num_experts},topk={topk}]"
)
record_property("torch_ms", round(torch_ms, 6))
record_property("fused_ms", round(fused_ms, 6))
record_property("speedup", round(torch_ms / fused_ms, 6))
_print_perf_result(case_name, torch_ms, fused_ms)
assert torch_ms > 0, _perf_assert_message(case_name, torch_ms, fused_ms)
assert fused_ms > 0, _perf_assert_message(case_name, torch_ms, fused_ms)
torch.testing.assert_close(torch_scores, fused_scores)
torch.testing.assert_close(torch_map, fused_map)
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize(
"num_tokens,num_experts,topk",
[(14234, 256, 4), (28672, 256, 8)],
ids=["large", "xlarge"],
)
@pytest.mark.parametrize("coeff", [0.01, 0.05])
def test_fused_moe_aux_loss_perf_against_torch(
dtype, num_tokens, num_experts, topk, coeff, record_property
):
_require_perf_env()
probs = _make_aux_loss_probs(dtype, num_tokens, num_experts)
tokens_per_expert = torch.randint(
1, 1000, (num_experts,), device="cuda", dtype=torch.int32
)
torch_loss = aux_loss_pytorch(
probs=probs,
tokens_per_expert=tokens_per_expert,
total_num_tokens=num_tokens,
topk=topk,
num_experts=num_experts,
moe_aux_loss_coeff=coeff,
)
fused_loss = fused_moe_aux_loss(
probs=probs,
tokens_per_expert=tokens_per_expert,
total_num_tokens=num_tokens,
num_experts=num_experts,
topk=topk,
coeff=coeff,
)
torch_ms = _benchmark_cuda_kernel(
lambda: aux_loss_pytorch(
probs=probs,
tokens_per_expert=tokens_per_expert,
total_num_tokens=num_tokens,
topk=topk,
num_experts=num_experts,
moe_aux_loss_coeff=coeff,
)
)
fused_ms = _benchmark_cuda_kernel(
lambda: fused_moe_aux_loss(
probs=probs,
tokens_per_expert=tokens_per_expert,
total_num_tokens=num_tokens,
num_experts=num_experts,
topk=topk,
coeff=coeff,
)
)
case_name = f"moe_aux_loss[tokens={num_tokens},experts={num_experts},coeff={coeff}]"
record_property("torch_ms", round(torch_ms, 6))
record_property("fused_ms", round(fused_ms, 6))
record_property("speedup", round(torch_ms / fused_ms, 6))
_print_perf_result(case_name, torch_ms, fused_ms)
assert torch_ms > 0, _perf_assert_message(case_name, torch_ms, fused_ms)
assert fused_ms > 0, _perf_assert_message(case_name, torch_ms, fused_ms)
torch.testing.assert_close(torch_loss, fused_loss)Result (H200 cuda13.1) |
Refactor naive_topk_and_mask to track selections with a per-lane mask and reduce across the warp more directly. This keeps the top-k routing path cleaner while preserving the existing interface. Co-authored-by: Guitar_Players XiaomingFun233 <xiaomingchinafun@outlook.com> Signed-off-by: yosh20004 <2172622103@qq.com>
for more information, see https://pre-commit.ci
66dd207 to
1ced891
Compare
Signed-off-by: yosh20005 <2172622103@qq.com>
1ced891 to
28844c1
Compare
Signed-off-by: yosh20005 <2172622103@qq.com>
Description
Refactor the fused router's naive top-k selection path to track already-selected entries with a per-lane mask
Fixes # (issue)
Type of change
Changes
naive_topk_and_maskintransformer_engine/common/fused_router/utils.hto track per-lane selections with a local bitmaskChecklist: