Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,12 @@ test = [
"physical-validation>=1.0.5",
"platformdirs>=4.0.0",
"psutil>=7.0.0",
"pymatgen>=2025.6.14",
"pytest-cov>=6",
"pytest>=8",
"spglib>=2.6",
"vesin[torch]>=0.5.3",
]
vesin = ["vesin[torch]>=0.5.3"]
io = ["ase>=3.26", "phonopy>=2.37.0", "pymatgen>=2025.6.14"]
io = ["ase>=3.26", "phonopy>=2.37.0", "pymatgen>=2026.3.23"]
symmetry = ["moyopy>=0.7.8"]
mace = ["mace-torch>=0.3.15"]
mattersim = ["mattersim>=1.2.2"]
Expand Down
115 changes: 115 additions & 0 deletions tests/models/test_electrostatics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
"""Tests for the electrostatics ModelInterface wrappers."""

import traceback # noqa: I001

import pytest
import torch
from ase.build import bulk

import torch_sim as ts
from tests.conftest import DEVICE, DTYPE
from tests.models.conftest import make_validate_model_outputs_test

try:
from torch_sim.models.electrostatics import DSFCoulombModel, EwaldModel, PMEModel
except (ImportError, OSError, RuntimeError):
pytest.skip(
f"nvalchemiops not installed: {traceback.format_exc()}",
allow_module_level=True,
)


def _make_charged_state(
device: torch.device = DEVICE,
dtype: torch.dtype = DTYPE,
) -> ts.SimState:
"""Build a small NaCl-like state with alternating +1/-1 site charges."""
atoms = bulk("NaCl", crystalstructure="rocksalt", a=5.64, cubic=True)
state = ts.io.atoms_to_state(atoms, device, dtype)
n = state.n_atoms
charges = torch.empty(n, dtype=dtype, device=device)
charges[::2] = 1.0
charges[1::2] = -1.0
state._atom_extras["partial_charges"] = charges # noqa: SLF001
return state


@pytest.fixture
def dsf_model() -> DSFCoulombModel:
return DSFCoulombModel(cutoff=8.0, alpha=0.2, device=DEVICE, dtype=DTYPE)


@pytest.fixture
def ewald_model() -> EwaldModel:
return EwaldModel(cutoff=8.0, accuracy=1e-6, device=DEVICE, dtype=DTYPE)


@pytest.fixture
def pme_model() -> PMEModel:
return PMEModel(cutoff=8.0, accuracy=1e-6, device=DEVICE, dtype=DTYPE)


def _add_partial_charges(state: ts.SimState) -> ts.SimState:
"""Inject alternating +/-0.5 site charges into a state."""
n = state.n_atoms
charges = torch.zeros(n, dtype=state.positions.dtype, device=state.device)
charges[::2] = 0.5
charges[1::2] = -0.5
state._atom_extras["partial_charges"] = charges # noqa: SLF001
return state


test_dsf_model_outputs = make_validate_model_outputs_test(
model_fixture_name="dsf_model",
device=DEVICE,
dtype=DTYPE,
state_modifiers=[_add_partial_charges],
)
test_ewald_model_outputs = make_validate_model_outputs_test(
model_fixture_name="ewald_model",
device=DEVICE,
dtype=DTYPE,
state_modifiers=[_add_partial_charges],
)
test_pme_model_outputs = make_validate_model_outputs_test(
model_fixture_name="pme_model",
device=DEVICE,
dtype=DTYPE,
state_modifiers=[_add_partial_charges],
)


def test_dsf_nonzero_energy() -> None:
"""Charged system should produce nonzero electrostatic energy."""
model = DSFCoulombModel(cutoff=8.0, alpha=0.2, device=DEVICE, dtype=DTYPE)
state = _make_charged_state()
out = model(state)
assert out["energy"].abs().item() > 0


def test_ewald_pme_energy_agreement() -> None:
"""Ewald and PME should give the same converged Coulomb energy."""
state = _make_charged_state()
ewald = EwaldModel(cutoff=8.0, accuracy=1e-6, device=DEVICE, dtype=DTYPE)
pme = PMEModel(cutoff=8.0, accuracy=1e-6, device=DEVICE, dtype=DTYPE)
torch.testing.assert_close(
ewald(state)["energy"], pme(state)["energy"], atol=1e-3, rtol=1e-3
)


def test_sum_model_lj_plus_dsf() -> None:
"""LJ + DSF should be additive through SumModel."""
from torch_sim.models.interface import SumModel
from torch_sim.models.lennard_jones import LennardJonesModel

lj = LennardJonesModel(
sigma=2.8, epsilon=0.01, cutoff=7.0, device=DEVICE, dtype=DTYPE
)
dsf = DSFCoulombModel(cutoff=8.0, alpha=0.2, device=DEVICE, dtype=DTYPE)
combined = SumModel(lj, dsf)
state = _make_charged_state()
lj_out = lj(state)
dsf_out = dsf(state)
sum_out = combined(state)
torch.testing.assert_close(sum_out["energy"], lj_out["energy"] + dsf_out["energy"])
torch.testing.assert_close(sum_out["forces"], lj_out["forces"] + dsf_out["forces"])
12 changes: 6 additions & 6 deletions torch_sim/_duecredit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,25 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, TypeVar


if TYPE_CHECKING:
from collections.abc import Callable

_F = TypeVar("_F", bound="Callable[..., Any]")


class InactiveDueCreditCollector:
"""Just a stub at the Collector which would not do anything."""

def _donothing(self, *_args: Any, **_kwargs: Any) -> None:
"""Perform no good and no bad."""

def dcite(
self, *_args: Any, **_kwargs: Any
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
def dcite(self, *_args: Any, **_kwargs: Any) -> Callable[[_F], _F]:
"""If I could cite I would."""

def nondecorating_decorator(func: Callable[..., Any]) -> Callable[..., Any]:
def nondecorating_decorator(func: _F) -> _F:
return func

return nondecorating_decorator
Expand Down Expand Up @@ -56,7 +56,7 @@ def _disable_duecredit(exc: Exception) -> None:

def dcite(
doi: str, description: str | None = None, *, path: str | None = None
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
) -> Callable[[_F], _F]:
"""Create a duecredit decorator from a DOI and description."""
kwargs: dict[str, Any] = (
{"description": description} if description is not None else {}
Expand Down
Loading
Loading