Skip to content

[Common][PyTorch] Add z_loss_weight to parallel_cross_entropy#2707

Open
bassoy wants to merge 2 commits intoNVIDIA:mainfrom
bassoy:add_zloss_to_parallel_cross_entropy
Open

[Common][PyTorch] Add z_loss_weight to parallel_cross_entropy#2707
bassoy wants to merge 2 commits intoNVIDIA:mainfrom
bassoy:add_zloss_to_parallel_cross_entropy

Conversation

@bassoy
Copy link

@bassoy bassoy commented Feb 26, 2026

Description

Adds z-loss regularization to parallel_cross_entropy. Z-loss penalizes
large logit magnitudes by adding z_loss_weight * log(Z)^2 per token to the
loss, where log(Z) = log(sum(exp(logits))) is the log-sum-exp (see
ST-MoE, arxiv.org/abs/2202.08906).
This stabilizes training by keeping logits in a numerically well-behaved range.

Type of change

  • New feature (non-breaking change which adds functionality)

Changes

The Triton kernel already computes m (running max) and d (sum of shifted
exponentials) as part of the online softmax. lse = m + log(d) is a free
byproduct which means no extra data movement required.

  • Forward: compute lse = m + log(d) using variables and add
    z_loss_weight * lse^2 to the per-token loss. Reuse lse in the
    label-smoothing path (smooth_loss previously recomputed m + log(d)).
  • Backward: scale the softmax gradient by (1 + 2 * z_loss_weight * lse),
    derived from d/dx_i[z_loss_weight * lse^2] = 2 * z_loss_weight * lse * softmax(x_i).
  • Dead-code elimination: z_loss_weight is tl.constexpr enabling triton
    to eliminate all z-loss branches at compile time when z_loss_weight=0.0.
  • API: parallel_cross_entropy(..., z_loss_weight=0.0) is backward
    compatible. Hence, default behavior is unchanged.

Tests

Z-loss tests extend the existing test infrastructure rather than introducing
separate helpers. generate_infra accepts an optional z_loss_weight
parameter: when > 0, it builds a PyTorch reference function that computes
F.cross_entropy + z_loss_weight * lse^2 in FP32. one_iteration_test
passes z_loss_weight through to parallel_cross_entropy. This keeps all
z-loss tests in the same pattern as the existing suite.

The existing z_loss_weight == 0.0 path in generate_infra is untouched
(torch.nn.CrossEntropyLoss), so there is no risk to existing tests.

6 new tests, 14 total pass (A40, single GPU):

Test What it verifies
test_z_loss FP32 loss and gradients match PyTorch reference (5 random iterations, random swap_dim)
test_z_loss_bfloat16 Same as above with BF16 input (3 iterations)
test_z_loss_with_ignore_idx Z-loss + ignored tokens: loss and gradients correct (5 iterations)
test_z_loss_zero_weight z_loss_weight=0.0 produces bit-identical results to the default (no z-loss)
test_z_loss_reduced Z-loss + reduce_loss=True: reduced loss and gradients correct (5 iterations)
test_z_loss_label_smoothing Z-loss + label_smoothing=0.1: both features interact correctly (3 iterations)

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@bassoy bassoy force-pushed the add_zloss_to_parallel_cross_entropy branch from 7f11aa2 to 70d3f84 Compare March 20, 2026 21:49
@bassoy bassoy changed the title [Common][PyTorch] Add z_loss_weight and log_sum_exp output to parallel_cross_entropy [Common][PyTorch] Add z_loss_weight to parallel_cross_entropy Mar 20, 2026
@bassoy bassoy marked this pull request as ready for review March 20, 2026 21:57
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 20, 2026

Greptile Summary

This PR adds z-loss regularization (z_loss_weight * log(Z)² per token) to parallel_cross_entropy in a backward-compatible way. The implementation reuses the m (running max) and d (sum of shifted exponentials) already computed by the online-softmax kernel to obtain lse = m + log(d) at zero extra memory-bandwidth cost, and applies the correct combined gradient softmax * (1 + 2*z*lse) - eps by multiplying before the label-smoothing eps subtraction. Using z_loss_weight as a tl.constexpr enables compile-time dead-code elimination when the weight is 0.0, preserving bit-identical behaviour for existing callers.

  • The Triton kernel forward and backward math is correct: the z-loss forward term z_loss_weight * lse² and the z-loss backward term 2 * z_loss_weight * lse * softmax(x_i) are both applied correctly, including the target-token correction and the reduce_loss normalization.
  • The backward return tuple now correctly contains 8 elements (one gradient + 7 Nones) to match the extended forward signature.
  • ignore_idx not forwarded to CrossEntropyLoss: In generate_infra, the newly-added ignore_idx parameter is correctly forwarded in the z_loss_weight > 0 reference closure, but the z_loss_weight == 0.0 branch creates torch.nn.CrossEntropyLoss without ignore_index=ignore_idx, creating an asymmetry that will silently produce wrong test references if a non-default ignore index is ever used with z_loss_weight=0.0.
  • Non-finite z_loss_weight bypasses the input validation: z_loss_weight < 0 does not catch float('nan') (silently disables z-loss) or float('inf') (immediately produces NaN losses), both of which should be rejected explicitly.

