Skip to content

Fix tests broken by a local GPU#1219

Open
brendanlong wants to merge 2 commits intoTransformerLensOrg:devfrom
brendanlong:brendan/fix-gpu-tests
Open

Fix tests broken by a local GPU#1219
brendanlong wants to merge 2 commits intoTransformerLensOrg:devfrom
brendanlong:brendan/fix-gpu-tests

Conversation

@brendanlong
Copy link
Copy Markdown
Contributor

@brendanlong brendanlong commented Mar 28, 2026

Description

Two of the tests fail if you have a local GPU.

test_apertus failed because it creates CPU tensors and then compares them to tensors created on the default device. I fixed this by updating the model config to use the CPU as well.

test_hooked_encoder has a CUDA-only test (test_cuda) which expects an undefined mlm_tokens fixture. Presumably it passes in CI because CUDA isn't available. I updated this to use tokens instead.

Before:

> poetry run pytest tests/unit/pretrained_weight_conversions/test_apertus.py tests/acceptance/test_hooked_encoder.py 
========================================================= test session starts =========================================================
platform linux -- Python 3.12.9, pytest-8.3.5, pluggy-1.5.0
rootdir: /home/brendanlong/workspace/TransformerLens
configfile: pyproject.toml
plugins: jaxtyping-0.2.19, doctestplus-1.3.0, nbval-0.10.0, cov-5.0.0, anyio-4.5.2, typeguard-4.4.0
collected 29 items                                                                                                                    

tests/unit/pretrained_weight_conversions/test_apertus.py ........F...                                                           [ 41%]
tests/acceptance/test_hooked_encoder.py ................E                                                                       [100%]

=============================================================== ERRORS ================================================================
_____________________________________________________ ERROR at setup of test_cuda _____________________________________________________
file /home/brendanlong/workspace/TransformerLens/tests/acceptance/test_hooked_encoder.py, line 224
  @pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires a CUDA device")
  def test_cuda(mlm_tokens):
E       fixture 'mlm_tokens' not found
>       available fixtures: anyio_backend, anyio_backend_name, anyio_backend_options, cache, capfd, capfdbinary, caplog, capsys, capsysbinary, cov, doctest_namespace, huggingface_bert, monkeypatch, no_cover, our_bert, pytestconfig, record_property, record_testsuite_property, record_xml_attribute, recwarn, tmp_path, tmp_path_factory, tmpdir, tmpdir_factory, tokenizer, tokens
>       use 'pytest --fixtures [testpath]' for help on them.

/home/brendanlong/workspace/TransformerLens/tests/acceptance/test_hooked_encoder.py:224
============================================================== FAILURES ===============================================================
___________________________________ TestApertusWeightConversion.test_xielu_params_fallback_defaults ___________________________________

self = <test_apertus.TestApertusWeightConversion object at 0x753ad8da8080>

    def test_xielu_params_fallback_defaults(self):
        """When activation params aren't found, defaults should be used."""
        cfg = make_cfg()
        model = make_mock_model(cfg, has_act_fn=False)
        # Also remove alternate attribute paths
        layer = model.model.layers[0]
        if hasattr(layer.mlp, "act"):
            del layer.mlp.act
        if hasattr(layer.mlp, "alpha_p"):
            del layer.mlp.alpha_p
    
        sd = convert_apertus_weights(model, cfg)
>       torch.testing.assert_close(
            sd["blocks.0.mlp.act_fn.alpha_p"], torch.tensor(0.8, dtype=cfg.dtype)
        )
E       AssertionError: The values for attribute 'device' do not match: cuda:0 != cpu.

tests/unit/pretrained_weight_conversions/test_apertus.py:157: AssertionError
---------------------------------------------------------- Captured log call ----------------------------------------------------------
WARNING  transformer_lens.pretrained.weight_conversions.apertus:apertus.py:117 XIeLU activation parameters not found in layer 0, using defaults
========================================================== warnings summary ===========================================================
../../.cache/pypoetry/virtualenvs/transformer-lens-MODGY9P9-py3.12/lib/python3.12/site-packages/jaxtyping/import_hook.py:106: 52 warnings
  /home/brendanlong/.cache/pypoetry/virtualenvs/transformer-lens-MODGY9P9-py3.12/lib/python3.12/site-packages/jaxtyping/import_hook.py:106: DeprecationWarning: ast.Str is deprecated and will be removed in Python 3.14; use ast.Constant instead
    elif isinstance(child, ast.Expr) and isinstance(child.value, ast.Str):

transformer_lens/SVDInterpreter.py:26
  /home/brendanlong/workspace/TransformerLens/transformer_lens/SVDInterpreter.py:26: InstrumentationWarning: @typechecked only supports instrumenting functions wrapped with @classmethod, @staticmethod or @property -- not typechecking transformer_lens.SVDInterpreter.SVDInterpreter.get_singular_vectors
    @typechecked

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
======================================================= short test summary info =======================================================
FAILED tests/unit/pretrained_weight_conversions/test_apertus.py::TestApertusWeightConversion::test_xielu_params_fallback_defaults - AssertionError: The values for attribute 'device' do not match: cuda:0 != cpu.
ERROR tests/acceptance/test_hooked_encoder.py::test_cuda
========================================= 1 failed, 27 passed, 53 warnings, 1 error in 15.02s =========================================

After:

> poetry run pytest tests/unit/pretrained_weight_conversions/test_apertus.py tests/acceptance/test_hooked_encoder.py 
========================================================= test session starts =========================================================
platform linux -- Python 3.12.9, pytest-8.3.5, pluggy-1.5.0
rootdir: /home/brendanlong/workspace/TransformerLens
configfile: pyproject.toml
plugins: jaxtyping-0.2.19, doctestplus-1.3.0, nbval-0.10.0, cov-5.0.0, anyio-4.5.2, typeguard-4.4.0
collected 29 items                                                                                                                    

tests/unit/pretrained_weight_conversions/test_apertus.py ............                                                           [ 41%]
tests/acceptance/test_hooked_encoder.py .................                                                                       [100%]

========================================================== warnings summary ===========================================================
../../.cache/pypoetry/virtualenvs/transformer-lens-MODGY9P9-py3.12/lib/python3.12/site-packages/jaxtyping/import_hook.py:106: 52 warnings
  /home/brendanlong/.cache/pypoetry/virtualenvs/transformer-lens-MODGY9P9-py3.12/lib/python3.12/site-packages/jaxtyping/import_hook.py:106: DeprecationWarning: ast.Str is deprecated and will be removed in Python 3.14; use ast.Constant instead
    elif isinstance(child, ast.Expr) and isinstance(child.value, ast.Str):

transformer_lens/SVDInterpreter.py:26
  /home/brendanlong/workspace/TransformerLens/transformer_lens/SVDInterpreter.py:26: InstrumentationWarning: @typechecked only supports instrumenting functions wrapped with @classmethod, @staticmethod or @property -- not typechecking transformer_lens.SVDInterpreter.SVDInterpreter.get_singular_vectors
    @typechecked

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
================================================== 29 passed, 53 warnings in 15.71s ===================================================

I'm also able to run the full CI suite:

> poetry run pytest tests/unit tests/integration tests/acceptance
[...]
94 passed, 30 skipped, 59 warnings in 327.31s (0:05:27)

Type of change

  • Bug fix (non-breaking change which fixes an issue)

Checklist:

  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes
  • I have not rewritten tests relating to key interfaces which would affect backward compatibility

@brendanlong
Copy link
Copy Markdown
Contributor Author

The CI test failure seems to be caused by HuggingFace rate limiting, which is confusing since I thought you had HF_TOKEN set. Either way it's not caused by my changes and should work if you re-run them.

@brendanlong brendanlong changed the base branch from main to dev March 28, 2026 23:40
@brendanlong brendanlong force-pushed the brendan/fix-gpu-tests branch from cf43c73 to d3a844f Compare April 3, 2026 04:08
brendanlong and others added 2 commits April 2, 2026 21:11
Tensor equality includes the device, so set device="cpu" so
weight tensors always match expected, even if there's  GPU
they could be created on.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
The test_cuda function referenced a fixture named mlm_tokens which was
never defined, causing a fixture-not-found error. Changed to use the
existing tokens fixture which provides the same MLM-style tokenized input.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@brendanlong brendanlong force-pushed the brendan/fix-gpu-tests branch from d3a844f to a476e09 Compare April 3, 2026 04:11
"unembed.b_U",
]:
assert sd[key].device.type == cfg.device.type, f"{key} on wrong device"
assert sd[key].device.type == cfg.device, f"{key} on wrong device"
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to remove .type since cfg.device is a str now. There's a pre-existing type issue here since cfg.device has type str | None so the old code shouldn't have compiled.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See #1230

brendanlong added a commit to brendanlong/TransformerLens that referenced this pull request Apr 3, 2026
This is either a torch.device or a string like "cpu", but it was
typed as just `Optional[str]`. This fixes it to be
`Optional[Union[str, torch.device]]` and all of the downstream
places that need to be updated.

Found while working on TransformerLensOrg#1219
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants