Skip to content
40 changes: 37 additions & 3 deletions src/pruna/algorithms/base/pruna_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import functools
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, Iterable

from transformers import Pipeline
Expand Down Expand Up @@ -355,8 +356,8 @@ def apply(self, model: Any, smash_config: SmashConfig) -> Any:
Any
The model after the algorithm has been applied.
"""
if self.save_fn == SAVE_FUNCTIONS.save_before_apply and smash_config._prepare_saving:
save_dir = smash_config.cache_dir / SAVE_BEFORE_SMASH_CACHE_DIR
if self.save_fn == SAVE_FUNCTIONS.save_before_apply and smash_config.prepare_saving:
save_dir = self.get_save_before_smash_dir(smash_config)
save_pruna_model(model, save_dir, smash_config)

# save algorithms to reapply after loading
Expand All @@ -369,7 +370,23 @@ def apply(self, model: Any, smash_config: SmashConfig) -> Any:

prefix = self.algorithm_name + "_"
wrapped_config = SmashConfigPrefixWrapper(smash_config, prefix)
return self._apply(model, wrapped_config)
result = self._apply(model, wrapped_config)

self.post_apply_hook(model, smash_config)
return result

def post_apply_hook(self, model: Any, smash_config: SmashConfig) -> None:
"""
Post apply hook called after _apply returns to run side effects after the algorithm applies.

Parameters
----------
model : Any
The model applied with the algorithm.
smash_config : SmashConfig
The SmashConfig object.
"""
return

def get_compatible_algorithms(self) -> list[str]:
"""
Expand Down Expand Up @@ -447,6 +464,23 @@ def get_algorithms_to_run_after_disjointly(self) -> list[str]:
"""
return _expand_tags_into_algorithm_names(self.disjointly_compatible_after)

@staticmethod
def get_save_before_smash_dir(smash_config: SmashConfig) -> Path:
"""
Get the save directory for the algorithm caches.

Parameters
----------
smash_config : SmashConfig
The SmashConfig to check the cache directory against.

Returns
-------
Path
The absolute path of "SAVE_BEFORE_SMASH_CACHE_DIR".
"""
return (smash_config.cache_dir / SAVE_BEFORE_SMASH_CACHE_DIR).resolve()


def wrap_handle_imports(func):
"""
Expand Down
29 changes: 24 additions & 5 deletions src/pruna/algorithms/global_utils/recovery/perp_recoverer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABCMeta
from typing import Any, Dict

import torch
Expand All @@ -23,19 +24,19 @@
from pruna.algorithms.global_utils.recovery.finetuners import PrunaFinetuner
from pruna.algorithms.global_utils.recovery.finetuners.diffusers.utils import get_denoiser_attr
from pruna.algorithms.global_utils.recovery.utils import get_trainable_parameters
from pruna.config.smash_config import SmashConfigPrefixWrapper
from pruna.config.smash_config import SmashConfig, SmashConfigPrefixWrapper
from pruna.engine.model_checks import (
is_causal_lm,
is_flux_pipeline,
is_sana_pipeline,
is_sd_pipeline,
is_sdxl_pipeline,
)
from pruna.engine.save import SAVE_FUNCTIONS
from pruna.engine.save import refresh_saved_model
from pruna.logging.logger import pruna_logger


class PERPRecoverer(PrunaAlgorithmBase):
class PERPRecoverer(PrunaAlgorithmBase, metaclass=ABCMeta):
"""
General purpose PERP recoverer using norm, head and bias finetuning and optionally HuggingFace's LoRA.

Expand All @@ -52,7 +53,7 @@ class PERPRecoverer(PrunaAlgorithmBase):
"""

group_tags: list[AlgorithmTag] = [AlgorithmTag.RECOVERER] # type: ignore[attr-defined]
save_fn = SAVE_FUNCTIONS.pickled
save_fn = None
references: dict[str, str] = {
"GitHub": "https://github.com/huggingface/peft",
"Paper": "https://arxiv.org/pdf/2312.15230",
Expand All @@ -63,7 +64,6 @@ class PERPRecoverer(PrunaAlgorithmBase):

def __init__(self, task_name: str, use_lora: bool, use_in_place: bool, is_distillation: bool) -> None:
self.task_name = task_name
self.tokenizer_required = task_name == "text_to_text" # type: ignore[misc]

if not use_lora and not use_in_place:
raise ValueError("Arguments use_lora and use_in_place cannot both be False, please use one of the two.")
Expand All @@ -89,6 +89,11 @@ def __init__(self, task_name: str, use_lora: bool, use_in_place: bool, is_distil

super().__init__() # self.adapters need to be set before calling get_hyperparameters

@property
def tokenizer_required(self) -> bool:
"""Overwritten ``tokenizer_required`` property."""
return self.task_name == "text_to_text"

def get_hyperparameters(self) -> list:
"""
Configure all algorithm-specific hyperparameters with ConfigSpace.
Expand Down Expand Up @@ -181,6 +186,20 @@ def _pre_smash_hook(self, model: Any, smash_config: SmashConfigPrefixWrapper) ->
adapter_smash_config = SmashConfigPrefixWrapper(smash_config, adapter.adapter_prefix + "_")
adapter.pre_smash_hook(model_recovery, adapter_smash_config, seed=adapter_seed)

def post_apply_hook(self, model: Any, smash_config: SmashConfig):
"""
Override to run side effects after the algorithm has been applied.

Parameters
----------
model : Any
The model.
smash_config : SmashConfig
The SmashConfig configuration to apply.
"""
if smash_config.prepare_saving:
refresh_saved_model(model, self.get_save_before_smash_dir(smash_config), smash_config)

def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any:
"""
Recover performances from a given model with a given config.
Expand Down
5 changes: 5 additions & 0 deletions src/pruna/config/smash_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,11 @@ def __init__(
raise ValueError(f"Unsupported configuration type: {type(configuration)}")
self.config_space: ConfigurationSpace = self._configuration.config_space

@property
def prepare_saving(self):
"""Getter of _prepare_saving as an object's internal data."""
return self._prepare_saving

@classmethod
def from_list(
cls,
Expand Down
35 changes: 32 additions & 3 deletions src/pruna/engine/save.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,9 +355,7 @@ def save_model_hqq(model: Any, model_path: str | Path, smash_config: SmashConfig
# save pipeline info so we can call transformers.pipeline at load time
save_pipeline_info(model, str(model_path))
# pipeline loading requires a safetensor file so we save a fake, lightweight one
save_model(
torch.nn.Linear(1, 1), str(model_path / "model.safetensors"), metadata={"format": "pt"}
)
save_model(torch.nn.Linear(1, 1), str(model_path / "model.safetensors"), metadata={"format": "pt"})

save_model_hqq(model.model, model_path, smash_config)
elif is_janus_llamagen_ar(model):
Expand Down Expand Up @@ -470,6 +468,37 @@ def save_component(attr_name: str | None, module: torch.nn.Module, subpaths: lis
smash_config.load_fns.append(LOAD_FUNCTIONS.hqq_diffusers.name)


def refresh_saved_model(model: Any, model_path: Path, smash_config: SmashConfig) -> None:
"""
Refresh the saved save-before-apply model, and the cache will reflect the current model state.

Recovery modifies weights in-place without changing the model's serialization
format. If a prior algorithm used ``save_before_apply`` (caching the model before
its transformation), the cached snapshot is now stale because recovery changed
the weights. This override refreshes that cache so the already saved model includes
the recovered weights.

Parameters
----------
model : Any
The model to apply the algorithm to.
model_path : Path
The model path to be saved.
smash_config : SmashConfig
The SmashConfig object containing the save and load functions.
"""
if not model_path.exists():
return None

ori_save_fns = smash_config.save_fns[:]
smash_config.save_fns = [fn for fn in smash_config.save_fns if fn != SAVE_FUNCTIONS.save_before_apply.name]
# Re-save with recovered weights
shutil.rmtree(model_path, ignore_errors=True)
save_pruna_model(model, model_path, smash_config)
# Restore save_fns
smash_config.save_fns = ori_save_fns


def reapply(model: Any, model_path: str | Path, smash_config: SmashConfig) -> None:
"""
Reapply the model.
Expand Down
52 changes: 44 additions & 8 deletions tests/engine/test_save.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
import os
import pytest
import torch
from pathlib import Path
from unittest.mock import patch

import pytest
import torch
from diffusers import DiffusionPipeline
from transformers import AutoModelForCausalLM
from pruna.config.smash_config import SmashConfig

from pruna import smash
from pruna.engine.save import save_pruna_model
from pruna.engine.save import save_pruna_model_to_hub
from pruna.engine.save import SAVE_FUNCTIONS
from pruna.engine.load import load_pruna_model
from pruna.config.smash_config import SmashConfig
from diffusers import DiffusionPipeline
from pruna.engine.load import load_pruna_model
from pruna.engine.pruna_model import PrunaModel
from pruna.engine.save import SAVE_FUNCTIONS, save_pruna_model, save_pruna_model_to_hub


@pytest.mark.slow
Expand All @@ -29,6 +28,7 @@ def test_save_llm_to_hub() -> None:
)
pruna_model.push_to_hub(upload_repo_id, private=False)


@pytest.mark.slow
@pytest.mark.cpu
def test_save_diffusers_to_hub() -> None:
Expand Down Expand Up @@ -160,3 +160,39 @@ def test_push_to_hub_path_types(tmp_path) -> None:
private=True
)
assert mock_upload.called


@pytest.mark.cpu
def test_perp_post_apply_hook_round_trip(tmp_path) -> None:
"""Test whether PERPRecoverer saves the correct model and load from it."""
from pruna.algorithms.base.pruna_base import PrunaAlgorithmBase
from pruna.algorithms.global_utils.recovery.perp_recoverer import PERPRecoverer

class FakeRecoverer(PERPRecoverer):
algorithm_name = "test_fake_recoverer"

def __init__(self): # noqa: D107
pass

def _apply(self, model, smash_config): # noqa: D401
model.weight.data.fill_(0.99)
return model

model = torch.nn.Linear(3, 2)
model.weight.data.fill_(0.1)
config = SmashConfig(device="cpu")

save_dir = PrunaAlgorithmBase.get_save_before_smash_dir(config)
save_pruna_model(model, save_dir, config)
config.save_fns.append(SAVE_FUNCTIONS.save_before_apply.name)

model = FakeRecoverer().apply(model, config)

save_path = tmp_path / "final_model"
save_pruna_model(model, save_path, config)

loaded_model, _ = load_pruna_model(save_path)
loaded_model = loaded_model.cpu()
assert torch.allclose(
loaded_model.weight, torch.full_like(loaded_model.weight, 0.99)
), "Recovered weights should survive save/load through save_before_apply"