Confidence Score: 4/5

  • PR is safe to merge with minor fixes — the core Triton kernel math is correct, two small gaps in the test harness and input validation should be addressed first.
  • The kernel implementation is mathematically sound (forward loss, backward gradient, reduce_loss normalization, label-smoothing interaction all verified). API is backward compatible. Two P1 issues remain: the ignore_idx asymmetry in generate_infra and the non-finite z_loss_weight gap in the validation guard. Neither affects the current test suite, but both are easy to hit in follow-on work.
  • transformer_engine/pytorch/cross_entropy.py (non-finite guard) and tests/pytorch/test_parallel_cross_entropy.py (missing ignore_index in CrossEntropyLoss).

Important Files Changed

Filename Overview
transformer_engine/common/triton/cross_entropy.py Adds z_loss_weight as a tl.constexpr, computes lse = m + log(d) after the online-softmax reduction, and correctly applies the z-loss term both in the forward loss (lse²) and in the backward gradient (softmax * (1 + 2*z*lse) - eps) by multiplying before the eps subtraction. The dead-code elimination via constexpr is well-documented; no new correctness issues identified in the kernel itself.
transformer_engine/pytorch/cross_entropy.py Threads z_loss_weight through CrossEntropyFunction.forward to the Triton forward call, adds the correct count of None returns in backward, adds a non-negative guard in parallel_cross_entropy, and documents the compile-time-constant behaviour. However, the guard does not reject NaN or Inf values.
transformer_engine/pytorch/triton/cross_entropy.py Minimal, correct change: adds z_loss_weight as an optional parameter defaulting to 0.0 in cross_entropy_forward and passes it straight through to cross_entropy_kernel.
tests/pytorch/test_parallel_cross_entropy.py Adds 7 new z-loss tests covering FP32, BF16, ignore-index, reduce_loss, label-smoothing, zero-weight, and the combined reduce+ignore path. The reference closure correctly forwards ignore_idx to F.cross_entropy and z_pen masking, but the z_loss_weight==0.0 branch of generate_infra does not forward ignore_idx to torch.nn.CrossEntropyLoss.

Sequence Diagram

sequenceDiagram
    participant Caller
    participant parallel_cross_entropy
    participant CrossEntropyFunction
    participant cross_entropy_forward
    participant online_softmax_kernel
    participant cross_entropy_kernel
    participant cross_entropy_backward
    participant element_mul_kernel

    Caller->>parallel_cross_entropy: inp, target, z_loss_weight
    parallel_cross_entropy->>parallel_cross_entropy: validate z_loss_weight >= 0
    parallel_cross_entropy->>CrossEntropyFunction: .apply(inp, target, ..., z_loss_weight)
    CrossEntropyFunction->>cross_entropy_forward: inp, target, ..., z_loss_weight
    cross_entropy_forward->>online_softmax_kernel: compute m, d, X_y per row
    online_softmax_kernel-->>cross_entropy_forward: m_d_X_y tensor
    cross_entropy_forward->>cross_entropy_kernel: m_d_X_y, z_loss_weight (constexpr)
    Note over cross_entropy_kernel: lse = m + log(d)<br/>loss += z_loss_weight * lse²<br/>grad = softmax*(1+2*z*lse) - eps
    cross_entropy_kernel-->>cross_entropy_forward: loss_1d, grad in _input
    cross_entropy_forward-->>CrossEntropyFunction: loss, _input (with grad)
    CrossEntropyFunction->>CrossEntropyFunction: ctx.save_for_backward(_input)
    CrossEntropyFunction-->>Caller: loss tensor

    Caller->>CrossEntropyFunction: .backward(grad_output)
    CrossEntropyFunction->>cross_entropy_backward: _input (saved grad), grad_output
    cross_entropy_backward->>element_mul_kernel: scale grad by grad_output
    element_mul_kernel-->>cross_entropy_backward: scaled grad
    cross_entropy_backward-->>CrossEntropyFunction: inp grad
    CrossEntropyFunction-->>Caller: (inp_grad, None×7)
Loading

Last reviewed commit: "[pre-commit.ci] auto..."

Comment on lines +195 to +197
# Z-loss gradient: d/dx_i[z_loss_weight * lse^2] = 2 * z_loss_weight * lse * softmax(x_i).
if z_loss_weight > 0:
X_block = X_block * (1.0 + 2.0 * z_loss_weight * lse)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Incorrect z-loss gradient when label_smoothing > 0

At this point X_block = (softmax(x_i) - eps) / N (or softmax - eps without reduction). Multiplying the combined CE gradient by (1 + 2 * z_loss_weight * lse) expands to:

(softmax - eps)/N * (1 + 2*z*lse)
= (softmax - eps)/N + (softmax - eps)/N * 2*z*lse

But the correct z-loss gradient is purely softmax/N * 2 * z * lse — the z-loss term should be additive on top of the CE gradient, not multiplicative against the entire (softmax - eps) expression. The error introduced is -eps/N * 2 * z_loss_weight * lse per element.

For typical training settings (label_smoothing=0.1, V=64000, z_loss_weight=0.001, lse≈11) the error is on the order of 3e-8, which is below float32 precision for large vocabularies — explaining why test_z_loss_label_smoothing still passes. However, for small vocabularies (e.g. V=32) the error becomes measurable and the implementation is mathematically incorrect.

The correct approach is to add the z-loss gradient additively, using the pre-eps softmax value:

        if reduce_loss:
            X_block = (tl.exp(X_block - m) / d - eps) / (n_non_ignore)
        else:
            X_block = tl.exp(X_block - m) / d - eps
        # Z-loss gradient: 2 * z_loss_weight * lse * softmax(x_i), additive to CE gradient.
        if z_loss_weight > 0:
            softmax_i = tl.exp(X_block_fp32 - m) / d  # pure softmax, before subtracting eps
            if reduce_loss:
                X_block = X_block + softmax_i * (2.0 * z_loss_weight * lse) / n_non_ignore
            else:
                X_block = X_block + softmax_i * (2.0 * z_loss_weight * lse)

where X_block_fp32 is the logit block before the CE computation (currently loaded at the top of the loop).

Comment on lines +204 to +215
def test_z_loss_zero_weight(self):
self.generate_infra(False, 0)
self.generate_input(torch.float32, False, False)
loss_base = self.test_loss_func(self.input_test.clone(), self.tar_test)
loss_zero = self.test_loss_func(self.input_test.clone(), self.tar_test, z_loss_weight=0.0)
assert torch.equal(
loss_base, loss_zero
), "z_loss_weight=0.0 must be bit-identical to the default"
self.input_test = None
self.input_ref = None
self.tar_test = None
self.tar_ref = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 test_z_loss_zero_weight only validates the forward pass

The test clones the input tensor but never calls .requires_grad_(True), so no gradient is accumulated and the backward path is never exercised. The Triton kernel eliminates the z-loss branches at compile time via tl.constexpr, so validating that the gradient is also bit-identical for z_loss_weight=0.0 would strengthen the regression value of this test.

Consider adding backward verification:

def test_z_loss_zero_weight(self):
    self.generate_infra(False, 0)
    self.generate_input(torch.float32, False, False)

    inp_base = self.input_test.clone().requires_grad_(True)
    inp_zero = self.input_test.clone().requires_grad_(True)

    loss_base = self.test_loss_func(inp_base, self.tar_test)
    loss_zero = self.test_loss_func(inp_zero, self.tar_test, z_loss_weight=0.0)

    assert torch.equal(loss_base, loss_zero), "z_loss_weight=0.0 must be bit-identical to the default"

    loss_base.sum().backward()
    loss_zero.sum().backward()

    assert torch.equal(inp_base.grad, inp_zero.grad), \
        "Gradients with z_loss_weight=0.0 must be bit-identical to the default"

    self.input_test = self.input_ref = self.tar_test = self.tar_ref = None

dist_process_group: Optional[torch.distributed.ProcessGroup] = None,
ignore_idx: int = -100,
is_cg_capturable: bool = False,
z_loss_weight: float = 0.0,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 No validation that z_loss_weight is non-negative

A negative value is mathematically well-formed but semantically inverts the regularization (rewarding large logit magnitudes). Given the docstring describes this as a "regularization weight", an early guard against negative values would make the API safer and the intent explicit:

Suggested change
z_loss_weight: float = 0.0,
z_loss_weight: float = 0.0,

Consider adding before the CrossEntropyFunction.apply(...) call:

if z_loss_weight < 0.0:
    raise ValueError(f"z_loss_weight must be non-negative, got {z_loss_weight}")

@bassoy bassoy force-pushed the add_zloss_to_parallel_cross_entropy branch from 056ce5f to bb47312 Compare March 20, 2026 22:19
n_non_ignore,
reduce_loss: tl.constexpr,
label_smoothing: tl.constexpr,
z_loss_weight: tl.constexpr,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 tl.constexpr specialization will recompile for every unique float value

z_loss_weight is declared tl.constexpr, which means Triton compiles a separate kernel for each unique Python float value passed. The PR describes this as intentional for dead-code elimination when z_loss_weight=0.0, and for a fixed training hyperparameter that's fine. However, if callers ever want to anneal or schedule z_loss_weight across training steps (e.g. a warmup from 0 → 0.001), every distinct float encountered will trigger a fresh JIT compilation, stalling the training loop.

Consider documenting this behaviour in the docstring of both cross_entropy_kernel and parallel_cross_entropy:

z_loss_weight (float): Weight for z-loss regularization. Adds z_loss_weight * log(Z)^2 per token.
    This value is used as a Triton compile-time constant (tl.constexpr); varying it across
    calls will trigger kernel recompilation. Use a fixed value during training.

inp = inp.float()
ce = F.cross_entropy(inp, tar, reduction="none", label_smoothing=label_smoothing)
z_pen = z_loss_weight * torch.square(torch.logsumexp(inp, dim=-1))
z_pen[tar == -100] = 0.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Ignore index hardcoded to -100 in reference function

z_pen[tar == -100] = 0.0 hardcodes the ignore index rather than closing over the ignore_idx value that will be used in the actual parallel_cross_entropy call. Currently all tests in the suite pass ignore_idx=-100 (the default), so this is consistent — but if a future test exercises a non-default ignore index, the reference would silently zero the wrong tokens and the comparison would yield a false positive.

Consider parameterising the closure:

Suggested change
z_pen[tar == -100] = 0.0
z_pen[tar == ignore_idx] = 0.0

And updating generate_infra to accept ignore_idx: int = -100 so it can be forwarded.

