Add dtype parameter to kspaceFirstOrder() (#695)#716
Add dtype parameter to kspaceFirstOrder() (#695)#716
Conversation
Exposes precision control on the modern unified API to match the legacy SimulationOptions.data_cast and MATLAB k-Wave's DataCast. data_cast='off' -> np.float64 (default; matches legacy) data_cast='double' -> np.float64 (alias for 'off', MATLAB compat) data_cast='single' -> np.float32 (~half memory, faster, lower accuracy) Python backend: plumbs through Simulation, which now stores self._dtype and uses it for all state arrays (p, u, rho_split, sensor_data buffers, PML arrays, alpha_coeff/BonA/p0 expansions, source signal buffers). Default behavior unchanged (float64 everywhere). C++ backend: data_cast has no effect — the binary uses fixed internal precision regardless of HDF5 input dtype. Setting anything other than 'off'/'double' with backend='cpp' emits a UserWarning explaining this and pointing users at backend='python' for precision control. Tests: 8 new in tests/test_data_cast.py covering output dtype matches request, default behavior unchanged, invalid value raises, single vs double numerical agreement within float32 tolerance, and the C++ warn/silent paths. Wider suite (62 tests across native_solver, ivp_homogeneous, issue_664) still passes. Closes #695. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #716 +/- ##
==========================================
+ Coverage 74.82% 75.23% +0.40%
==========================================
Files 56 56
Lines 8096 8132 +36
Branches 1577 1586 +9
==========================================
+ Hits 6058 6118 +60
+ Misses 1422 1395 -27
- Partials 616 619 +3
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Make the precision parameter Pythonic instead of MATLAB-stringly-typed.
The numpy ecosystem's convention is to accept dtype-like inputs broadly
(numpy types, strings, dtype objects), and the modern API should
follow that idiom rather than the legacy SimulationOptions.data_cast
naming.
Accepted forms (resolved via _resolve_dtype, which uses np.dtype()):
None / np.float64 / "float64" / "double" / float / "off" / np.dtype("f8")
-> np.float64 (default)
np.float32 / "float32" / "single" / np.dtype("f4")
-> np.float32
The MATLAB aliases ("off", "single", "double") are kept as compat
shortcuts so users porting from the legacy API or MATLAB k-Wave have
zero friction. Anything resolving to a non-float32/float64 type
(np.float16, np.complex64, etc.) raises ValueError -- the solver
isn't validated for those.
C++ backend warns when dtype is not np.float64 (binary uses fixed
internal precision regardless).
Tests: 22 (was 8) parametrized over every input form. Wider suite
(62 tests) still passes.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
@greptile-apps re-review |
Greptile P2 (test for p_final dtype) caught a real bug: with
dtype='single', p_final came back as float64 even though sensor_data
buffers (p, p_max, p_min, p_rms) were correctly float32.
Root cause: two sources of float64 leaking into the hot loop:
1. xp.fft.fftfreq returns float64; k_list, kappa, op_grad/div_list,
_k_mag inherited it. _diff's FFT round-trip (float64 op * complex64
field) upcasts to complex128, .real => float64. Result: self.p and
self.u rebound to float64 mid-step() despite being allocated as
float32. p_final = self.p[interior].copy() picked up float64.
sensor_data buffers stayed float32 because writes are in-place into
the pre-allocated buffer (silent narrowing on assignment).
2. get_pml returns float64 unconditionally; the per-step pml multiply
was a second upcast path independent of (1).
Both cast sites now apply .astype(self._dtype) at construction time,
keeping the entire compute pipeline in the user's requested precision.
Test updated: float32 / float64 input parametrizations now request
('p', 'p_final', 'p_max', 'p_min', 'p_rms') and assert every field's
dtype matches. Verified: float32 inputs => all five fields float32;
float64 => all five float64.
Bonus: helpful error for torch / jax / tensorflow dtype objects via
duck-typed __module__ check (no framework imports needed); cupy works
for free since cp.float32 is np.float32.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…odern-api # Conflicts: # kwave/solvers/kspace_solver.py
Greptile spotted a third dtype-drift path: ``sum(rho_split)`` in ``_nl_factor`` and the equation-of-state line starts with Python ``int 0``. Under numpy < 2 (NEP 50), ``int + float32 -> float64``, so: nl_factor = (2 * sum(rho_split) + rho0) / rho0 is float64 even when rho_split is float32. The product ``rho0 * div_u_i * nl_factor`` in mass conservation then upcasts the rho_split arrays to float64 on the very first step. Specifically affects any simulation that enables BonA. Fix: ``_array_sum`` helper that starts the accumulator from ``arrays[0]`` so the dtype is preserved. Used in both call sites (_nl_factor lambda and equation-of-state rho_total). Test added: test_python_backend_dtype_preserved_with_nonlinearity exercises the BonA path with dtype=np.float32 and asserts p / p_final / p_max all remain float32. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
@greptile-apps re-review |
End-to-end verification under numpy 1.26.4 (with my dtype tests requesting
('p', 'p_final', 'p_max', 'p_min', 'p_rms')) showed self.p still upcast to
float64 mid-step despite all the prior precision fixes.
Root cause: numpy < 2's `np.fft.fftn` always returns complex128 regardless
of input precision -- a known difference resolved in numpy 2 (NEP 50 era).
The k-space ops (op_grad/div_list, unstagger_ops) being complex64 isn't
enough; multiplying complex64 by complex128 upcasts to complex128, and
.real -> float64 propagates back into self.p.
Fix: introduce self._complex_dtype (complex64 / complex128 matching
self._dtype), and cast every fftn / ifftn .real result back to the
intended precision. Three call sites: step() momentum loop, sensor-data
unstagger, _diff helper. ``copy=False`` makes the cast a no-op when the
dtype already matches (numpy 2, cupy).
Greptile flagged dt_over_rho0, kappa, and sum() as P1 dtype-promotion
paths. Verified empirically on numpy 1.26.4 that none of those upcast in
practice (Python float / int / complex are weak types in numpy 1.x as
well) -- but the test failures Greptile would have seen had the same
visible signature, so the diagnosis pointed at the right region.
Tests: 86 (24 dtype + 62 wider suite) pass on numpy 1.26.4 AND numpy 2.2.6.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
Greptile flagged But Greptile was right that the diff was unsafe on numpy<2 — the actual culprit was different: Fixed in 22ec78d: So same conclusion (not safe on numpy<2 before this commit), different mechanism. Re-triggering review now. @greptile-apps re-review |
| # Spectral shift: move velocity from staggered (mid-cell) to collocated (pressure) grid | ||
| self.unstagger_ops = [xp.exp(-1j * self.k_list[ax] * self.spacing[ax] / 2) for ax in range(self.ndim)] |
There was a problem hiding this comment.
unstagger_ops also built with 1j * literal → always complex128 on numpy < 2
The same Python-literal promotion issue as op_grad_list/op_div_list applies here: -1j * self.k_list[ax] is complex128 on numpy < 2, regardless of k_list dtype. Cast to self._complex_dtype after construction to keep the unstaggering FFT in the correct precision.
| # Spectral shift: move velocity from staggered (mid-cell) to collocated (pressure) grid | |
| self.unstagger_ops = [xp.exp(-1j * self.k_list[ax] * self.spacing[ax] / 2) for ax in range(self.ndim)] | |
| # Spectral shift: move velocity from staggered (mid-cell) to collocated (pressure) grid | |
| self.unstagger_ops = [xp.exp(-1j * self.k_list[ax] * self.spacing[ax] / 2).astype(self._complex_dtype) for ax in range(self.ndim)] |
There was a problem hiding this comment.
Verified empirically on numpy 1.26.4: -1j * k_f32 * spacing / 2 evaluates to complex64 and np.exp(...) preserves complex64 — Python complex is a weak type just like Python float/int, so the multiplication chain stays in complex64 when k_list is float32. So unstagger_ops is already complex64 without an explicit cast.
Additionally, the prior commit (22ec78d) already added .astype(self._dtype, copy=False) to the staggered ifftn .real line in step(), so the recorded output stays float32 even if unstagger_ops were ever to change precision.
Added a regression test (d545952) that records ('p', 'ux_staggered', 'uy_staggered') with dtype=np.float32 and asserts each is float32 — locks the behavior in for both numpy 1.26.4 and 2.2.6.
@greptile-apps re-review
There was a problem hiding this comment.
The empirical result doesn't match what NEP 50's own migration table documents. The NEP 50 spec explicitly shows the old (numpy < 2) vs new (numpy ≥ 2 / NEP 50) behaviour:
| Expression | numpy < 2 (old) | numpy ≥ 2 (NEP 50) |
|---|---|---|
(3j + array(3, complex64)).dtype |
complex128 |
complex64 |
(float32(1) + 1j).dtype |
complex128 |
complex64 |
Source: https://numpy.org/neps/nep-0050-scalar-promotion.html
Python complex literals are not weak types in numpy < 2 — they behave like np.complex128 in the type promotion system. The "weak scalar" treatment is exactly what NEP 50 introduced. So on numpy 1.26.4, -1j * k_f32 * spacing / 2 → complex128, and np.exp(...) of a complex128 array → complex128.
It's possible the empirical test ran against a numpy 2.x environment, or that k_list[ax] was already float64 at the point of measurement (defeating the test). The regression test added in d545952 checks recorded sensor_data["p"] and "ux_staggered" dtypes, but those go through the .astype(self._dtype, copy=False) guard added in step() — so they'll pass regardless of what unstagger_ops itself is. The dtype of unstagger_ops isn't directly asserted.
The fix remains: cast the list comprehension result to self._complex_dtype:
| # Spectral shift: move velocity from staggered (mid-cell) to collocated (pressure) grid | |
| self.unstagger_ops = [xp.exp(-1j * self.k_list[ax] * self.spacing[ax] / 2) for ax in range(self.ndim)] | |
| self.unstagger_ops = [xp.exp(-1j * self.k_list[ax] * self.spacing[ax] / 2).astype(self._complex_dtype) for ax in range(self.ndim)] |
Tip: You can give me more context with .greptile/rules.md so I get this right next time.
Greptile flagged unstagger_ops as a P1 (claimed 1j * k_f32 -> complex128
on numpy<2). Empirically that's not true under numpy 1.26.4 -- Python
complex is a weak type just like Python float/int, so unstagger_ops
stays complex64 when k_list is float32. The recorded ux_staggered output
also stays float32 because the prior commit already cast the unstagger
ifftn .real result to self._dtype.
Adding a test that records ('p', 'ux_staggered', 'uy_staggered') and
asserts each is float32 under dtype=np.float32 -- guards against any
future regression in the staggered-velocity path on either numpy major
version. 25 tests pass on numpy 1.26.4 and numpy 2.2.6.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
@greptile-apps re-review |
Closes #695.
What
Exposes precision control on the modern unified API. Pythonic / numpy-idiomatic naming and accepted input forms:
None(default)np.float64np.float64/"float64"/"double"/float/np.dtype("f8")np.float64np.float32/"float32"/"single"/np.dtype("f4")np.float32"off"(legacy MATLAB alias)np.float64np.float16,np.complex64,"quad", …)ValueErrorThe MATLAB aliases (
"off","single","double") are kept as compat shortcuts for users porting from the legacySimulationOptions.data_castor MATLAB k-Wave'sDataCast. Everything else usesnp.dtype()for normalization, matching the broader numpy/scipy/torch convention.Why
dtypeinstead ofdata_castdata_castis a MATLAB term. The numpy ecosystem (numpy, pandas, jax, torch) usesdtypeand accepts dtype-like inputs broadly. The modern unified API is a fresh design — it should follow the Python idiom rather than the MATLAB one. The MATLAB-style strings still work, so MATLAB users lose nothing.How
Python backend plumbs
dtypethroughSimulation, which now storesself._dtypeand uses it for every state-array allocation:p,u,rho_split, sensor-data buffers, PML arrays, source signal buffers, and the_expand_to_gridhelper forsound_speed/density/alpha_coeff/BonA/p0. Default behavior unchanged (float64 everywhere).C++ backend intentionally has no effect — the binary uses fixed internal precision regardless of HDF5 input dtype. Setting
dtypeto anything other thannp.float64withbackend='cpp'emits aUserWarningexplaining this and pointing users atbackend='python'for precision control.Test plan
New file
tests/test_data_cast.py(22 tests):test_python_backend_float64_inputsparametrized over[None, np.float64, "float64", "double", float, "off", np.dtype("f8")]— every form resolves to float64 outputtest_python_backend_float32_inputsparametrized over[np.float32, "float32", "single", np.dtype("f4")]— every form resolves to float32 outputtest_default_dtype_is_float64— calling without the kwarg gives float64 (back-compat)test_invalid_dtype_raisesparametrized over[np.float16, np.complex64, "float16", "complex64", "quad", 42, "garbage"]— all raiseValueErrortest_python_single_vs_double_numerical_agreement— single and double runs agree to within1e-4relative errortest_cpp_backend_warns_on_non_float64_dtype—UserWarningfires before binary runstest_cpp_backend_silent_on_default_dtype— no warning on defaultWider suite verified (62 tests):
test_native_solver,test_ivp_homogeneous_medium,test_issue_664_alpha_power_near_unityall pass.Greptile Summary
This PR exposes precision control on the modern
kspaceFirstOrder()API via adtypeparameter, wiring float32/float64 throughSimulationstate arrays, PML, k-space operators, and sensor buffers. The_resolve_dtyperesolver, MATLAB-alias compatibility, and the C++ backend warning are all correctly implemented.One normalization gap remains: in
Simulation.__init__, theelif dtype in (np.float32, np.float64)branch stores the raw input asself._dtypewithout converting it to a scalar-type object, causing the subsequentischeck for_complex_dtypeto silently selectcomplex128instead ofcomplex64when a dtype instance (e.g.np.dtype("float32")) is passed directly. Additionally,dt_over_rho0is still computed as a Pythonfloatdivided by afloat32array — on numpy < 2 this yieldsfloat64regardless ofself._dtype, causingself.uandself.rho_splitto be rebound tofloat64after the first step; the new.astype()casts on FFT outputs mitigate the spectral-operator drift but not this direct scalar-arithmetic promotion, so thep_final/p_max/p_rmsfloat32 assertions will likely fail on numpy 1.26.x (CI green is still unchecked).Confidence Score: 3/5
Safe to merge for numpy ≥ 2 users; float32 tests are likely to fail on numpy 1.26.x due to unresolved Python-scalar promotion in dt_over_rho0
The dtype plumbing is thorough and well-tested, and _resolve_dtype, the C++ warning, and alias handling are all correct. However, CI green is still unchecked, and the previously-flagged dt_over_rho0 root cause (Python float / float32 → float64 on numpy < 2) is not addressed in this diff — the .astype() casts on FFT round-trips do not prevent self.u and self.rho_split from becoming float64 mid-step, which would cause p_final dtype assertions to fail on numpy 1.26.x. A minor dtype-instance vs type-object normalization issue in Simulation.init also needs a one-line fix.
kwave/solvers/kspace_solver.py — dt_over_rho0 computation (line 663) and mass-conservation update (line 716) still use bare self.dt (Python float), causing float64 promotion of state arrays on numpy < 2
Important Files Changed
dtypeparameter and_resolve_dtyperesolver; validation, MATLAB aliases, and C++ backend warning all look correct_array_sum, casts PML/k_list/FFT results, and threadsdtypethrough all fields;Simulation.__init__stores raw input inself._dtypewithout normalizing to a type object, so theischeck for_complex_dtypesilently producescomplex128when a dtype instance is passed directlyFlowchart
%%{init: {'theme': 'neutral'}}%% flowchart TD A["kspaceFirstOrder(dtype=...)"] --> B["_resolve_dtype(value)"] B -->|"None / 'off'"| C["np.float64"] B -->|"np.dtype(value).type"| D{"float32 or float64?"} D -->|"float32"| E["np.float32"] D -->|"float64"| C D -->|"other / TypeError"| F["ValueError"] E --> G["Simulation(dtype=np.float32)"] C --> G2["Simulation(dtype=np.float64)"] G --> H["self._dtype = np.float32\nself._complex_dtype = np.complex64"] G2 --> H2["self._dtype = np.float64\nself._complex_dtype = np.complex128"] H --> I["setup():\nc0, rho0, p, u, rho_split,\npml, k_list cast to float32\ndt_over_rho0 = self.dt(Python float)/rho"] I --> J["step():\nP = fftn(p).astype(complex64)\ngrad_p = _diff(...).astype(float32)\ndiv_u = _diff(...).astype(float32)\nu, rho_split, p updated"] J --> K["run():\nsensor_data (float32 buffers)\np_final = self.p.copy()"] K --> L["backend='cpp': UserWarning if dtype≠float64"]Comments Outside Diff (2)
kwave/solvers/kspace_solver.py, line 663 (link)dt_over_rho0computed as Pythonfloat/float32→ silently float64 on numpy < 2self.dtis stored asfloat(self.kgrid.dt)(a Python float, equivalent tofloat64). Dividing by afloat32array gives afloat64result under numpy < 2 (NEP 50 changed this in numpy 2.0). As a consequence, on everystep():self.dt_over_rho0[i] * grad_p_i(line 706) isfloat64 × float32 → float64, soself.u[i]is rebound to afloat64array after the very first step. The same Python-scalar promotion also affects line 716 (self.dt * self.rho0 * div_u_i * nl_factor), sorho_split[i]and, throughrho_total,self.palso becomefloat64. The sensor-data buffer (sensor_data["p"]) is pre-allocatedfloat32and silently narrows values on in-place assignment, soresult["p"]tests pass, butresult["p_final"](line 773) isself.p[interior].copy()— no narrowing — and will befloat64even whendtype=np.float32is requested, breaking the dtype contract on numpy < 2.kwave/solvers/kspace_solver.py, line 429-447 (link)complex128on numpy < 2k_listentries are now correctly cast toself._dtype(float32 when requested), butself.c_refandself.dtare Python floats (float64), soself.c_ref * k_mag * self.dt / 2isfloat64on numpy < 2, makingkappaandsource_kappafloat64. The1j *Python complex literal then forcesop_grad_listandop_div_listtocomplex128rather thancomplex64. Castkappa/source_kappatoself._dtypeand the final operators toself._complex_dtypeafter construction.Reviews (5): Last reviewed commit: "Add staggered-velocity dtype regression ..." | Re-trigger Greptile