Skip to content

remove gradient in state.positions after mace forward#540

Open
thomasloux wants to merge 1 commit intoTorchSim:mainfrom
thomasloux:fix/mace-positions-gradient
Open

remove gradient in state.positions after mace forward#540
thomasloux wants to merge 1 commit intoTorchSim:mainfrom
thomasloux:fix/mace-positions-gradient

Conversation

@thomasloux
Copy link
Copy Markdown
Collaborator

@thomasloux thomasloux commented Apr 9, 2026

Summary

Mace tends to add gradient to positions by running data["positions"].requires_grad_(True)
https://github.com/ACEsuit/mace/blob/main/mace/modules/models.py#L776
Because Mace TorchSim interface passes state.positions directly, the gradient flows back to torchsim algorithms. In the following relaxation script, this ends failing with a weird random failure. I add Claude Code diagnostic explaining the steps for those interested:

Root Cause Chain
The gradient contamination follows this path across BFGS steps:

Step 1 — The seed

  1. MACE model calls positions.requires_grad_(True) internally to compute forces via torch.autograd.grad. This is an in-place operation on state.positions (since
    data_dict["positions"] is the same tensor object).
  2. After the model returns, forces are detached, but state.positions still has requires_grad=True.

Step 1 — Hessian contamination

  1. frac_positions = torch.linalg.solve(deform_grad, state.positions) — inherits requires_grad from positions
  2. dpos = pos_new - pos_old — inherits from frac_positions
  3. Hessian update terms term1, term2 inherit from dpos
  4. state.hessian[idx] = H - term1 - term2 — this in-place IndexPut makes state.hessian part of the autograd graph (IndexPutBackward0)

Step 2 — Cell contamination

  1. H_group = state.hessian[...] — inherits requires_grad from contaminated hessian
  2. Eigendecomposition → step_group requires grad → step_dense requires grad → dr_cell requires grad
  3. cell_positions_new = state.cell_positions + dr_cell → requires grad
  4. deform_grad_new = torch.matrix_exp(...) → requires grad
  5. state.row_vector_cell = torch.bmm(ref_cell, deform_grad_new.T) → state.cell gets requires_grad=True

Step 2+ — Forces contamination

  1. _symmetrize_rank1 uses state.row_vector_cell (requires grad) as lattice
  2. symmetrize_rank1 returns a tensor with grad_fn (through inv(lattice) and @ lattice)
  3. vectors[start:end] = symmetrize_rank1(...) — in-place CopyBackwards makes forces require grad
  4. Forces with grad → torch.split in _split_state → views with SplitWithSizesBackward0
  5. post_init tries in-place modification on these views → CRASH

Script to reproduce:
uv sync --extra mace
uv run --with pymatgen --with moyopy python fix_relax_gradient.py

import torch
import torch_sim as ts
from torch_sim import Optimizer

from pymatgen.core.structure import Structure

# structure = Structure.from_file("Ti.cif")
# Create hcp titanium structure
structure = Structure(
    lattice=[[2.95, 0, 0], [-1.475, 2.556, 0], [0, 0, 4.68]],
    species=["Ti"] * 2,
    coords=[[0, 0, 0], [1 / 3, 2 / 3, 1 / 2]],
)
structure_list = [structure]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

optimizer_enum = Optimizer("bfgs")

convergence_fn = ts.generate_force_convergence_fn(
  force_tol=0.01, include_cell_forces=True
)

from mace.calculators import mace_mp
from torch_sim.models.mace import MaceModel

dtype = torch.float32
model = mace_mp(
  model="medium-mpa-0",
  dispersion=False,
  default_dtype=dtype,
  device=device,
  return_raw_model=True,
)
model = MaceModel(model, compute_stress=True, device=device)

autobatcher = ts.InFlightAutoBatcher(
  model, memory_scales_with="n_atoms", max_memory_scaler=5_000
)

init_kwargs: dict = {}
from torch_sim import CellFilter

init_kwargs["cell_filter"] = CellFilter(
  "frechet"
)

states = ts.initialize_state(structure_list, dtype=model.dtype, device=model.device)
states.constraints = ts.constraints.FixSymmetry.from_state(
  states, symprec=0.1
)
system = states

final_state = ts.optimize(
  system=system,
  model=model,
  optimizer=optimizer_enum,
  convergence_fn=convergence_fn,
  max_steps=100,
  init_kwargs=init_kwargs or None,
  autobatcher=autobatcher,
)

Checklist

Before a pull request can be merged, the following items must be checked:

  • Doc strings have been added in the Google docstring format.
  • Run ruff on your code.
  • Tests have been added for any new functionality or bug fixes.

@thomasloux
Copy link
Copy Markdown
Collaborator Author

I'm ok with other solutions like pass a copy of state.positions but removing gradient seems the lightest solution

@CompRhys
Copy link
Copy Markdown
Member

CompRhys commented Apr 9, 2026

LGTM, just awareness that as we move to external posture for mace on the mace 0.3.16 release this will likely change upstream. I hope this tooling PR will go in before then and I don't think that this issue is addressed over there.

c.f. #524

@thomasloux
Copy link
Copy Markdown
Collaborator Author

This should be good in the new external Mace Model because the positions used by the model is positions after wrapping, so when you run wrapped_positions = wrapped_positions.requires_grad_(True), it does not propagate to the original state.positions.
In this case we don't need to merge to PR.

@thomasloux
Copy link
Copy Markdown
Collaborator Author

But this problem is a good reminder to be careful about side effects when modifying input arguments

@CompRhys
Copy link
Copy Markdown
Member

CompRhys commented Apr 9, 2026

merge or close? Did you try checkout that PR and see if it fixes the issue? if so would appreciate a bump on the thread to prompt ilyas towards merging and cutting a new mace-torch release

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants