fix(autoencoderkl): handle proj_attn→out_proj key mapping in load_old_state_dict#8786
fix(autoencoderkl): handle proj_attn→out_proj key mapping in load_old_state_dict#8786ytl0623 wants to merge 6 commits intoProject-MONAI:devfrom
proj_attn→out_proj key mapping in load_old_state_dict#8786Conversation
…_state_dict Signed-off-by: ytl0623 <david89062388@gmail.com>
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughUpdated AutoEncoderKL.load_old_state_dict to handle legacy attention projection keys ( Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes 🚥 Pre-merge checks | ✅ 3 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
📝 Coding Plan
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 4
🧹 Nitpick comments (2)
monai/networks/nets/autoencoderkl.py (2)
677-683: Missingverboseparameter in docstring.The
verboseargument is undocumented. As per coding guidelines, docstrings should describe each parameter.Proposed fix
""" Load a state dict from an AutoencoderKL trained with [MONAI Generative](https://github.com/Project-MONAI/GenerativeModels). Args: old_state_dict: state dict from the old AutoencoderKL model. + verbose: if True, print diagnostic information about key mismatches. """🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@monai/networks/nets/autoencoderkl.py` around lines 677 - 683, The docstring for load_old_state_dict is missing documentation for the verbose parameter; update the docstring of load_old_state_dict(old_state_dict: dict, verbose=False) to include a short description of verbose (e.g., its type and that it enables printing/logging extra info when True), add it under Args alongside old_state_dict, and ensure formatting matches existing docstring style used in the file.
733-733: EN DASH in comment.Static analysis flags ambiguous character:
–(EN DASH) should be-(HYPHEN-MINUS).Proposed fix
- # new model has no out_proj at all – discard the legacy keys so they + # new model has no out_proj at all - discard the legacy keys so they🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@monai/networks/nets/autoencoderkl.py` at line 733, In the comment inside monai/networks/nets/autoencoderkl.py (near the AutoencoderKL-related code) replace the ambiguous EN DASH character in the comment "new model has no out_proj at all – discard the legacy keys so they" with a standard hyphen-minus ("-") so the comment reads "new model has no out_proj at all - discard the legacy keys so they"; ensure the change is only to the comment text and does not alter any code identifiers or logic.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@monai/networks/nets/autoencoderkl.py`:
- Around line 735-736: The code unconditionally pops keys proj_w and proj_b from
old_state_dict which can raise KeyError if proj_b is missing; update the removal
to be safe by using old_state_dict.pop(proj_b, None) (or wrap with if proj_b in
old_state_dict: old_state_dict.pop(proj_b)) and similarly ensure safe removal
for proj_w if needed; locate the block referencing old_state_dict, proj_w, and
proj_b in autoencoderkl.py and replace the unconditional pop calls with the
safe-pop variant.
- Around line 730-731: The replacement of weights/biases using torch.eye and
torch.zeros in the block that builds new_state_dict (keys referenced as out_w
and out_b) ignores the target tensor's dtype and device; update the
initialization to create tensors matching new_state_dict[out_w].dtype and
.device (and similarly for out_b) so the identity and zero tensors use the same
dtype/device as the original parameters before calling load_state_dict. Use the
existing new_state_dict[out_w] and new_state_dict[out_b] to obtain .dtype and
.device and pass them to the creation routines (torch.eye / torch.zeros) to
avoid dtype/device mismatch.
- Around line 724-727: The branch that moves proj_w into out_w doesn't handle
the case where proj_b is missing: if proj_w exists but proj_b does not, set
new_state_dict[out_b] to an explicit zero tensor instead of leaving the
initialized bias; implement this in the same block where proj_w is handled (keys
proj_w/proj_b -> out_w/out_b) by creating a zero tensor with the expected bias
shape and matching dtype/device (e.g., using self.state_dict()[out_b].shape and
torch.zeros(..., dtype=old_state_dict[proj_w].dtype,
device=old_state_dict[proj_w].device)) and assign it to new_state_dict[out_b]
when proj_b is absent.
- Around line 718-736: Add focused unit tests for the migration logic that maps
legacy "proj_attn" keys to new "out_proj" keys in autoencoderkl.py: create three
tests that exercise the branches inside the block handling
out_w/out_b/proj_w/proj_b — (1) when proj_attn keys exist in old_state_dict
ensure new_state_dict[out_w]/out_b are copied from old_state_dict[proj_w]/proj_b
and those old keys are popped, (2) when proj_attn keys are absent but out_proj
exists ensure new_state_dict[out_w] is initialized to torch.eye(...) and
new_state_dict[out_b] to torch.zeros(...), and (3) when new model lacks out_proj
but old_state_dict contains proj_attn ensure old_state_dict.pop(proj_w) and
old_state_dict.pop(proj_b) are called so no unexpected keys remain; use small
mocked state dicts and assert exact tensor shapes/values and presence/absence of
keys after running the migration code (referencing new_state_dict,
old_state_dict, proj_attn, out_proj, out_w/out_b/proj_w/proj_b).
---
Nitpick comments:
In `@monai/networks/nets/autoencoderkl.py`:
- Around line 677-683: The docstring for load_old_state_dict is missing
documentation for the verbose parameter; update the docstring of
load_old_state_dict(old_state_dict: dict, verbose=False) to include a short
description of verbose (e.g., its type and that it enables printing/logging
extra info when True), add it under Args alongside old_state_dict, and ensure
formatting matches existing docstring style used in the file.
- Line 733: In the comment inside monai/networks/nets/autoencoderkl.py (near the
AutoencoderKL-related code) replace the ambiguous EN DASH character in the
comment "new model has no out_proj at all – discard the legacy keys so they"
with a standard hyphen-minus ("-") so the comment reads "new model has no
out_proj at all - discard the legacy keys so they"; ensure the change is only to
the comment text and does not alter any code identifiers or logic.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 8725cf73-eb5c-41e0-a4ca-16e52ecc31ca
📒 Files selected for processing (1)
monai/networks/nets/autoencoderkl.py
Signed-off-by: ytl0623 <david89062388@gmail.com>
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (2)
tests/networks/nets/test_autoencoderkl.py (2)
330-347: Add docstring to helper method.Per coding guidelines, definitions should have Google-style docstrings describing parameters and return values.
📝 Suggested docstring
`@staticmethod` def _new_to_old_sd(new_sd: dict, include_proj_attn: bool = True) -> dict: + """Convert new-style state dict keys to legacy naming conventions. + + Args: + new_sd: State dict with current key naming. + include_proj_attn: If True, map `.attn.out_proj.` to `.proj_attn.`. + + Returns: + State dict with legacy key names. + """ old_sd: dict = {}🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/networks/nets/test_autoencoderkl.py` around lines 330 - 347, Add a Google-style docstring to the helper method _new_to_old_sd explaining what the function does, documenting parameters (new_sd: dict of state_dict entries, include_proj_attn: bool controlling whether to map .attn.out_proj to .proj_attn) and the return value (dict mapping new keys to old-style keys), and include a brief note about cloning tensors; place the docstring immediately under the def _new_to_old_sd(...) signature.
357-357: Replace en-dash (–) with hyphen-minus (-) in assertion messages.Ruff RUF001 flags ambiguous Unicode characters on lines 357, 381, and 405.
- self.assertGreater(len(expected), 0, "No proj_attn keys in old state dict – check model config") + self.assertGreater(len(expected), 0, "No proj_attn keys in old state dict - check model config")Also applies to: 381-381, 405-405
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/networks/nets/test_autoencoderkl.py` at line 357, Replace the Unicode en-dash (–) with a hyphen-minus (-) in the assertion message strings to satisfy RUF001: update the assertion message in the self.assertGreater call that currently reads "No proj_attn keys in old state dict – check model config" to use "No proj_attn keys in old state dict - check model config", and make the same replacement for the two other assertion messages in the same test (the other similar self.assert* calls at the other occurrences).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/networks/nets/test_autoencoderkl.py`:
- Around line 349-367: The test class references self._MIGRATION_PARAMS but that
class attribute is missing, causing AttributeError; add a class-level attribute
named _MIGRATION_PARAMS on the test class (above the test methods) containing
the default migration configuration used to construct AutoencoderKL in these
tests (include keys required by the model like the minimal args the other tests
expect, e.g., those referenced when calling
AutoencoderKL(**self._MIGRATION_PARAMS) and the variant with "include_fc":
True), so the three migration tests can instantiate models without raising
AttributeError.
---
Nitpick comments:
In `@tests/networks/nets/test_autoencoderkl.py`:
- Around line 330-347: Add a Google-style docstring to the helper method
_new_to_old_sd explaining what the function does, documenting parameters
(new_sd: dict of state_dict entries, include_proj_attn: bool controlling whether
to map .attn.out_proj to .proj_attn) and the return value (dict mapping new keys
to old-style keys), and include a brief note about cloning tensors; place the
docstring immediately under the def _new_to_old_sd(...) signature.
- Line 357: Replace the Unicode en-dash (–) with a hyphen-minus (-) in the
assertion message strings to satisfy RUF001: update the assertion message in the
self.assertGreater call that currently reads "No proj_attn keys in old state
dict – check model config" to use "No proj_attn keys in old state dict - check
model config", and make the same replacement for the two other assertion
messages in the same test (the other similar self.assert* calls at the other
occurrences).
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 95852eb1-9a5d-41e8-91e9-e112adbd2fa9
📒 Files selected for processing (2)
monai/networks/nets/autoencoderkl.pytests/networks/nets/test_autoencoderkl.py
✅ Files skipped from review due to trivial changes (1)
- monai/networks/nets/autoencoderkl.py
There was a problem hiding this comment.
Actionable comments posted: 2
♻️ Duplicate comments (1)
tests/networks/nets/test_autoencoderkl.py (1)
349-367:⚠️ Potential issue | 🔴 Critical
_MIGRATION_PARAMSis undefined; tests will raiseAttributeError.Lines 351, 371, 399 reference
self._MIGRATION_PARAMSbut no such class attribute exists. Add it above the test methods:_MIGRATION_PARAMS = { "spatial_dims": 2, "in_channels": 1, "out_channels": 1, "channels": (4, 4, 4), "latent_channels": 4, "attention_levels": (False, False, True), "num_res_blocks": 1, "norm_num_groups": 4, }🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/networks/nets/test_autoencoderkl.py` around lines 349 - 367, Define the missing class attribute _MIGRATION_PARAMS used by test_load_old_state_dict_proj_attn_copied_to_out_proj and other tests: add a class-level dict named _MIGRATION_PARAMS with the provided keys (spatial_dims, in_channels, out_channels, channels, latent_channels, attention_levels, num_res_blocks, norm_num_groups) so that AutoencoderKL(**self._MIGRATION_PARAMS) calls succeed; ensure the attribute is placed on the same test class where the test methods live so references like self._MIGRATION_PARAMS resolve at runtime.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@monai/networks/nets/autoencoderkl.py`:
- Around line 724-739: When new_state_dict contains out_w but old_state_dict
does not contain proj_w (i.e., out_proj exists in the new model but legacy
proj_attn keys are missing), initialize out_proj as an identity mapping with
zero bias so loading an old state yields a no-op projection: set
new_state_dict[out_w] to an identity matrix (or appropriately-shaped identity
blocks if not square) matching the shape, dtype and device of
new_state_dict[out_w], and set new_state_dict[out_b] to zeros matching its
shape/dtype/device; implement this branch in the same conditional that checks if
out_w in new_state_dict (alongside the existing path that copies values from
old_state_dict) so
test_load_old_state_dict_missing_proj_attn_initialises_identity passes.
In `@tests/networks/nets/test_autoencoderkl.py`:
- Line 357: Replace the en-dash (–) characters in the test failure messages with
a hyphen-minus (-): change the message "No proj_attn keys in old state dict –
check model config" and the analogous messages used later in the same test file
(the two other assertions referenced in the review) so they read with a
hyphen-minus ("-") instead of an en-dash; locate these messages in
tests/networks/nets/test_autoencoderkl.py by searching for the exact strings
containing "–" (they appear in the assert failure messages) and update them to
use "-" so the test output contains ASCII hyphens.
---
Duplicate comments:
In `@tests/networks/nets/test_autoencoderkl.py`:
- Around line 349-367: Define the missing class attribute _MIGRATION_PARAMS used
by test_load_old_state_dict_proj_attn_copied_to_out_proj and other tests: add a
class-level dict named _MIGRATION_PARAMS with the provided keys (spatial_dims,
in_channels, out_channels, channels, latent_channels, attention_levels,
num_res_blocks, norm_num_groups) so that AutoencoderKL(**self._MIGRATION_PARAMS)
calls succeed; ensure the attribute is placed on the same test class where the
test methods live so references like self._MIGRATION_PARAMS resolve at runtime.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 7a81d04f-1b85-4366-bc2b-592b32e3a1ea
📒 Files selected for processing (2)
monai/networks/nets/autoencoderkl.pytests/networks/nets/test_autoencoderkl.py
Signed-off-by: ytl0623 <david89062388@gmail.com>
There was a problem hiding this comment.
♻️ Duplicate comments (1)
tests/networks/nets/test_autoencoderkl.py (1)
401-401:⚠️ Potential issue | 🟡 MinorReplace en-dash with hyphen-minus in assert messages.
Two strings still use–, which triggersRUF001.Proposed fix
- self.assertGreater(len(out_proj_weights), 0, "No out_proj keys found – check model config") + self.assertGreater(len(out_proj_weights), 0, "No out_proj keys found - check model config") ... - self.assertGreater(len(attn_blocks), 0, "No attention blocks found – check model config") + self.assertGreater(len(attn_blocks), 0, "No attention blocks found - check model config")Also applies to: 425-425
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/networks/nets/test_autoencoderkl.py` at line 401, The test assertion messages in tests/networks/nets/test_autoencoderkl.py use an en-dash character in the failure strings (e.g., inside the assertGreater call and the second occurrence around line 425); replace the Unicode en-dash (–) with a standard ASCII hyphen-minus (-) in those assertion message strings so they no longer trigger RUF001, updating the assertGreater message and the other test assertion message accordingly.
🧹 Nitpick comments (1)
tests/networks/nets/test_autoencoderkl.py (1)
369-438: Add docstrings to the three new test methods.
They currently have no docstrings, which violates repo guidance for modified Python definitions.As per coding guidelines, "Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings."
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/networks/nets/test_autoencoderkl.py` around lines 369 - 438, Add Google-style docstrings to the three test methods (test_load_old_state_dict_proj_attn_copied_to_out_proj, test_load_old_state_dict_missing_proj_attn_initialises_identity, test_load_old_state_dict_proj_attn_discarded_when_no_out_proj) describing what each test verifies, listing key parameters used from the local scope (e.g. params, src, old_sd, dst, loaded), the expected outcome, and any exceptions/assertions raised; use the Google sections (Args, Returns — typically None for tests, and Raises for assertion failures) with brief single-line descriptions so each function has a proper docstring per repo guidance.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@tests/networks/nets/test_autoencoderkl.py`:
- Line 401: The test assertion messages in
tests/networks/nets/test_autoencoderkl.py use an en-dash character in the
failure strings (e.g., inside the assertGreater call and the second occurrence
around line 425); replace the Unicode en-dash (–) with a standard ASCII
hyphen-minus (-) in those assertion message strings so they no longer trigger
RUF001, updating the assertGreater message and the other test assertion message
accordingly.
---
Nitpick comments:
In `@tests/networks/nets/test_autoencoderkl.py`:
- Around line 369-438: Add Google-style docstrings to the three test methods
(test_load_old_state_dict_proj_attn_copied_to_out_proj,
test_load_old_state_dict_missing_proj_attn_initialises_identity,
test_load_old_state_dict_proj_attn_discarded_when_no_out_proj) describing what
each test verifies, listing key parameters used from the local scope (e.g.
params, src, old_sd, dst, loaded), the expected outcome, and any
exceptions/assertions raised; use the Google sections (Args, Returns — typically
None for tests, and Raises for assertion failures) with brief single-line
descriptions so each function has a proper docstring per repo guidance.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 78e4a9ac-2545-4b8a-92a4-37c316c98835
📒 Files selected for processing (1)
tests/networks/nets/test_autoencoderkl.py
Signed-off-by: ytl0623 <david89062388@gmail.com>
There was a problem hiding this comment.
🧹 Nitpick comments (3)
tests/networks/nets/test_autoencoderkl.py (2)
403-409: Dtype mismatch possible in identity assertion.
torch.eye(n, device=device)defaults to float32. If model dtype differs, assertion fails. Consider matching dtype.Suggested fix
torch.testing.assert_close( loaded[k], - torch.eye(n, device=device), + torch.eye(n, dtype=loaded[k].dtype, device=device), msg=f"{k} should be an identity matrix", )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/networks/nets/test_autoencoderkl.py` around lines 403 - 409, The identity-matrix assertion can fail due to dtype mismatch because torch.eye defaults to float32; update the test loop over out_proj_weights to construct the identity with the same dtype as the loaded tensor (e.g., create the identity using the dtype of loaded[k] or cast one side) before calling torch.testing.assert_close so both tensors share device and dtype; refer to variables out_proj_weights, loaded and the assert call to locate where to make this change.
172-181: Mutable class attribute flagged by static analysis.
_MIGRATION_PARAMSis a dict (mutable). Low risk here since tests only read it, but type annotation would silence RUF012.Optional: Add ClassVar annotation
+from typing import ClassVar + class TestAutoEncoderKL(unittest.TestCase): - _MIGRATION_PARAMS = { + _MIGRATION_PARAMS: ClassVar[dict] = { "spatial_dims": 2,🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/networks/nets/test_autoencoderkl.py` around lines 172 - 181, The mutable dict _MIGRATION_PARAMS is flagged by static analysis; annotate it as a class-level constant to silence RUF012 by adding a ClassVar type annotation (e.g., ClassVar[dict[str, Any]] or ClassVar[Mapping[str, Any]]) and import ClassVar (and Any/Mapping if needed) from typing; update the declaration of _MIGRATION_PARAMS to include the ClassVar[...] annotation so linters recognize it as an intended class-level constant used only for reads.monai/networks/nets/autoencoderkl.py (1)
719-751: Logic looks correct for the three migration scenarios.Handles: (1) copy
proj_attn→out_proj, (2) identity/zero init whenproj_attnmissing, (3) discard legacy keys when model lacksout_proj. dtype/device properly propagated.Minor: Line 736 uses en-dash (–) instead of hyphen (-) per static analysis.
Fix en-dash
- # No legacy proj_attn – initialize out_proj to identity/zero + # No legacy proj_attn - initialize out_proj to identity/zero🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@monai/networks/nets/autoencoderkl.py` around lines 719 - 751, Replace the en-dash in the inline comment that explains discarding legacy keys with a normal hyphen; locate the comment near the legacy-key handling where proj_w/proj_b are popped (in autoencoderkl.py within the block using new_state_dict, old_state_dict and identifiers proj_w/proj_b) and change the “–” character to a "-" so the comment uses a standard hyphen.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@monai/networks/nets/autoencoderkl.py`:
- Around line 719-751: Replace the en-dash in the inline comment that explains
discarding legacy keys with a normal hyphen; locate the comment near the
legacy-key handling where proj_w/proj_b are popped (in autoencoderkl.py within
the block using new_state_dict, old_state_dict and identifiers proj_w/proj_b)
and change the “–” character to a "-" so the comment uses a standard hyphen.
In `@tests/networks/nets/test_autoencoderkl.py`:
- Around line 403-409: The identity-matrix assertion can fail due to dtype
mismatch because torch.eye defaults to float32; update the test loop over
out_proj_weights to construct the identity with the same dtype as the loaded
tensor (e.g., create the identity using the dtype of loaded[k] or cast one side)
before calling torch.testing.assert_close so both tensors share device and
dtype; refer to variables out_proj_weights, loaded and the assert call to locate
where to make this change.
- Around line 172-181: The mutable dict _MIGRATION_PARAMS is flagged by static
analysis; annotate it as a class-level constant to silence RUF012 by adding a
ClassVar type annotation (e.g., ClassVar[dict[str, Any]] or
ClassVar[Mapping[str, Any]]) and import ClassVar (and Any/Mapping if needed)
from typing; update the declaration of _MIGRATION_PARAMS to include the
ClassVar[...] annotation so linters recognize it as an intended class-level
constant used only for reads.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 559affb7-c993-4aaa-a252-16bdab10bc27
📒 Files selected for processing (2)
monai/networks/nets/autoencoderkl.pytests/networks/nets/test_autoencoderkl.py
Signed-off-by: ytl0623 <david89062388@gmail.com>
Signed-off-by: ytl0623 <david89062388@gmail.com>
Fixes #8544
Description
Map
proj_attn→out_projwhen both exist, initialiseout_projto identity/zero when only the new model has it, and silently discardproj_attnwhen only the old checkpoint has it.Types of changes
./runtests.sh -f -u --net --coverage../runtests.sh --quick --unittests --disttests.make htmlcommand in thedocs/folder.