Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions auto_round/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1193,14 +1193,22 @@ def set_attr(model, key, new_attr):


def get_module(module, key):
"""Get module from model by key name using PyTorch native API.

Missing paths return `None` to preserve legacy non-fail-fast behavior.
"""
Get module from model by key name using PyTorch native API with a
fallback to manual traversal for backward compatibility.
"""
try:
return module.get_submodule(key)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, we had changed it to use get_submodule API.

except (AttributeError, KeyError):
return None
pass

attrs = key.split(".")
for attr in attrs:
try:
module = getattr(module, attr)
except AttributeError:
return None
return module


def set_module(model, key, new_module):
Expand Down