diff --git a/auto_round/__main__.py b/auto_round/__main__.py index 656d7e2cf..a961233b9 100644 --- a/auto_round/__main__.py +++ b/auto_round/__main__.py @@ -236,6 +236,13 @@ def __init__(self, *args, **kwargs): "Higher values may speed up tuning but require more memory. " "Recommended to keep at 1 for stability with large models.", ) + tuning.add_argument( + "--nblocks_overlap", + default=0, + type=int, + help="Number of overlapping blocks between adjacent nblocks windows. " + "For CBQ-style CBD, use --nblocks 2 --nblocks_overlap 1.", + ) tuning.add_argument( "--scale_dtype", default=None, @@ -703,6 +710,7 @@ def tune(args): lr=args.lr, minmax_lr=args.minmax_lr, nblocks=args.nblocks, + nblocks_overlap=args.nblocks_overlap, to_quant_block_names=args.to_quant_block_names, scale_dtype=args.scale_dtype, ) diff --git a/auto_round/algorithms/quantization/sign_round/config.py b/auto_round/algorithms/quantization/sign_round/config.py index 2aad5e08e..0e97cb472 100644 --- a/auto_round/algorithms/quantization/sign_round/config.py +++ b/auto_round/algorithms/quantization/sign_round/config.py @@ -40,6 +40,7 @@ def __init__( lr_scheduler=None, momentum: float = 0.0, nblocks: int = 1, + nblocks_overlap: int = 0, enable_minmax_tuning: bool = True, enable_norm_bias_tuning: bool = False, gradient_accumulate_steps: int = 1, @@ -71,6 +72,9 @@ def __init__( self.lr_scheduler = lr_scheduler self.nblocks = nblocks + self.nblocks_overlap = nblocks_overlap + if self.nblocks == 1: + self.nblocks_overlap = 0 self.momentum = momentum self.enable_alg_ext = enable_alg_ext @@ -104,5 +108,9 @@ def check_configs(self) -> None: raise ValueError("`iters` must be non-negative") if self.nblocks <= 0: raise ValueError("`nblocks` must be positive") + if self.nblocks_overlap < 0: + raise ValueError("`nblocks_overlap` must be non-negative") + if self.nblocks > 1 and self.nblocks_overlap >= self.nblocks: + raise ValueError("`nblocks_overlap` must be smaller than `nblocks`") if self.gradient_accumulate_steps <= 0: raise ValueError("`gradient_accumulate_steps` must be positive") diff --git a/auto_round/algorithms/quantization/sign_round/quantizer.py b/auto_round/algorithms/quantization/sign_round/quantizer.py index 2297d9d44..72b102869 100644 --- a/auto_round/algorithms/quantization/sign_round/quantizer.py +++ b/auto_round/algorithms/quantization/sign_round/quantizer.py @@ -144,6 +144,7 @@ def quantize_block( *, loss_device: Union[str, torch.device], mid_iter_mem_check: bool = False, + finalize_module_names: Optional[list[str]] = None, **kwargs, ) -> dict: """Apply the AutoRound optimization algorithm to a block. @@ -216,7 +217,7 @@ def quantize_block( f"layers in the block" ) logger.info(dump_info) - unwrapper_block(block, {}) + unwrapper_block(block, {}, unwrap_filter=self._get_unwrap_filter(finalize_module_names)) return {} if self.lr_scheduler is None: @@ -314,7 +315,7 @@ def quantize_block( if len(unquantized_layer_names) != 0: logger.info(f"Unquantized layers: {unquantized_layer_names}") with torch.no_grad(): - unwrapper_block(block, best_params) + unwrapper_block(block, best_params, unwrap_filter=self._get_unwrap_filter(finalize_module_names)) if self.config.is_act_nv_fp: # enable moe experts act_max automatic generation for WrapperWALayer @@ -323,6 +324,22 @@ def quantize_block( logger.infoclean(dump_info) return best_params + @staticmethod + def _get_unwrap_filter(finalize_module_names: Optional[list[str]] = None): + if finalize_module_names is None: + return None + + def _should_unwrap(_name, module): + global_name = getattr(getattr(module, "orig_layer", None), "global_name", None) + if global_name is None: + return True + return any( + global_name == block_name or global_name.startswith(f"{block_name}.") + for block_name in finalize_module_names + ) + + return _should_unwrap + def quantize_layer_outside_block( self, layer_name: str, diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 6f5dc2fda..a918763dd 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -218,7 +218,8 @@ def __init__( kwargs.pop("enable_alg_ext", None) kwargs.pop("vlm", None) amp = kwargs.pop("amp", True) - nblocks = kwargs.pop("nblocks", 1) + nblocks = kwargs.pop("nblocks", getattr(self.quantize_config, "nblocks", 1)) + nblocks_overlap = kwargs.pop("nblocks_overlap", getattr(self.quantize_config, "nblocks_overlap", 0)) disable_deterministic_algorithms = kwargs.pop("disable_deterministic_algorithms", True) enable_deterministic_algorithms = kwargs.pop("enable_deterministic_algorithms", False) @@ -270,6 +271,15 @@ def __init__( set_seed(self.seed) self.nblocks = nblocks + self.nblocks_overlap = nblocks_overlap + if self.nblocks <= 0: + raise ValueError("`nblocks` must be positive") + if self.nblocks == 1: + self.nblocks_overlap = 0 + if self.nblocks_overlap < 0: + raise ValueError("`nblocks_overlap` must be non-negative") + if self.nblocks_overlap >= self.nblocks: + raise ValueError("`nblocks_overlap` must be smaller than `nblocks`") self.enable_torch_compile = enable_torch_compile diff --git a/auto_round/compressors/config.py b/auto_round/compressors/config.py index dd028b6cd..7ef280192 100644 --- a/auto_round/compressors/config.py +++ b/auto_round/compressors/config.py @@ -42,6 +42,7 @@ def __init__( lr_scheduler: Callable = None, minmax_lr: float = None, nblocks: int = 1, + nblocks_overlap: int = 0, to_quant_block_names: Union[str, list, None] = None, scale_dtype: str = "fp16", # scheme @@ -122,6 +123,7 @@ def __init__( lr_scheduler=lr_scheduler, minmax_lr=minmax_lr, nblocks=nblocks, + nblocks_overlap=nblocks_overlap, to_quant_block_names=to_quant_block_names, scale_dtype=scale_dtype, ) @@ -257,6 +259,7 @@ class TuningExtraConfig(BaseExtraConfig): lr_scheduler: Callable = None minmax_lr: float = None nblocks: int = 1 + nblocks_overlap: int = 0 to_quant_block_names: Union[str, list, None] = None scale_dtype: str = "fp16" diff --git a/auto_round/compressors/data_driven.py b/auto_round/compressors/data_driven.py index 10ac8332f..32acc5fe9 100644 --- a/auto_round/compressors/data_driven.py +++ b/auto_round/compressors/data_driven.py @@ -59,6 +59,7 @@ memory_monitor, mv_module_from_gpu, set_amax_for_all_moe_layers, + set_module, to_device, to_dtype, wrap_block_forward_positional_to_kwargs, @@ -416,6 +417,118 @@ def quantize_block( finally: self.model_context.is_mllm = orig_is_mllm + @staticmethod + def _clone_fp_overlap_module(module: torch.nn.Module) -> torch.nn.Module: + fp_module = copy.deepcopy(module) + fp_module.requires_grad_(False) + return mv_module_from_gpu(fp_module) + + def _cache_fp_overlap_modules( + self, + model: torch.nn.Module, + block_names: list[str], + fp_overlap_modules: dict[str, torch.nn.Module], + ) -> None: + for block_name in block_names: + if block_name in fp_overlap_modules: + continue + fp_overlap_modules[block_name] = self._clone_fp_overlap_module(get_module(model, block_name)) + + def _swap_in_fp_overlap_modules( + self, + model: torch.nn.Module, + block_names: list[str], + fp_overlap_modules: dict[str, torch.nn.Module], + ) -> list[tuple[str, torch.nn.Module, torch.nn.Module]]: + swapped_modules = [] + for block_name in block_names: + fp_module = fp_overlap_modules.get(block_name) + if fp_module is None: + continue + current_module = get_module(model, block_name) + replacement = copy.deepcopy(fp_module) + set_module(model, block_name, replacement) + swapped_modules.append((block_name, current_module, replacement)) + return swapped_modules + + @staticmethod + def _restore_swapped_modules( + model: torch.nn.Module, + swapped_modules: list[tuple[str, torch.nn.Module, torch.nn.Module]], + ) -> None: + for block_name, current_module, _ in reversed(swapped_modules): + set_module(model, block_name, current_module) + + @staticmethod + def _build_window_block(modules: list[torch.nn.Module]) -> torch.nn.Module: + if len(modules) == 1: + return modules[0] + return WrapperMultiblock(modules) + + def _get_fp_window_outputs( + self, + model: torch.nn.Module, + window_names: list[str], + fp_overlap_modules: dict[str, torch.nn.Module], + input_ids, + input_others, + bs: int, + loss_device, + overlap_advance: int, + ): + swapped_modules = self._swap_in_fp_overlap_modules(model, window_names, fp_overlap_modules) + try: + modules = [get_module(model, block_name) for block_name in window_names] + fp_block = self._build_window_block(modules) + materialize_model_(fp_block) + convert_module_to_hp_if_necessary(fp_block, self.model_context.amp_dtype, self.compress_context.device) + fp_block = fp_block.to(self.compress_context.device) + reference_output = self.quantizer._get_block_outputs( + fp_block, + input_ids, + input_others, + bs, + device_override=loss_device, + ) + overlap_input = None + overlap_block = self._get_overlap_advance_block(fp_block, overlap_advance) + if overlap_block is not None: + overlap_input = self.quantizer._get_block_outputs( + overlap_block, + input_ids, + input_others, + bs, + device_override=loss_device, + ) + return reference_output, overlap_input + finally: + self._restore_swapped_modules(model, swapped_modules) + for _, _, replacement in swapped_modules: + mv_module_from_gpu(replacement) + + def _get_overlap_advance_block(self, block: torch.nn.Module, advance_blocks: int) -> Optional[torch.nn.Module]: + if advance_blocks <= 0: + return None + if isinstance(block, WrapperMultiblock): + modules = list(block.layers[:advance_blocks]) + if len(modules) == 1: + return modules[0] + return WrapperMultiblock(modules) + if advance_blocks == 1: + return block + return None + + @staticmethod + def _get_future_window_names(block_names: list[str], block_starts: list[int], step_idx: int, nblocks: int) -> set: + future_names = set() + for future_start in block_starts[step_idx + 1 :]: + future_names.update(block_names[future_start : min(future_start + nblocks, len(block_names))]) + return future_names + + @staticmethod + def _is_module_in_blocks(module_name: str, block_names: list[str]) -> bool: + return any(module_name == block_name or module_name.startswith(f"{block_name}.") for block_name in block_names) + def _quantize_blocks( self, model: torch.nn.Module, @@ -457,15 +570,19 @@ def _quantize_blocks( for k in extra_keys: input_others[k] = input_ids.pop(k) + overlap = self.nblocks_overlap if nblocks > 1 else 0 + stride = nblocks - overlap + block_starts = self._get_block_window_starts(block_names, nblocks) if pbar is None: - pbar = tqdm(range(0, len(block_names), nblocks)) + pbar = tqdm(block_starts) + fp_overlap_modules = {} - for i in range(0, len(block_names), nblocks): + for step_idx, i in enumerate(block_starts): if input_others_extra_blocks and block_names[i] in input_others_extra_blocks: input_others = input_others_extra_blocks[block_names[i]] _, input_others = self._preprocess_block_inputs(input_others) input_others_extra_blocks.pop(block_names[i]) - if i != 0: + if step_idx != 0: pbar.update(1) if nblocks == 1: n = block_names[i] @@ -476,6 +593,10 @@ def _quantize_blocks( pbar.set_description(f"Quantizing [{i + 1}-{min(i + nblocks, len(block_names))}]/{len(block_names)}") modules = [get_module(model, n) for n in names] m = WrapperMultiblock(modules) + window_names = [n] if nblocks == 1 else names + overlap_advance = stride if overlap > 0 and step_idx + 1 < len(block_starts) else 0 + future_window_names = self._get_future_window_names(block_names, block_starts, step_idx, nblocks) + final_window_names = [name for name in window_names if name not in future_window_names] if self.compress_context.low_cpu_mem_usage: if nblocks == 1: @@ -508,6 +629,12 @@ def _quantize_blocks( m = m.to(self.compress_context.device) card_0_in_high_risk, loss_device = False, self.compress_context.device + self._cache_fp_overlap_modules( + model, + [name for name in window_names if name in future_window_names], + fp_overlap_modules, + ) + if len(self.compress_context.device_list) > 1 and not self.model_context.is_diffusion: from accelerate.hooks import AlignDevicesHook, add_hook_to_module @@ -518,24 +645,29 @@ def _quantize_blocks( # ── Infrastructure: collect reference output and act_max ────────── bs = self.quantizer.batch_size * self.quantizer.infer_bs_coeff - if q_input is None: - hook_handles = self.quantizer.register_calibration_hooks(m) - reference_output = self.quantizer._get_block_outputs( - m, input_ids, input_others, bs, device_override=loss_device - ) - for h in hook_handles: - h.remove() - else: - reference_output = self.quantizer._get_block_outputs( - m, input_ids, input_others, bs, device_override=loss_device + reference_output, overlap_input = self._get_fp_window_outputs( + model, + window_names, + fp_overlap_modules, + input_ids, + input_others, + bs, + loss_device, + overlap_advance, + ) + hook_handles = self.quantizer.register_calibration_hooks(m) + if hook_handles: + calib_input = q_input if q_input is not None else input_ids + self.quantizer._get_block_outputs( + m, + calib_input, + input_others, + bs, + save_output=False, + device_override=loss_device, ) - hook_handles = self.quantizer.register_calibration_hooks(m) - if hook_handles: - self.quantizer._get_block_outputs( - m, q_input, input_others, bs, save_output=False, device_override=loss_device - ) - for h in hook_handles: - h.remove() + for h in hook_handles: + h.remove() # ── Infrastructure: swap q_input ────────────────────────────────── if q_input is not None: @@ -554,6 +686,7 @@ def _quantize_blocks( reference_output, loss_device=loss_device, mid_iter_mem_check=mid_iter_mem_check, + finalize_module_names=final_window_names, ) # ── MoE scale alignment for FP8 dispatch efficiency ──────────────── @@ -562,7 +695,8 @@ def _quantize_blocks( # ── Infrastructure: collect q_outputs if needed ─────────────────── if self.quantizer.enable_quanted_input: - q_input = self.quantizer._get_block_outputs(m, input_ids, input_others, bs) + q_output_block = self._get_overlap_advance_block(m, overlap_advance) or m + q_input = self.quantizer._get_block_outputs(q_output_block, input_ids, input_others, bs) else: q_input = None @@ -577,7 +711,7 @@ def _quantize_blocks( # from the current block's reference output, while q_input (when # enabled) is only used as the quantized-input companion for the # next block. - next_input_ids = reference_output + next_input_ids = overlap_input if overlap_input is not None else reference_output clear_memory( input_ids if input_ids is not next_input_ids else None, device_list=self.compress_context.device_list ) @@ -586,15 +720,22 @@ def _quantize_blocks( # ── Infrastructure: immediate_pack / shard write ────────────────── if self.compress_context.is_immediate_packing: for _n, _mod in m.named_modules(): - if hasattr(_mod, "bits") and check_to_quantized(_mod): + global_name = getattr(_mod, "global_name", None) + if ( + global_name is not None + and self._is_module_in_blocks(global_name, final_window_names) + and hasattr(_mod, "bits") + and check_to_quantized(_mod) + ): from auto_round.compressors.utils import immediate_pack as _immediate_pack - _immediate_pack(_mod.global_name, self.quantizer.layer_config) + _immediate_pack(global_name, self.quantizer.layer_config) input_ids = next_input_ids if self.compress_context.is_immediate_saving: - self.shard_writer.write(m, is_finalize=False) + for name in final_window_names: + self.shard_writer.write(name=name, is_finalize=False) if self.compress_context.low_cpu_mem_usage and not self.compress_context.is_immediate_saving: if nblocks == 1: @@ -602,6 +743,9 @@ def _quantize_blocks( else: for name in names: self._offloader(model, name, overwrite=True) + for name in list(fp_overlap_modules): + if name not in future_window_names: + del fp_overlap_modules[name] if pbar is not None: pbar.update(1) @@ -618,6 +762,19 @@ def _quantize_blocks( clear_memory(device_list=self.compress_context.device_list) + def _get_block_window_starts(self, block_names: list, nblocks: int) -> list[int]: + overlap = self.nblocks_overlap if nblocks > 1 else 0 + stride = nblocks - overlap + block_starts = [] + block_idx = 0 + while block_idx < len(block_names): + remaining = len(block_names) - block_idx + if block_idx > 0 and overlap > 0 and remaining <= overlap: + break + block_starts.append(block_idx) + block_idx += stride + return block_starts + def quantize(self) -> tuple[torch.nn.Module, dict[str, Any]]: """Quantize the model and return the quantized model along with layer configurations.The entry of AutoRound. Returns: @@ -695,10 +852,8 @@ def quantize(self) -> tuple[torch.nn.Module, dict[str, Any]]: self.compress_context.low_cpu_mem_usage = False else: self.compress_context.low_cpu_mem_usage = False - if len(all_blocks) > 1: - pbar = tqdm(range(0, sum([len(i) for i in all_blocks]), self.nblocks)) - else: - pbar = tqdm(range(0, len(all_blocks[0]), self.nblocks)) # move the alg warning outside pbar + block_window_count = sum(len(self._get_block_window_starts(block, self.nblocks)) for block in all_blocks) + pbar = tqdm(range(block_window_count)) # move the alg warning outside pbar start_time = time.time() for block_names in all_blocks: diff --git a/auto_round/compressors/entry.py b/auto_round/compressors/entry.py index 91ffa3b1a..6caf6ace6 100644 --- a/auto_round/compressors/entry.py +++ b/auto_round/compressors/entry.py @@ -400,6 +400,7 @@ def _pop_config_kwargs(kwargs: dict[str, Any]) -> tuple[dict[str, Any], dict[str ) auto_round_only_keys = ( "nblocks", + "nblocks_overlap", "enable_alg_ext", "lr_scheduler", "not_use_best_mse", diff --git a/auto_round/wrapper.py b/auto_round/wrapper.py index bd6a78570..8277ff7d5 100644 --- a/auto_round/wrapper.py +++ b/auto_round/wrapper.py @@ -772,7 +772,14 @@ def wrapper_block( """ quantized_layers = [] unquantized_layers = [] + wrapped_prefixes = [] for n, m in block.named_modules(): + if any(prefix == "" or n == prefix or n.startswith(f"{prefix}.") for prefix in wrapped_prefixes): + continue + if hasattr(m, "orig_layer"): + quantized_layers.append(n) + wrapped_prefixes.append(n) + continue if type(m) in SUPPORTED_LAYER_TYPES: if not check_to_quantized(m): unquantized_layers.append(n) @@ -787,6 +794,7 @@ def wrapper_block( ) set_module(block, n, new_m) quantized_layers.append(n) + wrapped_prefixes.append(n) elif enable_norm_bias_tuning: if "norm" in m.__class__.__name__.lower(): @@ -794,6 +802,7 @@ def wrapper_block( wrapper_layer_class = NORM_MAPPING[m.__class__.__name__] new_m = wrapper_layer_class(m, device=device) set_module(block, n, new_m) + wrapped_prefixes.append(n) elif "RMSNorm" in m.__class__.__name__: logger.warning_once( f"use LlamaRMSNorm to wrap {m.__class__.__name__}, please check the correctness yourself" @@ -801,6 +810,7 @@ def wrapper_block( wrapper_layer_class = NORM_MAPPING["LlamaRMSNorm"] new_m = wrapper_layer_class(m, device=device) set_module(block, n, new_m) + wrapped_prefixes.append(n) else: logger.warning_once(f"{m.__class__.__name__} is not supported") return quantized_layers, unquantized_layers @@ -837,7 +847,7 @@ def unwrapper_layer(model, layer, layer_name, best_params): @torch.no_grad() -def unwrapper_block(block, best_params): +def unwrapper_block(block, best_params, unwrap_filter=None): """Unwraps the WrapperLinear and WrapperTransformerConv1d modules in the given block. Args: @@ -852,5 +862,10 @@ def unwrapper_block(block, best_params): best_param = best_params[n] else: best_param = None - orig_layer = m.unwrapper(best_param) - set_module(block, n, orig_layer) + if unwrap_filter is None or unwrap_filter(n, m): + orig_layer = m.unwrapper(best_param) + set_module(block, n, orig_layer) + elif best_param is not None: + for key, value in best_param.items(): + if key in m.params: + m.params[key].data.copy_(value.to(m.params[key].device))