From 4f7c0c2b94c2435ee9d2567b7ecbd318263e9cc3 Mon Sep 17 00:00:00 2001 From: Brendan Long Date: Sat, 28 Mar 2026 18:28:55 -0700 Subject: [PATCH] Update RMSNorm.forward to accept 2D or 3D input shape RMSNorm.forward() has a jaxtyping hint expecting a 3D tensor (batch, pos, length), but _apply_qk_norm was reshaping to 2D (batch*pos*heads, d_head), causing a BeartypeCallHintParamViolation on Gemma 3 models. Update the type hint to allow 3D since the code only cares about the last dimension. Co-Authored-By: Claude Opus 4.6 (1M context) --- transformer_lens/components/rms_norm.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/transformer_lens/components/rms_norm.py b/transformer_lens/components/rms_norm.py index 7f45b6e55..e6eced523 100644 --- a/transformer_lens/components/rms_norm.py +++ b/transformer_lens/components/rms_norm.py @@ -12,6 +12,13 @@ from transformer_lens.hook_points import HookPoint from transformer_lens.HookedTransformerConfig import HookedTransformerConfig +# RMSNorm operates on the last dimension and supports both 2D and 3D inputs. +# The 2D case arises when callers (e.g. QK normalization) reshape before normalizing. +RMSNormInput = Union[ + Float[torch.Tensor, "batch pos length"], + Float[torch.Tensor, "batch_pos length"], +] + class RMSNorm(nn.Module): def __init__(self, cfg: Union[Dict, HookedTransformerConfig], length: Optional[int] = None): @@ -34,9 +41,7 @@ def __init__(self, cfg: Union[Dict, HookedTransformerConfig], length: Optional[i self.hook_scale = HookPoint() # [batch, pos, 1] self.hook_normalized = HookPoint() # [batch, pos, length] - def forward( - self, x: Float[torch.Tensor, "batch pos length"] - ) -> Float[torch.Tensor, "batch pos length"]: + def forward(self, x: RMSNormInput) -> RMSNormInput: if self.cfg.dtype not in [torch.float32, torch.float64]: x = x.to(torch.float32) scale: Float[torch.Tensor, "batch pos 1"] = self.hook_scale(