@bassoy bassoy force-pushed the add_zloss_to_parallel_cross_entropy branch from bb47312 to ed83839 Compare March 20, 2026 22:29
@bassoy bassoy marked this pull request as draft March 20, 2026 22:31
@bassoy bassoy force-pushed the add_zloss_to_parallel_cross_entropy branch from ed83839 to da503ed Compare March 20, 2026 22:33
@bassoy bassoy marked this pull request as ready for review March 20, 2026 22:36
Comment on lines +249 to +259
def test_z_loss_reduced(self):
self.generate_iters(5)
self.generate_infra(True, 0, z_loss_weight=0.001)
for i in range(self.iters):
self.one_iteration_test(
dtype=torch.float32,
swap_dim=random.choice([True, False]),
label_smoothing=0,
reduce_loss=True,
z_loss_weight=0.001,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Missing test combination: reduce_loss=True + z-loss + ignore_idx

The test suite covers reduce_loss=True + z-loss (test_z_loss_reduced) and reduce_loss=False + z-loss + ignore_idx (test_z_loss_with_ignore_idx), but no test exercises all three together. This is the most semantically interesting combination: n_non_ignore is used to normalize both the loss value (in Python) and the gradient (in the Triton kernel, line 198), so an incorrect interaction would only appear when tokens are actually masked and reduction is active.

Consider adding:

def test_z_loss_reduced_with_ignore_idx(self):
    self.generate_iters(5)
    self.generate_infra(True, 0, z_loss_weight=0.001)
    for i in range(self.iters):
        self.one_iteration_test(
            dtype=torch.float32,
            swap_dim=random.choice([True, False]),
            label_smoothing=0,
            reduce_loss=True,
            ignore_idx=True,
            z_loss_weight=0.001,
        )

Comment on lines +32 to +40
def ref_with_zloss(inp, tar):
inp = inp.float()
ce = F.cross_entropy(inp, tar, reduction="none", label_smoothing=label_smoothing)
z_pen = z_loss_weight * torch.square(torch.logsumexp(inp, dim=-1))
z_pen[tar == ignore_idx] = 0.0
loss = ce + z_pen
if reduce_loss:
loss = loss.sum() / (tar != ignore_idx).sum()
return loss
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 F.cross_entropy does not receive the parameterized ignore_index

generate_infra now accepts an ignore_idx parameter and correctly uses it to zero out z_pen, but F.cross_entropy is called without ignore_index=ignore_idx. PyTorch's default is -100, so all current tests pass since generate_input always uses -100. However, if a future test passes a non-default ignore_idx, the CE component of the reference would still ignore -100 while the real kernel would ignore the custom index, producing a silently incorrect reference loss and false-passing gradient tests.

def ref_with_zloss(inp, tar):
    inp = inp.float()
    ce = F.cross_entropy(
        inp, tar, reduction="none",
        label_smoothing=label_smoothing,
        ignore_index=ignore_idx,   # <-- forward the parameter
    )
    z_pen = z_loss_weight * torch.square(torch.logsumexp(inp, dim=-1))
    z_pen[tar == ignore_idx] = 0.0
    ...

Add z-loss regularization (z_loss_weight * log(Z)^2 per token) to the
Triton cross-entropy kernel. The z_loss_weight parameter is a
tl.constexpr, so it is dead-code-eliminated when set to 0.0.

Forward: adds z_loss_weight * lse^2 to per-token loss.
Backward: scales softmax gradient by (1 + 2 * z_loss_weight * lse).

Tests: extend existing test infrastructure with z_loss_weight parameter
in generate_infra and one_iteration_test. Z-loss tests cover FP32,
BF16, ignore_idx, and zero-weight identity.

Signed-off-by: Cem Bassoy <cem.bassoy@deepl.com>
@bassoy bassoy force-pushed the add_zloss_to_parallel_cross_entropy branch from b337f49 to 13ffa35 Compare March 20, 2026 22:55
Comment on lines +27 to +29
self.ref_loss_func = torch.nn.CrossEntropyLoss(
label_smoothing=label_smoothing, reduction="mean" if reduce_loss else "none"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 ignore_idx not forwarded to CrossEntropyLoss

generate_infra now accepts an ignore_idx parameter, and the z_loss_weight > 0 branch correctly passes it to both F.cross_entropy(..., ignore_index=ignore_idx) and z_pen[tar == ignore_idx]. However, the z_loss_weight == 0.0 branch silently falls back to PyTorch's default (-100), ignoring the parameter:

self.ref_loss_func = torch.nn.CrossEntropyLoss(
    label_smoothing=label_smoothing, reduction="mean" if reduce_loss else "none"
    # ignore_index=ignore_idx is missing here
)

All current tests happen to pass ignore_idx=-100 (the default), so there is no visible failure now. But if any test ever calls generate_infra(..., z_loss_weight=0.0, ignore_idx=42), the reference would still ignore token id -100 instead of 42, producing a silently incorrect reference and a false-passing comparison.

Suggested change
self.ref_loss_func = torch.nn.CrossEntropyLoss(
label_smoothing=label_smoothing, reduction="mean" if reduce_loss else "none"
)
self.ref_loss_func = torch.nn.CrossEntropyLoss(
label_smoothing=label_smoothing,
reduction="mean" if reduce_loss else "none",
ignore_index=ignore_idx,
)

Comment on lines +152 to +153
if z_loss_weight < 0:
raise ValueError(f"z_loss_weight must be non-negative, got {z_loss_weight}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Non-finite z_loss_weight bypasses validation

The current guard only rejects strictly negative values. float('nan') satisfies nan < 0 == False (NaN comparisons always return False) and float('inf') satisfies inf < 0 == False, so both slip through undetected:

  • z_loss_weight=float('nan'): the Triton tl.constexpr comparison nan > 0 evaluates to False at compile time, so the z-loss branch is silently skipped — users see no z-loss even though they passed a non-zero value.
  • z_loss_weight=float('inf'): the branch is taken, and loss += inf * lse * lse will produce inf/nan losses, immediately destabilising training.

Consider expanding the guard to also exclude non-finite values:

Suggested change
if z_loss_weight < 0:
raise ValueError(f"z_loss_weight must be non-negative, got {z_loss_weight}")
if not (z_loss_weight >= 0.0 and z_loss_weight != float("inf")):
raise ValueError(f"z_loss_weight must be a finite non-negative number, got {z_loss_weight}")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant