Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
bbe1fb1
refine devices
wenhuach21 Jun 1, 2026
1aa5efd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 1, 2026
fc14c01
refine devices
wenhuach21 Jun 2, 2026
705bd30
Merge branch 'refine_device_1' of https://github.com/intel/auto-round…
wenhuach21 Jun 2, 2026
9b19e74
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 2, 2026
a9f4b45
refine devices
wenhuach21 Jun 2, 2026
2e03b7c
refine devices
wenhuach21 Jun 2, 2026
5775f29
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 2, 2026
a4141d3
clean a little
wenhuach21 Jun 3, 2026
4db2409
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 3, 2026
2a32290
update
wenhuach21 Jun 3, 2026
5026a90
Merge branch 'refine_device_1' of https://github.com/intel/auto-round…
wenhuach21 Jun 3, 2026
9977980
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 3, 2026
d5cd6aa
fix ut
wenhuach21 Jun 3, 2026
182168f
Merge branch 'refine_device_1' of https://github.com/intel/auto-round…
wenhuach21 Jun 3, 2026
181e664
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 3, 2026
ab95d3d
fix code scan
wenhuach21 Jun 3, 2026
8882122
Merge branch 'refine_device_1' of https://github.com/intel/auto-round…
wenhuach21 Jun 3, 2026
a0bee00
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 3, 2026
905df02
fix some issues
wenhuach21 Jun 3, 2026
cc73855
Merge branch 'refine_device_1' of https://github.com/intel/auto-round…
wenhuach21 Jun 3, 2026
78c15bf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 3, 2026
20fab69
try to fix ut
wenhuach21 Jun 4, 2026
8db09e4
Merge branch 'refine_device_1' of https://github.com/intel/auto-round…
wenhuach21 Jun 4, 2026
7320efc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 4, 2026
a8d8c85
update
wenhuach21 Jun 4, 2026
20df398
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 4, 2026
632f200
update
wenhuach21 Jun 4, 2026
9affafe
update
wenhuach21 Jun 4, 2026
0aef05b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 4, 2026
ff60860
update
wenhuach21 Jun 4, 2026
8f563c5
Merge branch 'refine_device_1' of https://github.com/intel/auto-round…
wenhuach21 Jun 4, 2026
6efade3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 4, 2026
542cf94
Potential fix for pull request finding
wenhuach21 Jun 4, 2026
919594b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 4, 2026
7dc8fc2
Potential fix for pull request finding
wenhuach21 Jun 4, 2026
07e86ab
Potential fix for pull request finding
wenhuach21 Jun 4, 2026
e351f98
Merge branch 'main' into refine_device_1
wenhuach21 Jun 4, 2026
cf17da9
hot fix for gemma4-12b
wenhuach21 Jun 4, 2026
7c3573d
Merge branch 'refine_device_1' of https://github.com/intel/auto-round…
wenhuach21 Jun 4, 2026
65ef503
update
wenhuach21 Jun 4, 2026
5f227f5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 4, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions auto_round/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,7 @@ def tune(args):
"lm-eval is required for evaluation, please install it with `pip install 'lm-eval>=0.4.2'`",
)

from auto_round.utils import detect_device, get_library_version, logger
from auto_round.utils import get_library_version, get_major_device, logger

if args.low_cpu_mem_usage:
logger.warning(
Expand All @@ -638,8 +638,6 @@ def tune(args):
if "marlin" in args.format and args.asym is True:
raise RuntimeError("marlin backend only supports sym quantization, please remove --asym")

device_str, use_auto_mapping = get_device_and_parallelism(args.device_map)

if args.enable_torch_compile:
logger.info(
"`torch.compile` is enabled to reduce tuning costs. "
Expand Down Expand Up @@ -809,7 +807,7 @@ def tune(args):
clear_memory()

# ======================= Model evaluation =======================
run_model_evaluation(model, tokenizer, autoround, folders, formats, device_str, args)
run_model_evaluation(model, tokenizer, autoround, folders, formats, args)


def setup_eval_parser():
Expand Down
3 changes: 2 additions & 1 deletion auto_round/algorithms/quantization/adam_round/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from auto_round.algorithms.quantization.sign_round.quantizer import SignRoundQuantizer
from auto_round.schemes import QuantizationScheme
from auto_round.utils import check_is_cpu, htcore, is_hpex_available
from auto_round.utils.device_manager import device_manager


class AdamRoundQuantizer(SignRoundQuantizer):
Expand All @@ -37,7 +38,7 @@ def _get_optimizer(self, optimizer):

def _get_scaler(self):
scaler = None
if self.model_context.amp and not check_is_cpu(self.compress_context.device):
if self.model_context.amp and not check_is_cpu(device_manager.device):
from torch.cuda.amp import GradScaler

scaler = GradScaler(init_scale=1024, growth_interval=100000)
Expand Down
5 changes: 3 additions & 2 deletions auto_round/algorithms/quantization/awq/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
set_amax_for_all_moe_layers,
set_module,
)
from auto_round.utils.device_manager import device_manager
from auto_round.wrapper import WrapperLinear
from auto_round.wrapper import WrapperMultiblock as _WrapperMultiblock

Expand Down Expand Up @@ -325,9 +326,9 @@ def quantize_layer(self, name: str, dtype: torch.dtype = None) -> None:
if dtype is not None:
m = m.to(dtype)

m = convert_module_to_hp_if_necessary(m, self.model_context.amp_dtype, self.compress_context.device)
m = convert_module_to_hp_if_necessary(m, self.model_context.amp_dtype, device_manager.device)
set_module(self.model, name, m)
tuning_device = m.tuning_device if hasattr(m, "tuning_device") else self.compress_context.device
tuning_device = m.tuning_device if hasattr(m, "tuning_device") else device_manager.device

try:
m = m.to(tuning_device)
Expand Down
15 changes: 8 additions & 7 deletions auto_round/algorithms/quantization/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
get_module,
set_module,
)
from auto_round.utils.device_manager import device_manager
from auto_round.wrapper import WrapperLinear


Expand Down Expand Up @@ -220,7 +221,7 @@ def _quantize_embedding_layer(self):
# Attempt quantization on GPU, fall back to CPU if OOM
try:
weight, scale, zp = quant_func(
module.weight.to(dtype=dtype, device=self.compress_context.device),
module.weight.to(dtype=dtype, device=device_manager.device),
**{
k: config.get(k, None)
for k in ["bits", "group_size", "super_bits", "super_group_size", "scale_dtype"]
Expand Down Expand Up @@ -259,7 +260,7 @@ def _quantize_embedding_layer(self):
del weight
del scale
del zp
clear_memory(device_list=self.compress_context.device_list)
clear_memory(device_list=device_manager.device_list)

return is_quantized

Expand Down Expand Up @@ -294,9 +295,9 @@ def quantize_block(
def quantize_layer_via_rtn(self, layer_name: str, disable_opt_rtn: bool | None = None) -> None:
"""Quantize one layer with RTN and handle optional immediate pack/save."""
layer = get_module(self.model, layer_name)
layer = convert_module_to_hp_if_necessary(layer, self.model_context.amp_dtype, self.compress_context.device)
layer = convert_module_to_hp_if_necessary(layer, self.model_context.amp_dtype, device_manager.device)
set_module(self.model, layer_name, layer)
tuning_device = layer.tuning_device if hasattr(layer, "tuning_device") else self.compress_context.device
tuning_device = layer.tuning_device if hasattr(layer, "tuning_device") else device_manager.device
if (
self.compress_context.is_immediate_packing
and self.compress_context.formats[0].is_gguf()
Expand Down Expand Up @@ -424,7 +425,7 @@ def _get_block_outputs(
"""
diffusion_fn = getattr(self, "_get_diffusion_block_outputs", None)
if getattr(self.model_context, "is_diffusion", False):
device = device_override if device_override is not None else self.compress_context.device
device = device_override if device_override is not None else device_manager.device
return self._get_diffusion_block_outputs(
block,
input_ids,
Expand Down Expand Up @@ -455,7 +456,7 @@ def _get_block_outputs(
tmp_input_others,
self.model_context.amp,
self.model_context.amp_dtype,
self.compress_context.device,
device_manager.device,
).to(self.compress_context.cache_device)
if save_output:
if self.batch_size == 1:
Expand Down Expand Up @@ -487,7 +488,7 @@ def _resolve_block_forward(self):
elif self.compress_context.enable_torch_compile:
compiled = self.__dict__.get("_compiled_block_forward")
if compiled is None:
compiled = compile_func(block_forward, self.compress_context.device)
compiled = compile_func(block_forward, device_manager.device)
self._compiled_block_forward = compiled
self._resolved_block_forward = compiled
else:
Expand Down
31 changes: 0 additions & 31 deletions auto_round/algorithms/quantization/rtn/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,49 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from typing import Any, Callable, Optional, Union

import accelerate
import torch

from auto_round.algorithms.quantization.base import BaseQuantizers
from auto_round.algorithms.quantization.rtn.config import RTNConfig
from auto_round.algorithms.quantization.sign_round.quantizer import SignRoundQuantizer
from auto_round.algorithms.quantization.utils import register_imatrix_hooks
from auto_round.compressors.utils import (
IndexSampler,
block_forward,
check_need_act_calibration,
check_skippable_keywords,
collect_best_params,
get_shared_keys,
infer_bits_by_data_type,
init_cache,
reset_params,
set_layer_config,
)
from auto_round.data_type.utils import update_block_global_scale_if_needed
from auto_round.logger import logger
from auto_round.utils import (
check_to_quantized,
get_lm_head_name,
get_module,
htcore,
is_auto_device_mapping,
is_hpex_available,
memory_monitor,
set_amax_for_all_moe_layers,
set_module,
)
from auto_round.utils.device import (
clear_memory_if_reached_threshold,
get_major_device,
parse_available_devices,
set_auto_device_map_for_block_with_tuning,
set_non_auto_device_map,
)
from auto_round.wrapper import WrapperMultiblock, unwrapper_block, unwrapper_layer, wrapper_block


class RTNQuantizer(BaseQuantizers):
Expand Down
18 changes: 6 additions & 12 deletions auto_round/algorithms/quantization/sign_round/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,22 +34,16 @@
from auto_round.data_type.utils import reshape_pad_tensor_by_group_size, update_fused_layer_global_scales
from auto_round.logger import logger
from auto_round.utils import (
check_to_quantized,
compile_func,
get_module,
htcore,
is_auto_device_mapping,
is_hpex_available,
memory_monitor,
mv_module_from_gpu,
set_amax_for_all_moe_layers,
set_module,
to_device,
)
from auto_round.utils.device import (
clear_memory_if_reached_threshold,
set_auto_device_map_for_block_with_tuning,
)
from auto_round.utils.device import clear_memory_if_reached_threshold
from auto_round.utils.device_manager import device_manager
from auto_round.utils.distributed import setup_ddp_if_needed_
from auto_round.wrapper import WrapperLinear, unwrapper_block, unwrapper_layer, wrapper_block

Expand Down Expand Up @@ -168,7 +162,7 @@ def quantize_block(
best_params: Best quantization parameters found during optimization.
Empty dict if no trainable parameters were found.
"""
device = self.compress_context.device
device = device_manager.device

quantized_layer_names, unquantized_layer_names = self.wrapper_block(
block,
Expand Down Expand Up @@ -247,7 +241,7 @@ def quantize_block(
if self.gradient_accumulate_steps != 1 and not self.attention_mask:
whole_indices = torch.arange(global_batch_size)
num_elm = self._get_current_num_elm(input_ids, whole_indices)
setup_ddp_if_needed_(self, block, self.compress_context.device_list)
setup_ddp_if_needed_(self, block, device_manager.device_list)
index_sampler = IndexSampler(nsamples, global_batch_size)
batch_size = self.batch_size
for i in range(self.iters):
Expand All @@ -270,13 +264,13 @@ def quantize_block(

if mid_iter_mem_check:
# clear memory to avoid OOM due to memory fragmentation
clear_memory_if_reached_threshold(threshold=0.5, device_list=self.compress_context.device_list)
clear_memory_if_reached_threshold(threshold=0.5, device_list=device_manager.device_list)

self._scale_loss_and_backward(scaler, loss)

if mid_iter_mem_check:
# clear memory to avoid OOM due to memory fragmentation
clear_memory_if_reached_threshold(threshold=0.8, device_list=self.compress_context.device_list)
clear_memory_if_reached_threshold(threshold=0.8, device_list=device_manager.device_list)

if i == 0:
init_loss = total_loss
Expand Down
7 changes: 4 additions & 3 deletions auto_round/auto_scheme/delta_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
to_dtype,
)
from auto_round.utils.device import MemoryMonitor
from auto_round.utils.device_manager import get_current_device_manager
from auto_round.utils.offload import OffloadManager
from auto_round.wrapper import WrapperLinear

Expand Down Expand Up @@ -441,8 +442,7 @@ def backward_pre_hook(module, grad_input):
"""Hook executed before backward propagation."""
global last_grad_input
last_grad_input = grad_input
if torch.cuda.is_available():
torch.cuda.synchronize()
get_current_device_manager().synchronize()
raise MyCustomError("Interrupt backward pass")

for data in dataloader:
Expand Down Expand Up @@ -599,7 +599,8 @@ def get_score_for_scheme(
with torch.no_grad():
if low_gpu_mem_usage:
device = m.tuning_device if hasattr(m, "tuning_device") else major_device
if "cuda" in device or "xpu" in device:
# Any non-CPU device (cuda/xpu/hpu/...) is consolidated to the major device.
if str(device).split(":")[0] not in ("cpu", "meta", "disk"):
device = major_device
else:
device = m.weight.device
Expand Down
8 changes: 8 additions & 0 deletions auto_round/autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,14 @@ def __new__(
... # ...
... }
"""
if torch.mps.is_available() and (
device_map == 0 or device_map == "0" or device_map is None or device_map == "auto"
):
logger.warning(
"MPS detected. Using CPU by default to avoid potential memory issues. "
"Set --device_map=mps to force MPS usage."
)
device_map = "cpu"

local_args = {k: v for k, v in locals().items() if k not in cls.SKIP_ARGS}
if extra_config is not None:
Expand Down
3 changes: 2 additions & 1 deletion auto_round/calibration/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from auto_round.calibration.llm import LLMCalibrator
from auto_round.calibration.register import register_calibrator
from auto_round.logger import logger
from auto_round.utils.device_manager import device_manager
from auto_round.utils.model import wrap_block_forward_positional_to_kwargs


Expand Down Expand Up @@ -95,7 +96,7 @@ def calib(self, nsamples: int, bs: int) -> None:
)
exit(-1)

target_device = c.compress_context.device
target_device = device_manager.device
if pipe.device != torch.device(target_device):
pipe.to(target_device)
pipeline_fn = getattr(pipe, "_autoround_pipeline_fn", None)
Expand Down
3 changes: 2 additions & 1 deletion auto_round/calibration/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch

from auto_round.utils import clear_memory, to_device, to_dtype
from auto_round.utils.device_manager import device_manager

__all__ = ["split_inputs", "preprocess_block_inputs"]

Expand Down Expand Up @@ -84,7 +85,7 @@ def preprocess_block_inputs(
is_diffusion=model_context.is_diffusion,
shared_cache_keys=getattr(model_context, "shared_cache_keys", ()),
)
clear_memory(device_list=compress_context.device_list)
clear_memory(device_list=device_manager.device_list)
tmp_dtype = model_context.amp_dtype if model_context.amp else torch.float32
if input_ids is not None:
input_ids = to_device(input_ids, compress_context.cache_device)
Expand Down
11 changes: 6 additions & 5 deletions auto_round/calibration/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
to_dtype,
)
from auto_round.utils.device import parse_available_devices
from auto_round.utils.device_manager import device_manager


@register_calibrator("llm")
Expand Down Expand Up @@ -96,9 +97,9 @@ def collect(self, block_names, nsamples, layer_names=None, last_cache_name=None)
c.model_context.model, device_map=c.model_context.model.hf_device_map
)
else:
if str(c.model_context.model.device) == "cpu" and (not c.compress_context.device.startswith("hpu")):
if str(c.model_context.model.device) == "cpu" and (not device_manager.device.startswith("hpu")):
no_split_modules = list(getattr(c.model_context.model, "_no_split_modules", []))
devices = parse_available_devices(c.compress_context.device_map)
devices = parse_available_devices(device_manager.device_map)

max_memory = get_max_memory()
new_max_memory = {}
Expand All @@ -113,7 +114,7 @@ def collect(self, block_names, nsamples, layer_names=None, last_cache_name=None)
device = 0
else:
raise ValueError(
f"Unsupported device {device} in device_map: {c.compress_context.device_map}"
f"Unsupported device {device} in device_map: {device_manager.device_map}"
)
if device not in max_memory:
continue
Expand Down Expand Up @@ -164,7 +165,7 @@ def collect(self, block_names, nsamples, layer_names=None, last_cache_name=None)
else:
raise
else:
c.model_context.model = c.model_context.model.to(c.compress_context.device)
c.model_context.model = c.model_context.model.to(device_manager.device)

all_inputs = self.cache_inter_data(
block_names, nsamples, layer_names=layer_names, last_cache_name=last_cache_name
Expand All @@ -186,7 +187,7 @@ def collect(self, block_names, nsamples, layer_names=None, last_cache_name=None)
)
accelerate.hooks.remove_hook_from_submodules(c.model_context.model)
c.model_context.model = mv_module_from_gpu(c.model_context.model)
clear_memory(device_list=c.compress_context.device_list)
clear_memory(device_list=device_manager.device_list)
# On cpu, we use rtn mode for layers in layer_names (post v0.51).
all_inputs = self.cache_inter_data(
block_names, nsamples, layer_names=[], last_cache_name=last_cache_name
Expand Down
Loading
Loading