Skip to content

[PyTorch] NVFP4 RHT cast-fusion: emit GEMM-swizzled scale factors directly#3011

Open
cael-ling wants to merge 8 commits into
NVIDIA:mainfrom
cael-ling:feature/nvfp4-rht-cast-fusion-swizzled-sf-output
Open

[PyTorch] NVFP4 RHT cast-fusion: emit GEMM-swizzled scale factors directly#3011
cael-ling wants to merge 8 commits into
NVIDIA:mainfrom
cael-ling:feature/nvfp4-rht-cast-fusion-swizzled-sf-output

Conversation

@cael-ling
Copy link
Copy Markdown
Contributor

@cael-ling cael-ling commented May 19, 2026

Description

Before this PR every NVFP4 RHT-cast-fusion quantize was followed by two standalone swizzle kernels (rowwise + columnwise) whose only job was to move scale factors into the layout cuBLAS LT consumes. The cast-fusion kernel already had a kEnableSwizzleSFOutput switch for that, but the framework never set the matching with_gemm_swizzled_scales flag on
NVFP4 outputs -- it was a false with a TODO. This PR wires it through and saves ~25 us per quantize on LLaMA-class shapes (1.18x – 1.36x on the quant + swizzle path that te.Linear runs).

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Kernel side (transformer_engine/common/hadamard_transform/):

  • row_cast_col_hadamard_transform_cast_fusion.cu &
    group_row_cast_col_hadamard_transform_cast_fusion.cu: drive the
    existing kEnableSwizzleSFOutput template parameter from
    output.with_gemm_swizzled_scales. The grouped kernel additionally
    NVTE_CHECKs the flag is consistent across all tensors in a group
    (it honours a single boolean).
  • The graph-safe grouped variant already had this wired correctly --
    no change.

Framework side (transformer_engine/pytorch/csrc/):

  • New static helper NVFP4Quantizer::is_eligible_for_rht_cast_fusion(rows, cols)
    mirroring the dispatch-time eligibility check in
    NVFP4Quantizer::quantize_impl (rows%64==0 && cols%128==0 && SM100/110).
  • NVFP4Quantizer::create_tensor, NVFP4Quantizer::convert_and_update_tensor,
    and bulk_allocate_nvfp4_tensors now set
    with_gemm_swizzled_scales = optimize_for_gemm && with_rht && shape_eligible.
    For the grouped allocator the flag is True only if every tensor in
    the group is eligible.
  • Belt-and-suspenders NVTE_CHECK(!out.with_gemm_swizzled_scales) at
    the entry of quantize_with_rht_unfused_helper. The framework gate
    already keeps user code from tripping it; this only fires if a future
    low-level caller bypasses the gate.

Performance

SM100a, bf16 input, rowwise + columnwise SF, RHT + post-RHT amax.
Per-quantize wall-clock median via torch.utils.benchmark.Timer.blocked_autorange.
quant + swizzle = quantizer(x); tex.swizzle_scales_for_gemm_(t) --
exactly what te.Linear runs before its GEMM.

shape baseline SUT saved speedup note
(8192, 5120) 108.6 us 81.9 us 26.6 us 1.33x eligible
(8192, 10240) 107.8 us 90.2 us 17.5 us 1.19x eligible
(8192, 2560) 107.7 us 79.9 us 27.8 us 1.35x eligible
(8192, 11328) 236.3 us 236.3 us 0.0 us 1.00x ineligible
(8192, 3584) 106.0 us 78.6 us 27.4 us 1.35x eligible
(5120, 8192) 101.2 us 76.0 us 25.3 us 1.33x eligible
(10240, 8192) 107.8 us 90.4 us 17.4 us 1.19x eligible
(2560, 8192) 101.4 us 74.9 us 26.4 us 1.35x eligible
(11328, 8192) 114.4 us 93.2 us 21.2 us 1.23x eligible
(3584, 8192) 101.6 us 74.9 us 26.7 us 1.36x eligible
(4096, 16384) 100.2 us 75.0 us 25.2 us 1.34x eligible
(14336, 16384) 232.1 us 197.5 us 34.6 us 1.18x eligible
  • 11/12 shapes get 1.18x – 1.36x on the quant + swizzle path.
  • The single ineligible shape (8192, 11328) shows 1.00x as expected;
    the gate clamped, the unfused fallback ran, and the result is byte-
    identical to baseline (no regression, no crash).
  • quant_only is unchanged on all shapes within noise -- writing
    swizzled SF inside the cast-fusion kernel is essentially free; the
    entire win comes from eliminating the standalone swizzle pass.
    Repro: benchmarks/benchmark_rht_cast_swizzle_fusion.py (also has a
    --profile mode for ncu / nsys).

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

…ectly

Before this PR every NVFP4 RHT-cast-fusion quantize was followed by two
standalone swizzle kernels (rowwise + columnwise) whose only job was to
move scale factors into the layout cuBLAS LT consumes. The cast-fusion
kernel already had a `kEnableSwizzleSFOutput` switch for that, but the
framework never set the matching `with_gemm_swizzled_scales` flag on
NVFP4 outputs -- it was a `false` with a TODO. This PR wires it through.

Changes:
* Single + grouped Hadamard cast-fusion kernels: drive
  `kEnableSwizzleSFOutput` from `output.with_gemm_swizzled_scales`.
* NVFP4Quantizer create_tensor / convert_and_update_tensor /
  bulk_allocate_nvfp4_tensors: set the flag when
  `optimize_for_gemm && with_rht && shape eligible`, with eligibility
  in a new static helper NVFP4Quantizer::is_eligible_for_rht_cast_fusion
  (rows%64==0 && cols%128==0 && SM100/110) shared by all three sites.
* Belt-and-suspenders NVTE_CHECK in quantize_with_rht_unfused_helper
  in case a future low-level caller bypasses the gate.

The shape gate is part of this PR (not a follow-up) because LLaMA-class
shapes like (8192, 11328) have K%128==64. Without the gate the framework
would set the flag, dispatch would fall to the unfused path that can't
emit swizzled SF, and the process would abort. With the gate, ineligible
shapes silently fall back to the original code path.

Numbers (GB200 SM100, bf16, rowwise+columnwise, RHT, per-quantize median,
`quant + swizzle` path -- what te.Linear actually runs):

  (8192,  5120)    108.6 ->  81.9 us   1.33x   eligible
  (8192, 11328)    236.3 -> 236.3 us   1.00x   ineligible, gate clamped
  (11328, 8192)    114.4 ->  93.2 us   1.23x   eligible
  (14336,16384)    232.1 -> 197.5 us   1.18x   eligible

11/12 production-class shapes get 1.18x - 1.36x. The one ineligible
shape gets 1.00x (= unchanged, no regression). `quant_only` is unchanged
across all shapes -- the savings come entirely from eliminating the
standalone swizzle pass, not from a faster quant kernel.

Repro: benchmarks/benchmark_rht_cast_swizzle_fusion.py

Tests:
* new tests/pytorch/nvfp4/test_nvfp4_rht_quantize_swizzle_fusion.py:
  byte-equal SF / FP4 data / amax vs swizzled reference; plus 5 cases
  verifying the shape gate clamps correctly and that quantizer(x) on an
  ineligible shape does not raise.
* tests/pytorch/nvfp4/test_nvfp4_group_quantize.py: added
  optimize_for_gemm parametrization for the legacy grouped path.
* test_nvfp4_group_quantize_graph_safe.py passes unchanged (graph-safe
  variant already had the wiring).

Signed-off-by: Cael Ling <caell@nvidia.com>
@cael-ling cael-ling requested a review from ksivaman as a code owner May 19, 2026 03:49
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 19, 2026

Greptile Summary

This PR wires the existing kEnableSwizzleSFOutput kernel switch through the framework by setting with_gemm_swizzled_scales = optimize_for_gemm && with_rht && shape_eligible in create_tensor, convert_and_update_tensor, and bulk_allocate_nvfp4_tensors, replacing a hardcoded false with a TODO. A new static helper NVFP4Quantizer::is_eligible_for_rht_cast_fusion centralizes the shape/SM-arch gate, and a belt-and-suspenders NVTE_CHECK in the unfused fallback prevents silent SF-layout corruption if the gate is ever bypassed.

  • Kernel side: both single-tensor and grouped cast-fusion .cu files now read with_gemm_swizzled_scales from their output tensor(s) instead of a constant false; the grouped kernel adds consistency NVTE_CHECKs across all group members.
  • Framework side: cast.cpp gains NVTE_CHECK loops enforcing uniform optimize_for_gemm/with_rht across a group before computing the shared with_gemm_swizzled_scales flag; shape eligibility is checked per-tensor with for_grouped_kernel=true (128-row alignment).
  • Tests & benchmarks: a new single-tensor test file covers byte-equal SF fidelity and the shape-gate clamping behavior; the group test is parametrized over optimize_for_gemm; two benchmark scripts reproduce the 1.18×–1.36× quant + swizzle speedup reported in the PR.

