diff --git a/.claude/skills/adapt-new-diffusion-model/SKILL.md b/.claude/skills/adapt-new-diffusion-model/SKILL.md index 89b36c274..78e9b88ca 100644 --- a/.claude/skills/adapt-new-diffusion-model/SKILL.md +++ b/.claude/skills/adapt-new-diffusion-model/SKILL.md @@ -52,15 +52,12 @@ def is_diffusion_model(model_or_path): ``` If your model doesn't have `model_index.json`, either create one in the model -directory or pass diffusion-specific options through new-architecture -`ExtraConfig` / `AutoRound` kwargs: +directory or pass diffusion-specific options through `AutoRound` kwargs: ```python -from auto_round.compressors.config import ExtraConfig - ar = AutoRound( model, - extra_config=ExtraConfig(num_inference_steps=5), + num_inference_steps=5, ) ``` diff --git a/.claude/skills/add-vlm-model/SKILL.md b/.claude/skills/add-vlm-model/SKILL.md index 8d91c2af2..b6aa0a369 100644 --- a/.claude/skills/add-vlm-model/SKILL.md +++ b/.claude/skills/add-vlm-model/SKILL.md @@ -104,7 +104,7 @@ The new architecture routes multimodal calibration through: If your model works with an existing template/processor, prefer passing `template=...`, `processor=...`, or `image_processor=...` through `AutoRound` / -`ExtraConfig` instead of adding compressor code. +kwargs instead of adding compressor code. ## Step 3: Add Calibration Template diff --git a/.gitignore b/.gitignore index af7863bd1..e44765636 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,4 @@ tmp_autoround/ ut_log_dir/ CLAUDE.local.md docs/plan/ +.codegraph/ diff --git a/AGENTS.md b/AGENTS.md deleted file mode 100644 index 18808c518..000000000 --- a/AGENTS.md +++ /dev/null @@ -1,73 +0,0 @@ -# AGENTS.md - -This file provides guidance to ANY AGENTS when working with code in this repository. - -## Project - -AutoRound — post-training quantization for LLMs/VLMs using sign-gradient descent. Publishes as `auto-round` (GPU/CPU) and `auto-round-hpu` (Intel Gaudi). - -## Build & Install - -```bash -# From source (GPU/CPU) — --no-build-isolation is required when PyTorch is already installed -pip install --no-build-isolation -e . - -# HPU variant -BUILD_HPU_ONLY=1 pip install --no-build-isolation . -# or: python setup.py hpu install - -# XPU variant — install Intel PyTorch first -pip install torch --index-url https://download.pytorch.org/whl/xpu -pip install --no-build-isolation . -``` - -## Testing - -```bash -# CPU tests (most common during development) -pytest test/test_cpu/ -x -q - -# Single test -pytest test/test_cpu/ -k "test_name" -x -q - -# Hardware-specific -pytest test/test_cuda/ -pytest test/test_hpu/ --mode=lazy # or --mode=compile -pytest test/test_xpu/ -``` - -Test fixtures create tiny models (OPT-125M, Qwen-0.6B) at session scope — first run downloads them. - -## Code Style - -- **Line length: 120** (non-default) — enforced by black, isort, ruff, pylint -- **Formatter: black** (profile used by isort) -- **Import sorting: isort** with `profile=black`, first-party: `auto_round`, `auto_round_extension` -- **Linter: ruff** — rules `E4, E7, E9, F, NPY, FURB`; E501/E402/F401/F403 are intentionally ignored -- **License header**: Apache 2.0 auto-inserted into `.py/.yaml/.yml/.sh` under `auto_round/` and `auto_round_extension/` -- Pre-commit config: `.pre-commit-config.yaml` - -## Commit & PR Conventions - -- Conventional commits: `feat:`, `fix:`, `chore:`, `docs:`, `refactor:`, `test:` -- PRs target `main`, squash-merged -- **CN docs rule**: any change to a `.md` file must include a matching update to its `_CN` counterpart (e.g., `README.md` → `README_CN.md`) - -## Key Environment Variables - -- `BUILD_HPU_ONLY=1` — build HPU package variant -- `AR_USE_MODELSCOPE=1` — use ModelScope instead of HuggingFace for model downloads -- `FORCE_BF16=1` — force BF16 in tests (used in CI) - -## Source Layout - -- `auto_round/` — core library (AutoRound class, sign-SGD, exporters, eval, data types) -- `auto_round_extension/` — hardware backends (CUDA, HPU, IPEX/XPU, Triton, ARK, vLLM) -- `test/` — tests organized by hardware: `test_cpu/`, `test_cuda/`, `test_hpu/`, `test_xpu/` -- `examples/` — usage examples for different model types - -## Gotchas - -- `setup.py` forces `CC=CXX=g++` at import time -- Version is computed dynamically from git tags — untagged commits produce dev versions -- Some test dependencies (AutoAWQ, GPTQModel, llama-cpp) require manual git installs — see comments in `test/test_cuda/requirements.txt` diff --git a/auto_round/__main__.py b/auto_round/__main__.py index 656d7e2cf..17f07e078 100644 --- a/auto_round/__main__.py +++ b/auto_round/__main__.py @@ -11,880 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import argparse -import json -import os -import re -import sys - -import torch - -from auto_round.auto_scheme import AutoScheme -from auto_round.compressors.base import BaseCompressor -from auto_round.eval.eval_cli import EvalArgumentParser, eval, eval_task_by_task -from auto_round.eval.evaluation import run_model_evaluation -from auto_round.schemes import PRESET_SCHEMES, preset_name_to_scheme -from auto_round.utils import ( - clear_memory, - get_device_and_parallelism, - get_model_dtype, - parse_layer_config_arg, -) - -RECIPES = { - "default": {"batch_size": 8, "iters": 200, "seqlen": 2048, "nsamples": 128, "lr": None}, - "best": {"batch_size": 8, "iters": 1000, "seqlen": 2048, "nsamples": 512, "lr": None}, - "light": {"batch_size": 8, "iters": 50, "seqlen": 2048, "nsamples": 128, "lr": 5e-3}, - "rtn": {"batch_size": 8, "iters": 0, "seqlen": 2048, "nsamples": 1, "lr": None, "disable_opt_rtn": True}, - "opt_rtn": {"batch_size": 8, "iters": 0, "seqlen": 2048, "nsamples": 128, "lr": None, "disable_opt_rtn": False}, -} - - -class BasicArgumentParser(argparse.ArgumentParser): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.add_argument( - "model", - default=None, - nargs="?", - help="Path to the pre-trained model or model identifier from huggingface.co/models. " - "Examples: 'facebook/opt-125m', 'bert-base-uncased', or local path like '/path/to/model'", - ) - basic = self.add_argument_group("Basic Arguments") - basic.add_argument( - "--model_name", - "--model", - "--model_name_or_path", - default="facebook/opt-125m", - help="Path to the pre-trained model or model identifier from huggingface.co/models. " - "Examples: 'facebook/opt-125m', 'bert-base-uncased', or local path like '/path/to/model'", - ) - basic.add_argument("--model_dtype", default=None, help="model dtype used to load the pre-trained model") - basic.add_argument( - "--platform", - default="hf", - help="Platform to load the pre-trained model. Options: [hf, model_scope]." - " hf stands for huggingface and model_scope stands for model scope.", - ) - basic.add_argument( - "--scheme", - default="W4A16", - type=str, - # choices=["W4A16", "W2A16", "W3A16", "W8A16", "MXFP4", "MXFP8", "NVFP4", "FPW8A16", "FP8_STATIC"], - help="Quantization scheme to use. " - "W4A16: 4-bit weights with 16-bit activations (default). " - "Other options include W2A16, W3A16, W8A16, W8A8 for different bit widths, " - "and MXFP4/MXFP8/NVFP4 for different data type.", - ) - basic.add_argument( - "--algorithm", - default=None, - type=str.lower, - choices=["auto_round", "rtn", "awq"], - help="Quantization algorithm to use. " - "auto_round: SignSGD-based optimization (default when iters > 0). " - "rtn: Round-to-nearest (default when iters == 0). " - "awq: Activation-Aware Weight Quantization (AWQ smoothing + RTN).", - ) - basic.add_argument( - "--batch_size", - "--train_bs", - "--bs", - default=None, - type=int, - help="The batch size for tuning/calibration." - "Larger batch sizes may improve stability but require more memory.", - ) - basic.add_argument( - "--avg_bits", "--target_bits", default=None, type=float, help="for auto scheme, number of avg weight bits" - ) - basic.add_argument( - "--options", default=None, type=str, help="for auto scheme, options for auto scheme, e.g. 'W4A16,W8A16'" - ) - - basic.add_argument( - "--iters", - "--iter", - default=None, - type=int, - help="Number of iterations to tune each block. " - "More iterations may lead to better quantization quality but take longer.", - ) - basic.add_argument( - "--seqlen", - "--seq_len", - default=None, - type=int, - help="Sequence length of the calibration samples" - "Longer sequences capture more context but use more memory.", - ) - basic.add_argument( - "--nsamples", - "--nsample", - default=None, - type=int, - help="Number of calibration samples to use for quantization.", - ) - basic.add_argument( - "--device_map", - "--device", - "--devices", - default="0", - type=str, - help="The device to be used for tuning. " - "Currently, device settings support CPU, GPU, and HPU." - "The default is set to cuda:0," - "allowing for automatic detection and switch to HPU or CPU." - "set --device 0,1,2 to use multiple cards.", - ) - basic.add_argument( - "--dataset", - default="NeelNanda/pile-10k", - type=str, - help="Calibration dataset for quantization. " - "Should be a dataset from huggingface datasets or local path. ", - ) - basic.add_argument("--seed", default=42, type=int, help="Random seed for reproducibility.") - basic.add_argument("--adam", action="store_true", help="Use Adam optimizer instead of SignSGD.") - basic.add_argument( - "--low_gpu_mem_usage", - action="store_true", - help="Enable memory-efficient mode by offloading intermediate features to CPU. " - "Useful when working with large models that don't fit in GPU memory.", - ) - basic.add_argument( - "--low_cpu_mem_usage", - action="store_true", - help=( - "Deprecated: low CPU memory mode is enabled by default. " - "This flag is kept only for backward compatibility and has no effect " - "beyond explicitly re-enabling the default behavior." - ), - ) - basic.add_argument( - "--disable_low_cpu_mem_usage", - action="store_true", - help=("Disable low CPU memory mode. " "Use this flag to turn off the default low CPU memory behavior."), - ) - basic.add_argument( - "--format", - "--formats", - default="auto_round", - type=str, - help="Output format for the quantized model." - "'auto_round' is the recommended format" - "use command `auto_round list format` to show all supported formats with support scheme.", - ) - basic.add_argument( - "--output_dir", - default="./tmp_autoround", - type=str, - help="Directory to save the quantized model and related files", - ) - basic.add_argument( - "--not_use_best_mse", - action="store_true", - help="Disable using the iteration with best MSE loss during tuning.", - ) - basic.add_argument( - "--enable_torch_compile", action="store_true", help="Enable PyTorch compilation for faster execution. " - ) - basic.add_argument( - "--disable_trust_remote_code", - action="store_true", - help="Disable trusting remote code when loading models. " - "Use for security if you don't trust the model source.", - ) - tuning = self.add_argument_group("Tuning Arguments") - tuning.add_argument( - "--ignore_scale_zp_bits", - action="store_true", - help="for auto scheme whether ignore scale zp bits calculation ", - ) - tuning.add_argument( - "--lr", - default=None, - type=float, - help="Learning rate for tuning. " "If None, automatically sets to 1.0/iters. ", - ) - tuning.add_argument( - "--minmax_lr", - default=None, - type=float, - help="Learning rate specifically for min-max tuning. " "If None, uses the same value as --lr. ", - ) - tuning.add_argument( - "--momentum", - default=0, - type=float, - help="Momentum factor for the optimizer. Default is 0 (no momentum).", - ) - tuning.add_argument( - "--gradient_accumulate_steps", - default=1, - type=int, - help="Number of steps to accumulate gradients before updating weights. " - "Effectively increases batch size without requiring more GPU memory. " - "Useful for large models with limited memory.", - ) - tuning.add_argument( - "--nblocks", - default=1, - type=int, - help="Number of blocks to tune simultaneously. " - "Higher values may speed up tuning but require more memory. " - "Recommended to keep at 1 for stability with large models.", - ) - tuning.add_argument( - "--scale_dtype", - default=None, - choices=["fp16", "float16", "bf16", "bfloat16", "fp32", "float32"], - help="Data type for quantization scales. " - "fp16/bf16: lower memory, fp32: higher precision. " - "Choose based on your hardware support and accuracy requirements.", - ) - tuning.add_argument( - "--disable_amp", - action="store_true", - help="Disable Automatic Mixed Precision (AMP). " - "AMP speeds up training but may affect numerical stability in some cases.", - ) - tuning.add_argument( - "--disable_minmax_tuning", - action="store_true", - help="Disable weight min-max range tuning. " - "Not recommended as it may significantly reduce quantization accuracy.", - ) - tuning.add_argument( - "--enable_norm_bias_tuning", action="store_true", help="Enable normalization layer bias tuning. " - ) - tuning.add_argument( - "--disable_quanted_input", - action="store_true", - help="Use original (non-quantized) inputs for each block instead of" - " quantized outputs from previous blocks. ", - ) - tuning.add_argument( - "--to_quant_block_names", - default=None, - type=str, - help="Specific blocks to quantize, separated by commas. " - "Example: 'block1,block2,block3'. " - "If None, all blocks will be quantized.", - ) - tuning.add_argument( - "--enable_alg_ext", - action="store_true", - help="Enable experimental algorithms that may provide better quantization results. " - "These are newer methods that might improve accuracy but are less tested.", - ) - tuning.add_argument( - "--disable_deterministic_algorithms", - action="store_true", - help="deprecated, disable torch deterministic algorithms.", - ) - tuning.add_argument( - "--enable_deterministic_algorithms", - action="store_true", - help="Enable PyTorch deterministic algorithms for reproducible results. ", - ) - group_opt_rtn = tuning.add_mutually_exclusive_group() - group_opt_rtn.add_argument( - "--disable_opt_rtn", - action="store_const", - const=True, - dest="disable_opt_rtn", - default=None, - help="Disable optimization for RTN (Round-To-Nearest) mode when iters=0. " - "RTN is fast but less accurate; keeping optimization enabled is recommended.", - ) - group_opt_rtn.add_argument( - "--enable_opt_rtn", - action="store_const", - const=False, - dest="disable_opt_rtn", - help="Enable optimization for RTN mode when iters=0.", - ) - group_model_free = tuning.add_mutually_exclusive_group() - group_model_free.add_argument( - "--model_free", - action="store_true", - help="Force model-free quantization mode. " - "Downloads and quantizes safetensors files directly using RTN, " - "without loading the full model into memory. " - "Only supports auto_round output format.", - ) - group_model_free.add_argument( - "--disable_model_free", - action="store_true", - help="Disable the automatic model-free routing that activates when " - "--iters 0 --disable_opt_rtn is combined with a supported INT WOQ scheme. " - "Use this to force the regular AutoRound flow.", - ) - - awq_group = self.add_argument_group("AWQ Arguments") - awq_group.add_argument( - "--duo_scaling", - default=True, - type=lambda s: "both" if s.lower() == "both" else s.lower() in ("true", "1", "yes"), - help="Whether to use both activations and weights for AWQ scaling. " - "Options: true/false/'both'. 'both' searches both modes and picks the best. " - "(default: True).", - ) - awq_group.add_argument( - "--n_grid", - default=20, - type=int, - help="Number of grid points for AWQ scaling ratio search (default: 20).", - ) - - scheme = self.add_argument_group("Scheme Arguments") - scheme.add_argument("--bits", default=None, type=int, help="Number of bits for weight quantization. ") - scheme.add_argument( - "--group_size", - default=None, - type=lambda s: int(s) if s.lstrip("-").isdigit() else tuple([int(x.strip()) for x in s.split(",")]), - help="Group size for weight quantization.", - ) - scheme.add_argument("--asym", action="store_true", help="Use asymmetric quantization instead of symmetric.") - scheme.add_argument( - "--act_asym", action="store_true", help="Use asymmetric quantization for activation instead of symmetric." - ) - scheme.add_argument( - "--data_type", - "--dtype", - default=None, - help="Data type for quantization. Options: 'int' for integer, 'mx_fp' for mixed floating-point, etc.", - ) - scheme.add_argument( - "--act_bits", - default=None, - type=int, - help="Number of bits for activation quantization. " - "Activation quantization significantly impacts performance and accuracy.", - ) - scheme.add_argument( - "--act_group_size", - default=None, - type=int, - help="Group size for activation quantization. " "Similar to weight group size but for activations.", - ) - scheme.add_argument( - "--act_data_type", "--act_dtype", default=None, type=str, help="Data type for activation quantization. " - ) - scheme.add_argument( - "--disable_act_dynamic", action="store_true", help="Use static instead of dynamic activation quantization. " - ) - scheme.add_argument( - "--layer_config", - default=None, - type=str, - help="Per-layer quantization config for missing tensors (e.g., MTP layers) as a JSON string. " - "Keys are name prefixes, values are config dicts with optional bits/group_size/sym. " - 'Example: "{mtp:{bits:8,data_type:int},mtp.fc:{bits:16}}". ' - "These settings are saved to extra_config and override the global quantization config.", - ) - scheme.add_argument( - "--shared_layers", - type=str, - nargs="+", - action="append", - default=None, - help="[mix-precision] ensure that listed layers are using same data type for quantization", - ) - scheme.add_argument( - "--quant_lm_head", - action="store_true", - help="Quantize the lm_head. " "Usually kept in higher precision for better output quality.", - ) - scheme.add_argument( - "--ignore_layers", - "--fp_layers", - default="", - type=str, - help="List of layer names to keep in original precision (not quantized). " - "Useful for preserving critical layers. Separate multiple names with commas.", - ) - scheme.add_argument( - "--static_kv_dtype", - default=None, - type=str, - choices=["fp8", "float8_e4m3fn"], - help="Data type for static quantize key and value. ", - ) - - scheme.add_argument( - "--static_attention_dtype", - default=None, - type=str, - choices=["fp8", "float8_e4m3fn"], - help="Data type for static quantize attention. ", - ) - scheme.add_argument( - "--rotation_type", - default=None, - type=str, - choices=["hadamard", "random_hadamard"], - help="Research feature: applies a rotation (e.g., Hadamard) to reduce activation/weight outliers", - ) - gguf = self.add_argument_group("Double Quant Arguments") - gguf.add_argument( - "--super_group_size", default=None, type=int, help="Super group size for double quantization." - ) - gguf.add_argument( - "--super_bits", - default=None, - type=int, - help="Number of bits for scale and zero-point quantization in double quantization. ", - ) - - ## ======================= eval ======================= - eval_args = self.add_argument_group("eval arguments") - eval_args.add_argument( - "--tasks", - "--task", - nargs="?", - const="lambada_openai,hellaswag,winogrande,piqa,mmlu,wikitext,truthfulqa_mc1," - "openbookqa,boolq,arc_easy,arc_challenge", - default=None, - help="LM-Evaluation-Harness tasks to run. " - "Specify specific tasks like 'mmlu,wikitext' for custom evaluation.", - ) - eval_args.add_argument("--eval_bs", default=None, type=int, help="Batch size for evaluation.") - eval_args.add_argument( - "--limit", - type=float, - default=None, - metavar="N|0=0.4.2", - "lm-eval is required for evaluation, please install it with `pip install 'lm-eval>=0.4.2'`", - ) - - from auto_round.utils import detect_device, get_library_version, logger - - if args.low_cpu_mem_usage: - logger.warning( - "`low_cpu_mem_usage` is deprecated and is now enabled by default. " - "To disable it, use `--disable_low_cpu_mem_usage`." - ) - - if args.format is None: - args.format = "auto_round" - - formats = args.format.lower().replace(" ", "").split(",") - from auto_round.utils import SUPPORTED_FORMATS - - for format in formats: - if format not in SUPPORTED_FORMATS: - raise ValueError(f"{format} is not supported, we only support {SUPPORTED_FORMATS}") - - if "auto_gptq" in args.format and args.asym is True: - logger.warning( - "the auto_gptq kernel has issues with asymmetric quantization. " - "It is recommended to use sym quantization or --format='auto_round'" - ) - - if "marlin" in args.format and args.asym is True: - raise RuntimeError("marlin backend only supports sym quantization, please remove --asym") - - device_str, use_auto_mapping = get_device_and_parallelism(args.device_map) - - if args.enable_torch_compile: - logger.info( - "`torch.compile` is enabled to reduce tuning costs. " - "If it causes issues, you can disable it by removing `--enable_torch_compile` argument." - ) - - model_name = args.model - if model_name[-1] == "/": - model_name = model_name[:-1] - logger.info(f"start to quantize {model_name}") - - from auto_round import AutoRound - - if "bloom" in model_name: - args.low_gpu_mem_usage = False - - if args.quant_lm_head: - for format in formats: - # MLX (native ``mlx`` and ``auto_round:mlx``) supports per-layer - # mixed-bit quantization including lm_head; treat it the same as - # auto_round / fake here. - if "auto_round" not in format and "fake" not in format and "mlx" not in format: - auto_round_formats = [s for s in SUPPORTED_FORMATS if s.startswith("auto_round") or s == "mlx"] - raise ValueError( - f"{format} is not supported for lm-head quantization, please change to {auto_round_formats}" - ) - - enable_torch_compile = True if "--enable_torch_compile" in sys.argv else False - sym = None # the default value should be None now - if args.asym: # if the scheme is asym, how to set it to sym is an issue - sym = False - act_dynamic = None - if args.disable_act_dynamic: - act_dynamic = False - scheme = args.scheme.upper() - if scheme not in PRESET_SCHEMES: - raise ValueError(f"{scheme} is not supported. only {PRESET_SCHEMES.keys()} are supported ") - if args.disable_deterministic_algorithms: - logger.warning( - "default not use deterministic_algorithms. disable_deterministic_algorithms is deprecated," - " please use enable_deterministic_algorithms instead. " - ) - - from auto_round.compressors import ( - DiffusionExtraConfig, - ExtraConfig, - MLLMExtraConfig, - SchemeExtraConfig, - TuningExtraConfig, - ) - - extra_config = ExtraConfig() - tuning_config = TuningExtraConfig( - amp=not args.disable_amp, - disable_opt_rtn=args.disable_opt_rtn, - enable_alg_ext=args.enable_alg_ext, - enable_minmax_tuning=not args.disable_minmax_tuning, - enable_norm_bias_tuning=args.enable_norm_bias_tuning, - enable_quanted_input=not args.disable_quanted_input, - enable_deterministic_algorithms=args.enable_deterministic_algorithms, - lr=args.lr, - minmax_lr=args.minmax_lr, - nblocks=args.nblocks, - to_quant_block_names=args.to_quant_block_names, - scale_dtype=args.scale_dtype, - ) - scheme_config = SchemeExtraConfig( - bits=args.bits, - group_size=args.group_size, - sym=sym, - data_type=args.data_type, - act_bits=args.act_bits, - act_group_size=args.act_group_size, - act_data_type=args.act_data_type, - act_dynamic=act_dynamic, - act_sym=None if not args.asym else False, - super_bits=args.super_bits, - super_group_size=args.super_group_size, - quant_lm_head=args.quant_lm_head, - ignore_layers=args.ignore_layers, - static_kv_dtype=args.static_kv_dtype, - static_attention_dtype=args.static_attention_dtype, - ) - mllm_config = MLLMExtraConfig( - quant_nontext_module=args.quant_nontext_module, extra_data_dir=args.extra_data_dir, template=args.template - ) - diffusion_config = DiffusionExtraConfig( - guidance_scale=args.guidance_scale, - num_inference_steps=args.num_inference_steps, - generator_seed=args.generator_seed, - ) - extra_config.tuning_config = tuning_config - extra_config.scheme_config = scheme_config - extra_config.mllm_config = mllm_config - extra_config.diffusion_config = diffusion_config - - layer_config = {} - if args.layer_config: - layer_config = parse_layer_config_arg(args.layer_config) - args.layer_config = layer_config - - low_cpu_mem_usage = True - if args.disable_low_cpu_mem_usage: - low_cpu_mem_usage = False - - if args.avg_bits is not None: - if args.options is None: - raise ValueError("please set --options for auto scheme") - if enable_torch_compile: - logger.warning( - "`enable_torch_compile=True` with AutoScheme may cause compile errors " - "on some models. If so, try removing `--enable_torch_compile`." - ) - scheme = AutoScheme( - options=args.options, - avg_bits=args.avg_bits, - shared_layers=args.shared_layers, - ignore_scale_zp_bits=args.ignore_scale_zp_bits, - low_gpu_mem_usage=True, # force it to be True as it uses much smaller vram but similar time cost - low_cpu_mem_usage=low_cpu_mem_usage, - ) - rot_config = None - if args.rotation_type: - from auto_round.algorithms.transforms.rotation.config import RotationConfig - - rot_config = RotationConfig(hadamard_type=args.rotation_type) - - autoround: BaseCompressor = AutoRound( - model=model_name, - platform=args.platform, - format=args.format, - scheme=scheme, - dataset=args.dataset, - iters=args.iters, - seqlen=args.seqlen, - nsamples=args.nsamples, - batch_size=args.batch_size, - gradient_accumulate_steps=args.gradient_accumulate_steps, - low_gpu_mem_usage=args.low_gpu_mem_usage, - low_cpu_mem_usage=low_cpu_mem_usage, - device_map=args.device_map, - enable_torch_compile=enable_torch_compile, - seed=args.seed, - not_use_best_mse=args.not_use_best_mse, - enable_adam=args.adam, - extra_config=extra_config, - layer_config=layer_config, - model_dtype=args.model_dtype, - momentum=args.momentum, - trust_remote_code=not args.disable_trust_remote_code, - rotation_config=rot_config, - model_free=args.model_free, - disable_model_free=args.disable_model_free, - algorithm=getattr(args, "algorithm", None), - **( - {"duo_scaling": args.duo_scaling, "n_grid": args.n_grid} - if getattr(args, "algorithm", None) == "awq" - else {} - ), - ) - - # ======================= Quantize and save model ======================= - # Export directory is now derived automatically inside quantize_and_save via - # BaseCompressor._get_export_dir(), so we only need to pass the base output_dir. - model, folders = autoround.quantize_and_save(args.output_dir, format=args.format) # pylint: disable=E1101 - tokenizer = autoround.tokenizer # pylint: disable=E1101 - clear_memory() - - # ======================= Model evaluation ======================= - run_model_evaluation(model, tokenizer, autoround, folders, formats, device_str, args) - - -def setup_eval_parser(): - parser = EvalArgumentParser() - args = parser.parse_args() - return args - - -def run_eval(): - from auto_round.logger import logger - from auto_round.utils import is_gguf_model, is_mllm_model - - args = setup_eval_parser() - assert args.model or args.model_name, "[model] or --model MODEL_NAME should be set." - - if args.model is None: - args.model = args.model_name - if "llama" in args.model.lower() and not args.add_bos_token: - logger.warning("set add_bos_token=True for llama model.") - args.add_bos_token = True - if not is_gguf_model(args.model) and is_mllm_model(args.model): - args.mllm = True - - if args.eval_task_by_task: - eval_task_by_task( - model=args.model, - device=args.device_map, - limit=args.limit, - tasks=args.tasks, - batch_size=args.eval_bs, - trust_remote_code=not args.disable_trust_remote_code, - eval_model_dtype=args.eval_model_dtype, - add_bos_token=args.add_bos_token, - ) - else: - eval(args) - - -def run(): - if "list" in sys.argv or "--list" in sys.argv: - if "list" in sys.argv: - sys.argv.remove("list") - if "--list" in sys.argv: - sys.argv.remove("--list") - list_item() - exit() - if "--eval" in sys.argv or "eval" in sys.argv: - if "--eval" in sys.argv: - sys.argv.remove("--eval") - if "eval" in sys.argv: - sys.argv.remove("eval") - run_eval() - else: - start() - - -def run_best(): - start("best") - - -def run_light(): - start("light") - - -def run_rtn(): - start("rtn") - - -def run_opt_rtn(): - start("opt_rtn") - - -def run_mllm(): - run() +# Thin shim — all logic lives in auto_round.cli.main. +# This file exists solely to satisfy setup.cfg console_scripts entry points. +from auto_round.cli.main import run, run_best, run_eval, run_light, run_mllm, run_opt_rtn, run_rtn # noqa: F401 if __name__ == "__main__": run() diff --git a/auto_round/alg_ext.py b/auto_round/alg_ext.py deleted file mode 100644 index e6137d1a5..000000000 --- a/auto_round/alg_ext.py +++ /dev/null @@ -1,872 +0,0 @@ -# Copyright (c) 2025 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import types -from contextlib import nullcontext -from functools import lru_cache, partial -from typing import Any, Callable, Union - -import torch -import transformers -from torch import autocast -from torch.functional import F - -from auto_round import AutoRound -from auto_round.compressors.utils import check_need_act_calibration, is_nv_fp, is_wint4aint4 -from auto_round.data_type.int import search_scales -from auto_round.data_type.mxfp import MXFP_FORMAT_CACHE, quant_element -from auto_round.data_type.nvfp import FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX, ref_nvfp4_quant -from auto_round.data_type.utils import floor_ste, reshape_pad_tensor_by_group_size, revert_tensor_by_pad, round_ste -from auto_round.logger import logger -from auto_round.utils import SUPPORTED_LAYER_TYPES, check_to_quantized, compile_func, get_reciprocal, set_module -from auto_round.wrapper import NORM_MAPPING, WrapperLinear - -__all__ = ["wrapper_autoround"] - - -def wrapper_autoround(cls: AutoRound): - cls._register_act_max_hook = types.MethodType(_register_act_max_hook_ext, cls) - if ( - cls.sym - and cls.enable_alg_ext - and cls.super_group_size is None - and ((cls.data_type.startswith("int")) or cls.data_type.startswith("mx") or cls.data_type.startswith("nv")) - ): - if cls.bits > 2 and (not cls.data_type.startswith("mx") or not cls.data_type.startswith("nv")): - logger.warning_once( - "algorithm extension has only undergone limited validation on " - "W2A16,INT4, MXFP4 and NVFP4; use with caution." - ) - cls._get_loss = types.MethodType(_get_loss_ext, cls) - setattr(cls, "wrapper_block", wrapper_block_v2) - if cls.data_type.endswith("dq"): - setattr(cls, "wrapper_block", dq_wrapper_block) - - -def get_abs_top_percent_mask(x: torch.Tensor, percent: float = 1.0): - """ - Return a mask for the top `percent` absolute values in x and its inverse. - - Args: - x (torch.Tensor): Input tensor. - percent (float): Percentage of elements to select (0~100). - - Returns: - mask (torch.BoolTensor): True for top `percent` abs elements. - inv_mask (torch.BoolTensor): Inverse of mask. - """ - flat = x.view(-1) - k = max(1, int(flat.numel() * percent / 1000)) - _, idx = torch.topk(torch.abs(flat), k) - - mask = torch.zeros_like(flat, dtype=torch.bool) - mask[idx] = True - mask = mask.view_as(x) - return mask, ~mask - - -def _get_loss_ext( - self: AutoRound, - output_q: torch.Tensor, - current_output: torch.Tensor, - indices: torch.Tensor, - mse_loss: Callable, - device: Union[str, torch.device] = "cpu", -): - _, mask = get_abs_top_percent_mask(torch.abs(output_q - current_output)) - autocast_ctx = autocast(device_type=str(device).split(":")[0], dtype=self.amp_dtype) if self.amp else nullcontext() - if self.attention_mask: - tmp_attention_mask = [self.attention_mask[i] for i in indices] - tmp_attention_mask = torch.cat(tmp_attention_mask, dim=0).to(device) - tmp_attention_mask.unsqueeze_(-1) - - with autocast_ctx: - loss = torch.mean( - (torch.abs(output_q.to(torch.float32) - current_output.to(torch.float32)) * tmp_attention_mask * mask) - ** 2 - ) - else: - with autocast_ctx: - loss = torch.mean((torch.abs(output_q.to(torch.float32) - current_output.to(torch.float32)) * mask) ** 2) - - return loss - - -def quant_tensor_sym( - tensor, - bits=4, - group_size=-1, - v=0, - min_scale=1.0, - max_scale=1.0, - scale_dtype=torch.float16, - tensor_min=None, - tensor_max=None, - q_scale_thresh=1e-5, - init_scale=None, - **kwargs, -): - """Quantize and de-quantize tensor symmetrically (full-range, llama.cpp style). - - ``maxq`` is computed via ``int(2.0 ** (bits - 1))`` so it stays a plain - Python int constant inside the inductor graph and Triton never tries to - lower a ``2 ** SymInt`` through ``libdevice.pow(fp32, i64)``. - """ - tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size) - maxq = int(2.0 ** (bits - 1)) - scale = init_scale * max_scale.unsqueeze(dim=-1) - int_w = round_ste(tensor / scale + v) - q = torch.clamp(int_w, -maxq, maxq - 1) - qdq_result = (scale * q).to(tensor.dtype) - qdq_result = revert_tensor_by_pad(qdq_result, orig_shape=orig_shape, pad_len=pad_len) - return qdq_result, scale, maxq - - -@torch.inference_mode() -def qdq_mxfp(tensor, max_val, max_norm, emax, ebits, mbits): - shared_exp = torch.where(max_val == 0, torch.ones_like(max_val), torch.log2(max_val)) - shared_exp = torch.floor(shared_exp) - scale_emax = (1 << (8 - 1)) - 1 - shared_exp = (shared_exp - emax).clamp(min=-scale_emax, max=scale_emax) - - scale = torch.pow(2.0, shared_exp) - tensor = tensor / scale - tensor = torch.clamp(tensor, min=-max_norm, max=max_norm) - tensor = quant_element(tensor, ebits, mbits, max_norm) - - tensor = tensor * scale - return tensor - - -def mx_init(tensor, bits, qw=None): - data_type = "mx_fp" + str(bits) - ebits, mbits, emax, max_norm, min_norm = MXFP_FORMAT_CACHE[data_type] - tensor = tensor.to(torch.float32) - max_val, _ = torch.max(torch.abs(tensor), dim=-1, keepdim=True) - qdq_t = qdq_mxfp(tensor, max_val, max_norm, emax, ebits, mbits) - best_loss = torch.sum((qdq_t - tensor) ** 2 * qw, dim=-1) - scales = torch.ones_like(max_val) - tmp_scale = 0.5 - while tmp_scale < 1.51: - if tmp_scale == 1.0: - continue - max_val_new = max_val * tmp_scale - tmp_qdq_t = qdq_mxfp(tensor, max_val_new, max_norm, emax, ebits, mbits) - loss = torch.sum((tmp_qdq_t - tensor) ** 2 * qw, dim=-1) - replace_id = loss < best_loss - scales[replace_id] = (torch.ones_like(scales) * tmp_scale)[replace_id] - best_loss[replace_id] = loss[replace_id] - tmp_scale += 0.01 - return scales - - -def nv_fp4(tensor, bits=4, group_size=16, v=0, global_scale=None, max_scale=1.0, init_scale=1.0, **kwargs): - orig_dtype = tensor.dtype - tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size) - if global_scale is None: - tensor_max = tensor.abs().max().to(torch.float32) - global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX * get_reciprocal(tensor_max) - global_scale = global_scale.to(device=tensor.device, dtype=torch.float32) - if isinstance(init_scale, torch.Tensor): - init_scale = init_scale.view(-1) - max_scale = max_scale * init_scale - qdq_res, scale = ref_nvfp4_quant(tensor, global_scale, group_size, v, scale_coeff=max_scale) - qdq_res = revert_tensor_by_pad(qdq_res, orig_shape=orig_shape, pad_len=pad_len) - return qdq_res.to(orig_dtype), scale, None - - -def nv_init(tensor, bits, qw=None): - tensor = tensor.to(torch.float32) - qdq_t, dummy_scale, _ = nv_fp4(tensor, bits=4, group_size=16, v=0, max_scale=1.0) - best_loss = torch.sum((qdq_t - tensor) ** 2 * qw, dim=-1) - scales = torch.ones_like(dummy_scale) - tmp_scale = 0.5 - while tmp_scale < 1.51: - if tmp_scale == 1.0: - continue - scales_new = torch.ones_like(dummy_scale) * tmp_scale - tmp_qdq_t, _, _ = nv_fp4(tensor, bits=4, group_size=16, v=0, max_scale=scales_new) - loss = torch.sum((tmp_qdq_t - tensor) ** 2 * qw, dim=-1) - replace_id = loss < best_loss - scales[replace_id] = scales_new[replace_id] - best_loss[replace_id] = loss[replace_id] - tmp_scale += 0.01 - return scales - - -def quant_mx( - tensor, - bits=4, - group_size=-1, - v=0, - max_scale=1.0, - init_scale=1.0, - mantissa_rounding="even", - data_type="mx_fp", - **kwargs, -): - """Quantize the given tensor using the specified parameters. - - This function performs quantization on the `tensor` tensor according to the - given bit width (`bits`), data type (`data_type`), and additional parameters. - The quantization process involves scaling the tensor values and adjusting - the exponent and mantissa to fit within the specified format. - - Args: - tensor (torch.Tensor): The tensor containing the tensors to be quantized. - bits (int): The bit width to be used for quantization. - group_size (int): The group size of sharing scale and exponent. - data_type (str): The data type for quantization (e.g., 'mx_fp4'). - v (float): A value used for adjusting the tensors. - max_scale (float or torch.Tensor): The maximum scale to be applied to the tensors. - mantissa_rounding (str): rounding method for mantissa,currently support even,nearest,floor - - Returns: - tuple: A tuple containing the quantized tensors, shared exponent, and None (reserved for future use). - - Raises: - KeyError: If `data_type` is not found in `MXFP_FORMAT_CACHE`. - """ - tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size) - data_type = data_type if data_type in MXFP_FORMAT_CACHE else "mx_fp" + str(bits) - ebits, mbits, emax, max_norm, min_norm = MXFP_FORMAT_CACHE[data_type] - orig_dtype = tensor.dtype - tensor = tensor.to(torch.float32) - max_val, _ = torch.max(torch.abs(tensor), dim=-1, keepdim=True) - if isinstance(max_scale, torch.Tensor): - max_val *= init_scale * (max_scale.unsqueeze(dim=-1)).to(tensor.device) - else: - max_val *= init_scale * max_scale - - # shared_exp = torch.log2(shared_exp + FP32_MIN_NORMAL * (shared_exp == 0).type(shared_exp.dtype)) - shared_exp = torch.where(max_val == 0, torch.ones_like(max_val), torch.log2(max_val)) - shared_exp = floor_ste(shared_exp) - scale_emax = (1 << (8 - 1)) - 1 - shared_exp = (shared_exp - emax).clamp(min=-scale_emax, max=scale_emax) - - scale = torch.pow(2.0, shared_exp) - tensor = tensor / scale + v - tensor = torch.clamp(tensor, min=-max_norm, max=max_norm) - tensor = quant_element(tensor, ebits, mbits, max_norm, mantissa_rounding) - - tensor = tensor * scale - tensor = revert_tensor_by_pad(tensor, orig_shape=orig_shape, pad_len=pad_len) - return tensor.to(orig_dtype), shared_exp.to(orig_dtype), None - - -class WrapperLinearV2(WrapperLinear): - """A wrapper for linear/conv1d layers to enable quantization and tuning. - - This module wraps an existing linear or conv1d layer and provides additional functionality - for quantization, parameter tuning, and activation/bias normalization. - - Args: - orig_layer (torch.nn.Module): The original layer to be wrapped (linear or conv1d). - enable_minmax_tuning (bool): Whether to enable min-max scale tuning. - enable_norm_bias_tuning (bool): Whether to enable normalization and tuning of the bias term. - device (str): Device on which to run computations (e.g., 'cpu' or 'cuda'). - """ - - def __init__( - self, - orig_layer, - enable_minmax_tuning=True, - enable_norm_bias_tuning=False, - device="cpu", - enable_round_tuning=True, - enable_torch_compile=False, - disable_opt_rtn=True, # TODO does not support it - **kwargs, - ): - """Initializes the WrapperLinear module. - - Args: - orig_layer (torch.nn.Module): The original layer to wrap. - enable_minmax_tuning (bool): Whether to enable min-max scale tuning. - enable_norm_bias_tuning (bool): Whether to enable normalization and tuning for the bias term. - device (str): The computation device, such as 'cpu' or 'cuda'. - """ - super(WrapperLinearV2, self).__init__( - orig_layer=orig_layer, - enable_minmax_tuning=enable_minmax_tuning, - enable_norm_bias_tuning=enable_norm_bias_tuning, - device=device, - enable_round_tuning=enable_round_tuning, - enable_torch_compile=enable_torch_compile, - disable_opt_rtn=disable_opt_rtn, - **kwargs, - ) - - def _init_tuning_params_and_quant_func(self): - """Initializes tuning parameters and quantization functions. - - This method sets up required parameters and functions for weight quantization, - activation quantization, and bias/normalization. - """ - super()._init_tuning_params_and_quant_func() - - orig_weight = getattr(self.orig_layer, "get_weight", lambda: self.orig_layer.weight)() - weight_reshape, _, _ = reshape_pad_tensor_by_group_size(orig_weight.data, self.orig_layer.group_size) - if hasattr(self.orig_layer, "imatrix"): # MOE model may have no imatrix - imatrix = self.orig_layer.imatrix.reshape(1, -1) - imatrix = reshape_pad_tensor_by_group_size(imatrix, self.orig_layer.group_size, val=1e-5)[0].view(1, -1) - imatrix = imatrix.expand(weight_reshape.numel() // imatrix.numel(), -1) - imatrix = imatrix.reshape(weight_reshape.shape) - imatrix = imatrix.to(orig_weight.device) - else: - imatrix = 1.0 - if self.orig_layer.data_type.startswith("int"): - self.init_scale = search_scales(weight_reshape, self.orig_layer.bits, imatrix) - self.init_scale = torch.where( - self.init_scale < 0, - torch.clamp(self.init_scale, max=-self.q_scale_thresh), - torch.clamp(self.init_scale, min=self.q_scale_thresh), - ) - elif self.orig_layer.data_type.startswith("mx"): - self.init_scale = mx_init(weight_reshape, self.orig_layer.bits, imatrix) - elif self.orig_layer.data_type.startswith("nv"): - self.init_scale = nv_init(weight_reshape, self.orig_layer.bits, imatrix) - else: - self.init_scale = 1.0 - self.orig_layer.imatrix = None - delattr(self.orig_layer, "imatrix") - - # self.weight_quant_func, self.data_type = get_quant_func(orig_layer.data_type, orig_layer.bits, orig_layer.sym) - if self.orig_layer.data_type.startswith("int"): - self.weight_quant_func = quant_tensor_sym - elif self.orig_layer.data_type.startswith("mx"): - self.weight_quant_func = quant_mx - elif self.orig_layer.data_type.startswith("nv"): - self.weight_quant_func = nv_fp4 - else: - logger.error("unsupported dtype") - exit(-1) - self.data_type = self.orig_layer.data_type - if self.enable_torch_compile: - self.weight_quant_func = compile_func(self.weight_quant_func, self.device) - - def _qdq_weight(self, value, min_scale, max_scale): - """Quantizes and dequantizes weights with tuning parameters. - - Args: - value (torch.Tensor): Value added for rounding for tuning. - min_scale (torch.Tensor): Minimum scale for the min value of quantization. - max_scale (torch.Tensor): Maximum scale for the max value of quantization. - - Returns: - tuple: Quantized weight, scale, and zero point. - """ - if self.orig_layer.bits >= 16: - return self.orig_layer.weight, None, None - min_scale.data.clamp_(0, 2.0) # TODO changed - max_scale.data.clamp_(0, 2.0) - weight = self.orig_layer.weight - if weight.device.type == "meta": - weight = self.orig_layer.get_weight().to(self.device) - if isinstance(self.orig_layer, transformers.pytorch_utils.Conv1D): - weight = weight.t() - - quant_kwargs = {} - if hasattr(self.orig_layer, "super_bits"): - quant_kwargs["super_bits"] = self.orig_layer.super_bits - quant_kwargs["super_group_size"] = self.orig_layer.super_group_size - - weight_q, scale, zp = self.weight_quant_func( - weight, - bits=self.orig_layer.bits, - group_size=self.orig_layer.group_size, - v=value, - min_scale=min_scale, - max_scale=max_scale, - scale_dtype=self.orig_layer.scale_dtype, - tensor_min=self.weight_min, - tensor_max=self.weight_max, - data_type=self.data_type, - q_scale_thresh=self.q_scale_thresh, - global_scale=getattr(self, "weight_global_scale", None), - init_scale=self.init_scale, - **quant_kwargs, - ) - weight_q = weight_q.to(weight.dtype) - if isinstance(self.orig_layer, transformers.pytorch_utils.Conv1D): - weight_q = weight_q.t() - return weight_q, scale, zp - - -def wrapper_block_v2(block, enable_minmax_tuning, enable_norm_bias_tuning, device="cpu", **kwargs): - """Wraps the layers in the given block with a custom Wrapper module. - - Args: - block: The input block containing linear and conv1d layers to be wrapped. - enable_minmax_tuning: A boolean indicating whether min-max tuning is enabled. - - Returns: - list: A list of names of the wrapped layers and unwrapped layers. - """ - quantized_layers = [] - unquantized_layers = [] - for n, m in block.named_modules(): - if isinstance(m, SUPPORTED_LAYER_TYPES): - if not check_to_quantized(m): - unquantized_layers.append(n) - continue - new_m = WrapperLinearV2( - m, - enable_minmax_tuning=enable_minmax_tuning, - enable_norm_bias_tuning=enable_norm_bias_tuning, - device=device, - **kwargs, - ) - set_module(block, n, new_m) - quantized_layers.append(n) - - elif enable_norm_bias_tuning: - if "norm" in m.__class__.__name__.lower(): - if m.__class__.__name__ in NORM_MAPPING.keys(): - wrapper_layer_class = NORM_MAPPING[m.__class__.__name__] - new_m = wrapper_layer_class(m, device=device) - set_module(block, n, new_m) - elif "RMSNorm" in m.__class__.__name__: - logger.warning_once( - f"use LlamaRMSNorm to wrap {m.__class__.__name__}, please check the correctness yourself" - ) - wrapper_layer_class = NORM_MAPPING["LlamaRMSNorm"] - new_m = wrapper_layer_class(m, device=device) - set_module(block, n, new_m) - else: - logger.warning_once(f"{m.__class__.__name__} is not supported") - return quantized_layers, unquantized_layers - - -def _register_act_max_hook_ext(self, model): - - def get_act_max_hook(module, input, output): - if isinstance(input, (tuple, list)): - input = input[0] - if input.numel() == 0: - return # as no needs for act_max update - input, _, _ = reshape_pad_tensor_by_group_size(input, self.act_group_size) - act_max = torch.max(torch.abs(input), dim=-1).values - if not hasattr(module, "act_max") or module.act_max.numel() == 0: - module.act_max = act_max - else: - act_max = act_max.to(module.act_max.device) - if is_nv_fp(self.act_data_type): ## for nvfp per-tensor input_global_scale calculation usage - module.act_max = torch.max(torch.tensor([act_max.max(), module.act_max.max()], device=act_max.device)) - else: - module.act_max = torch.max(act_max, module.act_max) - - def get_imatrix_hook(module, input, output): - input = input[0] if isinstance(input, (tuple, list)) else input - flattened = input.reshape(-1, input.shape[-1]).to(torch.float32) - squared = torch.sum(torch.pow(flattened, 2), dim=0).to(torch.float32) - - if not hasattr(module, "imatrix"): - module.imatrix = squared - else: - module.imatrix += squared.to(module.imatrix.device) - - hook_handles = [] - - for n, m in model.named_modules(): - if isinstance(m, self.supported_types) and check_to_quantized(m): - if not is_wint4aint4(self): # INT4 no imatrix is much better - hook = m.register_forward_hook(get_imatrix_hook) - hook_handles.append(hook) - - if ( - hasattr(m, "act_dynamic") - and check_need_act_calibration(m.act_dynamic, m.act_data_type, m.act_bits) - and check_to_quantized(m) - ): - hook = m.register_forward_hook(get_act_max_hook) - hook_handles.append(hook) - continue - - # for whole model, RTN - if n in self.layer_config: - config = self.layer_config[n] - act_dynamic = config.get("act_dynamic", True) - act_data_type = config.get("act_data_type", None) - act_bits = config.get("act_bits", 16) - if ( - config["bits"] <= 8 - and check_need_act_calibration(act_dynamic, act_data_type, act_bits) - and check_to_quantized(config) - ): - hook = m.register_forward_hook(get_act_max_hook) - hook_handles.append(hook) - continue - return hook_handles - - -def _dq_asym_qdq(tensor, scale, wmin, bits, group_size, v=0): - """Pure asym double-quant qdq math given precomputed scale/wmin. - - ``maxq`` is computed via ``int(2.0 ** bits) - 1`` so that any SymInt - handling for ``bits`` does not produce a ``libdevice.pow(fp32, i64)`` call - in Triton (which lacks that overload). The final ``int(...)`` cast keeps - ``maxq`` as a Python int constant inside the compiled graph. - """ - tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size) - orig_dtype = tensor.dtype - tensor = tensor.to(torch.float32) - maxq = int(2.0**bits) - 1 - inverse_scale = get_reciprocal(scale) - int_w = torch.clamp(round_ste((tensor + wmin) * inverse_scale + v), 0, maxq) - qdq = (scale * int_w - wmin).to(orig_dtype) - qdq = revert_tensor_by_pad(qdq, orig_shape=orig_shape, pad_len=pad_len) - return qdq - - -def _dq_sym_qdq(tensor, scale, bits, v=0): - """Pure sym double-quant qdq math given precomputed scale. - - ``maxq`` is computed via float ``2.0 ** (bits - 1)`` then cast to - ``int`` to avoid SymInt-driven shifts being lowered through - ``libdevice.pow``. - """ - from auto_round.export.export_to_gguf.config import GGML_QUANT_SIZES, QK_K - - group_size = 16 - super_group_size = 16 - maxq = int(2.0 ** (bits - 1)) - tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size) - orig_dtype = tensor.dtype - tensor = tensor.to(torch.float32) - ggml_type = f"q{bits}_k" - block_size, _ = GGML_QUANT_SIZES[ggml_type] - n_blocks = tensor.nelement() // block_size - tensor = tensor.reshape(n_blocks, super_group_size, QK_K // super_group_size) - if isinstance(v, torch.Tensor): - v_r, _, _ = reshape_pad_tensor_by_group_size(v, group_size) - v_r = v_r.reshape(n_blocks, super_group_size, QK_K // super_group_size) - else: - v_r = v - zp = torch.full_like(scale, maxq) - inverse_scale = get_reciprocal(scale) - int_w = round_ste(tensor * inverse_scale + v_r).clip(-maxq, maxq - 1) + maxq - qdq = (scale * (int_w - zp)).to(orig_dtype) - qdq = revert_tensor_by_pad(qdq, orig_shape=orig_shape, pad_len=pad_len) - return qdq - - -class DQWrapperLinear(WrapperLinear): - """A wrapper for linear/conv1d layers to enable quantization and tuning. - - This module wraps an existing linear or conv1d layer and provides additional functionality - for quantization, parameter tuning, and activation/bias normalization. - - Args: - orig_layer (torch.nn.Module): The original layer to be wrapped (linear or conv1d). - enable_minmax_tuning (bool): Whether to enable min-max scale tuning. - enable_norm_bias_tuning (bool): Whether to enable normalization and tuning of the bias term. - device (str): Device on which to run computations (e.g., 'cpu' or 'cuda'). - """ - - def __init__( - self, - orig_layer, - enable_minmax_tuning=True, - enable_norm_bias_tuning=False, - device="cpu", - enable_round_tuning=True, - enable_torch_compile=False, - disable_opt_rtn=True, - **kwargs, - ): - """Initializes the WrapperLinear module. - - Args: - orig_layer (torch.nn.Module): The original layer to wrap. - enable_minmax_tuning (bool): Whether to enable min-max scale tuning. - enable_norm_bias_tuning (bool): Whether to enable normalization and tuning for the bias term. - device (str): The computation device, such as 'cpu' or 'cuda'. - """ - super(DQWrapperLinear, self).__init__( - orig_layer=orig_layer, - enable_minmax_tuning=enable_minmax_tuning, - enable_norm_bias_tuning=enable_norm_bias_tuning, - device=device, - enable_round_tuning=enable_round_tuning, - enable_torch_compile=enable_torch_compile, - disable_opt_rtn=disable_opt_rtn, - **kwargs, - ) - self.prev_scale = None - self.prev_wmin = None - self.prev_d_scale = None - self.prev_d_wmin = None - - def _init_tuning_params_and_quant_func(self): - """Initializes tuning parameters and quantization functions. - - This method sets up required parameters and functions for weight quantization, - activation quantization, and bias/normalization. - """ - super()._init_tuning_params_and_quant_func() - p_dtype = torch.float32 - # ``search_func`` is kept un-compiled because it contains data-dependent - # control flow (imatrix branches, iterative searches), while - # ``weight_quant_func`` is the compilable pure-math part. - self.search_func = None - self._dq_kind = None - self._is_dq_path = False - if hasattr(self.orig_layer, "super_group_size") and self.orig_layer.super_group_size is not None: - self._is_dq_path = True - from auto_round.data_type.gguf import search_gguf_scale_min_asym, search_gguf_scale_min_sym - - if self.orig_layer.data_type == "int_asym_dq": - self.search_func = search_gguf_scale_min_asym - self.weight_quant_func = _dq_asym_qdq - self._dq_kind = "asym" - else: - self.search_func = search_gguf_scale_min_sym - self.weight_quant_func = _dq_sym_qdq - self._dq_kind = "sym" - elif self.orig_layer.sym: - from auto_round.data_type.int import quant_tensor_sym - - self.weight_quant_func = quant_tensor_sym - else: - from auto_round.data_type.int import quant_tensor_asym - - self.weight_quant_func = quant_tensor_asym - self.data_type = self.orig_layer.data_type - if self.enable_act_quant: - from auto_round.data_type.gguf import quant_tensor_gguf_asym_dq as _gguf_asym_dq - from auto_round.data_type.gguf import quant_tensor_gguf_sym_dq as _gguf_sym_dq - - self.act_quant_func = _gguf_asym_dq if self.orig_layer.act_data_type == "int_asym_dq" else _gguf_sym_dq - if self.enable_torch_compile: - self.act_quant_func = compile_func(self.act_quant_func, self.device) - self._init_params("act_max_scale", p_dtype, (1), 1.0, not self.orig_layer.act_dynamic) - - if self.enable_torch_compile: - self.weight_quant_func = compile_func(self.weight_quant_func, self.device) - - @torch.no_grad() - def _run_search(self, weight, v): - """Run the per-format scale/wmin search separately from the quant func. - - Uses the search routines from ``auto_round.data_type.gguf`` and forwards - the tuning perturbation ``v``. Returns the parameters to feed into the - (compilable) ``weight_quant_func``. - """ - from auto_round.data_type.gguf import double_quant_tensor_sym_rtn - from auto_round.export.export_to_gguf.config import GGML_QUANT_SIZES, QK_K - - bits = self.orig_layer.bits - scale_dtype = self.orig_layer.scale_dtype - imatrix = getattr(self.orig_layer, "imatrix", None) - - if self._dq_kind == "asym": - group_size = 16 if bits == 2 else 32 - t, _, _ = reshape_pad_tensor_by_group_size(weight.to(torch.float32), group_size) - v_r = v - if isinstance(v, torch.Tensor): - v_r, _, _ = reshape_pad_tensor_by_group_size(v, group_size) - scale, wmin, d_scale, d_wmin = self.search_func( - t, - bits=bits, - scale_dtype=scale_dtype, - imatrix=imatrix, - split_num=1, - v=v_r, - ) - # Search funcs are decorated with ``@torch.inference_mode()``; their - # outputs are inference tensors and cannot be saved for backward. - # Clone to detach from inference mode so autograd may use them. - return { - "scale": scale.clone(), - "wmin": wmin.clone(), - "d_scale": d_scale.clone(), - "d_wmin": d_wmin.clone(), - } - - # sym path - group_size = 16 - super_group_size = 16 - t, _, _ = reshape_pad_tensor_by_group_size(weight.to(torch.float32), group_size) - ggml_type = f"q{bits}_k" - block_size, _ = GGML_QUANT_SIZES[ggml_type] - n_blocks = t.nelement() // block_size - t = t.reshape(n_blocks, super_group_size, QK_K // super_group_size) - v_r = v - if isinstance(v, torch.Tensor): - v_r, _, _ = reshape_pad_tensor_by_group_size(v, group_size) - v_r = v_r.reshape(n_blocks, super_group_size, QK_K // super_group_size) - super_bits = 6 if bits == 3 else 8 - scale = self.search_func(t, bits, imatrix, scale_dtype, split_num=1, v=v_r) - scale = scale.to(scale_dtype) - scale = torch.where(torch.abs(scale) < 1e-30, torch.zeros_like(scale), scale) - scale, d_scale = double_quant_tensor_sym_rtn(scale, super_bits) - scale = scale.unsqueeze(-1) - # Clone to escape inference-mode tensors (see asym branch comment). - return {"scale": scale.clone(), "d_scale": d_scale.clone()} - - def _qdq_weight(self, value, min_scale, max_scale, scale_v=None, iter=None): - """Quantizes and dequantizes weights with tuning parameters. - - Args: - value (torch.Tensor): Value added for rounding for tuning. - min_scale (torch.Tensor): Minimum scale for the min value of quantization. - max_scale (torch.Tensor): Maximum scale for the max value of quantization. - - Returns: - tuple: Quantized weight, scale, and zero point. - """ - if self.orig_layer.bits >= 16: - return self.orig_layer.weight, None, None - min_scale.data.clamp_(0.5, 1.5) - max_scale.data.clamp_(0.5, 1.5) - weight = self.orig_layer.weight - if weight.device.type == "meta": - weight = self.orig_layer.get_weight().to(self.device) - if isinstance(self.orig_layer, transformers.pytorch_utils.Conv1D): - weight = weight.t() - - if self._is_dq_path: - # Split search (data-dependent, un-compiled) from quant math (compilable). - iter_v = self.cur_iter if (iter is None and hasattr(self, "cur_iter")) else iter - if iter_v is None: - iter_v = 0 - need_search = (iter_v % 10 == 0) or (iter_v == -1) or (self.prev_scale is None) - if need_search: - params = self._run_search(weight, value) - self.prev_scale = params["scale"] - self.prev_d_scale = params["d_scale"] - if self._dq_kind == "asym": - self.prev_wmin = params["wmin"] - self.prev_d_wmin = params["d_wmin"] - else: - params = { - "scale": self.prev_scale.detach(), - "d_scale": self.prev_d_scale.detach(), - } - if self._dq_kind == "asym": - params["wmin"] = self.prev_wmin.detach() - params["d_wmin"] = self.prev_d_wmin.detach() - - bits = self.orig_layer.bits - if self._dq_kind == "asym": - group_size = 16 if bits == 2 else 32 - weight_q = self.weight_quant_func( - weight, - params["scale"], - params["wmin"], - bits, - group_size, - v=value, - ) - scale_out = {"scale": params["scale"], "d_scale": params["d_scale"]} - zp_out = {"wmin": params["wmin"], "d_wmin": params["d_wmin"]} - else: - weight_q = self.weight_quant_func( - weight, - params["scale"], - bits, - v=value, - ) - scale_out = {"scale": params["scale"], "d_scale": params["d_scale"]} - zp_out = torch.full_like(params["scale"], int(2.0 ** (bits - 1))) - - weight_q = weight_q.to(weight.dtype) - if isinstance(self.orig_layer, transformers.pytorch_utils.Conv1D): - weight_q = weight_q.t() - return weight_q, scale_out, zp_out - - # Non-dq path: preserve original behavior. - quant_kwargs = {} - if hasattr(self.orig_layer, "super_bits"): - quant_kwargs["super_bits"] = self.orig_layer.super_bits - quant_kwargs["super_group_size"] = self.orig_layer.super_group_size - weight_q, scale, zp = self.weight_quant_func( - weight, - bits=self.orig_layer.bits, - group_size=self.orig_layer.group_size, - v=value, - min_scale=min_scale, - max_scale=max_scale, - scale_dtype=self.orig_layer.scale_dtype, - tensor_min=self.weight_min, - tensor_max=self.weight_max, - data_type=self.data_type, - q_scale_thresh=self.q_scale_thresh, - prev_scale=self.prev_scale, - prev_wmin=self.prev_wmin, - prev_d_scale=self.prev_d_scale, - prev_d_wmin=self.prev_d_wmin, - **quant_kwargs, - ) - weight_q = weight_q.to(weight.dtype) - if isinstance(self.orig_layer, transformers.pytorch_utils.Conv1D): - weight_q = weight_q.t() - if isinstance(scale, dict) and "d_scale" in scale and self.prev_scale is None: - self.prev_scale = scale["scale"] - self.prev_d_scale = scale["d_scale"] - if isinstance(zp, dict): - self.prev_wmin = zp["wmin"] - self.prev_d_wmin = zp["d_wmin"] - elif self.prev_scale is None: - self.prev_scale = scale - # self.orig_layer.imatrix = None - return weight_q, scale, zp - - -def dq_wrapper_block(block, enable_minmax_tuning, enable_norm_bias_tuning, device="cpu", **kwargs): - """Wraps the layers in the given block with a custom Wrapper module. - - Args: - block: The input block containing linear and conv1d layers to be wrapped. - enable_minmax_tuning: A boolean indicating whether min-max tuning is enabled. - - Returns: - list: A list of names of the wrapped layers and unwrapped layers. - """ - quantized_layers = [] - unquantized_layers = [] - for n, m in block.named_modules(): - if isinstance(m, SUPPORTED_LAYER_TYPES): - if not check_to_quantized(m): - unquantized_layers.append(n) - continue - new_m = DQWrapperLinear( - m, - enable_minmax_tuning=enable_minmax_tuning, - enable_norm_bias_tuning=enable_norm_bias_tuning, - device=device, - **kwargs, - ) - set_module(block, n, new_m) - quantized_layers.append(n) - - elif enable_norm_bias_tuning: - if "norm" in m.__class__.__name__.lower(): - if m.__class__.__name__ in NORM_MAPPING.keys(): - wrapper_layer_class = NORM_MAPPING[m.__class__.__name__] - new_m = wrapper_layer_class(m, device=device) - set_module(block, n, new_m) - elif "RMSNorm" in m.__class__.__name__: - logger.warning_once( - f"use LlamaRMSNorm to wrap {m.__class__.__name__}, please check the correctness yourself" - ) - wrapper_layer_class = NORM_MAPPING["LlamaRMSNorm"] - new_m = wrapper_layer_class(m, device=device) - set_module(block, n, new_m) - else: - logger.warning_once(f"{m.__class__.__name__} is not supported") - - return quantized_layers, unquantized_layers diff --git a/auto_round/algorithms/base.py b/auto_round/algorithms/base.py index 4590536c5..c8a228750 100644 --- a/auto_round/algorithms/base.py +++ b/auto_round/algorithms/base.py @@ -13,5 +13,87 @@ # limitations under the License. +from contextlib import contextmanager + +from auto_round.algorithms.registry import resolve_pipeline_member +from auto_round.schemes import QuantizationScheme + + class BaseAlgorithm: pass + + +class BasePipelineMember: + """Shared interface for all members of a quantization pipeline.""" + + model_context = None + compress_context = None + _scheme_context_fields = set(QuantizationScheme.get_attributes()) + bits: int | None + group_size: int | tuple | None + sym: bool | None + data_type: str | None + act_bits: int | None + act_group_size: int | None + act_sym: bool | None + act_data_type: str | None + act_dynamic: bool | None + super_bits: int | None + super_group_size: int | None + super_sym: bool | None + scale_dtype: str | None + + def __init__(self, config=None): + self.config = config + self.scheme = getattr(config, "scheme", None) + + @classmethod + def from_config(cls, config): + """Instantiate the registered implementation class for ``config``.""" + alg_cls = resolve_pipeline_member(config) + if cls is alg_cls: + return cls(config) + return alg_cls(config) + + def bind(self, compressor) -> None: + """Wire shared context from the owning compressor.""" + self.model_context = compressor.model_context + self.compress_context = compressor.compress_context + self.scheme = getattr(compressor, "scheme_context", None) + + def prepare_run(self, run_ctx) -> None: + """Model-level preparation called once before block iteration starts.""" + return + + def get_act_calib_policy(self, ctx): + """Return the activation calibration policy for this block.""" + from auto_round.algorithms.pipeline import ActCalibPolicy, CalibTiming, InputSource + + return ActCalibPolicy(when=CalibTiming.SKIP, source=InputSource.FP_CACHE) + + @contextmanager + def block_forward_hooks(self, ctx): + """Register algorithm-specific forward hooks for the reference forward.""" + yield [] + + def finalize_run(self, run_ctx) -> None: + """Model-level teardown called once after all blocks are processed.""" + return + + +def _make_scheme_property(name): + def getter(self): + scheme = getattr(self, "scheme", None) + return getattr(scheme, name, None) if scheme is not None else None + + def setter(self, value): + scheme = getattr(self, "scheme", None) + if scheme is None: + raise AttributeError(f"{type(self).__name__} has no bound scheme") + setattr(scheme, name, value) + + return property(getter, setter) + + +for _scheme_field in QuantizationScheme.get_attributes(): + setattr(BasePipelineMember, _scheme_field, _make_scheme_property(_scheme_field)) diff --git a/auto_round/algorithms/pipeline.py b/auto_round/algorithms/pipeline.py new file mode 100644 index 000000000..c5cb1a878 --- /dev/null +++ b/auto_round/algorithms/pipeline.py @@ -0,0 +1,758 @@ +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Algorithm Fusion Pipeline abstraction. + +This module defines the core data structures and utilities for composing +pre-processing algorithms (AWQ smooth, SmoothQuant, Rotation…) with a +block-quantization algorithm (RTN, SignRound/AutoRound…) into a single +shared-calibration pipeline. + +Design invariants (see AWQ_REFACTOR_PLAN.md §0.0 and §3.0): +- ``QuantizationPipeline`` is the *first-class abstraction*; it is NOT just + AWQ's helper. +- All block-wise scheduling in ``DataDrivenCompressor`` operates against + ``QuantizationPipeline``, never against a concrete ``AWQQuantizer``. +- Single-algorithm use is expressed as + ``QuantizationPipeline(preprocessors=[], block_quantizer=q)``, which is + semantically identical to the current direct-quantizer path. +""" + +from __future__ import annotations + +from collections import defaultdict +from contextlib import ExitStack +from dataclasses import dataclass, field, fields +from enum import IntEnum +from typing import TYPE_CHECKING, Any, Optional, Union + +import torch + +from auto_round.compressors.utils import block_forward + +if TYPE_CHECKING: # avoid circular imports at runtime + import torch + + from auto_round.algorithms.base import BasePipelineMember + from auto_round.algorithms.quantization.base import BaseQuantizer + from auto_round.algorithms.quantization.config import QuantizationConfig + from auto_round.algorithms.transforms.base import BaseWeightTransformer + +# --------------------------------------------------------------------------- +# Policy +# --------------------------------------------------------------------------- + + +class CalibTiming(IntEnum): + """When to run the activation calibration hook forward pass. + + SKIP: + No activation calibration needed (e.g. AWQ W4A16 pure weight-only smooth). + WITH_REFERENCE: + Register act-calib hook and run together with the FP reference forward + (RTN / AutoRound default). When ``source == InputSource.QUANTIZED_INPUT``, a + separate hook-only forward is run with ``quantized_input`` instead. + AFTER_PREPROCESS: + Run a dedicated forward pass *after* all ``pre_quantize_block`` calls + complete (i.e. after smooth/rotation transforms are applied). Used for + AWQ W8A8 / static-activation quantization where the activation + distribution changes after smoothing. + """ + + SKIP = 0 + WITH_REFERENCE = 1 + AFTER_PREPROCESS = 2 + + +class InputSource(IntEnum): + """Which tensor to use as the block input for the act-calib forward pass. + + FP_CACHE: + The original FP calibration activations cached at the start of the run + (``enable_quanted_input=False`` path, RTN default). + QUANTIZED_INPUT: + The quantized output of the previous block used as the next block's input + (``enable_quanted_input=True`` path, SignRound / AutoRound). + """ + + FP_CACHE = 0 + QUANTIZED_INPUT = 1 + + +def get_algorithm_class(config: Any): + """Return the registered implementation class for a quantization config.""" + from auto_round.algorithms.registry import normalize_algorithm_config, resolve_pipeline_member + + try: + return resolve_pipeline_member(normalize_algorithm_config(config)) + except ValueError: + return None + + +def is_preprocessor_config(config: Any) -> bool: + """Whether *config* resolves to a BaseWeightTransformer implementation.""" + from auto_round.algorithms.transforms.base import BaseWeightTransformer + + alg_cls = get_algorithm_class(config) + return alg_cls is not None and issubclass(alg_cls, BaseWeightTransformer) + + +def is_block_quantizer_config(config: Any) -> bool: + """Whether *config* resolves to a BaseQuantizer implementation.""" + from auto_round.algorithms.quantization.base import BaseQuantizer + + alg_cls = get_algorithm_class(config) + return alg_cls is not None and issubclass(alg_cls, BaseQuantizer) + + +def split_quantization_configs(configs: list[Any]) -> tuple[list[Any], list[Any]]: + """Split configs into preprocessor and block-quantizer config lists.""" + preprocessors = [] + block_quantizers = [] + for config in configs: + if is_preprocessor_config(config): + preprocessors.append(config) + elif is_block_quantizer_config(config): + block_quantizers.append(config) + return preprocessors, block_quantizers + + +def _quantization_configs(configs: list[Any]) -> list[Any]: + """Return configs that participate in quantization shared-field resolution.""" + from auto_round.algorithms.quantization.config import QuantizationConfig + + return [config for config in configs if isinstance(config, QuantizationConfig)] + + +def _public_config_attrs(config: Any) -> dict[str, Any]: + """Public, data-like attrs used for structural shared config resolution.""" + return { + key: value + for key, value in vars(config).items() + if key != "scheme" and not key.startswith("_") and not callable(value) + } + + +def _format_shared_config_values(field: str, values: list[tuple[Any, Any]]) -> str: + parts = [f"{type(config).__name__}.{field}={value!r}" for config, value in values] + return ", ".join(parts) + + +def _resolve_shared_scheme_values(configs: list[Any]) -> None: + from auto_round.schemes import QuantizationScheme + + scheme_fields = {field.name for field in fields(QuantizationScheme)} + field_values: dict[str, list[tuple[Any, Any]]] = defaultdict(list) + for config in configs: + for attr_name in getattr(config, "_user_set_scheme_fields", set()): + if attr_name not in scheme_fields: + continue + value = getattr(config, attr_name, None) + if value is not None: + field_values[attr_name].append((config, value)) + + for attr_name, values in field_values.items(): + unique_values = [] + for _, value in values: + if not any(value == existing for existing in unique_values): + unique_values.append(value) + if len(unique_values) > 1: + raise ValueError( + f"Conflicting shared scheme field {attr_name!r}: " + f"{_format_shared_config_values(attr_name, values)}. " + "Use the same value for shared fields or pass scheme arguments through Compressor." + ) + shared_value = unique_values[0] + for config in configs: + if hasattr(config, "scheme") and attr_name not in getattr(config, "_user_set_scheme_fields", set()): + setattr(config.scheme, attr_name, shared_value) + + +def resolve_shared_config_values(configs: list[Any]) -> list[Any]: + """Merge shared public attrs across quantization configs without naming fields. + + A field is shared when at least two quantization configs define the same + public attribute. ``None`` means "not set" and inherits from the single + non-None value, while conflicting non-None values fail fast. Configs that do + not define a field do not participate in that field. + """ + quant_configs = _quantization_configs(configs) + _resolve_shared_scheme_values(quant_configs) + attrs_by_config = [(config, _public_config_attrs(config)) for config in quant_configs] + field_to_configs: dict[str, list[Any]] = defaultdict(list) + for config, attrs in attrs_by_config: + for attr_name in attrs: + field_to_configs[attr_name].append(config) + + for attr_name, field_configs in field_to_configs.items(): + if len(field_configs) < 2: + continue + + field_attrs = [ + (config, attrs[attr_name]) + for config, attrs in attrs_by_config + if any(config is field_config for field_config in field_configs) + ] + non_none_values = [(config, value) for config, value in field_attrs if value is not None] + unique_values = [] + for _, value in non_none_values: + if not any(value == existing for existing in unique_values): + unique_values.append(value) + if len(unique_values) > 1: + raise ValueError( + f"Conflicting shared config field {attr_name!r}: " + f"{_format_shared_config_values(attr_name, non_none_values)}. " + "Use the same value for shared fields or leave it unset on secondary algorithms." + ) + if len(unique_values) == 1: + shared_value = unique_values[0] + for config in field_configs: + if getattr(config, attr_name) is None: + setattr(config, attr_name, shared_value) + return configs + + +def sync_shared_config_from(source_config: Any, target_configs: list[Any]) -> None: + """Propagate resolved source values to targets that already define matching attrs.""" + source_attrs = _public_config_attrs(source_config) + for target in _quantization_configs(target_configs): + if target is source_config: + continue + target_attrs = _public_config_attrs(target) + for attr_name, source_value in source_attrs.items(): + if attr_name in target_attrs and source_value is not None: + setattr(target, attr_name, source_value) + + +@dataclass +class ActCalibPolicy: + """Activation calibration policy: when and from what inputs to collect act stats. + + Attributes: + when: ``CalibTiming`` — controls the scheduling phase. + source: ``InputSource`` — which tensor feeds the calibration forward. + Ignored when ``when == CalibTiming.SKIP``. + """ + + when: CalibTiming = CalibTiming.WITH_REFERENCE + source: InputSource = InputSource.FP_CACHE + + def __post_init__(self) -> None: + if not isinstance(self.when, CalibTiming): + try: + self.when = CalibTiming(self.when) + except ValueError: + raise ValueError(f"ActCalibPolicy.when must be a CalibTiming, got {self.when!r}") + if not isinstance(self.source, InputSource): + try: + self.source = InputSource(self.source) + except ValueError: + raise ValueError(f"ActCalibPolicy.source must be an InputSource, got {self.source!r}") + + @classmethod + def no_collection(cls) -> "ActCalibPolicy": + """Convenience factory: no activation calibration.""" + return cls(when=CalibTiming.SKIP, source=InputSource.FP_CACHE) + + +def merge_policies(policies: list["ActCalibPolicy"]) -> "ActCalibPolicy": + """Merge act-calib policies from multiple pipeline members. + + Rules: + - ``when``: take the *latest* timing (``SKIP < WITH_REFERENCE < AFTER_PREPROCESS``). + - ``source``: when two policies share the same ``when`` but differ in ``source``, + that is a compatibility conflict → raise ``ValueError`` fail-fast. + - Policies with ``when == CalibTiming.SKIP`` do **not** contribute a source + constraint. + + Returns: + A merged :class:`ActCalibPolicy`. + """ + if not policies: + return ActCalibPolicy.no_collection() + + merged_when = max(policies, key=lambda p: p.when).when + + if merged_when == CalibTiming.SKIP: + return ActCalibPolicy.no_collection() + + contributing = [p for p in policies if p.when == merged_when] + sources = {p.source for p in contributing} + if len(sources) > 1: + raise ValueError( + f"Incompatible act-calib policies: multiple algorithms request " + f"when={merged_when.name!r} but with different input sources: " + f"{[s.name for s in sources]}. " + "Use a compatible combination of algorithms or file an issue." + ) + + return ActCalibPolicy(when=merged_when, source=contributing[0].source) + + +# --------------------------------------------------------------------------- +# Context dataclasses +# --------------------------------------------------------------------------- + + +@dataclass +class BlockIO: + """Owns per-block calibration inputs, outputs, and batch forward mechanics.""" + + fp_inputs: Any + input_others: dict + quantized_inputs: Any = None + reference_outputs: Any = None + quantized_outputs: Any = None + active_source: InputSource = InputSource.FP_CACHE + batch_dim: int = 0 + seqlen: int = 2048 + shared_cache_keys: tuple = () + + @property + def has_quantized_inputs(self) -> bool: + return self.quantized_inputs is not None + + def get_inputs(self, source: InputSource): + if source == InputSource.QUANTIZED_INPUT: + return self.quantized_inputs + return self.fp_inputs + + def set_reference_outputs(self, outputs) -> None: + self.reference_outputs = outputs + + def set_quantized_inputs(self, inputs) -> None: + self.quantized_inputs = inputs + + def select_inputs(self, source: InputSource, indices): + input_ids = self.get_inputs(source) + if input_ids is None: + raise ValueError(f"Input source {source.name} is unavailable for this block.") + return self._select_inputs(input_ids, self.input_others, indices) + + def forward_batch(self, block, quantizer, indices, *, source: InputSource, device, cache_device="cpu"): + input_ids, input_others = self.select_inputs(source, indices) + output = quantizer._resolve_block_forward()( + block, + input_ids, + input_others, + quantizer.model_context.amp, + quantizer.model_context.amp_dtype, + device, + ) + return output.to(cache_device) + + @torch.no_grad() + def collect_outputs(self, block, quantizer, *, source: InputSource, batch_size: int, save: bool = True): + input_ids = self.get_inputs(source) + if input_ids is None: + raise ValueError(f"Input source {source.name} is unavailable for this block.") + outputs = [] + for start in range(0, len(input_ids), batch_size): + end = min(len(input_ids), start + batch_size) + indices = torch.arange(start, end).to(torch.long) + output = self.forward_batch( + block, + quantizer, + indices, + source=source, + device=quantizer.compress_context.device, + cache_device=quantizer.compress_context.cache_device, + ) + if save: + if quantizer.batch_size == 1: + outputs.append(output) + else: + outputs.extend(list(torch.split(output, 1, dim=self.batch_dim))) + quantizer.compress_context.clear_memory() + return outputs + + def select_reference_outputs(self, indices, *, device=None): + if self.reference_outputs is None: + raise ValueError("Reference outputs have not been collected for this block.") + output = torch.cat([self.reference_outputs[i] for i in indices], dim=self.batch_dim) + return output.to(device) if device is not None else output + + def count_input_elements(self, indices, *, source: InputSource | None = None) -> int: + source = self.active_source if source is None else source + input_ids = self.get_inputs(source) + current_input_ids = [input_ids[i] for i in indices] + return sum(t.numel() for t in current_input_ids) + + def preprocess_block_inputs(self, input_ids, input_others: dict, block): + return input_ids, input_others + + def _select_inputs(self, input_ids, input_others: dict, indices): + if isinstance(input_ids, list): + current_input_ids = [input_ids[i] for i in indices] + current_input_ids = torch.cat(current_input_ids, dim=self.batch_dim) + elif isinstance(input_ids, dict): + current_input_ids = {} + for key in input_ids.keys(): + current_input_ids[key] = torch.cat([input_ids[key][i] for i in indices], dim=self.batch_dim) + else: + raise TypeError(f"Unsupported input container type: {type(input_ids).__name__}") + + current_input_others = {"positional_inputs": input_others["positional_inputs"]} + for key in input_others.keys(): + if "positional_inputs" in key: + continue + if key in self.shared_cache_keys: + val = input_others[key] + if isinstance(val, list) and len(val) == 1: + current_input_others[key] = val[0] + elif isinstance(val, list) and len(val) > 1: + idx = int(indices[0]) if len(indices) == 1 else 0 + current_input_others[key] = val[idx] if idx < len(val) else val[0] + else: + current_input_others[key] = val + elif not isinstance(input_others[key], (str, bool, type(None))): + current_input_others[key] = [input_others[key][i] for i in indices] + if len(current_input_others[key]) == 1: + current_input_others[key] = current_input_others[key][0] + else: + current_input_others[key] = torch.cat(current_input_others[key], dim=self.batch_dim) + else: + current_input_others[key] = input_others[key] + return current_input_ids, current_input_others + + +class DiffusionBlockIO(BlockIO): + """BlockIO variant for diffusion blocks with dict inputs and tuple outputs.""" + + def __init__(self, *args, output_config=None, **kwargs): + super().__init__(*args, **kwargs) + self.output_config = output_config or ["hidden_states"] + + @torch.no_grad() + def collect_outputs(self, block, quantizer, *, source: InputSource, batch_size: int, save: bool = True): + input_ids = self.get_inputs(source) + if input_ids is None: + raise ValueError(f"Input source {source.name} is unavailable for this block.") + outputs = defaultdict(list) + nsamples = len(input_ids["hidden_states"]) if isinstance(input_ids, dict) else len(input_ids) + for start in range(0, nsamples, batch_size): + end = min(nsamples, start + batch_size) + indices = torch.arange(start, end).to(torch.long) + tmp_input_ids, tmp_input_others = self.select_inputs(source, indices) + if isinstance(tmp_input_ids, dict): + hidden_states = tmp_input_ids.pop("hidden_states") + tmp_input_others.update(tmp_input_ids) + tmp_input_ids = hidden_states + tmp_output = block_forward( + block, + tmp_input_ids, + tmp_input_others, + quantizer.model_context.amp, + quantizer.model_context.amp_dtype, + quantizer.compress_context.device, + None, + ) + if isinstance(tmp_output, torch.Tensor): + tmp_output = [tmp_output] + assert len(self.output_config) == len(tmp_output) + tmp_output = dict(zip(self.output_config, tmp_output)) + if save: + for name, out in tmp_output.items(): + if quantizer.batch_size == 1: + outputs[name].append(out.to(quantizer.compress_context.cache_device)) + else: + outputs[name].extend( + list(torch.split(out.to(quantizer.compress_context.cache_device), 1, dim=self.batch_dim)) + ) + quantizer.compress_context.clear_memory() + return outputs + + def forward_batch(self, block, quantizer, indices, *, source: InputSource, device, cache_device="cpu"): + input_ids, input_others = self.select_inputs(source, indices) + idx = self.output_config.index("hidden_states") if "hidden_states" in self.output_config else None + if isinstance(input_ids, dict): + hidden_states = input_ids.pop("hidden_states") + input_others.update(input_ids) + input_ids = hidden_states + output = quantizer._resolve_block_forward()( + block, + input_ids, + input_others, + quantizer.model_context.amp, + quantizer.model_context.amp_dtype, + device, + idx, + ) + return output.to(cache_device) + + def select_reference_outputs(self, indices, *, device=None): + if self.reference_outputs is None or "hidden_states" not in self.reference_outputs: + raise ValueError("Diffusion reference hidden_states have not been collected for this block.") + output = torch.cat([self.reference_outputs["hidden_states"][i] for i in indices], dim=self.batch_dim) + return output.to(device) if device is not None else output + + def count_input_elements(self, indices, *, source: InputSource | None = None) -> int: + source = self.active_source if source is None else source + input_ids = self.get_inputs(source) + current_input_ids = [input_ids["hidden_states"][i] for i in indices] + return sum(t.numel() for t in current_input_ids) + + def preprocess_block_inputs(self, input_ids, input_others: dict, block): + if not isinstance(input_ids, dict): + return input_ids, input_others + extra_keys = [key for key in list(input_ids.keys()) if key not in self.output_config] + for key in extra_keys: + input_others[key] = input_ids.pop(key) + return input_ids, input_others + + +@dataclass +class RunContext: + """Model-level context shared across the entire quantization run. + + Passed to ``prepare_run()`` and ``finalize_run()`` so that algorithm-specific + preparation/teardown logic can access model topology without reading + private ``Compressor`` fields. + """ + + model: "torch.nn.Module" + all_blocks: list # list[list[str]] – block groups as returned by get_block_names + layer_names: list[str] # quantizable layer names *outside* all blocks + formats: Optional[list] # resolved OutputFormat list (or None) + scheme: Any # str | QuantizationScheme | AutoScheme + alg_configs: list # all algorithm configs in the pipeline + model_context: Any # ModelContext + compress_context: Any # CompressContext + + +@dataclass +class BlockContext: + """Per-block context threaded through the lifecycle hooks. + + Passed to lifecycle methods like ``block_forward_hooks``, + ``pre_quantize_block``, ``post_quantize_block``, etc. + + ``block_names`` preserves the *scheduling group* (which may contain more + than one block when ``nblocks > 1``). Pre-processing algorithms that + only support single-block operation (e.g. AWQ Phase-1) must check + ``len(block_names) == 1`` in ``prepare_block_group`` and raise + ``ValueError`` with a user-readable message. + """ + + model: "torch.nn.Module" + block: "torch.nn.Module" + block_names: list[str] # scheduling group; len > 1 when nblocks > 1 + block_name: str # = block_names[0] for single-block; descriptive label for multi + block_index: int # 0-based index within the current all_blocks group + io: BlockIO + bs: int = 1 + loss_device: Union[str, "torch.device", None] = None + device: Union[str, "torch.device", None] = None + mid_iter_mem_check: bool = False + is_mllm: bool = False # fail-fast gate for algorithms that don't support MLLM + is_diffusion: bool = False # fail-fast gate for algorithms that don't support diffusion + pbar: Any = None + # Names of FP parameters modified in-place by preprocessors (for example, + # smooth source norm weights). Populated by pre_quantize_block; read by Compressor + # during immediate_save to persist FP param changes alongside packed weights. + # Pipelines without FP-param preprocessors leave this empty, no behavior change. + modified_fp_params: list = field(default_factory=list) + + def mark_modified_fp_params(self, param_names: list[str]) -> None: + """Called by preprocessors to declare which FP params were modified in-place.""" + self.modified_fp_params.extend(param_names) + + @property + def input_ids(self): + return self.io.fp_inputs + + @input_ids.setter + def input_ids(self, value) -> None: + self.io.fp_inputs = value + + @property + def input_others(self): + return self.io.input_others + + @property + def reference_output(self): + return self.io.reference_outputs + + @reference_output.setter + def reference_output(self, value) -> None: + self.io.reference_outputs = value + + @property + def quantized_input(self): + return self.io.quantized_inputs + + @quantized_input.setter + def quantized_input(self, value) -> None: + self.io.quantized_inputs = value + + +# --------------------------------------------------------------------------- +# QuantizationPipeline +# --------------------------------------------------------------------------- + + +@dataclass +class QuantizationPipeline: + """An ordered composition of pre-processing quantizers + one block quantizer. + + The ``preprocessors`` list is order-sensitive: algorithms are applied in + the listed order (e.g. ``[Rotation, AWQ]``). There must be **exactly one** + ``block_quantizer`` (the terminal weight-compression step). + + Single-algorithm use: + ``QuantizationPipeline(preprocessors=[], block_quantizer=rtn_quantizer)`` + is semantically equivalent to the current direct-quantizer path; the + compressor's existing ``self.quantizer`` call-sites are transparently + forwarded to ``block_quantizer`` via a ``@property``. + """ + + preprocessors: list["BaseWeightTransformer"] = field(default_factory=list) + block_quantizer: "BaseQuantizer" = None # type: ignore[assignment] + + def __post_init__(self) -> None: + if self.block_quantizer is None: + raise ValueError("QuantizationPipeline requires a non-None block_quantizer.") + from auto_round.algorithms.quantization.base import BaseQuantizer + from auto_round.algorithms.transforms.base import BaseWeightTransformer + + for q in self.preprocessors: + if not isinstance(q, BaseWeightTransformer): + raise TypeError( + f"{type(q).__name__} is listed as a preprocessor but does not " f"inherit BaseWeightTransformer." + ) + if not isinstance(self.block_quantizer, BaseQuantizer): + raise TypeError( + f"{type(self.block_quantizer).__name__} is used as block_quantizer but does not " + f"inherit BaseQuantizer." + ) + + def all(self) -> "list[BasePipelineMember]": + """Return all members in pipeline order: preprocessors then block_quantizer.""" + return [*self.preprocessors, self.block_quantizer] + + @classmethod + def from_configs(cls, configs: list, compressor: Any = None) -> "QuantizationPipeline": + """Construct a ``QuantizationPipeline`` from a list of algorithm config instances. + + Resolution rules: + 1. If no ``QuantizationConfig`` with a ``BaseQuantizer`` is found in *configs*, + a default :class:`RTNConfig` is appended automatically. + 2. If ``compressor`` indicates a diffusion model, :class:`DiffusionMixin` is + dynamically prepended to each algorithm class's MRO before instantiation, + activating diffusion-aware method overrides without touching the class definitions. + 3. Instances of ``BaseWeightTransformer`` go into ``preprocessors`` (in order). + 4. Exactly one ``BaseQuantizer`` becomes ``block_quantizer``. + 5. Multiple block-quantization configs raise ``ValueError``. + 6. If ``compressor`` is provided, every member is bound to it. + """ + from auto_round.algorithms.quantization.base import BaseQuantizer, DiffusionMixin + from auto_round.algorithms.quantization.config import QuantizationConfig + from auto_round.algorithms.transforms.base import BaseWeightTransformer + + is_diffusion = compressor is not None and getattr(compressor.model_context, "is_diffusion", False) + configs = list(configs) + + # Ensure at least one terminal block quantizer is present; fall back to RTN. + _, block_quantizer_configs = split_quantization_configs(configs) + has_quantizer = bool(block_quantizer_configs) + if not has_quantizer: + from auto_round.algorithms.quantization.rtn.config import RTNConfig + + configs = list(configs) + [RTNConfig()] + + configs = resolve_shared_config_values(configs) + + def _resolve_cls(cfg): + alg_cls = get_algorithm_class(cfg) + if alg_cls is None: + raise ValueError(f"Unknown algorithm config type {type(cfg).__name__!r}.") + if is_diffusion and not issubclass(alg_cls, DiffusionMixin): + alg_cls = type(alg_cls.__name__, (DiffusionMixin, alg_cls), {}) + return alg_cls + + preprocessors = [] + block_quantizers = [] + + for cfg in configs: + if not isinstance(cfg, QuantizationConfig): + continue + from auto_round.algorithms.registry import normalize_algorithm_config + + cfg = normalize_algorithm_config(cfg) + alg_cls = _resolve_cls(cfg) + q = alg_cls(cfg) + if compressor is not None: + q.bind(compressor) + if isinstance(q, BaseWeightTransformer): + preprocessors.append(q) + elif isinstance(q, BaseQuantizer): + block_quantizers.append(q) + else: + raise TypeError( + f"Algorithm class {type(q).__name__} must inherit either " "BaseWeightTransformer or BaseQuantizer." + ) + + if len(block_quantizers) > 1: + raise ValueError( + f"QuantizationPipeline allows exactly one block-quantization config, " + f"but got {len(block_quantizers)}: " + f"{[type(q).__name__ for q in block_quantizers]}. " + "Ensure only one of RTNConfig / SignRoundConfig / etc. is in the pipeline." + ) + + seen_preprocessors = set() + for preprocessor in preprocessors: + name = type(preprocessor).__name__ + if name in seen_preprocessors: + raise ValueError( + f"Duplicate preprocessor {name} in QuantizationPipeline. " + "Repeated instances of the same preprocessor are not supported yet." + ) + seen_preprocessors.add(name) + + return cls(preprocessors=preprocessors, block_quantizer=block_quantizers[0]) + + # ── Convenience act-calib helpers ──────────────────────────────────────── + + def get_merged_policy(self, ctx: "BlockContext") -> ActCalibPolicy: + """Compute the merged act-calib policy for the current block.""" + policies = [q.get_act_calib_policy(ctx) for q in self.all()] + return merge_policies(policies) + + def enter_block_forward_hooks(self, ctx: "BlockContext", fwd_stack: ExitStack) -> list: + """Enter all pipeline members' ``block_forward_hooks`` into *fwd_stack*. + + Iterates over all members (preprocessors then block_quantizer) in order, + entering each member's ``block_forward_hooks(ctx)`` context manager into + the provided :class:`contextlib.ExitStack`. + + Returns the hook handles yielded by the terminal ``block_quantizer`` + so the caller can determine whether any act-calib hooks were registered + (needed to decide whether a second forward with quantized inputs is required). + """ + self.enter_preprocessor_hooks(ctx, fwd_stack) + return self.enter_quantizer_hooks(ctx, fwd_stack) + + def enter_preprocessor_hooks(self, ctx: "BlockContext", fwd_stack: ExitStack) -> None: + """Enter preprocessor hooks only. + + Preprocessor hooks collect stats from the FP reference forward. They are + intentionally separate from terminal quantizer hooks so quantizer stats + can be collected from quantized inputs when required by policy. + """ + for pre in self.preprocessors: + fwd_stack.enter_context(pre.block_forward_hooks(ctx)) + + def enter_quantizer_hooks(self, ctx: "BlockContext", fwd_stack: ExitStack) -> list: + """Enter terminal block-quantizer hooks only and return their handles.""" + return fwd_stack.enter_context(self.block_quantizer.block_forward_hooks(ctx)) diff --git a/auto_round/algorithms/quantization/__init__.py b/auto_round/algorithms/quantization/__init__.py index 5b297df1d..d05375224 100644 --- a/auto_round/algorithms/quantization/__init__.py +++ b/auto_round/algorithms/quantization/__init__.py @@ -12,13 +12,34 @@ # See the License for the specific language governing permissions and # limitations under the License. -from auto_round.algorithms.quantization.base import BaseQuantizers +from auto_round.algorithms.base import BasePipelineMember +from auto_round.algorithms.quantization.base import BaseQuantizer, DiffusionMixin, RTNLayerFallbackMixin from auto_round.algorithms.quantization.config import QuantizationConfig -from auto_round.algorithms.quantization.sign_round.config import SignRoundConfig +from auto_round.algorithms.pipeline import ( + ActCalibPolicy, + CalibTiming, + InputSource, + BlockContext, + QuantizationPipeline, + RunContext, + merge_policies, +) +from auto_round.algorithms.quantization.sign_round.config import AdamRoundConfig, SignRoundConfig, SignRoundV2Config from auto_round.algorithms.quantization.sign_round.quantizer import SignRoundQuantizer from auto_round.algorithms.quantization.sign_roundv2 import SignRoundV2Quantizer from auto_round.algorithms.quantization.adam_round.adam import AdamRoundQuantizer -from auto_round.algorithms.quantization.awq.config import AWQConfig -from auto_round.algorithms.quantization.awq.quantizer import AWQQuantizer -from auto_round.algorithms.quantization.rtn.config import RTNConfig +from auto_round.algorithms.quantization.rtn.config import OptimizedRTNConfig, RTNConfig from auto_round.algorithms.quantization.rtn.quantizer import RTNQuantizer, OptimizedRTNQuantizer +from auto_round.algorithms.transforms.base import BaseWeightTransformer + + +def __getattr__(name): + if name == "AWQConfig": + from auto_round.algorithms.transforms.awq.config import AWQConfig + + return AWQConfig + if name == "AWQQuantizer": + from auto_round.algorithms.transforms.awq.quantizer import AWQQuantizer + + return AWQQuantizer + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/auto_round/algorithms/quantization/adam_round/adam.py b/auto_round/algorithms/quantization/adam_round/adam.py index 96835b533..fb394369f 100644 --- a/auto_round/algorithms/quantization/adam_round/adam.py +++ b/auto_round/algorithms/quantization/adam_round/adam.py @@ -15,11 +15,14 @@ import torch +from auto_round.algorithms.quantization.sign_round.config import AdamRoundConfig from auto_round.algorithms.quantization.sign_round.quantizer import SignRoundQuantizer +from auto_round.algorithms.registry import register_pipeline_member from auto_round.schemes import QuantizationScheme from auto_round.utils import check_is_cpu, htcore, is_hpex_available +@register_pipeline_member(AdamRoundConfig) class AdamRoundQuantizer(SignRoundQuantizer): def __init__(self, config): diff --git a/auto_round/algorithms/quantization/awq/__init__.py b/auto_round/algorithms/quantization/awq/__init__.py deleted file mode 100644 index 14a492441..000000000 --- a/auto_round/algorithms/quantization/awq/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright (c) 2026 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/auto_round/algorithms/quantization/awq/config.py b/auto_round/algorithms/quantization/awq/config.py deleted file mode 100644 index 50340eeda..000000000 --- a/auto_round/algorithms/quantization/awq/config.py +++ /dev/null @@ -1,77 +0,0 @@ -# Copyright (c) 2026 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from auto_round.algorithms.quantization.config import QuantizationConfig - - -class AWQConfig(QuantizationConfig): - """Configuration for AWQ (Activation-Aware Weight Quantization). - - AWQ protects salient weight channels by analyzing activation patterns and - applying channel-wise scaling to reduce quantization error. The scaling - factors are computed offline via a grid search over calibration data. - After smoothing, standard RTN quantization is applied to the adjusted weights. - - Args: - duo_scaling: Use both activations and weights for the scaling factor. - True: always use duo scaling. False: only activation scaling. - "both": search both modes and pick the best. - n_grid: Number of grid points for the scaling-ratio search. - seqlen: Calibration sequence length. - nsamples: Number of calibration samples. Grid search time scales - linearly with nsamples (each batch triggers one parent forward - per grid point). - batch_size: Batch size for calibration forward passes. - apply_smooth: Whether to apply AWQ smoothing (channel scaling). - True: run both smoothing and quantization (default AWQ behavior). - False: skip smoothing, only run RTN quantization. - mappings: Explicit AWQ mappings. Each mapping is a dict with keys - ``smooth_layer`` (str) and ``balance_layers`` (list[str]). - If None, mappings are inferred automatically from the model structure. - **kwargs: Forwarded to ``QuantizationConfig`` (bits, group_size, sym, …). - """ - - _alg_cls = "AWQQuantizer" - - def __init__( - self, - *, - duo_scaling: bool | str = True, - n_grid: int = 20, - seqlen: int = 2048, - nsamples: int = 128, - batch_size: int = 8, - apply_smooth: bool = True, - mappings: list[dict] = None, - **kwargs, - ): - super().__init__(**kwargs) - - if isinstance(duo_scaling, str) and duo_scaling != "both": - raise ValueError(f"duo_scaling must be True, False, or 'both', got '{duo_scaling}'") - self.duo_scaling = duo_scaling - self.n_grid = n_grid - self.seqlen = seqlen - self.nsamples = nsamples - self.batch_size = batch_size - self.apply_smooth = apply_smooth - self.mappings = mappings - self.infer_bs_coeff = 1 - self.batch_dim = None - - # TODO adjust those args defaults after architecture refactoring: - self.enable_quanted_input = False # AWQ doesn't cascade quantized block outputs - # AWQ uses plain RTN (no iterative optimization) for the quantization step. - self.disable_opt_rtn = True - self.orig_disable_opt_rtn = True diff --git a/auto_round/algorithms/quantization/awq/quantizer.py b/auto_round/algorithms/quantization/awq/quantizer.py deleted file mode 100644 index e88f67cf1..000000000 --- a/auto_round/algorithms/quantization/awq/quantizer.py +++ /dev/null @@ -1,956 +0,0 @@ -# Copyright (c) 2026 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""AWQ (Activation-Aware Weight Quantization) quantizer. - -The algorithm: -1. Collects per-channel activation magnitudes during calibration. -2. For each smooth-balance mapping, performs a grid search over scaling ratios - to find the one that minimizes quantization error (output-based loss). -3. Applies the best channel-wise scaling: balance_layer.weight *= scales, - smooth_layer.weight /= scales (or smooth_layer.bias /= scales if 1-D). - - -Reference implementations: - - AutoAWQ: https://github.com/casper-hansen/AutoAWQ - - llm-compressor: https://github.com/vllm-project/llm-compressor -""" - -from __future__ import annotations - -import contextlib -import inspect -import traceback - -import torch -from tqdm import tqdm - -from auto_round.algorithms.quantization.awq.config import AWQConfig -from auto_round.algorithms.quantization.awq.mappings import ( - ResolvedMapping, - _extract_block_prefix, - resolve_mappings, -) -from auto_round.algorithms.quantization.base import BaseQuantizers -from auto_round.compressors.shard_writer import ShardWriter -from auto_round.compressors.utils import immediate_pack -from auto_round.data_type.utils import ( - get_quant_func, - reshape_pad_tensor_by_group_size, - revert_tensor_by_pad, - update_block_global_scale_if_needed, -) -from auto_round.logger import logger -from auto_round.utils import ( - check_to_quantized, - clear_memory, - convert_module_to_hp_if_necessary, - get_module, - set_amax_for_all_moe_layers, - set_module, -) -from auto_round.wrapper import WrapperLinear -from auto_round.wrapper import WrapperMultiblock as _WrapperMultiblock - - -class AWQQuantizer(BaseQuantizers): - """AWQ quantizer: activation-aware channel scaling + delegated quantization. - - AWQ is a pre-quantization transform that applies channel-wise scaling to reduce - quantization error. The actual quantization is delegated to an inner - quantizer (currently RTN by default). - """ - - def __init__(self, config: AWQConfig): - super().__init__(config) - self.duo_scaling = config.duo_scaling - self.n_grid = config.n_grid - self.apply_smooth = config.apply_smooth - - # Populated during calibration - self._user_mappings = config.mappings - self._resolved_mappings: list[ResolvedMapping] = [] - self._block_groups: dict[str, list[ResolvedMapping]] = {} - self._activation_stats: dict[str, list[torch.Tensor]] = {} - # Parent module kwargs cache: parent_module → list of kwargs dicts - self._parent_args_cache: dict[torch.nn.Module, list[dict]] = {} - self._parent_signatures: dict[int, inspect.BoundArguments] = {} - self._smoothing_applied: bool = False - - # Fail fast: validate scheme at construction time - self._check_scheme_compatibility() - - def bind(self, compressor) -> None: - """Wire shared state and validate compressor settings for AWQ compatibility.""" - super().bind(compressor) - # Check for unsupported compressor-level args - nblocks = getattr(compressor, "nblocks", 1) - if nblocks > 1: - logger.warning( - f"AWQ does not support nblocks > 1 (got nblocks={nblocks}). " - f"AWQ smoothing resolves activation hooks and mappings per single block prefix. " - f"Falling back to nblocks=1." - ) - compressor.nblocks = 1 - - def __setattr__(self, name, value): - """Trigger model-level AWQ setup when compress_context is assigned.""" - super().__setattr__(name, value) - if name == "compress_context" and value is not None: - self._prepare_model() - - # ── Public API ──────────────────────────────────────────────────────────── - - def resolve_all_mappings(self, model: torch.nn.Module) -> dict[str, list[ResolvedMapping]]: - """Resolve all AWQ mappings and group by block prefix. - - Call this once before the block-by-block loop. The resolved mappings - are stored internally and used by ``register_activation_hooks`` and - ``smooth_block`` for per-block filtering. - - Returns: - Dictionary mapping block prefix (e.g. "model.layers.0") to the - list of resolved mappings for that block. - """ - self._resolved_mappings = resolve_mappings(model, self._user_mappings) - block_groups: dict[str, list[ResolvedMapping]] = {} - for m in self._resolved_mappings: - prefix = _extract_block_prefix(m.smooth_name) - block_groups.setdefault(prefix, []).append(m) - self._block_groups = block_groups - return block_groups - - def apply_smoothing( - self, - model: torch.nn.Module, - device: torch.device | str | None = None, - ) -> None: - """Apply AWQ smoothing to the model weights. - - This is the core AWQ step: for each mapping, find the best scaling - factor and apply it to smooth/balance layers. - - When *device* is provided the model is assumed to reside on CPU and - only the modules needed for each mapping are temporarily moved to - *device* during the grid search. This dramatically reduces peak VRAM - (~14 GB savings for an 8B model). - - Should be called AFTER calibration data has been collected and BEFORE - RTN quantization. - """ - if self._smoothing_applied: - return - - # Resolve mappings - self._resolved_mappings = resolve_mappings(model, self._user_mappings) - if not self._resolved_mappings: - logger.warning("No AWQ mappings resolved; skipping smoothing.") - self._smoothing_applied = True - return - - if not self._activation_stats: - logger.error("No activation statistics collected for AWQ smoothing.") - - logger.info(f"Applying AWQ smoothing to {len(self._resolved_mappings)} mappings...") - - for mapping in tqdm(self._resolved_mappings, desc="AWQ Smoothing"): - if mapping.smooth_name not in self._activation_stats: - logger.warning(f"No activation stats for '{mapping.smooth_name}', skipping this layer.") - continue - - act_sum, act_count = self._activation_stats.pop(mapping.smooth_name) - if act_count == 0: - logger.warning(f"Zero activation count for '{mapping.smooth_name}', skipping this layer.") - continue - - # Mean activation per channel (act_count is a plain int now, - # avoiding a CPU tensor allocation per channel). - x_mean = (act_sum / act_count).to(torch.float32) - del act_sum - - with self._align_modules(mapping, device): - best_scales = self._grid_search_scales(mapping, x_mean) - - if best_scales is not None: - self._apply_scales(mapping, best_scales) - - # Eagerly free per-mapping parent cache to reduce RAM - parent = mapping.parent - self._parent_args_cache.pop(parent, None) - - self._activation_stats.clear() - self._parent_args_cache.clear() - self._parent_signatures.clear() - self._smoothing_applied = True - logger.info("AWQ smoothing complete.") - - def smooth_block(self, block_prefix: str) -> None: - """Apply AWQ smoothing for one block's mappings. - - Assumes activation stats and parent kwargs have already been collected - for this block via ``register_activation_hooks`` + block forward. - The block's modules are expected to be on the compute device, so - no device alignment is needed (unlike ``apply_smoothing`` which uses - ``_align_modules`` for CPU-offloaded models). - - Args: - block_prefix: Block prefix (e.g. "model.layers.0") identifying - which mappings to smooth. - """ - block_mappings = [m for m in self._resolved_mappings if _extract_block_prefix(m.smooth_name) == block_prefix] - for mapping in block_mappings: - if mapping.smooth_name not in self._activation_stats: - logger.warning(f"No activation stats for '{mapping.smooth_name}', skipping.") - continue - - act_sum, act_count = self._activation_stats.pop(mapping.smooth_name) - if act_count == 0: - logger.warning(f"Zero activation count for '{mapping.smooth_name}', skipping.") - continue - - x_mean = (act_sum / act_count).to(torch.float32) - del act_sum - - # Block is on the compute device — grid search runs in-place - best_scales = self._grid_search_scales(mapping, x_mean) - - if best_scales is not None: - self._apply_scales(mapping, best_scales) - - # Free parent kwargs cache after ALL mappings for this block are done. - # (All mappings in a block typically share the same parent module; - # popping inside the loop would break output-based loss for - # subsequent mappings.) - seen_parents = set() - for mapping in block_mappings: - pid = id(mapping.parent) - if pid not in seen_parents: - seen_parents.add(pid) - self._parent_args_cache.pop(mapping.parent, None) - - # ── Model-level initialization ───────────────────────────────────────────── - - def _prepare_model(self) -> None: - """One-time model-level AWQ setup: check compatibility and resolve mappings. - - Triggered automatically when model_context is assigned (via __setattr__). - """ - from auto_round.algorithms.quantization.awq.mappings import check_model_compatibility - - model = self.model_context.model - report = check_model_compatibility(model, self._user_mappings) - for w in report["warnings"]: - logger.warning(w) - if not report["compatible"]: - model_class = report.get("model_class", "unknown") - raise ValueError( - f"AWQ: no smooth-balance mappings could be resolved for " - f"'{model_class}'. Either the model architecture is not " - f"supported for automatic AWQ mapping detection, or the model " - f"has no repeating transformer block structure. " - f"You can provide explicit mappings via " - f"AWQConfig(mappings=[{{'smooth_layer': '', " - f"'balance_layers': ['', ...]}}])." - ) - - self.resolve_all_mappings(model) - - # AWQ caches block I/O on CPU — only one block lives on GPU at a time. - self.compress_context.cache_device = torch.device("cpu") - - def _check_scheme_compatibility(self) -> None: - """Validate the quantization scheme against AWQ inference backends.""" - bits = self.bits - act_bits = self.act_bits or 16 - data_type = self.data_type or "int" - - if "int" not in data_type: - raise ValueError( - f"AWQ requires integer data_type, got '{data_type}'. " - f"AWQ channel scaling is designed for integer quantization " - f"grids. Use algorithm='autoround' for FP8/MXFP quantization." - ) - - if act_bits is not None and act_bits < 16 and act_bits != 8: - raise ValueError( - f"AWQ with act_bits={act_bits} is not supported. " - f"No inference kernel exists for W{bits}A{act_bits} in vllm " - f"or sglang. Supported schemes: W4A16 (canonical AWQ) or " - f"W8A8 (compressed_tensors INT8 backend)." - ) - - if bits == 4 and act_bits >= 16: - pass - elif bits == 8 and act_bits == 8: - logger.info( - "AWQ with W8A8: AWQ smoothing will be applied, followed by " - "INT8 quantization. This is served by vllm's " - "compressed_tensors backend (cutlass INT8 GEMM), not AWQ " - "kernels." - ) - elif bits not in (4, 8): - logger.warning( - f"AWQ with bits={bits}: vllm AWQ kernels only support " - f"bits=4 (AWQ/Marlin) natively. bits=8 is supported via " - f"compressed_tensors. Other bit widths may not have " - f"optimized inference kernels." - ) - - # ── Weight quantization (delegated RTN) ───────────────────────────────────── - - def quantize_layer(self, name: str, dtype: torch.dtype = None) -> None: - """Quantize a single layer using RTN after AWQ smoothing has been applied. - - AWQ's quantize_layer is simpler than RTN's because: - - AWQ always uses plain RTN (disable_opt_rtn=True, no imatrix) - - No MoE-specific RTN disabling (AWQ handles MoE via mappings) - - No GGUF special path (AWQ targets AWQ/Marlin inference kernels) - - Args: - name: Fully-qualified module name (e.g. "model.layers.0.self_attn.q_proj"). - dtype: Optional dtype to cast the layer to before quantization. - """ - m = get_module(self.model, name) - if dtype is not None: - m = m.to(dtype) - - m = convert_module_to_hp_if_necessary(m, self.model_context.amp_dtype, self.compress_context.device) - set_module(self.model, name, m) - tuning_device = m.tuning_device if hasattr(m, "tuning_device") else self.compress_context.device - - try: - m = m.to(tuning_device) - m = WrapperLinear( - m, - device=tuning_device, - enable_minmax_tuning=False, - enable_norm_bias_tuning=False, - enable_round_tuning=False, - enable_torch_compile=self.compress_context.enable_torch_compile, - disable_opt_rtn=self.config.disable_opt_rtn, - iters=0, - ) - m = m.unwrapper({}) - except torch.OutOfMemoryError: - cuda_error_msg = traceback.format_exc() - m = m.orig_layer if hasattr(m, "orig_layer") else m - try: - logger.error(cuda_error_msg) - logger.warning("AWQ quantize_layer: falling back to CPU.") - m.to("cpu") - m = WrapperLinear( - m, - enable_minmax_tuning=False, - enable_norm_bias_tuning=False, - enable_round_tuning=False, - enable_torch_compile=self.compress_context.enable_torch_compile, - disable_opt_rtn=self.config.disable_opt_rtn, - iters=0, - ) - m = m.unwrapper({}) - except Exception: - raise - - set_module(self.model, name, m) - self._immediate_pack_and_save_module(name) - - def quantize_layer_outside_block(self, *args, **kwargs): - """Quantize layers outside blocks (e.g. lm_head). Delegates to quantize_layer.""" - return self.quantize_layer(*args, **kwargs) - - def _immediate_pack_and_save_module(self, module_name: str) -> None: - """Pack and/or save a quantized module if immediate mode is enabled.""" - shard_writer = ShardWriter.get_shard_writer() - to_cpu = self.compress_context.low_gpu_mem_usage - module = get_module(self.model, module_name) - if self.compress_context.is_immediate_packing: - immediate_pack(module_name, self.layer_config) - if to_cpu: - module = module.to("cpu") - packed_module = get_module(self.model, module_name) - set_module(self.model, module_name, packed_module.to("cpu")) - else: - if to_cpu: - module = module.to("cpu") - set_module(self.model, module_name, module) - if self.compress_context.is_immediate_saving: - module = get_module(self.model, module_name) - module.to("cpu") - shard_writer.write(module, module_name, False) - module.to("meta") - - # ── Block-level quantization ────────────────────────────────────────────── - - def quantize_block( - self, block: torch.nn.Module, input_ids=None, input_others=None, reference_output=None, **kwargs - ) -> dict: - """AWQ block quantization: collect stats → smooth → quantize. - - The full per-block AWQ pipeline: - 1. Register activation hooks for this block's mappings - 2. Run block forward to collect activation stats + parent kwargs - 3. Apply AWQ smoothing (grid search + scale application) - 4. Collect act_max AFTER smoothing (for activation quantization) - 5. Quantize the smoothed block via RTN - - Args: - block: Module already on the compute device. - input_ids: Calibration inputs for this block. - input_others: Additional kwargs for block forward. - reference_output: FP16 block output (kept for interface consistency). - **kwargs: Must include 'block_name' for mapping lookup. - """ - if isinstance(block, _WrapperMultiblock): - raise ValueError( - "AWQ does not support nblocks > 1 (multi-block quantization). " - "Each block must be quantized individually. " - "Please set nblocks=1 when using algorithm='awq'." - ) - - block_name = kwargs.get("block_name") - if block_name is None: - # Infer block_name from resolved mappings by matching the block's modules - for prefix, mappings in self._block_groups.items(): - if any(m.smooth_layer is mod or m.parent is mod for m in mappings for mod in block.modules()): - block_name = prefix - break - if block_name is None: - raise ValueError( - "AWQQuantizer.quantize_block() could not determine block_name. " - "Pass block_name explicitly or ensure resolved mappings cover this block." - ) - - model = self.model - bs = self.batch_size * self.infer_bs_coeff - if self.compress_context.low_gpu_mem_usage: - bs = 1 - logger.info("AWQ: low_gpu_mem_usage enabled, setting inference batch size to 1.") - - # Step 1 & 2: AWQ smoothing (optional) - if self.apply_smooth: - awq_hooks = self.register_activation_hooks(model, block_prefix=block_name) - self._get_block_outputs(block, input_ids, input_others, bs, save_output=False) - for h in awq_hooks: - h.remove() - - self.smooth_block(block_name) - - # Step 3: Collect act_max AFTER smoothing - # AWQ smoothing changes internal activations (LayerNorm output /= scales), - # so act_max must be collected post-smoothing. Reset any pre-smoothing - # act_max values first to avoid stale data persisting via max(). - act_max_hooks = self.register_calibration_hooks(block, imatrix=False) - if act_max_hooks: - for _name, m in block.named_modules(): - if hasattr(m, "act_max"): - del m.act_max - self._get_block_outputs(block, input_ids, input_others, bs, save_output=False) - for h in act_max_hooks: - h.remove() - - # Step 4: Quantize the smoothed block (delegated RTN) - # This is equivalent to RTNQuantizer.quantize_block but inlined here - # to avoid inheritance coupling. - update_block_global_scale_if_needed(block, self.data_type, self.group_size) - if ( - self.config.is_act_nv_fp - or self.config.is_static_afp8 - or (self.config.is_wfp8afp8 and not self.config.act_dynamic) - ): - set_amax_for_all_moe_layers(block, attr_name="act_max") - - for _name, m in block.named_modules(): - if hasattr(m, "global_name") and check_to_quantized(m): - self.quantize_layer(m.global_name) - - return {} - - # ── Module alignment (device onload/offload) ───────────────────────────── - - @contextlib.contextmanager - def _align_modules( - self, - mapping, - device: torch.device | str | None, - ): - """Temporarily move mapping-related modules to *device* for grid search. - - If *device* is None (model already on the compute device), this is a - no-op passthrough. - """ - if device is None: - yield - return - - modules_to_move = [mapping.parent, mapping.smooth_layer] + list(mapping.balance_layers) - # Deduplicate (parent may overlap with smooth/balance) - seen = set() - unique = [] - for m in modules_to_move: - if id(m) not in seen: - seen.add(id(m)) - unique.append(m) - - # Onload - for m in unique: - m.to(device) - - try: - yield - finally: - # Offload back to CPU - for m in unique: - m.to("cpu") - clear_memory() - - # ── Activation statistics collection ────────────────────────────────────── - - def register_activation_hooks( - self, - model: torch.nn.Module, - block_prefix: str | None = None, - ) -> list: - """Register activation hooks for AWQ and return hook handles. - - Hooks are registered for: - 1. Smooth layers: accumulate abs activation sums/counts for x_mean. - 2. Parent modules: cache forward kwargs for output-based loss - - Args: - model: The model (hooks are registered on the actual module - objects, which are shared with blocks obtained via - ``get_module``). - block_prefix: When provided, only register hooks for mappings - whose smooth layer belongs to this block (e.g. - "model.layers.0"). Used in the block-by-block pipeline. - - Should be called by the compressor before calibration, and handles - should be removed after calibration. - """ - if not self._resolved_mappings: - self._resolved_mappings = resolve_mappings(model, self._user_mappings) - - mappings = self._resolved_mappings - if block_prefix is not None: - mappings = [m for m in mappings if _extract_block_prefix(m.smooth_name) == block_prefix] - - hooks = [] - smooth_names = {m.smooth_name for m in mappings} - - # Hook smooth layers for activation statistics - for name, module in model.named_modules(): - if name not in smooth_names: - continue - - def _make_hook(layer_name): - def hook_fn(mod, args, output): - if isinstance(output, tuple): - x = output[0] - else: - x = output - if x is None or x.numel() == 0: - return - - # Compute abs-sum directly on GPU without materialising a - # full float32 intermediate. flatten→abs→sum produces a - # single [channels] tensor; .cpu() immediately frees the - # GPU allocation (aligned with AR's "immediate CPU move" - # pattern from CalibCompressor hooks). - channel_sum = x.detach().float().flatten(0, -2).abs().sum(dim=0).cpu() - count = x.shape[:-1].numel() - - if layer_name not in self._activation_stats: - self._activation_stats[layer_name] = [ - torch.zeros_like(channel_sum), - 0, - ] - self._activation_stats[layer_name][0] += channel_sum - self._activation_stats[layer_name][1] += count - - return hook_fn - - h = module.register_forward_hook(_make_hook(name)) - hooks.append(h) - - # Hook parent modules to cache kwargs for output-based loss - parent_modules_hooked = set() - for mapping in mappings: - parent = mapping.parent - if id(parent) in parent_modules_hooked: - continue - parent_modules_hooked.add(id(parent)) - - self._parent_args_cache[parent] = [] - - def _make_parent_hook(parent_module): - def hook_fn(mod, args, kwargs): - # Cache ALL calibration batches for output-based grid - # search loss, use all calibration data without subsampling. - mod_cls_id = id(type(mod)) - if mod_cls_id not in self._parent_signatures: - self._parent_signatures[mod_cls_id] = inspect.signature(mod.forward) - sig = self._parent_signatures[mod_cls_id] - bound = sig.bind(*args, **kwargs) - bound.apply_defaults() - - # Infer the parent's compute dtype so that tensors - # upcasted by nn.LayerNorm (float32) are stored in the - # weight dtype (e.g. bfloat16). This avoids repeated - # per-sample casting in the grid search stage. - param = next(mod.parameters(), None) - w_dtype = param.dtype if param is not None else None - - stored = {} - for k, v in bound.arguments.items(): - if isinstance(v, torch.Tensor): - v = v.detach() - if w_dtype is not None and v.is_floating_point() and v.dtype != w_dtype: - v = v.to(w_dtype) - stored[k] = v - elif isinstance(v, tuple) and any(isinstance(t, torch.Tensor) for t in v): - # Detach tensors in tuples (e.g. position_embeddings - # = (cos, sin)) to release computation graph refs. - stored[k] = tuple( - ( - ( - t.detach().to(w_dtype) - if w_dtype and t.is_floating_point() and t.dtype != w_dtype - else t.detach() - ) - if isinstance(t, torch.Tensor) - else t - ) - for t in v - ) - elif hasattr(v, "key_cache"): - # Null out KV cache objects (DynamicCache etc.) - stored[k] = None - else: - stored[k] = v - - self._parent_args_cache[parent_module].append(stored) - - return hook_fn - - h = parent.register_forward_pre_hook(_make_parent_hook(parent), with_kwargs=True) - hooks.append(h) - - return hooks - - # ── Grid search ─────────────────────────────────────────────────────────── - - def _get_grid_search_params(self) -> list[tuple[float, bool]]: - """Get grid search parameters (ratio, duo_scaling). - - Returns: - List of (ratio, use_duo_scaling) tuples for the grid search. - """ - match self.duo_scaling: - # "both": half grid with duo off, half with duo on - case "both": - n_grid = max(int(self.n_grid / 2), 2) - return [ - (grid_idx / (n_grid - 1), duo_scaling) - for grid_idx in range(n_grid) - for duo_scaling in [False, True] - ] - case False: - n_grid = max(self.n_grid, 2) - return [(grid_idx / (n_grid - 1), False) for grid_idx in range(n_grid)] - # True: include identity (0.0, False) as first, then duo points - case True: - n_grid = max(self.n_grid, 3) - return [(0.0, False)] + [(grid_idx / (n_grid - 2), True) for grid_idx in range(n_grid - 1)] - case _: - raise ValueError(f"Found unexpected duo_scaling configuration {self.duo_scaling}") - - @staticmethod - def _compute_layer_means(layers: list[torch.nn.Module], group_size: int) -> torch.Tensor: - """Compute per-channel mean of normalised weights. - - Within each quantization group, weights are normalized by their group max - (so values are on a 0-1 scale), then averaged across all groups to get - per-channel importance. - - Args: - layers: Balance layers whose weights are concatenated. - group_size: Quantization group size. If -1, uses full channel width. - - Returns: - Per-channel mean of normalised weights [in_features]. - """ - # Concatenate all balance layer weights [total_out, in_features] - weight = torch.cat([m.weight.detach().float() for m in layers], dim=0) - org_shape = weight.shape - - # Determine effective group size - gs = group_size if group_size > 0 else org_shape[1] - - # Pad when needed so AWQ works with layers whose input width is not a - # multiple of group_size, matching grouped RTN quantization behavior. - weight, _, pad_len = reshape_pad_tensor_by_group_size(weight, gs) - # Normalize within each group: abs / max (0-1 scale per group) - w_scale = weight.abs() / (weight.abs().amax(dim=1, keepdim=True) + 1e-6) - # Remove padding, then take mean across output channels. - w_scale = revert_tensor_by_pad(w_scale, orig_shape=org_shape, pad_len=pad_len) - w_mean = w_scale.mean(0) - return w_mean - - @torch.no_grad() - def _grid_search_scales( - self, - mapping: ResolvedMapping, - x_mean: torch.Tensor, - ) -> torch.Tensor | None: - """Find the best scaling ratio via grid search. - - Uses output-based error: - ``L(s) = || fp16_output - Q(W*s) @ (X/s) ||^2`` - - For each candidate scaling ratio, applies scales to balance layer - weights, quantize-dequantizes them, runs all cached calibration - batches through the parent module, and computes the output MSE - against the fp16 baseline. - - Returns: - Best scales tensor, or None if no valid scale was found. - """ - device = mapping.balance_layers[0].weight.device - x_mean = x_mean.to(device) - - # Compute normalised weight means - group_size = self.group_size if self.group_size > 0 else -1 - if self.duo_scaling is not False: - w_mean = self._compute_layer_means(mapping.balance_layers, group_size).to(device) - - # Try to run parent module forward for output-based loss - parent_kwargs_list = self._parent_args_cache.get(mapping.parent, []) - use_parent_forward = len(parent_kwargs_list) > 0 - - if use_parent_forward: - # Compute fp16 baseline outputs for loss computation - fp16_outputs = self._run_parent_samples(mapping.parent, parent_kwargs_list) - if not fp16_outputs or all(f.numel() == 0 for f in fp16_outputs): - use_parent_forward = False - - # Save original weights for restoration during grid search - orig_state = {bl: bl.weight.data.clone() for bl in mapping.balance_layers} - - best_error = float("inf") - best_scales = None - best_ratio = -1 - - # Pre-resolve quant function once. - ref_layer = mapping.balance_layers[0] - ref_name = getattr(ref_layer, "global_name", None) or "" - ref_cfg = self.layer_config.get(ref_name, {}) - try: - cached_quant_func, _ = get_quant_func( - ref_cfg.get("data_type", self.data_type), - ref_cfg.get("bits", self.bits), - ref_cfg.get("sym", self.sym), - disable_opt_rtn=True, - group_size=ref_cfg.get("group_size", self.group_size), - iters=0, - ) - except Exception: - cached_quant_func = None - - grid_params = self._get_grid_search_params() - - for ratio, use_duo in grid_params: - # Compute scales - if use_duo: - scales = (x_mean.pow(ratio) / (w_mean.pow(1 - ratio) + 1e-4)).clamp(min=1e-4) - else: - scales = x_mean.pow(ratio).clamp(min=1e-4).view(-1) - scales = scales / (scales.max() * scales.min()).sqrt() - scales[torch.isinf(scales)] = 1 - scales[torch.isnan(scales)] = 1 - scales_view = scales.view(1, -1).to(device) - - if use_parent_forward: - # Q(W * s) / s: mutate balance layer weights - for bl in mapping.balance_layers: - bl.weight.data.copy_(orig_state[bl] * scales_view) - w_qdq = self._quantize_dequantize_weight( - bl, - bl.weight.data.float(), - quant_func=cached_quant_func, - ) - if w_qdq is not None: - bl.weight.data = (w_qdq / scales_view).to(bl.weight.dtype) - else: - bl.weight.data.copy_(orig_state[bl]) - - # Collect quantized outputs then compute loss - int_w_outputs = self._run_parent_samples(mapping.parent, parent_kwargs_list) - total_loss = self._compute_loss(fp16_outputs, int_w_outputs) - del int_w_outputs - - # Restore original weights - for bl in mapping.balance_layers: - bl.weight.data.copy_(orig_state[bl]) - else: - # Weight-only fallback: || W - Q(W*s)/s ||^2 - total_loss = 0.0 - for bl in mapping.balance_layers: - w_orig = orig_state[bl].to(device) - w_scaled = w_orig * scales_view - w_qdq = self._quantize_dequantize_weight( - bl, - w_scaled, - quant_func=cached_quant_func, - ) - if w_qdq is None: - total_loss = float("inf") - break - w_qdq_unscaled = w_qdq / scales_view - total_loss += (w_orig - w_qdq_unscaled).pow(2).sum().item() - - if total_loss < best_error: - best_error = total_loss - best_scales = scales.clone() - best_ratio = ratio - - if best_ratio < 0: - logger.warning(f"AWQ grid search failed for '{mapping.smooth_name}': " "no finite error found.") - return None - - logger.debug(f"AWQ '{mapping.smooth_name}': best_ratio={best_ratio:.2f}, " f"best_error={best_error:.3e}") - return best_scales - - # ── Parent module forward ───────────────────────────── - - @torch.no_grad() - def _run_parent_samples( - self, - parent: torch.nn.Module, - kwargs_list: list[dict], - ) -> list[torch.Tensor]: - """Run cached samples through the parent module. - - Feeds cached kwargs through the parent forward without any CUDA - synchronisation between batches, so the GPU can pipeline all forwards. - Outputs are kept on-device for loss computation. - """ - outputs = [] - for stored_kwargs in kwargs_list: - out = parent(**stored_kwargs) - if isinstance(out, tuple): - out = out[0] - outputs.append(out) - return outputs - - @staticmethod - @torch.no_grad() - def _compute_loss( - fp16_outputs: list[torch.Tensor], - int_w_outputs: list[torch.Tensor], - ) -> float: - """Compute normalised MSE between fp16 and quantized outputs.""" - device = fp16_outputs[0].device - loss = torch.tensor(0.0, device=device) - num_elements = torch.tensor(0, device=device, dtype=torch.long) - - for fp16_out, int_w_out in zip(fp16_outputs, int_w_outputs): - loss += torch.nn.functional.mse_loss( - fp16_out.float(), - int_w_out.to(fp16_out.device).float(), - reduction="sum", - ) - num_elements += fp16_out.numel() - - if num_elements == 0: - return float("inf") - return (loss / num_elements).item() - - def _quantize_dequantize_weight( - self, - layer: torch.nn.Module, - weight: torch.Tensor, - quant_func=None, - ) -> torch.Tensor | None: - """Quantize-dequantize a weight tensor using the layer's config. - - Args: - quant_func: Pre-resolved quantization function. When provided, - skips the ``get_quant_func`` dispatch (avoids redundant - lookups in the inner grid search loop). - - Returns the quantized-then-dequantized weight, or None on failure. - """ - layer_name = getattr(layer, "global_name", None) or "" - config = self.layer_config.get(layer_name, {}) - bits = config.get("bits", self.bits) - group_size = config.get("group_size", self.group_size) - sym = config.get("sym", self.sym) - data_type = config.get("data_type", self.data_type) - - if quant_func is None: - try: - quant_func, _ = get_quant_func( - data_type, - bits, - sym, - disable_opt_rtn=True, # AWQ always uses plain RTN - group_size=group_size, - iters=0, # Route to rtn_int_sym, not int_sym - ) - except Exception: - return None - - if quant_func is None: - return None - - try: - qdq_weight, scale, zp = quant_func( - weight, - bits=bits, - group_size=group_size, - ) - return qdq_weight - except Exception: - return None - - # ── Apply scales ────────────────────────────────────────────────────────── - - @torch.no_grad() - def _apply_scales(self, mapping: ResolvedMapping, scales: torch.Tensor) -> None: - """Apply the computed AWQ scales to smooth and balance layers. - - - Balance layers (Linear): weight *= scales (along input channels) - - Smooth layer (LayerNorm/RMSNorm): weight /= scales, bias /= scales - """ - for bl in mapping.balance_layers: - device = bl.weight.device - s = scales.to(device).view(1, -1) - bl.weight.data.mul_(s) - - smooth = mapping.smooth_layer - device = smooth.weight.device - s = scales.to(device) - - if smooth.weight.ndim == 1: - # LayerNorm / RMSNorm: 1-D weight (per-channel) - smooth.weight.data.div_(s) - else: - # Edge case: when smooth layer's out_features != balance layer's - # in_features (e.g. fused qkv_proj smoothing o_proj). Scale the - # last output features (aligned with AutoAWQ). - # https://github.com/casper-hansen/AutoAWQ/blob/main/awq/quantize/scale.py#L123 - smooth.weight.data[-s.size(0) :].div_(s.view(-1, 1)) - - if hasattr(smooth, "bias") and smooth.bias is not None: - smooth.bias.data.div_(s) diff --git a/auto_round/algorithms/quantization/base.py b/auto_round/algorithms/quantization/base.py index e41c1223a..cc3e09bcc 100644 --- a/auto_round/algorithms/quantization/base.py +++ b/auto_round/algorithms/quantization/base.py @@ -11,20 +11,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import importlib import traceback -from collections import defaultdict -from typing import Union +from contextlib import contextmanager import torch +from auto_round.algorithms.base import BasePipelineMember from auto_round.algorithms.quantization.config import QuantizationConfig -from auto_round.algorithms.quantization.utils import register_act_max_hooks +from auto_round.algorithms.transforms.base import BaseWeightTransformer from auto_round.compressors.utils import ( block_forward, + check_need_act_calibration, immediate_pack, ) from auto_round.data_type import QUANT_FUNC_WITH_DTYPE +from auto_round.data_type.utils import reshape_pad_tensor_by_group_size from auto_round.logger import logger from auto_round.utils import ( INNER_SUPPORTED_LAYER_TYPES, @@ -39,19 +40,129 @@ from auto_round.wrapper import WrapperLinear -class BaseQuantizers: - # Class-level attribute declarations for convenient access in quantization methods. - # Scheme-related attrs (layer_config, scale_dtype, has_qlayer_outside_block, etc.) - # are resolved by SchemeMixin in BaseCompressor and synced here after post_init(). - model_context = None - compress_context = None - dataset = None - supported_types = SUPPORTED_LAYER_TYPES - inner_supported_types = INNER_SUPPORTED_LAYER_TYPES - enable_alg_ext = False - # Subclasses that support diffusion models should override this with the - # appropriate output key mapping, e.g.: - # DIFFUSION_OUTPUT_CONFIGS = {"FluxTransformerBlock": ["encoder_hidden_states", "hidden_states"]} +class RTNLayerFallbackMixin: + """Default outside-block/layer quantization via RTN. + + Algorithms that want RTN fallback for embeddings, lm_head, or layers outside + transformer blocks should inherit this mixin explicitly. + """ + + @torch.no_grad() + def quantize_layer_via_rtn(self, layer_name: str, disable_opt_rtn: bool | None = None) -> None: + """Quantize one layer with RTN and handle optional immediate pack/save.""" + layer = get_module(self.model, layer_name) + layer = convert_module_to_hp_if_necessary(layer, self.model_context.amp_dtype, self.compress_context.device) + set_module(self.model, layer_name, layer) + tuning_device = layer.tuning_device if hasattr(layer, "tuning_device") else self.compress_context.device + if ( + self.compress_context.is_immediate_packing + and self.compress_context.formats[0].is_gguf() + and not getattr(self.config, "disable_opt_rtn", False) + ): + layer = layer.to(tuning_device) + layer.scale = None + layer.zp = None + else: + try: + if disable_opt_rtn is None: + disable_opt_rtn = bool(getattr(self.config, "disable_opt_rtn", False)) + if ( + not disable_opt_rtn + and getattr(self.config, "orig_disable_opt_rtn", None) is None + and self.model_context.is_moe_model + and "expert" in layer.global_name + and "shared_expert" not in layer.global_name + and self.config.super_bits is None + ): + disable_opt_rtn = True + logger.warning_once( + "MoE layer detected: optimized RTN is disabled for efficiency. " + "Use `--enable_opt_rtn` to force-enable it for MoE layers." + ) + layer = layer.to(tuning_device) + layer = WrapperLinear( + layer, + device=tuning_device, + enable_minmax_tuning=False, + enable_norm_bias_tuning=False, + enable_round_tuning=False, + enable_torch_compile=self.compress_context.enable_torch_compile, + disable_opt_rtn=disable_opt_rtn, + iters=0, + ) + layer = layer.unwrapper({}) + except torch.OutOfMemoryError: + cuda_error_msg = traceback.format_exc() + layer = layer.orig_layer if hasattr(layer, "orig_layer") else layer + try: + logger.error(cuda_error_msg) + logger.warning("falling back to CPU.") + layer.to("cpu") + layer = WrapperLinear( + layer, + enable_minmax_tuning=False, + enable_norm_bias_tuning=False, + enable_round_tuning=False, + enable_torch_compile=self.compress_context.enable_torch_compile, + iters=0, + ) + layer = layer.unwrapper({}) + except Exception: + raise + set_module(self.model, layer_name, layer) + self._immediate_pack_and_save_module(layer_name) + + def _immediate_pack_and_save_module(self, module_name): + from auto_round.compressors.shard_writer import ShardWriter + + shard_writer = ShardWriter.get_shard_writer() + to_cpu = self.compress_context.low_gpu_mem_usage + module = get_module(self.model, module_name) + if self.compress_context.is_immediate_packing: + immediate_pack(module_name, self.layer_config) + if to_cpu: + module = module.to("cpu") + packed_module = get_module(self.model, module_name) + set_module(self.model, module_name, packed_module.to("cpu")) + else: + if to_cpu: + module = module.to("cpu") + set_module(self.model, module_name, module) + if self.compress_context.is_immediate_saving: + module = get_module(self.model, module_name) + module.to("cpu") + shard_writer.write(module, module_name, False) + module.to("meta") + + def quantize_layer_outside_block(self, layer_name: str, input_ids=None, **kwargs): + dtype = kwargs.pop("dtype", None) + if dtype is not None: + layer = get_module(self.model, layer_name) + set_module(self.model, layer_name, layer.to(dtype)) + self.quantize_layer_via_rtn(layer_name, **kwargs) + + +class DiffusionMixin: + """Mixin that adds diffusion-model support to a :class:`BaseQuantizer` subclass. + + Attach to any :class:`BaseQuantizer` subclass that needs to quantize + diffusion models (e.g. Flux, StableAudio, Wan): + + class SignRoundQuantizer(DiffusionMixin, BaseQuantizer): ... + + The mixin overrides :meth:`create_block_io` so diffusion-specific input + slicing, output mapping, and hidden_states extraction live in + :class:`DiffusionBlockIO` rather than individual quantizers. + + ``DIFFUSION_OUTPUT_CONFIGS`` maps block class name → ordered list of output + tensor keys. Extend in subclasses to register new diffusion architectures + without overriding any method. + """ + + # Map block class name → list of output tensor keys returned by that block. + # The order of keys must match the order of tensors in the block's return tuple. + # Subclasses can extend this dict to register new architectures without + # overriding any method. DIFFUSION_OUTPUT_CONFIGS: dict = { "FluxTransformerBlock": ["encoder_hidden_states", "hidden_states"], "FluxSingleTransformerBlock": ["encoder_hidden_states", "hidden_states"], @@ -61,24 +172,61 @@ class BaseQuantizers: "WanTransformerBlock": ["hidden_states"], } + def _get_output_config(self, block: torch.nn.Module) -> list: + """Return the output key list for *block* from ``DIFFUSION_OUTPUT_CONFIGS``.""" + return self.DIFFUSION_OUTPUT_CONFIGS.get(block.__class__.__name__, ["hidden_states"]) + + def create_block_io(self, input_ids, input_others, quantized_input=None, block=None): + from auto_round.algorithms.pipeline import DiffusionBlockIO, InputSource + + active_source = ( + InputSource.QUANTIZED_INPUT + if quantized_input is not None and self.enable_quanted_input + else InputSource.FP_CACHE + ) + io = DiffusionBlockIO( + fp_inputs=input_ids, + input_others=input_others, + quantized_inputs=quantized_input, + active_source=active_source, + batch_dim=self.batch_dim, + seqlen=self.seqlen, + shared_cache_keys=self.model_context.shared_cache_keys, + output_config=self._get_output_config(block), + ) + io.fp_inputs, io.input_others = io.preprocess_block_inputs(io.fp_inputs, io.input_others, block) + return io + + +class BaseQuantizer(BasePipelineMember): + """Base class for terminal weight-compression algorithms in a QuantizationPipeline. + + Developers adding a new quantization algorithm should inherit from this + class and override at minimum :meth:`quantize_block`. + + For diffusion model support, also inherit :class:`DiffusionMixin`: + ``class MyQuantizer(DiffusionMixin, BaseQuantizer): ...`` + + Lifecycle hooks to override as needed: + - :meth:`prepare_run` – model-level setup (once before all blocks) + - :meth:`get_act_calib_policy` – activation calibration policy + - :meth:`block_forward_hooks` – register act-calib hooks (context manager) + - :meth:`quantize_block` – **must override**: quantize a single block + - :meth:`quantize_layer_outside_block` – quantize layers outside blocks + - :meth:`finalize_run` – model-level teardown (once after all blocks) + """ + + # Class-level attribute declarations for convenient access in quantization methods. + # Scheme-related attrs (layer_config, scale_dtype, has_qlayer_outside_block, etc.) + # are resolved by SchemeMixin in BaseCompressor and synced here after post_init(). + dataset = None + supported_types = SUPPORTED_LAYER_TYPES + inner_supported_types = INNER_SUPPORTED_LAYER_TYPES + enable_alg_ext = False + def __init__(self, config: QuantizationConfig): - self.config = config + super().__init__(config) self.layer_config = None - self.bits = config.bits - self.group_size = config.group_size - self.sym = config.sym - self.data_type = config.data_type - self.act_bits = config.act_bits - self.act_group_size = config.act_group_size - self.act_sym = config.act_sym - self.act_data_type = config.act_data_type - self.act_dynamic = config.act_dynamic - self.super_bits = config.super_bits - self.super_group_size = config.super_group_size - self.scale_dtype = config.scale_dtype - self.ignore_layers = config.ignore_layers - self.quant_lm_head = config.quant_lm_head - self.to_quant_block_names = config.to_quant_block_names # Calibration-time state lives on a shared # :class:`~auto_round.calibration.state.CalibrationState` instance. # The compressor wires its own instance here in ``_resolve_scheme``; @@ -136,15 +284,6 @@ def nsamples(self) -> int: def seqlen(self) -> int: return self._calibration_state.seqlen - @classmethod - def from_config(cls, config: QuantizationConfig): - if cls.__name__ == config._alg_cls: - return cls(config) - else: - module = importlib.import_module("auto_round.algorithms.quantization") - alg_cls = getattr(module, config._alg_cls) - return alg_cls(config) - def bind(self, compressor) -> None: """Wire shared state from the owning compressor. @@ -155,6 +294,7 @@ def bind(self, compressor) -> None: """ self.model_context = compressor.model_context self.compress_context = compressor.compress_context + self.scheme = compressor.scheme_context self.scale_dtype = compressor.scale_dtype # Share the compressor's CalibrationState instance. self._calibration_state = compressor._calibration_state @@ -173,23 +313,130 @@ def amp(self): @property def amp_dtype(self): - import torch - return getattr(self.model_context, "amp_dtype", torch.float32) - def register_calibration_hooks(self, model, *, act_max: bool = True, imatrix: bool = True): - return register_act_max_hooks(self, model) if act_max else [] + # ── Activation-calibration hook infrastructure ─────────────────────────────── + + def _register_act_max_hooks(self, model): + """Register per-module act_max tracking hooks for static activation quantization. + + Internal implementation called by :meth:`block_forward_hooks`. + Returns a list of hook handles that the caller must remove when done. + """ + + def collect_act_max(module, input, output): + input = input[0] if isinstance(input, (tuple, list)) else input + if input.numel() == 0: + return + input, _, _ = reshape_pad_tensor_by_group_size(input, self.act_group_size) + act_max = torch.max(torch.abs(input), dim=-1).values + if not hasattr(module, "act_max") or module.act_max.numel() == 0: + module.act_max = act_max + if self.config.is_act_nv_fp: + max_val = act_max.max() + module.act_max = max_val.unsqueeze(0) if max_val.dim() == 0 else max_val + return + + act_max = act_max.to(module.act_max.device) + if self.config.is_act_nv_fp: + max_val = torch.max(act_max.max(), module.act_max.max()) + module.act_max = max_val.unsqueeze(0) if max_val.dim() == 0 else max_val + else: + module.act_max = torch.max(act_max, module.act_max) + + def should_collect(name, module): + if isinstance(module, SUPPORTED_LAYER_TYPES): + return ( + hasattr(module, "act_dynamic") + and check_need_act_calibration(module.act_dynamic, module.act_data_type, module.act_bits) + and check_to_quantized(module) + ) + if name in self.layer_config: + config = self.layer_config[name] + act_dynamic = config.get("act_dynamic", True) + act_data_type = config.get("act_data_type", None) + act_bits = config.get("act_bits", 16) + return ( + config["bits"] <= 8 + and check_need_act_calibration(act_dynamic, act_data_type, act_bits) + and check_to_quantized(config) + ) + return False + + handles = [] + if should_collect("", model): + handles.append(model.register_forward_hook(collect_act_max)) + return handles + for name, module in model.named_modules(): + if name and should_collect(name, module): + handles.append(module.register_forward_hook(collect_act_max)) + return handles + + @contextmanager + def block_forward_hooks(self, ctx): + """Register act-calib forward hooks for the reference forward. + + Implements the :meth:`BasePipelineMember.block_forward_hooks` interface for + terminal quantizers. Yields the list of hook handles so the caller can + determine whether any act-calib hooks were registered (used to decide + whether a second forward with quantized inputs is needed). + """ + from auto_round.algorithms.pipeline import CalibTiming + + policy = self.get_act_calib_policy(ctx) + if policy.when == CalibTiming.SKIP: + yield [] + return + handles = self._register_act_max_hooks(ctx.block) + try: + yield handles + finally: + for h in handles: + h.remove() + + def get_act_calib_policy(self, ctx): + """Return the activation calibration policy for this block. + + Default: ``WITH_REFERENCE + FP_CACHE``, or ``QUANTIZED_INPUT`` when + ``enable_quanted_input=True`` and a quantized previous-block output is available. + """ + from auto_round.algorithms.pipeline import ActCalibPolicy, CalibTiming, InputSource + + quantized_input = getattr(ctx, "quantized_input", None) + if quantized_input is not None and self.enable_quanted_input: + return ActCalibPolicy(when=CalibTiming.WITH_REFERENCE, source=InputSource.QUANTIZED_INPUT) + return ActCalibPolicy(when=CalibTiming.WITH_REFERENCE, source=InputSource.FP_CACHE) + + def create_block_io(self, input_ids, input_others, quantized_input=None, block=None): + from auto_round.algorithms.pipeline import BlockIO, InputSource + + active_source = ( + InputSource.QUANTIZED_INPUT + if quantized_input is not None and self.enable_quanted_input + else InputSource.FP_CACHE + ) + return BlockIO( + fp_inputs=input_ids, + input_others=input_others, + quantized_inputs=quantized_input, + active_source=active_source, + batch_dim=self.batch_dim, + seqlen=self.seqlen, + shared_cache_keys=self.model_context.shared_cache_keys, + ) + + # ── Embedding quantization ──────────────────────────────────────────────────── @torch.inference_mode() - def _quantize_embedding_layer(self): + def quantize_embedding_layer(self): """Quantizes embedding layers in the model according to the configuration. This method iterates through all modules in the model, identifies embedding - layers specified in `self.quantizer.layer_config`, and applies the appropriate quantization + layers specified in `self.layer_config`, and applies the appropriate quantization function based on bit precision, grouping strategy, and dtype. Returns: - bool: True if the quantization process completes without critical errors. + bool: True if any embedding layer was quantized. """ is_quantized = False for name, module in self.model_context.model.named_modules(): @@ -263,9 +510,9 @@ def _quantize_embedding_layer(self): return is_quantized - def quantize_block( - self, block: torch.nn.Module, input_ids=None, input_others=None, reference_output=None, **kwargs - ) -> dict: + # ── Abstract quantization interface ────────────────────────────────────────── + + def quantize_block(self, ctx) -> dict: """Apply the quantization algorithm to a prepared block. This is the **pure-algorithm** entry point called by the Compressor after @@ -273,110 +520,17 @@ def quantize_block( registration, DDP setup) has been completed. Implementations should: - - Perform the algorithm-specific weight/activation quantization on ``block``. + - Perform the algorithm-specific weight/activation quantization on ``ctx.block``. - Return a dict of best parameters (may be empty for zero-shot algorithms). Args: - block: Module already placed on the correct device(s). - input_ids: Calibration inputs on cache_device (None for zero-shot RTN). - input_others: Additional inputs (None for zero-shot RTN). - reference_output: FP reference outputs collected by Compressor - (None for algorithms that don't need a reconstruction loss). - **kwargs: Algorithm-specific keyword arguments (e.g. ``loss_device``, - ``card_0_in_high_risk`` for SignRoundQuantizer). + ctx: Per-block pipeline context. ``ctx.io`` owns calibration inputs, + quantized inputs, and reference outputs. Returns: dict: Best quantization parameters found, or ``{}`` if not applicable. """ - raise NotImplementedError("quantize_block must be implemented in subclasses of BaseQuantizers") - - @torch.no_grad() - def quantize_layer_via_rtn(self, layer_name: str, disable_opt_rtn: bool | None = None) -> None: - """Quantize one layer with RTN and handle optional immediate pack/save.""" - layer = get_module(self.model, layer_name) - layer = convert_module_to_hp_if_necessary(layer, self.model_context.amp_dtype, self.compress_context.device) - set_module(self.model, layer_name, layer) - tuning_device = layer.tuning_device if hasattr(layer, "tuning_device") else self.compress_context.device - if ( - self.compress_context.is_immediate_packing - and self.compress_context.formats[0].is_gguf() - and not getattr(self.config, "disable_opt_rtn", False) - ): - layer = layer.to(tuning_device) - layer.scale = None - layer.zp = None - else: - try: - if disable_opt_rtn is None: - disable_opt_rtn = bool(getattr(self.config, "disable_opt_rtn", False)) - if ( - not disable_opt_rtn - and getattr(self.config, "orig_disable_opt_rtn", None) is None - and self.model_context.is_moe_model - and "expert" in layer.global_name - and "shared_expert" not in layer.global_name - and self.config.super_bits is None - ): - disable_opt_rtn = True - logger.warning_once( - "MoE layer detected: optimized RTN is disabled for efficiency. " - "Use `--enable_opt_rtn` to force-enable it for MoE layers." - ) - layer = layer.to(tuning_device) - layer = WrapperLinear( - layer, - device=tuning_device, - enable_minmax_tuning=False, - enable_norm_bias_tuning=False, - enable_round_tuning=False, - enable_torch_compile=self.compress_context.enable_torch_compile, - disable_opt_rtn=disable_opt_rtn, - iters=0, - ) - layer = layer.unwrapper({}) - except torch.OutOfMemoryError: - cuda_error_msg = traceback.format_exc() - layer = layer.orig_layer if hasattr(layer, "orig_layer") else layer - try: - logger.error(cuda_error_msg) - logger.warning("falling back to CPU.") - layer.to("cpu") - layer = WrapperLinear( - layer, - enable_minmax_tuning=False, - enable_norm_bias_tuning=False, - enable_round_tuning=False, - enable_torch_compile=self.compress_context.enable_torch_compile, - iters=0, - ) - layer = layer.unwrapper({}) - except Exception: - raise - - set_module(self.model, layer_name, layer) - self._immediate_pack_and_save_module(layer_name) - - def _immediate_pack_and_save_module(self, module_name): - from auto_round.compressors.shard_writer import ShardWriter - - shard_writer = ShardWriter.get_shard_writer() - to_cpu = self.compress_context.low_gpu_mem_usage - module = get_module(self.model, module_name) - if self.compress_context.is_immediate_packing: - immediate_pack(module_name, self.layer_config) - if to_cpu: - module = module.to("cpu") - packed_module = get_module(self.model, module_name) - set_module(self.model, module_name, packed_module.to("cpu")) - else: - if to_cpu: - module = module.to("cpu") - set_module(self.model, module_name, module) - if self.compress_context.is_immediate_saving: - module = get_module(self.model, module_name) - module.to("cpu") - shard_writer.write(module, module_name, False) - module.to("meta") + raise NotImplementedError("quantize_block must be implemented in subclasses of BaseQuantizer") def quantize_layer(self, layer_name: str, **kwargs): """Quantizes a single layer of the model. @@ -385,7 +539,7 @@ def quantize_layer(self, layer_name: str, **kwargs): layer_name (str): The name of the layer to quantize. The layer module is retrieved internally via get_module(model, layer_name). """ - raise NotImplementedError("quantize_layer must be implemented in subclasses of BaseQuantizers") + raise NotImplementedError("quantize_layer must be implemented in subclasses of BaseQuantizer") def quantize_layer_outside_block(self, layer_name: str, input_ids=None, **kwargs): """Quantizes a single layer of the model outside of a block. @@ -395,76 +549,7 @@ def quantize_layer_outside_block(self, layer_name: str, input_ids=None, **kwargs retrieved internally via get_module(model, layer_name). input_ids: Optional calibration inputs for data-driven outside-layer quantization. """ - dtype = kwargs.pop("dtype", None) - if dtype is not None: - layer = get_module(self.model, layer_name) - set_module(self.model, layer_name, layer.to(dtype)) - self.quantize_layer_via_rtn(layer_name, **kwargs) - - @torch.no_grad() - def _get_block_outputs( - self, - block: torch.nn.Module, - input_ids, - input_others, - bs: int, - save_output: bool = True, - device_override: Union[torch.device, str, None] = None, - ): - """Compute the output of a block for calibration inputs. - - Shared by SignRoundQuantizer and OptimizedRTNQuantizer. Algorithm-specific - block-forward selection (compile vs. plain) is handled here based on - ``enable_alg_ext`` and act-quantization flags. - - Args: - device_override: Override the target device. Used by diffusion with - multi-device dispatch to pass None so block_forward uses the block's - current device instead of forcing a specific device. - """ - diffusion_fn = getattr(self, "_get_diffusion_block_outputs", None) - if getattr(self.model_context, "is_diffusion", False): - device = device_override if device_override is not None else self.compress_context.device - return self._get_diffusion_block_outputs( - block, - input_ids, - input_others, - bs, - device, - self.compress_context.cache_device, - ) - - _bf = self._resolve_block_forward() - - output = [] - nsamples = len(input_ids) - for i in range(0, nsamples, bs): - end_index = min(nsamples, i + bs) - indices = torch.arange(i, end_index).to(torch.long) - tmp_input_ids, tmp_input_others = self._sampling_inputs( - input_ids, - input_others, - indices, - self.seqlen, - self.batch_dim, - share_cache_keys=self.model_context.shared_cache_keys, - ) - tmp_output = _bf( - block, - tmp_input_ids, - tmp_input_others, - self.model_context.amp, - self.model_context.amp_dtype, - self.compress_context.device, - ).to(self.compress_context.cache_device) - if save_output: - if self.batch_size == 1: - output.append(tmp_output) - else: - output.extend(list(torch.split(tmp_output, 1, dim=self.batch_dim))) - self.compress_context.clear_memory() - - return output + raise NotImplementedError("quantize_layer_outside_block must be implemented in subclasses or mixins") def _resolve_block_forward(self): """Resolve and cache the block forward function once. @@ -499,180 +584,13 @@ def _invalidate_block_forward_cache(self): self.__dict__.pop("_resolved_block_forward", None) self.__dict__.pop("_compiled_block_forward", None) - def _get_current_q_output( - self, - block: torch.nn.Module, - input_ids, - input_others: dict, - indices, - device, - cache_device: str = "cpu", - ) -> torch.Tensor: - """Compute block output for a mini-batch selected by *indices* (used during training). - - Handles both LLM and diffusion model block formats. Uses the compiled - block_forward when enable_torch_compile is True (same as _get_block_outputs), - matching old-arch behavior where self.block_forward was compiled at init. - """ - current_input_ids, current_input_others = self._sampling_inputs( - input_ids, - input_others, - indices, - seqlen=self.seqlen, - batch_dim=self.batch_dim, - share_cache_keys=self.model_context.shared_cache_keys, - ) - _bf = self._resolve_block_forward() - - if getattr(self.model_context, "is_diffusion", False): - output_config = self.DIFFUSION_OUTPUT_CONFIGS.get(block.__class__.__name__, ["hidden_states"]) - idx = None if "hidden_states" not in output_config else output_config.index("hidden_states") - if isinstance(current_input_ids, dict): - hidden_states = current_input_ids.pop("hidden_states") - current_input_others.update(current_input_ids) - current_input_ids = hidden_states - output_q = _bf( - block, - current_input_ids, - current_input_others, - self.model_context.amp, - self.model_context.amp_dtype, - device, - idx, - ) - else: - output_q = _bf( - block, - current_input_ids, - current_input_others, - self.model_context.amp, - self.model_context.amp_dtype, - device, - ) - return output_q.to(cache_device) - - @classmethod - @torch.no_grad() - def _sampling_inputs( - cls, - input_ids: Union[list[torch.Tensor], dict], - input_others: dict, - indices, - seqlen: int, - batch_dim: int = 0, - share_cache_keys: tuple = (), - ): - """Sample a mini-batch of calibration inputs by indices. - - Shared by SignRoundQuantizer and OptimizedRTNQuantizer. - """ - if isinstance(input_ids, list): - current_input_ids = [input_ids[i] for i in indices] - current_input_ids = torch.cat(current_input_ids, dim=batch_dim) - elif isinstance(input_ids, dict): - current_input_ids = defaultdict(list) - for k in input_ids.keys(): - current_input_ids[k].extend([input_ids[k][i] for i in indices]) - current_input_ids[k] = torch.cat(current_input_ids[k], dim=batch_dim) - - current_input_others = {"positional_inputs": input_others["positional_inputs"]} - for key in input_others.keys(): - if "positional_inputs" in key: - continue - if key in share_cache_keys: - # Shared keys are stored once (not per-sample), often wrapped in a - # 1-element list by the caching hook. Unwrap so the model receives - # the raw value (e.g. (cos, sin) tuple, not [(cos, sin)]). - # Exception: if the hook detected that the "shared" key actually varies - # per sample (e.g. position_embeddings in a VLM visual encoder), it - # upgrades the storage to a per-sample list with >1 elements. - val = input_others[key] - if isinstance(val, list) and len(val) == 1: - current_input_others[key] = val[0] - elif isinstance(val, list) and len(val) > 1: - # Per-sample storage for a nominally-shared key that varies across - # calibration samples (e.g. position_embeddings in Qwen2-VL visual - # encoder blocks where each image has a different patch count). - idx = int(indices[0]) if len(indices) == 1 else 0 - current_input_others[key] = val[idx] if idx < len(val) else val[0] - else: - current_input_others[key] = val - elif not isinstance(input_others[key], (str, bool, type(None))): - current_input_others[key] = None - if input_others[key] is not None: - current_input_others[key] = [input_others[key][i] for i in indices] - if len(indices) == 1: - current_input_others[key] = current_input_others[key][0] - else: - try: - current_input_others[key] = torch.cat(current_input_others[key], dim=0) - except TypeError as err: - logger.warning_once("Please check the model cache inputs or try setting batch_size to 1.") - else: - current_input_others[key] = input_others[key] + def prepare_run(self, run_ctx) -> None: + """Model-level preparation (called once before block iteration starts).""" + return - return current_input_ids, current_input_others + def finalize_run(self, run_ctx) -> None: + """Model-level teardown (called once after all blocks are processed). - @torch.no_grad() - def _get_diffusion_block_outputs( - self, - block: torch.nn.Module, - input_ids: Union[torch.Tensor, dict], - input_others, - bs: int, - device: Union[str, torch.device], - cache_device: Union[str, torch.device], - save_output: bool = True, - ): - """Compute block outputs for diffusion models. - - Uses ``self.DIFFUSION_OUTPUT_CONFIGS`` to map block class names to their - output keys. Subclasses override ``DIFFUSION_OUTPUT_CONFIGS`` to add - support for new diffusion architectures. + Must be idempotent – the Compressor calls this inside a ``try/finally``. """ - output = defaultdict(list) - output_config = self.DIFFUSION_OUTPUT_CONFIGS.get(block.__class__.__name__, ["hidden_states"]) - if isinstance(input_ids, dict): - nsamples = len(input_ids["hidden_states"]) - else: - nsamples = len(input_ids) - - for i in range(0, nsamples, bs): - end_index = min(nsamples, i + bs) - indices = torch.arange(i, end_index).to(torch.long) - tmp_input_ids, tmp_input_others = self._sampling_inputs( - input_ids, - input_others, - indices, - self.seqlen, - self.batch_dim, - share_cache_keys=self.model_context.shared_cache_keys, - ) - if isinstance(tmp_input_ids, dict): - hidden_states = tmp_input_ids.pop("hidden_states") - tmp_input_others.update(tmp_input_ids) - tmp_input_ids = hidden_states - - tmp_output = block_forward( - block, - tmp_input_ids, - tmp_input_others, - self.model_context.amp, - self.model_context.amp_dtype, - device, - None, - ) - if isinstance(tmp_output, torch.Tensor): - tmp_output = [tmp_output] - assert len(output_config) == len(tmp_output) - tmp_output = dict(zip(output_config, tmp_output)) - - if save_output: - for name, out in tmp_output.items(): - if self.batch_size == 1: - output[name].append(out.to(cache_device)) - else: - output[name].extend(list(torch.split(out.to(cache_device), 1, dim=self.batch_dim))) - self.compress_context.clear_memory() - - return output + return diff --git a/auto_round/algorithms/quantization/config.py b/auto_round/algorithms/quantization/config.py index d99219e51..f2e412db6 100644 --- a/auto_round/algorithms/quantization/config.py +++ b/auto_round/algorithms/quantization/config.py @@ -11,11 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass from enum import Enum from typing import ClassVar, Union -from auto_round.algorithms.alg_config import AlgConfig from auto_round.export.export_to_gguf.config import GGUF_INNER_CONFIG from auto_round.logger import logger from auto_round.schemes import QuantizationScheme @@ -29,28 +27,57 @@ class BackendDataType(str, Enum): FP8 = "fp8" -@dataclass(kw_only=True) class QuantizationConfig: - _alg_cls: ClassVar[str] = None - - # quantization args - bits: int = None - group_size: int = None - sym: bool = None - data_type: str = None - act_bits: int = None - act_group_size: int = None - act_sym: bool = None - act_data_type: str = None - act_dynamic: bool = None - super_bits: int = None - super_group_size: int = None - scale_dtype: str = None - ignore_layers: str = "" - quant_lm_head: bool = False - to_quant_block_names: Union[str, list, None] = None - - def __post_init__(self): + """Common quantization configuration shared by block quantizers. + + Args: + bits: Weight quantization bit width. + group_size: Weight quantization group size. Use -1 for per-channel, + 0 for per-tensor, or a positive integer for grouped quantization. + sym: Whether to use symmetric weight quantization. + data_type: Weight quantization data type, such as int, mx_fp, + nv_fp, or fp8 variants. + act_bits: Activation quantization bit width. + act_group_size: Activation quantization group size. + act_sym: Whether to use symmetric activation quantization. + act_data_type: Activation quantization data type. + act_dynamic: Whether activation quantization should be dynamic. + super_bits: Bit width used for double quantization metadata. + super_group_size: Group size used for double quantization metadata. + """ + + _scheme_fields: ClassVar[set[str]] = set(QuantizationScheme.get_attributes()) + + def __init__(self, *, scheme: QuantizationScheme = None, **kwargs): + object.__setattr__(self, "scheme", scheme if scheme is not None else QuantizationScheme.empty()) + object.__setattr__(self, "_user_set_scheme_fields", set()) + + unknown = [] + for key, value in kwargs.items(): + if key in self._scheme_fields: + setattr(self.scheme, key, value) + self._user_set_scheme_fields.add(key) + else: + unknown.append(key) + if unknown: + unknown_args = ", ".join(repr(arg) for arg in unknown) + raise TypeError(f"Unexpected quantization config argument(s): {unknown_args}") + + self._check_partial_config() + + def __getattr__(self, name): + if name in self._scheme_fields: + return getattr(self.scheme, name, None) + raise AttributeError(f"{type(self).__name__!s} object has no attribute {name!r}") + + def __setattr__(self, name, value): + if name in self._scheme_fields and "scheme" in self.__dict__: + setattr(self.scheme, name, value) + self._user_set_scheme_fields.add(name) + return + object.__setattr__(self, name, value) + + def _check_partial_config(self): # Run block-wise validation early (at construction time, before model loading). # Scheme resolution is deferred to BaseCompressor.post_init() via SchemeMixin. # Guard with None checks in case the user hasn't explicitly set data_type/bits diff --git a/auto_round/algorithms/quantization/registry.py b/auto_round/algorithms/quantization/registry.py new file mode 100644 index 000000000..82a77742c --- /dev/null +++ b/auto_round/algorithms/quantization/registry.py @@ -0,0 +1,11 @@ +# # Copyright (C) 2026 Intel Corporation +# # SPDX-License-Identifier: Apache-2.0 + +from auto_round.algorithms.registry import list_registered_algorithms, register_algorithm, resolve_alg_config + + +def register_alg(alias, factory): + register_algorithm(alias, aliases=(alias,), config_factory=factory) + + +__all__ = ["register_alg", "resolve_alg_config", "list_registered_algorithms"] diff --git a/auto_round/algorithms/quantization/rtn/config.py b/auto_round/algorithms/quantization/rtn/config.py index 5e27dbccb..03dc1829f 100644 --- a/auto_round/algorithms/quantization/rtn/config.py +++ b/auto_round/algorithms/quantization/rtn/config.py @@ -18,14 +18,22 @@ class RTNConfig(QuantizationConfig): - _alg_cls = "RTNQuantizer" - def __init__( self, *, disable_opt_rtn: bool = None, **kwargs, ): + """Initialize an RTN configuration. + + Args: + disable_opt_rtn: Whether to disable the optimized RTN path. + ``None`` keeps the default heuristic, True forces plain + RTN, and False forces the optimized implementation. + **kwargs: Common quantization arguments forwarded to + QuantizationConfig, such as bits, group_size, sym, + data_type, and activation quantization fields. + """ # pop before super().__init__ so it doesn't leak into QuantizationConfig as an unknown kwarg enable_opt_rtn = kwargs.pop("enable_opt_rtn", None) super().__init__(**kwargs) @@ -47,5 +55,7 @@ def __init__( ) disable_opt_rtn = False self.disable_opt_rtn = disable_opt_rtn - if not self.disable_opt_rtn: - self._alg_cls = "OptimizedRTNQuantizer" + + +class OptimizedRTNConfig(RTNConfig): + pass diff --git a/auto_round/algorithms/quantization/rtn/quantizer.py b/auto_round/algorithms/quantization/rtn/quantizer.py index d1c2c0779..d8e5978be 100644 --- a/auto_round/algorithms/quantization/rtn/quantizer.py +++ b/auto_round/algorithms/quantization/rtn/quantizer.py @@ -12,15 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections import defaultdict +from contextlib import contextmanager from typing import Any, Callable, Optional, Union import accelerate import torch -from auto_round.algorithms.quantization.base import BaseQuantizers -from auto_round.algorithms.quantization.rtn.config import RTNConfig +from auto_round.algorithms.quantization.base import BaseQuantizer, RTNLayerFallbackMixin +from auto_round.algorithms.quantization.rtn.config import OptimizedRTNConfig, RTNConfig from auto_round.algorithms.quantization.sign_round.quantizer import SignRoundQuantizer -from auto_round.algorithms.quantization.utils import register_imatrix_hooks +from auto_round.algorithms.registry import register_pipeline_member from auto_round.compressors.utils import ( IndexSampler, block_forward, @@ -56,15 +57,14 @@ from auto_round.wrapper import WrapperMultiblock, unwrapper_block, unwrapper_layer, wrapper_block -class RTNQuantizer(BaseQuantizers): +@register_pipeline_member(RTNConfig) +class RTNQuantizer(RTNLayerFallbackMixin, BaseQuantizer): def __init__(self, config: RTNConfig): - BaseQuantizers.__init__(self, config) + BaseQuantizer.__init__(self, config) @torch.no_grad() - def quantize_block( - self, block: torch.nn.Module, input_ids=None, input_others=None, reference_output=None, **kwargs - ) -> dict: + def quantize_block(self, ctx) -> dict: """Apply zero-shot RTN quantization to a block. Pure-algorithm entry point. Infrastructure (materialize, shard writing, @@ -79,6 +79,7 @@ def quantize_block( Returns: dict: Empty dict (zero-shot RTN has no tunable parameters to return). """ + block = ctx.block if ( self.config.is_act_nv_fp or self.config.is_static_afp8 @@ -103,10 +104,11 @@ def quantize_layer(self, name: str, dtype: torch.dtype = None) -> None: self.quantize_layer_via_rtn(name) +@register_pipeline_member(OptimizedRTNConfig) class OptimizedRTNQuantizer(RTNQuantizer): def __init__(self, config: RTNConfig): - BaseQuantizers.__init__(self, config) + BaseQuantizer.__init__(self, config) self.data_type = config.data_type self.group_size = config.group_size self.infer_bs_coeff = config.infer_bs_coeff @@ -114,16 +116,36 @@ def __init__(self, config: RTNConfig): self.enable_alg_ext = True - def register_calibration_hooks(self, model, *, act_max: bool = True, imatrix: bool = True): - hook_handles = super().register_calibration_hooks(model, act_max=act_max, imatrix=imatrix) - if imatrix and self.enable_imatrix: - hook_handles.extend(register_imatrix_hooks(self, model, with_count=True)) - return hook_handles + @contextmanager + def block_forward_hooks(self, ctx): + with super().block_forward_hooks(ctx) as hook_handles: + if self.enable_imatrix: + hook_handles.extend(self._register_imatrix_hooks(ctx.block, with_count=True)) + yield hook_handles + + def _register_imatrix_hooks(self, model, *, with_count: bool = False): + def collect_imatrix(module, input, output): + input = input[0] if isinstance(input, (tuple, list)) else input + flattened = input.reshape(-1, input.shape[-1]).to(torch.float32) + squared = torch.sum(torch.pow(flattened, 2), dim=0).to(torch.float32) + + if not hasattr(module, "imatrix"): + module.imatrix = squared + if with_count: + module.imatrix_cnt = input.shape[0] + return + module.imatrix += squared.to(module.imatrix.device) + if with_count: + module.imatrix_cnt += input.shape[0] + + handles = [] + for _, module in model.named_modules(): + if isinstance(module, self.supported_types) and check_to_quantized(module): + handles.append(module.register_forward_hook(collect_imatrix)) + return handles @torch.no_grad() - def quantize_block( - self, block: torch.nn.Module, input_ids=None, input_others=None, reference_output=None, **kwargs - ): + def quantize_block(self, ctx): """Apply imatrix-informed RTN quantization to a block. Pure-algorithm entry point. Device placement and cleanup are handled @@ -137,6 +159,7 @@ def quantize_block( input_others: Unused for optimized RTN. reference_output: Unused for optimized RTN. """ + block = ctx.block update_block_global_scale_if_needed(block, self.data_type, self.group_size) if ( self.config.is_act_nv_fp diff --git a/auto_round/algorithms/quantization/sign_round/config.py b/auto_round/algorithms/quantization/sign_round/config.py index 2aad5e08e..674b4611c 100644 --- a/auto_round/algorithms/quantization/sign_round/config.py +++ b/auto_round/algorithms/quantization/sign_round/config.py @@ -18,18 +18,7 @@ class SignRoundConfig(QuantizationConfig): - """ - - Args: - iters (int): Number of iterations (default is 200). - lr (float): The learning rate (default is 0.005). - minmax_lr (float): The learning rate for min-max tuning (default is None). - lr_scheduler: The learning rate scheduler to be used. - enable_minmax_tuning (bool): Whether to enable min-max tuning (default is True). - enable_norm_bias_tuning (bool): Whether to enable fast norm/layer_bias tuning - """ - - _alg_cls = "SignRoundQuantizer" + """Configuration for SignRound-style block quantization.""" def __init__( self, @@ -51,6 +40,38 @@ def __init__( enable_adam: bool = False, **kwargs, ): + """Initialize a SignRound configuration. + + Args: + iters: Number of optimization iterations for each quantized + block. + lr: Learning rate used by the main rounding optimization. + If None, a heuristic based on ``iters`` is used. + minmax_lr: Learning rate used by min-max tuning. If None, it + falls back to ``lr``. + lr_scheduler: Optional learning-rate scheduler name or + scheduler object used by the optimizer. + momentum: Momentum factor used by the optimizer. + nblocks: Number of blocks to optimize together. + enable_minmax_tuning: Whether to tune weight min/max ranges. + enable_norm_bias_tuning: Whether to tune normalization and + bias terms. + gradient_accumulate_steps: Number of gradient accumulation + steps used per optimization update. + enable_alg_ext: Whether to enable the experimental SignRound + extension implementation. + not_use_best_mse: Whether to skip restoring the best-MSE + checkpoint during tuning. + dynamic_max_gap: Maximum dynamic gap used by adaptive tuning + logic. + enable_quanted_input: Whether each block should consume the + quantized output of previous blocks during calibration. + optimizer: Optional optimizer name override. + enable_adam: Whether to use the Adam-based SignRound variant. + **kwargs: Common quantization arguments forwarded to + QuantizationConfig, such as bits, group_size, sym, + data_type, and activation quantization fields. + """ super().__init__(**kwargs) self.gradient_accumulate_steps = gradient_accumulate_steps self.iters = iters @@ -87,11 +108,6 @@ def __init__( self.optimizer = optimizer self.enable_adam = enable_adam - if self.enable_adam: - self._alg_cls = "AdamRoundQuantizer" - elif self.enable_alg_ext: - self._alg_cls = "SignRoundV2Quantizer" - def check_configs(self) -> None: """Checks if the configurations are valid. @@ -106,3 +122,11 @@ def check_configs(self) -> None: raise ValueError("`nblocks` must be positive") if self.gradient_accumulate_steps <= 0: raise ValueError("`gradient_accumulate_steps` must be positive") + + +class AdamRoundConfig(SignRoundConfig): + pass + + +class SignRoundV2Config(SignRoundConfig): + pass diff --git a/auto_round/algorithms/quantization/sign_round/quantizer.py b/auto_round/algorithms/quantization/sign_round/quantizer.py index c53bb8b87..6fa4fe552 100644 --- a/auto_round/algorithms/quantization/sign_round/quantizer.py +++ b/auto_round/algorithms/quantization/sign_round/quantizer.py @@ -21,9 +21,10 @@ import torch from torch import autocast -from auto_round.algorithms.quantization.base import BaseQuantizers +from auto_round.algorithms.quantization.base import BaseQuantizer, RTNLayerFallbackMixin from auto_round.algorithms.quantization.sign_round.config import SignRoundConfig from auto_round.algorithms.quantization.sign_round.sign_sgd import SignSGD +from auto_round.algorithms.registry import register_pipeline_member from auto_round.compressors.utils import ( IndexSampler, block_forward, @@ -36,6 +37,7 @@ from auto_round.utils import ( check_to_quantized, compile_func, + convert_module_to_hp_if_necessary, get_module, htcore, is_auto_device_mapping, @@ -54,7 +56,8 @@ from auto_round.wrapper import WrapperLinear, unwrapper_block, unwrapper_layer, wrapper_block -class SignRoundQuantizer(BaseQuantizers): +@register_pipeline_member(SignRoundConfig) +class SignRoundQuantizer(RTNLayerFallbackMixin, BaseQuantizer): def __init__(self, config: SignRoundConfig): super().__init__(config) @@ -75,28 +78,6 @@ def __init__(self, config: SignRoundConfig): self.optimizer = self._get_optimizer(optimizer=config.optimizer) self.wrapper_block = wrapper_block - def _get_current_output(self, output: list[torch.Tensor], indices: list[int]) -> torch.Tensor: - if self.model_context.is_diffusion: - assert "hidden_states" in output - current_output = [output["hidden_states"][x] for x in indices] - current_output = torch.cat(current_output, dim=self.batch_dim) - return current_output - current_output = [output[x] for x in indices] - current_output = torch.cat(current_output, dim=self.batch_dim) - return current_output - - def _get_current_num_elm( - self, - input_ids: list[torch.Tensor], - indices: list[int], - ) -> int: - if self.model_context.is_diffusion: - current_input_ids = [input_ids["hidden_states"][i] for i in indices] - return sum(id.numel() for id in current_input_ids) - - current_input_ids = [input_ids[i] for i in indices] - return sum(id.numel() for id in current_input_ids) - def _get_non_zero_cnt(self, tensor: list[torch.Tensor], indices: list[int]) -> int: current_tensors = [tensor[i] for i in indices] non_zero_cnt = 0 @@ -135,17 +116,7 @@ def _get_loss( return loss - def quantize_block( - self, - block: torch.nn.Module, - input_ids: Union[list[torch.Tensor], dict], - input_others: dict, - reference_output, - *, - loss_device: Union[str, torch.device], - mid_iter_mem_check: bool = False, - **kwargs, - ) -> dict: + def quantize_block(self, ctx) -> dict: """Apply the AutoRound optimization algorithm to a block. This is the pure-algorithm entry point. All infrastructure concerns @@ -154,21 +125,17 @@ def quantize_block( before and after this call. Args: - block: Module already placed on the correct device(s). - input_ids: Calibration inputs (already on cache_device). - input_others: Additional inputs for the block's forward pass. - reference_output: FP reference outputs collected by the Compressor. - loss_device: Device on which to compute the MSE loss. - mid_iter_mem_check: Pre-evaluated by the Compressor as - ``low_gpu_mem_usage and card_0_in_high_risk``. When True, - triggers mid-iteration memory threshold checks to reduce - fragmentation on the primary GPU. + ctx: Per-block pipeline context. ``ctx.io`` owns calibration inputs, + reference outputs, and mini-batch block forwards. Returns: best_params: Best quantization parameters found during optimization. Empty dict if no trainable parameters were found. """ + block = ctx.block device = self.compress_context.device + loss_device = ctx.loss_device + mid_iter_mem_check = ctx.mid_iter_mem_check quantized_layer_names, unquantized_layer_names = self.wrapper_block( block, @@ -226,10 +193,8 @@ def quantize_block( else: lr_schedule = copy.deepcopy(self.lr_scheduler) - if isinstance(input_ids, dict): # input_ids of Flux is dict - nsamples = len(input_ids["hidden_states"]) - else: - nsamples = len(input_ids) + active_inputs = ctx.io.get_inputs(ctx.io.active_source) + nsamples = len(active_inputs["hidden_states"]) if isinstance(active_inputs, dict) else len(active_inputs) last_best_iter = 0 best_loss = torch.finfo(torch.float).max num_elm = 1 @@ -246,7 +211,7 @@ def quantize_block( # We assume the block input and output shape is same if self.gradient_accumulate_steps != 1 and not self.attention_mask: whole_indices = torch.arange(global_batch_size) - num_elm = self._get_current_num_elm(input_ids, whole_indices) + num_elm = ctx.io.count_input_elements(whole_indices) setup_ddp_if_needed_(self, block, self.compress_context.device_list) index_sampler = IndexSampler(nsamples, global_batch_size) batch_size = self.batch_size @@ -261,9 +226,10 @@ def quantize_block( for batch_start in range(0, len(global_indices), batch_size): indices = global_indices[batch_start : batch_start + batch_size] - current_output = self._get_current_output(reference_output, indices) - current_output = to_device(current_output, loss_device) - output_q = self._get_current_q_output(block, input_ids, input_others, indices, device, loss_device) + current_output = ctx.io.select_reference_outputs(indices, device=loss_device) + output_q = ctx.io.forward_batch( + block, self, indices, source=ctx.io.active_source, device=device, cache_device=loss_device + ) loss = self._get_loss(output_q, current_output, indices, mse_loss, device) num_elm = 1 if num_elm <= 0 else num_elm total_loss += loss.item() / num_elm @@ -336,7 +302,7 @@ def quantize_layer_outside_block( Args: layer_name (str): The name of the layer to quantize. - input_ids (list[torch.Tensor], optional): Input data for quantization. + inputs (torch.Tensor): Input data for quantization. q_inputs (torch.Tensor, optional): Quantized input data. Defaults to None. device (torch.device, optional): The device to use for quantization. Defaults to torch.device("cpu"). @@ -377,7 +343,7 @@ def quantize_layer_outside_block( static_attention_dtype, ): tmp_inputs = q_inputs if q_inputs is not None else input_ids - hook_handles = self.register_calibration_hooks(layer) + hook_handles = self._register_act_max_hooks(layer) with torch.no_grad(): for input in tmp_inputs: layer(input) @@ -439,9 +405,9 @@ def quantize_layer_outside_block( if gradient_accumulate_steps != 1 and not self.attention_mask: whole_indices = torch.arange(global_batch_size) if q_inputs is not None: - num_elm = self._get_current_num_elm(q_inputs, whole_indices) + num_elm = self._count_layer_input_elements(q_inputs, whole_indices) else: - num_elm = self._get_current_num_elm(input_ids, whole_indices) + num_elm = self._count_layer_input_elements(input_ids, whole_indices) index_sampler = IndexSampler(nsamples, global_batch_size) @@ -535,6 +501,9 @@ def _get_optimizer(self, optimizer: Any): ) return SignSGD + def _count_layer_input_elements(self, input_ids, indices: list) -> int: + return sum(input_ids[i].numel() for i in indices) + def _get_scaler(self): """Returns scaler, in SignRound, no need to use scaler.""" return None diff --git a/auto_round/algorithms/quantization/sign_roundv2/quantizer.py b/auto_round/algorithms/quantization/sign_roundv2/quantizer.py index b8ff6c9bb..55e31a31c 100644 --- a/auto_round/algorithms/quantization/sign_roundv2/quantizer.py +++ b/auto_round/algorithms/quantization/sign_roundv2/quantizer.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from contextlib import nullcontext +from contextlib import contextmanager, nullcontext from functools import partial from typing import Callable, Union @@ -20,9 +20,9 @@ import transformers from torch import autocast -from auto_round.algorithms.quantization.sign_round.config import SignRoundConfig +from auto_round.algorithms.quantization.sign_round.config import SignRoundConfig, SignRoundV2Config from auto_round.algorithms.quantization.sign_round.quantizer import SignRoundQuantizer -from auto_round.algorithms.quantization.utils import register_imatrix_hooks +from auto_round.algorithms.registry import register_pipeline_member from auto_round.data_type.gguf import ( double_quant_tensor_sym_rtn, quant_tensor_gguf_asym_dq, @@ -137,7 +137,7 @@ class SignRoundDQWrapperLinear(WrapperLinear): def __init__(self, *args, **kwargs): if "enable_minmax_tuning" in kwargs: logger.warning_once("disable minmax tuning for a little better accuracy and lower cost") - kwargs["enable_minmax_tuning"] = False # a little faster and better + kwargs["enable_minmax_tuning"] = False super().__init__(*args, **kwargs) self.prev_scale = None self.prev_wmin = None @@ -180,7 +180,6 @@ def _init_tuning_params_and_quant_func(self): @torch.no_grad() def _run_search(self, weight, v): - """Per-format scale/wmin search separated from the (compilable) quant func.""" from auto_round.export.export_to_gguf.config import GGML_QUANT_SIZES, QK_K bits = self.orig_layer.bits @@ -201,8 +200,6 @@ def _run_search(self, weight, v): split_num=1, v=v_r, ) - # Search funcs use ``@torch.inference_mode()``; clone to detach so - # autograd may consume them. return { "scale": scale.clone(), "wmin": wmin.clone(), @@ -210,7 +207,6 @@ def _run_search(self, weight, v): "d_wmin": d_wmin.clone(), } - # sym path group_size = 16 super_group_size = 16 t, _, _ = reshape_pad_tensor_by_group_size(weight.to(torch.float32), group_size) @@ -232,13 +228,12 @@ def _run_search(self, weight, v): def _qdq_weight(self, value, min_scale, max_scale): if not self._is_dq_path: - # Non-dq path keeps the original behavior (base class handles it). return super()._qdq_weight(value, min_scale, max_scale) if self.orig_layer.bits >= 16: return self.orig_layer.weight, None, None min_bound, max_bound = self.minmax_scale_bound - min_scale.data.clamp_(min_bound, max_bound) # TODO this one could be deleted + min_scale.data.clamp_(min_bound, max_bound) max_scale.data.clamp_(min_bound, max_bound) weight = self.orig_layer.weight if weight.device.type == "meta": @@ -246,7 +241,6 @@ def _qdq_weight(self, value, min_scale, max_scale): if isinstance(self.orig_layer, transformers.pytorch_utils.Conv1D): weight = weight.t() - # Re-search every 10 steps; otherwise reuse the cached search results. iter_v = getattr(self, "cur_iter", 0) need_search = (iter_v == 0) or (iter_v == -1) or (self.prev_scale is None) if need_search: @@ -294,6 +288,7 @@ def _qdq_weight(self, value, min_scale, max_scale): return weight_q, scale_out, zp_out +@register_pipeline_member(SignRoundV2Config) class SignRoundV2Quantizer(SignRoundQuantizer): """SignRound variant using the open algorithm-extension path in the new architecture.""" @@ -361,15 +356,31 @@ def _get_loss( ) return super()._get_loss(output_q, current_output, indices, mse_loss, device) - def register_calibration_hooks(self, model, *, act_max: bool = True, imatrix: bool = True): - hook_handles = super().register_calibration_hooks(model, act_max=act_max, imatrix=imatrix) - if not imatrix: - return hook_handles + @contextmanager + def block_forward_hooks(self, ctx): + with super().block_forward_hooks(ctx) as hook_handles: + if not self._is_wint4aint4(): + hook_handles.extend(self._register_imatrix_hooks(ctx.block)) + yield hook_handles - is_wint4aint4 = ("int4" in self.act_data_type or ("int" in self.act_data_type and self.act_bits == 4)) and ( + def _is_wint4aint4(self): + return ("int4" in self.act_data_type or ("int" in self.act_data_type and self.act_bits == 4)) and ( "int4" in self.data_type or ("int" in self.data_type and self.bits == 4) ) - if is_wint4aint4: - return hook_handles - hook_handles.extend(register_imatrix_hooks(self, model)) - return hook_handles + + def _register_imatrix_hooks(self, model): + def collect_imatrix(module, input, output): + input = input[0] if isinstance(input, (tuple, list)) else input + flattened = input.reshape(-1, input.shape[-1]).to(torch.float32) + squared = torch.sum(torch.pow(flattened, 2), dim=0).to(torch.float32) + + if not hasattr(module, "imatrix"): + module.imatrix = squared + return + module.imatrix += squared.to(module.imatrix.device) + + handles = [] + for _, module in model.named_modules(): + if isinstance(module, self.supported_types) and check_to_quantized(module): + handles.append(module.register_forward_hook(collect_imatrix)) + return handles diff --git a/auto_round/algorithms/registry.py b/auto_round/algorithms/registry.py new file mode 100644 index 000000000..d3a16e6cb --- /dev/null +++ b/auto_round/algorithms/registry.py @@ -0,0 +1,168 @@ +# # Copyright (C) 2026 Intel Corporation +# # SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import copy +import importlib +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Callable + +if TYPE_CHECKING: + from auto_round.algorithms.base import BasePipelineMember + + +@dataclass +class AlgRegistryEntry: + name: str + aliases: tuple[str, ...] = () + config_factory: Callable[[], object] | None = None + cli_handler: type | None = None + summary: str = "" + alias_factories: dict[str, Callable[[], object]] = field(default_factory=dict) + + +_ALG_REGISTRY: dict[str, AlgRegistryEntry] = {} +_ALIAS_TO_NAME: dict[str, str] = {} +_CONFIG_IMPL_REGISTRY: dict[type, type["BasePipelineMember"]] = {} +_builtin_algorithms_registered = False +_pipeline_members_registered = False + + +def _ensure_builtin_algorithms_registered() -> None: + global _builtin_algorithms_registered + if _builtin_algorithms_registered: + return + importlib.import_module("auto_round.cli.algorithms") + _builtin_algorithms_registered = True + + +def _ensure_pipeline_members_registered() -> None: + global _pipeline_members_registered + if _pipeline_members_registered: + return + for module_name in ( + "auto_round.algorithms.quantization.rtn.quantizer", + "auto_round.algorithms.quantization.sign_round.quantizer", + "auto_round.algorithms.quantization.sign_roundv2.quantizer", + "auto_round.algorithms.quantization.adam_round.adam", + "auto_round.algorithms.transforms.awq.quantizer", + ): + importlib.import_module(module_name) + _pipeline_members_registered = True + + +def register_algorithm( + name: str, + *, + aliases: tuple[str, ...] = (), + config_factory: Callable[[], object] | None = None, + cli_handler: type | None = None, + summary: str = "", + alias_factories: dict[str, Callable[[], object]] | None = None, +) -> None: + key = name.strip().lower() + entry = _ALG_REGISTRY.get(key) + if entry is None: + entry = AlgRegistryEntry(name=key) + _ALG_REGISTRY[key] = entry + + merged_aliases = tuple( + dict.fromkeys((entry.aliases or ()) + tuple(a.strip().lower() for a in aliases if a.strip())) + ) + if config_factory is not None: + entry.config_factory = config_factory + if cli_handler is not None: + entry.cli_handler = cli_handler + if summary: + entry.summary = summary + if alias_factories: + entry.alias_factories.update({k.strip().lower(): v for k, v in alias_factories.items()}) + entry.aliases = merged_aliases + + _ALIAS_TO_NAME[key] = key + for alias in merged_aliases: + _ALIAS_TO_NAME[alias] = key + + +def resolve_algorithm_alias(alias: str) -> str | None: + _ensure_builtin_algorithms_registered() + return _ALIAS_TO_NAME.get(alias.strip().lower()) + + +def get_algorithm_entry(name: str) -> AlgRegistryEntry: + _ensure_builtin_algorithms_registered() + canonical = resolve_algorithm_alias(name) + if canonical is None or canonical not in _ALG_REGISTRY: + raise KeyError(name) + return _ALG_REGISTRY[canonical] + + +def iter_algorithm_entries() -> list[AlgRegistryEntry]: + _ensure_builtin_algorithms_registered() + return list(_ALG_REGISTRY.values()) + + +def resolve_alg_config(alias: str) -> object: + _ensure_builtin_algorithms_registered() + canonical = resolve_algorithm_alias(alias) + if canonical is None: + raise ValueError( + f"Unknown algorithm alias '{alias}'. Supported aliases: {sorted(_ALIAS_TO_NAME.keys())}. " + "If you are adding a new algorithm, register it via auto_round.algorithms.registry.register_algorithm()." + ) + + entry = _ALG_REGISTRY[canonical] + factory = entry.alias_factories.get(alias.strip().lower(), entry.config_factory) + if factory is None: + raise ValueError(f"Algorithm alias '{alias}' is registered but has no config factory.") + return factory() + + +def list_registered_algorithms() -> list[str]: + _ensure_builtin_algorithms_registered() + return sorted(_ALIAS_TO_NAME.keys()) + + +def register_pipeline_member(config_cls: type): + def _decorator(member_cls: type["BasePipelineMember"]) -> type["BasePipelineMember"]: + _CONFIG_IMPL_REGISTRY[config_cls] = member_cls + return member_cls + + return _decorator + + +def resolve_pipeline_member(config: object) -> type["BasePipelineMember"]: + _ensure_pipeline_members_registered() + config_cls = type(config) + for cls in config_cls.__mro__: + member_cls = _CONFIG_IMPL_REGISTRY.get(cls) + if member_cls is not None: + return member_cls + raise ValueError(f"Unknown algorithm config type {config_cls.__name__!r}.") + + +def coerce_config_class(config: object, target_cls: type) -> object: + if type(config) is target_cls: + return config + new_config = copy.copy(config) + if hasattr(config, "scheme") and getattr(config, "scheme", None) is not None: + new_config.scheme = config.scheme.copy() + if hasattr(config, "_user_set_scheme_fields"): + new_config._user_set_scheme_fields = set(getattr(config, "_user_set_scheme_fields", set())) + new_config.__class__ = target_cls + return new_config + + +def normalize_algorithm_config(config: object) -> object: + from auto_round.algorithms.quantization.rtn.config import OptimizedRTNConfig, RTNConfig + from auto_round.algorithms.quantization.sign_round.config import AdamRoundConfig, SignRoundConfig, SignRoundV2Config + + if type(config) is RTNConfig and not getattr(config, "disable_opt_rtn", False): + return coerce_config_class(config, OptimizedRTNConfig) + if type(config) is SignRoundConfig: + if getattr(config, "enable_adam", False): + return coerce_config_class(config, AdamRoundConfig) + if getattr(config, "enable_alg_ext", False): + return coerce_config_class(config, SignRoundV2Config) + return config diff --git a/auto_round/algorithms/transforms/__init__.py b/auto_round/algorithms/transforms/__init__.py index d2b63d9e7..12eeb37da 100644 --- a/auto_round/algorithms/transforms/__init__.py +++ b/auto_round/algorithms/transforms/__init__.py @@ -20,7 +20,7 @@ Current algorithms ------------------ * **hadamard** – Block-diagonal Hadamard rotations (QuaRot / SpinQuant style). - See :mod:`auto_round.algorithms.transforms.rotation`. + See :mod:`auto_round.algorithms.transforms.quarot`. * **spinquant** – SpinQuant/QuaRot multi-level rotation (R1–R4) with optional online hooks, trainable rotations, and known Hadamard matrices for non-pow2. See :mod:`auto_round.algorithms.transforms.spinquant`. @@ -45,6 +45,7 @@ import torch from auto_round.algorithms.transforms.base import ( + BaseWeightTransformer, BaseRotation, BaseRotationConfig, SerializerMixin, @@ -52,7 +53,7 @@ check_supported_schemes, _ensure_registry_populated, ) -from auto_round.algorithms.transforms.rotation import ( +from auto_round.algorithms.transforms.quarot import ( HadamardRotation, apply_rotation_transform, normalize_rotation_config as _normalize_hadamard_config, @@ -61,6 +62,7 @@ __all__ = [ # Base interfaces + "BaseWeightTransformer", "BaseRotation", "BaseRotationConfig", "SerializerMixin", @@ -82,6 +84,18 @@ ] +def __getattr__(name): + if name == "AWQConfig": + from auto_round.algorithms.transforms.awq.config import AWQConfig + + return AWQConfig + if name == "AWQQuantizer": + from auto_round.algorithms.transforms.awq.quantizer import AWQQuantizer + + return AWQQuantizer + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + def normalize_rotation_config( config: Any, ) -> BaseRotationConfig | None: diff --git a/auto_round/algorithms/alg_config.py b/auto_round/algorithms/transforms/awq/__init__.py similarity index 76% rename from auto_round/algorithms/alg_config.py rename to auto_round/algorithms/transforms/awq/__init__.py index d9d5f0c75..5637a3e0b 100644 --- a/auto_round/algorithms/alg_config.py +++ b/auto_round/algorithms/transforms/awq/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from auto_round.algorithms.transforms.awq.config import AWQConfig +from auto_round.algorithms.transforms.awq.quantizer import AWQQuantizer -class AlgConfig: - def __init__(self): - pass +__all__ = ["AWQConfig", "AWQQuantizer"] diff --git a/auto_round/algorithms/transforms/awq/config.py b/auto_round/algorithms/transforms/awq/config.py new file mode 100644 index 000000000..2f87fe7b3 --- /dev/null +++ b/auto_round/algorithms/transforms/awq/config.py @@ -0,0 +1,92 @@ +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from auto_round.algorithms.quantization.config import QuantizationConfig + + +class AWQConfig(QuantizationConfig): + """Configuration for AWQ (Activation-Aware Weight Quantization). + + AWQ is a **pre-processing** algorithm (``role="preprocess"``). It + protects salient weight channels by analyzing activation patterns and + applying channel-wise scaling to reduce quantization error. After + smoothing, a separate ``block_quantizer`` (RTN, SignRound, …) performs + the actual weight compression. + + The quantization parameters (``bits``, ``group_size``, ``sym``, + ``data_type``, …) on this config are used *only* for the internal grid + search loss calculation (quantize-dequantize during scale selection). + The definitive quantization parameters for the final weight compression + step come from the pipeline's ``block_quantizer`` config. + """ + + def __init__( + self, + *, + duo_scaling: bool | str = True, + n_grid: int = 20, + seqlen: int = 2048, + nsamples: int = 128, + batch_size: int = 8, + apply_smooth: bool = True, + mappings: list[dict] | None = None, + **kwargs, + ): + """Initialize an AWQ configuration. + + Args: + duo_scaling: Whether AWQ should use activation-aware and + weight-aware scaling together. Use True to always enable + duo scaling, False to use activation-only scaling, or + "both" to search both modes and keep the better result. + n_grid: Number of grid-search points used when searching the + AWQ scaling ratio. + seqlen: Calibration sequence length retained for compatibility + with standalone AWQ entry points. + nsamples: Number of calibration samples retained for compatibility + with standalone AWQ entry points. + batch_size: Batch size retained for compatibility with standalone + AWQ entry points. + apply_smooth: Whether to apply AWQ smoothing before the downstream + block quantizer. + mappings: Optional explicit AWQ smooth/balance mappings. Each + item should contain ``smooth_layer`` and + ``balance_layers`` entries. If None, mappings are inferred + automatically from the model structure. + **kwargs: Common quantization arguments forwarded to + QuantizationConfig, such as bits, group_size, sym, + data_type, and activation quantization fields. + """ + super().__init__(**kwargs) + + if isinstance(duo_scaling, str) and duo_scaling != "both": + raise ValueError(f"duo_scaling must be True, False, or 'both', got '{duo_scaling!r}'") + self.duo_scaling = duo_scaling + self.n_grid = n_grid + self.seqlen = seqlen + self.nsamples = nsamples + self.batch_size = batch_size + self.apply_smooth = apply_smooth + self.mappings = mappings + self.infer_bs_coeff = 1 + self.batch_dim = None + # NOTE: enable_quanted_input is NOT set here. It belongs to the + # block_quantizer (RTN/AutoRound), not to AWQ. See §3.7.1. + + def __repr__(self) -> str: + return ( + f"AWQConfig(duo_scaling={self.duo_scaling!r}, n_grid={self.n_grid}, " + f"bits={self.bits}, group_size={self.group_size}, sym={self.sym}, " + f"mappings={'' if self.mappings else 'auto'})" + ) diff --git a/auto_round/algorithms/quantization/awq/mappings.py b/auto_round/algorithms/transforms/awq/mappings.py similarity index 99% rename from auto_round/algorithms/quantization/awq/mappings.py rename to auto_round/algorithms/transforms/awq/mappings.py index 32617632f..d7f6fd147 100644 --- a/auto_round/algorithms/quantization/awq/mappings.py +++ b/auto_round/algorithms/transforms/awq/mappings.py @@ -386,7 +386,7 @@ def _build_hybrid_attention_mappings(model: torch.nn.Module) -> list[AWQMapping] mappings.append(AWQMapping(r"up_proj$", [r"down_proj$"])) - logger.warning( + logger.info( f"Built dynamic hybrid-attention AWQ mappings: " f"{len(full_indices)} full-attention, {len(linear_indices)} linear-attention, " f"projections={linear_proj_names}, MoE={is_moe}" diff --git a/auto_round/algorithms/transforms/awq/quantizer.py b/auto_round/algorithms/transforms/awq/quantizer.py new file mode 100644 index 000000000..26cb8aba6 --- /dev/null +++ b/auto_round/algorithms/transforms/awq/quantizer.py @@ -0,0 +1,566 @@ +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""AWQ (Activation-Aware Weight Quantization) quantizer. + +Algorithm: +1. Collect per-channel activation magnitudes during calibration. +2. For each smooth-balance mapping, perform a grid search over scaling ratios + to find the one that minimises quantization error (output-based loss). +3. Apply the best channel-wise scaling: + - balance_layer.weight *= scales + - smooth_layer.weight /= scales (or smooth_layer.bias /= scales if 1-D) +4. Weight compression is delegated to the pipeline's block_quantizer. + +Reference implementations: + - AutoAWQ: https://github.com/casper-hansen/AutoAWQ + - llm-compressor: https://github.com/vllm-project/llm-compressor +""" + +from __future__ import annotations + +import gc +import inspect +from contextlib import contextmanager +from typing import TYPE_CHECKING, Any + +import torch + +from auto_round.algorithms.pipeline import ( + ActCalibPolicy, + CalibTiming, + InputSource, +) +from auto_round.algorithms.registry import register_pipeline_member +from auto_round.algorithms.transforms.awq.config import AWQConfig +from auto_round.algorithms.transforms.awq.mappings import ( + ResolvedMapping, + _extract_block_prefix, + check_model_compatibility, + resolve_mappings, +) +from auto_round.algorithms.transforms.base import BaseWeightTransformer +from auto_round.data_type.utils import ( + get_quant_func, + reshape_pad_tensor_by_group_size, + revert_tensor_by_pad, +) +from auto_round.logger import logger + +if TYPE_CHECKING: + from auto_round.algorithms.pipeline import BlockContext, RunContext + + +@register_pipeline_member(AWQConfig) +class AWQQuantizer(BaseWeightTransformer): + """AWQ quantizer: activation-aware weight smoothing pre-processor. + + Inherits :class:`~auto_round.algorithms.transforms.base.BaseWeightTransformer`. + It smooths block weights in-place; actual weight compression (RTN / + SignRound) is performed by the pipeline's ``block_quantizer``. + """ + + def __init__(self, config: AWQConfig): + super().__init__(config) + self.duo_scaling: bool | str = config.duo_scaling + self.n_grid: int = config.n_grid + + self._user_mappings: list[dict] | None = config.mappings + + # Set at runtime by the compressor's post_init() via ``pre.layer_config = self.layer_config``. + self.layer_config: dict | None = None + + self._resolved_mappings: list[ResolvedMapping] = [] + self._block_mappings: dict[str, list[ResolvedMapping]] = {} + + self._activation_stats: dict[str, list] = {} + self._parent_args_cache: dict[torch.nn.Module, list[dict]] = {} + self._parent_signatures: dict[int, inspect.Signature] = {} + + self._finalized: bool = False + + # ── Algorithm Fusion: lifecycle hook implementations ────────────────────── + + def bind(self, compressor) -> None: + """Wire shared state and force AWQ onto single-block scheduling.""" + super().bind(compressor) + nblocks = getattr(compressor, "nblocks", 1) + if nblocks > 1: + logger.warning( + "AWQ does not support nblocks > 1 (got nblocks=%s). " "Falling back to nblocks=1.", + nblocks, + ) + compressor.nblocks = 1 + + def prepare_run(self, run_ctx: "RunContext") -> None: + """Validate compatibility, resolve model-wide mappings, and group by block prefix.""" + report = check_model_compatibility(run_ctx.model, self._user_mappings) + for warning in report["warnings"]: + logger.warning(warning) + + # ── Resolve all model-level mappings (name-only, no module caching) ── + self._resolved_mappings = resolve_mappings(run_ctx.model, self._user_mappings) + if not self._resolved_mappings: + raise ValueError( + "AWQ: no layer mappings were resolved for this model. " + f"Model class: {type(run_ctx.model).__name__}. " + "To add support, provide explicit 'mappings' in AWQConfig, or " + "add an entry to auto_round/algorithms/transforms/awq/mappings.py." + ) + + # Group mappings by block prefix for O(1) lookup during block iteration. + self._block_mappings = {} + for m in self._resolved_mappings: + prefix = _extract_block_prefix(m.smooth_name) + self._block_mappings.setdefault(prefix, []).append(m) + + if run_ctx.compress_context is not None: + run_ctx.compress_context.cache_device = torch.device("cpu") + + logger.info( + "AWQ: resolved %d mappings across %d blocks.", + len(self._resolved_mappings), + len(self._block_mappings), + ) + self._finalized = False + + def get_act_calib_policy(self, ctx: "BlockContext"): + """AWQ W4A16 (weight-only): no activation calibration needed.""" + # AWQ pre-processing does not collect act-calib stats; that is the + # block_quantizer's concern. For W8A8/static activation, a post-smooth + # forward may be needed — handled via the block_quantizer's policy. + return ActCalibPolicy(when=CalibTiming.SKIP, source=InputSource.FP_CACHE) + + @contextmanager + def block_forward_hooks(self, ctx: "BlockContext"): + """Register AWQ activation-stats and parent-kwargs hooks. + + Hooks are registered on the *current block's* smooth sources and + parent modules. All handles are removed when this context manager + exits (before ``__exit__`` returns), regardless of exceptions. + """ + handles = [] + block_mappings = self._block_mappings.get(ctx.block_name, []) + if block_mappings: + handles = self._register_awq_hooks(ctx.model, ctx.block, ctx.block_name) + try: + yield handles + finally: + for h in handles: + h.remove() + handles.clear() + + def pre_quantize_block(self, ctx: "BlockContext") -> None: + """Apply AWQ smoothing for this block and mark modified params. + + Called after the reference forward (activation stats collected) and + before the block quantizer runs. + """ + if len(ctx.block_names) != 1: + raise ValueError(f"AWQ requires nblocks=1, got {len(ctx.block_names)} blocks: {ctx.block_names}.") + block_name = ctx.block_names[0] + block_mappings = self._block_mappings.get(block_name, []) + if not block_mappings: + logger.debug("AWQ: no mappings for block '%s', skipping.", block_name) + return + self._smooth_block(block_name, block_mappings) + modified = [] + for mapping in block_mappings: + modified.extend(mapping.balance_names) + modified.append(mapping.smooth_name) + ctx.mark_modified_fp_params(modified) + + def post_quantize_block(self, ctx: "BlockContext") -> None: + """Release per-block AWQ caches to free memory.""" + block_mappings = self._block_mappings.get(ctx.block_name, []) + if not block_mappings: + return + for m in block_mappings: + self._activation_stats.pop(m.smooth_name, None) + seen_parents: set[int] = set() + for m in block_mappings: + pid = id(m.parent) + if pid not in seen_parents: + seen_parents.add(pid) + self._parent_args_cache.pop(m.parent, None) + + def finalize_run(self, run_ctx: "RunContext") -> None: + """Idempotent global teardown. Safe to call inside try/finally.""" + if self._finalized: + return + self._activation_stats.clear() + self._parent_args_cache.clear() + self._parent_signatures.clear() + self._finalized = True + logger.debug("AWQ: finalize_quantization complete.") + + # ── Hook registration ───────────────────────────────────────────────────── + + def _register_awq_hooks( + self, + model: torch.nn.Module, + block: torch.nn.Module, + block_name: str, + ) -> list: + """Register activation-stats and parent-kwargs hooks for one block.""" + handles = [] + mappings = self._block_mappings.get(block_name, []) + smooth_names = {m.smooth_name for m in mappings} + + # ── Smooth-layer activation-stats hooks ─────────────────────────────── + # Priority: smooth source forward_hook (output stats). + # Each smooth source is hooked exactly once (set de-duplication via name). + for name, module in block.named_modules(): + full_name = f"{block_name}.{name}" if name else block_name + if full_name not in smooth_names: + continue + + def _make_stats_hook(layer_name: str): + + def hook_fn(mod, args, output): + x = output[0] if isinstance(output, tuple) else output + if x is None or x.numel() == 0: + return + channel_sum = x.detach().float().flatten(0, -2).abs().sum(dim=0).cpu() + count = x[..., 0].numel() + if layer_name not in self._activation_stats: + self._activation_stats[layer_name] = [ + torch.zeros_like(channel_sum), + 0, + ] + self._activation_stats[layer_name][0] += channel_sum + self._activation_stats[layer_name][1] += count + + return hook_fn + + h = module.register_forward_hook(_make_stats_hook(full_name)) + handles.append(h) + + # ── Parent-kwargs hooks ─────────────────────────────────────────────── + # One forward_pre_hook per unique parent module in the current block. + parent_modules_hooked: set[int] = set() + for mapping in mappings: + parent = mapping.parent + hook_target = mapping.activation_hook_target + if hook_target: + target_parent = dict(model.named_modules()).get(hook_target) + if target_parent is None: + logger.warning( + "AWQ: activation_hook_target '%s' for '%s' was not found; using resolved parent '%s'.", + hook_target, + mapping.smooth_name, + mapping.parent_name, + ) + else: + parent = target_parent + if id(parent) in parent_modules_hooked: + continue + parent_modules_hooked.add(id(parent)) + + if parent not in self._parent_args_cache: + self._parent_args_cache[parent] = [] + + def _make_parent_hook(parent_module: torch.nn.Module): + + def hook_fn(mod, args, kwargs): + cls_id = id(type(mod)) + if cls_id not in self._parent_signatures: + self._parent_signatures[cls_id] = inspect.signature(mod.forward) + sig = self._parent_signatures[cls_id] + try: + bound = sig.bind(*args, **kwargs) + bound.apply_defaults() + except TypeError: + return # signature mismatch; skip this sample + + param = next(mod.parameters(), None) + w_dtype = param.dtype if param is not None else None + + stored: dict[str, Any] = {} + for k, v in bound.arguments.items(): + if isinstance(v, torch.Tensor): + v = v.detach() + if w_dtype and v.is_floating_point() and v.dtype != w_dtype: + v = v.to(w_dtype) + stored[k] = v + elif isinstance(v, tuple) and any(isinstance(t, torch.Tensor) for t in v): + stored[k] = tuple( + ( + t.detach().to(w_dtype) + if (w_dtype and isinstance(t, torch.Tensor) and t.is_floating_point()) + else (t.detach() if isinstance(t, torch.Tensor) else t) + ) + for t in v + ) + elif hasattr(v, "key_cache"): + stored[k] = None # Null out KV cache objects + else: + stored[k] = v + + self._parent_args_cache[parent_module].append(stored) + + return hook_fn + + h = parent.register_forward_pre_hook(_make_parent_hook(parent), with_kwargs=True) + handles.append(h) + + return handles + + # ── Smoothing (grid search + scale apply) ───────────────────────────────── + + def _smooth_block(self, block_prefix: str, block_mappings: list) -> None: + """Run grid search and apply AWQ scales for one block.""" + for mapping in block_mappings: + if mapping.smooth_name not in self._activation_stats: + logger.warning( + "AWQ: no activation stats for '%s' in block '%s'; skipping.", + mapping.smooth_name, + block_prefix, + ) + continue + + act_sum, act_count = self._activation_stats[mapping.smooth_name] + if act_count == 0: + logger.warning( + "AWQ: zero activation count for '%s' in block '%s'; skipping.", + mapping.smooth_name, + block_prefix, + ) + continue + + x_mean = (act_sum / act_count).to(torch.float32) + del act_sum + + best_scales = self._grid_search_scales(mapping, x_mean) + if best_scales is not None: + self._apply_scales(mapping, best_scales) + + # Release parent kwargs after ALL mappings for this block are processed. + seen_parents: set[int] = set() + for mapping in block_mappings: + pid = id(mapping.parent) + if pid not in seen_parents: + seen_parents.add(pid) + self._parent_args_cache.pop(mapping.parent, None) + + def _get_grid_search_params(self) -> list[tuple[float, bool]]: + """Return (ratio, use_duo_scaling) tuples for the grid search.""" + match self.duo_scaling: + case "both": + n = max(int(self.n_grid / 2), 2) + return [(idx / (n - 1), duo) for idx in range(n) for duo in [False, True]] + case False: + n = max(self.n_grid, 2) + return [(idx / (n - 1), False) for idx in range(n)] + case True: + n = max(self.n_grid, 3) + return [(0.0, False)] + [(idx / (n - 2), True) for idx in range(n - 1)] + case _: + raise ValueError(f"Unexpected duo_scaling value: {self.duo_scaling!r}") + + @staticmethod + def _compute_layer_means(layers: list[torch.nn.Module], group_size: int) -> torch.Tensor: + """Per-channel mean of normalised weights across all balance layers.""" + weight = torch.cat([m.weight.detach().float() for m in layers], dim=0) + org_shape = weight.shape + gs = group_size if group_size > 0 else org_shape[1] + weight, _, pad_len = reshape_pad_tensor_by_group_size(weight, gs) + w_scale = weight.abs() / (weight.abs().amax(dim=1, keepdim=True) + 1e-6) + w_scale = revert_tensor_by_pad(w_scale, orig_shape=org_shape, pad_len=pad_len) + return w_scale.mean(0) + + @torch.no_grad() + def _grid_search_scales( + self, + mapping: ResolvedMapping, + x_mean: torch.Tensor, + ) -> torch.Tensor | None: + """Find the best scaling ratio for *mapping* via output-based loss.""" + device = mapping.balance_layers[0].weight.device + x_mean = x_mean.to(device) + + group_size = self.group_size if (self.group_size is not None and self.group_size > 0) else -1 + if self.duo_scaling is not False: + w_mean = self._compute_layer_means(mapping.balance_layers, group_size).to(device) + + parent_kwargs_list = self._parent_args_cache.get(mapping.parent, []) + use_parent_forward = len(parent_kwargs_list) > 0 + + if use_parent_forward: + fp16_outputs = self._run_parent_samples(mapping.parent, parent_kwargs_list) + if not fp16_outputs or all(f.numel() == 0 for f in fp16_outputs): + use_parent_forward = False + + orig_state = {bl: bl.weight.data.clone() for bl in mapping.balance_layers} + if not use_parent_forward: + orig_weights = orig_state # same reference is fine + + # Pre-resolve quant function once to avoid repeated dispatch in loop. + ref_layer = mapping.balance_layers[0] + ref_name = getattr(ref_layer, "global_name", None) or "" + ref_cfg = (self.layer_config or {}).get(ref_name, {}) + try: + cached_quant_func, _ = get_quant_func( + ref_cfg.get("data_type", self.data_type), + ref_cfg.get("bits", self.bits), + ref_cfg.get("sym", self.sym), + disable_opt_rtn=True, + group_size=ref_cfg.get("group_size", self.group_size), + iters=0, + ) + except Exception as exc: + logger.debug("AWQ: failed to pre-resolve quant function for '%s': %s", ref_name, exc) + cached_quant_func = None + + best_error = float("inf") + best_scales = None + best_ratio = -1 + + for ratio, use_duo in self._get_grid_search_params(): + if use_duo: + scales = (x_mean.pow(ratio) / (w_mean.pow(1 - ratio) + 1e-4)).clamp(min=1e-4) + else: + scales = x_mean.pow(ratio).clamp(min=1e-4).view(-1) + scales = scales / (scales.max() * scales.min()).sqrt() + scales[torch.isinf(scales)] = 1 + scales[torch.isnan(scales)] = 1 + scales_view = scales.view(1, -1).to(device) + + if use_parent_forward: + for bl in mapping.balance_layers: + bl.weight.data.copy_(orig_state[bl] * scales_view) + w_qdq = self._quantize_dequantize_weight(bl, bl.weight.data.float(), quant_func=cached_quant_func) + if w_qdq is not None: + bl.weight.data = (w_qdq / scales_view).to(bl.weight.dtype) + else: + bl.weight.data.copy_(orig_state[bl]) + + int_w_outputs = self._run_parent_samples(mapping.parent, parent_kwargs_list) + total_loss = self._compute_loss(fp16_outputs, int_w_outputs) + del int_w_outputs + for bl in mapping.balance_layers: + bl.weight.data.copy_(orig_state[bl]) + else: + total_loss = 0.0 + for bl in mapping.balance_layers: + w_orig = orig_weights[bl].to(device) + w_qdq = self._quantize_dequantize_weight(bl, w_orig * scales_view, quant_func=cached_quant_func) + if w_qdq is None: + total_loss = float("inf") + break + total_loss += (w_orig - w_qdq / scales_view).pow(2).sum().item() + + if total_loss < best_error: + best_error = total_loss + best_scales = scales.clone() + best_ratio = ratio + + if best_ratio < 0: + logger.warning("AWQ: grid search failed for '%s': no finite error.", mapping.smooth_name) + return None + + logger.debug("AWQ '%s': best_ratio=%.2f, best_error=%.3e", mapping.smooth_name, best_ratio, best_error) + return best_scales + + @torch.no_grad() + def _run_parent_samples( + self, + parent: torch.nn.Module, + kwargs_list: list[dict], + ) -> list[torch.Tensor]: + outputs = [] + for stored_kwargs in kwargs_list: + out = parent(**stored_kwargs) + if isinstance(out, tuple): + out = out[0] + outputs.append(out) + return outputs + + @staticmethod + @torch.no_grad() + def _compute_loss( + fp16_outputs: list[torch.Tensor], + int_w_outputs: list[torch.Tensor], + ) -> float: + device = fp16_outputs[0].device + loss = torch.tensor(0.0, device=device) + num_elements = torch.tensor(0, device=device, dtype=torch.long) + for fp16_out, int_w_out in zip(fp16_outputs, int_w_outputs): + loss += torch.nn.functional.mse_loss( + fp16_out.float(), + int_w_out.to(fp16_out.device).float(), + reduction="sum", + ) + num_elements += fp16_out.numel() + if num_elements == 0: + return float("inf") + return (loss / num_elements).item() + + def _quantize_dequantize_weight( + self, + layer: torch.nn.Module, + weight: torch.Tensor, + quant_func=None, + ) -> torch.Tensor | None: + """Quantize-dequantize a weight tensor using the layer's config. + + Used internally for grid search loss calculation only. Does NOT + modify the layer's stored weights. + """ + layer_name = getattr(layer, "global_name", None) or "" + config = (self.layer_config or {}).get(layer_name, {}) + bits = config.get("bits", self.bits) + group_size = config.get("group_size", self.group_size) + sym = config.get("sym", self.sym) + data_type = config.get("data_type", self.data_type) + + if quant_func is None: + try: + quant_func, _ = get_quant_func( + data_type, + bits, + sym, + disable_opt_rtn=True, + group_size=group_size, + iters=0, + ) + except Exception as exc: + logger.debug("AWQ: failed to resolve quant function for '%s': %s", layer_name, exc) + return None + + if quant_func is None: + return None + + try: + qdq_weight, _, _ = quant_func(weight, bits=bits, group_size=group_size) + return qdq_weight + except Exception as exc: + logger.debug("AWQ: quantize-dequantize failed for '%s': %s", layer_name, exc) + return None + + @torch.no_grad() + def _apply_scales(self, mapping: ResolvedMapping, scales: torch.Tensor) -> None: + """Apply computed AWQ scales to smooth and balance layers in-place.""" + for bl in mapping.balance_layers: + s = scales.to(bl.weight.device).view(1, -1) + bl.weight.data.mul_(s) + + smooth = mapping.smooth_layer + s = scales.to(smooth.weight.device) + if smooth.weight.ndim == 1: + smooth.weight.data.div_(s) + else: + smooth.weight.data[-s.size(0) :].div_(s.view(-1, 1)) + + if hasattr(smooth, "bias") and smooth.bias is not None: + smooth.bias.data.div_(s) diff --git a/auto_round/algorithms/transforms/base.py b/auto_round/algorithms/transforms/base.py index db48c6059..4a36b8080 100644 --- a/auto_round/algorithms/transforms/base.py +++ b/auto_round/algorithms/transforms/base.py @@ -27,6 +27,21 @@ import torch import torch.nn as nn +from auto_round.algorithms.base import BasePipelineMember + + +class BaseWeightTransformer(BasePipelineMember): + """Base class for weight-transformation algorithms in a QuantizationPipeline.""" + + def pre_quantize_block(self, ctx) -> None: + """Called after the reference forward, before block quantization.""" + return + + def post_quantize_block(self, ctx) -> None: + """Called after the block quantizer completes.""" + return + + # --------------------------------------------------------------------------- # Config base # --------------------------------------------------------------------------- @@ -159,7 +174,7 @@ def _ensure_registry_populated() -> None: # Import each sub-package here. Add new entries as more algorithms land. import importlib - for sub in ("rotation", "spinquant"): + for sub in ("quarot", "spinquant"): try: importlib.import_module(f"auto_round.algorithms.transforms.{sub}") except ImportError: diff --git a/auto_round/algorithms/transforms/rotation/__init__.py b/auto_round/algorithms/transforms/quarot/__init__.py similarity index 85% rename from auto_round/algorithms/transforms/rotation/__init__.py rename to auto_round/algorithms/transforms/quarot/__init__.py index 640cf06ab..3faf385e4 100644 --- a/auto_round/algorithms/transforms/rotation/__init__.py +++ b/auto_round/algorithms/transforms/quarot/__init__.py @@ -13,15 +13,15 @@ # limitations under the License. """Hadamard rotation sub-package for ``algorithms/transforms``.""" -from auto_round.algorithms.transforms.rotation.apply import ( +from auto_round.algorithms.transforms.quarot.apply import ( HadamardRotation, apply_rotation_transform, ) -from auto_round.algorithms.transforms.rotation.config import ( +from auto_round.algorithms.transforms.quarot.config import ( RotationConfig, normalize_rotation_config, ) -from auto_round.algorithms.transforms.rotation.transforms import ( +from auto_round.algorithms.transforms.quarot.transforms import ( HADAMARDS, HadamardTransform, RandomHadamardTransform, diff --git a/auto_round/algorithms/transforms/rotation/apply.py b/auto_round/algorithms/transforms/quarot/apply.py similarity index 91% rename from auto_round/algorithms/transforms/rotation/apply.py rename to auto_round/algorithms/transforms/quarot/apply.py index 62ceb02b5..9656002cc 100644 --- a/auto_round/algorithms/transforms/rotation/apply.py +++ b/auto_round/algorithms/transforms/quarot/apply.py @@ -27,8 +27,8 @@ import tqdm from auto_round.algorithms.transforms.base import BaseRotation -from auto_round.algorithms.transforms.rotation.config import RotationConfig, normalize_rotation_config -from auto_round.algorithms.transforms.rotation.transforms import build_hadamard_transform +from auto_round.algorithms.transforms.quarot.config import RotationConfig, normalize_rotation_config +from auto_round.algorithms.transforms.quarot.transforms import build_hadamard_transform from auto_round.compressors.utils import is_nv_fp from auto_round.experimental.qmodules.base import QModuleBase @@ -44,7 +44,7 @@ def _triton_available(data_type: str = "mx_fp") -> bool: if not torch.cuda.is_available(): return False - from auto_round.algorithms.transforms.rotation.utils.triton.mxfp4 import ( # noqa: F401 + from auto_round.algorithms.transforms.quarot.utils.triton.mxfp4 import ( # noqa: F401 mxfp4_forward_kernel_wrapper, ) @@ -67,7 +67,7 @@ class HadamardRotation(BaseRotation): Or directly:: - from auto_round.algorithms.transforms.rotation import apply_rotation_transform + from auto_round.algorithms.transforms.quarot import apply_rotation_transform model = apply_rotation_transform(model, config=RotationConfig(), need_calibration=True) """ @@ -109,13 +109,13 @@ def apply_to_model( # Dispatch by backend. The transform backend (triton-fused per-Linear) # is implemented below; the inplace (QuaRot) backend is delegated to - # :mod:`auto_round.algorithms.transforms.rotation.inplace`. - from auto_round.algorithms.transforms.rotation.dispatcher import resolve_hadamard_backend + # :mod:`auto_round.algorithms.transforms.quarot.inplace`. + from auto_round.algorithms.transforms.quarot.dispatcher import resolve_hadamard_backend backend = resolve_hadamard_backend(cfg, data_type) if backend == "inplace": import auto_round.envs as envs - from auto_round.algorithms.transforms.rotation.inplace import apply_rotation_transform as _inplace_apply + from auto_round.algorithms.transforms.quarot.inplace import apply_rotation_transform as _inplace_apply # Resolve fuse flag: explicit > env var > default(False). fuse_online_to_weight = cfg.fuse_online_to_weight @@ -181,7 +181,7 @@ def _apply_to_module( def _apply_input_transform(module: torch.nn.Module, config: RotationConfig, data_type: str = "mx_fp") -> None: """Register a forward pre-hook that applies the Hadamard to the input activation.""" - from auto_round.algorithms.transforms.rotation.utils.matrix import multihead_matmul + from auto_round.algorithms.transforms.quarot.utils.matrix import multihead_matmul inp_transform = build_hadamard_transform( **config.model_dump(), @@ -197,7 +197,7 @@ def _apply_input_transform(module: torch.nn.Module, config: RotationConfig, data hadamard_weight = None if _triton_available(data_type): - from auto_round.algorithms.transforms.rotation.utils.triton.mxfp4 import mxfp4_forward_kernel_wrapper + from auto_round.algorithms.transforms.quarot.utils.triton.mxfp4 import mxfp4_forward_kernel_wrapper def _input_hook(self, args): x = args[0] @@ -232,7 +232,7 @@ def _apply_weight_transform( config: RotationConfig, ) -> None: """Fuse or patch the Hadamard rotation into the weight of *module*.""" - from auto_round.algorithms.transforms.rotation.patch import ( + from auto_round.algorithms.transforms.quarot.patch import ( patch_quantlinear, patch_wrapperlinear_to_apply_transform, patch_wrapperwalayer_forward_to_apply_transform, @@ -248,7 +248,7 @@ def _apply_weight_transform( # For random Hadamard, save the matrix as a submodule for serialisation. if config.hadamard_type == "random_hadamard": - from auto_round.algorithms.transforms.rotation.patch import patch_quantlinear as _patch_ql + from auto_round.algorithms.transforms.quarot.patch import patch_quantlinear as _patch_ql _patch_ql(w_transform) diff --git a/auto_round/algorithms/transforms/rotation/config.py b/auto_round/algorithms/transforms/quarot/config.py similarity index 85% rename from auto_round/algorithms/transforms/rotation/config.py rename to auto_round/algorithms/transforms/quarot/config.py index 629714557..761b3b318 100644 --- a/auto_round/algorithms/transforms/rotation/config.py +++ b/auto_round/algorithms/transforms/quarot/config.py @@ -19,13 +19,13 @@ Two implementation backends share this one schema (method B): * ``backend="inplace"`` – QuaRot-style residual-stream rotation, implemented - under :mod:`auto_round.algorithms.transforms.rotation.inplace`. Works for any + under :mod:`auto_round.algorithms.transforms.quarot.inplace`. Works for any weight/activation dtype and can optionally fuse the online Hadamard into weights (``fuse_online_to_weight=True``). * ``backend="transform"`` – Per-Linear weight + activation Hadamard with a fused triton kernel, implemented under - :mod:`auto_round.algorithms.transforms.rotation.apply`. Supports only + :mod:`auto_round.algorithms.transforms.quarot.apply`. Supports only MXFP4 / NVFP4 and cannot fuse online to weight. * ``backend="auto"`` – dispatcher picks inplace when a fused online rotation @@ -83,6 +83,27 @@ class RotationConfig(BaseModel, BaseRotationConfig): model_config = {"arbitrary_types_allowed": True} + def __init__(self, **data): + """Initialize a Hadamard rotation configuration. + + Args: + algorithm: Canonical algorithm name used for registry lookup. + backend: Rotation backend to use. ``auto`` lets AutoRound pick + an implementation, ``inplace`` uses QuaRot-style online + rotation, and ``transform`` uses the transform backend. + block_size: Grouped Hadamard block size. None keeps the backend + default behavior. + hadamard_type: Hadamard transform variant, such as + ``hadamard``, ``random_hadamard``, or ``quarot_hadamard``. + fuse_online_to_weight: Whether online Hadamard rotation should + be fused into the weights when supported. + allow_online_rotation: Whether online activation rotation is + allowed. + random_seed: Internal flag used by random Hadamard paths. + **data: Additional Pydantic field values forwarded to BaseModel. + """ + super().__init__(**data) + @field_validator("backend") @classmethod def _validate_backend(cls, v: str) -> str: diff --git a/auto_round/algorithms/transforms/rotation/dispatcher.py b/auto_round/algorithms/transforms/quarot/dispatcher.py similarity index 93% rename from auto_round/algorithms/transforms/rotation/dispatcher.py rename to auto_round/algorithms/transforms/quarot/dispatcher.py index ef47c2e56..e15236398 100644 --- a/auto_round/algorithms/transforms/rotation/dispatcher.py +++ b/auto_round/algorithms/transforms/quarot/dispatcher.py @@ -9,11 +9,11 @@ Two backend implementations exist: -* ``inplace`` – :mod:`auto_round.algorithms.transforms.rotation.inplace` +* ``inplace`` – :mod:`auto_round.algorithms.transforms.quarot.inplace` QuaRot-style residual-stream rotation. Works for any weight/activation dtype. Optionally fuses the online Hadamard into weights (``fuse_online_to_weight=True``). -* ``transform`` – :mod:`auto_round.algorithms.transforms.rotation` +* ``transform`` – :mod:`auto_round.algorithms.transforms.quarot` Per-Linear weight + activation Hadamard with a fused triton kernel. Only supports MXFP4 / NVFP4 and **cannot** fuse online to weight. @@ -33,7 +33,7 @@ import torch import auto_round.envs as envs -from auto_round.algorithms.transforms.rotation.config import RotationConfig, normalize_rotation_config +from auto_round.algorithms.transforms.quarot.config import RotationConfig, normalize_rotation_config from auto_round.compressors.utils import is_mx_fp, is_nv_fp from auto_round.utils import logger @@ -122,7 +122,7 @@ def apply_hadamard_rotation( if backend == "inplace": logger.warning("this backend does not support real exporting, please export the model to fake format") - from auto_round.algorithms.transforms.rotation.inplace import apply_rotation_transform + from auto_round.algorithms.transforms.quarot.inplace import apply_rotation_transform # block_size -> group_size (None / -1 / 0 means full-dimension) bs = config.block_size @@ -145,7 +145,7 @@ def apply_hadamard_rotation( supported_hadamard_types = ("hadamard", "random_hadamard") if config.hadamard_type not in supported_hadamard_types: raise ValueError("this backend only supports hadamard or random_hadamard") - from auto_round.algorithms.transforms.rotation.apply import apply_rotation_transform + from auto_round.algorithms.transforms.quarot.apply import apply_rotation_transform return apply_rotation_transform(model, config, data_type=data_type) else: diff --git a/auto_round/algorithms/transforms/rotation/inplace/__init__.py b/auto_round/algorithms/transforms/quarot/inplace/__init__.py similarity index 56% rename from auto_round/algorithms/transforms/rotation/inplace/__init__.py rename to auto_round/algorithms/transforms/quarot/inplace/__init__.py index d7b85d9b1..53fa25fc1 100644 --- a/auto_round/algorithms/transforms/rotation/inplace/__init__.py +++ b/auto_round/algorithms/transforms/quarot/inplace/__init__.py @@ -5,7 +5,7 @@ Canonical home of the residual-stream Hadamard rotation implementation. """ -from auto_round.algorithms.transforms.rotation.inplace.apply import apply_rotation_transform # noqa: F401 -from auto_round.algorithms.transforms.rotation.inplace.hooks import clear_random_hadamard_cache # noqa: F401 +from auto_round.algorithms.transforms.quarot.inplace.apply import apply_rotation_transform # noqa: F401 +from auto_round.algorithms.transforms.quarot.inplace.hooks import clear_random_hadamard_cache # noqa: F401 __all__ = ["apply_rotation_transform", "clear_random_hadamard_cache"] diff --git a/auto_round/algorithms/transforms/rotation/inplace/apply.py b/auto_round/algorithms/transforms/quarot/inplace/apply.py similarity index 99% rename from auto_round/algorithms/transforms/rotation/inplace/apply.py rename to auto_round/algorithms/transforms/quarot/inplace/apply.py index 3177d6fa9..749013065 100644 --- a/auto_round/algorithms/transforms/rotation/inplace/apply.py +++ b/auto_round/algorithms/transforms/quarot/inplace/apply.py @@ -14,7 +14,7 @@ import torch import tqdm -from auto_round.algorithms.transforms.rotation.inplace.hooks import ( +from auto_round.algorithms.transforms.quarot.inplace.hooks import ( CrossHeadOnlineHadamardHook, FullOnlineHadamardHook, GroupOnlineHadamardHook, @@ -29,7 +29,7 @@ get_hadK, get_or_create_random_hadamard, ) -from auto_round.algorithms.transforms.rotation.inplace.model_config import ( +from auto_round.algorithms.transforms.quarot.inplace.model_config import ( MAPPING_REGISTRY, RotationMapping, _resolve, diff --git a/auto_round/algorithms/transforms/rotation/inplace/hooks.py b/auto_round/algorithms/transforms/quarot/inplace/hooks.py similarity index 99% rename from auto_round/algorithms/transforms/rotation/inplace/hooks.py rename to auto_round/algorithms/transforms/quarot/inplace/hooks.py index 2b3d26c6e..6b86eee4d 100644 --- a/auto_round/algorithms/transforms/rotation/inplace/hooks.py +++ b/auto_round/algorithms/transforms/quarot/inplace/hooks.py @@ -295,7 +295,7 @@ def register_online_had_hooks(model, mapping=None, fp32_had=False, use_fast_had= list of hook handles (call ``handle.remove()`` to detach). """ if mapping is None: - from auto_round.algorithms.transforms.rotation.inplace.model_config import infer_mapping_from_model + from auto_round.algorithms.transforms.quarot.inplace.model_config import infer_mapping_from_model mapping = infer_mapping_from_model(model) @@ -352,7 +352,7 @@ def get_hadK(n: int, transpose=False) -> (torch.Tensor, int): K = 1 return hadK, K else: - from auto_round.algorithms.transforms.rotation.utils.math import _fetch_hadamard_divisor + from auto_round.algorithms.transforms.quarot.utils.math import _fetch_hadamard_divisor hadK = _fetch_hadamard_divisor(n, torch.float, torch.device("cpu")) if transpose: diff --git a/auto_round/algorithms/transforms/rotation/inplace/model_config.py b/auto_round/algorithms/transforms/quarot/inplace/model_config.py similarity index 100% rename from auto_round/algorithms/transforms/rotation/inplace/model_config.py rename to auto_round/algorithms/transforms/quarot/inplace/model_config.py diff --git a/auto_round/algorithms/transforms/rotation/patch.py b/auto_round/algorithms/transforms/quarot/patch.py similarity index 100% rename from auto_round/algorithms/transforms/rotation/patch.py rename to auto_round/algorithms/transforms/quarot/patch.py diff --git a/auto_round/algorithms/transforms/rotation/transforms.py b/auto_round/algorithms/transforms/quarot/transforms.py similarity index 97% rename from auto_round/algorithms/transforms/rotation/transforms.py rename to auto_round/algorithms/transforms/quarot/transforms.py index 78dcd0d6b..71320e61e 100644 --- a/auto_round/algorithms/transforms/rotation/transforms.py +++ b/auto_round/algorithms/transforms/quarot/transforms.py @@ -27,11 +27,11 @@ import torch import torch.nn as nn -from auto_round.algorithms.transforms.rotation.utils.math import ( +from auto_round.algorithms.transforms.quarot.utils.math import ( deterministic_hadamard_matrix, random_hadamard_matrix, ) -from auto_round.algorithms.transforms.rotation.utils.matrix import apply_transform_weight +from auto_round.algorithms.transforms.quarot.utils.matrix import apply_transform_weight __all__ = [ "HadamardTransform", diff --git a/auto_round/algorithms/transforms/rotation/utils/__init__.py b/auto_round/algorithms/transforms/quarot/utils/__init__.py similarity index 100% rename from auto_round/algorithms/transforms/rotation/utils/__init__.py rename to auto_round/algorithms/transforms/quarot/utils/__init__.py diff --git a/auto_round/algorithms/transforms/rotation/utils/hadamards.safetensors b/auto_round/algorithms/transforms/quarot/utils/hadamards.safetensors similarity index 100% rename from auto_round/algorithms/transforms/rotation/utils/hadamards.safetensors rename to auto_round/algorithms/transforms/quarot/utils/hadamards.safetensors diff --git a/auto_round/algorithms/transforms/rotation/utils/math.py b/auto_round/algorithms/transforms/quarot/utils/math.py similarity index 100% rename from auto_round/algorithms/transforms/rotation/utils/math.py rename to auto_round/algorithms/transforms/quarot/utils/math.py diff --git a/auto_round/algorithms/transforms/rotation/utils/matrix.py b/auto_round/algorithms/transforms/quarot/utils/matrix.py similarity index 100% rename from auto_round/algorithms/transforms/rotation/utils/matrix.py rename to auto_round/algorithms/transforms/quarot/utils/matrix.py diff --git a/auto_round/algorithms/transforms/rotation/utils/triton/__init__.py b/auto_round/algorithms/transforms/quarot/utils/triton/__init__.py similarity index 100% rename from auto_round/algorithms/transforms/rotation/utils/triton/__init__.py rename to auto_round/algorithms/transforms/quarot/utils/triton/__init__.py diff --git a/auto_round/algorithms/transforms/rotation/utils/triton/mxfp4.py b/auto_round/algorithms/transforms/quarot/utils/triton/mxfp4.py similarity index 100% rename from auto_round/algorithms/transforms/rotation/utils/triton/mxfp4.py rename to auto_round/algorithms/transforms/quarot/utils/triton/mxfp4.py diff --git a/auto_round/algorithms/transforms/spinquant/inplace/__init__.py b/auto_round/algorithms/transforms/spinquant/inplace/__init__.py index dc80ac2a5..ac3071a93 100644 --- a/auto_round/algorithms/transforms/spinquant/inplace/__init__.py +++ b/auto_round/algorithms/transforms/spinquant/inplace/__init__.py @@ -4,7 +4,7 @@ """ SpinQuant in-place application sub-package. -Follows the same structure as ``auto_round.algorithms.transforms.rotation.inplace``. +Follows the same structure as ``auto_round.algorithms.transforms.quarot.inplace``. """ from auto_round.algorithms.transforms.spinquant.inplace.apply import ( diff --git a/auto_round/algorithms/transforms/spinquant/inplace/apply.py b/auto_round/algorithms/transforms/spinquant/inplace/apply.py index c831c193f..f4e7cba9f 100644 --- a/auto_round/algorithms/transforms/spinquant/inplace/apply.py +++ b/auto_round/algorithms/transforms/spinquant/inplace/apply.py @@ -6,7 +6,7 @@ This module provides ``apply_spinquant_in_place`` and hook registration that follow the same patterns used by AutoRound's -``auto_round.algorithms.transforms.rotation.inplace`` package. +``auto_round.algorithms.transforms.quarot.inplace`` package. R3 rotation uses the architecture-generic monkeypatch approach from QuaRot/Quark: we replace ``apply_rotary_pos_emb`` in the attention forward's globals with a diff --git a/auto_round/algorithms/transforms/spinquant/rotation_utils.py b/auto_round/algorithms/transforms/spinquant/rotation_utils.py index aff9b2395..65678405d 100644 --- a/auto_round/algorithms/transforms/spinquant/rotation_utils.py +++ b/auto_round/algorithms/transforms/spinquant/rotation_utils.py @@ -154,7 +154,7 @@ def matmul_hadU(X: torch.Tensor, hadamard_K: Optional[torch.Tensor] = None, K: O # the fallback implementations below are used. # --------------------------------------------------------------------------- try: - from auto_round.algorithms.transforms.rotation.utils.matrix import apply_transform_weight + from auto_round.algorithms.transforms.quarot.utils.matrix import apply_transform_weight except ImportError: # Fallback for standalone usage. def apply_transform_weight( @@ -462,7 +462,7 @@ def fuse_rmsnorm_in_model(model: nn.Module) -> None: """ # Attempt to use AutoRound's model_config layer discovery if available. try: - from auto_round.algorithms.transforms.rotation.inplace.model_config import get_scaling_layers + from auto_round.algorithms.transforms.quarot.inplace.model_config import get_scaling_layers layer_paths = get_scaling_layers(model.config.model_type if hasattr(model, "config") else "") if layer_paths: @@ -756,7 +756,7 @@ def get_model_arch_info(model: nn.Module) -> dict: def get_attention_layers(model: nn.Module): """Yield attention modules using model_config if available, else fall back.""" try: - from auto_round.algorithms.transforms.rotation.inplace.model_config import get_attention_layers as _get + from auto_round.algorithms.transforms.quarot.inplace.model_config import get_attention_layers as _get return _get(model) except ImportError: diff --git a/auto_round/autoround.py b/auto_round/autoround.py index c2e2742c4..e45200933 100644 --- a/auto_round/autoround.py +++ b/auto_round/autoround.py @@ -23,7 +23,77 @@ if TYPE_CHECKING: from auto_round.auto_scheme.gen_auto_scheme import AutoScheme from auto_round.compressors.base import BaseCompressor - from auto_round.compressors.config import ExtraConfig + + +_COMPAT_KWARGS = { + "format", + "bits", + "group_size", + "sym", + "data_type", + "act_bits", + "act_group_size", + "act_sym", + "act_data_type", + "act_dynamic", + "super_bits", + "super_group_size", + "scale_dtype", + "ignore_layers", + "quant_lm_head", + "to_quant_block_names", + "model_free", + "disable_model_free", + "model_dtype", + "trust_remote_code", + "amp", + "nblocks", + "lr", + "minmax_lr", + "enable_minmax_tuning", + "enable_norm_bias_tuning", + "enable_quanted_input", + "enable_opt_rtn", + "disable_deterministic_algorithms", + "enable_deterministic_algorithms", + "static_kv_dtype", + "static_attention_dtype", + "rotation_config", + "processor", + "image_processor", + "template", + "extra_data_dir", + "quant_nontext_module", + "guidance_scale", + "num_inference_steps", + "generator_seed", + "duo_scaling", + "n_grid", + "mappings", + "algorithm", + "optimizer", + "lr_scheduler", + "not_use_best_mse", + "dynamic_max_gap", + "momentum", + "device", +} + + +def _filter_supported_compat_kwargs(kwargs: dict) -> dict: + supported = {} + unknown = [] + for key, value in kwargs.items(): + if key in _COMPAT_KWARGS: + supported[key] = value + else: + unknown.append(key) + if unknown: + logger.warning_once( + "AutoRound compatibility path received unsupported kwargs %s. They will be ignored.", + ", ".join(sorted(unknown)), + ) + return supported class AutoRound: @@ -45,7 +115,7 @@ class AutoRound: enable_torch_compile (bool): Whether to enable torch.compile for quant blocks/layers. """ - SKIP_ARGS = ("local_args", "kwargs", "cls", "model_cls", "dynamic_compressor", "extra_config") + SKIP_ARGS = ("local_args", "kwargs", "cls", "model_cls", "dynamic_compressor", "alg_configs") bits: int | None group_size: int | tuple | None @@ -65,7 +135,7 @@ def __new__( model: Union[torch.nn.Module, str], tokenizer=None, platform: str = "hf", - scheme: Union[str, dict, QuantizationScheme, AutoScheme] = "W4A16", + scheme: Union[str, dict, QuantizationScheme, "AutoScheme"] = "W4A16", layer_config: dict[str, Union[str, dict, QuantizationScheme]] = None, dataset: Union[str, list, tuple, torch.utils.data.DataLoader] = "NeelNanda/pile-10k", iters: int = 200, @@ -78,10 +148,10 @@ def __new__( enable_torch_compile: bool = False, seed: int = 42, enable_adam: bool = False, - extra_config: "ExtraConfig" = None, enable_alg_ext: bool = False, disable_opt_rtn: bool | None = None, low_cpu_mem_usage: bool = True, + alg_configs=None, **kwargs, ) -> "BaseCompressor": """Initialize AutoRound with quantization and tuning configuration. @@ -103,7 +173,6 @@ def __new__( enable_torch_compile (bool, optional): Enable torch.compile for low cost in quantization. Defaults to False. seed (int, optional): Random seed. Defaults to 42. enable_adam (bool, optional): Enable Adam-based optimizer. Defaults to False. - extra_config(ExtraConfig, optional): Extra configuration for lots of configurations. Defaults to None. enable_alg_ext (bool, optional): Enable algorithm extension (primarily for INT2) for better accuracy. Defaults to False. disable_opt_rtn (bool, optional): Disable RTN-mode optimization (iters=0) for fast quatnziation @@ -150,19 +219,62 @@ def __new__( ... } """ - local_args = {k: v for k, v in locals().items() if k not in cls.SKIP_ARGS} - if extra_config is not None: - for key, value in extra_config.to_dict().items(): - if value is None: - continue - if key in local_args: - local_args[key] = value - else: - kwargs[key] = value + # Short-circuit: if alg_configs is provided, bypass AutoRoundCompatible and go directly + # to the new-arch entry point to avoid duplicate keyword argument errors. + if alg_configs is not None: + from auto_round.compressors.entry import AutoRound as _NewAutoRound + from auto_round.compressors.entry import filter_supported_entry_kwargs + + entry_kwargs = filter_supported_entry_kwargs(kwargs, context="AutoRound") + + return _NewAutoRound( + alg_configs=alg_configs, + model=model, + tokenizer=tokenizer, + platform=platform, + format=entry_kwargs.pop("format", None), + scheme=scheme, + low_gpu_mem_usage=low_gpu_mem_usage, + device_map=device_map, + iters=iters, + gradient_accumulate_steps=gradient_accumulate_steps, + enable_torch_compile=enable_torch_compile, + seed=seed, + low_cpu_mem_usage=low_cpu_mem_usage, + layer_config=layer_config, + nsamples=nsamples, + seqlen=seqlen, + **entry_kwargs, + ) + + compat_kwargs = _filter_supported_compat_kwargs(kwargs) + compat_kwargs.update( + enable_adam=enable_adam, + enable_alg_ext=enable_alg_ext, + disable_opt_rtn=disable_opt_rtn, + ) from auto_round.compressors.entry import AutoRoundCompatible - return AutoRoundCompatible(**local_args, **kwargs) + return AutoRoundCompatible( + model=model, + tokenizer=tokenizer, + platform=platform, + scheme=scheme, + layer_config=layer_config, + dataset=dataset, + iters=iters, + seqlen=seqlen, + nsamples=nsamples, + batch_size=batch_size, + gradient_accumulate_steps=gradient_accumulate_steps, + low_gpu_mem_usage=low_gpu_mem_usage, + device_map=device_map, + enable_torch_compile=enable_torch_compile, + seed=seed, + low_cpu_mem_usage=low_cpu_mem_usage, + **compat_kwargs, + ) @classmethod @torch.no_grad() diff --git a/auto_round/calibration/state.py b/auto_round/calibration/state.py index f4fdfc9f7..49c10e512 100644 --- a/auto_round/calibration/state.py +++ b/auto_round/calibration/state.py @@ -15,7 +15,7 @@ This dataclass owns every per-run calibration field shared between :class:`~auto_round.compressors.base.BaseCompressor` and -:class:`~auto_round.algorithms.quantization.base.BaseQuantizers`: +:class:`~auto_round.algorithms.quantization.base.BaseQuantizer`: - Cache state ``(inputs, to_cached_layers, last_cache_name, blocks_requiring_input_ids)`` - Per-batch shape state ``(attention_mask, batch_dim)`` diff --git a/auto_round/cli/__init__.py b/auto_round/cli/__init__.py new file mode 100644 index 000000000..f52f12468 --- /dev/null +++ b/auto_round/cli/__init__.py @@ -0,0 +1,12 @@ +# # Copyright (C) 2026 Intel Corporation +# # SPDX-License-Identifier: Apache-2.0 + +"""CLI package for AutoRound. + +Modules: + main.py - command routing, RECIPES, tune(), eval entry points + parser.py - argparse parser construction (all explicit flags here) + algorithms.py - per-algorithm flag registration and config building +""" + +from auto_round.cli.main import run, run_best, run_light, run_rtn, run_opt_rtn, run_mllm, run_eval # noqa: F401 diff --git a/auto_round/cli/algorithms.py b/auto_round/cli/algorithms.py new file mode 100644 index 000000000..7046e6f4b --- /dev/null +++ b/auto_round/cli/algorithms.py @@ -0,0 +1,406 @@ +# # Copyright (C) 2026 Intel Corporation +# # SPDX-License-Identifier: Apache-2.0 + +"""Algorithm discovery, registration, listing, and config building for the CLI. + +To add a new algorithm: +1. Write a class that extends AlgorithmHandler and set name/aliases/summary. +2. Implement register(group) to declare argparse flags. +3. Implement build(args, common_kwargs) to return a config object. + +No manual registry update needed — subclasses are auto-registered on definition. +""" + +from __future__ import annotations + +import argparse +from abc import ABC, abstractmethod +from typing import Any, ClassVar + +from auto_round.algorithms.registry import ( + get_algorithm_entry, + iter_algorithm_entries, + register_algorithm, + resolve_algorithm_alias, +) + +# ============================================================================ +# Base class + registry +# ============================================================================ + + +class AlgorithmHandler(ABC): + """Bundles everything the CLI needs to know about one algorithm. + + Concrete subclasses are auto-registered in the class-level registry + the moment their class body is processed. + """ + + name: str # canonical name used in --algorithm + aliases: tuple[str, ...] = () # all accepted names, including canonical + summary: str = "" # one-liner shown by `auto_round list alg` + config_factory: ClassVar[type | None] = None + + def __init_subclass__(cls, **kwargs) -> None: + super().__init_subclass__(**kwargs) + # A class that overrides both register and build is considered a concrete + # algorithm implementation and must also declare a `name` class attribute. + if "register" in cls.__dict__ and "build" in cls.__dict__: + if "name" not in cls.__dict__: + raise TypeError(f"{cls.__name__} must define a 'name' class attribute " f"(e.g. name = 'my_algo').") + register_algorithm( + cls.name, + aliases=cls.aliases, + config_factory=cls.config_factory, + cli_handler=cls, + summary=cls.summary, + ) + + # ------------------------------------------------------------------ + # Abstract interface — implement in each subclass + # ------------------------------------------------------------------ + + @abstractmethod + def register(self, group) -> None: + """Add argparse arguments to *group*.""" + + @abstractmethod + def build(self, args, common_kwargs: dict[str, Any]) -> Any: + """Build and return an algorithm config from parsed *args*.""" + + # ------------------------------------------------------------------ + # Registry operations — called on the class, not on instances + # ------------------------------------------------------------------ + + @classmethod + def get(cls, name: str) -> AlgorithmHandler: + """Return the handler for a canonical algorithm name. Raises KeyError if unknown.""" + entry = get_algorithm_entry(name) + if entry.cli_handler is None: + raise KeyError(f"No handler registered for algorithm '{name}'.") + return entry.cli_handler() + + @classmethod + def resolve_alias(cls, user_name: str) -> str | None: + """Resolve a user-supplied algorithm name or alias to the canonical name. + + Returns None instead of raising so callers can silently skip unknowns. + """ + return resolve_algorithm_alias(user_name) + + @classmethod + def add_groups(cls, parser) -> None: + """Add an argparse argument group for every registered algorithm.""" + for entry in iter_algorithm_entries(): + if entry.cli_handler is None: + continue + handler = entry.cli_handler() + group = parser.add_argument_group(f"Algorithm: {handler.name}") + handler.register(group) + + @classmethod + def build_configs(cls, args, common_kwargs: dict[str, Any]) -> list: + """Build the ordered algorithm config list from parsed CLI args.""" + raw = getattr(args, "algorithm", None) or "" + names = [n.strip().lower() for n in raw.split(",") if n.strip()] + + # Infer hadamard when --rotation_type is given + if getattr(args, "rotation_hadamard_type", None) and "hadamard" not in names: + names.append("hadamard") + + # Resolve aliases, drop unknowns, deduplicate preserving order + seen: set[str] = set() + canonical: list[str] = [] + for n in names: + c = cls.resolve_alias(n) + if c and c not in seen: + canonical.append(c) + seen.add(c) + + # Default quantization algorithm if none was specified + if not ({"awq", "rtn", "auto_round"} & seen): + canonical.append("rtn" if getattr(args, "iters", 0) == 0 else "auto_round") + + return [cls.get(name).build(args, common_kwargs) for name in canonical] + + @classmethod + def format_listing(cls) -> str: + """Render the short `list alg` output.""" + lines = [] + for entry in iter_algorithm_entries(): + if entry.cli_handler is None: + continue + handler = entry.cli_handler() + other = [a for a in handler.aliases if a != handler.name] + alias_str = f" (aliases: {', '.join(other)})" if other else "" + lines.append(f"- {handler.name}{alias_str}: {handler.summary}") + return "\n".join(lines) + + @classmethod + def format_detail(cls, name: str) -> str: + """Render detailed help text for one algorithm.""" + canon = cls.resolve_alias(name) + if canon is None: + supported = [entry.name for entry in iter_algorithm_entries() if entry.cli_handler is not None] + raise ValueError(f"Unknown algorithm '{name}'. " f"Supported: {', '.join(supported)}.") + handler = cls.get(canon) + lines = [f"{handler.name}: {handler.summary}"] + other = [a for a in handler.aliases if a != handler.name] + if other: + lines.append(f"Aliases: {', '.join(other)}") + temp = argparse.ArgumentParser(add_help=False) + group = temp.add_argument_group(f"Flags for {handler.name}") + handler.register(group) + for action in group._group_actions: + flags = ", ".join(action.option_strings) + default = f" (default: {action.default})" if action.default is not None else "" + lines.append(f" {flags}: {action.help or ''}{default}") + return "\n".join(lines) + + +# ============================================================================ +# Helpers +# ============================================================================ + + +def _parse_bool_or_mode(value: str) -> bool | str: + """Parse AWQ duo_scaling's tri-state: true / false / both.""" + lowered = value.strip().lower() + if lowered == "true": + return True + if lowered == "false": + return False + if lowered == "both": + return "both" + raise argparse.ArgumentTypeError("Expected one of: true, false, both") + + +# ============================================================================ +# Algorithm implementations (auto-registered via __init_subclass__) +# ============================================================================ + + +class AWQ(AlgorithmHandler): + name = "awq" + aliases = ("awq",) + summary = "Activation-Aware Weight Quantization (pre-processing)." + config_factory = None + + def register(self, group) -> None: + group.add_argument( + "--awq-duo-scaling", + dest="awq_duo_scaling", + default=True, + type=_parse_bool_or_mode, + metavar="{true,false,both}", + help="Use activation+weight duo scaling (true/false/both).", + ) + group.add_argument( + "--awq-n-grid", + dest="awq_n_grid", + default=20, + type=int, + help="Number of grid-search points for AWQ scaling ratio.", + ) + + def build(self, args, common_kwargs: dict[str, Any]): + from auto_round.algorithms.transforms.awq.config import AWQConfig + + return AWQConfig( + duo_scaling=getattr(args, "awq_duo_scaling", True), + n_grid=getattr(args, "awq_n_grid", 20), + **common_kwargs, + ) + + +class RTN(AlgorithmHandler): + name = "rtn" + aliases = ("rtn",) + summary = "Round-To-Nearest quantization." + config_factory = None + + def register(self, group) -> None: + mutex = group.add_mutually_exclusive_group() + mutex.add_argument( + "--disable_opt_rtn", + dest="disable_opt_rtn", + default=None, + action="store_const", + const=True, + help="Force plain RTN (disable optimized path).", + ) + mutex.add_argument( + "--enable_opt_rtn", + dest="disable_opt_rtn", + action="store_const", + const=False, + help="Force optimized RTN path.", + ) + + def build(self, args, common_kwargs: dict[str, Any]): + from auto_round.algorithms.quantization.rtn.config import RTNConfig + + return RTNConfig( + disable_opt_rtn=getattr(args, "disable_opt_rtn", None), + **common_kwargs, + ) + + +class AutoRound(AlgorithmHandler): + name = "auto_round" + aliases = ("auto_round", "autoround", "sign_round", "signround") + summary = "SignRound-style iterative block quantization." + config_factory = None + + def register(self, group) -> None: + group.add_argument( + "--iters", "--iter", default=None, type=int, help="Number of optimization iterations per block." + ) + group.add_argument("--lr", default=None, type=float, help="Learning rate for rounding optimization.") + group.add_argument("--minmax_lr", default=None, type=float, help="Learning rate for min-max tuning.") + group.add_argument("--momentum", default=0.0, type=float, help="Momentum factor for the optimizer.") + group.add_argument("--nblocks", default=1, type=int, help="Number of blocks to optimize together.") + group.add_argument( + "--enable_minmax_tuning", + default=True, + action=argparse.BooleanOptionalAction, + help="Tune weight min/max ranges.", + ) + group.add_argument( + "--enable_norm_bias_tuning", + default=False, + action=argparse.BooleanOptionalAction, + help="Tune normalization and bias terms.", + ) + group.add_argument( + "--gradient_accumulate_steps", default=1, type=int, help="Gradient accumulation steps per update." + ) + group.add_argument( + "--enable_alg_ext", + default=False, + action=argparse.BooleanOptionalAction, + help="Enable experimental SignRound extension.", + ) + group.add_argument( + "--not_use_best_mse", + default=False, + action=argparse.BooleanOptionalAction, + help="Skip restoring best-MSE checkpoint.", + ) + group.add_argument( + "--enable_quanted_input", + default=True, + action=argparse.BooleanOptionalAction, + help="Consume quantized output of previous blocks.", + ) + group.add_argument( + "--enable_adam", + default=False, + action=argparse.BooleanOptionalAction, + help="Use the Adam-based SignRound variant.", + ) + + def build(self, args, common_kwargs: dict[str, Any]): + from auto_round.algorithms.quantization.sign_round.config import SignRoundConfig + + return SignRoundConfig( + iters=getattr(args, "iters", 200), + lr=getattr(args, "lr", None), + minmax_lr=getattr(args, "minmax_lr", None), + momentum=getattr(args, "momentum", 0.0), + nblocks=getattr(args, "nblocks", 1), + enable_minmax_tuning=getattr(args, "enable_minmax_tuning", True), + enable_norm_bias_tuning=getattr(args, "enable_norm_bias_tuning", False), + gradient_accumulate_steps=getattr(args, "gradient_accumulate_steps", 1), + enable_alg_ext=getattr(args, "enable_alg_ext", False), + not_use_best_mse=getattr(args, "not_use_best_mse", False), + enable_quanted_input=getattr(args, "enable_quanted_input", True), + enable_adam=getattr(args, "enable_adam", False), + **common_kwargs, + ) + + +class Hadamard(AlgorithmHandler): + name = "hadamard" + aliases = ("hadamard", "random_hadamard", "quarot_hadamard") + summary = "Hadamard rotation/transform applied before quantization." + config_factory = None + + def register(self, group) -> None: + group.add_argument( + "--rotation_type", + "--rotation-hadamard-type", + dest="rotation_hadamard_type", + default=None, + choices=["hadamard", "random_hadamard", "quarot_hadamard"], + help="Hadamard transform variant.", + ) + group.add_argument( + "--rotation_backend", + dest="rotation_backend", + default="auto", + choices=["auto", "inplace", "transform"], + help="Rotation backend to use.", + ) + group.add_argument( + "--rotation_block_size", + dest="rotation_block_size", + default=None, + type=int, + help="Grouped Hadamard block size.", + ) + group.add_argument( + "--fuse_online_to_weight", + default=None, + action=argparse.BooleanOptionalAction, + help="Fuse online Hadamard rotation into weights.", + ) + group.add_argument( + "--allow_online_rotation", + default=True, + action=argparse.BooleanOptionalAction, + help="Allow online activation rotation.", + ) + + def build(self, args, common_kwargs: dict[str, Any]): + from auto_round.algorithms.transforms.quarot.config import RotationConfig + + hadamard_type = getattr(args, "rotation_hadamard_type", None) or "hadamard" + return RotationConfig( + hadamard_type=hadamard_type, + backend=getattr(args, "rotation_backend", "auto"), + block_size=getattr(args, "rotation_block_size", None), + fuse_online_to_weight=getattr(args, "fuse_online_to_weight", None), + allow_online_rotation=getattr(args, "allow_online_rotation", True), + ) + + +def _register_builtin_algorithm_factories() -> None: + from auto_round.algorithms.quantization.rtn.config import RTNConfig + from auto_round.algorithms.quantization.sign_round.config import SignRoundConfig + from auto_round.algorithms.transforms.awq.config import AWQConfig + from auto_round.algorithms.transforms.quarot.config import RotationConfig + + register_algorithm("rtn", aliases=("rtn",), config_factory=RTNConfig, cli_handler=RTN, summary=RTN.summary) + register_algorithm( + "auto_round", + aliases=("auto_round", "autoround", "sign_round", "signround"), + config_factory=SignRoundConfig, + cli_handler=AutoRound, + summary=AutoRound.summary, + ) + register_algorithm("awq", aliases=("awq",), config_factory=AWQConfig, cli_handler=AWQ, summary=AWQ.summary) + register_algorithm( + "hadamard", + aliases=("hadamard", "random_hadamard", "quarot_hadamard"), + config_factory=RotationConfig, + cli_handler=Hadamard, + summary=Hadamard.summary, + alias_factories={ + "random_hadamard": lambda: RotationConfig(hadamard_type="random_hadamard"), + "quarot_hadamard": lambda: RotationConfig(hadamard_type="quarot_hadamard"), + }, + ) + + +_register_builtin_algorithm_factories() diff --git a/auto_round/cli/main.py b/auto_round/cli/main.py new file mode 100644 index 000000000..f44f498ea --- /dev/null +++ b/auto_round/cli/main.py @@ -0,0 +1,475 @@ +# # Copyright (C) 2026 Intel Corporation +# # SPDX-License-Identifier: Apache-2.0 + +"""CLI entry points: command routing, RECIPES, tune, eval, list. + +This module is the single place that wires together: + parser.py - argparse construction (all flag declarations here) + algorithms.py - algorithm config building + +All console_scripts (auto_round, auto-round-best, etc.) point here. +""" + +from __future__ import annotations + +import argparse +import sys + +from auto_round.cli.algorithms import AlgorithmHandler +from auto_round.cli.parser import ( + add_common_quantization_arguments, + build_eval_parser, + build_list_parser, + build_quantize_parser, + build_root_parser, +) + + +def _extract_common_quantization_kwargs(args) -> dict: + """Map parsed CLI args back to QuantizationConfig constructor kwargs. + + Handles inverted flags: --asym -> sym, --act_asym -> act_sym, + --disable_act_dynamic -> act_dynamic. + When the flag was not set (False), the value is None (defer to scheme). + """ + return { + "bits": args.bits, + "group_size": args.group_size, + "sym": None if not args.asym else False, + "data_type": args.data_type, + "act_bits": args.act_bits, + "act_group_size": args.act_group_size, + "act_sym": None if not args.act_asym else False, + "act_data_type": args.act_data_type, + "act_dynamic": None if not args.disable_act_dynamic else False, + "super_bits": args.super_bits, + "super_group_size": args.super_group_size, + } + + +def _build_entry_base_kwargs( + args, *, model_name, scheme, low_cpu_mem_usage, enable_torch_compile, layer_config +) -> dict: + return { + "model": model_name, + "platform": args.platform, + "format": args.format, + "scheme": scheme, + "dataset": args.dataset, + "seqlen": args.seqlen, + "nsamples": args.nsamples, + "batch_size": args.batch_size, + "low_gpu_mem_usage": args.low_gpu_mem_usage, + "low_cpu_mem_usage": low_cpu_mem_usage, + "device_map": args.device_map, + "enable_torch_compile": enable_torch_compile, + "seed": args.seed, + "layer_config": layer_config, + "model_dtype": args.model_dtype, + "trust_remote_code": not args.disable_trust_remote_code, + } + + +def _build_entry_route_kwargs(args) -> dict: + return { + "model_free": args.model_free, + "disable_model_free": args.disable_model_free, + } + + +def _build_entry_compressor_kwargs(args) -> dict: + return { + "scale_dtype": args.scale_dtype, + "ignore_layers": args.ignore_layers, + "quant_lm_head": args.quant_lm_head, + "to_quant_block_names": args.to_quant_block_names, + } + + +def _build_entry_model_type_kwargs(args) -> dict: + return { + "quant_nontext_module": args.quant_nontext_module, + "extra_data_dir": args.extra_data_dir, + "template": args.template, + "guidance_scale": args.guidance_scale, + "num_inference_steps": args.num_inference_steps, + "generator_seed": args.generator_seed, + } + + +def _to_autoround_kwargs(args, *, model_name, scheme, low_cpu_mem_usage, enable_torch_compile, layer_config) -> dict: + """Collect only the kwargs accepted by the new AutoRound entry API.""" + kwargs = _build_entry_base_kwargs( + args, + model_name=model_name, + scheme=scheme, + low_cpu_mem_usage=low_cpu_mem_usage, + enable_torch_compile=enable_torch_compile, + layer_config=layer_config, + ) + kwargs.update(_build_entry_route_kwargs(args)) + kwargs.update(_build_entry_compressor_kwargs(args)) + kwargs.update(_build_entry_model_type_kwargs(args)) + return kwargs + + +RECIPES = { + "default": { + "batch_size": 8, + "iters": 200, + "seqlen": 2048, + "nsamples": 128, + "lr": None, + }, + "best": { + "batch_size": 8, + "iters": 1000, + "seqlen": 2048, + "nsamples": 512, + "lr": None, + }, + "light": { + "batch_size": 8, + "iters": 50, + "seqlen": 2048, + "nsamples": 128, + "lr": 5e-3, + }, + "rtn": {"batch_size": 8, "iters": 0, "seqlen": 2048, "nsamples": 1, "lr": None, "disable_opt_rtn": True}, + "opt_rtn": {"batch_size": 8, "iters": 0, "seqlen": 2048, "nsamples": 128, "lr": None, "disable_opt_rtn": False}, +} + +# ============================================================================ +# list subcommand +# ============================================================================ + + +def list_item(argv=None): + args = build_list_parser().parse_args(argv) + if args.item in {"format", "formats"}: + from auto_round.formats import OutputFormat + + print("AutoRound supported output formats and quantization scheme:") + print(OutputFormat.get_support_matrix()) + elif args.item in {"alg", "algs", "algorithm", "algorithms"}: + if args.name: + print(AlgorithmHandler.format_detail(args.name)) + else: + print("AutoRound supported algorithms:") + print(AlgorithmHandler.format_listing()) + print("\nUse `auto_round list alg ` or `auto_round --algorithm --help` for details.") + else: + raise ValueError(f"Unsupported list target: {args.item}") + + +# ============================================================================ +# quantize subcommand +# ============================================================================ + + +def _print_algorithm_help(argv: list[str]) -> bool: + """If --algorithm --help is present, print algorithm-focused help and return True.""" + if not any(flag in argv for flag in ("-h", "--help")): + return False + + # Pre-parse just --algorithm + + pre = argparse.ArgumentParser(add_help=False) + pre.add_argument("--algorithm", default=None) + known, _ = pre.parse_known_args(argv) + names = [n.strip().lower() for n in (known.algorithm or "").split(",") if n.strip()] + if not names: + return False + + # Resolve aliases to canonical names, silently ignore unknowns + canonical_names: list[str] = [] + for name in names: + canon = AlgorithmHandler.resolve_alias(name) + if canon and canon not in canonical_names: + canonical_names.append(canon) + if not canonical_names: + return False + + mini = argparse.ArgumentParser( + prog=f"auto_round --algorithm {','.join(canonical_names)}", + description=f"Flags for algorithm(s): {', '.join(canonical_names)}. " + "Use `auto_round --help` for the full argument list.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + quant_group = mini.add_argument_group("Common Quantization Arguments") + add_common_quantization_arguments(quant_group) + for name in canonical_names: + alg_group = mini.add_argument_group(f"Algorithm: {name}") + AlgorithmHandler.get(name).register(alg_group) + mini.print_help() + return True + + +def start(recipe="default", argv=None): + recipe_defaults = RECIPES[recipe] + argv = list(sys.argv[1:] if argv is None else argv) + + if _print_algorithm_help(argv): + return + + parser = build_quantize_parser(prog="auto_round quantize") + args = parser.parse_args(argv) + + # Apply recipe defaults for fields the user didn't set + for key, value in recipe_defaults.items(): + if getattr(args, key, None) is None: + setattr(args, key, value) + + tune(args) + + +def tune(args): + assert args.model or args.model_name, "[model] or --model MODEL_NAME should be set." + if args.model is None: + args.model = args.model_name + if args.eval_bs is None: + args.eval_bs = "auto" + + from transformers.utils.versions import require_version + + if args.tasks is not None: + require_version( + "lm_eval>=0.4.2", + "lm-eval is required for evaluation, please install it with `pip install 'lm-eval>=0.4.2'`", + ) + + from auto_round.utils import logger + + if args.low_cpu_mem_usage: + logger.warning( + "`low_cpu_mem_usage` is deprecated and is now enabled by default. " + "To disable it, use `--disable_low_cpu_mem_usage`." + ) + + if args.format is None: + args.format = "auto_round" + + formats = args.format.lower().replace(" ", "").split(",") + from auto_round.utils import SUPPORTED_FORMATS + + for fmt in formats: + if fmt not in SUPPORTED_FORMATS: + raise ValueError(f"{fmt} is not supported, we only support {SUPPORTED_FORMATS}") + + if "auto_gptq" in args.format and args.asym is True: + logger.warning( + "the auto_gptq kernel has issues with asymmetric quantization. " + "It is recommended to use sym quantization or --format='auto_round'" + ) + + if "marlin" in args.format and args.asym is True: + raise RuntimeError("marlin backend only supports sym quantization, please remove --asym") + + from auto_round.utils import get_device_and_parallelism + + device_str, use_auto_mapping = get_device_and_parallelism(args.device_map) + + if args.enable_torch_compile: + logger.info( + "`torch.compile` is enabled to reduce tuning costs. " + "If it causes issues, you can disable it by removing `--enable_torch_compile` argument." + ) + + model_name = args.model + if model_name[-1] == "/": + model_name = model_name[:-1] + logger.info(f"start to quantize {model_name}") + + from auto_round.compressors.base import BaseCompressor + from auto_round.compressors.entry import AutoRound as PipelineAutoRound + + if "bloom" in model_name: + args.low_gpu_mem_usage = False + + if args.quant_lm_head: + for fmt in formats: + if "auto_round" not in fmt and "fake" not in fmt and "mlx" not in fmt: + auto_round_formats = [s for s in SUPPORTED_FORMATS if s.startswith("auto_round") or s == "mlx"] + raise ValueError( + f"{fmt} is not supported for lm-head quantization, please change to {auto_round_formats}" + ) + + enable_torch_compile = True if "--enable_torch_compile" in sys.argv else False + scheme = args.scheme.upper() + + from auto_round.schemes import PRESET_SCHEMES + + if scheme not in PRESET_SCHEMES: + raise ValueError(f"{scheme} is not supported. only {PRESET_SCHEMES.keys()} are supported ") + + if args.disable_deterministic_algorithms: + logger.warning( + "default not use deterministic_algorithms. disable_deterministic_algorithms is deprecated," + " please use enable_deterministic_algorithms instead. " + ) + + from auto_round.utils import parse_layer_config_arg + + layer_config = {} + if args.layer_config: + layer_config = parse_layer_config_arg(args.layer_config) + args.layer_config = layer_config + + low_cpu_mem_usage = True + if args.disable_low_cpu_mem_usage: + low_cpu_mem_usage = False + + from auto_round.auto_scheme import AutoScheme + + if args.avg_bits is not None: + if args.options is None: + raise ValueError("please set --options for auto scheme") + if enable_torch_compile: + logger.warning( + "`enable_torch_compile=True` with AutoScheme may cause compile errors " + "on some models. If so, try removing `--enable_torch_compile`." + ) + scheme = AutoScheme( + options=args.options, + avg_bits=args.avg_bits, + shared_layers=args.shared_layers, + ignore_scale_zp_bits=args.ignore_scale_zp_bits, + low_gpu_mem_usage=True, + low_cpu_mem_usage=low_cpu_mem_usage, + ) + + common_kwargs = _extract_common_quantization_kwargs(args) + alg_configs = AlgorithmHandler.build_configs(args, common_kwargs) + + from auto_round.utils import clear_memory + + autoround: BaseCompressor = PipelineAutoRound( + alg_configs=alg_configs if len(alg_configs) > 1 else alg_configs[0], + **_to_autoround_kwargs( + args, + model_name=model_name, + scheme=scheme, + low_cpu_mem_usage=low_cpu_mem_usage, + enable_torch_compile=enable_torch_compile, + layer_config=layer_config, + ), + ) + + model, folders = autoround.quantize_and_save(args.output_dir, format=args.format) # pylint: disable=no-member + tokenizer = autoround.tokenizer # pylint: disable=no-member + clear_memory() + + from auto_round.eval.evaluation import run_model_evaluation + + run_model_evaluation(model, tokenizer, autoround, folders, formats, device_str, args) + + +# ============================================================================ +# eval subcommand +# ============================================================================ + + +def setup_eval_parser(argv=None): + parser = build_eval_parser(prog="auto_round eval") + return parser.parse_args(argv) + + +def run_eval(argv=None): + from auto_round.eval.eval_cli import eval, eval_task_by_task + from auto_round.logger import logger + from auto_round.utils import is_gguf_model, is_mllm_model + + args = setup_eval_parser(argv) + assert args.model or args.model_name, "[model] or --model MODEL_NAME should be set." + + if args.model is None: + args.model = args.model_name + if "llama" in args.model.lower() and not args.add_bos_token: + logger.warning("set add_bos_token=True for llama model.") + args.add_bos_token = True + if not is_gguf_model(args.model) and is_mllm_model(args.model): + args.mllm = True + + if args.eval_task_by_task: + eval_task_by_task( + model=args.model, + device=args.device_map, + limit=args.limit, + tasks=args.tasks, + batch_size=args.eval_bs, + trust_remote_code=not args.disable_trust_remote_code, + eval_model_dtype=args.eval_model_dtype, + add_bos_token=args.add_bos_token, + ) + else: + eval(args) + + +# ============================================================================ +# Command routing +# ============================================================================ + + +def _normalize_cli_invocation(argv): + """Normalize legacy invocation styles to (command, rest_argv).""" + if argv and argv[0] in {"quantize", "list", "eval", "help"}: + return argv[0], argv[1:] + if "--list" in argv: + normalized = list(argv) + normalized.remove("--list") + return "list", normalized + if "--eval" in argv: + normalized = list(argv) + normalized.remove("--eval") + return "eval", normalized + return "quantize", list(argv) + + +def _print_help(topic=None): + if topic == "quantize": + build_quantize_parser(prog="auto_round quantize").print_help() + return + if topic == "list": + build_list_parser(prog="auto_round list").print_help() + return + if topic == "eval": + build_eval_parser(prog="auto_round eval").print_help() + return + build_root_parser().print_help() + + +def run(): + argv = list(sys.argv[1:]) + command, command_argv = _normalize_cli_invocation(argv) + + if command == "help": + root_args = build_root_parser().parse_args(argv) + _print_help(root_args.topic) + return + if command == "list": + list_item(command_argv) + return + if command == "eval": + run_eval(command_argv) + return + start(argv=command_argv) + + +def run_best(): + start("best") + + +def run_light(): + start("light") + + +def run_rtn(): + start("rtn") + + +def run_opt_rtn(): + start("opt_rtn") + + +def run_mllm(): + run() diff --git a/auto_round/cli/parser.py b/auto_round/cli/parser.py new file mode 100644 index 000000000..c823022ca --- /dev/null +++ b/auto_round/cli/parser.py @@ -0,0 +1,273 @@ +# # Copyright (C) 2026 Intel Corporation +# # SPDX-License-Identifier: Apache-2.0 + +"""Argparse parser construction for all CLI subcommands. + +All static flags are declared directly via add_argument() — no intermediate +dataclass wrappers. To add a new flag, find the right argument group below +and add a line. + +Public API: + build_quantize_parser() — the main quantize parser + build_list_parser() — the `list` subcommand parser + build_eval_parser() — the `eval` subcommand parser + build_root_parser() — top-level router +""" + +from __future__ import annotations + +import argparse + +from auto_round.cli.algorithms import AlgorithmHandler +from auto_round.eval.eval_cli import EvalArgumentParser + + +def _parse_group_size(s: str): + """Parse group_size: a plain int, or comma-separated ints for block-wise fp8.""" + if s.lstrip("-").isdigit(): + return int(s) + return tuple(int(x.strip()) for x in s.split(",")) + + +def add_common_quantization_arguments(group) -> None: + """Add common quantization flags to an argparse group. + + To add a new flag, append an add_argument() call here and mirror it in + _extract_common_quantization_kwargs() in main.py. + """ + group.add_argument("--scheme", default="W4A16", type=str, help="Quantization scheme preset, e.g. W4A16, W8A16.") + group.add_argument("--bits", default=None, type=int, help="Weight quantization bit width.") + group.add_argument( + "--group_size", + default=None, + type=_parse_group_size, + help="Weight group size: positive int, -1 per-channel, 0 per-tensor.", + ) + group.add_argument( + "--asym", default=False, action="store_true", help="Use asymmetric weight quantization instead of symmetric." + ) + group.add_argument( + "--data_type", "--dtype", default=None, type=str, help="Weight quantization data type, e.g. int, fp8." + ) + group.add_argument("--act_bits", default=None, type=int, help="Activation quantization bit width.") + group.add_argument("--act_group_size", default=None, type=int, help="Activation quantization group size.") + group.add_argument( + "--act_asym", + default=False, + action="store_true", + help="Use asymmetric activation quantization instead of symmetric.", + ) + group.add_argument( + "--act_data_type", "--act_dtype", default=None, type=str, help="Activation quantization data type." + ) + group.add_argument( + "--disable_act_dynamic", + default=False, + action="store_true", + help="Use static activation quantization instead of dynamic.", + ) + group.add_argument("--super_bits", default=None, type=int, help="Bit width for double quantization metadata.") + group.add_argument( + "--super_group_size", default=None, type=int, help="Group size for double quantization metadata." + ) + group.add_argument( + "--scale_dtype", + default=None, + choices=["fp16", "float16", "bf16", "bfloat16", "fp32", "float32"], + help="Data type used to store quantization scales.", + ) + group.add_argument( + "--ignore_layers", + "--fp_layers", + default="", + type=str, + help="Comma-separated layer names to keep in higher precision.", + ) + group.add_argument( + "--quant_lm_head", default=False, action=argparse.BooleanOptionalAction, help="Quantize the lm_head module." + ) + group.add_argument( + "--to_quant_block_names", default=None, type=str, help="Comma-separated subset of block names to quantize." + ) + + +def build_quantize_parser(*, prog: str = "auto_round quantize") -> argparse.ArgumentParser: + """Build the quantize parser with all argument groups.""" + parser = argparse.ArgumentParser(prog=prog) + + # ---- Model / Runtime ---- + rt = parser.add_argument_group("Runtime Arguments") + rt.add_argument("model", default=None, nargs="?", help="Path to the pre-trained model or Hugging Face model id.") + rt.add_argument( + "--model_name", + "--model", + "--model_name_or_path", + default="facebook/opt-125m", + help="Path to the pre-trained model or Hugging Face model id.", + ) + rt.add_argument("--model_dtype", default=None, help="Model dtype used when loading the model.") + rt.add_argument("--platform", default="hf", help="Model loading platform. Options: hf or model_scope.") + rt.add_argument( + "--batch_size", "--train_bs", "--bs", default=None, type=int, help="Batch size for calibration and tuning." + ) + rt.add_argument("--seqlen", "--seq_len", default=None, type=int, help="Sequence length of the calibration samples.") + rt.add_argument("--nsamples", "--nsample", default=None, type=int, help="Number of calibration samples to use.") + rt.add_argument( + "--device_map", "--device", "--devices", default="0", type=str, help="Device mapping used for quantization." + ) + rt.add_argument( + "--dataset", default="NeelNanda/pile-10k", type=str, help="Calibration dataset or local dataset path." + ) + rt.add_argument("--seed", default=42, type=int, help="Random seed for reproducibility.") + rt.add_argument( + "--format", "--formats", default="auto_round", type=str, help="Output format for the quantized model." + ) + rt.add_argument( + "--algorithm", default=None, type=str, help="Comma-separated algorithms such as 'awq' or 'awq,auto_round'." + ) + rt.add_argument("--output_dir", default="./tmp_autoround", type=str, help="Directory to save quantized artifacts.") + rt.add_argument("--avg_bits", "--target_bits", default=None, type=float, help="Average target bits for AutoScheme.") + rt.add_argument("--options", default=None, type=str, help="AutoScheme options, for example 'W4A16,W8A16'.") + rt.add_argument( + "--low_gpu_mem_usage", action="store_true", help="Enable memory-efficient mode by offloading features to CPU." + ) + rt.add_argument( + "--low_cpu_mem_usage", + action="store_true", + help="Deprecated compatibility flag. Low CPU memory mode is enabled by default.", + ) + rt.add_argument("--disable_low_cpu_mem_usage", action="store_true", help="Disable low CPU memory mode.") + rt.add_argument("--enable_torch_compile", action="store_true", help="Enable torch.compile during quantization.") + rt.add_argument( + "--disable_trust_remote_code", action="store_true", help="Disable trust_remote_code when loading models." + ) + rt.add_argument( + "--layer_config", default=None, type=str, help="Per-layer quantization overrides encoded as JSON-like text." + ) + rt.add_argument( + "--shared_layers", + type=str, + nargs="+", + action="append", + default=None, + help="Ensure listed layers share the same quantization data type.", + ) + rt.add_argument( + "--static_kv_dtype", + default=None, + type=str, + choices=["fp8", "float8_e4m3fn"], + help="Static KV-cache quantization data type.", + ) + rt.add_argument( + "--static_attention_dtype", + default=None, + type=str, + choices=["fp8", "float8_e4m3fn"], + help="Static attention quantization data type.", + ) + + # ---- Evaluation ---- + ev = parser.add_argument_group("Evaluation Arguments") + ev.add_argument( + "--tasks", + "--task", + nargs="?", + const="lambada_openai,hellaswag,winogrande,piqa,mmlu,wikitext," + "truthfulqa_mc1,openbookqa,boolq,arc_easy,arc_challenge", + default=None, + help="LM-Eval tasks to run after quantization.", + ) + ev.add_argument("--eval_bs", default=None, type=int, help="Batch size for evaluation.") + ev.add_argument( + "--limit", type=float, default=None, metavar="N|0 argparse.ArgumentParser: + parser = argparse.ArgumentParser(prog=prog) + parser.add_argument("item", type=str, help="item to list, e.g., format, alg") + parser.add_argument("name", nargs="?", default=None, help="optional specific format/algorithm name") + return parser + + +def build_eval_parser(*, prog: str = "auto_round eval") -> argparse.ArgumentParser: + return EvalArgumentParser(prog=prog) + + +def build_root_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + prog="auto_round", + description="AutoRound command line interface.", + ) + subparsers = parser.add_subparsers(dest="command") + subparsers.add_parser("quantize", help="Quantize a model.", add_help=False) + subparsers.add_parser("list", help="List supported algorithms or formats.", add_help=False) + subparsers.add_parser("eval", help="Evaluate a model.", add_help=False) + help_parser = subparsers.add_parser("help", help="Show help for the CLI or a subcommand.") + help_parser.add_argument("topic", nargs="?", choices=["quantize", "list", "eval"], default=None) + return parser diff --git a/auto_round/compressors/__init__.py b/auto_round/compressors/__init__.py index 531fc237f..a9b4467af 100644 --- a/auto_round/compressors/__init__.py +++ b/auto_round/compressors/__init__.py @@ -12,20 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Lazy imports to avoid circular dependencies -# Users should import from specific modules instead of this __init__.py +# Lazy imports to avoid circular dependencies. from typing import TYPE_CHECKING if TYPE_CHECKING: from auto_round.compressors.base import BaseCompressor - from auto_round.compressors.config import ( - DiffusionExtraConfig, - ExtraConfig, - MLLMExtraConfig, - SchemeExtraConfig, - TuningExtraConfig, - ) from auto_round.compressors.data_driven import CalibratedRTNCompressor, DataDrivenCompressor from auto_round.compressors.entry import AutoRoundCompatible, AutoRound from auto_round.compressors.model_free import ModelFreeCompressor @@ -39,11 +31,6 @@ "ZeroShotCompressor", "AutoRoundCompatible", "ModelFreeCompressor", - "ExtraConfig", - "TuningExtraConfig", - "SchemeExtraConfig", - "MLLMExtraConfig", - "DiffusionExtraConfig", ] @@ -74,20 +61,4 @@ def __getattr__(name): from auto_round.compressors.model_free import ModelFreeCompressor return ModelFreeCompressor - elif name in ("ExtraConfig", "TuningExtraConfig", "SchemeExtraConfig", "MLLMExtraConfig", "DiffusionExtraConfig"): - from auto_round.compressors.config import ( - DiffusionExtraConfig, - ExtraConfig, - MLLMExtraConfig, - SchemeExtraConfig, - TuningExtraConfig, - ) - - return { - "ExtraConfig": ExtraConfig, - "TuningExtraConfig": TuningExtraConfig, - "SchemeExtraConfig": SchemeExtraConfig, - "MLLMExtraConfig": MLLMExtraConfig, - "DiffusionExtraConfig": DiffusionExtraConfig, - }[name] raise AttributeError(f"module '{__name__}' has no attribute '{name}'") diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index f970d89f9..92143566a 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -21,8 +21,7 @@ import torch from transformers import AutoConfig, set_seed -from auto_round.algorithms.alg_config import AlgConfig -from auto_round.algorithms.quantization import BaseQuantizers, QuantizationConfig +from auto_round.algorithms.quantization import BaseQuantizer, QuantizationConfig from auto_round.algorithms.transforms import ( BaseRotationConfig, apply_rotation, @@ -107,6 +106,47 @@ class SerializedCompressorConfig: SERIALIZATION_KEYS = tuple(field.name for field in fields(SerializedCompressorConfig)) +def collect_user_scheme_overrides(configs: list) -> dict[str, Any]: + scheme_fields = {f.name for f in fields(QuantizationScheme)} + user_scheme_overrides = {} + user_scheme_sources = {} + for config in configs: + for key in getattr(config, "_user_set_scheme_fields", set()): + if key not in scheme_fields: + continue + value = getattr(config, key, None) + if value is None: + continue + if key in user_scheme_overrides and value != user_scheme_overrides[key]: + prev_config, prev_value = user_scheme_sources[key] + raise ValueError( + f"Conflicting shared scheme field {key!r}: " + f"{type(prev_config).__name__}.{key}={prev_value!r}, " + f"{type(config).__name__}.{key}={value!r}. " + "Use the same value for shared fields or pass scheme arguments through Compressor." + ) + user_scheme_overrides[key] = value + user_scheme_sources[key] = (config, value) + return user_scheme_overrides + + +def _make_compressor_scheme_property(name): + def getter(self): + scheme_context = getattr(self, "scheme_context", None) + if scheme_context is not None: + return getattr(scheme_context, name) + return self.__dict__.get(name, getattr(type(self), name, None)) + + def setter(self, value): + scheme_context = getattr(self, "scheme_context", None) + if scheme_context is not None: + setattr(scheme_context, name, value) + else: + self.__dict__[name] = value + + return property(getter, setter) + + class BaseCompressor(object): need_calib: bool = True compress_context: CompressContext = None @@ -129,6 +169,9 @@ class BaseCompressor(object): quant_lm_head: bool = False _scheme_resolved: bool = False scheme_generator = None + _scheme_context_fields = set(QuantizationScheme.get_attributes()) + for _scheme_field in QuantizationScheme.get_attributes(): + locals()[_scheme_field] = _make_compressor_scheme_property(_scheme_field) @staticmethod def _preload_model_config(model: Union[torch.nn.Module, str], trust_remote_code: bool) -> Optional[AutoConfig]: @@ -148,7 +191,7 @@ def _preload_model_config(model: Union[torch.nn.Module, str], trust_remote_code: def __init__( self, - config: Union[AlgConfig, list[AlgConfig]], + config: Union[object, list[object]], model: Union[torch.nn.Module, str], tokenizer=None, platform="hf", @@ -162,6 +205,10 @@ def __init__( layer_config=None, nsamples: int = None, seqlen: int = None, + scale_dtype=None, + ignore_layers: str = "", + quant_lm_head: bool = False, + to_quant_block_names=None, **kwargs, ): # ``CalibrationState`` is the single source of truth for calibration @@ -188,10 +235,25 @@ def __init__( self.quantize_config = None self.rotation_configs: list[BaseRotationConfig] = [] _config_list = config if isinstance(config, list) else [config] - for _cfg in _config_list: - if isinstance(_cfg, QuantizationConfig): - self.quantize_config = _cfg - elif isinstance(_cfg, BaseRotationConfig): + # Keep full list for pipeline construction (includes preprocessor configs). + self._alg_configs: list = list(_config_list) + from auto_round.algorithms.pipeline import split_quantization_configs + + _preprocessor_configs, _block_quantizer_configs = split_quantization_configs(self._alg_configs) + if len(_block_quantizer_configs) > 1: + raise ValueError( + f"Only one block-quantization config is allowed, but got {len(_block_quantizer_configs)}: " + f"{[type(c).__name__ for c in _block_quantizer_configs]}" + ) + if _block_quantizer_configs: + self.quantize_config = _block_quantizer_configs[0] + elif _preprocessor_configs: + from auto_round.algorithms.quantization.rtn.config import RTNConfig as _RTNConfig + + self.quantize_config = _RTNConfig() + self._alg_configs.append(self.quantize_config) + for _cfg in self._alg_configs: + if isinstance(_cfg, BaseRotationConfig): self.rotation_configs.append(_cfg) assert self.quantize_config is not None, "QuantizationConfig is required for Compressor" @@ -200,6 +262,10 @@ def __init__( # ``self._calibration_state`` (seeded above) and exposed via # ``@property`` forwarders. self.layer_config = layer_config + self.scale_dtype = scale_dtype + self.ignore_layers = ignore_layers + self.quant_lm_head = quant_lm_head + self.to_quant_block_names = to_quant_block_names # ``post_init()`` may run before ``quantize_and_save()`` in tests and # compatibility paths, so seed the same default used by # ``quantize_and_save(..., inplace=True)`` here. @@ -207,6 +273,7 @@ def __init__( # Scheme is passed directly to the compressor, not stored in QuantizationConfig. self.scheme = scheme + self.scheme_context = None # Calibrator strategy (auto_round.calibration.base.Calibrator). Constructed # lazily by ``DataDrivenCompressor.post_init`` based on ``_get_calibrator_kind()``; @@ -323,9 +390,6 @@ def __init__( ) self.shard_writer = None - # scale_dtype is resolved in quantizer.resolve_scheme() after scheme resolution, - # so it is not initialized here to avoid premature evaluation with an unresolved scheme. - # Flag for post_init idempotency. Set to False here so post_init() can be called # either via quantize_and_save() (preferred, outside inference_mode) or directly # from quantize() as a fallback for non-AutoScheme cases. @@ -342,6 +406,13 @@ def __init__( self.has_variable_block_shape = False + # ── Convenience properties ──────────────────────────────────────────────── + + @property + def tokenizer(self): + """Convenience accessor for the tokenizer stored in ``model_context``.""" + return self.model_context.tokenizer + # ── Scheme resolution ───────────────────────────────────────────────────── def resolve_scheme(self, model_context=None, compress_context=None, dataset: str = None) -> None: @@ -360,18 +431,13 @@ def resolve_scheme(self, model_context=None, compress_context=None, dataset: str if dataset is not None: self.dataset = dataset - scheme_fields = {f.name for f in fields(QuantizationScheme)} - user_scheme_overrides = { - k: getattr(self.quantize_config, k) - for k in scheme_fields - if getattr(self.quantize_config, k, None) is not None - } + user_scheme_overrides = collect_user_scheme_overrides(self._alg_configs) default_scheme, self.is_auto_scheme, final_attrs = _parse_scheme(self.scheme, user_scheme_overrides) - for key, value in final_attrs.items(): - setattr(self.quantize_config, key, value) - if hasattr(self, key): - setattr(self, key, value) + self.scheme_context = QuantizationScheme.from_dict(final_attrs) + for config in self._alg_configs: + if hasattr(config, "scheme"): + config.scheme = self.scheme_context self.quantize_config.check_config() self.orig_scheme = copy.deepcopy(self.scheme) @@ -405,7 +471,6 @@ def _scheme_post_init(self) -> None: ) if self.to_quant_block_names is None and self.quant_block_list: self.to_quant_block_names = extract_block_names_to_str(self.quant_block_list) - self.quantize_config.to_quant_block_names = self.to_quant_block_names self.configure_layer_config(enable_gguf_official_mixed=enable_gguf_official_mixed) @@ -776,8 +841,6 @@ def _resolve_scheme(self) -> None: - ``self.quantize_config`` is a valid :class:`QuantizationConfig`. Work performed: - - Seeds scheme-related attrs (``scale_dtype``, ``ignore_layers``, - ``quant_lm_head``, ``to_quant_block_names``) from ``quantize_config``. - Calls :meth:`resolve_scheme` to derive ``data_type``, ``bits``, ``sym``, ``scale_dtype`` etc. and write them back to both ``self`` and ``self.quantize_config``. @@ -785,16 +848,8 @@ def _resolve_scheme(self) -> None: Postconditions: - ``self.scheme`` and ``self.quantize_config`` carry resolved scheme attrs. """ - cfg = self.quantize_config - self.scale_dtype = cfg.scale_dtype - # self.layer_config is already set from __init__ (direct compressor param). - self.ignore_layers = cfg.ignore_layers - self.quant_lm_head = cfg.quant_lm_head - self.to_quant_block_names = cfg.to_quant_block_names if self.to_quant_block_names is None: self.to_quant_block_names = getattr(self.model_context.model, "_autoround_to_quant_block_names", None) - if self.to_quant_block_names is not None: - self.quantize_config.to_quant_block_names = self.to_quant_block_names # Resolve the scheme (pure config work: sets data_type / bits / sym / # scale_dtype etc. on both self and self.quantize_config). @@ -814,18 +869,51 @@ def _build_quantizer(self) -> None: been synced back to ``self.quantize_config``. Work performed: - - Constructs ``self.quantizer`` from the resolved config. + - Constructs the block_quantizer from the resolved config. + - Wraps it in a :class:`~auto_round.algorithms.pipeline.QuantizationPipeline` + so that the entire compressor operates through the pipeline abstraction. - Calls ``quantizer.bind(self)`` so the quantizer pulls ``model_context`` / ``compress_context`` / ``scale_dtype`` / ``CalibrationState`` from this compressor. ``quantizer.model`` is a property that reads ``model_context.model``. + - Exposes ``self.quantizer`` as a ``@property`` (see below) that + transparently delegates to ``self.pipeline.block_quantizer`` so all + existing call-sites continue to work without modification. Postconditions: - - ``self.quantizer`` is ready and shares ``CalibrationState`` with - the compressor. + - ``self.pipeline`` is a ``QuantizationPipeline`` wrapping the block quantizer. + - ``self.quantizer`` (via property) is ready and shares ``CalibrationState`` + with the compressor. + """ + from auto_round.algorithms.pipeline import QuantizationPipeline + + self._pipeline = QuantizationPipeline.from_configs(self._alg_configs, compressor=self) + + @property + def quantizer(self) -> BaseQuantizer: + """Transparent forwarder to ``self.pipeline.block_quantizer``. + + All existing ``self.quantizer.xxx`` call-sites continue to work + unchanged. New code should prefer ``self.pipeline`` for pipeline-aware + operations. """ - self.quantizer = BaseQuantizers.from_config(self.quantize_config) - self.quantizer.bind(self) + _pipeline = self.__dict__.get("_pipeline") + if _pipeline is not None: + return _pipeline.block_quantizer + return self.__dict__["_quantizer"] + + @quantizer.setter + def quantizer(self, value) -> None: + _pipeline = self.__dict__.get("_pipeline") + if _pipeline is not None: + _pipeline.block_quantizer = value + else: + self.__dict__["_quantizer"] = value + + @property + def pipeline(self): + """The active :class:`~auto_round.algorithms.pipeline.QuantizationPipeline`.""" + return self._pipeline def _resolve_formats(self) -> None: """Phase 2 – Format resolution and config attr sync. @@ -1032,14 +1120,25 @@ def _build_layer_config(self) -> None: self.quantizer.scale_dtype = self.scale_dtype self.quantizer.ignore_layers = self.ignore_layers + from auto_round.algorithms.pipeline import sync_shared_config_from + + sync_shared_config_from(self.quantizer.config, [pre.config for pre in self._pipeline.preprocessors]) + + # Also sync runtime-only state to all preprocessors in the pipeline so + # they have access to per-layer quant config during pre-processing (e.g. + # AWQ grid search uses layer_config to look up bits/group_size for each layer). + for pre in self._pipeline.preprocessors: + pre.layer_config = self.layer_config + pre.scale_dtype = self.scale_dtype + def _hardware_setup(self) -> None: """Phase 5 – Hardware and compile configuration. Preconditions: - Phase 4 complete: ``layer_config`` is built and ``has_qlayer_outside_block`` is known. - - ``self.quantize_config.data_type`` is the final resolved value - (needed by :meth:`_finalize_torch_compile`). + - ``self.quantize_config.data_type`` is the final resolved value + (needed by :meth:`_finalize_torch_compile`). Work performed: - Applies the device map via :func:`~auto_round.utils.device.set_non_auto_device_map`. @@ -1085,10 +1184,31 @@ def __getattr__(self, name: str) -> Any: if name in self.__dict__: return self.__dict__[name] - for obj in ["quantizer", "quantize_config", "model_context", "compress_context"]: - if obj not in self.__dict__: + # Never proxy private/dunder attributes — they should be set explicitly + # in __init__. Proxying them hides bugs (e.g. missing _post_init_done) + # and can cause infinite recursion. + if name.startswith("_"): + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + + # Delegate to block_quantizer: access _pipeline directly from __dict__ to + # avoid recursion (quantizer is now a @property backed by _pipeline; going + # through the property inside __getattr__ would re-trigger __getattr__ + # if _pipeline itself isn't ready yet). + _pipeline = self.__dict__.get("_pipeline") + if _pipeline is not None: + try: + return object.__getattribute__(_pipeline.block_quantizer, name) + except AttributeError: + pass + + for attr in ["quantize_config", "model_context", "compress_context"]: + # These are regular instance attributes; use object.__getattribute__ + # so Python's normal descriptor protocol is used without re-entering + # __getattr__ on self. + try: + obj = object.__getattribute__(self, attr) + except AttributeError: continue - obj = object.__getattribute__(self, obj) try: return object.__getattribute__(obj, name) except AttributeError: @@ -1105,9 +1225,11 @@ def calibration_state(self): def calibration_state(self, value) -> None: self._calibration_state = value # Re-wire quantizer if it already exists so they keep sharing. - q = self.__dict__.get("quantizer") - if q is not None: - q.calibration_state = value + # quantizer is now a @property forwarding to _pipeline.block_quantizer; + # use _pipeline directly to avoid triggering __getattr__ loops. + _pipeline = self.__dict__.get("_pipeline") + if _pipeline is not None: + _pipeline.block_quantizer.calibration_state = value @property def inputs(self) -> dict: diff --git a/auto_round/compressors/config.py b/auto_round/compressors/config.py deleted file mode 100644 index dd028b6cd..000000000 --- a/auto_round/compressors/config.py +++ /dev/null @@ -1,296 +0,0 @@ -# Copyright (c) 2025 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from __future__ import annotations - -from dataclasses import dataclass, fields -from typing import Any, Callable, Optional, Union - -import torch - - -class ExtraConfig: - """Class for extra or legacy configs.""" - - _model_config = None - _scheme_config = None - _tuning_config = None - _mllm_config = None - _diffusion_config = None - - def __init__( - self, - # tuning - amp: bool = True, - disable_opt_rtn: bool | None = None, - enable_alg_ext: bool = False, - enable_minmax_tuning: bool = True, - enable_norm_bias_tuning: bool = False, - enable_quanted_input: bool = True, - enable_deterministic_algorithms: bool = False, - lr: float = None, - lr_scheduler: Callable = None, - minmax_lr: float = None, - nblocks: int = 1, - to_quant_block_names: Union[str, list, None] = None, - scale_dtype: str = "fp16", - # scheme - bits: int = None, - group_size: int = None, - sym: bool = None, - data_type: str = None, - act_bits: int = None, - act_group_size: int = None, - act_sym: bool = None, - act_data_type: str = None, - act_dynamic: bool = None, - super_bits: int = None, - super_group_size: int = None, - static_kv_dtype: Union[str, torch.dtype] = None, - quant_lm_head: bool = False, - ignore_layers: str = None, - # mllm - processor: Callable = None, - image_processor: Callable = None, - quant_nontext_module: bool = False, - extra_data_dir: str = None, - template: str = None, - # diffusion - guidance_scale: float = 7.5, - num_inference_steps: int = 50, - generator_seed: int = None, - ): - """Initialize - - Args: - amp (bool): Whether to use automatic mixed precision (default is True). - disable_opt_rtn (bool, optional): Disable RTN-mode optimization (iters=0). Defaults to True. - enable_alg_ext (bool, optional): Enable algorithm extension (primarily for INT2). Defaults to False. - enable_minmax_tuning (bool, optional): Enable weight min-max tuning. Defaults to True. - enable_norm_bias_tuning (bool): Whether to enable fast norm/layer_bias tuning. - enable_quanted_input (bool): Whether to use quantized input data (default is True). - enable_deterministic_algorithms (bool): Whether to use deterministic_algorithms. - lr (float): The learning rate (default is 0.005). - lr_scheduler: The learning rate scheduler to be used. - minmax_lr (float): The learning rate for min-max tuning (default is None). - nblocks (int): Number of blocks (default is 1). - quant_lm_head (bool): Whether to quant lm_head. - to_quant_block_names (str|list): Names of quantitative blocks, please use commas to separate them. - scale_dtype (str): The data type of quantization scale to be used (default is "float16"), different kernels - bits (int, optional): Weight quantization bits. Defaults to 4. - group_size (int, optional): Weight quantization group size. Defaults to 128. - sym (bool, optional): Symmetric weight quantization. Defaults to True. - data_type (str, optional): Weight data type string, e.g., "int". Defaults to "int". - act_bits (int, optional): Activation quantization bits. Defaults to 16. - act_group_size (int, optional): Activation group size. Defaults to None. - act_sym (bool, optional): Symmetric activation quantization. Defaults to None. - act_data_type (str, optional): Activation data type; inherits weight dtype if None and act_bits < 16. - act_dynamic (bool, optional): Dynamic activation quantization. Defaults to True. - super_bits (int): number of scale and mins quant bits for double quant. - super_group_size (int): the number of super group size when use double quant. - static_kv_dtype (str): The data type of kv-cache to be used. - processor: Any multi-modal model will require an object to encode or - decode the data that groups several modalities (among text, vision and audio). - image_processor: Image processor for special model like llava. - quant_nontext_module: Whether to quantize nontext module. - extra_data_dir: The path of extra data such as images, audio and videos. - template: The path or name of template used to specify process for different MLLMs. - guidance_scale (float): Control how much the image generation process follows the text prompt. - The more it is, the more closely it follows the prompt (default is 7.5). - num_inference_steps (int): The reference number of denoising steps (default is 50). - generator_seed (int): A seed that controls the initial noise for image generation (default is None). - """ - self.tuning_config = TuningExtraConfig( - amp=amp, - disable_opt_rtn=disable_opt_rtn, - enable_alg_ext=enable_alg_ext, - enable_minmax_tuning=enable_minmax_tuning, - enable_norm_bias_tuning=enable_norm_bias_tuning, - enable_quanted_input=enable_quanted_input, - enable_deterministic_algorithms=enable_deterministic_algorithms, - lr=lr, - lr_scheduler=lr_scheduler, - minmax_lr=minmax_lr, - nblocks=nblocks, - to_quant_block_names=to_quant_block_names, - scale_dtype=scale_dtype, - ) - self.scheme_config = SchemeExtraConfig( - bits=bits, - group_size=group_size, - sym=sym, - data_type=data_type, - act_bits=act_bits, - act_group_size=act_group_size, - act_sym=act_sym, - act_data_type=act_data_type, - act_dynamic=act_dynamic, - super_bits=super_bits, - super_group_size=super_group_size, - static_kv_dtype=static_kv_dtype, - quant_lm_head=quant_lm_head, - ignore_layers=ignore_layers, - ) - self.mllm_config = MLLMExtraConfig( - processor=processor, - image_processor=image_processor, - quant_nontext_module=quant_nontext_module, - extra_data_dir=extra_data_dir, - template=template, - ) - self.diffusion_config = DiffusionExtraConfig( - guidance_scale=guidance_scale, - num_inference_steps=num_inference_steps, - generator_seed=generator_seed, - ) - - @property - def tuning_config(self): - return self._tuning_config - - @tuning_config.setter - def tuning_config(self, config: TuningExtraConfig): - assert isinstance( - config, TuningExtraConfig - ), f"tuning_config should be ModelExtraConfig, but got {config.__class__.__name__}" - self._tuning_config = config - - @property - def scheme_config(self): - return self._scheme_config - - @scheme_config.setter - def scheme_config(self, config: SchemeExtraConfig): - assert isinstance( - config, SchemeExtraConfig - ), f"scheme_config should be SchemeExtraConfig, but got {config.__class__.__name__}" - self._scheme_config = config - - @property - def mllm_config(self): - return self._mllm_config - - @mllm_config.setter - def mllm_config(self, config: MLLMExtraConfig): - if config is None: - self._mllm_config = None - else: - assert isinstance( - config, MLLMExtraConfig - ), f"mllm_config should be MLLMExtraConfig, but got {config.__class__.__name__}" - self._mllm_config = config - - @property - def diffusion_config(self): - return self._diffusion_config - - @diffusion_config.setter - def diffusion_config(self, config: DiffusionExtraConfig): - if config is None: - self._diffusion_config = None - else: - assert isinstance( - config, DiffusionExtraConfig - ), f"diffusion_config should be DiffusionExtraConfig, but got {config.__class__.__name__}" - self._diffusion_config = config - - def to_dict(self): - output_dict = {} - for config in self.__dict__.values(): - if config: - output_dict.update(config.to_dict()) - return output_dict - - -@dataclass -class BaseExtraConfig: - - @classmethod - def get_attributes(cls: "BaseExtraConfig") -> list[str]: - return [field.name for field in fields(cls)] - - def __getitem__(self, key: str): - if key not in self.get_attributes(): - raise KeyError(f"{key} is not a valid attribute") - return getattr(self, key) - - def __setitem__(self, key: str, value: None | int | str): - if key not in self.get_attributes(): - raise KeyError(f"{key} is not a valid attribute") - setattr(self, key, value) - - def __contains__(self, item): - return item in self.get_attributes() - - def to_dict(self): - return self.__dict__ - - def is_default(self): - for field in fields(self): - default_value = field.default - current_value = getattr(self, field.name) - if current_value != default_value: - return False - return True - - -@dataclass -class TuningExtraConfig(BaseExtraConfig): - amp: bool = True - disable_opt_rtn: bool | None = None - enable_alg_ext: bool = False - enable_minmax_tuning: bool = True - enable_norm_bias_tuning: bool = False - enable_quanted_input: bool = True - enable_deterministic_algorithms: bool = False - lr: float = None - lr_scheduler: Callable = None - minmax_lr: float = None - nblocks: int = 1 - to_quant_block_names: Union[str, list, None] = None - scale_dtype: str = "fp16" - - -@dataclass -class SchemeExtraConfig(BaseExtraConfig): - bits: int = None - group_size: int = None - sym: bool = None - data_type: str = None - act_bits: int = None - act_group_size: int = None - act_sym: bool = None - act_data_type: str = None - act_dynamic: bool = None - super_bits: int = None - super_group_size: int = None - static_kv_dtype: Union[str, torch.dtype] = None - static_attention_dtype: Union[str, torch.dtype] = None - quant_lm_head: bool = False - ignore_layers: str = None - - -@dataclass -class MLLMExtraConfig(BaseExtraConfig): - processor: Callable = None - image_processor: Callable = None - quant_nontext_module: bool = False - extra_data_dir: str = None - template: str = None - - -@dataclass -class DiffusionExtraConfig(BaseExtraConfig): - guidance_scale: float = 7.5 - num_inference_steps: int = 50 - generator_seed: int = None diff --git a/auto_round/compressors/data_driven.py b/auto_round/compressors/data_driven.py index 2fb3358e6..c6cf5aebc 100644 --- a/auto_round/compressors/data_driven.py +++ b/auto_round/compressors/data_driven.py @@ -15,6 +15,7 @@ import gc import time import traceback +from contextlib import ExitStack from functools import partial from typing import Any, Callable, Optional, Union @@ -25,7 +26,6 @@ from tqdm import tqdm from auto_round import envs -from auto_round.algorithms.alg_config import AlgConfig from auto_round.calibration.utils import ( _infer_last_cache_name, _split_inputs_diffusion, @@ -75,7 +75,7 @@ class DataDrivenCompressor(BaseCompressor): def __init__( self, - config: Union[AlgConfig, list[AlgConfig]], + config: Union[object, list[object]], model: Union[torch.nn.Module, str], tokenizer=None, platform="hf", @@ -89,6 +89,8 @@ def __init__( low_cpu_mem_usage: bool = True, **kwargs, ): + if iters is None: + iters = 200 self.iters = iters super().__init__( config=config, @@ -314,6 +316,20 @@ def quantize_block( if not self._post_init_done: self.post_init() + if len(self.quant_block_list) != 1 or len(self.quant_block_list[0]) != 1: + raise ValueError( + f"{self.__class__.__name__}.quantize_block supports exactly one target block, " + f"but quant_block_list is {self.quant_block_list!r}. " + "Use to_quant_block_names to select a single block." + ) + expected_block_name = self.quant_block_list[0][0] + actual_block_name = getattr(block, "global_name", None) + if actual_block_name is not None and actual_block_name != expected_block_name: + raise ValueError( + f"quantize_block received block {actual_block_name!r}, but cached inputs are for " + f"{expected_block_name!r}. Pass the matching block or update to_quant_block_names." + ) + # When called from LLM-Compressor, `wrapped_model` is a single decoder layer # (not the full VL model), so it must not be treated as an MLLM regardless of # whether the original model had multimodal assets. Force is_mllm=False for @@ -367,44 +383,114 @@ def quantize_block( continue add_hook_to_module(m, AlignDevicesHook(m.tuning_device, io_same_device=True), True) - # ── Infrastructure: collect reference output and act_max ────────────── + blk_name = self.quant_block_list[0][0] bs = self.quantizer.batch_size * self.quantizer.infer_bs_coeff - if q_input is None: - hook_handles = self.quantizer.register_calibration_hooks(block) - reference_output = self.quantizer._get_block_outputs(block, input_ids, input_others, bs) - for h in hook_handles: - h.remove() + mid_iter_mem_check = self.compress_context.low_gpu_mem_usage and card_0_in_high_risk + + if not hasattr(self.quantizer, "create_block_io"): + if q_input is None: + hook_handles = self.quantizer.register_calibration_hooks(block) + reference_output = self.quantizer._get_block_outputs(block, input_ids, input_others, bs) + for h in hook_handles: + h.remove() + else: + reference_output = self.quantizer._get_block_outputs(block, input_ids, input_others, bs) + hook_handles = self.quantizer.register_calibration_hooks(block) + if hook_handles: + self.quantizer._get_block_outputs(block, q_input, input_others, bs, save_output=False) + for h in hook_handles: + h.remove() + if input_ids is not q_input: + clear_memory(input_ids, device_list=self.compress_context.device_list) + else: + clear_memory(device_list=self.compress_context.device_list) + input_ids = q_input + + self.quantizer.quantize_block( + block, + input_ids, + input_others, + reference_output, + loss_device=loss_device, + mid_iter_mem_check=mid_iter_mem_check, + ) + + if is_nv_fp(self.quantizer.act_data_type) or is_static_wfp8afp8(self.quantizer): + set_amax_for_all_moe_layers(block, attr_name="act_max") + + if self.quantizer.enable_quanted_input: + q_outputs = self.quantizer._get_block_outputs(block, input_ids, input_others, bs) + else: + q_outputs = None + + if len(self.compress_context.device_list) > 1: + accelerate.hooks.remove_hook_from_submodules(block) + mv_module_from_gpu(block) + return q_outputs, reference_output + + from auto_round.algorithms.pipeline import BlockContext, InputSource + + ctx = BlockContext( + model=self.model_context.model, + block=block, + block_names=[blk_name], + block_name=blk_name, + block_index=0, + io=self.quantizer.create_block_io(input_ids, input_others, q_input, block), + bs=bs, + loss_device=loss_device, + device=device, + mid_iter_mem_check=mid_iter_mem_check, + is_mllm=False, + is_diffusion=False, + ) + policy = self.pipeline.get_merged_policy(ctx) + + if policy.source == InputSource.QUANTIZED_INPUT and q_input is not None: + with ExitStack() as fwd_stack: + self.pipeline.enter_preprocessor_hooks(ctx, fwd_stack) + reference_output = ctx.io.collect_outputs( + block, self.quantizer, source=InputSource.FP_CACHE, batch_size=bs + ) + with ExitStack() as fwd_stack: + self.pipeline.enter_quantizer_hooks(ctx, fwd_stack) + ctx.io.collect_outputs( + block, self.quantizer, source=InputSource.QUANTIZED_INPUT, batch_size=bs, save=False + ) else: - reference_output = self.quantizer._get_block_outputs(block, input_ids, input_others, bs) - hook_handles = self.quantizer.register_calibration_hooks(block) - if hook_handles: - self.quantizer._get_block_outputs(block, q_input, input_others, bs, save_output=False) - for h in hook_handles: - h.remove() + with ExitStack() as fwd_stack: + self.pipeline.enter_block_forward_hooks(ctx, fwd_stack) + reference_output = ctx.io.collect_outputs( + block, self.quantizer, source=InputSource.FP_CACHE, batch_size=bs + ) + + if q_input is not None: if input_ids is not q_input: clear_memory(input_ids, device_list=self.compress_context.device_list) else: clear_memory(device_list=self.compress_context.device_list) input_ids = q_input - # ── Pure algorithm: delegates to quantizer ──────────────────────────── - mid_iter_mem_check = self.compress_context.low_gpu_mem_usage and card_0_in_high_risk - self.quantizer.quantize_block( - block, - input_ids, - input_others, - reference_output, - loss_device=loss_device, - mid_iter_mem_check=mid_iter_mem_check, - ) + ctx.reference_output = reference_output + + # pre_quantize_block: consolidate stats and apply weight transforms. + for pre in self.pipeline.preprocessors: + pre.pre_quantize_block(ctx) + + # ── Pure algorithm: block_quantizer.quantize_block ──────────────────── + self.pipeline.block_quantizer.quantize_block(ctx) + + # ── Pipeline lifecycle: post_quantize_block ─────────────────────────── + for pre in self.pipeline.preprocessors: + pre.post_quantize_block(ctx) # ── MoE scale alignment for FP8 dispatch efficiency ──────────────── if is_nv_fp(self.quantizer.act_data_type) or is_static_wfp8afp8(self.quantizer): set_amax_for_all_moe_layers(block, attr_name="act_max") # ── Collect quantized-block outputs ─────────────────────────────────── - if self.quantizer.enable_quanted_input: - q_outputs = self.quantizer._get_block_outputs(block, input_ids, input_others, bs) + if self.pipeline.block_quantizer.enable_quanted_input: + q_outputs = ctx.io.collect_outputs(block, self.quantizer, source=ctx.io.active_source, batch_size=bs) else: q_outputs = None @@ -444,19 +530,6 @@ def _quantize_blocks( input_ids, input_others = self._preprocess_block_inputs(inputs) - # For diffusion models, the heuristic split ("hidden_state" in key) may - # place keys like encoder_hidden_states in input_ids even though they are - # not block outputs. Move those to input_others so they persist across - # blocks (only output keys get refreshed via reference_output each iteration). - if self.model_context.is_diffusion and isinstance(input_ids, dict): - first_block = get_module(model, block_names[0]) - output_config = self.quantizer.DIFFUSION_OUTPUT_CONFIGS.get( - first_block.__class__.__name__, ["hidden_states"] - ) - extra_keys = [k for k in list(input_ids.keys()) if k not in output_config] - for k in extra_keys: - input_others[k] = input_ids.pop(k) - if pbar is None: pbar = tqdm(range(0, len(block_names), nblocks)) @@ -516,26 +589,59 @@ def _quantize_blocks( continue add_hook_to_module(_mod, AlignDevicesHook(_mod.tuning_device, io_same_device=True), True) - # ── Infrastructure: collect reference output and act_max ────────── + # ── Pipeline lifecycle: per-block setup ─────────────────────────── + from auto_round.algorithms.pipeline import BlockContext, InputSource + + current_block_names = ( + block_name_or_names if isinstance(block_name_or_names, list) else [block_name_or_names] + ) + current_block_name = current_block_names[0] if len(current_block_names) == 1 else str(block_name_or_names) bs = self.quantizer.batch_size * self.quantizer.infer_bs_coeff - if q_input is None: - hook_handles = self.quantizer.register_calibration_hooks(m) - reference_output = self.quantizer._get_block_outputs( - m, input_ids, input_others, bs, device_override=loss_device - ) - for h in hook_handles: - h.remove() + mid_iter_mem_check = self.compress_context.low_gpu_mem_usage and card_0_in_high_risk + + ctx = BlockContext( + model=model, + block=m, + block_names=current_block_names, + block_name=current_block_name, + block_index=i, + io=self.quantizer.create_block_io(input_ids, input_others, q_input, m), + bs=bs, + loss_device=loss_device, + device=self.compress_context.device, + mid_iter_mem_check=mid_iter_mem_check, + is_mllm=self.model_context.is_mllm, + is_diffusion=self.model_context.is_diffusion, + pbar=pbar, + ) + + # ── Infrastructure: collect reference output and act calib ──────── + # All forward hooks (preprocessor stats + act-calib) are active during + # the reference forward and removed when the ExitStack exits. + policy = self.pipeline.get_merged_policy(ctx) + + if policy.source == InputSource.QUANTIZED_INPUT and q_input is not None: + # First: reference forward with FP inputs and preprocessor hooks only. + with ExitStack() as fwd_stack: + self.pipeline.enter_preprocessor_hooks(ctx, fwd_stack) + reference_output = ctx.io.collect_outputs( + m, self.quantizer, source=InputSource.FP_CACHE, batch_size=bs + ) + # Second: quantizer stats forward with q_input. + with ExitStack() as fwd_stack: + self.pipeline.enter_quantizer_hooks(ctx, fwd_stack) + ctx.io.collect_outputs( + m, self.quantizer, source=InputSource.QUANTIZED_INPUT, batch_size=bs, save=False + ) else: - reference_output = self.quantizer._get_block_outputs( - m, input_ids, input_others, bs, device_override=loss_device - ) - hook_handles = self.quantizer.register_calibration_hooks(m) - if hook_handles: - self.quantizer._get_block_outputs( - m, q_input, input_others, bs, save_output=False, device_override=loss_device + # Unified: reference forward with all hooks active (or no hooks). + with ExitStack() as fwd_stack: + self.pipeline.enter_block_forward_hooks(ctx, fwd_stack) + reference_output = ctx.io.collect_outputs( + m, self.quantizer, source=InputSource.FP_CACHE, batch_size=bs ) - for h in hook_handles: - h.remove() + + ctx.reference_output = reference_output # ── Infrastructure: swap q_input ────────────────────────────────── if q_input is not None: @@ -545,24 +651,24 @@ def _quantize_blocks( clear_memory(device_list=self.compress_context.device_list) input_ids = q_input - # ── Pure algorithm: delegates to quantizer ──────────────────────── - mid_iter_mem_check = self.compress_context.low_gpu_mem_usage and card_0_in_high_risk - self.quantizer.quantize_block( - m, - input_ids, - input_others, - reference_output, - loss_device=loss_device, - mid_iter_mem_check=mid_iter_mem_check, - ) + # ── Pipeline lifecycle: pre_quantize_block (stats consolidation + weight transforms) ── + for pre in self.pipeline.preprocessors: + pre.pre_quantize_block(ctx) + + # ── Pure algorithm: block_quantizer.quantize_block ──────────────── + self.pipeline.block_quantizer.quantize_block(ctx) + + # ── Pipeline lifecycle: post_quantize_block ─────────────────────── + for pre in self.pipeline.preprocessors: + pre.post_quantize_block(ctx) # ── MoE scale alignment for FP8 dispatch efficiency ──────────────── if is_nv_fp(self.quantizer.act_data_type) or is_static_wfp8afp8(self.quantizer): set_amax_for_all_moe_layers(m, attr_name="act_max") # ── Infrastructure: collect q_outputs if needed ─────────────────── - if self.quantizer.enable_quanted_input: - q_input = self.quantizer._get_block_outputs(m, input_ids, input_others, bs) + if self.pipeline.block_quantizer.enable_quanted_input: + q_input = ctx.io.collect_outputs(m, self.quantizer, source=ctx.io.active_source, batch_size=bs) else: q_input = None @@ -589,7 +695,12 @@ def _quantize_blocks( if hasattr(_mod, "bits") and check_to_quantized(_mod): from auto_round.compressors.utils import immediate_pack as _immediate_pack - _immediate_pack(_mod.global_name, self.quantizer.layer_config) + module_name = getattr(_mod, "global_name", None) + if module_name is None and nblocks == 1 and _n: + module_name = f"{n}.{_n}" + if module_name is None: + continue + _immediate_pack(module_name, self.quantizer.layer_config) input_ids = next_input_ids @@ -667,7 +778,7 @@ def quantize(self) -> tuple[torch.nn.Module, dict[str, Any]]: last_cache_name=_last_cache_name, ) self.inputs = all_inputs - is_quantized_embedding = self._quantize_embedding_layer() + is_quantized_embedding = self.quantizer.quantize_embedding_layer() clear_memory(device_list=self.compress_context.device_list) all_q_inputs = None if is_quantized_embedding: @@ -701,38 +812,64 @@ def quantize(self) -> tuple[torch.nn.Module, dict[str, Any]]: pbar = tqdm(range(0, len(all_blocks[0]), self.nblocks)) # move the alg warning outside pbar start_time = time.time() - for block_names in all_blocks: - inputs = all_inputs[block_names[0]] - all_inputs.pop(block_names[0]) - q_inputs = None - if all_q_inputs is not None: - q_inputs = all_q_inputs[block_names[0]] - all_q_inputs.pop(block_names[0]) - inputs, q_inputs = _update_inputs(inputs, q_inputs) + # ── Pipeline lifecycle: prepare_quantization (model-level setup) ────── + from auto_round.algorithms.pipeline import RunContext - clear_memory(self.inputs, device_list=self.compress_context.device_list) + run_ctx = RunContext( + model=self.model_context.model, + all_blocks=all_blocks, + layer_names=layer_names, + formats=getattr(self, "formats", None), + scheme=getattr(self.quantize_config, "scheme", None), + alg_configs=getattr(self, "alg_configs", []), + model_context=self.model_context, + compress_context=self.compress_context, + ) + for alg in self.pipeline.all(): + alg.prepare_run(run_ctx) - if "input_ids" in inputs.keys(): - total_samples = len(inputs["input_ids"]) - if total_samples < self.quantizer.batch_size: - self.quantizer.batch_size = total_samples - logger.warning(f"force the train batch size to {total_samples}") - - self._quantize_blocks( - self.model_context.model, - inputs, - block_names, - q_input=q_inputs if q_inputs is not None else None, - nblocks=self.nblocks, - pbar=pbar, - input_others_extra_blocks=all_inputs, - ) - if self.compress_context.is_immediate_packing and len(self.formats) != 1: - raise ValueError( - f"Expected exactly one packing format when 'immediate_packing' is True, " - f"but got {len(self.formats)} formats." + try: + for block_names in all_blocks: + inputs = all_inputs[block_names[0]] + all_inputs.pop(block_names[0]) + q_inputs = None + if all_q_inputs is not None: + q_inputs = all_q_inputs[block_names[0]] + all_q_inputs.pop(block_names[0]) + + inputs, q_inputs = _update_inputs(inputs, q_inputs) + + clear_memory(self.inputs, device_list=self.compress_context.device_list) + + if "input_ids" in inputs.keys(): + total_samples = len(inputs["input_ids"]) + if total_samples < self.quantizer.batch_size: + self.quantizer.batch_size = total_samples + logger.warning(f"force the train batch size to {total_samples}") + + self._quantize_blocks( + self.model_context.model, + inputs, + block_names, + q_input=q_inputs if q_inputs is not None else None, + nblocks=self.nblocks, + pbar=pbar, + input_others_extra_blocks=all_inputs, ) + if self.compress_context.is_immediate_packing and len(self.formats) != 1: + raise ValueError( + f"Expected exactly one packing format when 'immediate_packing' is True, " + f"but got {len(self.formats)} formats." + ) + finally: + # ── Pipeline lifecycle: finalize_quantization (model-level teardown) ─ + for alg in self.pipeline.all(): + try: + alg.finalize_run(run_ctx) + except Exception as _fe: + logger.warning("finalize_run error in %s: %s", type(alg).__name__, _fe) + pbar.set_description("Quantizing done") pbar.close() if self.compress_context.low_cpu_mem_usage: @@ -783,6 +920,7 @@ def _quantize_layers(self, layer_names: list, layer_inputs: dict) -> None: """ # TODO currently we take all the layers outside blocks as post block layers which is not optimal # if there is no input for layer, we use rtn + for layer_name in copy.deepcopy(layer_names): if layer_name not in layer_inputs: if self.act_bits < 16 and not self.act_dynamic: @@ -892,7 +1030,7 @@ class CalibratedRTNCompressor(DataDrivenCompressor): def __init__( self, - config: AlgConfig, + config: object, model: torch.nn.Module, **kwargs, ): @@ -1011,7 +1149,7 @@ def process_input_others(input_others): self.quantizer.batch_size, self.compress_context.device, ) - if len(self.compress_context.device_list) > 1 and not self.model_context.is_diffusion: + if len(self.compress_context.device_list) > 1: from accelerate.hooks import AlignDevicesHook, add_hook_to_module for _, _mod in block.named_modules(): @@ -1021,17 +1159,28 @@ def process_input_others(input_others): else: block = block.to(self.compress_context.device) - # ── Infrastructure: register act_max hook and run forward pass ── - hook_handles = self.quantizer.register_calibration_hooks(block, imatrix=False) - block_input_ids = input_ids # keep reference for quantize_block - input_ids = self.quantizer._get_block_outputs( - block, - input_ids, - input_others, - self.quantizer.batch_size * self.quantizer.infer_bs_coeff, + # ── Infrastructure: collect block outputs and hook stats ── + from auto_round.algorithms.pipeline import BlockContext, InputSource + + block_input_ids = input_ids + bs = self.quantizer.batch_size * self.quantizer.infer_bs_coeff + ctx = BlockContext( + model=self.model_context.model, + block=block, + block_names=[block_name], + block_name=block_name, + block_index=0, + io=self.quantizer.create_block_io(input_ids, input_others, None, block), + bs=bs, + device=self.compress_context.device, + is_mllm=self.model_context.is_mllm, + is_diffusion=self.model_context.is_diffusion, ) - for h in hook_handles: - h.remove() + with ExitStack() as fwd_stack: + self.pipeline.enter_block_forward_hooks(ctx, fwd_stack) + input_ids = ctx.io.collect_outputs( + block, self.quantizer, source=InputSource.FP_CACHE, batch_size=bs + ) if len(self.compress_context.device_list) > 1: accelerate.hooks.remove_hook_from_submodules(block) @@ -1041,7 +1190,9 @@ def process_input_others(input_others): self.compress_context.clear_memory() # ── Pure algorithm ──────────────────────────────────────────── - self.quantizer.quantize_block(block, block_input_ids, input_others, block_name=block_name) + ctx.io.fp_inputs = block_input_ids + ctx.reference_output = input_ids + self.quantizer.quantize_block(ctx) # ── Infrastructure: cleanup ─────────────────────────────────── mv_module_from_gpu(block) @@ -1078,14 +1229,6 @@ def process_input_others(input_others): # shard_writer(self, is_finalize=True) def _quant_rtn_with_imatrix(self) -> None: - """Performs RTN quantization using input activation statistics (imatrix). - - OptimizedRTNQuantizer owns imatrix hook registration. This method only - enables the quantizer-side collection path and keeps the OOM fallback. - - Returns: - None - """ logger.info("start to compute imatrix") self.quantizer.enable_imatrix = True @@ -1098,7 +1241,6 @@ def _quant_rtn_with_imatrix(self) -> None: if hasattr(model, "hf_device_map") and len(model.hf_device_map) > 1: dispatch_model(model, model.hf_device_map) - hooks = self.quantizer.register_calibration_hooks(model, act_max=False) try: if hasattr(model, "hf_device_map") and len(model.hf_device_map) > 1: import accelerate @@ -1111,7 +1253,6 @@ def _quant_rtn_with_imatrix(self) -> None: cuda_error_msg = traceback.format_exc() try: logger.error(cuda_error_msg) - # Final fallback: warn and use CPU-only quantization logger.warning( "Fallback to CPU. " "Consider enabling `low_gpu_mem_usage` or using more GPUs via `--device 0,1,2,3`." @@ -1130,8 +1271,6 @@ def _quant_rtn_with_imatrix(self) -> None: except Exception as e: raise finally: - for hook in hooks: - hook.remove() self.quantizer.enable_imatrix = False def quantize(self): @@ -1155,7 +1294,7 @@ def _quantize_impl(self): formats = getattr(self, "formats", None) or [] if not (any(fmt.is_gguf() for fmt in formats) or self.super_bits is not None): - self._quantize_embedding_layer() # leave to gguf itself to handle + self.quantizer.quantize_embedding_layer() # leave to gguf itself to handle # Release memory clear_memory(device_list=self.compress_context.device_list) diff --git a/auto_round/compressors/entry.py b/auto_round/compressors/entry.py index d001aaf72..87905f5af 100644 --- a/auto_round/compressors/entry.py +++ b/auto_round/compressors/entry.py @@ -6,20 +6,90 @@ import torch -from auto_round.algorithms.alg_config import AlgConfig -from auto_round.algorithms.quantization.awq.config import AWQConfig -from auto_round.algorithms.quantization.rtn.config import RTNConfig +from auto_round.algorithms.pipeline import split_quantization_configs +from auto_round.algorithms.quantization.config import QuantizationConfig +from auto_round.algorithms.quantization.rtn.config import OptimizedRTNConfig, RTNConfig from auto_round.algorithms.quantization.sign_round.config import SignRoundConfig -from auto_round.algorithms.transforms import normalize_rotation_config as _normalize_any_rotation_config -from auto_round.algorithms.transforms.base import BaseRotationConfig as _BaseRotationConfig -from auto_round.algorithms.transforms.rotation.config import RotationConfig as _NewArchRotationConfig +from auto_round.algorithms.registry import normalize_algorithm_config, resolve_alg_config +from auto_round.algorithms.transforms.awq.config import AWQConfig +from auto_round.algorithms.transforms.quarot.config import RotationConfig as _NewArchRotationConfig from auto_round.auto_scheme.gen_auto_scheme import AutoScheme +from auto_round.compressors.base import BaseCompressor from auto_round.compressors.data_driven import CalibratedRTNCompressor, DataDrivenCompressor from auto_round.compressors.utils import check_need_act_calibration from auto_round.compressors.zero_shot import ZeroShotCompressor from auto_round.logger import logger from auto_round.schemes import QuantizationScheme, _parse_scheme +_ENTRY_ROUTE_KWARGS = {"model_free", "disable_model_free", "disable_opt_rtn"} +_ENTRY_COMPRESSOR_KWARGS = {"scale_dtype", "ignore_layers", "quant_lm_head", "to_quant_block_names"} +_ENTRY_BASE_KWARGS = { + "format", + "dataset", + "batch_size", + "model_dtype", + "trust_remote_code", + "amp", + "nblocks", + "disable_deterministic_algorithms", + "enable_deterministic_algorithms", + "static_kv_dtype", + "static_attention_dtype", +} +_ENTRY_MLLM_KWARGS = {"processor", "image_processor", "template", "extra_data_dir", "quant_nontext_module"} +_ENTRY_DIFFUSION_KWARGS = {"guidance_scale", "num_inference_steps", "generator_seed"} +_ENTRY_ALLOWED_KWARGS = ( + _ENTRY_ROUTE_KWARGS | _ENTRY_COMPRESSOR_KWARGS | _ENTRY_BASE_KWARGS | _ENTRY_MLLM_KWARGS | _ENTRY_DIFFUSION_KWARGS +) + + +def filter_supported_entry_kwargs(kwargs: dict[str, Any], *, context: str) -> dict[str, Any]: + """Return only kwargs supported by the new entry API. + + Unsupported kwargs are ignored with a warning so callers can cleanly migrate + without leaking old-API parameters into compressor constructors. + """ + + supported = {} + unknown = [] + for key, value in kwargs.items(): + if key in _ENTRY_ALLOWED_KWARGS: + supported[key] = value + else: + unknown.append(key) + if unknown: + logger.warning_once( + "%s received unsupported kwargs %s. They will be ignored.", + context, + ", ".join(sorted(unknown)), + ) + return supported + + +def _split_entry_kwargs(kwargs: dict[str, Any]) -> dict[str, dict[str, Any]]: + """Partition new-entry kwargs by ownership.""" + + kwargs = filter_supported_entry_kwargs(kwargs, context="AutoRound entry") + buckets = { + "route": {}, + "compressor": {}, + "base": {}, + "mllm": {}, + "diffusion": {}, + } + for key, value in kwargs.items(): + if key in _ENTRY_ROUTE_KWARGS: + buckets["route"][key] = value + elif key in _ENTRY_COMPRESSOR_KWARGS: + buckets["compressor"][key] = value + elif key in _ENTRY_BASE_KWARGS: + buckets["base"][key] = value + elif key in _ENTRY_MLLM_KWARGS: + buckets["mllm"][key] = value + elif key in _ENTRY_DIFFUSION_KWARGS: + buckets["diffusion"][key] = value + return buckets + def _preview_resolved_attrs(config, scheme=None) -> dict: """Resolve scheme attributes without mutating config, for routing decisions. @@ -36,7 +106,7 @@ def _preview_resolved_attrs(config, scheme=None) -> dict: if isinstance(scheme, AutoScheme): # AutoScheme needs model info — cannot preview, rely on raw config attrs return {} - scheme_attr_names = QuantizationScheme.get_attributes() + scheme_attr_names = tuple(config._scheme_fields) user_overrides = {k: getattr(config, k) for k in scheme_attr_names if getattr(config, k, None) is not None} try: _, _, final_attrs = _parse_scheme(scheme, user_overrides) @@ -59,7 +129,7 @@ def _eager_validate_scheme(config, scheme=None) -> None: if isinstance(scheme, AutoScheme): return - scheme_attr_names = QuantizationScheme.get_attributes() + scheme_attr_names = tuple(config._scheme_fields) user_overrides = {k: getattr(config, k) for k in scheme_attr_names if getattr(config, k, None) is not None} try: _, _, final_attrs = _parse_scheme(scheme, user_overrides) @@ -71,6 +141,9 @@ def _eager_validate_scheme(config, scheme=None) -> None: import copy temp_config = copy.copy(config) + if hasattr(config, "scheme"): + temp_config.scheme = config.scheme.copy() + temp_config._user_set_scheme_fields = set(getattr(config, "_user_set_scheme_fields", set())) for key, value in final_attrs.items(): setattr(temp_config, key, value) temp_config.check_config() # raises ValueError / NotImplementedError if invalid @@ -138,33 +211,113 @@ def is_gguf_k_target(value) -> bool: return False +def _resolve_quant_config_for_routing(alg_configs) -> tuple[list, list, QuantizationConfig]: + preprocessor_configs, block_quant_configs = split_quantization_configs(alg_configs) + if len(block_quant_configs) == 0 and preprocessor_configs: + from auto_round.algorithms.quantization.rtn.config import RTNConfig as _RTNConfig + + return preprocessor_configs, block_quant_configs, _RTNConfig() + if len(block_quant_configs) > 1: + raise ValueError( + f"Only one block-quantization config is allowed, but got {len(block_quant_configs)}: " + f"{[type(c).__name__ for c in block_quant_configs]}" + ) + if len(block_quant_configs) == 1: + return preprocessor_configs, block_quant_configs, block_quant_configs[0] + raise ValueError( + "At least one quantization algorithm config is required. " + "Pass a block quantizer such as RTNConfig or SignRoundConfig, " + "or a quantization preprocessor such as AWQConfig." + ) + + +def _build_model_type_ctor_kwargs(model, base_kwargs, mllm_kwargs, diffusion_kwargs) -> tuple[str, dict[str, Any]]: + from auto_round.utils.model import detect_model_type + + model_type = detect_model_type(model) + has_multimodal_assets = mllm_kwargs.get("processor") is not None or mllm_kwargs.get("image_processor") is not None + if has_multimodal_assets and model_type != "mllm": + model_type = "mllm" + + ctor_kwargs = dict(base_kwargs) + if model_type == "mllm": + ctor_kwargs.update(mllm_kwargs) + if model_type == "diffusion": + ctor_kwargs.update(diffusion_kwargs) + return model_type, ctor_kwargs + + +def _select_rtn_compressor_base_cls(quant_config: RTNConfig, scheme, format, base_kwargs) -> type: + enable_imatrix = False + resolved_attrs = {} + disable_opt_rtn = getattr(quant_config, "disable_opt_rtn", False) + + # If disable_opt_rtn was not explicitly set and scheme is W8A16/W8A8, + # auto-disable optimization to improve efficiency. + if getattr(quant_config, "orig_disable_opt_rtn", None) is None: + if isinstance(scheme, str) and scheme.upper() in ["W8A16", "W8A8"]: + logger.warning("`disable_opt_rtn` is turned on for W8A16/W8A8 quantization to improve efficiency.") + disable_opt_rtn = True + quant_config.disable_opt_rtn = True + + if not disable_opt_rtn: + has_gguf_k = is_gguf_k_target(format) or is_gguf_k_target(scheme) + if has_gguf_k: + enable_imatrix = True + else: + # Resolve scheme attrs for routing. SchemeMixin will do the authoritative + # resolution later; this preview only chooses the compressor class. + resolved_attrs = _preview_resolved_attrs(quant_config, scheme) + sym = resolved_attrs.get("sym", getattr(quant_config, "sym", None)) + data_type = resolved_attrs.get("data_type", getattr(quant_config, "data_type", "") or "") + bits = resolved_attrs.get("bits", getattr(quant_config, "bits", None)) + if sym is not None and sym is False: + enable_imatrix = False + elif data_type == "int" and (bits is None or bits < 8): + enable_imatrix = True + elif is_weight_scheme(scheme): + enable_imatrix = True + + resolved_attrs = resolved_attrs if not disable_opt_rtn else _preview_resolved_attrs(quant_config, scheme) + act_bits = resolved_attrs.get("act_bits", getattr(quant_config, "act_bits", None)) + act_data_type = resolved_attrs.get("act_data_type", getattr(quant_config, "act_data_type", None)) + act_dynamic = resolved_attrs.get("act_dynamic", getattr(quant_config, "act_dynamic", None)) + is_act_quantize = act_bits is not None and act_bits <= 8 + needs_act_calib = is_act_quantize and check_need_act_calibration( + act_dynamic, + act_data_type, + act_bits if act_bits is not None else 16, + static_kv_dtype=base_kwargs.get("static_kv_dtype"), + static_attention_dtype=base_kwargs.get("static_attention_dtype"), + ) + + # AutoScheme always requires calibration data for delta-loss based scheme + # selection, regardless of whether imatrix is needed. + quant_config.enable_imatrix = enable_imatrix + if enable_imatrix or needs_act_calib or isinstance(scheme, AutoScheme): + if not isinstance(quant_config, OptimizedRTNConfig): + quant_config.__class__ = OptimizedRTNConfig + return CalibratedRTNCompressor + + if isinstance(quant_config, OptimizedRTNConfig): + quant_config.__class__ = RTNConfig + return ZeroShotCompressor + + class AutoRound(object): - # Mapping from string alias to config class (and optional defaults override). - _CONFIG_ALIASES: dict[str, type] = { - "sign_round": SignRoundConfig, - "signround": SignRoundConfig, - "rtn": RTNConfig, - "hadamard": _NewArchRotationConfig, - } @classmethod - def _resolve_config(cls, config: Union[str, AlgConfig, list]) -> Union[AlgConfig, list[AlgConfig]]: + def _resolve_config(cls, config: Union[str, object, list]) -> Union[object, list[object]]: """Convert string alias(es) to the corresponding config instance(s) with default parameters.""" if isinstance(config, str): - key = config.strip().lower() - # Handle spinquant/quarot via unified normalizer - if key in ("spinquant", "quarot"): - return _normalize_any_rotation_config(key) - if key not in cls._CONFIG_ALIASES: - raise ValueError(f"Unknown config alias '{config}'. " f"Supported: {list(cls._CONFIG_ALIASES.keys())}") - return cls._CONFIG_ALIASES[key]() + return resolve_alg_config(config) if isinstance(config, list): return [cls._resolve_config(c) for c in config] return config def __new__( cls, - alg_configs: Union[str, AlgConfig, list[Union[str, AlgConfig]]], + alg_configs: Union[str, object, list[Union[str, object]]], model: Union[torch.nn.Module, str], tokenizer=None, platform="hf", @@ -181,34 +334,57 @@ def __new__( nsamples: int = None, seqlen: int = None, **kwargs, - ): - from auto_round.algorithms.quantization.config import QuantizationConfig + ) -> "BaseCompressor": + from auto_round.utils.model import is_model_free_route + + split_kwargs = _split_entry_kwargs(kwargs) + route_kwargs = dict(split_kwargs["route"]) + compressor_kwargs = dict(split_kwargs["compressor"]) + base_kwargs = dict(split_kwargs["base"]) + mllm_kwargs = dict(split_kwargs["mllm"]) + diffusion_kwargs = dict(split_kwargs["diffusion"]) # Resolve string alias(es) to config instance(s) before routing. alg_configs = cls._resolve_config(alg_configs) - - # Extract the single QuantizationConfig from a list; validate at most one exists. if isinstance(alg_configs, list): - quant_configs = [c for c in alg_configs if isinstance(c, QuantizationConfig)] - if len(quant_configs) == 0: - raise ValueError("At least one QuantizationConfig (SignRoundConfig / RTNConfig) is required.") - if len(quant_configs) > 1: - raise ValueError( - f"Only one QuantizationConfig is allowed, but got {len(quant_configs)}: " - f"{[type(c).__name__ for c in quant_configs]}" - ) - quant_config = quant_configs[0] + alg_configs = [normalize_algorithm_config(cfg) for cfg in alg_configs] else: - quant_config = alg_configs + alg_configs = normalize_algorithm_config(alg_configs) + configs_for_routing = alg_configs if isinstance(alg_configs, list) else [alg_configs] + preprocessor_configs, _, quant_config = _resolve_quant_config_for_routing(configs_for_routing) + + # Model-free routing is now supported directly by the new entry path. + model_free_iters = 0 if isinstance(quant_config, RTNConfig) else getattr(quant_config, "iters", None) + model_free_disable_opt_rtn = getattr(quant_config, "disable_opt_rtn", None) + if is_model_free_route(model, scheme, model_free_iters, model_free_disable_opt_rtn, route_kwargs): + from auto_round.compressors.model_free import ModelFreeCompressor + + if not isinstance(model, str): + raise ValueError("model_free=True requires `model` to be a HuggingFace ID or local path string.") + if not bool(route_kwargs.get("model_free", False)): + logger.info( + "Auto-routing to model-free quantization " + "(iters=0, disable_opt_rtn=True, supported scheme). " + "Pass disable_model_free=True to use the regular flow." + ) + return ModelFreeCompressor( + model_name_or_path=model, + scheme=scheme, + layer_config=layer_config, + tokenizer=tokenizer, + device_map=device_map, + **compressor_kwargs, + **base_kwargs, + **mllm_kwargs, + **diffusion_kwargs, + **route_kwargs, + ) # Eagerly validate scheme constraints that do not require model info. # This mirrors old-arch _check_configs() called at __init__ time so that # callers get ValueError/NotImplementedError on construction, not deferred. _eager_validate_scheme(quant_config, scheme) - # Explicitly build the dict of constructor args to forward to the - # compressor. This avoids the fragile locals()-based approach that - # required a growing SKIP_ARGS blocklist. local_args = dict( model=model, tokenizer=tokenizer, @@ -225,94 +401,25 @@ def __new__( layer_config=layer_config, nsamples=nsamples, seqlen=seqlen, + **compressor_kwargs, ) + model_type, ctor_kwargs = _build_model_type_ctor_kwargs(model, base_kwargs, mllm_kwargs, diffusion_kwargs) - # Detect model type to determine if we need special compressor - from auto_round.utils.model import detect_model_type - - model_type = detect_model_type(model) - - # If the user explicitly passes processor/image_processor, treat as MLLM even if - # auto-detection missed it (mirrors the has_multimodal_assets check in autoround.py). - has_multimodal_assets = kwargs.get("processor") is not None or kwargs.get("image_processor") is not None - if has_multimodal_assets and model_type != "mllm": - model_type = "mllm" - - # Pop kwargs that are only consumed by specific Mixins so they don't - # leak through to BaseCompressor as unrecognized keys. - if model_type != "diffusion": - for _k in ("guidance_scale", "num_inference_steps", "generator_seed"): - kwargs.pop(_k, None) - if model_type != "mllm": - for _k in ("processor", "image_processor", "template", "extra_data_dir", "quant_nontext_module"): - kwargs.pop(_k, None) - kwargs.pop("disable_opt_rtn", None) # consumed by RTN routing above, not a compressor param + # Preprocessor algorithms (AWQ, …) require a data-driven host so that + # the per-block preprocessor lifecycle (prepare_block_group -> + # block_forward_hooks -> pre_quantize_block -> pre_quantize_block -> + # post_quantize_block) actually runs. CalibratedRTNCompressor's + # Preprocessor algorithms require DataDrivenCompressor for per-block lifecycle hooks. + # The pipeline auto-appends RTN when no block_quantizer is supplied. + if preprocessor_configs: + return _get_compressor_class(model_type, DataDrivenCompressor)(alg_configs, **local_args, **ctor_kwargs) if isinstance(quant_config, SignRoundConfig): - return _get_compressor_class(model_type, DataDrivenCompressor)(alg_configs, **local_args, **kwargs) - - elif isinstance(quant_config, AWQConfig): - # AWQ requires calibration for activation collection + smoothing - quant_config._alg_cls = "AWQQuantizer" - return _get_compressor_class(model_type, CalibratedRTNCompressor)(alg_configs, **local_args, **kwargs) + return _get_compressor_class(model_type, DataDrivenCompressor)(alg_configs, **local_args, **ctor_kwargs) elif isinstance(quant_config, RTNConfig): - enable_imatrix = False - _resolved = {} - disable_opt_rtn = getattr(quant_config, "disable_opt_rtn", False) - # If disable_opt_rtn was not explicitly set and scheme is W8A16/W8A8, - # auto-disable optimization to improve efficiency. - if getattr(quant_config, "orig_disable_opt_rtn", None) is None: - if isinstance(scheme, str) and scheme.upper() in ["W8A16", "W8A8"]: - logger.warning("`disable_opt_rtn` is turned on for W8A16/W8A8 quantization to improve efficiency.") - disable_opt_rtn = True - quant_config.disable_opt_rtn = True - if not disable_opt_rtn: - has_gguf_k = is_gguf_k_target(format) or is_gguf_k_target(scheme) - if has_gguf_k: - enable_imatrix = True - else: - # Resolve scheme attrs for routing (config hasn't been through - # SchemeMixin yet; user may have specified only scheme="W4A16"). - _resolved = _preview_resolved_attrs(quant_config, scheme) - _sym = _resolved.get("sym", getattr(quant_config, "sym", None)) - _data_type = _resolved.get("data_type", getattr(quant_config, "data_type", "") or "") - _bits = _resolved.get("bits", getattr(quant_config, "bits", None)) - if _sym is not None and _sym is False: - enable_imatrix = False - elif _data_type == "int" and (_bits is None or _bits < 8): - enable_imatrix = True - elif is_weight_scheme(scheme): - enable_imatrix = True - else: - _resolved = {} - - _resolved = _resolved if not disable_opt_rtn else _preview_resolved_attrs(quant_config, scheme) - _act_bits = _resolved.get("act_bits", getattr(quant_config, "act_bits", None)) - _act_data_type = _resolved.get("act_data_type", getattr(quant_config, "act_data_type", None)) - _act_dynamic = _resolved.get("act_dynamic", getattr(quant_config, "act_dynamic", None)) - _is_act_quantize = _act_bits is not None and _act_bits <= 8 - needs_act_calib = _is_act_quantize and check_need_act_calibration( - _act_dynamic, - _act_data_type, - _act_bits if _act_bits is not None else 16, - static_kv_dtype=kwargs.get("static_kv_dtype"), - static_attention_dtype=kwargs.get("static_attention_dtype"), - ) - - # AutoScheme always requires calibration data for delta-loss based - # scheme selection, regardless of whether imatrix is needed. - from auto_round.auto_scheme.gen_auto_scheme import AutoScheme as _AutoScheme - - is_auto_scheme = isinstance(scheme, _AutoScheme) - quant_config.enable_imatrix = enable_imatrix - - if enable_imatrix or needs_act_calib or is_auto_scheme: - quant_config._alg_cls = "OptimizedRTNQuantizer" - return _get_compressor_class(model_type, CalibratedRTNCompressor)(alg_configs, **local_args, **kwargs) - else: - quant_config._alg_cls = "RTNQuantizer" - return _get_compressor_class(model_type, ZeroShotCompressor)(alg_configs, **local_args, **kwargs) + base_cls = _select_rtn_compressor_base_cls(quant_config, scheme, format, base_kwargs) + return _get_compressor_class(model_type, base_cls)(alg_configs, **local_args, **ctor_kwargs) class AutoRoundCompatible: @@ -369,14 +476,7 @@ class AutoRoundCompatible: @staticmethod def _pop_config_kwargs(kwargs: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]: """Extract old-API config kwargs and split them by config type.""" - common_keys = ( - "ignore_layers", - "quant_lm_head", - "scale_dtype", - "super_bits", - "super_group_size", - "to_quant_block_names", - ) + common_keys = ("super_bits", "super_group_size") auto_round_only_keys = ( "nblocks", "enable_alg_ext", @@ -397,6 +497,157 @@ def _pop_config_kwargs(kwargs: dict[str, Any]) -> tuple[dict[str, Any], dict[str auto_round_kwargs[key] = kwargs.pop(key) return common_kwargs, auto_round_kwargs + @staticmethod + def _pop_compressor_only_kwargs(kwargs: dict[str, Any]) -> dict[str, Any]: + return { + "scale_dtype": kwargs.pop("scale_dtype", None), + "ignore_layers": kwargs.pop("ignore_layers", ""), + "quant_lm_head": kwargs.pop("quant_lm_head", False), + "to_quant_block_names": kwargs.pop("to_quant_block_names", None), + } + + @staticmethod + def _resolve_compat_algorithm(algorithm, iters) -> str: + if algorithm and algorithm.lower() == "awq": + return "awq" + if (algorithm and algorithm.lower() == "rtn") or iters == 0: + return "rtn" + return "signround" + + @staticmethod + def _pop_shared_quant_kwargs(kwargs: dict[str, Any]) -> dict[str, Any]: + return { + "bits": kwargs.pop("bits", None), + "group_size": kwargs.pop("group_size", None), + "sym": kwargs.pop("sym", None), + "data_type": kwargs.pop("data_type", None), + "act_bits": kwargs.pop("act_bits", None), + "act_group_size": kwargs.pop("act_group_size", None), + "act_sym": kwargs.pop("act_sym", None), + "act_data_type": kwargs.pop("act_data_type", None), + "act_dynamic": kwargs.pop("act_dynamic", None), + } + + @staticmethod + def _build_awq_config( + shared_quant_kwargs: dict[str, Any], + *, + seqlen, + nsamples, + batch_size, + kwargs, + common_config_kwargs, + ): + return AWQConfig( + **shared_quant_kwargs, + duo_scaling=kwargs.pop("duo_scaling", True), + n_grid=kwargs.pop("n_grid", 20), + seqlen=seqlen, + nsamples=nsamples, + batch_size=batch_size, + mappings=kwargs.pop("mappings", None), + **common_config_kwargs, + ) + + @staticmethod + def _build_rtn_config(shared_quant_kwargs: dict[str, Any], *, kwargs, common_config_kwargs): + cfg = RTNConfig( + **shared_quant_kwargs, + disable_opt_rtn=kwargs.pop("disable_opt_rtn", None), + enable_opt_rtn=kwargs.pop("enable_opt_rtn", None), + **common_config_kwargs, + ) + return normalize_algorithm_config(cfg) + + @staticmethod + def _build_signround_config( + shared_quant_kwargs: dict[str, Any], + *, + iters, + gradient_accumulate_steps, + kwargs, + common_config_kwargs, + auto_round_config_kwargs, + ): + cfg = SignRoundConfig( + **shared_quant_kwargs, + iters=iters, + gradient_accumulate_steps=gradient_accumulate_steps, + lr=kwargs.pop("lr", None), + minmax_lr=kwargs.pop("minmax_lr", None), + enable_minmax_tuning=kwargs.pop("enable_minmax_tuning", True), + enable_norm_bias_tuning=kwargs.pop("enable_norm_bias_tuning", False), + enable_quanted_input=kwargs.pop("enable_quanted_input", True), + **common_config_kwargs, + **auto_round_config_kwargs, + ) + return normalize_algorithm_config(cfg) + + @classmethod + def _build_alg_config( + cls, + *, + algorithm, + iters, + gradient_accumulate_steps, + seqlen, + nsamples, + batch_size, + kwargs, + common_config_kwargs, + auto_round_config_kwargs, + ): + alg_name = cls._resolve_compat_algorithm(algorithm, iters) + shared_quant_kwargs = cls._pop_shared_quant_kwargs(kwargs) + + if alg_name == "awq": + return cls._build_awq_config( + shared_quant_kwargs, + seqlen=seqlen, + nsamples=nsamples, + batch_size=batch_size, + kwargs=kwargs, + common_config_kwargs=common_config_kwargs, + ) + if alg_name == "rtn": + return cls._build_rtn_config( + shared_quant_kwargs, + kwargs=kwargs, + common_config_kwargs=common_config_kwargs, + ) + return cls._build_signround_config( + shared_quant_kwargs, + iters=iters, + gradient_accumulate_steps=gradient_accumulate_steps, + kwargs=kwargs, + common_config_kwargs=common_config_kwargs, + auto_round_config_kwargs=auto_round_config_kwargs, + ) + + @staticmethod + def _build_entry_forward_kwargs(kwargs: dict[str, Any]) -> dict[str, Any]: + format_name = kwargs.pop("format", None) + rotation_config = kwargs.pop("rotation_config", None) + mllm_kwargs = { + "processor": kwargs.pop("processor", None), + "image_processor": kwargs.pop("image_processor", None), + "template": kwargs.pop("template", None), + "extra_data_dir": kwargs.pop("extra_data_dir", None), + "quant_nontext_module": kwargs.pop("quant_nontext_module", False), + } + diffusion_kwargs = { + "guidance_scale": kwargs.pop("guidance_scale", 7.5), + "num_inference_steps": kwargs.pop("num_inference_steps", 50), + "generator_seed": kwargs.pop("generator_seed", None), + } + return { + "format": format_name, + "rotation_config": rotation_config, + **mllm_kwargs, + **diffusion_kwargs, + **kwargs, + } + def __new__( cls, model: Union[torch.nn.Module, str], @@ -417,16 +668,10 @@ def __new__( low_cpu_mem_usage: bool = True, algorithm: str = None, **kwargs, - ): + ) -> "BaseCompressor": """Create AutoRoundCompatible instance using new AutoRound architecture. This method translates old AutoRoundCompatible API to new AutoRound API. - - Args: - algorithm: Quantization algorithm to use. Options: - - None or "auto_round": SignSGD-based optimization (default when iters > 0) - - "rtn": Round-to-nearest (default when iters == 0) - - "awq": Activation-Aware Weight Quantization (AWQ smoothing + RTN) """ from auto_round.utils import is_diffusion_model, is_mllm_model from auto_round.utils.model import is_model_free_route @@ -441,6 +686,8 @@ def __new__( if is_model_free_route(model, scheme, iters, kwargs.get("disable_opt_rtn"), kwargs): from auto_round.compressors.model_free import ModelFreeCompressor + compressor_only_kwargs = cls._pop_compressor_only_kwargs(kwargs) + if not isinstance(model, str): raise ValueError("model_free=True requires `model` to be a HuggingFace ID or local path string.") if not bool(kwargs.get("model_free", False)): @@ -455,131 +702,38 @@ def __new__( layer_config=layer_config, tokenizer=tokenizer, device_map=device_map, + **compressor_only_kwargs, **kwargs, ) # -------------------------------------------------------------------- + compressor_only_kwargs = cls._pop_compressor_only_kwargs(kwargs) common_config_kwargs, auto_round_config_kwargs = cls._pop_config_kwargs(kwargs) - # Extract quantization parameters from kwargs or use defaults - bits = kwargs.pop("bits", None) - group_size = kwargs.pop("group_size", None) - sym = kwargs.pop("sym", None) - data_type = kwargs.pop("data_type", None) - act_bits = kwargs.pop("act_bits", None) - act_group_size = kwargs.pop("act_group_size", None) - act_sym = kwargs.pop("act_sym", None) - act_data_type = kwargs.pop("act_data_type", None) - act_dynamic = kwargs.pop("act_dynamic", None) - enable_opt_rtn = kwargs.pop("enable_opt_rtn", None) - lr = kwargs.pop("lr", None) - minmax_lr = kwargs.pop("minmax_lr", None) - enable_minmax_tuning = kwargs.pop("enable_minmax_tuning", True) - enable_norm_bias_tuning = kwargs.pop("enable_norm_bias_tuning", False) - enable_quanted_input = kwargs.pop("enable_quanted_input", True) - - # Pop AWQ-only kwargs early so they don't leak into non-AWQ constructors - duo_scaling = kwargs.pop("duo_scaling", True) - n_grid = kwargs.pop("n_grid", 20) - awq_mappings = kwargs.pop("mappings", None) - - # Decide which algorithm to use - if algorithm and algorithm.lower() == "awq": - # AWQ mode: activation-aware weight quantization - config = AWQConfig( - bits=bits, - group_size=group_size, - sym=sym, - data_type=data_type, - act_bits=act_bits, - act_group_size=act_group_size, - act_sym=act_sym, - act_data_type=act_data_type, - act_dynamic=act_dynamic, - duo_scaling=duo_scaling, - n_grid=n_grid, - seqlen=seqlen, - nsamples=nsamples, - batch_size=batch_size, - mappings=awq_mappings, - **common_config_kwargs, - ) - elif (algorithm and algorithm.lower() == "rtn") or iters == 0: - # RTN mode - disable_opt_rtn = kwargs.pop("disable_opt_rtn", None) - config = RTNConfig( - bits=bits, - group_size=group_size, - sym=sym, - data_type=data_type, - act_bits=act_bits, - act_group_size=act_group_size, - act_sym=act_sym, - act_data_type=act_data_type, - act_dynamic=act_dynamic, - disable_opt_rtn=disable_opt_rtn, - enable_opt_rtn=enable_opt_rtn, - **common_config_kwargs, - ) - else: - # AutoRoundCompatible mode - config = SignRoundConfig( - iters=iters, - gradient_accumulate_steps=gradient_accumulate_steps, - bits=bits, - group_size=group_size, - sym=sym, - data_type=data_type, - act_bits=act_bits, - act_group_size=act_group_size, - act_sym=act_sym, - act_data_type=act_data_type, - act_dynamic=act_dynamic, - lr=lr, - minmax_lr=minmax_lr, - enable_minmax_tuning=enable_minmax_tuning, - enable_norm_bias_tuning=enable_norm_bias_tuning, - enable_quanted_input=enable_quanted_input, - **common_config_kwargs, - **auto_round_config_kwargs, - ) - - # Determine output format if specified - format = kwargs.pop("format", None) + config = cls._build_alg_config( + algorithm=algorithm, + iters=iters, + gradient_accumulate_steps=gradient_accumulate_steps, + seqlen=seqlen, + nsamples=nsamples, + batch_size=batch_size, + kwargs=kwargs, + common_config_kwargs=common_config_kwargs, + auto_round_config_kwargs=auto_round_config_kwargs, + ) - # Extract rotation_config (old-API kwarg) and thread it into alg_configs. - # In old arch this was a standalone keyword arg; the new arch passes rotation - # transforms as part of the alg_configs list. All backends (auto / inplace / - # transform) are dispatched inside ``HadamardRotation.apply_to_model``. - # Also supports SpinQuantConfig and string shorthands ("quarot", "spinquant"). - _rotation_config_raw = kwargs.pop("rotation_config", None) + forward_kwargs = cls._build_entry_forward_kwargs(kwargs) + format_name = forward_kwargs.pop("format", None) + _rotation_config_raw = forward_kwargs.pop("rotation_config", None) if _rotation_config_raw is not None: - if isinstance(_rotation_config_raw, _BaseRotationConfig): - # Already a valid config (RotationConfig, SpinQuantConfig, etc.) + if isinstance(_rotation_config_raw, _NewArchRotationConfig): _rc = _rotation_config_raw elif isinstance(_rotation_config_raw, dict): - # Use unified normalizer which dispatches by "algorithm" key - _rc = _normalize_any_rotation_config(_rotation_config_raw) - elif isinstance(_rotation_config_raw, str): - # String shorthands: "quarot", "spinquant", "hadamard", - # "random_hadamard", "default", etc. - _rc = _normalize_any_rotation_config(_rotation_config_raw) + _rc = _NewArchRotationConfig.model_validate(_rotation_config_raw) else: + # str alias ("default", "random_hadamard", …) -> default config _rc = _NewArchRotationConfig() - if _rc is not None: - config = [config, _rc] - - # Extract MLLM-specific parameters - processor = kwargs.pop("processor", None) - image_processor = kwargs.pop("image_processor", None) - template = kwargs.pop("template", None) - extra_data_dir = kwargs.pop("extra_data_dir", None) - quant_nontext_module = kwargs.pop("quant_nontext_module", False) - - # Extract Diffusion-specific parameters - guidance_scale = kwargs.pop("guidance_scale", 7.5) - num_inference_steps = kwargs.pop("num_inference_steps", 50) - generator_seed = kwargs.pop("generator_seed", None) + config = [config, _rc] # Check model type for logging (use warning_once to avoid repeating for every block # when called from LLM-Compressor which instantiates AutoRound per block) @@ -596,7 +750,7 @@ def __new__( model=model, tokenizer=tokenizer, platform=platform, - format=format, + format=format_name, scheme=scheme, dataset=dataset, iters=iters, @@ -610,18 +764,8 @@ def __new__( nsamples=nsamples, seqlen=seqlen, batch_size=batch_size, - # MLLM parameters - processor=processor, - image_processor=image_processor, - template=template, - extra_data_dir=extra_data_dir, - quant_nontext_module=quant_nontext_module, - # Diffusion parameters - guidance_scale=guidance_scale, - num_inference_steps=num_inference_steps, - generator_seed=generator_seed, - # Pass remaining kwargs - **kwargs, + **compressor_only_kwargs, + **forward_kwargs, ) return compressor diff --git a/auto_round/compressors/mllm_mixin.py b/auto_round/compressors/mllm_mixin.py index 26bf4d06a..97f5926a0 100644 --- a/auto_round/compressors/mllm_mixin.py +++ b/auto_round/compressors/mllm_mixin.py @@ -68,8 +68,8 @@ def __init__( new_grad_acc = batch_size * grad_acc kwargs["gradient_accumulate_steps"] = new_grad_acc kwargs["batch_size"] = 1 - # Also patch ``gradient_accumulate_steps`` on AlgConfig (still - # owned there) so behaviour matches the old arch. + # Also patch ``gradient_accumulate_steps`` on algorithm configs so + # behaviour matches the old arch. _alg_cfg = args[0] if args else None if _alg_cfg is not None: cfgs = _alg_cfg if isinstance(_alg_cfg, list) else [_alg_cfg] diff --git a/auto_round/compressors/zero_shot.py b/auto_round/compressors/zero_shot.py index 3b471c6a2..f4668be5e 100644 --- a/auto_round/compressors/zero_shot.py +++ b/auto_round/compressors/zero_shot.py @@ -17,7 +17,7 @@ import torch from tqdm import tqdm -from auto_round.algorithms.alg_config import AlgConfig +from auto_round.algorithms.pipeline import BlockContext from auto_round.compressors.base import BaseCompressor from auto_round.compressors.utils import is_nv_fp, is_static_wfp8afp8 from auto_round.logger import logger @@ -43,7 +43,7 @@ class ZeroShotCompressor(BaseCompressor): def __init__( self, - config: Union[AlgConfig, list[AlgConfig]], + config: Union[object, list[object]], model: Union[torch.nn.Module, str], tokenizer=None, platform="hf", @@ -82,9 +82,9 @@ def quantize_block( """Quantize a single block via RTN (public API for LLM-Compressor). ZeroShotCompressor does not need calibration data, so ``inputs`` and - ``q_input`` are accepted for interface compatibility but not used for - algorithm purposes. The block is materialized, converted to the target - dtype, moved to ``device``, and quantized in-place via RTN. + ``q_input`` are accepted for interface compatibility + but not used for algorithm purposes. The block is materialized, converted + to the target dtype, moved to ``device``, and quantized in-place via RTN. Returns: tuple: ``(None, None)`` — RTN does not produce reference outputs. @@ -100,7 +100,16 @@ def quantize_block( convert_module_to_hp_if_necessary(block, self.model_context.amp_dtype, device) block = block.to(device) - self.quantizer.quantize_block(block) + ctx = BlockContext( + model=self.model_context.model, + block=block, + block_names=[getattr(block, "global_name", "")], + block_name=getattr(block, "global_name", ""), + block_index=0, + io=self.quantizer.create_block_io(None, {}, None, block), + device=device, + ) + self.quantizer.quantize_block(ctx) # ── MoE scale alignment for FP8 dispatch efficiency ──────────────── if is_nv_fp(self.quantizer.act_data_type) or is_static_wfp8afp8(self.quantizer): @@ -122,7 +131,7 @@ def quantize(self) -> tuple[torch.nn.Module, dict[str, Any]]: formats = self.formats if isinstance(self.formats, list) else [] if not (any(fmt.is_gguf() for fmt in formats) or self.super_bits is not None): - self._quantize_embedding_layer() # leave to gguf itself to handle + self.quantizer.quantize_embedding_layer() # leave to gguf itself to handle # Release memory clear_memory(device_list=self.device_list) @@ -164,7 +173,16 @@ def quantize(self) -> tuple[torch.nn.Module, dict[str, Any]]: materialize_model_(block) # ── Pure algorithm ──────────────────────────────────────── - self.quantizer.quantize_block(block) + ctx = BlockContext( + model=self.model_context.model, + block=block, + block_names=[block_name], + block_name=block_name, + block_index=0, + io=self.quantizer.create_block_io(None, {}, None, block), + device=self.compress_context.device, + ) + self.quantizer.quantize_block(ctx) # ── MoE scale alignment for FP8 dispatch efficiency ──────────────── if is_nv_fp(self.quantizer.act_data_type) or is_static_wfp8afp8(self.quantizer): diff --git a/auto_round/context/__init__.py b/auto_round/context/__init__.py index 14a492441..7548429bf 100644 --- a/auto_round/context/__init__.py +++ b/auto_round/context/__init__.py @@ -11,3 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from auto_round.schemes import QuantizationScheme + +__all__ = ["QuantizationScheme"] diff --git a/auto_round/context/model.py b/auto_round/context/model.py index 2cd5273d1..1bae19f71 100644 --- a/auto_round/context/model.py +++ b/auto_round/context/model.py @@ -250,7 +250,7 @@ def apply_patches(self, formats): """Apply format-specific model structure patches. Must be called after formats are resolved (list[OutputFormat]) and before - BaseQuantizers.post_init() so that configure_layer_config() operates on the + BaseQuantizer.post_init() so that configure_layer_config() operates on the final model structure (post update_module). Eliminates the need for a subsequent refresh_quantizer_for_initialized_model() call. """ diff --git a/auto_round/experimental/apply_rotation_transform.py b/auto_round/experimental/apply_rotation_transform.py index 520f6b2fb..c04da06af 100644 --- a/auto_round/experimental/apply_rotation_transform.py +++ b/auto_round/experimental/apply_rotation_transform.py @@ -3,10 +3,10 @@ """Backward-compat re-export shim. The canonical implementation now lives in -:mod:`auto_round.algorithms.transforms.rotation.dispatcher`. +:mod:`auto_round.algorithms.transforms.quarot.dispatcher`. """ -from auto_round.algorithms.transforms.rotation.dispatcher import ( # noqa: F401 +from auto_round.algorithms.transforms.quarot.dispatcher import ( # noqa: F401 apply_hadamard_rotation, resolve_hadamard_backend, ) diff --git a/auto_round/experimental/utils.py b/auto_round/experimental/utils.py index d27c73abb..d199b024c 100644 --- a/auto_round/experimental/utils.py +++ b/auto_round/experimental/utils.py @@ -16,8 +16,8 @@ import torch -from auto_round.algorithms.transforms.rotation.config import RotationConfig -from auto_round.algorithms.transforms.rotation.transforms import HADAMARDS +from auto_round.algorithms.transforms.quarot.config import RotationConfig +from auto_round.algorithms.transforms.quarot.transforms import HADAMARDS from auto_round.compressors.utils import is_mx_fp, is_nv_fp from auto_round.utils import logger @@ -130,7 +130,7 @@ def is_triton_kernel_available(data_type: str) -> bool: return False try: - from auto_round.algorithms.transforms.rotation.utils.triton.mxfp4 import ( # pylint: disable=E0401 + from auto_round.algorithms.transforms.quarot.utils.triton.mxfp4 import ( # pylint: disable=E0401 mxfp4_forward_kernel_wrapper, ) except Exception: @@ -140,19 +140,19 @@ def is_triton_kernel_available(data_type: str) -> bool: def dump_group_size_to_rotation_config(rotation_config: str | dict | RotationConfig, group_size: int): - from auto_round.algorithms.transforms.rotation.config import dump_group_size_to_rotation_config as _impl + from auto_round.algorithms.transforms.quarot.config import dump_group_size_to_rotation_config as _impl return _impl(rotation_config, group_size) def to_dict_rotation_config(rotation_config: str | dict | RotationConfig): - from auto_round.algorithms.transforms.rotation.config import to_dict_rotation_config as _impl + from auto_round.algorithms.transforms.quarot.config import to_dict_rotation_config as _impl return _impl(rotation_config) def normalize_rotation_config(rotation_config: str | dict | RotationConfig | None, data_type: str) -> dict[str, Any]: - from auto_round.algorithms.transforms.rotation.config import normalize_rotation_config as _impl + from auto_round.algorithms.transforms.quarot.config import normalize_rotation_config as _impl return _impl(rotation_config, data_type) diff --git a/auto_round/formats.py b/auto_round/formats.py index 57b6e4f3c..bbc89cc66 100644 --- a/auto_round/formats.py +++ b/auto_round/formats.py @@ -768,13 +768,14 @@ class GGUFFormat(OutputFormat): def __init__(self, format: str, ar: BaseCompressor): if format.startswith("gguf:"): self._original_format = format # preserve "gguf:q2_k_mixed" etc. for Phase 2b - self.gguf_args_check(ar, format, model_type=ModelType.TEXT) - if ar.mllm: - self.gguf_args_check(ar, format, model_type=ModelType.MMPROJ) - self.output_format = "gguf" self.backend_cls = GGUFFormat self.backend = GGUFFormat(format.split(":")[-1], ar) + + resolved_format = self.backend.output_format + self.gguf_args_check(ar, resolved_format, model_type=ModelType.TEXT) + if ar.mllm: + self.gguf_args_check(ar, resolved_format, model_type=ModelType.MMPROJ) else: scheme = ar.scheme gguf_format = f"gguf:{format.lower()}" @@ -782,12 +783,12 @@ def __init__(self, format: str, ar: BaseCompressor): from auto_round.schemes import _handle_special_schemes from auto_round.utils.model import is_moe_model - if format.lower() == "q2_k_mixed" and getattr(ar, "iters", 0) > 0 and not is_moe_model(ar.model): + if format.lower() == "q2_k_mixed" and (getattr(ar, "iters", 0) or 0) > 0 and not is_moe_model(ar.model): logger.warning( "gguf:q2_k_mixed only supports MoE models with iters>0. " - "It is not an MoE model, falling back to gguf:q4_k_m." + "It is not an MoE model, falling back to gguf:q2_k_s." ) - gguf_format = "gguf:q4_k_m" + gguf_format = "gguf:q2_k_s" else: ar.layer_config = _handle_special_schemes( gguf_format, ar.layer_config, ar.model, quant_nontext_module=ar.quant_nontext_module diff --git a/auto_round/inference/convert_model.py b/auto_round/inference/convert_model.py index 7d19b1e3e..b6de09d66 100644 --- a/auto_round/inference/convert_model.py +++ b/auto_round/inference/convert_model.py @@ -875,8 +875,8 @@ def convert_hf_model(model: nn.Module, target_device: str = "cpu") -> tuple[nn.M rotation_config = getattr(quantization_config, "rotation_config", None) if rotation_config is not None and rotation_config: - from auto_round.algorithms.transforms.rotation.apply import apply_rotation_transform - from auto_round.algorithms.transforms.rotation.config import RotationConfig + from auto_round.algorithms.transforms.quarot.apply import apply_rotation_transform + from auto_round.algorithms.transforms.quarot.config import RotationConfig # apply forward hook act_rotation_config = RotationConfig( diff --git a/auto_round/schemes.py b/auto_round/schemes.py index d351a3f07..2d4673de5 100644 --- a/auto_round/schemes.py +++ b/auto_round/schemes.py @@ -42,12 +42,27 @@ class QuantizationScheme: super_group_size: Optional[int] = None rotation_config: Optional[dict] = None + @classmethod + def empty(cls): + return cls(**{field.name: None for field in fields(cls)}) + @classmethod def from_dict(cls, config: dict): field_names = {f.name for f in fields(cls)} filtered_config = {k: v for k, v in config.items() if k in field_names} return cls(**filtered_config) + def to_dict(self) -> dict: + return asdict(self) + + def copy(self): + return copy.deepcopy(self) + + def update_from_dict(self, config: dict) -> None: + for key, value in config.items(): + if key in self.get_attributes(): + setattr(self, key, value) + @classmethod def get_attributes(cls: "QuantizationScheme") -> list[str]: return [field.name for field in fields(cls)] diff --git a/auto_round/utils/model.py b/auto_round/utils/model.py index 805c596d3..363fa0ff9 100644 --- a/auto_round/utils/model.py +++ b/auto_round/utils/model.py @@ -2294,8 +2294,8 @@ def is_model_free_route( """ from auto_round.compressors.model_free import is_model_free_supported_scheme - explicit = bool(kwargs.pop("model_free", False)) - disabled = bool(kwargs.pop("disable_model_free", False)) + explicit = bool(kwargs.get("model_free", False)) + disabled = bool(kwargs.get("disable_model_free", False)) if explicit: return True # Only auto-route when format is auto_round (or not specified). diff --git a/auto_round/wrapper.py b/auto_round/wrapper.py index bd6a78570..0f6c222ea 100644 --- a/auto_round/wrapper.py +++ b/auto_round/wrapper.py @@ -335,6 +335,12 @@ def unwrapper(self, best_params): Returns: torch.nn.Module: The unwrapped and restored original layer. """ + + def _preserve_global_name(layer): + if hasattr(self.orig_layer, "global_name") and not hasattr(layer, "global_name"): + layer.global_name = self.orig_layer.global_name + return layer + best_params = best_params or {} v = best_params.get("value", torch.tensor(0.0)).to(self.device) min_scale = best_params.get("min_scale", torch.tensor(1.0)).to(self.device) @@ -441,9 +447,9 @@ def _set_dict_attr(attr_dict, attr_name): enable_torch_compile=self.enable_torch_compile, device=self.device, ) - return wrapper_layer + return _preserve_global_name(wrapper_layer) - return self.orig_layer + return _preserve_global_name(self.orig_layer) def linear_forward(self, x, weight, bias): """Performs the forward pass for a linear layer. diff --git a/docs/step_by_step.md b/docs/step_by_step.md index 4a2a8052d..5467c1922 100644 --- a/docs/step_by_step.md +++ b/docs/step_by_step.md @@ -483,6 +483,42 @@ We will try to optimize the RAM usage in the future. The RAM usage is about 1.1- Embedding layer is not supported in AutoScheme, it will use the best scheme in options. +### AWQ Quantization Algorithm + +AWQ (`algorithm="awq"`) is a pre-processing quantization algorithm that analyzes activation patterns and applies channel-wise scaling to protect salient weights. It runs BEFORE the actual quantization (RTN by default, or auto_round/SignRound). + +#### CLI Usage +```bash +# AWQ + default RTN (iters=0 auto-selected) +auto-round --model Qwen/Qwen3-0.6B --algorithm awq --scheme W4A16 + +# AWQ + AutoRound optimization +auto-round --model Qwen/Qwen3-0.6B --algorithm awq,auto_round --scheme W4A16 + +# AWQ flags +--awq-duo-scaling true|false|both (default: true) +--awq-n-grid 20 (default: 20) +``` + +#### API Usage +```python +from auto_round import AutoRound +from auto_round.algorithms.quantization.awq.config import AWQConfig +from auto_round.algorithms.quantization.sign_round.config import SignRoundConfig + +# AWQ + default RTN (simplest) +ar = AutoRound(model, tokenizer, algorithm="awq", scheme="W4A16") + +# AWQ + AutoRound via alg_configs (explicit pipeline) +ar = AutoRound(model, tokenizer, alg_configs=[AWQConfig(), SignRoundConfig(iters=200)], scheme="W4A16") +ar.quantize_and_save(output_dir="./qmodel") +``` + +**Important Note**: `algorithm="awq"` (quantization algorithm) and `format="auto_awq"` (export format) are independent. You can use: +- `algorithm="awq"` + `format="auto_round"`: AWQ smoothing + AutoRound packing +- `algorithm="auto_round"` + `format="auto_awq"`: No AWQ smoothing + AutoAWQ packing + + ### OPT RTN Mode AutoRound also supports Optimized RTN (Round-To-Nearest) mode for fast, calibration-free baseline quantization. Setting `iters=0` tp enable it and we recommend using `group_size=32` for better results. Check [accuracy comparison](./opt_rtn.md) between RTN and OPT RTN mode diff --git a/docs/step_by_step_CN.md b/docs/step_by_step_CN.md index cf4347df0..b2de1fb32 100644 --- a/docs/step_by_step_CN.md +++ b/docs/step_by_step_CN.md @@ -480,7 +480,45 @@ ar.quantize_and_save() #### 局限性 AutoScheme 目前还**不支持对嵌入层(Embedding layer)进行自动量化**。该层将直接采用候选方案中精度最高的配置。 + +### AWQ 量化算法 + +AWQ(`algorithm="awq"`)是一种预处理量化算法,通过分析激活分布并应用通道缩放(channel-wise scaling)来保护重要的权重。它在实际量化(默认为 RTN,或使用 auto_round/SignRound)之前运行。 + +#### 命令行用法 +```bash +# AWQ + 默认 RTN (自动选择 iters=0) +auto-round --model Qwen/Qwen3-0.6B --algorithm awq --scheme W4A16 + +# AWQ + AutoRound 优化 +auto-round --model Qwen/Qwen3-0.6B --algorithm awq,auto_round --scheme W4A16 + +# AWQ 相关参数 +--awq-duo-scaling true|false|both (默认: true) +--awq-n-grid 20 (默认: 20) +``` + +#### API 用法 +```python +from auto_round import AutoRound +from auto_round.algorithms.quantization.awq.config import AWQConfig +from auto_round.algorithms.quantization.sign_round.config import SignRoundConfig + +# AWQ + 默认 RTN (最简用法) +ar = AutoRound(model, tokenizer, algorithm="awq", scheme="W4A16") + +# 通过 alg_configs 指定 AWQ + AutoRound (显式流水线) +ar = AutoRound(model, tokenizer, alg_configs=[AWQConfig(), SignRoundConfig(iters=200)], scheme="W4A16") +ar.quantize_and_save(output_dir="./qmodel") +``` + +**重要提示**:`algorithm="awq"`(量化算法)与 `format="auto_awq"`(导出格式)是相互独立的。你可以使用: +- `algorithm="awq"` + `format="auto_round"`:AWQ 平滑 + AutoRound 打包 +- `algorithm="auto_round"` + `format="auto_awq"`:不使用 AWQ 平滑 + AutoAWQ 打包 + + ### OPT-RTN 模式 + AutoRound 还提供优化版 RTN(Round-To-Nearest,就近舍入)模式,无需标定数据即可实现快速基线量化。**启用方式为 `iters=0`**。同时为获得更好的效果,推荐搭配 `group_size=32` 。RTN 与 OPT RTN 模式的精度对比详见[《精度对比报告》](./opt_rtn.md)。 对于 GGUF 格式,我们参考 llamacpp 的思路,优化了 RTN 算法。若需使用原始(非优化)RTN 算法,开启 `--disable_opt_rtn` 即可。 diff --git a/setup.py b/setup.py index 6bb184bcb..bac4362f7 100644 --- a/setup.py +++ b/setup.py @@ -187,7 +187,7 @@ def fetch_requirements(path): package_data={ "": [ "mllm/templates/*.json", - "algorithms/transforms/rotation/utils/hadamards.safetensors", + "algorithms/transforms/quarot/utils/hadamards.safetensors", ] }, ) diff --git a/test/test_cpu/algorithms/test_awq.py b/test/test_cpu/algorithms/test_awq.py index 85a3e815f..e6b55e3a1 100644 --- a/test/test_cpu/algorithms/test_awq.py +++ b/test/test_cpu/algorithms/test_awq.py @@ -176,7 +176,7 @@ def _save_dir(self, tmp_path): def test_awq_moe_dynamic_smoothing(self, tiny_qwen_moe_model_path): """AWQ dynamic smoothing should resolve mappings on a MoE model without error.""" - from auto_round.algorithms.quantization.awq.mappings import resolve_mappings + from auto_round.algorithms.transforms.awq.mappings import resolve_mappings model = AutoModelForCausalLM.from_pretrained( tiny_qwen_moe_model_path, diff --git a/test/test_cpu/core/test_awq_autoround_smoke.py b/test/test_cpu/core/test_awq_autoround_smoke.py new file mode 100644 index 000000000..0ef86a18d --- /dev/null +++ b/test/test_cpu/core/test_awq_autoround_smoke.py @@ -0,0 +1,21 @@ +"""Minimal runtime smoke for AWQ + AutoRound fusion.""" + +from auto_round.algorithms.quantization.sign_round.config import SignRoundConfig +from auto_round.algorithms.transforms.awq.config import AWQConfig +from auto_round.compressors.entry import AutoRound + + +def test_awq_plus_autoround_quantize_smoke(tiny_opt_model_path): + ar = AutoRound( + [AWQConfig(n_grid=2), SignRoundConfig(iters=1)], + tiny_opt_model_path, + scheme="W4A16", + nsamples=1, + seqlen=8, + low_cpu_mem_usage=False, + ) + + model, layer_config = ar.quantize() + + assert model is not None + assert layer_config diff --git a/test/test_cpu/core/test_pipeline_fail_fast.py b/test/test_cpu/core/test_pipeline_fail_fast.py new file mode 100644 index 000000000..dbc91e0d7 --- /dev/null +++ b/test/test_cpu/core/test_pipeline_fail_fast.py @@ -0,0 +1,138 @@ +"""Fast unit tests for algorithm registry and pipeline construction.""" + +import pytest + +from auto_round.algorithms.pipeline import ( + QuantizationPipeline, + get_algorithm_class, + resolve_shared_config_values, + split_quantization_configs, + sync_shared_config_from, +) +from auto_round.algorithms.quantization import registry as _r +from auto_round.algorithms.quantization.config import QuantizationConfig +from auto_round.algorithms.quantization.rtn.config import OptimizedRTNConfig, RTNConfig +from auto_round.algorithms.quantization.rtn.quantizer import RTNQuantizer +from auto_round.algorithms.quantization.sign_round.config import SignRoundConfig +from auto_round.algorithms.transforms.awq.config import AWQConfig +from auto_round.algorithms.transforms.quarot.config import RotationConfig +from auto_round.compressors.base import collect_user_scheme_overrides +from auto_round.compressors.entry import AutoRound as NewAutoRound +from auto_round.logger import logger + + +class PartialSharedConfig(RTNConfig): + def __init__(self, *, weight_clip_ratio=None, **kwargs): + super().__init__(**kwargs) + self.weight_clip_ratio = weight_clip_ratio + + +class NoWeightClipConfig(RTNConfig): + pass + + +def test_split_awq_plus_rtn(): + pre, block = split_quantization_configs([AWQConfig(), RTNConfig()]) + assert len(pre) == 1 and type(pre[0]).__name__ == "AWQConfig" + assert len(block) == 1 and type(block[0]).__name__ == "RTNConfig" + + +def test_pipeline_preprocessor_only_auto_appends_rtn(): + pipeline = QuantizationPipeline.from_configs([AWQConfig()]) + assert type(pipeline.preprocessors[0]).__name__ == "AWQQuantizer" + assert isinstance(pipeline.block_quantizer, RTNQuantizer) + + +def test_pipeline_duplicate_preprocessor_rejected(): + with pytest.raises(ValueError, match="Duplicate preprocessor"): + QuantizationPipeline.from_configs([AWQConfig(), AWQConfig()]) + + +def test_pipeline_multiple_block_quantizers_rejected(): + with pytest.raises(ValueError, match="exactly one block-quantization config"): + QuantizationPipeline.from_configs([RTNConfig(), SignRoundConfig()]) + + +def test_registry_builtin_aliases_and_unknown(): + assert isinstance(_r.resolve_alg_config("RTN"), RTNConfig) + assert isinstance(_r.resolve_alg_config("awq"), AWQConfig) + assert isinstance(_r.resolve_alg_config("autoround"), SignRoundConfig) + with pytest.raises(ValueError, match="Unknown algorithm alias"): + _r.resolve_alg_config("definitely_not_registered_abc123") + + +def test_registry_resolves_variant_configs_to_registered_members(): + assert get_algorithm_class(OptimizedRTNConfig()) is not None + assert get_algorithm_class(SignRoundConfig(enable_adam=True)).__name__ == "AdamRoundQuantizer" + + +def test_entry_rejects_configs_without_quantization_members(): + with pytest.raises(ValueError, match="At least one quantization algorithm config"): + NewAutoRound(alg_configs=[RotationConfig()], model="dummy-model") + + +def test_entry_warns_and_drops_unsupported_kwargs(monkeypatch, tiny_opt_model_path): + calls = [] + + def _record_warning(message, *args): + calls.append(message % args) + + monkeypatch.setattr(logger, "warning_once", _record_warning) + + NewAutoRound( + alg_configs=RTNConfig(disable_opt_rtn=True), + model=tiny_opt_model_path, + scheme="W4A16", + nsamples=1, + seqlen=8, + low_cpu_mem_usage=False, + nonsense_kwarg=123, + ) + + assert any("unsupported kwargs nonsense_kwarg" in msg for msg in calls) + + +def test_shared_config_values_inherit_across_matching_attrs_only(): + awq = PartialSharedConfig(weight_clip_ratio=0.9) + smoothquant_like = NoWeightClipConfig() + signround = PartialSharedConfig(weight_clip_ratio=None) + + resolve_shared_config_values([awq, smoothquant_like, signround]) + + assert signround.weight_clip_ratio == 0.9 + assert not hasattr(smoothquant_like, "weight_clip_ratio") + + +def test_shared_config_values_reject_conflicts(): + with pytest.raises(ValueError, match="Conflicting shared config field 'weight_clip_ratio'"): + resolve_shared_config_values( + [PartialSharedConfig(weight_clip_ratio=0.8), PartialSharedConfig(weight_clip_ratio=0.9)] + ) + + +def test_shared_config_sync_from_source_skips_missing_attrs(): + source = PartialSharedConfig(weight_clip_ratio=0.75) + target = PartialSharedConfig() + no_clip_target = NoWeightClipConfig() + + sync_shared_config_from(source, [target, no_clip_target, RotationConfig()]) + + assert target.weight_clip_ratio == 0.75 + assert not hasattr(no_clip_target, "weight_clip_ratio") + + +def test_user_scheme_overrides_merge_across_all_configs(): + awq = AWQConfig(bits=8) + rtn = RTNConfig() + assert collect_user_scheme_overrides([awq, rtn])["bits"] == 8 + + resolve_shared_config_values([awq, rtn]) + + assert rtn.bits == 8 + + +def test_user_scheme_overrides_reject_explicit_conflicts(): + with pytest.raises(ValueError, match="Conflicting shared scheme field 'bits'"): + collect_user_scheme_overrides([AWQConfig(bits=8), RTNConfig(bits=4)]) + with pytest.raises(ValueError, match="Conflicting shared scheme field 'bits'"): + resolve_shared_config_values([AWQConfig(bits=8), RTNConfig(bits=4)]) diff --git a/test/test_cpu/export/test_gguf_format.py b/test/test_cpu/export/test_gguf_format.py index 56daf87f9..8ad73cf5a 100644 --- a/test/test_cpu/export/test_gguf_format.py +++ b/test/test_cpu/export/test_gguf_format.py @@ -8,6 +8,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer from auto_round import AutoRound +from auto_round.algorithms.quantization.rtn.config import OptimizedRTNConfig from ...helpers import eval_generated_prompt, get_model_path, get_tiny_model, save_tiny_model @@ -63,7 +64,7 @@ def test_q2_k_s_routes_calibrated_rtn(self, tiny_qwen_model_path): ) assert type(autoround).__name__ == "CalibratedRTNCompressor" - assert autoround.quantize_config._alg_cls == "OptimizedRTNQuantizer" + assert isinstance(autoround.quantize_config, OptimizedRTNConfig) def test_func(self): bits, group_size, sym = 4, 128, True diff --git a/test/test_cpu/models/test_audio_model.py b/test/test_cpu/models/test_audio_model.py index aaa29bc26..59bbec20e 100644 --- a/test/test_cpu/models/test_audio_model.py +++ b/test/test_cpu/models/test_audio_model.py @@ -196,10 +196,10 @@ class TestStableAudioRegistration: """Verify StableAudio-specific registrations.""" def test_config_and_special_registered(self): - from auto_round.algorithms.quantization.base import BaseQuantizers + from auto_round.algorithms.quantization.base import DiffusionMixin - assert "StableAudioDiTBlock" in BaseQuantizers.DIFFUSION_OUTPUT_CONFIGS - assert BaseQuantizers.DIFFUSION_OUTPUT_CONFIGS["StableAudioDiTBlock"] == ["hidden_states"] + assert "StableAudioDiTBlock" in DiffusionMixin.DIFFUSION_OUTPUT_CONFIGS + assert DiffusionMixin.DIFFUSION_OUTPUT_CONFIGS["StableAudioDiTBlock"] == ["hidden_states"] assert "StableAudioDiTModel" in SPECIAL_SHARED_CACHE_KEYS assert "encoder_hidden_states" in SPECIAL_SHARED_CACHE_KEYS["StableAudioDiTModel"] diff --git a/test/test_cpu/quantization/test_model_free.py b/test/test_cpu/quantization/test_model_free.py index 736b6c910..3ec26c0c4 100644 --- a/test/test_cpu/quantization/test_model_free.py +++ b/test/test_cpu/quantization/test_model_free.py @@ -345,9 +345,10 @@ class TestCliAutoRouting: def test_auto_routes(self, tmp_path): model_dir = _make_model_dir(tmp_path, _LLAMA_CFG, {"layer.weight": torch.randn(64, 128)}) out_dir = str(tmp_path / "out") - from auto_round.__main__ import BasicArgumentParser, tune + from auto_round.cli.main import tune + from auto_round.cli.parser import build_quantize_parser - args = BasicArgumentParser().parse_args( + args = build_quantize_parser().parse_args( [ "--model", model_dir, @@ -366,9 +367,9 @@ def test_auto_routes(self, tmp_path): assert _read_qconfig(out_dir).get("model_free") is True def test_disable_model_free_flag(self): - from auto_round.__main__ import BasicArgumentParser + from auto_round.cli.parser import build_quantize_parser - args = BasicArgumentParser().parse_args( + args = build_quantize_parser().parse_args( [ "--model", "dummy", diff --git a/test/test_cpu/utils/test_alg_ext.py b/test/test_cpu/utils/test_alg_ext.py index 2daf2aada..d5f64fb72 100644 --- a/test/test_cpu/utils/test_alg_ext.py +++ b/test/test_cpu/utils/test_alg_ext.py @@ -21,7 +21,7 @@ def test_alg_ext(self, tiny_opt_model_path, tiny_qwen_model_path): ar.quantize() def test_alg_ext_import(self): - from auto_round.alg_ext import wrapper_autoround + from auto_round.algorithms.quantization.sign_roundv2 import SignRoundV2Quantizer def test_all_support_dtype(self, tiny_opt_model_path): model_name = tiny_opt_model_path diff --git a/test/test_cpu/utils/test_cli_usage.py b/test/test_cpu/utils/test_cli_usage.py index 64308cf43..0911e4005 100644 --- a/test/test_cpu/utils/test_cli_usage.py +++ b/test/test_cpu/utils/test_cli_usage.py @@ -131,7 +131,7 @@ def test_parse_layer_config_with_single_escaped_regex_keys(): def test_run_rtn_uses_zero_shot_recipe(monkeypatch): - from auto_round import __main__ as cli_main + from auto_round.cli import main as cli_main captured = {} @@ -160,7 +160,7 @@ def fake_tune(args): def test_run_rtn_preserves_eval_args(monkeypatch, tmp_path): - from auto_round import __main__ as cli_main + from auto_round.cli import main as cli_main captured = {} @@ -198,7 +198,7 @@ def fake_tune(args): def test_run_opt_rtn_uses_recipe(monkeypatch): - from auto_round import __main__ as cli_main + from auto_round.cli import main as cli_main captured = {} diff --git a/test/test_cuda/algorithms/test_alg_ext.py b/test/test_cuda/algorithms/test_alg_ext.py index e661e22f9..9d23aa86c 100644 --- a/test/test_cuda/algorithms/test_alg_ext.py +++ b/test/test_cuda/algorithms/test_alg_ext.py @@ -31,11 +31,11 @@ def test_gguf_q2_k_s_uses_dq_wrapper_block(self, tiny_qwen_model_path): gguf:q2_k_s overrides data_type to "int_asym_dq" at format-resolution time. The quantizer must be created *after* that override so that - wrapper_autoround() sees the final data_type and sets dq_wrapper_block + SignRoundV2 sees the final data_type and sets dq_wrapper_block (which wraps layers with DQWrapperLinear) instead of falling back to the plain wrapper_block (which produces WrapperLinear). """ - from auto_round.alg_ext import dq_wrapper_block + from auto_round.algorithms.quantization.sign_roundv2.quantizer import SignRoundDQWrapperLinear ar = AutoRound( tiny_qwen_model_path, @@ -50,8 +50,8 @@ def test_gguf_q2_k_s_uses_dq_wrapper_block(self, tiny_qwen_model_path): # create_quantizer → ...). quantizer only exists afterwards. ar.post_init() - assert ar.quantizer.wrapper_block.__name__ == dq_wrapper_block.__name__, ( - f"Expected wrapper_block to be '{dq_wrapper_block.__name__}', " + assert ar.quantizer.wrapper_block.keywords["wrapper_cls"] is SignRoundDQWrapperLinear, ( + f"Expected wrapper_block to use '{SignRoundDQWrapperLinear.__name__}', " f"got '{ar.quantizer.wrapper_block.__name__}'. " "This likely means the quantizer was created before GGUF format " "overrides were applied (data_type was not yet 'int_asym_dq')." diff --git a/test/test_cuda/algorithms/test_awq.py b/test/test_cuda/algorithms/test_awq.py index 5222a98da..2f4f38f37 100644 --- a/test/test_cuda/algorithms/test_awq.py +++ b/test/test_cuda/algorithms/test_awq.py @@ -126,7 +126,7 @@ def _save_dir(self, tmp_path): def test_awq_moe_dynamic_smoothing(self, tiny_qwen_moe_model_path): """AWQ mapping resolution works on MoE model.""" - from auto_round.algorithms.quantization.awq.mappings import resolve_mappings + from auto_round.algorithms.transforms.awq.mappings import resolve_mappings model = AutoModelForCausalLM.from_pretrained( tiny_qwen_moe_model_path, diff --git a/test/test_cuda/models/test_audio_model.py b/test/test_cuda/models/test_audio_model.py index 2d32053e3..e9996a054 100644 --- a/test/test_cuda/models/test_audio_model.py +++ b/test/test_cuda/models/test_audio_model.py @@ -261,10 +261,10 @@ class TestStableAudioRegistration: """Verify StableAudio-specific registrations.""" def test_config_and_special_registered(self): - from auto_round.algorithms.quantization.base import BaseQuantizers + from auto_round.algorithms.quantization.base import DiffusionMixin - assert "StableAudioDiTBlock" in BaseQuantizers.DIFFUSION_OUTPUT_CONFIGS - assert BaseQuantizers.DIFFUSION_OUTPUT_CONFIGS["StableAudioDiTBlock"] == ["hidden_states"] + assert "StableAudioDiTBlock" in DiffusionMixin.DIFFUSION_OUTPUT_CONFIGS + assert DiffusionMixin.DIFFUSION_OUTPUT_CONFIGS["StableAudioDiTBlock"] == ["hidden_states"] assert "StableAudioDiTModel" in SPECIAL_SHARED_CACHE_KEYS assert "encoder_hidden_states" in SPECIAL_SHARED_CACHE_KEYS["StableAudioDiTModel"] diff --git a/test/test_cuda/quantization/test_torch_compile.py b/test/test_cuda/quantization/test_torch_compile.py index a578e8821..f451c6b5b 100644 --- a/test/test_cuda/quantization/test_torch_compile.py +++ b/test/test_cuda/quantization/test_torch_compile.py @@ -6,7 +6,7 @@ from transformers import AutoTokenizer from auto_round import AutoRound -from auto_round.algorithms.quantization.base import BaseQuantizers +from auto_round.algorithms.quantization.base import BaseQuantizer from auto_round.algorithms.quantization.rtn.config import RTNConfig from auto_round.compressors.utils import block_forward @@ -76,7 +76,7 @@ def test_gguf_q2ks_torch_compile_iters0(self, tiny_qwen_model_path): def test_opt_rtn_uses_plain_block_forward(self): config = RTNConfig(bits=4, data_type="int", act_bits=16, disable_opt_rtn=False) - quantizer = BaseQuantizers(config) + quantizer = BaseQuantizer(config) quantizer.compress_context = SimpleNamespace(enable_torch_compile=True, device="cpu") assert quantizer._resolve_block_forward() is block_forward