Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 215 additions & 0 deletions tests/benchmarks/test_boot_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
#!/usr/bin/env python3
"""Memory benchmark: TransformerBridge.boot_transformers vs HookedTransformer.from_pretrained.

Run with: python -m pytest tests/benchmarks/test_boot_memory.py -v -s
Or directly: python tests/benchmarks/test_boot_memory.py [model_name]
"""

import gc
import os
import subprocess
import sys

import pytest


def get_rss_mb():
"""Get current process RSS in MB."""
try:
import psutil

return psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024
except ImportError:
try:
with open(f"/proc/{os.getpid()}/status") as f:
for line in f:
if line.startswith("VmRSS:"):
return int(line.split()[1]) / 1024
except FileNotFoundError:
pass
try:
result = subprocess.run(
["ps", "-o", "rss=", "-p", str(os.getpid())],
capture_output=True,
text=True,
)
return int(result.stdout.strip()) / 1024
except Exception:
return 0.0


def profile_hooked_transformer(
model_name, fold_ln=False, fold_value_biases=False, center_writing_weights=False
):
"""Profile HookedTransformer.from_pretrained RSS at each stage."""
import torch

_ = torch.set_grad_enabled(False)
checkpoints = []

gc.collect()
checkpoints.append(("baseline", get_rss_mb()))

from transformer_lens import HookedTransformer

gc.collect()
checkpoints.append(("after import", get_rss_mb()))

model = HookedTransformer.from_pretrained(
model_name,
fold_ln=fold_ln,
fold_value_biases=fold_value_biases,
center_writing_weights=center_writing_weights,
)
gc.collect()
checkpoints.append(("after from_pretrained", get_rss_mb()))

param_mb = sum(p.nelement() * p.element_size() for p in model.parameters()) / 1024 / 1024
checkpoints.append(("param_size_mb", param_mb))

del model
gc.collect()
checkpoints.append(("after del model", get_rss_mb()))

return checkpoints


def profile_transformer_bridge(
model_name, fold_ln=False, fold_value_biases=False, center_writing_weights=False
):
"""Profile TransformerBridge.boot_transformers RSS at each stage."""
import torch

_ = torch.set_grad_enabled(False)
checkpoints = []

gc.collect()
checkpoints.append(("baseline", get_rss_mb()))

from transformer_lens.model_bridge import TransformerBridge

gc.collect()
checkpoints.append(("after import", get_rss_mb()))

bridge = TransformerBridge.boot_transformers(model_name)
gc.collect()
checkpoints.append(("after boot_transformers", get_rss_mb()))

bridge.enable_compatibility_mode(
fold_ln=fold_ln,
fold_value_biases=fold_value_biases,
center_writing_weights=center_writing_weights,
)
gc.collect()
checkpoints.append(("after enable_compatibility_mode", get_rss_mb()))

param_mb = sum(p.nelement() * p.element_size() for p in bridge.parameters()) / 1024 / 1024
checkpoints.append(("param_size_mb", param_mb))

del bridge
gc.collect()
checkpoints.append(("after del bridge", get_rss_mb()))

return checkpoints


def run_in_subprocess(func_name, model_name, **kwargs):
"""Run a profiling function in a fresh subprocess for clean RSS readings."""
kwargs_str = ", ".join(f"{k}={v!r}" for k, v in kwargs.items())
script = f"""
import sys
sys.path.insert(0, '.')
from tests.benchmarks.test_boot_memory import {func_name}
results = {func_name}({model_name!r}, {kwargs_str})
for name, val in results:
print(f"{{name}}\\t{{val:.1f}}")
"""
result = subprocess.run(
[sys.executable, "-c", script],
capture_output=True,
text=True,
cwd=os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))),
)
if result.returncode != 0:
print(f"STDERR:\n{result.stderr}", file=sys.stderr)
raise RuntimeError(f"{func_name} subprocess failed (exit {result.returncode})")

checkpoints = {}
for line in result.stdout.strip().split("\n"):
if "\t" in line:
name, val = line.split("\t", 1)
checkpoints[name] = float(val)
return checkpoints


MEMORY_BENCHMARK_MODELS = ["gpt2"]
_BENCH_KWARGS = dict(fold_ln=False, fold_value_biases=False, center_writing_weights=False)


