Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,14 @@ def _convert_huggingface_to_jax_weights(base_model_path, model_params, mem_info)
if key.endswith("_scale_inv"):
raise ValueError("fp8 checkpoint is not supported.")
if ds_ckpt.is_key_allowed(key, ds_ckpt.MTP_KEYS_TO_SKIP):
mapped_key = ds_ckpt.hf_to_maxtext_mapping(layer, num_experts, first_num_dense_layers, base_num_decoder_layers)[
key
]
chkpt_vars[mapped_key] = f.get_tensor(key)
mapped_key = ds_ckpt.hf_to_maxtext_mapping(
layer, num_experts, first_num_dense_layers, base_num_decoder_layers
).get(key)
if mapped_key:
Copy link
Copy Markdown
Collaborator

@RissyRan RissyRan Apr 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we add else branch for debugging in the future?

if mapped_key:
                chkpt_vars[mapped_key] = f.get_tensor(key)
else:
                # This catches keys that are allowed but missing from the mapping dictionary
                print(f"[DEBUG] Key allowed but no mapping found: {key}")

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, done

chkpt_vars[mapped_key] = f.get_tensor(key)
Comment thread
gagika marked this conversation as resolved.
else:
# This catches keys that are allowed but missing from the mapping dictionary
max_logging.log(f"Debug: Allowed key '{key}' (layer {layer}) has no mapping in hf_to_maxtext_mapping.")

logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3))

Expand Down
Loading