Information
Tasks
Reproduction
After training any LoRA unlearning with gradient checkpointing enabled, inspect the adapter:
from safetensors.torch import load_file
sd = load_file("saves/unlearn/<method>/<run>/adapter_model.safetensors")
b_norms = {k: v.norm().item() for k, v in sd.items() if "lora_B" in k}
nonzero = sum(1 for n in b_norms.values() if n > 1e-6)
print(f"lora_B nonzero: {nonzero}/{len(b_norms)}")
# Expected: all non-zero
# Actual: 1/<total> (only lm_head.lora_B is non-zero)
Expected behavior
When training unlearning methods with LoRA and gradient_checkpointing enabled, only the lm_head LoRA adapter is trained, all transformer layer LoRA matrices (q_proj, k_proj, v_proj, o_proj, gate_proj, down_proj, up_proj) remain at their zero initialization throughout training. This is completely unexpected.
Cause
Torch's gradient checkpointing skips the backward pass through a checkpointed block if none of its input tensors has requires_grad=True. For PEFT LoRA models, the hidden states entering transformer blocks have requires_grad=False by default. Without calling model.enable_input_require_grads(), no gradient reaches the LoRA parameters inside those blocks.
lm_head is outside the checkpointed transformer blocks and connects directly to the loss, so its lora_B receives gradients normally, which is why it alone gets trained.
Fix
Add one line in src/train.py after get_model():
model, tokenizer = get_model(model_cfg)
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
Information
Tasks
Reproduction
After training any LoRA unlearning with gradient checkpointing enabled, inspect the adapter:
Expected behavior
When training unlearning methods with LoRA and gradient_checkpointing enabled, only the lm_head LoRA adapter is trained, all transformer layer LoRA matrices (q_proj, k_proj, v_proj, o_proj, gate_proj, down_proj, up_proj) remain at their zero initialization throughout training. This is completely unexpected.
Cause
Torch's gradient checkpointing skips the backward pass through a checkpointed block if none of its input tensors has requires_grad=True. For PEFT LoRA models, the hidden states entering transformer blocks have requires_grad=False by default. Without calling model.enable_input_require_grads(), no gradient reaches the LoRA parameters inside those blocks.
lm_head is outside the checkpointed transformer blocks and connects directly to the loss, so its lora_B receives gradients normally, which is why it alone gets trained.
Fix
Add one line in src/train.py after get_model():