class TestBootMemory:
"""Ensure TransformerBridge memory stays within bounds relative to HookedTransformer."""

@pytest.mark.parametrize("model_name", MEMORY_BENCHMARK_MODELS)
def test_bridge_memory_within_bounds(self, model_name):
"""TransformerBridge RSS must not exceed 4x parameter size."""
results = run_in_subprocess("profile_transformer_bridge", model_name, **_BENCH_KWARGS)

param_mb = results["param_size_mb"]
net_rss = results["after enable_compatibility_mode"] - results["baseline"]
max_allowed = param_mb * 4

print(f"\n TransformerBridge({model_name}):")
print(f" Param size: {param_mb:>8.1f} MB")
print(f" Net RSS: {net_rss:>8.1f} MB ({net_rss / param_mb:.1f}x params)")
print(f" Max allowed: {max_allowed:>8.1f} MB (4x params)")

assert net_rss < max_allowed, (
f"TransformerBridge RSS ({net_rss:.0f} MB) exceeds 4x param size "
f"({max_allowed:.0f} MB) for {model_name}. Ratio: {net_rss / param_mb:.1f}x"
)

@pytest.mark.parametrize("model_name", MEMORY_BENCHMARK_MODELS)
def test_bridge_vs_hooked_transformer_ratio(self, model_name):
"""TransformerBridge must use no more than 2x the RSS of HookedTransformer."""
ht_results = run_in_subprocess("profile_hooked_transformer", model_name, **_BENCH_KWARGS)
bridge_results = run_in_subprocess(
"profile_transformer_bridge", model_name, **_BENCH_KWARGS
)

ht_net = ht_results["after from_pretrained"] - ht_results["baseline"]
bridge_net = bridge_results["after enable_compatibility_mode"] - bridge_results["baseline"]
ratio = bridge_net / ht_net if ht_net > 0 else float("inf")

print(f"\n Memory comparison ({model_name}):")
print(f" HookedTransformer: {ht_net:>8.1f} MB")
print(f" TransformerBridge: {bridge_net:>8.1f} MB")
print(f" Ratio: {ratio:>8.1f}x")

assert ratio < 2.0, (
f"TransformerBridge uses {ratio:.1f}x more memory than HookedTransformer "
f"for {model_name} (Bridge: {bridge_net:.0f} MB, HT: {ht_net:.0f} MB). Expected < 2.0x."
)


if __name__ == "__main__":
model_name = sys.argv[1] if len(sys.argv) > 1 else "gpt2"
print(f"Memory benchmark for: {model_name}")
print("=" * 60)

print("\nHookedTransformer.from_pretrained:")
ht = run_in_subprocess("profile_hooked_transformer", model_name, **_BENCH_KWARGS)
for name, val in ht.items():
print(f" {name:<35s} {val:>8.1f} MB")

print("\nTransformerBridge.boot_transformers:")
bridge = run_in_subprocess("profile_transformer_bridge", model_name, **_BENCH_KWARGS)
for name, val in bridge.items():
print(f" {name:<35s} {val:>8.1f} MB")

