Skip to content
Open
189 changes: 189 additions & 0 deletions benchmarks/benchmark_rht_cast_swizzle_fusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""Benchmark NVFP4 RHT cast-fusion with vs without fused GEMM-swizzled SF output.

For each shape we measure two paths and two builds:

* path = "quant_only": just NVFP4Quantizer(x)
* path = "quant_plus_swizzle": NVFP4Quantizer(x) + tex.swizzle_scales_for_gemm_(t)
(this is what te.Linear -> tex.generic_gemm does right before the
cuBLAS LT NVFP4 GEMM dispatch)

* build = "baseline": optimize_for_gemm=False
-> quant kernel emits compact SF;
tex.swizzle_scales_for_gemm_ launches the standalone
swizzle_{row,col}_scaling_kernel pass before GEMM.
* build = "swizzle_fusion": optimize_for_gemm=True
-> quant kernel emits GEMM-swizzled SF directly (via the
kEnableSwizzleSFOutput compile-time switch in
row_cast_col_hadamard_transform_cast_fusion.cu);
tex.swizzle_scales_for_gemm_ early-returns and the standalone
swizzle pass disappears from the timeline.

The wall-clock delta on the "quant_plus_swizzle" path is the production
saving of this PR.
"""

import argparse
import torch
import pandas as pd
import torch.utils.benchmark as benchmark

import transformer_engine.pytorch as te # noqa: F401 must be first per te-python-import-order
import transformer_engine_torch as tex
from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer


def make_quantizer(optimize_for_gemm: bool) -> NVFP4Quantizer:
q = NVFP4Quantizer(
fp4_dtype=tex.DType.kFloat4E2M1,
rowwise=True,
columnwise=True,
with_amax_reduction=False,
amax_reduction_group=None,
with_rht=True,
with_post_rht_amax=True,
with_random_sign_mask=True,
)
q.optimize_for_gemm = optimize_for_gemm
return q


def _bench(stmt: str, globals_dict: dict, min_run_time: float) -> float:
"""Returns median wall-clock per call in microseconds."""
timing = benchmark.Timer(
stmt=stmt,
globals=globals_dict,
num_threads=1,
).blocked_autorange(min_run_time=min_run_time)
return timing.median * 1e6


def run_shape(shape, min_run_time: float):
M, K = shape
assert M % 16 == 0 and K % 16 == 0, "Shape must be divisible by 16"

x = torch.randn([M, K], dtype=torch.bfloat16, device="cuda")
q_base = make_quantizer(optimize_for_gemm=False)
q_swf = make_quantizer(optimize_for_gemm=True)

# quant_only path
quant_only_base_us = _bench(
stmt="q(x)",
globals_dict={"q": q_base, "x": x},
min_run_time=min_run_time,
)
quant_only_swf_us = _bench(
stmt="q(x)",
globals_dict={"q": q_swf, "x": x},
min_run_time=min_run_time,
)

# quant_plus_swizzle path (this is what te.Linear actually runs)
quant_plus_swizzle_base_us = _bench(
stmt="t = q(x); tex.swizzle_scales_for_gemm_(t)",
globals_dict={"q": q_base, "x": x, "tex": tex},
min_run_time=min_run_time,
)
quant_plus_swizzle_swf_us = _bench(
stmt="t = q(x); tex.swizzle_scales_for_gemm_(t)",
globals_dict={"q": q_swf, "x": x, "tex": tex},
min_run_time=min_run_time,
)

saved_us = quant_plus_swizzle_base_us - quant_plus_swizzle_swf_us
speedup = (
quant_plus_swizzle_base_us / quant_plus_swizzle_swf_us
if quant_plus_swizzle_swf_us > 0
else float("inf")
)

print(
f" shape={shape}: quant_only base={quant_only_base_us:.2f}us, "
f"SUT={quant_only_swf_us:.2f}us | "
f"quant+swizzle base={quant_plus_swizzle_base_us:.2f}us, "
f"SUT={quant_plus_swizzle_swf_us:.2f}us "
f"-> saved {saved_us:.2f}us ({speedup:.2f}x)"
)

return {
"shape": shape,
"M": M,
"K": K,
"quant_only_base_us": quant_only_base_us,
"quant_only_swf_us": quant_only_swf_us,
"quant_plus_swizzle_base_us": quant_plus_swizzle_base_us,
"quant_plus_swizzle_swf_us": quant_plus_swizzle_swf_us,
"saved_us": saved_us,
"speedup": speedup,
}


# Nsight Compute Profiling Command (for verifying the swizzle kernel disappears):
# ncu -f -o swizzle_fusion --set=full \
# --kernel-name "regex:swizzle_(row|col)_scaling_kernel|cast_col_hadamard_transform_cast_fusion" \
# -s 5 -c 10 python benchmarks/benchmark_rht_cast_swizzle_fusion.py --profile


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--profile",
action="store_true",
help="Run only one shape for use with ncu/nsys; longer min_run_time",
)
parser.add_argument(
"--min-run-time",
type=float,
default=2.0,
help="Minimum total measured time per cell in seconds (benchmark.Timer)",
)
parser.add_argument(
"--csv",
type=str,
default="benchmark_rht_cast_swizzle_fusion.csv",
help="CSV output path",
)
args = parser.parse_args()

if args.profile:
print("Profiling mode enabled (single shape).")
shapes = [(8192, 4096)]
min_run_time = max(5.0, args.min_run_time)
else:
shapes = [
# production-class shapes
(8192, 5120),
(8192, 10240),
(8192, 2560),
(8192, 11328),
(8192, 3584),
(5120, 8192),
(10240, 8192),
(2560, 8192),
(11328, 8192),
(3584, 8192),
(4096, 16384),
(14336, 16384),
]
min_run_time = args.min_run_time

print(
"NVFP4 RHT cast-fusion: swizzle-fusion (optimize_for_gemm=True) vs baseline. "
f"min_run_time={min_run_time}s per cell, BF16 input, "
"rowwise+columnwise SF, RHT=True+post_rht_amax."
)
rows = []
for shape in shapes:
print(f"Running {shape} ...")
rows.append(run_shape(shape, min_run_time))

df = pd.DataFrame(rows)
pd.set_option("display.max_columns", None)
pd.set_option("display.width", 200)
print()
print(df.to_string(index=False))
df.to_csv(args.csv, index=False)
print(f"\nWrote {args.csv}")
127 changes: 127 additions & 0 deletions benchmarks/profile_rht_cast_swizzle_fusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""
Profile that the dedicated swizzle kernels (swizzle_{row,col}_scaling_kernel
in transformer_engine/common/swizzle/swizzle.cu) disappear from the timeline
when NVFP4 RHT cast-fusion emits SF in the GEMM-swizzled layout directly
(optimize_for_gemm=True).

Test setup:
- NVFP4 + RHT + post-RHT amax (same as te.Linear sets up internally)
- rowwise=True AND columnwise=True (covers BOTH swizzle_row_scaling_kernel
and swizzle_col_scaling_kernel; this is what tex.Linear's input quantizer
needs during training because the rowwise tensor is used by the fwd GEMM
and the columnwise tensor is used by the dgrad GEMM)
- tex.swizzle_scales_for_gemm_(t) is what te.Linear -> tex.generic_gemm
calls just before the cuBLAS LT NVFP4 GEMM dispatch
"""

import torch
import transformer_engine.pytorch as te # noqa: F401 must be first
import transformer_engine_torch as tex
from transformer_engine.pytorch import NVFP4Quantizer


def make_quantizer(optimize_for_gemm: bool) -> NVFP4Quantizer:
q = NVFP4Quantizer(
fp4_dtype=tex.DType.kFloat4E2M1,
rowwise=True,
columnwise=True,
with_amax_reduction=False,
amax_reduction_group=None,
with_rht=True,
with_post_rht_amax=True,
with_random_sign_mask=True,
)
q.optimize_for_gemm = optimize_for_gemm
return q


import re

# Match ONLY the standalone swizzle pass kernels in
# transformer_engine/common/swizzle/swizzle.cu — NOT RHT cast-fusion kernels
# whose mangled name happens to contain "Swizzle" because of the
# `template <..., bool kEnableSwizzleSFOutput, ...>` parameter substring.
STANDALONE_SWIZZLE_RE = re.compile(
r"(?:multi_tensor_(?:un)?swizzle|(?:un)?swizzle)_(?:row|col)_scaling_kernel"
)


def dump_kernel_counts(prof, label: str) -> dict:
print(f"\n=== {label} ===")
counts: dict[str, int] = {}
for ev in prof.events():
if ev.device_type != torch.autograd.DeviceType.CUDA:
continue
counts[ev.name] = counts.get(ev.name, 0) + 1
standalone_swizzle_total = 0
for name, c in sorted(counts.items(), key=lambda kv: -kv[1]):
marker = ""
if STANDALONE_SWIZZLE_RE.search(name):
marker = " <-- STANDALONE SWIZZLE PASS"
standalone_swizzle_total += c
# Truncate long mangled CUTLASS names for readability
short = name if len(name) <= 110 else name[:107] + "..."
print(f" {c:4d} {short}{marker}")
print(f" -- standalone swizzle kernel total: {standalone_swizzle_total}")
return counts


def profile_path(optimize_for_gemm: bool, x: torch.Tensor, n_iters: int = 20):
q = make_quantizer(optimize_for_gemm=optimize_for_gemm)
# warm-up
for _ in range(3):
t = q(x)
tex.swizzle_scales_for_gemm_(t)
torch.cuda.synchronize()
with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof:
for _ in range(n_iters):
t = q(x)
tex.swizzle_scales_for_gemm_(t)
torch.cuda.synchronize()
return prof


def main():
torch.manual_seed(0)
torch.cuda.manual_seed(0)
device = "cuda"
# Shape that hits the production RHT cast-fusion fast-path
# (rows % 64 == 0, cols % 128 == 0, BF16, SM100/110).
M, N = 8192, 4096
x = torch.randn((M, N), dtype=torch.bfloat16, device=device)

print(f"Shape: M={M}, N={N}, dtype=bf16, RHT=True, post_rht_amax=True")
print(f"iters: 20 (after 3 warm-up)")

prof_baseline = profile_path(optimize_for_gemm=False, x=x)
counts_baseline = dump_kernel_counts(
prof_baseline, "BASELINE: optimize_for_gemm=False (separate swizzle kernel)"
)

prof_swf = profile_path(optimize_for_gemm=True, x=x)
counts_swf = dump_kernel_counts(
prof_swf, "SUT: optimize_for_gemm=True (quant emits swizzled SF directly)"
)

print("\n=== VERDICT ===")
base_swizzle = sum(c for n, c in counts_baseline.items() if STANDALONE_SWIZZLE_RE.search(n))
swf_swizzle = sum(c for n, c in counts_swf.items() if STANDALONE_SWIZZLE_RE.search(n))
print(f" baseline standalone swizzle kernel launches: {base_swizzle}")
print(f" SUT standalone swizzle kernel launches: {swf_swizzle}")
if swf_swizzle == 0 and base_swizzle > 0:
print(
" PASS: standalone swizzle pass disappears from timeline under optimize_for_gemm=True"
)
else:
print(
" FAIL: expected baseline > 0 and SUT == 0; check whether SUT actually "
"set with_gemm_swizzled_scales=True on the output tensor"
)


if __name__ == "__main__":
main()
Loading
Loading