Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions auto_round/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,13 @@ def __init__(self, *args, **kwargs):
"Higher values may speed up tuning but require more memory. "
"Recommended to keep at 1 for stability with large models.",
)
tuning.add_argument(
"--nblocks_overlap",
default=0,
type=int,
help="Number of overlapping blocks between adjacent nblocks windows. "
"For CBQ-style CBD, use --nblocks 2 --nblocks_overlap 1.",
)
tuning.add_argument(
"--scale_dtype",
default=None,
Expand Down Expand Up @@ -703,6 +710,7 @@ def tune(args):
lr=args.lr,
minmax_lr=args.minmax_lr,
nblocks=args.nblocks,
nblocks_overlap=args.nblocks_overlap,
to_quant_block_names=args.to_quant_block_names,
scale_dtype=args.scale_dtype,
)
Expand Down
8 changes: 8 additions & 0 deletions auto_round/algorithms/quantization/sign_round/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(
lr_scheduler=None,
momentum: float = 0.0,
nblocks: int = 1,
nblocks_overlap: int = 0,
enable_minmax_tuning: bool = True,
enable_norm_bias_tuning: bool = False,
gradient_accumulate_steps: int = 1,
Expand Down Expand Up @@ -71,6 +72,9 @@ def __init__(
self.lr_scheduler = lr_scheduler

self.nblocks = nblocks
self.nblocks_overlap = nblocks_overlap
if self.nblocks == 1:
self.nblocks_overlap = 0
self.momentum = momentum
self.enable_alg_ext = enable_alg_ext

Expand Down Expand Up @@ -104,5 +108,9 @@ def check_configs(self) -> None:
raise ValueError("`iters` must be non-negative")
if self.nblocks <= 0:
raise ValueError("`nblocks` must be positive")
if self.nblocks_overlap < 0:
raise ValueError("`nblocks_overlap` must be non-negative")
if self.nblocks > 1 and self.nblocks_overlap >= self.nblocks:
raise ValueError("`nblocks_overlap` must be smaller than `nblocks`")
if self.gradient_accumulate_steps <= 0:
raise ValueError("`gradient_accumulate_steps` must be positive")
21 changes: 19 additions & 2 deletions auto_round/algorithms/quantization/sign_round/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def quantize_block(
*,
loss_device: Union[str, torch.device],
mid_iter_mem_check: bool = False,
finalize_module_names: Optional[list[str]] = None,
**kwargs,
) -> dict:
"""Apply the AutoRound optimization algorithm to a block.
Expand Down Expand Up @@ -216,7 +217,7 @@ def quantize_block(
f"layers in the block"
)
logger.info(dump_info)
unwrapper_block(block, {})
unwrapper_block(block, {}, unwrap_filter=self._get_unwrap_filter(finalize_module_names))
return {}

if self.lr_scheduler is None:
Expand Down Expand Up @@ -314,7 +315,7 @@ def quantize_block(
if len(unquantized_layer_names) != 0:
logger.info(f"Unquantized layers: {unquantized_layer_names}")
with torch.no_grad():
unwrapper_block(block, best_params)
unwrapper_block(block, best_params, unwrap_filter=self._get_unwrap_filter(finalize_module_names))

if self.config.is_act_nv_fp:
# enable moe experts act_max automatic generation for WrapperWALayer
Expand All @@ -323,6 +324,22 @@ def quantize_block(
logger.infoclean(dump_info)
return best_params

@staticmethod
def _get_unwrap_filter(finalize_module_names: Optional[list[str]] = None):
if finalize_module_names is None:
return None

def _should_unwrap(_name, module):
global_name = getattr(getattr(module, "orig_layer", None), "global_name", None)
if global_name is None:
return True
return any(
global_name == block_name or global_name.startswith(f"{block_name}.")
for block_name in finalize_module_names
)

return _should_unwrap

def quantize_layer_outside_block(
self,
layer_name: str,
Expand Down
12 changes: 11 additions & 1 deletion auto_round/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,8 @@ def __init__(
kwargs.pop("enable_alg_ext", None)
kwargs.pop("vlm", None)
amp = kwargs.pop("amp", True)
nblocks = kwargs.pop("nblocks", 1)
nblocks = kwargs.pop("nblocks", getattr(self.quantize_config, "nblocks", 1))
nblocks_overlap = kwargs.pop("nblocks_overlap", getattr(self.quantize_config, "nblocks_overlap", 0))
disable_deterministic_algorithms = kwargs.pop("disable_deterministic_algorithms", True)
enable_deterministic_algorithms = kwargs.pop("enable_deterministic_algorithms", False)

Expand Down Expand Up @@ -270,6 +271,15 @@ def __init__(
set_seed(self.seed)

self.nblocks = nblocks
self.nblocks_overlap = nblocks_overlap
if self.nblocks <= 0:
raise ValueError("`nblocks` must be positive")
if self.nblocks == 1:
self.nblocks_overlap = 0
if self.nblocks_overlap < 0:
raise ValueError("`nblocks_overlap` must be non-negative")
if self.nblocks_overlap >= self.nblocks:
raise ValueError("`nblocks_overlap` must be smaller than `nblocks`")

self.enable_torch_compile = enable_torch_compile

Expand Down
3 changes: 3 additions & 0 deletions auto_round/compressors/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(
lr_scheduler: Callable = None,
minmax_lr: float = None,
nblocks: int = 1,
nblocks_overlap: int = 0,
to_quant_block_names: Union[str, list, None] = None,
scale_dtype: str = "fp16",
# scheme
Expand Down Expand Up @@ -122,6 +123,7 @@ def __init__(
lr_scheduler=lr_scheduler,
minmax_lr=minmax_lr,
nblocks=nblocks,
nblocks_overlap=nblocks_overlap,
to_quant_block_names=to_quant_block_names,
scale_dtype=scale_dtype,
)
Expand Down Expand Up @@ -257,6 +259,7 @@ class TuningExtraConfig(BaseExtraConfig):
lr_scheduler: Callable = None
minmax_lr: float = None
nblocks: int = 1
nblocks_overlap: int = 0
to_quant_block_names: Union[str, list, None] = None
scale_dtype: str = "fp16"

Expand Down
Loading
Loading