Confidence Score: 5/5

Safe to merge; the change correctly wires an existing kernel switch through a well-guarded framework path with no silent fallback.

The core logic — shape/arch gating, group consistency checks, and the unfused-path NVTE_CHECK — is correct and comprehensively tested. The two style nits have no runtime impact. The previously noted missing skipif guard on test_nvfp4_rht_swizzle_fusion_shape_gate is the only item that could cause false CI failures on non-SM100/110 runners, but it does not affect correctness of the production code path.

No files require special attention for correctness; tests/pytorch/nvfp4/test_nvfp4_rht_quantize_swizzle_fusion.py has a known missing hardware skip guard on one test function addressed in earlier review threads.

Important Files Changed

Filename Overview
transformer_engine/pytorch/csrc/quantizer.cpp Core wiring: new is_eligible_for_rht_cast_fusion static helper replaces inline eligibility check; create_tensor and convert_and_update_tensor now compute with_gemm_swizzled_scales from optimize_for_gemm && with_rht && eligibility; quantize_with_rht_unfused_helper gets a belt-and-suspenders NVTE_CHECK. Logic is correct; minor inconsistent py::cast usage in convert_and_update_tensor.
transformer_engine/pytorch/csrc/extensions/cast.cpp Replaces the false TODO placeholder for with_gemm_swizzled_scales in bulk_allocate_nvfp4_tensors; adds NVTE_CHECK loops enforcing optimize_for_gemm and with_rht consistency across the group; per-tensor shape eligibility loop uses for_grouped_kernel=true (128-row alignment). Logic is correct and matches the grouped kernel's NVTE_CHECK.
transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu Removes the TODO and wires use_swizzle_sf_output from output_.with_gemm_swizzled_scales. Minimal, correct change enabling the existing kEnableSwizzleSFOutput template switch.
transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu Removes the false hardcode and reads use_swizzle_sf_output from output_list[0]->with_gemm_swizzled_scales; adds a consistency NVTE_CHECK across all group members and a non-empty guard.
tests/pytorch/nvfp4/test_nvfp4_rht_quantize_swizzle_fusion.py New test file verifying byte-equal SF output for the swizzle-fusion path. test_nvfp4_rht_quantize_swizzle_fusion has the required skipif guard; test_nvfp4_rht_swizzle_fusion_shape_gate is missing the guard and will error on non-SM100/110 hardware.
tests/pytorch/nvfp4/test_nvfp4_group_quantize.py Adds optimize_for_gemm parametrization to the group quantize test; correctly gates on with_rht=True, N%128==0, and M>=512 before exercising the grouped fused path.
benchmarks/benchmark_rht_cast_swizzle_fusion.py New benchmark measuring baseline vs. swizzle-fusion wall-clock on the quant + swizzle path; well-structured with documented measurement methodology.
benchmarks/profile_rht_cast_swizzle_fusion.py Profiling script verifying the standalone swizzle kernels disappear from the timeline; import re placed after the make_quantizer function (PEP 8 violation noted in previous review).
transformer_engine/pytorch/csrc/common.h Adds the is_eligible_for_rht_cast_fusion static declaration. Docstring could note the missing dtype constraint; code is safe given RHT enforces bf16-only at dispatch time.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["quantizer(x) called\n(optimize_for_gemm=True, with_rht=True)"] --> B["create_tensor / convert_and_update_tensor\ncompute with_gemm_swizzled_scales"]
    B --> C{"optimize_for_gemm && with_rht\n&& is_eligible(shape)?\n(rows%64==0, cols%128==0, SM100/110)"}
    C -- "Yes" --> D["with_gemm_swizzled_scales = True"]
    C -- "No" --> E["with_gemm_swizzled_scales = False"]
    D --> F["quantize_impl: eligible_for_rht_cast_fusion =\nbf16 dtype && is_eligible(shape)"]
    E --> F
    F -- "Eligible" --> G["hadamard_transform_cast_fusion\nkEnableSwizzleSFOutput=True\nSF in GEMM-swizzled layout"]
    F -- "Ineligible" --> H["quantize_with_rht_unfused_helper\nNVTE_CHECK(!with_gemm_swizzled_scales)"]
    G --> I["swizzle_scales_for_gemm_ early-returns"]
    H --> J["compact SF emitted\nstandalone swizzle kernels run"]
    I --> K["cuBLAS LT NVFP4 GEMM"]
    J --> K
Loading

Reviews (6): Last reviewed commit: "Add license header to profile_rht_cast_s..." | Re-trigger Greptile

Comment on lines +751 to +753
const bool with_gemm_swizzled_scales = quantizer_cpp_list[0]->optimize_for_gemm &&
quantizer_cpp_list[0]->with_rht &&
all_tensors_rht_cast_fusion_eligible;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 optimize_for_gemm and with_rht read only from first quantizer without validation

with_gemm_swizzled_scales is derived exclusively from quantizer_cpp_list[0], so if any later quantizer in the group has a different optimize_for_gemm or with_rht value, its tensors are silently allocated with the wrong SF layout. The shape-eligibility loop below correctly iterates every tensor, but there is no matching check that all quantizers agree on optimize_for_gemm/with_rht. The split-quantize path at line 1276 documents this assumption explicitly (// Assume all quantizers have identical config); the same note or an NVTE_CHECK loop here would make the contract visible and consistent.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with_gemm_swizzled_scales was derived from quantizer_cpp_list[0]->optimize_for_gemm / with_rht without checking that other quantizers in the group agreed; if any later quantizer had a different value, its tensors would be silently allocated with the wrong SF layout.

Following the precedent of the split-quantize path at line 1276
(// Assume all quantizers have identical config), this commit:

  • adds an explicit comment block calling out the group-wide
    identical-config assumption and which fields this PR enforces
    vs. which are pre-existing;
  • adds an NVTE_CHECK loop enforcing identical optimize_for_gemm
    and with_rht across the group (the two fields the
    with_gemm_swizzled_scales gate depends on), with error messages
    that print the offending tensor index and the disagreeing values;
  • extracts the [0] reads into group_optimize_for_gemm and
    group_with_rht locals so the same value feeds both the check
    and the gate.

Reviewer feedback: with_gemm_swizzled_scales was derived from
quantizer_cpp_list[0]->optimize_for_gemm / with_rht without checking
that other quantizers in the group agreed; if any later quantizer
had a different value, its tensors would be silently allocated with
the wrong SF layout.
Following the precedent of the split-quantize path at line 1276
(// Assume all quantizers have identical config), this commit:
  * adds an explicit comment block calling out the group-wide
    identical-config assumption and which fields this PR enforces
    vs. which are pre-existing;
  * adds an NVTE_CHECK loop enforcing identical optimize_for_gemm
    and with_rht across the group (the two fields the
    with_gemm_swizzled_scales gate depends on), with error messages
    that print the offending tensor index and the disagreeing values;
  * extracts the [0] reads into group_optimize_for_gemm and
    group_with_rht locals so the same value feeds both the check
    and the gate.
Other from-[0] reads (rowwise_usage, row_scaled_nvfp4,
columnwise_usage, scaling_mode, dtype) are pre-existing assumptions
and remain out of scope for this PR.
Signed-off-by: Cael Ling <caell@nvidia.com>
Comment on lines +722 to +732
// Quantization parameters. Like the NVFP4 split-quantize path
// (see split_quantize_nvfp4_impl in this file), we assume all
// quantizers in the group share an identical config and read
// group-wide flags from quantizer_cpp_list[0]. The grouped RHT
// cast-fusion kernel honours a single with_gemm_swizzled_scales
// boolean across the whole group, so optimize_for_gemm and with_rht
// must in particular agree across all quantizers; the NVTE_CHECK
// loop below enforces that for the fields the swizzled-SF gate
// depends on. (The other group-wide reads from [0] -- rowwise_usage,
// row_scaled_nvfp4, columnwise_usage, scaling_mode, dtype -- are
// pre-existing assumptions and out of scope for this PR.)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not super happy about the style of those comments - they reference multiple other files, and while right now the comment matches the reality, it will easily drift. We should concentrate on commenting the invariants and assumptions needed for this file only.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New commit removed the prose about dispatch internals and caller responsibilities.

Comment on lines +744 to +753
// Only the RHT cast-fusion quant kernel supports direct swizzled SF
// emission. Other NVFP4 quant kernels (e.g. nvte_quantize_v2 ->
// quantize_nvfp4.cuh, quantize_transpose_nvfp4.cuh) NVTE_CHECK reject
// a swizzled-flagged output, so we gate on with_rht to avoid silent
// data corruption / hard aborts on non-RHT paths. Additionally we
// require *all* tensors in the group to be shape-eligible for RHT
// cast-fusion, because the grouped kernel honours a single boolean
// and the unfused fallback rejects swizzled output (see NVTE_CHECK
// at group_row_cast_col_hadamard_transform_cast_fusion.cu and
// quantize_with_rht_unfused_helper).
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New commit removed the prose about dispatch internals and caller responsibilities.

Comment on lines +377 to +383
* Matches the dispatch logic in NVFP4Quantizer::quantize_impl.
* The dtype check (BF16) is implicit -- with_rht=True requires
* BF16 input by construction, so callers gate on with_rht first.
* When false, the dispatch falls back to quantize_with_rht_unfused
* which cannot emit GEMM-swizzled SF; framework gates that opt
* into with_gemm_swizzled_scales must therefore also check this
* to avoid mismatched-flag aborts in the fallback path.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, this is mostly talking about the internal implementation choices rather than what that function actually does (which is covered by the first sentence).

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New commit removed the prose about dispatch internals and caller responsibilities.

* into with_gemm_swizzled_scales must therefore also check this
* to avoid mismatched-flag aborts in the fallback path.
*/
static bool is_eligible_for_rht_cast_fusion(size_t rows, size_t cols);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't it take arbitrary shape rather than assuming it will be 2D?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Changed the signature to take the full tensor shape (const std::vector<size_t>& shape) and moved the get_2d_dims(...) flatten inside the function. All four call sites (create_tensor, convert_and_update_tensor, quantize_impl, and the grouped path in cast.cpp) now pass the shape directly without pre-flattening. The bulk loop in cast.cpp also no longer calls get_2d_dims per iteration since the function takes care of it.

Comment on lines +1764 to +1767
// Must mirror the eligibility check in NVFP4Quantizer::quantize_impl
// (search for "eligible_for_rht_cast_fusion" in this file). The dtype
// check (BF16) is implicit: with_rht is only valid for BF16 input by
// construction.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does it have to mirror the check in that other function? Considering that both of these functions are in the same file and in the same class, can't we just call one from the other to keep a single source of truth?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct, the proper fix is to call one from the other. quantize_impl now delegates the shape/arch predicate to NVFP4Quantizer::is_eligible_for_rht_cast_fusion(...) instead of re-inlining the same check. The BF16 dtype guard stays as an explicit && at the call site because it's specific to quantize_impl (the allocation callers don't have an input tensor to check). I also replaced the hand-rolled rows = product(input.shape[:-1]) loop with get_2d_dims(input.shape()) so the flattening rule isn't duplicated either. The shape/arch eligibility now has a single source of truth.

// neither of which supports emitting SF in the GEMM-swizzled layout (their
// backing kernels NVTE_CHECK reject swizzled-flagged output). Surface a clean
// error here instead of letting it abort deep inside the kernel with an
// opaque message. JAX hard-asserts eligibility upfront; PyTorch matches that
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we mention JAX in the PyTorch source files?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New commit dropped the JAX reference and the surrounding narration. The remaining 2-line comment just explains why this NVTE_CHECK is here.

@ptrendx
Copy link
Copy Markdown
Member

ptrendx commented May 19, 2026

Please also handle the convert_and_update_tensor path since it also needs changes.

bool all_tensors_rht_cast_fusion_eligible = true;
for (size_t i = 0; i < num_tensors; ++i) {
const auto [rows, cols] = get_2d_dims(shape_list[i]);
if (!NVFP4Quantizer::is_eligible_for_rht_cast_fusion(rows, cols)) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The grouped kernel that supports the swizzle will only run for rows being divisible by 128, but this function will allow tensors divisible by 64.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch — this was a real bug, not just an over-permissive style.

Before this fix, is_eligible_for_rht_cast_fusion(shape) used a single row-alignment constraint of rows % 64 == 0 (the single-tensor RHT cast-fusion kernel's entry check at row_cast_col_hadamard_transform_cast_fusion.cu:1161). The bulk-allocation path in cast.cpp was calling this same lax check, so shapes like rows in {64, 192, 320, ...} — all satisfying % 64 == 0 — would pass eligibility, get with_gemm_swizzled_scales=True, and then hard-abort inside the grouped kernel whose entry asserts first_logical_dim % 128 == 0
(graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu:1385).

The fix adds a for_grouped_kernel parameter on is_eligible_for_rht_cast_fusion so callers select the constraint
that matches the kernel they will actually invoke:

  • false (default): rows % 64 == 0, single-tensor kernel
  • true: rows % 128 == 0, grouped kernel

The bulk-allocation caller in cast.cpp passes /*for_grouped_kernel=*/true; the three single-tensor callers
(create_tensor, convert_and_update_tensor, quantize_impl) keep the default false. Shapes with rows in {64, 192, 320, ...} now correctly fail the grouped-path eligibility and fall back to the unfused path instead of reaching the grouped kernel.

// (search for "eligible_for_rht_cast_fusion" in this file). The dtype
// check (BF16) is implicit: with_rht is only valid for BF16 input by
// construction.
return rows % 64 == 0 && cols % 128 == 0 && transformer_engine::cuda::sm_arch() >= 100 &&
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the rows % 64 == 0 a requirement here rather than rows % 128 == 0?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The 64 here is correct for the single-tensor cast-fusion kernel — its entry check is NVTE_CHECK(M % 64 == 0, ...) at row_cast_col_hadamard_transform_cast_fusion.cu:1161. The 128 you're thinking of is the grouped kernel's stricter requirement at graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu:1385.

cael-ling and others added 2 commits May 19, 2026 20:14
Functional fix:
- `bulk_allocate_nvfp4_tensors` previously used the single-tensor RHT
  eligibility check (`rows % 64 == 0`), but the grouped kernel asserts
  `first_logical_dim % 128 == 0` at entry. Shapes with rows in
  {64, 192, 320, ...} would pass eligibility, set
  `with_gemm_swizzled_scales=True`, and then hard-abort inside the
  grouped kernel with an opaque NVTE_CHECK message. Adding a
  `for_grouped_kernel` parameter on `is_eligible_for_rht_cast_fusion`
  selects the correct row alignment: 64 for the single-tensor kernel,
  128 for the grouped variant. Only the bulk-allocation caller passes
  `true`; the three single-tensor callers keep the default `false`.
Refactors:
- `is_eligible_for_rht_cast_fusion` now takes the full tensor shape
  (`std::vector<size_t>`) and flattens internally with `get_2d_dims`,
  so the four call sites no longer pre-flatten and duplicate the
  flatten rule.
- `quantize_impl` delegates the shape/arch eligibility to
  `is_eligible_for_rht_cast_fusion` instead of inlining the same
  predicate, and its hand-rolled `rows = product(shape[:-1])` loop is
  replaced with `get_2d_dims(input.shape())`. The shape/arch
  eligibility now has a single source of truth.
Comment cleanups:
- Trimmed verbose comments in `bulk_allocate_nvfp4_tensors`,
  `create_tensor`, `convert_and_update_tensor`, and
  `quantize_with_rht_unfused_helper`. Removed cross-references to
  other functions/files, code narration of subsequent lines, the JAX
  reference in PyTorch source, and the "see X for rationale" pattern.
- Doxygen on `is_eligible_for_rht_cast_fusion` reduced to a single
  brief sentence.

Signed-off-by: Cael Ling <caell@nvidia.com>
@cael-ling
Copy link
Copy Markdown
Contributor Author

Please also handle the convert_and_update_tensor path since it also needs changes.

Done. Both create_tensor and convert_and_update_tensor now have the same 2-line comment on the gating; removed the previous "See NVFP4Quantizer::create_tensor for the rationale" cross-reference. I also trimmed create_tensor's long rationale block (which referenced specific .cu/.cuh filenames and quantize_with_rht_unfused's internal behavior) in the same pass, so the two functions are consistent.

@cael-ling cael-ling requested a review from ptrendx May 21, 2026 01:17

// Swizzled SF is only valid when the RHT cast-fusion path runs;
// other quantize paths reject it.
const bool with_gemm_swizzled_scales = this->optimize_for_gemm && this->with_rht &&
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is then set for the out_cpp TensorWrapper (at the end of this function), but not in the actual Python object. See handing of this in the MXFP8 quantizer:

  tensor.attr("_with_gemm_swizzled_scales") = with_gemm_swizzled_scales;

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch — fixed in the latest commit (pushed), mirroring the MXFP8 quantizer

@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label Jun 3, 2026
@cael-ling cael-ling requested a review from ptrendx June 3, 2026 09:55
ptrendx
ptrendx previously approved these changes Jun 3, 2026
@ptrendx
Copy link
Copy Markdown
Member

ptrendx commented Jun 3, 2026

/te-ci pytorch

…tput

Signed-off-by: cael-ling <caell@nvidia.com>
ptrendx
ptrendx previously approved these changes Jun 4, 2026
Signed-off-by: Cael Ling <caell@nvidia.com>
@cael-ling cael-ling requested a review from ptrendx June 4, 2026 05:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants