Skip to content

Add dtype parameter to kspaceFirstOrder() (#695)#716

Open
waltsims wants to merge 7 commits intomasterfrom
feature-data-cast-modern-api
Open

Add dtype parameter to kspaceFirstOrder() (#695)#716
waltsims wants to merge 7 commits intomasterfrom
feature-data-cast-modern-api

Conversation

@waltsims
Copy link
Copy Markdown
Owner

@waltsims waltsims commented May 4, 2026

Closes #695.

What

Exposes precision control on the modern unified API. Pythonic / numpy-idiomatic naming and accepted input forms:

input resolved precision
None (default) np.float64
np.float64 / "float64" / "double" / float / np.dtype("f8") np.float64
np.float32 / "float32" / "single" / np.dtype("f4") np.float32
"off" (legacy MATLAB alias) np.float64
anything else (np.float16, np.complex64, "quad", …) ValueError

The MATLAB aliases ("off", "single", "double") are kept as compat shortcuts for users porting from the legacy SimulationOptions.data_cast or MATLAB k-Wave's DataCast. Everything else uses np.dtype() for normalization, matching the broader numpy/scipy/torch convention.

Why dtype instead of data_cast

data_cast is a MATLAB term. The numpy ecosystem (numpy, pandas, jax, torch) uses dtype and accepts dtype-like inputs broadly. The modern unified API is a fresh design — it should follow the Python idiom rather than the MATLAB one. The MATLAB-style strings still work, so MATLAB users lose nothing.

How

Python backend plumbs dtype through Simulation, which now stores self._dtype and uses it for every state-array allocation: p, u, rho_split, sensor-data buffers, PML arrays, source signal buffers, and the _expand_to_grid helper for sound_speed / density / alpha_coeff / BonA / p0. Default behavior unchanged (float64 everywhere).

C++ backend intentionally has no effect — the binary uses fixed internal precision regardless of HDF5 input dtype. Setting dtype to anything other than np.float64 with backend='cpp' emits a UserWarning explaining this and pointing users at backend='python' for precision control.

Test plan

New file tests/test_data_cast.py (22 tests):

  • test_python_backend_float64_inputs parametrized over [None, np.float64, "float64", "double", float, "off", np.dtype("f8")] — every form resolves to float64 output
  • test_python_backend_float32_inputs parametrized over [np.float32, "float32", "single", np.dtype("f4")] — every form resolves to float32 output
  • test_default_dtype_is_float64 — calling without the kwarg gives float64 (back-compat)
  • test_invalid_dtype_raises parametrized over [np.float16, np.complex64, "float16", "complex64", "quad", 42, "garbage"] — all raise ValueError
  • test_python_single_vs_double_numerical_agreement — single and double runs agree to within 1e-4 relative error
  • test_cpp_backend_warns_on_non_float64_dtypeUserWarning fires before binary runs
  • test_cpp_backend_silent_on_default_dtype — no warning on default

Wider suite verified (62 tests): test_native_solver, test_ivp_homogeneous_medium, test_issue_664_alpha_power_near_unity all pass.

  • CI green

Greptile Summary

This PR exposes precision control on the modern kspaceFirstOrder() API via a dtype parameter, wiring float32/float64 through Simulation state arrays, PML, k-space operators, and sensor buffers. The _resolve_dtype resolver, MATLAB-alias compatibility, and the C++ backend warning are all correctly implemented.

One normalization gap remains: in Simulation.__init__, the elif dtype in (np.float32, np.float64) branch stores the raw input as self._dtype without converting it to a scalar-type object, causing the subsequent is check for _complex_dtype to silently select complex128 instead of complex64 when a dtype instance (e.g. np.dtype("float32")) is passed directly. Additionally, dt_over_rho0 is still computed as a Python float divided by a float32 array — on numpy < 2 this yields float64 regardless of self._dtype, causing self.u and self.rho_split to be rebound to float64 after the first step; the new .astype() casts on FFT outputs mitigate the spectral-operator drift but not this direct scalar-arithmetic promotion, so the p_final/p_max/p_rms float32 assertions will likely fail on numpy 1.26.x (CI green is still unchecked).

Confidence Score: 3/5

Safe to merge for numpy ≥ 2 users; float32 tests are likely to fail on numpy 1.26.x due to unresolved Python-scalar promotion in dt_over_rho0

The dtype plumbing is thorough and well-tested, and _resolve_dtype, the C++ warning, and alias handling are all correct. However, CI green is still unchecked, and the previously-flagged dt_over_rho0 root cause (Python float / float32 → float64 on numpy < 2) is not addressed in this diff — the .astype() casts on FFT round-trips do not prevent self.u and self.rho_split from becoming float64 mid-step, which would cause p_final dtype assertions to fail on numpy 1.26.x. A minor dtype-instance vs type-object normalization issue in Simulation.init also needs a one-line fix.

kwave/solvers/kspace_solver.py — dt_over_rho0 computation (line 663) and mass-conservation update (line 716) still use bare self.dt (Python float), causing float64 promotion of state arrays on numpy < 2

Important Files Changed

Filename Overview
kwave/kspaceFirstOrder.py Adds dtype parameter and _resolve_dtype resolver; validation, MATLAB aliases, and C++ backend warning all look correct
kwave/solvers/kspace_solver.py Adds _array_sum, casts PML/k_list/FFT results, and threads dtype through all fields; Simulation.__init__ stores raw input in self._dtype without normalizing to a type object, so the is check for _complex_dtype silently produces complex128 when a dtype instance is passed directly
tests/test_data_cast.py Comprehensive 22-test suite covering dtype resolution, precision preservation, BonA path, staggered velocity, C++ backend warning, and single-vs-double numerical agreement

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["kspaceFirstOrder(dtype=...)"] --> B["_resolve_dtype(value)"]
    B -->|"None / 'off'"| C["np.float64"]
    B -->|"np.dtype(value).type"| D{"float32 or float64?"}
    D -->|"float32"| E["np.float32"]
    D -->|"float64"| C
    D -->|"other / TypeError"| F["ValueError"]
    E --> G["Simulation(dtype=np.float32)"]
    C --> G2["Simulation(dtype=np.float64)"]
    G --> H["self._dtype = np.float32\nself._complex_dtype = np.complex64"]
    G2 --> H2["self._dtype = np.float64\nself._complex_dtype = np.complex128"]
    H --> I["setup():\nc0, rho0, p, u, rho_split,\npml, k_list cast to float32\ndt_over_rho0 = self.dt(Python float)/rho"]
    I --> J["step():\nP = fftn(p).astype(complex64)\ngrad_p = _diff(...).astype(float32)\ndiv_u = _diff(...).astype(float32)\nu, rho_split, p updated"]
    J --> K["run():\nsensor_data (float32 buffers)\np_final = self.p.copy()"]
    K --> L["backend='cpp': UserWarning if dtype≠float64"]
Loading

Comments Outside Diff (2)

  1. kwave/solvers/kspace_solver.py, line 663 (link)

    P1 dt_over_rho0 computed as Python float / float32 → silently float64 on numpy < 2

    self.dt is stored as float(self.kgrid.dt) (a Python float, equivalent to float64). Dividing by a float32 array gives a float64 result under numpy < 2 (NEP 50 changed this in numpy 2.0). As a consequence, on every step(): self.dt_over_rho0[i] * grad_p_i (line 706) is float64 × float32 → float64, so self.u[i] is rebound to a float64 array after the very first step. The same Python-scalar promotion also affects line 716 (self.dt * self.rho0 * div_u_i * nl_factor), so rho_split[i] and, through rho_total, self.p also become float64. The sensor-data buffer (sensor_data["p"]) is pre-allocated float32 and silently narrows values on in-place assignment, so result["p"] tests pass, but result["p_final"] (line 773) is self.p[interior].copy() — no narrowing — and will be float64 even when dtype=np.float32 is requested, breaking the dtype contract on numpy < 2.

  2. kwave/solvers/kspace_solver.py, line 429-447 (link)

    P1 Complex k-space operators still promote to complex128 on numpy < 2

    k_list entries are now correctly cast to self._dtype (float32 when requested), but self.c_ref and self.dt are Python floats (float64), so self.c_ref * k_mag * self.dt / 2 is float64 on numpy < 2, making kappa and source_kappa float64. The 1j * Python complex literal then forces op_grad_list and op_div_list to complex128 rather than complex64. Cast kappa/source_kappa to self._dtype and the final operators to self._complex_dtype after construction.

Reviews (5): Last reviewed commit: "Add staggered-velocity dtype regression ..." | Re-trigger Greptile

Exposes precision control on the modern unified API to match the
legacy SimulationOptions.data_cast and MATLAB k-Wave's DataCast.

  data_cast='off'    -> np.float64 (default; matches legacy)
  data_cast='double' -> np.float64 (alias for 'off', MATLAB compat)
  data_cast='single' -> np.float32 (~half memory, faster, lower accuracy)

Python backend: plumbs through Simulation, which now stores self._dtype
and uses it for all state arrays (p, u, rho_split, sensor_data buffers,
PML arrays, alpha_coeff/BonA/p0 expansions, source signal buffers).
Default behavior unchanged (float64 everywhere).

C++ backend: data_cast has no effect — the binary uses fixed internal
precision regardless of HDF5 input dtype. Setting anything other than
'off'/'double' with backend='cpp' emits a UserWarning explaining this
and pointing users at backend='python' for precision control.

Tests: 8 new in tests/test_data_cast.py covering output dtype matches
request, default behavior unchanged, invalid value raises, single vs
double numerical agreement within float32 tolerance, and the C++
warn/silent paths. Wider suite (62 tests across native_solver,
ivp_homogeneous, issue_664) still passes.

Closes #695.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@codecov
Copy link
Copy Markdown

codecov Bot commented May 4, 2026

Codecov Report

❌ Patch coverage is 83.58209% with 11 lines in your changes missing coverage. Please review.
✅ Project coverage is 75.23%. Comparing base (a1b054c) to head (d545952).

Files with missing lines Patch % Lines
kwave/solvers/kspace_solver.py 77.08% 10 Missing and 1 partial ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master     #716      +/-   ##
==========================================
+ Coverage   74.82%   75.23%   +0.40%     
==========================================
  Files          56       56              
  Lines        8096     8132      +36     
  Branches     1577     1586       +9     
==========================================
+ Hits         6058     6118      +60     
+ Misses       1422     1395      -27     
- Partials      616      619       +3     
Flag Coverage Δ
3.10 75.19% <83.58%> (+0.39%) ⬆️
3.11 75.19% <83.58%> (+0.39%) ⬆️
3.12 75.19% <83.58%> (+0.39%) ⬆️
3.13 75.19% <83.58%> (+0.39%) ⬆️
macos-latest 75.06% <83.58%> (+0.30%) ⬆️
ubuntu-latest 75.11% <83.58%> (+0.35%) ⬆️
windows-latest 74.98% <83.58%> (+0.22%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Comment thread kwave/solvers/kspace_solver.py Outdated
Comment thread tests/test_data_cast.py
Make the precision parameter Pythonic instead of MATLAB-stringly-typed.
The numpy ecosystem's convention is to accept dtype-like inputs broadly
(numpy types, strings, dtype objects), and the modern API should
follow that idiom rather than the legacy SimulationOptions.data_cast
naming.

Accepted forms (resolved via _resolve_dtype, which uses np.dtype()):

  None / np.float64 / "float64" / "double" / float / "off" / np.dtype("f8")
    -> np.float64 (default)
  np.float32 / "float32" / "single" / np.dtype("f4")
    -> np.float32

The MATLAB aliases ("off", "single", "double") are kept as compat
shortcuts so users porting from the legacy API or MATLAB k-Wave have
zero friction. Anything resolving to a non-float32/float64 type
(np.float16, np.complex64, etc.) raises ValueError -- the solver
isn't validated for those.

C++ backend warns when dtype is not np.float64 (binary uses fixed
internal precision regardless).

Tests: 22 (was 8) parametrized over every input form. Wider suite
(62 tests) still passes.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@waltsims waltsims changed the title Add data_cast parameter to kspaceFirstOrder() (#695) Add dtype parameter to kspaceFirstOrder() (#695) May 4, 2026
@waltsims
Copy link
Copy Markdown
Owner Author

waltsims commented May 4, 2026

@greptile-apps re-review

Comment thread kwave/solvers/kspace_solver.py Outdated
waltsims and others added 3 commits May 4, 2026 04:20
Greptile P2 (test for p_final dtype) caught a real bug: with
dtype='single', p_final came back as float64 even though sensor_data
buffers (p, p_max, p_min, p_rms) were correctly float32.

Root cause: two sources of float64 leaking into the hot loop:

1. xp.fft.fftfreq returns float64; k_list, kappa, op_grad/div_list,
   _k_mag inherited it. _diff's FFT round-trip (float64 op * complex64
   field) upcasts to complex128, .real => float64. Result: self.p and
   self.u rebound to float64 mid-step() despite being allocated as
   float32. p_final = self.p[interior].copy() picked up float64.
   sensor_data buffers stayed float32 because writes are in-place into
   the pre-allocated buffer (silent narrowing on assignment).
2. get_pml returns float64 unconditionally; the per-step pml multiply
   was a second upcast path independent of (1).

Both cast sites now apply .astype(self._dtype) at construction time,
keeping the entire compute pipeline in the user's requested precision.

Test updated: float32 / float64 input parametrizations now request
('p', 'p_final', 'p_max', 'p_min', 'p_rms') and assert every field's
dtype matches. Verified: float32 inputs => all five fields float32;
float64 => all five float64.

Bonus: helpful error for torch / jax / tensorflow dtype objects via
duck-typed __module__ check (no framework imports needed); cupy works
for free since cp.float32 is np.float32.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…odern-api

# Conflicts:
#	kwave/solvers/kspace_solver.py
Greptile spotted a third dtype-drift path: ``sum(rho_split)`` in
``_nl_factor`` and the equation-of-state line starts with Python ``int 0``.
Under numpy < 2 (NEP 50), ``int + float32 -> float64``, so:

  nl_factor = (2 * sum(rho_split) + rho0) / rho0

is float64 even when rho_split is float32.  The product
``rho0 * div_u_i * nl_factor`` in mass conservation then upcasts the
rho_split arrays to float64 on the very first step.  Specifically affects
any simulation that enables BonA.

Fix: ``_array_sum`` helper that starts the accumulator from
``arrays[0]`` so the dtype is preserved.  Used in both call sites
(_nl_factor lambda and equation-of-state rho_total).

Test added: test_python_backend_dtype_preserved_with_nonlinearity exercises
the BonA path with dtype=np.float32 and asserts p / p_final / p_max all
remain float32.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@waltsims
Copy link
Copy Markdown
Owner Author

waltsims commented May 4, 2026

@greptile-apps re-review

End-to-end verification under numpy 1.26.4 (with my dtype tests requesting
('p', 'p_final', 'p_max', 'p_min', 'p_rms')) showed self.p still upcast to
float64 mid-step despite all the prior precision fixes.

Root cause: numpy < 2's `np.fft.fftn` always returns complex128 regardless
of input precision -- a known difference resolved in numpy 2 (NEP 50 era).
The k-space ops (op_grad/div_list, unstagger_ops) being complex64 isn't
enough; multiplying complex64 by complex128 upcasts to complex128, and
.real -> float64 propagates back into self.p.

Fix: introduce self._complex_dtype (complex64 / complex128 matching
self._dtype), and cast every fftn / ifftn .real result back to the
intended precision. Three call sites: step() momentum loop, sensor-data
unstagger, _diff helper. ``copy=False`` makes the cast a no-op when the
dtype already matches (numpy 2, cupy).

Greptile flagged dt_over_rho0, kappa, and sum() as P1 dtype-promotion
paths. Verified empirically on numpy 1.26.4 that none of those upcast in
practice (Python float / int / complex are weak types in numpy 1.x as
well) -- but the test failures Greptile would have seen had the same
visible signature, so the diagnosis pointed at the right region.

Tests: 86 (24 dtype + 62 wider suite) pass on numpy 1.26.4 AND numpy 2.2.6.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@waltsims
Copy link
Copy Markdown
Owner Author

waltsims commented May 4, 2026

Greptile flagged dt_over_rho0, kappa / op_grad/div_list, and sum() as P1 dtype-promotion paths under numpy<2. Verified empirically on numpy 1.26.4 (with NEP 50 disabled): none of those promote in practice — Python float/int/complex are weak types in numpy 1.x as well as 2.x.

But Greptile was right that the diff was unsafe on numpy<2 — the actual culprit was different: np.fft.fftn always returns complex128 in numpy<2, regardless of input precision (a known difference resolved in numpy 2). My earlier complex64 op_grad/div_list weren't enough: complex64 op × complex128 P → complex128, .real → float64 propagating back into self.p.

Fixed in 22ec78d: self._complex_dtype (complex64/complex128) + cast every fftn/ifftn .real result back to the intended precision (3 call sites). Tests now pass on numpy 1.26.4 AND numpy 2.2.6 (86 each).

So same conclusion (not safe on numpy<2 before this commit), different mechanism. Re-triggering review now.

@greptile-apps re-review

Comment on lines 675 to 676
# Spectral shift: move velocity from staggered (mid-cell) to collocated (pressure) grid
self.unstagger_ops = [xp.exp(-1j * self.k_list[ax] * self.spacing[ax] / 2) for ax in range(self.ndim)]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 unstagger_ops also built with 1j * literal → always complex128 on numpy < 2

The same Python-literal promotion issue as op_grad_list/op_div_list applies here: -1j * self.k_list[ax] is complex128 on numpy < 2, regardless of k_list dtype. Cast to self._complex_dtype after construction to keep the unstaggering FFT in the correct precision.

Suggested change
# Spectral shift: move velocity from staggered (mid-cell) to collocated (pressure) grid
self.unstagger_ops = [xp.exp(-1j * self.k_list[ax] * self.spacing[ax] / 2) for ax in range(self.ndim)]
# Spectral shift: move velocity from staggered (mid-cell) to collocated (pressure) grid
self.unstagger_ops = [xp.exp(-1j * self.k_list[ax] * self.spacing[ax] / 2).astype(self._complex_dtype) for ax in range(self.ndim)]

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Verified empirically on numpy 1.26.4: -1j * k_f32 * spacing / 2 evaluates to complex64 and np.exp(...) preserves complex64 — Python complex is a weak type just like Python float/int, so the multiplication chain stays in complex64 when k_list is float32. So unstagger_ops is already complex64 without an explicit cast.

Additionally, the prior commit (22ec78d) already added .astype(self._dtype, copy=False) to the staggered ifftn .real line in step(), so the recorded output stays float32 even if unstagger_ops were ever to change precision.

Added a regression test (d545952) that records ('p', 'ux_staggered', 'uy_staggered') with dtype=np.float32 and asserts each is float32 — locks the behavior in for both numpy 1.26.4 and 2.2.6.

@greptile-apps re-review

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The empirical result doesn't match what NEP 50's own migration table documents. The NEP 50 spec explicitly shows the old (numpy < 2) vs new (numpy ≥ 2 / NEP 50) behaviour:

Expression numpy < 2 (old) numpy ≥ 2 (NEP 50)
(3j + array(3, complex64)).dtype complex128 complex64
(float32(1) + 1j).dtype complex128 complex64

Source: https://numpy.org/neps/nep-0050-scalar-promotion.html

Python complex literals are not weak types in numpy < 2 — they behave like np.complex128 in the type promotion system. The "weak scalar" treatment is exactly what NEP 50 introduced. So on numpy 1.26.4, -1j * k_f32 * spacing / 2complex128, and np.exp(...) of a complex128 array → complex128.

It's possible the empirical test ran against a numpy 2.x environment, or that k_list[ax] was already float64 at the point of measurement (defeating the test). The regression test added in d545952 checks recorded sensor_data["p"] and "ux_staggered" dtypes, but those go through the .astype(self._dtype, copy=False) guard added in step() — so they'll pass regardless of what unstagger_ops itself is. The dtype of unstagger_ops isn't directly asserted.

The fix remains: cast the list comprehension result to self._complex_dtype:

Suggested change
# Spectral shift: move velocity from staggered (mid-cell) to collocated (pressure) grid
self.unstagger_ops = [xp.exp(-1j * self.k_list[ax] * self.spacing[ax] / 2) for ax in range(self.ndim)]
self.unstagger_ops = [xp.exp(-1j * self.k_list[ax] * self.spacing[ax] / 2).astype(self._complex_dtype) for ax in range(self.ndim)]

Tip: You can give me more context with .greptile/rules.md so I get this right next time.

Greptile flagged unstagger_ops as a P1 (claimed 1j * k_f32 -> complex128
on numpy<2). Empirically that's not true under numpy 1.26.4 -- Python
complex is a weak type just like Python float/int, so unstagger_ops
stays complex64 when k_list is float32. The recorded ux_staggered output
also stays float32 because the prior commit already cast the unstagger
ifftn .real result to self._dtype.

Adding a test that records ('p', 'ux_staggered', 'uy_staggered') and
asserts each is float32 under dtype=np.float32 -- guards against any
future regression in the staggered-velocity path on either numpy major
version. 25 tests pass on numpy 1.26.4 and numpy 2.2.6.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@waltsims
Copy link
Copy Markdown
Owner Author

waltsims commented May 4, 2026

@greptile-apps re-review

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[ENHANCE] Add a parameter to kspaceFirstOrder() to control the data type that variables are cast to before computation

1 participant