From cea0bb9a53045ead400fac8f2d9396a8d19c3c88 Mon Sep 17 00:00:00 2001 From: Zihao Xue Date: Mon, 6 Apr 2026 23:06:01 -0600 Subject: [PATCH 1/8] refactor: add prepare_saving as config property and get util function: get_save_before_smash_dir --- src/pruna/algorithms/base/pruna_base.py | 10 ++++++++-- src/pruna/config/smash_config.py | 5 +++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/src/pruna/algorithms/base/pruna_base.py b/src/pruna/algorithms/base/pruna_base.py index 0784069b..0cce1e68 100644 --- a/src/pruna/algorithms/base/pruna_base.py +++ b/src/pruna/algorithms/base/pruna_base.py @@ -16,6 +16,7 @@ import functools from abc import ABC, abstractmethod +from pathlib import Path from typing import Any, Dict, Iterable from transformers import Pipeline @@ -355,8 +356,8 @@ def apply(self, model: Any, smash_config: SmashConfig) -> Any: Any The model after the algorithm has been applied. """ - if self.save_fn == SAVE_FUNCTIONS.save_before_apply and smash_config._prepare_saving: - save_dir = smash_config.cache_dir / SAVE_BEFORE_SMASH_CACHE_DIR + if self.save_fn == SAVE_FUNCTIONS.save_before_apply and smash_config.prepare_saving: + save_dir = self.get_save_before_smash_dir(smash_config) save_pruna_model(model, save_dir, smash_config) # save algorithms to reapply after loading @@ -447,6 +448,11 @@ def get_algorithms_to_run_after_disjointly(self) -> list[str]: """ return _expand_tags_into_algorithm_names(self.disjointly_compatible_after) + @staticmethod + def get_save_before_smash_dir(smash_config: SmashConfig) -> Path: + """Get the save directory for the algorithm caches.""" + return smash_config.cache_dir / SAVE_BEFORE_SMASH_CACHE_DIR + def wrap_handle_imports(func): """ diff --git a/src/pruna/config/smash_config.py b/src/pruna/config/smash_config.py index 0acc1e12..81429b48 100644 --- a/src/pruna/config/smash_config.py +++ b/src/pruna/config/smash_config.py @@ -119,6 +119,11 @@ def __init__( raise ValueError(f"Unsupported configuration type: {type(configuration)}") self.config_space: ConfigurationSpace = self._configuration.config_space + @property + def prepare_saving(self): + """Getter of _prepare_saving as an object's internal data.""" + return self._prepare_saving + @classmethod def from_list( cls, From 67880c11facbd1bf719cc6723368722cc096abc3 Mon Sep 17 00:00:00 2001 From: Zihao Xue Date: Mon, 6 Apr 2026 23:28:06 -0600 Subject: [PATCH 2/8] refactor: save_fns won't append any new save_fn; save recovered model's weights for "save_before_apply" algos --- .../global_utils/recovery/perp_recoverer.py | 46 +++++++++++++++++-- 1 file changed, 43 insertions(+), 3 deletions(-) diff --git a/src/pruna/algorithms/global_utils/recovery/perp_recoverer.py b/src/pruna/algorithms/global_utils/recovery/perp_recoverer.py index 4b65d9df..8d7941c0 100644 --- a/src/pruna/algorithms/global_utils/recovery/perp_recoverer.py +++ b/src/pruna/algorithms/global_utils/recovery/perp_recoverer.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import shutil from typing import Any, Dict import torch @@ -23,7 +24,7 @@ from pruna.algorithms.global_utils.recovery.finetuners import PrunaFinetuner from pruna.algorithms.global_utils.recovery.finetuners.diffusers.utils import get_denoiser_attr from pruna.algorithms.global_utils.recovery.utils import get_trainable_parameters -from pruna.config.smash_config import SmashConfigPrefixWrapper +from pruna.config.smash_config import SmashConfig, SmashConfigPrefixWrapper from pruna.engine.model_checks import ( is_causal_lm, is_flux_pipeline, @@ -31,7 +32,7 @@ is_sd_pipeline, is_sdxl_pipeline, ) -from pruna.engine.save import SAVE_FUNCTIONS +from pruna.engine.save import SAVE_FUNCTIONS, save_pruna_model from pruna.logging.logger import pruna_logger @@ -52,7 +53,7 @@ class PERPRecoverer(PrunaAlgorithmBase): """ group_tags: list[AlgorithmTag] = [AlgorithmTag.RECOVERER] # type: ignore[attr-defined] - save_fn = SAVE_FUNCTIONS.pickled + save_fn = None references: dict[str, str] = { "GitHub": "https://github.com/huggingface/peft", "Paper": "https://arxiv.org/pdf/2312.15230", @@ -181,6 +182,45 @@ def _pre_smash_hook(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> adapter_smash_config = SmashConfigPrefixWrapper(smash_config, adapter.adapter_prefix + "_") adapter.pre_smash_hook(model_recovery, adapter_smash_config, seed=adapter_seed) + def apply(self, model: Any, smash_config: SmashConfig) -> Any: + """ + Apply the recovery algorithm and refresh the save cache if needed. + + Recovery modifies weights in-place without changing the model's serialization + format. If a prior algorithm used ``save_before_apply`` (caching the model before + its transformation), the cached snapshot is now stale because recovery changed + the weights. This override refreshes that cache so the already saved model includes + the recovered weights. + + Parameters + ---------- + model : Any + The model to apply the algorithm to. + smash_config : SmashConfig + The SmashConfig object containing the save and load functions. + + Returns + ------- + Any + The model after recovery has been applied. + """ + result = super().apply(model, smash_config) + + if smash_config.prepare_saving: + save_dir = self.get_save_before_smash_dir(smash_config) + if not save_dir.exists(): + return result + + ori_save_fns = smash_config.save_fns[:] + smash_config.save_fns = [fn for fn in smash_config.save_fns if fn != SAVE_FUNCTIONS.save_before_apply.name] + # Re-save with recovered weights + shutil.rmtree(save_dir, ignore_errors=True) + save_dir.mkdir(parents=True) + save_pruna_model(model, save_dir, smash_config) + # Restore save_fns + smash_config.save_fns = ori_save_fns + return result + def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: """ Recover performances from a given model with a given config. From c86401d648b17848c819be0c91b75c9b84e915c0 Mon Sep 17 00:00:00 2001 From: Zihao Xue Date: Mon, 6 Apr 2026 23:31:43 -0600 Subject: [PATCH 3/8] test: add unit tests to verify main ideas --- tests/engine/test_save.py | 81 +++++++++++++++++++++++++++++++++++---- 1 file changed, 73 insertions(+), 8 deletions(-) diff --git a/tests/engine/test_save.py b/tests/engine/test_save.py index 2cd6f3e1..1191836a 100644 --- a/tests/engine/test_save.py +++ b/tests/engine/test_save.py @@ -1,18 +1,18 @@ import os -import pytest -import torch +import shutil from pathlib import Path from unittest.mock import patch + +import pytest +import torch +from diffusers import DiffusionPipeline from transformers import AutoModelForCausalLM -from pruna.config.smash_config import SmashConfig + from pruna import smash -from pruna.engine.save import save_pruna_model -from pruna.engine.save import save_pruna_model_to_hub -from pruna.engine.save import SAVE_FUNCTIONS -from pruna.engine.load import load_pruna_model from pruna.config.smash_config import SmashConfig -from diffusers import DiffusionPipeline +from pruna.engine.load import PICKLED_FILE_NAME, load_pruna_model from pruna.engine.pruna_model import PrunaModel +from pruna.engine.save import SAVE_FUNCTIONS, save_pruna_model, save_pruna_model_to_hub @pytest.mark.slow @@ -160,3 +160,68 @@ def test_push_to_hub_path_types(tmp_path) -> None: private=True ) assert mock_upload.called + + +@pytest.mark.cpu +def test_recovery_save_fn_is_none() -> None: + """Test that recovery algorithms use save_fn=None, preserving the prior algorithm's save format.""" + from pruna.algorithms.global_utils.recovery.perp_recoverer import PERPRecoverer + + assert PERPRecoverer.save_fn is None + + +@pytest.mark.cpu +def test_recovery_does_not_add_to_save_fns(tmp_path) -> None: + """Test that recovery's apply() does not append to save_fns when save_fn is None.""" + + config = SmashConfig() + config.save_fns = ["hqq"] # simulate a prior algorithm's save_fn + + save_fn = None + + # PrunaAlgorithmBase apply logic + if save_fn is not None and save_fn != SAVE_FUNCTIONS.reapply: + config.save_fns.append(save_fn.name) + + assert config.save_fns == ["hqq"], "Recovery should not add to save_fns" + + +@pytest.mark.cpu +def test_recovery_refresh_save_cache(tmp_path) -> None: + """Test that recovery refreshes a stale save_before_apply cache with recovered weights.""" + from pruna.algorithms.base.pruna_base import PrunaAlgorithmBase + + model = AutoModelForCausalLM.from_pretrained("yujiepan/opt-tiny-random") + + config = SmashConfig(device="cpu") + + # Simulate a save_before_apply algorithm having run before recovery: + # 1. Save original (pre-transformation) model to cache + save_dir = PrunaAlgorithmBase.get_save_before_smash_dir(config) + save_dir.mkdir(parents=True) + save_pruna_model(model, save_dir, config) + + # 2. Mark save_before_apply in save_fns (as the algorithm would) + config.save_fns.append(SAVE_FUNCTIONS.save_before_apply.name) + + # 3. Simulate the transformation (e.g., half) + recovery modifying weights + model.lm_head.weight.data.fill_(0.99) # "recovered" weights + + # 4. Simulate what recovery's apply() does: refresh the stale cache + ori_save_fns = config.save_fns[:] + config.save_fns = [fn for fn in config.save_fns if fn != SAVE_FUNCTIONS.save_before_apply.name] + shutil.rmtree(save_dir, ignore_errors=True) + save_dir.mkdir(parents=True) + save_pruna_model(model, save_dir, config) + config.save_fns = ori_save_fns + + # 5. Verify the cache was refreshed: save_before_apply should copy updated files + save_path = tmp_path / "final_model" + save_pruna_model(model, save_path, config) + + # Load and verify the recovered weights survived the round-trip + loaded_model, _ = load_pruna_model(save_path) + loaded_model = loaded_model.cpu() + assert torch.allclose( + loaded_model.lm_head.weight, torch.full_like(loaded_model.lm_head.weight, 0.99) + ), "Recovered weights should survive save/load through save_before_apply" From 10627a395d30bb95ad7a785546d763853f0cec1a Mon Sep 17 00:00:00 2001 From: Zihao Xue Date: Mon, 6 Apr 2026 23:39:30 -0600 Subject: [PATCH 4/8] fix: remove unused imports from test_save.py --- tests/engine/test_save.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/engine/test_save.py b/tests/engine/test_save.py index 1191836a..34fd30b6 100644 --- a/tests/engine/test_save.py +++ b/tests/engine/test_save.py @@ -10,7 +10,7 @@ from pruna import smash from pruna.config.smash_config import SmashConfig -from pruna.engine.load import PICKLED_FILE_NAME, load_pruna_model +from pruna.engine.load import load_pruna_model from pruna.engine.pruna_model import PrunaModel from pruna.engine.save import SAVE_FUNCTIONS, save_pruna_model, save_pruna_model_to_hub From e5aaccd2e2a918da39d82334439640aac0a34ce4 Mon Sep 17 00:00:00 2001 From: Zihao Xue Date: Wed, 8 Apr 2026 09:31:00 -0600 Subject: [PATCH 5/8] docs: add docstring for get_save_before_smash_dir --- src/pruna/algorithms/base/pruna_base.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/pruna/algorithms/base/pruna_base.py b/src/pruna/algorithms/base/pruna_base.py index 0cce1e68..fa9a0405 100644 --- a/src/pruna/algorithms/base/pruna_base.py +++ b/src/pruna/algorithms/base/pruna_base.py @@ -450,8 +450,20 @@ def get_algorithms_to_run_after_disjointly(self) -> list[str]: @staticmethod def get_save_before_smash_dir(smash_config: SmashConfig) -> Path: - """Get the save directory for the algorithm caches.""" - return smash_config.cache_dir / SAVE_BEFORE_SMASH_CACHE_DIR + """ + Get the save directory for the algorithm caches. + + Parameters + ---------- + smash_config : SmashConfig + The SmashConfig to check the cache directory against. + + Returns + ------- + Path + The absolute path of "SAVE_BEFORE_SMASH_CACHE_DIR". + """ + return (smash_config.cache_dir / SAVE_BEFORE_SMASH_CACHE_DIR).resolve() def wrap_handle_imports(func): From 799bd31cfce4de3ef90668f0117d9a62cf4cbaf0 Mon Sep 17 00:00:00 2001 From: Zihao Xue Date: Wed, 13 May 2026 03:56:21 -0700 Subject: [PATCH 6/8] refactor: move cache refresh logic into post_apply_hook func --- src/pruna/algorithms/base/pruna_base.py | 18 ++++++- .../global_utils/recovery/perp_recoverer.py | 52 ++++--------------- src/pruna/engine/save.py | 35 +++++++++++-- tests/engine/test_save.py | 32 +----------- 4 files changed, 62 insertions(+), 75 deletions(-) diff --git a/src/pruna/algorithms/base/pruna_base.py b/src/pruna/algorithms/base/pruna_base.py index fa9a0405..16512372 100644 --- a/src/pruna/algorithms/base/pruna_base.py +++ b/src/pruna/algorithms/base/pruna_base.py @@ -370,7 +370,23 @@ def apply(self, model: Any, smash_config: SmashConfig) -> Any: prefix = self.algorithm_name + "_" wrapped_config = SmashConfigPrefixWrapper(smash_config, prefix) - return self._apply(model, wrapped_config) + result = self._apply(model, wrapped_config) + + self.post_apply_hook(model, smash_config) + return result + + def post_apply_hook(self, model: Any, smash_config: SmashConfig) -> None: + """ + Post apply hook called after _apply returns to run side effects after the algorithm applies. + + Parameters + ---------- + model : Any + The model applied with the algorithm. + smash_config : SmashConfig + The SmashConfig object. + """ + return def get_compatible_algorithms(self) -> list[str]: """ diff --git a/src/pruna/algorithms/global_utils/recovery/perp_recoverer.py b/src/pruna/algorithms/global_utils/recovery/perp_recoverer.py index 8d7941c0..20f86a52 100644 --- a/src/pruna/algorithms/global_utils/recovery/perp_recoverer.py +++ b/src/pruna/algorithms/global_utils/recovery/perp_recoverer.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import shutil +from abc import ABCMeta from typing import Any, Dict import torch @@ -32,11 +32,11 @@ is_sd_pipeline, is_sdxl_pipeline, ) -from pruna.engine.save import SAVE_FUNCTIONS, save_pruna_model +from pruna.engine.save import refresh_saved_model from pruna.logging.logger import pruna_logger -class PERPRecoverer(PrunaAlgorithmBase): +class PERPRecoverer(PrunaAlgorithmBase, metaclass=ABCMeta): """ General purpose PERP recoverer using norm, head and bias finetuning and optionally HuggingFace's LoRA. @@ -64,7 +64,6 @@ class PERPRecoverer(PrunaAlgorithmBase): def __init__(self, task_name: str, use_lora: bool, use_in_place: bool, is_distillation: bool) -> None: self.task_name = task_name - self.tokenizer_required = task_name == "text_to_text" # type: ignore[misc] if not use_lora and not use_in_place: raise ValueError("Arguments use_lora and use_in_place cannot both be False, please use one of the two.") @@ -90,6 +89,11 @@ def __init__(self, task_name: str, use_lora: bool, use_in_place: bool, is_distil super().__init__() # self.adapters need to be set before calling get_hyperparameters + @property + def tokenizer_required(self) -> bool: + """Overwritten ``tokenizer_required`` property.""" + return self.task_name == "text_to_text" + def get_hyperparameters(self) -> list: """ Configure all algorithm-specific hyperparameters with ConfigSpace. @@ -182,44 +186,10 @@ def _pre_smash_hook(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> adapter_smash_config = SmashConfigPrefixWrapper(smash_config, adapter.adapter_prefix + "_") adapter.pre_smash_hook(model_recovery, adapter_smash_config, seed=adapter_seed) - def apply(self, model: Any, smash_config: SmashConfig) -> Any: - """ - Apply the recovery algorithm and refresh the save cache if needed. - - Recovery modifies weights in-place without changing the model's serialization - format. If a prior algorithm used ``save_before_apply`` (caching the model before - its transformation), the cached snapshot is now stale because recovery changed - the weights. This override refreshes that cache so the already saved model includes - the recovered weights. - - Parameters - ---------- - model : Any - The model to apply the algorithm to. - smash_config : SmashConfig - The SmashConfig object containing the save and load functions. - - Returns - ------- - Any - The model after recovery has been applied. - """ - result = super().apply(model, smash_config) - + def post_apply_hook(self, model: Any, smash_config: SmashConfig): + """Override to run side effects after the algorithm has been applied.""" if smash_config.prepare_saving: - save_dir = self.get_save_before_smash_dir(smash_config) - if not save_dir.exists(): - return result - - ori_save_fns = smash_config.save_fns[:] - smash_config.save_fns = [fn for fn in smash_config.save_fns if fn != SAVE_FUNCTIONS.save_before_apply.name] - # Re-save with recovered weights - shutil.rmtree(save_dir, ignore_errors=True) - save_dir.mkdir(parents=True) - save_pruna_model(model, save_dir, smash_config) - # Restore save_fns - smash_config.save_fns = ori_save_fns - return result + refresh_saved_model(model, self.get_save_before_smash_dir(smash_config), smash_config) def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: """ diff --git a/src/pruna/engine/save.py b/src/pruna/engine/save.py index 27101b31..35e2bbdb 100644 --- a/src/pruna/engine/save.py +++ b/src/pruna/engine/save.py @@ -355,9 +355,7 @@ def save_model_hqq(model: Any, model_path: str | Path, smash_config: SmashConfig # save pipeline info so we can call transformers.pipeline at load time save_pipeline_info(model, str(model_path)) # pipeline loading requires a safetensor file so we save a fake, lightweight one - save_model( - torch.nn.Linear(1, 1), str(model_path / "model.safetensors"), metadata={"format": "pt"} - ) + save_model(torch.nn.Linear(1, 1), str(model_path / "model.safetensors"), metadata={"format": "pt"}) save_model_hqq(model.model, model_path, smash_config) elif is_janus_llamagen_ar(model): @@ -470,6 +468,37 @@ def save_component(attr_name: str | None, module: torch.nn.Module, subpaths: lis smash_config.load_fns.append(LOAD_FUNCTIONS.hqq_diffusers.name) +def refresh_saved_model(model: Any, model_path: Path, smash_config: SmashConfig) -> None: + """ + Refresh the saved save-before-apply model, and the cache will reflect the current model state. + + Recovery modifies weights in-place without changing the model's serialization + format. If a prior algorithm used ``save_before_apply`` (caching the model before + its transformation), the cached snapshot is now stale because recovery changed + the weights. This override refreshes that cache so the already saved model includes + the recovered weights. + + Parameters + ---------- + model : Any + The model to apply the algorithm to. + model_path: Path + The model path to be saved. + smash_config : SmashConfig + The SmashConfig object containing the save and load functions. + """ + if not model_path.exists(): + return None + + ori_save_fns = smash_config.save_fns[:] + smash_config.save_fns = [fn for fn in smash_config.save_fns if fn != SAVE_FUNCTIONS.save_before_apply.name] + # Re-save with recovered weights + shutil.rmtree(model_path, ignore_errors=True) + save_pruna_model(model, model_path, smash_config) + # Restore save_fns + smash_config.save_fns = ori_save_fns + + def reapply(model: Any, model_path: str | Path, smash_config: SmashConfig) -> None: """ Reapply the model. diff --git a/tests/engine/test_save.py b/tests/engine/test_save.py index 34fd30b6..1872deee 100644 --- a/tests/engine/test_save.py +++ b/tests/engine/test_save.py @@ -1,9 +1,6 @@ import os -import shutil -from pathlib import Path -from unittest.mock import patch - import pytest +import shutil import torch from diffusers import DiffusionPipeline from transformers import AutoModelForCausalLM @@ -29,6 +26,7 @@ def test_save_llm_to_hub() -> None: ) pruna_model.push_to_hub(upload_repo_id, private=False) + @pytest.mark.slow @pytest.mark.cpu def test_save_diffusers_to_hub() -> None: @@ -162,30 +160,6 @@ def test_push_to_hub_path_types(tmp_path) -> None: assert mock_upload.called -@pytest.mark.cpu -def test_recovery_save_fn_is_none() -> None: - """Test that recovery algorithms use save_fn=None, preserving the prior algorithm's save format.""" - from pruna.algorithms.global_utils.recovery.perp_recoverer import PERPRecoverer - - assert PERPRecoverer.save_fn is None - - -@pytest.mark.cpu -def test_recovery_does_not_add_to_save_fns(tmp_path) -> None: - """Test that recovery's apply() does not append to save_fns when save_fn is None.""" - - config = SmashConfig() - config.save_fns = ["hqq"] # simulate a prior algorithm's save_fn - - save_fn = None - - # PrunaAlgorithmBase apply logic - if save_fn is not None and save_fn != SAVE_FUNCTIONS.reapply: - config.save_fns.append(save_fn.name) - - assert config.save_fns == ["hqq"], "Recovery should not add to save_fns" - - @pytest.mark.cpu def test_recovery_refresh_save_cache(tmp_path) -> None: """Test that recovery refreshes a stale save_before_apply cache with recovered weights.""" @@ -198,7 +172,6 @@ def test_recovery_refresh_save_cache(tmp_path) -> None: # Simulate a save_before_apply algorithm having run before recovery: # 1. Save original (pre-transformation) model to cache save_dir = PrunaAlgorithmBase.get_save_before_smash_dir(config) - save_dir.mkdir(parents=True) save_pruna_model(model, save_dir, config) # 2. Mark save_before_apply in save_fns (as the algorithm would) @@ -211,7 +184,6 @@ def test_recovery_refresh_save_cache(tmp_path) -> None: ori_save_fns = config.save_fns[:] config.save_fns = [fn for fn in config.save_fns if fn != SAVE_FUNCTIONS.save_before_apply.name] shutil.rmtree(save_dir, ignore_errors=True) - save_dir.mkdir(parents=True) save_pruna_model(model, save_dir, config) config.save_fns = ori_save_fns From bf726dd9c68fac07093a32bf9f392f4b6589b9a6 Mon Sep 17 00:00:00 2001 From: Zihao Xue Date: Wed, 13 May 2026 04:28:06 -0700 Subject: [PATCH 7/8] test: update recover unit tests --- tests/engine/test_save.py | 39 +++++++++++++++++++-------------------- 1 file changed, 19 insertions(+), 20 deletions(-) diff --git a/tests/engine/test_save.py b/tests/engine/test_save.py index 1872deee..7a531ae7 100644 --- a/tests/engine/test_save.py +++ b/tests/engine/test_save.py @@ -1,6 +1,8 @@ import os +from pathlib import Path +from unittest.mock import patch + import pytest -import shutil import torch from diffusers import DiffusionPipeline from transformers import AutoModelForCausalLM @@ -161,39 +163,36 @@ def test_push_to_hub_path_types(tmp_path) -> None: @pytest.mark.cpu -def test_recovery_refresh_save_cache(tmp_path) -> None: - """Test that recovery refreshes a stale save_before_apply cache with recovered weights.""" +def test_perp_post_apply_hook_round_trip(tmp_path) -> None: + """Test whether PERPRecoverer saves the correct model and load from it.""" from pruna.algorithms.base.pruna_base import PrunaAlgorithmBase + from pruna.algorithms.global_utils.recovery.perp_recoverer import PERPRecoverer + + class FakeRecoverer(PERPRecoverer): + algorithm_name = "test_fake_recoverer" + + def __init__(self): # noqa: D107 + pass - model = AutoModelForCausalLM.from_pretrained("yujiepan/opt-tiny-random") + def _apply(self, model, smash_config): # noqa: D401 + model.weight.data.fill_(0.99) + return model + model = torch.nn.Linear(3, 2) + model.weight.data.fill_(0.1) config = SmashConfig(device="cpu") - # Simulate a save_before_apply algorithm having run before recovery: - # 1. Save original (pre-transformation) model to cache save_dir = PrunaAlgorithmBase.get_save_before_smash_dir(config) save_pruna_model(model, save_dir, config) - - # 2. Mark save_before_apply in save_fns (as the algorithm would) config.save_fns.append(SAVE_FUNCTIONS.save_before_apply.name) - # 3. Simulate the transformation (e.g., half) + recovery modifying weights - model.lm_head.weight.data.fill_(0.99) # "recovered" weights - - # 4. Simulate what recovery's apply() does: refresh the stale cache - ori_save_fns = config.save_fns[:] - config.save_fns = [fn for fn in config.save_fns if fn != SAVE_FUNCTIONS.save_before_apply.name] - shutil.rmtree(save_dir, ignore_errors=True) - save_pruna_model(model, save_dir, config) - config.save_fns = ori_save_fns + model = FakeRecoverer().apply(model, config) - # 5. Verify the cache was refreshed: save_before_apply should copy updated files save_path = tmp_path / "final_model" save_pruna_model(model, save_path, config) - # Load and verify the recovered weights survived the round-trip loaded_model, _ = load_pruna_model(save_path) loaded_model = loaded_model.cpu() assert torch.allclose( - loaded_model.lm_head.weight, torch.full_like(loaded_model.lm_head.weight, 0.99) + loaded_model.weight, torch.full_like(loaded_model.weight, 0.99) ), "Recovered weights should survive save/load through save_before_apply" From 8975c4906cdcada18ac8af792599154d4af0b2a0 Mon Sep 17 00:00:00 2001 From: Zihao Xue Date: Thu, 14 May 2026 09:35:27 -0700 Subject: [PATCH 8/8] style: docstring styles --- .../global_utils/recovery/perp_recoverer.py | 11 ++++++++++- src/pruna/engine/save.py | 2 +- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/src/pruna/algorithms/global_utils/recovery/perp_recoverer.py b/src/pruna/algorithms/global_utils/recovery/perp_recoverer.py index 20f86a52..b7f0f5b4 100644 --- a/src/pruna/algorithms/global_utils/recovery/perp_recoverer.py +++ b/src/pruna/algorithms/global_utils/recovery/perp_recoverer.py @@ -187,7 +187,16 @@ def _pre_smash_hook(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> adapter.pre_smash_hook(model_recovery, adapter_smash_config, seed=adapter_seed) def post_apply_hook(self, model: Any, smash_config: SmashConfig): - """Override to run side effects after the algorithm has been applied.""" + """ + Override to run side effects after the algorithm has been applied. + + Parameters + ---------- + model : Any + The model. + smash_config : SmashConfig + The SmashConfig configuration to apply. + """ if smash_config.prepare_saving: refresh_saved_model(model, self.get_save_before_smash_dir(smash_config), smash_config) diff --git a/src/pruna/engine/save.py b/src/pruna/engine/save.py index 35e2bbdb..c2df1265 100644 --- a/src/pruna/engine/save.py +++ b/src/pruna/engine/save.py @@ -482,7 +482,7 @@ def refresh_saved_model(model: Any, model_path: Path, smash_config: SmashConfig) ---------- model : Any The model to apply the algorithm to. - model_path: Path + model_path : Path The model path to be saved. smash_config : SmashConfig The SmashConfig object containing the save and load functions.