[Common][PyTorch] Add z_loss_weight to parallel_cross_entropy#2707
[Common][PyTorch] Add z_loss_weight to parallel_cross_entropy#2707bassoy wants to merge 2 commits intoNVIDIA:mainfrom
Conversation
7f11aa2 to
70d3f84
Compare
Greptile SummaryThis PR adds z-loss regularization (
Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Caller
participant parallel_cross_entropy
participant CrossEntropyFunction
participant cross_entropy_forward
participant online_softmax_kernel
participant cross_entropy_kernel
participant cross_entropy_backward
participant element_mul_kernel
Caller->>parallel_cross_entropy: inp, target, z_loss_weight
parallel_cross_entropy->>parallel_cross_entropy: validate z_loss_weight >= 0
parallel_cross_entropy->>CrossEntropyFunction: .apply(inp, target, ..., z_loss_weight)
CrossEntropyFunction->>cross_entropy_forward: inp, target, ..., z_loss_weight
cross_entropy_forward->>online_softmax_kernel: compute m, d, X_y per row
online_softmax_kernel-->>cross_entropy_forward: m_d_X_y tensor
cross_entropy_forward->>cross_entropy_kernel: m_d_X_y, z_loss_weight (constexpr)
Note over cross_entropy_kernel: lse = m + log(d)<br/>loss += z_loss_weight * lse²<br/>grad = softmax*(1+2*z*lse) - eps
cross_entropy_kernel-->>cross_entropy_forward: loss_1d, grad in _input
cross_entropy_forward-->>CrossEntropyFunction: loss, _input (with grad)
CrossEntropyFunction->>CrossEntropyFunction: ctx.save_for_backward(_input)
CrossEntropyFunction-->>Caller: loss tensor
Caller->>CrossEntropyFunction: .backward(grad_output)
CrossEntropyFunction->>cross_entropy_backward: _input (saved grad), grad_output
cross_entropy_backward->>element_mul_kernel: scale grad by grad_output
element_mul_kernel-->>cross_entropy_backward: scaled grad
cross_entropy_backward-->>CrossEntropyFunction: inp grad
CrossEntropyFunction-->>Caller: (inp_grad, None×7)
Last reviewed commit: "[pre-commit.ci] auto..." |
| # Z-loss gradient: d/dx_i[z_loss_weight * lse^2] = 2 * z_loss_weight * lse * softmax(x_i). | ||
| if z_loss_weight > 0: | ||
| X_block = X_block * (1.0 + 2.0 * z_loss_weight * lse) |
There was a problem hiding this comment.
Incorrect z-loss gradient when
label_smoothing > 0
At this point X_block = (softmax(x_i) - eps) / N (or softmax - eps without reduction). Multiplying the combined CE gradient by (1 + 2 * z_loss_weight * lse) expands to:
(softmax - eps)/N * (1 + 2*z*lse)
= (softmax - eps)/N + (softmax - eps)/N * 2*z*lse
But the correct z-loss gradient is purely softmax/N * 2 * z * lse — the z-loss term should be additive on top of the CE gradient, not multiplicative against the entire (softmax - eps) expression. The error introduced is -eps/N * 2 * z_loss_weight * lse per element.
For typical training settings (label_smoothing=0.1, V=64000, z_loss_weight=0.001, lse≈11) the error is on the order of 3e-8, which is below float32 precision for large vocabularies — explaining why test_z_loss_label_smoothing still passes. However, for small vocabularies (e.g. V=32) the error becomes measurable and the implementation is mathematically incorrect.
The correct approach is to add the z-loss gradient additively, using the pre-eps softmax value:
if reduce_loss:
X_block = (tl.exp(X_block - m) / d - eps) / (n_non_ignore)
else:
X_block = tl.exp(X_block - m) / d - eps
# Z-loss gradient: 2 * z_loss_weight * lse * softmax(x_i), additive to CE gradient.
if z_loss_weight > 0:
softmax_i = tl.exp(X_block_fp32 - m) / d # pure softmax, before subtracting eps
if reduce_loss:
X_block = X_block + softmax_i * (2.0 * z_loss_weight * lse) / n_non_ignore
else:
X_block = X_block + softmax_i * (2.0 * z_loss_weight * lse)where X_block_fp32 is the logit block before the CE computation (currently loaded at the top of the loop).
| def test_z_loss_zero_weight(self): | ||
| self.generate_infra(False, 0) | ||
| self.generate_input(torch.float32, False, False) | ||
| loss_base = self.test_loss_func(self.input_test.clone(), self.tar_test) | ||
| loss_zero = self.test_loss_func(self.input_test.clone(), self.tar_test, z_loss_weight=0.0) | ||
| assert torch.equal( | ||
| loss_base, loss_zero | ||
| ), "z_loss_weight=0.0 must be bit-identical to the default" | ||
| self.input_test = None | ||
| self.input_ref = None | ||
| self.tar_test = None | ||
| self.tar_ref = None |
There was a problem hiding this comment.
test_z_loss_zero_weight only validates the forward pass
The test clones the input tensor but never calls .requires_grad_(True), so no gradient is accumulated and the backward path is never exercised. The Triton kernel eliminates the z-loss branches at compile time via tl.constexpr, so validating that the gradient is also bit-identical for z_loss_weight=0.0 would strengthen the regression value of this test.
Consider adding backward verification:
def test_z_loss_zero_weight(self):
self.generate_infra(False, 0)
self.generate_input(torch.float32, False, False)
inp_base = self.input_test.clone().requires_grad_(True)
inp_zero = self.input_test.clone().requires_grad_(True)
loss_base = self.test_loss_func(inp_base, self.tar_test)
loss_zero = self.test_loss_func(inp_zero, self.tar_test, z_loss_weight=0.0)
assert torch.equal(loss_base, loss_zero), "z_loss_weight=0.0 must be bit-identical to the default"
loss_base.sum().backward()
loss_zero.sum().backward()
assert torch.equal(inp_base.grad, inp_zero.grad), \
"Gradients with z_loss_weight=0.0 must be bit-identical to the default"
self.input_test = self.input_ref = self.tar_test = self.tar_ref = None| dist_process_group: Optional[torch.distributed.ProcessGroup] = None, | ||
| ignore_idx: int = -100, | ||
| is_cg_capturable: bool = False, | ||
| z_loss_weight: float = 0.0, |
There was a problem hiding this comment.
No validation that
z_loss_weight is non-negative
A negative value is mathematically well-formed but semantically inverts the regularization (rewarding large logit magnitudes). Given the docstring describes this as a "regularization weight", an early guard against negative values would make the API safer and the intent explicit:
| z_loss_weight: float = 0.0, | |
| z_loss_weight: float = 0.0, |
Consider adding before the CrossEntropyFunction.apply(...) call:
if z_loss_weight < 0.0:
raise ValueError(f"z_loss_weight must be non-negative, got {z_loss_weight}")056ce5f to
bb47312
Compare
| n_non_ignore, | ||
| reduce_loss: tl.constexpr, | ||
| label_smoothing: tl.constexpr, | ||
| z_loss_weight: tl.constexpr, |
There was a problem hiding this comment.
tl.constexpr specialization will recompile for every unique float value
z_loss_weight is declared tl.constexpr, which means Triton compiles a separate kernel for each unique Python float value passed. The PR describes this as intentional for dead-code elimination when z_loss_weight=0.0, and for a fixed training hyperparameter that's fine. However, if callers ever want to anneal or schedule z_loss_weight across training steps (e.g. a warmup from 0 → 0.001), every distinct float encountered will trigger a fresh JIT compilation, stalling the training loop.
Consider documenting this behaviour in the docstring of both cross_entropy_kernel and parallel_cross_entropy:
z_loss_weight (float): Weight for z-loss regularization. Adds z_loss_weight * log(Z)^2 per token.
This value is used as a Triton compile-time constant (tl.constexpr); varying it across
calls will trigger kernel recompilation. Use a fixed value during training.| inp = inp.float() | ||
| ce = F.cross_entropy(inp, tar, reduction="none", label_smoothing=label_smoothing) | ||
| z_pen = z_loss_weight * torch.square(torch.logsumexp(inp, dim=-1)) | ||
| z_pen[tar == -100] = 0.0 |
There was a problem hiding this comment.
Ignore index hardcoded to
-100 in reference function
z_pen[tar == -100] = 0.0 hardcodes the ignore index rather than closing over the ignore_idx value that will be used in the actual parallel_cross_entropy call. Currently all tests in the suite pass ignore_idx=-100 (the default), so this is consistent — but if a future test exercises a non-default ignore index, the reference would silently zero the wrong tokens and the comparison would yield a false positive.
Consider parameterising the closure:
| z_pen[tar == -100] = 0.0 | |
| z_pen[tar == ignore_idx] = 0.0 |
And updating generate_infra to accept ignore_idx: int = -100 so it can be forwarded.
bb47312 to
ed83839
Compare
ed83839 to
da503ed
Compare
| def test_z_loss_reduced(self): | ||
| self.generate_iters(5) | ||
| self.generate_infra(True, 0, z_loss_weight=0.001) | ||
| for i in range(self.iters): | ||
| self.one_iteration_test( | ||
| dtype=torch.float32, | ||
| swap_dim=random.choice([True, False]), | ||
| label_smoothing=0, | ||
| reduce_loss=True, | ||
| z_loss_weight=0.001, | ||
| ) |
There was a problem hiding this comment.
Missing test combination:
reduce_loss=True + z-loss + ignore_idx
The test suite covers reduce_loss=True + z-loss (test_z_loss_reduced) and reduce_loss=False + z-loss + ignore_idx (test_z_loss_with_ignore_idx), but no test exercises all three together. This is the most semantically interesting combination: n_non_ignore is used to normalize both the loss value (in Python) and the gradient (in the Triton kernel, line 198), so an incorrect interaction would only appear when tokens are actually masked and reduction is active.
Consider adding:
def test_z_loss_reduced_with_ignore_idx(self):
self.generate_iters(5)
self.generate_infra(True, 0, z_loss_weight=0.001)
for i in range(self.iters):
self.one_iteration_test(
dtype=torch.float32,
swap_dim=random.choice([True, False]),
label_smoothing=0,
reduce_loss=True,
ignore_idx=True,
z_loss_weight=0.001,
)| def ref_with_zloss(inp, tar): | ||
| inp = inp.float() | ||
| ce = F.cross_entropy(inp, tar, reduction="none", label_smoothing=label_smoothing) | ||
| z_pen = z_loss_weight * torch.square(torch.logsumexp(inp, dim=-1)) | ||
| z_pen[tar == ignore_idx] = 0.0 | ||
| loss = ce + z_pen | ||
| if reduce_loss: | ||
| loss = loss.sum() / (tar != ignore_idx).sum() | ||
| return loss |
There was a problem hiding this comment.
F.cross_entropy does not receive the parameterized ignore_index
generate_infra now accepts an ignore_idx parameter and correctly uses it to zero out z_pen, but F.cross_entropy is called without ignore_index=ignore_idx. PyTorch's default is -100, so all current tests pass since generate_input always uses -100. However, if a future test passes a non-default ignore_idx, the CE component of the reference would still ignore -100 while the real kernel would ignore the custom index, producing a silently incorrect reference loss and false-passing gradient tests.
def ref_with_zloss(inp, tar):
inp = inp.float()
ce = F.cross_entropy(
inp, tar, reduction="none",
label_smoothing=label_smoothing,
ignore_index=ignore_idx, # <-- forward the parameter
)
z_pen = z_loss_weight * torch.square(torch.logsumexp(inp, dim=-1))
z_pen[tar == ignore_idx] = 0.0
...Add z-loss regularization (z_loss_weight * log(Z)^2 per token) to the Triton cross-entropy kernel. The z_loss_weight parameter is a tl.constexpr, so it is dead-code-eliminated when set to 0.0. Forward: adds z_loss_weight * lse^2 to per-token loss. Backward: scales softmax gradient by (1 + 2 * z_loss_weight * lse). Tests: extend existing test infrastructure with z_loss_weight parameter in generate_infra and one_iteration_test. Z-loss tests cover FP32, BF16, ignore_idx, and zero-weight identity. Signed-off-by: Cem Bassoy <cem.bassoy@deepl.com>
b337f49 to
13ffa35
Compare
for more information, see https://pre-commit.ci
| self.ref_loss_func = torch.nn.CrossEntropyLoss( | ||
| label_smoothing=label_smoothing, reduction="mean" if reduce_loss else "none" | ||
| ) |
There was a problem hiding this comment.
ignore_idx not forwarded to CrossEntropyLoss
generate_infra now accepts an ignore_idx parameter, and the z_loss_weight > 0 branch correctly passes it to both F.cross_entropy(..., ignore_index=ignore_idx) and z_pen[tar == ignore_idx]. However, the z_loss_weight == 0.0 branch silently falls back to PyTorch's default (-100), ignoring the parameter:
self.ref_loss_func = torch.nn.CrossEntropyLoss(
label_smoothing=label_smoothing, reduction="mean" if reduce_loss else "none"
# ignore_index=ignore_idx is missing here
)All current tests happen to pass ignore_idx=-100 (the default), so there is no visible failure now. But if any test ever calls generate_infra(..., z_loss_weight=0.0, ignore_idx=42), the reference would still ignore token id -100 instead of 42, producing a silently incorrect reference and a false-passing comparison.
| self.ref_loss_func = torch.nn.CrossEntropyLoss( | |
| label_smoothing=label_smoothing, reduction="mean" if reduce_loss else "none" | |
| ) | |
| self.ref_loss_func = torch.nn.CrossEntropyLoss( | |
| label_smoothing=label_smoothing, | |
| reduction="mean" if reduce_loss else "none", | |
| ignore_index=ignore_idx, | |
| ) |
| if z_loss_weight < 0: | ||
| raise ValueError(f"z_loss_weight must be non-negative, got {z_loss_weight}") |
There was a problem hiding this comment.
Non-finite
z_loss_weight bypasses validation
The current guard only rejects strictly negative values. float('nan') satisfies nan < 0 == False (NaN comparisons always return False) and float('inf') satisfies inf < 0 == False, so both slip through undetected:
z_loss_weight=float('nan'): the Tritontl.constexprcomparisonnan > 0evaluates toFalseat compile time, so the z-loss branch is silently skipped — users see no z-loss even though they passed a non-zero value.z_loss_weight=float('inf'): the branch is taken, andloss += inf * lse * lsewill produceinf/nanlosses, immediately destabilising training.
Consider expanding the guard to also exclude non-finite values:
| if z_loss_weight < 0: | |
| raise ValueError(f"z_loss_weight must be non-negative, got {z_loss_weight}") | |
| if not (z_loss_weight >= 0.0 and z_loss_weight != float("inf")): | |
| raise ValueError(f"z_loss_weight must be a finite non-negative number, got {z_loss_weight}") |
Description
Adds z-loss regularization to
parallel_cross_entropy. Z-loss penalizeslarge logit magnitudes by adding
z_loss_weight * log(Z)^2per token to theloss, where
log(Z) = log(sum(exp(logits)))is the log-sum-exp (seeST-MoE, arxiv.org/abs/2202.08906).
This stabilizes training by keeping logits in a numerically well-behaved range.
Type of change
Changes
The Triton kernel already computes
m(running max) andd(sum of shiftedexponentials) as part of the online softmax.
lse = m + log(d)is a freebyproduct which means no extra data movement required.
lse = m + log(d)using variables and addz_loss_weight * lse^2to the per-token loss. Reuselsein thelabel-smoothing path (
smooth_losspreviously recomputedm + log(d)).(1 + 2 * z_loss_weight * lse),derived from
d/dx_i[z_loss_weight * lse^2] = 2 * z_loss_weight * lse * softmax(x_i).z_loss_weightistl.constexprenabling tritonto eliminate all z-loss branches at compile time when
z_loss_weight=0.0.parallel_cross_entropy(..., z_loss_weight=0.0)is backwardcompatible. Hence, default behavior is unchanged.
Tests
Z-loss tests extend the existing test infrastructure rather than introducing
separate helpers.
generate_infraaccepts an optionalz_loss_weightparameter: when
> 0, it builds a PyTorch reference function that computesF.cross_entropy + z_loss_weight * lse^2in FP32.one_iteration_testpasses
z_loss_weightthrough toparallel_cross_entropy. This keeps allz-loss tests in the same pattern as the existing suite.
The existing
z_loss_weight == 0.0path ingenerate_infrais untouched(
torch.nn.CrossEntropyLoss), so there is no risk to existing tests.6 new tests, 14 total pass (A40, single GPU):
test_z_losstest_z_loss_bfloat16test_z_loss_with_ignore_idxtest_z_loss_zero_weightz_loss_weight=0.0produces bit-identical results to the default (no z-loss)test_z_loss_reducedreduce_loss=True: reduced loss and gradients correct (5 iterations)test_z_loss_label_smoothinglabel_smoothing=0.1: both features interact correctly (3 iterations)Checklist: