Skip to content

Commit 407501e

Browse files
authored
Fixing memory leak in Joint QKV Attention Bridge (#1229)
* Fixing bug with deepcopy that was unintentionally copying all weights in every layer of Joint QKV Attention Bridge * Additional model testing to confirm fix was not negatively impacting models
1 parent c39dbbc commit 407501e

6 files changed

Lines changed: 401 additions & 20 deletions

File tree

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
#!/usr/bin/env python3
2+
"""Memory benchmark: TransformerBridge.boot_transformers vs HookedTransformer.from_pretrained.
3+
4+
Run with: python -m pytest tests/benchmarks/test_boot_memory.py -v -s
5+
Or directly: python tests/benchmarks/test_boot_memory.py [model_name]
6+
"""
7+
8+
import gc
9+
import os
10+
import subprocess
11+
import sys
12+
13+
import pytest
14+
15+
16+
def get_rss_mb():
17+
"""Get current process RSS in MB."""
18+
try:
19+
import psutil
20+
21+
return psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024
22+
except ImportError:
23+
try:
24+
with open(f"/proc/{os.getpid()}/status") as f:
25+
for line in f:
26+
if line.startswith("VmRSS:"):
27+
return int(line.split()[1]) / 1024
28+
except FileNotFoundError:
29+
pass
30+
try:
31+
result = subprocess.run(
32+
["ps", "-o", "rss=", "-p", str(os.getpid())],
33+
capture_output=True,
34+
text=True,
35+
)
36+
return int(result.stdout.strip()) / 1024
37+
except Exception:
38+
return 0.0
39+
40+
41+
def profile_hooked_transformer(
42+
model_name, fold_ln=False, fold_value_biases=False, center_writing_weights=False
43+
):
44+
"""Profile HookedTransformer.from_pretrained RSS at each stage."""
45+
import torch
46+
47+
_ = torch.set_grad_enabled(False)
48+
checkpoints = []
49+
50+
gc.collect()
51+
checkpoints.append(("baseline", get_rss_mb()))
52+
53+
from transformer_lens import HookedTransformer
54+
55+
gc.collect()
56+
checkpoints.append(("after import", get_rss_mb()))
57+
58+
model = HookedTransformer.from_pretrained(
59+
model_name,
60+
fold_ln=fold_ln,
61+
fold_value_biases=fold_value_biases,
62+
center_writing_weights=center_writing_weights,
63+
)
64+
gc.collect()
65+
checkpoints.append(("after from_pretrained", get_rss_mb()))
66+
67+
param_mb = sum(p.nelement() * p.element_size() for p in model.parameters()) / 1024 / 1024
68+
checkpoints.append(("param_size_mb", param_mb))
69+
70+
del model
71+
gc.collect()
72+
checkpoints.append(("after del model", get_rss_mb()))
73+
74+
return checkpoints
75+
76+
77+
def profile_transformer_bridge(
78+
model_name, fold_ln=False, fold_value_biases=False, center_writing_weights=False
79+
):
80+
"""Profile TransformerBridge.boot_transformers RSS at each stage."""
81+
import torch
82+
83+
_ = torch.set_grad_enabled(False)
84+
checkpoints = []
85+
86+
gc.collect()
87+
checkpoints.append(("baseline", get_rss_mb()))
88+
89+
from transformer_lens.model_bridge import TransformerBridge
90+
91+
gc.collect()
92+
checkpoints.append(("after import", get_rss_mb()))
93+
94+
bridge = TransformerBridge.boot_transformers(model_name)
95+
gc.collect()
96+
checkpoints.append(("after boot_transformers", get_rss_mb()))
97+
98+
bridge.enable_compatibility_mode(
99+
fold_ln=fold_ln,
100+
fold_value_biases=fold_value_biases,
101+
center_writing_weights=center_writing_weights,
102+
)
103+
gc.collect()
104+
checkpoints.append(("after enable_compatibility_mode", get_rss_mb()))
105+
106+
param_mb = sum(p.nelement() * p.element_size() for p in bridge.parameters()) / 1024 / 1024
107+
checkpoints.append(("param_size_mb", param_mb))
108+
109+
del bridge
110+
gc.collect()
111+
checkpoints.append(("after del bridge", get_rss_mb()))
112+
113+
return checkpoints
114+
115+
116+
def run_in_subprocess(func_name, model_name, **kwargs):
117+
"""Run a profiling function in a fresh subprocess for clean RSS readings."""
118+
kwargs_str = ", ".join(f"{k}={v!r}" for k, v in kwargs.items())
119+
script = f"""
120+
import sys
121+
sys.path.insert(0, '.')
122+
from tests.benchmarks.test_boot_memory import {func_name}
123+
results = {func_name}({model_name!r}, {kwargs_str})
124+
for name, val in results:
125+
print(f"{{name}}\\t{{val:.1f}}")
126+
"""
127+
result = subprocess.run(
128+
[sys.executable, "-c", script],
129+
capture_output=True,
130+
text=True,
131+
cwd=os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))),
132+
)
133+
if result.returncode != 0:
134+
print(f"STDERR:\n{result.stderr}", file=sys.stderr)
135+
raise RuntimeError(f"{func_name} subprocess failed (exit {result.returncode})")
136+
137+
checkpoints = {}
138+
for line in result.stdout.strip().split("\n"):
139+
if "\t" in line:
140+
name, val = line.split("\t", 1)
141+
checkpoints[name] = float(val)
142+
return checkpoints
143+
144+
145+
MEMORY_BENCHMARK_MODELS = ["gpt2"]
146+
_BENCH_KWARGS = dict(fold_ln=False, fold_value_biases=False, center_writing_weights=False)
147+
148+
149+
class TestBootMemory:
150+
"""Ensure TransformerBridge memory stays within bounds relative to HookedTransformer."""
151+
152+
@pytest.mark.parametrize("model_name", MEMORY_BENCHMARK_MODELS)
153+
def test_bridge_memory_within_bounds(self, model_name):
154+
"""TransformerBridge RSS must not exceed 4x parameter size."""
155+
results = run_in_subprocess("profile_transformer_bridge", model_name, **_BENCH_KWARGS)
156+
157+
param_mb = results["param_size_mb"]
158+
net_rss = results["after enable_compatibility_mode"] - results["baseline"]
159+
max_allowed = param_mb * 4
160+
161+
print(f"\n TransformerBridge({model_name}):")
162+
print(f" Param size: {param_mb:>8.1f} MB")
163+
print(f" Net RSS: {net_rss:>8.1f} MB ({net_rss / param_mb:.1f}x params)")
164+
print(f" Max allowed: {max_allowed:>8.1f} MB (4x params)")
165+
166+
assert net_rss < max_allowed, (
167+
f"TransformerBridge RSS ({net_rss:.0f} MB) exceeds 4x param size "
168+
f"({max_allowed:.0f} MB) for {model_name}. Ratio: {net_rss / param_mb:.1f}x"
169+
)
170+
171+
@pytest.mark.parametrize("model_name", MEMORY_BENCHMARK_MODELS)
172+
def test_bridge_vs_hooked_transformer_ratio(self, model_name):
173+
"""TransformerBridge must use no more than 2x the RSS of HookedTransformer."""
174+
ht_results = run_in_subprocess("profile_hooked_transformer", model_name, **_BENCH_KWARGS)
175+
bridge_results = run_in_subprocess(
176+
"profile_transformer_bridge", model_name, **_BENCH_KWARGS
177+
)
178+
179+
ht_net = ht_results["after from_pretrained"] - ht_results["baseline"]
180+
bridge_net = bridge_results["after enable_compatibility_mode"] - bridge_results["baseline"]
181+
ratio = bridge_net / ht_net if ht_net > 0 else float("inf")
182+
183+
print(f"\n Memory comparison ({model_name}):")
184+
print(f" HookedTransformer: {ht_net:>8.1f} MB")
185+
print(f" TransformerBridge: {bridge_net:>8.1f} MB")
186+
print(f" Ratio: {ratio:>8.1f}x")
187+
188+
assert ratio < 2.0, (
189+
f"TransformerBridge uses {ratio:.1f}x more memory than HookedTransformer "
190+
f"for {model_name} (Bridge: {bridge_net:.0f} MB, HT: {ht_net:.0f} MB). Expected < 2.0x."
191+
)
192+
193+
194+
if __name__ == "__main__":
195+
model_name = sys.argv[1] if len(sys.argv) > 1 else "gpt2"
196+
print(f"Memory benchmark for: {model_name}")
197+
print("=" * 60)
198+
199+
print("\nHookedTransformer.from_pretrained:")
200+
ht = run_in_subprocess("profile_hooked_transformer", model_name, **_BENCH_KWARGS)
201+
for name, val in ht.items():
202+
print(f" {name:<35s} {val:>8.1f} MB")
203+
204+
print("\nTransformerBridge.boot_transformers:")
205+
bridge = run_in_subprocess("profile_transformer_bridge", model_name, **_BENCH_KWARGS)
206+
for name, val in bridge.items():
207+
print(f" {name:<35s} {val:>8.1f} MB")
208+
209+
print("\n" + "=" * 60)
210+
ht_net = ht["after from_pretrained"] - ht["baseline"]
211+
bridge_net = bridge["after enable_compatibility_mode"] - bridge["baseline"]
212+
print(f"HookedTransformer net: {ht_net:>8.1f} MB")
213+
print(f"TransformerBridge net: {bridge_net:>8.1f} MB")
214+
print(f"Ratio: {bridge_net / ht_net:>8.1f}x")
215+
print(f"Param size: {bridge['param_size_mb']:>8.1f} MB")

tests/unit/model_bridge/generalized_components/test_joint_qkv_attention.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""Unit tests for Joint QKV Attention bridge."""
22

3+
import copy
4+
35
import torch
46

57
from transformer_lens.model_bridge.generalized_components.joint_qkv_attention import (
@@ -485,3 +487,46 @@ def __init__(self):
485487
bridge,
486488
position_embeddings=position_embeddings,
487489
)
490+
491+
def test_deepcopy_does_not_copy_bound_method_self(self):
492+
"""Deepcopy shares split_qkv_matrix and config instead of copying them."""
493+
494+
class FakeAdapter:
495+
def __init__(self):
496+
self.heavy_data = torch.randn(100, 100)
497+
498+
def split_qkv(self, component):
499+
return (torch.nn.Linear(4, 4), torch.nn.Linear(4, 4), torch.nn.Linear(4, 4))
500+
501+
class TestConfig:
502+
n_heads = 2
503+
d_model = 4
504+
505+
adapter = FakeAdapter()
506+
bridge = JointQKVAttentionBridge(
507+
name="attn",
508+
config=TestConfig(),
509+
split_qkv_matrix=adapter.split_qkv,
510+
)
511+
512+
clone = copy.deepcopy(bridge)
513+
514+
assert clone.split_qkv_matrix is bridge.split_qkv_matrix
515+
assert clone.split_qkv_matrix.__self__ is adapter
516+
assert clone.config is bridge.config
517+
518+
def test_deepcopy_produces_independent_hooks(self):
519+
"""Deepcopy produces independent HookPoint and LinearBridge instances."""
520+
521+
class TestConfig:
522+
n_heads = 2
523+
d_model = 4
524+
525+
bridge = JointQKVAttentionBridge(name="attn", config=TestConfig())
526+
clone = copy.deepcopy(bridge)
527+
528+
assert clone.hook_in is not bridge.hook_in
529+
assert clone.hook_out is not bridge.hook_out
530+
assert clone.q is not bridge.q
531+
assert clone.k is not bridge.k
532+
assert clone.v is not bridge.v

tests/unit/model_bridge/test_component_setup.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from types import SimpleNamespace
55

66
import pytest
7+
import torch
78
import torch.nn as nn
89

910
from tests.mocks.architecture_adapter import MockArchitectureAdapter, mock_model_adapter
@@ -21,6 +22,9 @@
2122
MLPBridge,
2223
NormalizationBridge,
2324
)
25+
from transformer_lens.model_bridge.generalized_components.joint_qkv_attention import (
26+
JointQKVAttentionBridge,
27+
)
2428

2529

2630
class TestComponentSetup:
@@ -286,6 +290,45 @@ def _create_fresh_mock_model(self):
286290

287291
return model
288292

293+
def test_setup_blocks_bridge_no_weight_duplication_with_bound_method(self):
294+
"""Block deepcopy must not duplicate adapter via bound split_qkv_matrix."""
295+
adapter = MockArchitectureAdapter()
296+
297+
class FakeAdapterWithSplitFn:
298+
def __init__(self):
299+
self.component_mapping = {}
300+
self.heavy_tensor = torch.randn(256, 256)
301+
302+
def split_qkv(self, component):
303+
return (
304+
nn.Linear(10, 10, bias=False),
305+
nn.Linear(10, 10, bias=False),
306+
nn.Linear(10, 10, bias=False),
307+
)
308+
309+
fake_adapter = FakeAdapterWithSplitFn()
310+
311+
blocks_template = BlockBridge(
312+
name="blocks",
313+
submodules={
314+
"ln1": NormalizationBridge(name="ln1", config={}),
315+
"attn": JointQKVAttentionBridge(
316+
name="attn",
317+
config=SimpleNamespace(n_heads=1, d_model=10),
318+
split_qkv_matrix=fake_adapter.split_qkv,
319+
),
320+
},
321+
)
322+
fake_adapter.component_mapping["blocks"] = blocks_template
323+
324+
mock_model = self._create_fresh_mock_model()
325+
bridged_blocks = setup_blocks_bridge(blocks_template, adapter, mock_model)
326+
327+
for block in bridged_blocks:
328+
fn = block.attn.split_qkv_matrix
329+
assert hasattr(fn, "__self__")
330+
assert fn.__self__ is fake_adapter
331+
289332
@pytest.fixture
290333
def mock_model_adapter(self):
291334
"""Create a mock model for testing."""

transformer_lens/model_bridge/generalized_components/joint_qkv_attention.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
33
This module contains the bridge component for attention layers that use a fused qkv matrix.
44
"""
5+
import copy
56
from typing import Any, Callable, Dict, Optional
67

78
import einops
@@ -108,6 +109,35 @@ def __init__(
108109
# Exclude stale qkv combined weights from state_dict after splitting.
109110
self._register_state_dict_hook(JointQKVAttentionBridge._filter_qkv_state_dict)
110111

112+
def __deepcopy__(self, memo):
113+
"""Share split_qkv_matrix and config across clones instead of copying.
114+
115+
split_qkv_matrix may be a bound method of the architecture adapter,
116+
which transitively references the full HF model. Without this override,
117+
deepcopy duplicates the entire model per block (~1GB x N_layers).
118+
"""
119+
saved_split_fn = self.split_qkv_matrix
120+
saved_config = self.config
121+
122+
self.split_qkv_matrix = None # type: ignore[assignment]
123+
self.config = None
124+
try:
125+
# Remove override from defining class (not subclass) to avoid recursion.
126+
owner = JointQKVAttentionBridge
127+
override = owner.__dict__["__deepcopy__"]
128+
del owner.__deepcopy__
129+
try:
130+
clone = copy.deepcopy(self, memo)
131+
finally:
132+
owner.__deepcopy__ = override # type: ignore[method-assign]
133+
finally:
134+
self.split_qkv_matrix = saved_split_fn
135+
self.config = saved_config
136+
137+
clone.split_qkv_matrix = saved_split_fn
138+
clone.config = saved_config
139+
return clone
140+
111141
@staticmethod
112142
def _filter_qkv_state_dict(
113143
module: torch.nn.Module,

0 commit comments

Comments
 (0)