diff --git a/monai/networks/nets/autoencoderkl.py b/monai/networks/nets/autoencoderkl.py index b5a282a340..11b4fcfc9e 100644 --- a/monai/networks/nets/autoencoderkl.py +++ b/monai/networks/nets/autoencoderkl.py @@ -680,6 +680,7 @@ def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None: Args: old_state_dict: state dict from the old AutoencoderKL model. + verbose: if True, print diagnostic information about key mismatches. """ new_state_dict = self.state_dict() @@ -715,13 +716,39 @@ def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None: new_state_dict[f"{block}.attn.to_k.bias"] = old_state_dict.pop(f"{block}.to_k.bias") new_state_dict[f"{block}.attn.to_v.bias"] = old_state_dict.pop(f"{block}.to_v.bias") - # old version did not have a projection so set these to the identity - new_state_dict[f"{block}.attn.out_proj.weight"] = torch.eye( - new_state_dict[f"{block}.attn.out_proj.weight"].shape[0] - ) - new_state_dict[f"{block}.attn.out_proj.bias"] = torch.zeros( - new_state_dict[f"{block}.attn.out_proj.bias"].shape - ) + out_w = f"{block}.attn.out_proj.weight" + out_b = f"{block}.attn.out_proj.bias" + proj_w = f"{block}.proj_attn.weight" + proj_b = f"{block}.proj_attn.bias" + + if out_w in new_state_dict: + if proj_w in old_state_dict: + new_state_dict[out_w] = old_state_dict.pop(proj_w) + if proj_b in old_state_dict: + new_state_dict[out_b] = old_state_dict.pop(proj_b) + else: + new_state_dict[out_b] = torch.zeros( + new_state_dict[out_b].shape, + dtype=new_state_dict[out_b].dtype, + device=new_state_dict[out_b].device, + ) + else: + # No legacy proj_attn - initialize out_proj to identity/zero + new_state_dict[out_w] = torch.eye( + new_state_dict[out_w].shape[0], + dtype=new_state_dict[out_w].dtype, + device=new_state_dict[out_w].device, + ) + new_state_dict[out_b] = torch.zeros( + new_state_dict[out_b].shape, + dtype=new_state_dict[out_b].dtype, + device=new_state_dict[out_b].device, + ) + elif proj_w in old_state_dict: + # new model has no out_proj at all - discard the legacy keys so they + # don't surface as "unexpected keys" during load_state_dict + old_state_dict.pop(proj_w) + old_state_dict.pop(proj_b, None) # fix the upsample conv blocks which were renamed postconv for k in new_state_dict: diff --git a/tests/networks/nets/test_autoencoderkl.py b/tests/networks/nets/test_autoencoderkl.py index bbe2840164..af0c55d6ec 100644 --- a/tests/networks/nets/test_autoencoderkl.py +++ b/tests/networks/nets/test_autoencoderkl.py @@ -169,6 +169,17 @@ class TestAutoEncoderKL(unittest.TestCase): + _MIGRATION_PARAMS = { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + } + @parameterized.expand(CASES) def test_shape(self, input_param, input_shape, expected_shape, expected_latent_shape): net = AutoencoderKL(**input_param).to(device) @@ -327,6 +338,96 @@ def test_compatibility_with_monai_generative(self): net.load_old_state_dict(torch.load(weight_path, weights_only=True), verbose=False) + @staticmethod + def _new_to_old_sd(new_sd: dict, include_proj_attn: bool = True) -> dict: + """Convert new-style state dict keys to legacy naming conventions. + + Args: + new_sd: State dict with current key naming. + include_proj_attn: If True, map `.attn.out_proj.` to `.proj_attn.`. + + Returns: + State dict with legacy key names. + """ + old_sd: dict = {} + for k, v in new_sd.items(): + if ".attn.to_q." in k: + old_sd[k.replace(".attn.to_q.", ".to_q.")] = v.clone() + elif ".attn.to_k." in k: + old_sd[k.replace(".attn.to_k.", ".to_k.")] = v.clone() + elif ".attn.to_v." in k: + old_sd[k.replace(".attn.to_v.", ".to_v.")] = v.clone() + elif ".attn.out_proj." in k: + if include_proj_attn: + old_sd[k.replace(".attn.out_proj.", ".proj_attn.")] = v.clone() + elif "postconv" in k: + old_sd[k.replace("postconv", "conv")] = v.clone() + else: + old_sd[k] = v.clone() + return old_sd + + @skipUnless(has_einops, "Requires einops") + def test_load_old_state_dict_proj_attn_copied_to_out_proj(self): + params = {**self._MIGRATION_PARAMS, "include_fc": True} + src = AutoencoderKL(**params).to(device) + old_sd = self._new_to_old_sd(src.state_dict(), include_proj_attn=True) + + # record the tensor values that were stored under proj_attn + expected = {k.replace(".proj_attn.", ".attn.out_proj."): v for k, v in old_sd.items() if ".proj_attn." in k} + self.assertGreater(len(expected), 0, "No proj_attn keys in old state dict - check model config") + + dst = AutoencoderKL(**params).to(device) + dst.load_old_state_dict(old_sd) + + for new_key, expected_val in expected.items(): + torch.testing.assert_close( + dst.state_dict()[new_key], expected_val.to(device), msg=f"Weight mismatch for {new_key}" + ) + + @skipUnless(has_einops, "Requires einops") + def test_load_old_state_dict_missing_proj_attn_initialises_identity(self): + params = {**self._MIGRATION_PARAMS, "include_fc": True} + src = AutoencoderKL(**params).to(device) + old_sd = self._new_to_old_sd(src.state_dict(), include_proj_attn=False) + + dst = AutoencoderKL(**params).to(device) + dst.load_old_state_dict(old_sd) + loaded = dst.state_dict() + + out_proj_weights = [k for k in loaded if "attn.out_proj.weight" in k] + out_proj_biases = [k for k in loaded if "attn.out_proj.bias" in k] + self.assertGreater(len(out_proj_weights), 0, "No out_proj keys found - check model config") + + for k in out_proj_weights: + n = loaded[k].shape[0] + torch.testing.assert_close( + loaded[k], torch.eye(n, dtype=loaded[k].dtype, device=device), msg=f"{k} should be an identity matrix" + ) + for k in out_proj_biases: + torch.testing.assert_close(loaded[k], torch.zeros_like(loaded[k]), msg=f"{k} should be all-zeros") + + @skipUnless(has_einops, "Requires einops") + def test_load_old_state_dict_proj_attn_discarded_when_no_out_proj(self): + params = {**self._MIGRATION_PARAMS, "include_fc": False} + src = AutoencoderKL(**params).to(device) + old_sd = self._new_to_old_sd(src.state_dict(), include_proj_attn=False) + + # inject synthetic proj_attn keys (mimic an old checkpoint) + attn_blocks = [k.replace(".to_q.weight", "") for k in old_sd if k.endswith(".to_q.weight")] + self.assertGreater(len(attn_blocks), 0, "No attention blocks found - check model config") + for block in attn_blocks: + ch = old_sd[f"{block}.to_q.weight"].shape[0] + old_sd[f"{block}.proj_attn.weight"] = torch.randn(ch, ch) + old_sd[f"{block}.proj_attn.bias"] = torch.randn(ch) + + dst = AutoencoderKL(**params).to(device) + dst.load_old_state_dict(old_sd) + + loaded = dst.state_dict() + self.assertFalse( + any("out_proj" in k for k in loaded), "out_proj should not exist in a model built with include_fc=False" + ) + if __name__ == "__main__": unittest.main()