print("\n" + "=" * 60)
ht_net = ht["after from_pretrained"] - ht["baseline"]
bridge_net = bridge["after enable_compatibility_mode"] - bridge["baseline"]
print(f"HookedTransformer net: {ht_net:>8.1f} MB")
print(f"TransformerBridge net: {bridge_net:>8.1f} MB")
print(f"Ratio: {bridge_net / ht_net:>8.1f}x")
print(f"Param size: {bridge['param_size_mb']:>8.1f} MB")
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Unit tests for Joint QKV Attention bridge."""

import copy

import torch

from transformer_lens.model_bridge.generalized_components.joint_qkv_attention import (
Expand Down Expand Up @@ -485,3 +487,46 @@ def __init__(self):
bridge,
position_embeddings=position_embeddings,
)

def test_deepcopy_does_not_copy_bound_method_self(self):
"""Deepcopy shares split_qkv_matrix and config instead of copying them."""

class FakeAdapter:
def __init__(self):
self.heavy_data = torch.randn(100, 100)

def split_qkv(self, component):
return (torch.nn.Linear(4, 4), torch.nn.Linear(4, 4), torch.nn.Linear(4, 4))

class TestConfig:
n_heads = 2
d_model = 4

adapter = FakeAdapter()
bridge = JointQKVAttentionBridge(
name="attn",
config=TestConfig(),
split_qkv_matrix=adapter.split_qkv,
)

clone = copy.deepcopy(bridge)

assert clone.split_qkv_matrix is bridge.split_qkv_matrix
assert clone.split_qkv_matrix.__self__ is adapter
assert clone.config is bridge.config

def test_deepcopy_produces_independent_hooks(self):
"""Deepcopy produces independent HookPoint and LinearBridge instances."""

class TestConfig:
n_heads = 2
d_model = 4

bridge = JointQKVAttentionBridge(name="attn", config=TestConfig())
clone = copy.deepcopy(bridge)

assert clone.hook_in is not bridge.hook_in
assert clone.hook_out is not bridge.hook_out
assert clone.q is not bridge.q
assert clone.k is not bridge.k
assert clone.v is not bridge.v
43 changes: 43 additions & 0 deletions tests/unit/model_bridge/test_component_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from types import SimpleNamespace

import pytest
import torch
import torch.nn as nn

from tests.mocks.architecture_adapter import MockArchitectureAdapter, mock_model_adapter
Expand All @@ -21,6 +22,9 @@
MLPBridge,
NormalizationBridge,
)
from transformer_lens.model_bridge.generalized_components.joint_qkv_attention import (
JointQKVAttentionBridge,
)


class TestComponentSetup:
Expand Down Expand Up @@ -286,6 +290,45 @@ def _create_fresh_mock_model(self):

return model

def test_setup_blocks_bridge_no_weight_duplication_with_bound_method(self):
"""Block deepcopy must not duplicate adapter via bound split_qkv_matrix."""
adapter = MockArchitectureAdapter()

class FakeAdapterWithSplitFn:
def __init__(self):
self.component_mapping = {}
self.heavy_tensor = torch.randn(256, 256)

def split_qkv(self, component):
return (
nn.Linear(10, 10, bias=False),
nn.Linear(10, 10, bias=False),
nn.Linear(10, 10, bias=False),
)

fake_adapter = FakeAdapterWithSplitFn()

blocks_template = BlockBridge(
name="blocks",
submodules={
"ln1": NormalizationBridge(name="ln1", config={}),
"attn": JointQKVAttentionBridge(
name="attn",
config=SimpleNamespace(n_heads=1, d_model=10),
split_qkv_matrix=fake_adapter.split_qkv,
),
},
)
fake_adapter.component_mapping["blocks"] = blocks_template

mock_model = self._create_fresh_mock_model()
bridged_blocks = setup_blocks_bridge(blocks_template, adapter, mock_model)

for block in bridged_blocks:
fn = block.attn.split_qkv_matrix
assert hasattr(fn, "__self__")
assert fn.__self__ is fake_adapter

@pytest.fixture
def mock_model_adapter(self):
"""Create a mock model for testing."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

This module contains the bridge component for attention layers that use a fused qkv matrix.
"""
import copy
from typing import Any, Callable, Dict, Optional

import einops
Expand Down Expand Up @@ -108,6 +109,35 @@ def __init__(
# Exclude stale qkv combined weights from state_dict after splitting.
self._register_state_dict_hook(JointQKVAttentionBridge._filter_qkv_state_dict)

def __deepcopy__(self, memo):
"""Share split_qkv_matrix and config across clones instead of copying.

split_qkv_matrix may be a bound method of the architecture adapter,
which transitively references the full HF model. Without this override,
deepcopy duplicates the entire model per block (~1GB x N_layers).
"""
saved_split_fn = self.split_qkv_matrix
saved_config = self.config

self.split_qkv_matrix = None # type: ignore[assignment]
self.config = None
try:
# Remove override from defining class (not subclass) to avoid recursion.
owner = JointQKVAttentionBridge
override = owner.__dict__["__deepcopy__"]
del owner.__deepcopy__
try:
clone = copy.deepcopy(self, memo)
finally:
owner.__deepcopy__ = override # type: ignore[method-assign]
finally:
self.split_qkv_matrix = saved_split_fn
self.config = saved_config

clone.split_qkv_matrix = saved_split_fn
clone.config = saved_config
return clone

@staticmethod
def _filter_qkv_state_dict(
module: torch.nn.Module,
Expand Down
Loading
Loading