Skip to content

PEFT LoRA transformer layers not trained when gradient_checkpointing=True — enable_input_require_grads() missing #178

@Anya-wUw

Description

@Anya-wUw

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task
  • My own task or dataset (give details below)

Reproduction

After training any LoRA unlearning with gradient checkpointing enabled, inspect the adapter:

from safetensors.torch import load_file                                                                                                                                
sd = load_file("saves/unlearn/<method>/<run>/adapter_model.safetensors")
b_norms = {k: v.norm().item() for k, v in sd.items() if "lora_B" in k}                                                                                                 
nonzero = sum(1 for n in b_norms.values() if n > 1e-6)                                                                                                                 
print(f"lora_B nonzero: {nonzero}/{len(b_norms)}")                                                                                                                     
# Expected: all non-zero                                                                                                                                               
# Actual:   1/<total>  (only lm_head.lora_B is non-zero) 

Expected behavior

When training unlearning methods with LoRA and gradient_checkpointing enabled, only the lm_head LoRA adapter is trained, all transformer layer LoRA matrices (q_proj, k_proj, v_proj, o_proj, gate_proj, down_proj, up_proj) remain at their zero initialization throughout training. This is completely unexpected.

Cause
Torch's gradient checkpointing skips the backward pass through a checkpointed block if none of its input tensors has requires_grad=True. For PEFT LoRA models, the hidden states entering transformer blocks have requires_grad=False by default. Without calling model.enable_input_require_grads(), no gradient reaches the LoRA parameters inside those blocks.

lm_head is outside the checkpointed transformer blocks and connects directly to the loss, so its lora_B receives gradients normally, which is why it alone gets trained.

Fix

Add one line in src/train.py after get_model():

 model, tokenizer = get_model(model_cfg)    
                                                                                                                                                                                                                                                  
 if hasattr(model, "enable_input_require_grads"):
     model.enable_input_require_grads() 

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions